fix
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
from django.db import models
|
||||
from django.db.models import F
|
||||
from django.utils import timezone
|
||||
|
||||
from .models import Submission
|
||||
@@ -7,6 +8,33 @@ from utils.serializers import LanguageNameChoiceField
|
||||
from problemset.models import ProblemSetProgress
|
||||
|
||||
|
||||
def bulk_fetch_problemset_progress(user, problem_ids):
|
||||
"""一次 IN 查询获取该用户对多个题目的题单进度,返回 {problem_id: ProblemSetProgress|None}"""
|
||||
if not problem_ids:
|
||||
return {}
|
||||
rows = (
|
||||
ProblemSetProgress.objects.filter(
|
||||
user=user,
|
||||
problemset__status="active",
|
||||
problemset__problemsetproblem__problem_id__in=problem_ids,
|
||||
)
|
||||
.filter(
|
||||
models.Q(problemset__end_time__isnull=True)
|
||||
| models.Q(problemset__end_time__gt=timezone.now())
|
||||
)
|
||||
.annotate(matched_problem_id=F("problemset__problemsetproblem__problem_id"))
|
||||
.only("join_time", "progress_detail")
|
||||
)
|
||||
cache = {}
|
||||
for row in rows:
|
||||
pid = row.matched_problem_id
|
||||
if pid not in cache:
|
||||
cache[pid] = row
|
||||
for pid in problem_ids:
|
||||
cache.setdefault(pid, None)
|
||||
return cache
|
||||
|
||||
|
||||
class CreateSubmissionSerializer(serializers.Serializer):
|
||||
problem_id = serializers.IntegerField()
|
||||
language = LanguageNameChoiceField()
|
||||
@@ -44,7 +72,10 @@ class SubmissionListSerializer(serializers.ModelSerializer):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.user = kwargs.pop("user", None)
|
||||
preloaded = kwargs.pop("problemset_progress_cache", None)
|
||||
super().__init__(*args, **kwargs)
|
||||
if preloaded is not None:
|
||||
self._problemset_progress_cache = preloaded
|
||||
|
||||
class Meta:
|
||||
model = Submission
|
||||
|
||||
@@ -18,7 +18,7 @@ from ..serializers import (
|
||||
SubmissionModelSerializer,
|
||||
ShareSubmissionSerializer,
|
||||
)
|
||||
from ..serializers import SubmissionSafeModelSerializer, SubmissionListSerializer
|
||||
from ..serializers import SubmissionSafeModelSerializer, SubmissionListSerializer, bulk_fetch_problemset_progress
|
||||
|
||||
|
||||
class SubmissionAPI(APIView):
|
||||
@@ -193,8 +193,14 @@ class SubmissionListAPI(APIView):
|
||||
)
|
||||
|
||||
data = self.paginate_data(request, submissions)
|
||||
results = data["results"]
|
||||
if request.user.is_authenticated and request.user.is_regular_user():
|
||||
problem_ids = list({s.problem_id for s in results})
|
||||
progress_cache = bulk_fetch_problemset_progress(request.user, problem_ids)
|
||||
else:
|
||||
progress_cache = {}
|
||||
data["results"] = SubmissionListSerializer(
|
||||
data["results"], many=True, user=request.user
|
||||
results, many=True, user=request.user, problemset_progress_cache=progress_cache
|
||||
).data
|
||||
return self.success(data)
|
||||
|
||||
@@ -241,8 +247,14 @@ class ContestSubmissionListAPI(APIView):
|
||||
submissions = submissions.filter(user_id=request.user.id)
|
||||
|
||||
data = self.paginate_data(request, submissions)
|
||||
results = data["results"]
|
||||
if request.user.is_authenticated and request.user.is_regular_user():
|
||||
problem_ids = list({s.problem_id for s in results})
|
||||
progress_cache = bulk_fetch_problemset_progress(request.user, problem_ids)
|
||||
else:
|
||||
progress_cache = {}
|
||||
data["results"] = SubmissionListSerializer(
|
||||
data["results"], many=True, user=request.user
|
||||
results, many=True, user=request.user, problemset_progress_cache=progress_cache
|
||||
).data
|
||||
return self.success(data)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user