fix N+1 query
This commit is contained in:
@@ -241,11 +241,29 @@ class ProblemSetProgressSerializer(serializers.ModelSerializer):
|
|||||||
|
|
||||||
def get_completed_problems(self, obj):
|
def get_completed_problems(self, obj):
|
||||||
"""获取已完成的题目列表"""
|
"""获取已完成的题目列表"""
|
||||||
from problem.models import Problem
|
|
||||||
|
|
||||||
completed_problems = []
|
completed_problems = []
|
||||||
|
|
||||||
|
# 尝试从 request 中获取预加载的问题字典(用于批量查询优化)
|
||||||
|
problems_dict = {}
|
||||||
|
request = self.context.get('request')
|
||||||
|
if request and hasattr(request, '_problems_dict_cache'):
|
||||||
|
problems_dict = request._problems_dict_cache
|
||||||
|
|
||||||
if obj.progress_detail:
|
if obj.progress_detail:
|
||||||
for problem_id in obj.progress_detail.keys():
|
for problem_id in obj.progress_detail.keys():
|
||||||
|
# 优先使用预加载的问题字典
|
||||||
|
if problems_dict:
|
||||||
|
problem = problems_dict.get(problem_id)
|
||||||
|
if problem:
|
||||||
|
completed_problems.append({
|
||||||
|
'id': problem.id,
|
||||||
|
'_id': problem._id,
|
||||||
|
'title': problem.title
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 如果没有预加载字典,则回退到单独查询(向后兼容)
|
||||||
|
from problem.models import Problem
|
||||||
try:
|
try:
|
||||||
problem = Problem.objects.get(id=problem_id)
|
problem = Problem.objects.get(id=problem_id)
|
||||||
completed_problems.append({
|
completed_problems.append({
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from django.db.models import Q, Avg
|
from django.db.models import Q, Avg, Count
|
||||||
from django.utils import timezone
|
from django.utils import timezone
|
||||||
|
|
||||||
from utils.api import APIView, validate_serializer
|
from utils.api import APIView, validate_serializer
|
||||||
@@ -306,13 +306,13 @@ class ProblemSetUserProgressAPI(APIView):
|
|||||||
except ProblemSet.DoesNotExist:
|
except ProblemSet.DoesNotExist:
|
||||||
return self.error("题单不存在")
|
return self.error("题单不存在")
|
||||||
|
|
||||||
# 获取所有参与该题单的用户进度
|
# 获取所有参与该题单的用户进度,使用 select_related 预加载用户信息
|
||||||
progresses = ProblemSetProgress.objects.filter(problemset=problem_set)
|
progresses = ProblemSetProgress.objects.filter(problemset=problem_set).select_related('user')
|
||||||
|
|
||||||
# 班级过滤
|
# 班级过滤
|
||||||
class_name = request.GET.get("class_name", "").strip()
|
class_name = request.GET.get("class_name", "").strip()
|
||||||
if class_name:
|
if class_name:
|
||||||
progresses = progresses.filter(user_username__icontains=class_name)
|
progresses = progresses.filter(user__username__icontains=class_name)
|
||||||
|
|
||||||
# 排序
|
# 排序
|
||||||
progresses = progresses.order_by(
|
progresses = progresses.order_by(
|
||||||
@@ -320,9 +320,43 @@ class ProblemSetUserProgressAPI(APIView):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 计算统计数据(基于所有数据,而非分页数据)
|
# 计算统计数据(基于所有数据,而非分页数据)
|
||||||
total_count = progresses.count()
|
# 使用一次查询获取所有统计数据
|
||||||
completed_count = progresses.filter(is_completed=True).count()
|
stats = progresses.aggregate(
|
||||||
avg_progress = progresses.aggregate(avg=Avg("progress_percentage"))["avg"] or 0
|
total=Count('id'),
|
||||||
|
completed=Count('id', filter=Q(is_completed=True)),
|
||||||
|
avg_progress=Avg("progress_percentage")
|
||||||
|
)
|
||||||
|
total_count = stats['total']
|
||||||
|
completed_count = stats['completed']
|
||||||
|
avg_progress = stats['avg_progress'] or 0
|
||||||
|
|
||||||
|
# 获取分页参数,用于预先收集当前页需要的问题ID
|
||||||
|
try:
|
||||||
|
limit = int(request.GET.get("limit", "10"))
|
||||||
|
except ValueError:
|
||||||
|
limit = 10
|
||||||
|
try:
|
||||||
|
offset = int(request.GET.get("offset", "0"))
|
||||||
|
except ValueError:
|
||||||
|
offset = 0
|
||||||
|
if offset < 0:
|
||||||
|
offset = 0
|
||||||
|
|
||||||
|
# 只从当前页的数据中收集问题ID,避免遍历所有记录
|
||||||
|
paginated_progresses = list(progresses[offset:offset + limit])
|
||||||
|
all_problem_ids = set()
|
||||||
|
for progress in paginated_progresses:
|
||||||
|
if progress.progress_detail:
|
||||||
|
all_problem_ids.update(progress.progress_detail.keys())
|
||||||
|
|
||||||
|
# 批量加载当前页所需的问题
|
||||||
|
problems_dict = {}
|
||||||
|
if all_problem_ids:
|
||||||
|
problems = Problem.objects.filter(id__in=all_problem_ids).only('id', '_id', 'title')
|
||||||
|
problems_dict = {str(problem.id): problem for problem in problems}
|
||||||
|
|
||||||
|
# 将预加载的问题字典存储到 request 中,供序列化器使用
|
||||||
|
request._problems_dict_cache = problems_dict
|
||||||
|
|
||||||
# 使用分页
|
# 使用分页
|
||||||
data = self.paginate_data(request, progresses, ProblemSetProgressSerializer)
|
data = self.paginate_data(request, progresses, ProblemSetProgressSerializer)
|
||||||
|
|||||||
Reference in New Issue
Block a user