fix N+1 query

This commit is contained in:
2025-12-22 18:46:17 +08:00
parent 391647785c
commit 59a5efd6bd
2 changed files with 61 additions and 9 deletions

View File

@@ -1,4 +1,4 @@
from django.db.models import Q, Avg
from django.db.models import Q, Avg, Count
from django.utils import timezone
from utils.api import APIView, validate_serializer
@@ -306,13 +306,13 @@ class ProblemSetUserProgressAPI(APIView):
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
# 获取所有参与该题单的用户进度
progresses = ProblemSetProgress.objects.filter(problemset=problem_set)
# 获取所有参与该题单的用户进度,使用 select_related 预加载用户信息
progresses = ProblemSetProgress.objects.filter(problemset=problem_set).select_related('user')
# 班级过滤
class_name = request.GET.get("class_name", "").strip()
if class_name:
progresses = progresses.filter(user_username__icontains=class_name)
progresses = progresses.filter(user__username__icontains=class_name)
# 排序
progresses = progresses.order_by(
@@ -320,9 +320,43 @@ class ProblemSetUserProgressAPI(APIView):
)
# 计算统计数据(基于所有数据,而非分页数据)
total_count = progresses.count()
completed_count = progresses.filter(is_completed=True).count()
avg_progress = progresses.aggregate(avg=Avg("progress_percentage"))["avg"] or 0
# 使用一次查询获取所有统计数据
stats = progresses.aggregate(
total=Count('id'),
completed=Count('id', filter=Q(is_completed=True)),
avg_progress=Avg("progress_percentage")
)
total_count = stats['total']
completed_count = stats['completed']
avg_progress = stats['avg_progress'] or 0
# 获取分页参数用于预先收集当前页需要的问题ID
try:
limit = int(request.GET.get("limit", "10"))
except ValueError:
limit = 10
try:
offset = int(request.GET.get("offset", "0"))
except ValueError:
offset = 0
if offset < 0:
offset = 0
# 只从当前页的数据中收集问题ID避免遍历所有记录
paginated_progresses = list(progresses[offset:offset + limit])
all_problem_ids = set()
for progress in paginated_progresses:
if progress.progress_detail:
all_problem_ids.update(progress.progress_detail.keys())
# 批量加载当前页所需的问题
problems_dict = {}
if all_problem_ids:
problems = Problem.objects.filter(id__in=all_problem_ids).only('id', '_id', 'title')
problems_dict = {str(problem.id): problem for problem in problems}
# 将预加载的问题字典存储到 request 中,供序列化器使用
request._problems_dict_cache = problems_dict
# 使用分页
data = self.paginate_data(request, progresses, ProblemSetProgressSerializer)