From afde8dfc404127acc74afab506520f2d4f300af7 Mon Sep 17 00:00:00 2001 From: yuetsh <517252939@qq.com> Date: Fri, 26 Dec 2025 17:11:35 +0800 Subject: [PATCH] fix problemset list N+1 --- problemset/serializers.py | 40 +++++++++++++++++++------------- problemset/views/oj.py | 48 +++++++++++++++++++++++++++++++++++++-- 2 files changed, 70 insertions(+), 18 deletions(-) diff --git a/problemset/serializers.py b/problemset/serializers.py index ded572f..3015ddd 100644 --- a/problemset/serializers.py +++ b/problemset/serializers.py @@ -78,7 +78,7 @@ class ProblemSetListSerializer(serializers.ModelSerializer): """题单列表序列化器""" created_by = UsernameSerializer() - problems_count = serializers.SerializerMethodField() + problems_count = serializers.IntegerField(read_only=True) user_progress = serializers.SerializerMethodField() badges = serializers.SerializerMethodField() @@ -98,35 +98,43 @@ class ProblemSetListSerializer(serializers.ModelSerializer): "visible", ] - def get_problems_count(self, obj): - """获取题单中的题目数量""" - return ProblemSetProblem.objects.filter(problemset=obj).count() - def get_user_progress(self, obj): """获取当前用户在该题单中的进度""" request = self.context.get("request") - return get_user_progress_data(obj, request) + if request and hasattr(request, "_user_progress_map"): + progress = request._user_progress_map.get(obj.id) + if progress: + return { + "is_joined": True, + "progress_percentage": progress.progress_percentage, + "completed_count": progress.completed_problems_count, + "total_count": progress.total_problems_count, + "is_completed": progress.is_completed, + } + return { + "is_joined": False, + "progress_percentage": 0, + "completed_count": 0, + "total_count": 0, + "is_completed": False, + } def get_badges(self, obj): """获取题单的奖章列表,并标记用户已获得的徽章""" request = self.context.get("request") - badges = ProblemSetBadge.objects.filter(problemset=obj) + + # 使用预加载的奖章数据 + badges = getattr(obj, "badges", []) badge_data = ProblemSetBadgeSerializer(badges, many=True).data # 如果用户已登录,检查哪些徽章已被获得 - if request and request.user.is_authenticated: - earned_badge_ids = set( - UserBadge.objects.filter( - user=request.user, - badge__problemset=obj - ).values_list('badge_id', flat=True) - ) - + if request and request.user.is_authenticated and hasattr(request, "_user_earned_badge_ids"): + earned_badge_ids = request._user_earned_badge_ids # 为每个徽章添加是否已获得的标记 for badge in badge_data: badge['is_earned'] = badge['id'] in earned_badge_ids else: - # 未登录用户,所有徽章都标记为未获得 + # 未登录用户或未预加载,所有徽章都标记为未获得 for badge in badge_data: badge['is_earned'] = False diff --git a/problemset/views/oj.py b/problemset/views/oj.py index d1e11e9..2f59a5b 100644 --- a/problemset/views/oj.py +++ b/problemset/views/oj.py @@ -1,4 +1,4 @@ -from django.db.models import Q, Avg, Count +from django.db.models import Q, Avg, Count, Prefetch from django.utils import timezone from utils.api import APIView, validate_serializer @@ -33,7 +33,13 @@ class ProblemSetAPI(APIView): def get(self, request): """获取题单列表""" - problem_sets = ProblemSet.objects.filter(visible=True).exclude(status="draft") + # 预加载创建者信息 + problem_sets = ProblemSet.objects.filter(visible=True).exclude(status="draft").select_related("created_by") + + # 使用annotate在查询时计算题目数量,避免N+1查询 + problem_sets = problem_sets.annotate( + problems_count=Count("problemsetproblem", distinct=True) + ) # 过滤条件 keyword = request.GET.get("keyword", "").strip() @@ -57,6 +63,44 @@ class ProblemSetAPI(APIView): else: problem_sets = problem_sets.order_by("-create_time") + # 批量查询用户进度和已获得的奖章(如果用户已登录) + # 注意:需要在应用prefetch_related之前获取ID列表,避免不必要的预加载 + user_progress_map = {} + user_earned_badge_ids = set() + if request.user.is_authenticated: + # 先获取所有题单ID(不应用prefetch_related,只获取ID) + problem_set_ids = list(problem_sets.values_list("id", flat=True)) + + if problem_set_ids: + # 批量查询用户在这些题单中的进度 + user_progresses = ProblemSetProgress.objects.filter( + problemset_id__in=problem_set_ids, + user=request.user + ).select_related("problemset") + # 构建映射:题单ID -> 进度对象 + user_progress_map = {progress.problemset_id: progress for progress in user_progresses} + + # 批量查询用户已获得的奖章ID(这些题单相关的) + user_earned_badge_ids = set( + UserBadge.objects.filter( + user=request.user, + badge__problemset_id__in=problem_set_ids + ).values_list('badge_id', flat=True) + ) + + # 预加载奖章信息(在获取ID之后应用,避免在获取ID时也预加载) + problem_sets = problem_sets.prefetch_related( + Prefetch( + "problemsetbadge_set", + queryset=ProblemSetBadge.objects.all(), + to_attr="badges" + ) + ) + + # 将用户进度映射和已获得的奖章ID集合存储到request中,供序列化器使用 + request._user_progress_map = user_progress_map + request._user_earned_badge_ids = user_earned_badge_ids + data = self.paginate_data(request, problem_sets, ProblemSetListSerializer) return self.success(data)