diff --git a/problemset/views/oj.py b/problemset/views/oj.py index 9dc1654..946e153 100644 --- a/problemset/views/oj.py +++ b/problemset/views/oj.py @@ -336,7 +336,7 @@ class ProblemSetUserProgressAPI(APIView): completed_count = stats["completed"] avg_progress = stats["avg_progress"] or 0 - # 获取分页参数,用于预先收集当前页需要的问题ID + # 获取分页参数 try: limit = int(request.GET.get("limit", "10")) except ValueError: @@ -348,20 +348,42 @@ class ProblemSetUserProgressAPI(APIView): if offset < 0: offset = 0 - # 只从当前页的数据中收集问题ID,避免遍历所有记录 + # 提前获取题单的所有题目(用于前端显示未完成题目和序列化器) + # 使用 select_related 和 only 优化查询,只选择需要的字段 + all_problemset_problems = ( + ProblemSetProblem.objects.filter(problemset=problem_set) + .select_related("problem") + .only("problem__id", "problem___id", "problem__title", "order") + .order_by("order") + ) + + # 构建题单所有题目的数据结构和映射 + all_problems_list = [] + all_problems_map = {} + for psp in all_problemset_problems: + problem_data = { + "id": psp.problem.id, + "_id": psp.problem._id, + "title": psp.problem.title, + } + all_problems_list.append(problem_data) + # 用于序列化器查找,key 使用字符串格式(与 progress_detail 的 key 格式一致) + all_problems_map[str(psp.problem.id)] = psp.problem + + # 从当前页的数据中收集已完成的问题ID,用于序列化器 paginated_progresses = list(progresses[offset : offset + limit]) - all_problem_ids = set() + completed_problem_ids = set() for progress in paginated_progresses: if progress.progress_detail: - all_problem_ids.update(progress.progress_detail.keys()) + # progress_detail 的 key 是字符串格式的 problem_id + completed_problem_ids.update(progress.progress_detail.keys()) - # 批量加载当前页所需的问题 - problems_dict = {} - if all_problem_ids: - problems = Problem.objects.filter(id__in=all_problem_ids).only( - "id", "_id", "title" - ) - problems_dict = {str(problem.id): problem for problem in problems} + # 从已加载的题单题目中构建 problems_dict,避免重复查询 + problems_dict = { + pid: all_problems_map[pid] + for pid in completed_problem_ids + if pid in all_problems_map + } # 将预加载的问题字典存储到 request 中,供序列化器使用 request._problems_dict_cache = problems_dict @@ -376,14 +398,7 @@ class ProblemSetUserProgressAPI(APIView): "avg_progress": round(avg_progress, 2), } - # 返回问题 ID 列表 - data["problems"] = [ - { - "id": problem.id, - "_id": problem._id, - "title": problem.title, - } - for problem in problems_dict.values() - ] + # 返回题单的所有题目 + data["problems"] = all_problems_list return self.success(data)