diff --git a/submission/api.py b/submission/api.py index fe1a9d1..150eed0 100644 --- a/submission/api.py +++ b/submission/api.py @@ -5,6 +5,7 @@ from ninja.errors import HttpError from ninja.pagination import paginate from django.shortcuts import get_object_or_404 from django.contrib.auth.decorators import login_required +from django.db.models import OuterRef, Subquery, IntegerField from .schemas import ( @@ -55,19 +56,20 @@ def list_submissions(request, filters: SubmissionFilter = Query(...)): if filters.username: submissions = submissions.filter(user__username__icontains=filters.username) - # 获取所有提交 - submissions = submissions.prefetch_related("ratings") - - # 获取当前用户的评分 - user_ratings = { - rating.submission_id: rating.score - for rating in Rating.objects.filter( - user=request.user, - submission__in=submissions - ) - } + user_rating_subquery = Subquery( + Rating.objects.filter(user=request.user, submission=OuterRef("pk")).values( + "score" + )[:1], + output_field=IntegerField(), + ) + submissions = submissions.annotate(my_score=user_rating_subquery) - return [SubmissionOut.list(submission, user_ratings) for submission in submissions] + def get_submission_data(submission): + """从 submission 对象构建 SubmissionOut 数据""" + my_score = getattr(submission, "my_score", None) or 0 + return SubmissionOut.list(submission, {submission.id: my_score}) + + return [get_submission_data(submission) for submission in submissions] @router.get("/{submission_id}", response=SubmissionOut)