From 59a5efd6bd6cc7629e52cedadeaf90cb003594ca Mon Sep 17 00:00:00 2001 From: yuetsh <517252939@qq.com> Date: Mon, 22 Dec 2025 18:46:17 +0800 Subject: [PATCH] fix N+1 query --- problemset/serializers.py | 22 ++++++++++++++++-- problemset/views/oj.py | 48 +++++++++++++++++++++++++++++++++------ 2 files changed, 61 insertions(+), 9 deletions(-) diff --git a/problemset/serializers.py b/problemset/serializers.py index 343ac19..ded572f 100644 --- a/problemset/serializers.py +++ b/problemset/serializers.py @@ -241,11 +241,29 @@ class ProblemSetProgressSerializer(serializers.ModelSerializer): def get_completed_problems(self, obj): """获取已完成的题目列表""" - from problem.models import Problem - completed_problems = [] + + # 尝试从 request 中获取预加载的问题字典(用于批量查询优化) + problems_dict = {} + request = self.context.get('request') + if request and hasattr(request, '_problems_dict_cache'): + problems_dict = request._problems_dict_cache + if obj.progress_detail: for problem_id in obj.progress_detail.keys(): + # 优先使用预加载的问题字典 + if problems_dict: + problem = problems_dict.get(problem_id) + if problem: + completed_problems.append({ + 'id': problem.id, + '_id': problem._id, + 'title': problem.title + }) + continue + + # 如果没有预加载字典,则回退到单独查询(向后兼容) + from problem.models import Problem try: problem = Problem.objects.get(id=problem_id) completed_problems.append({ diff --git a/problemset/views/oj.py b/problemset/views/oj.py index 69fe483..198380b 100644 --- a/problemset/views/oj.py +++ b/problemset/views/oj.py @@ -1,4 +1,4 @@ -from django.db.models import Q, Avg +from django.db.models import Q, Avg, Count from django.utils import timezone from utils.api import APIView, validate_serializer @@ -306,13 +306,13 @@ class ProblemSetUserProgressAPI(APIView): except ProblemSet.DoesNotExist: return self.error("题单不存在") - # 获取所有参与该题单的用户进度 - progresses = ProblemSetProgress.objects.filter(problemset=problem_set) + # 获取所有参与该题单的用户进度,使用 select_related 预加载用户信息 + progresses = ProblemSetProgress.objects.filter(problemset=problem_set).select_related('user') # 班级过滤 class_name = request.GET.get("class_name", "").strip() if class_name: - progresses = progresses.filter(user_username__icontains=class_name) + progresses = progresses.filter(user__username__icontains=class_name) # 排序 progresses = progresses.order_by( @@ -320,9 +320,43 @@ class ProblemSetUserProgressAPI(APIView): ) # 计算统计数据(基于所有数据,而非分页数据) - total_count = progresses.count() - completed_count = progresses.filter(is_completed=True).count() - avg_progress = progresses.aggregate(avg=Avg("progress_percentage"))["avg"] or 0 + # 使用一次查询获取所有统计数据 + stats = progresses.aggregate( + total=Count('id'), + completed=Count('id', filter=Q(is_completed=True)), + avg_progress=Avg("progress_percentage") + ) + total_count = stats['total'] + completed_count = stats['completed'] + avg_progress = stats['avg_progress'] or 0 + + # 获取分页参数,用于预先收集当前页需要的问题ID + try: + limit = int(request.GET.get("limit", "10")) + except ValueError: + limit = 10 + try: + offset = int(request.GET.get("offset", "0")) + except ValueError: + offset = 0 + if offset < 0: + offset = 0 + + # 只从当前页的数据中收集问题ID,避免遍历所有记录 + paginated_progresses = list(progresses[offset:offset + limit]) + all_problem_ids = set() + for progress in paginated_progresses: + if progress.progress_detail: + all_problem_ids.update(progress.progress_detail.keys()) + + # 批量加载当前页所需的问题 + problems_dict = {} + if all_problem_ids: + problems = Problem.objects.filter(id__in=all_problem_ids).only('id', '_id', 'title') + problems_dict = {str(problem.id): problem for problem in problems} + + # 将预加载的问题字典存储到 request 中,供序列化器使用 + request._problems_dict_cache = problems_dict # 使用分页 data = self.paginate_data(request, progresses, ProblemSetProgressSerializer)