fix N+1 query

This commit is contained in:
2025-12-22 18:46:17 +08:00
parent 391647785c
commit 59a5efd6bd
2 changed files with 61 additions and 9 deletions

View File

@@ -241,11 +241,29 @@ class ProblemSetProgressSerializer(serializers.ModelSerializer):
def get_completed_problems(self, obj): def get_completed_problems(self, obj):
"""获取已完成的题目列表""" """获取已完成的题目列表"""
from problem.models import Problem
completed_problems = [] completed_problems = []
# 尝试从 request 中获取预加载的问题字典(用于批量查询优化)
problems_dict = {}
request = self.context.get('request')
if request and hasattr(request, '_problems_dict_cache'):
problems_dict = request._problems_dict_cache
if obj.progress_detail: if obj.progress_detail:
for problem_id in obj.progress_detail.keys(): for problem_id in obj.progress_detail.keys():
# 优先使用预加载的问题字典
if problems_dict:
problem = problems_dict.get(problem_id)
if problem:
completed_problems.append({
'id': problem.id,
'_id': problem._id,
'title': problem.title
})
continue
# 如果没有预加载字典,则回退到单独查询(向后兼容)
from problem.models import Problem
try: try:
problem = Problem.objects.get(id=problem_id) problem = Problem.objects.get(id=problem_id)
completed_problems.append({ completed_problems.append({

View File

@@ -1,4 +1,4 @@
from django.db.models import Q, Avg from django.db.models import Q, Avg, Count
from django.utils import timezone from django.utils import timezone
from utils.api import APIView, validate_serializer from utils.api import APIView, validate_serializer
@@ -306,13 +306,13 @@ class ProblemSetUserProgressAPI(APIView):
except ProblemSet.DoesNotExist: except ProblemSet.DoesNotExist:
return self.error("题单不存在") return self.error("题单不存在")
# 获取所有参与该题单的用户进度 # 获取所有参与该题单的用户进度,使用 select_related 预加载用户信息
progresses = ProblemSetProgress.objects.filter(problemset=problem_set) progresses = ProblemSetProgress.objects.filter(problemset=problem_set).select_related('user')
# 班级过滤 # 班级过滤
class_name = request.GET.get("class_name", "").strip() class_name = request.GET.get("class_name", "").strip()
if class_name: if class_name:
progresses = progresses.filter(user_username__icontains=class_name) progresses = progresses.filter(user__username__icontains=class_name)
# 排序 # 排序
progresses = progresses.order_by( progresses = progresses.order_by(
@@ -320,9 +320,43 @@ class ProblemSetUserProgressAPI(APIView):
) )
# 计算统计数据(基于所有数据,而非分页数据) # 计算统计数据(基于所有数据,而非分页数据)
total_count = progresses.count() # 使用一次查询获取所有统计数据
completed_count = progresses.filter(is_completed=True).count() stats = progresses.aggregate(
avg_progress = progresses.aggregate(avg=Avg("progress_percentage"))["avg"] or 0 total=Count('id'),
completed=Count('id', filter=Q(is_completed=True)),
avg_progress=Avg("progress_percentage")
)
total_count = stats['total']
completed_count = stats['completed']
avg_progress = stats['avg_progress'] or 0
# 获取分页参数用于预先收集当前页需要的问题ID
try:
limit = int(request.GET.get("limit", "10"))
except ValueError:
limit = 10
try:
offset = int(request.GET.get("offset", "0"))
except ValueError:
offset = 0
if offset < 0:
offset = 0
# 只从当前页的数据中收集问题ID避免遍历所有记录
paginated_progresses = list(progresses[offset:offset + limit])
all_problem_ids = set()
for progress in paginated_progresses:
if progress.progress_detail:
all_problem_ids.update(progress.progress_detail.keys())
# 批量加载当前页所需的问题
problems_dict = {}
if all_problem_ids:
problems = Problem.objects.filter(id__in=all_problem_ids).only('id', '_id', 'title')
problems_dict = {str(problem.id): problem for problem in problems}
# 将预加载的问题字典存储到 request 中,供序列化器使用
request._problems_dict_cache = problems_dict
# 使用分页 # 使用分页
data = self.paginate_data(request, progresses, ProblemSetProgressSerializer) data = self.paginate_data(request, progresses, ProblemSetProgressSerializer)