This commit is contained in:
2026-03-30 09:34:38 -06:00
parent a12a665fde
commit 24ff67ec0c
14 changed files with 149 additions and 42 deletions

View File

@@ -14,7 +14,7 @@ class AnnouncementAPI(APIView):
except Announcement.DoesNotExist:
return self.error("Announcement does not exist")
announcements = Announcement.objects.filter(visible=True)
announcements = Announcement.objects.select_related("created_by").filter(visible=True)
return self.success(
self.paginate_data(request, announcements, AnnouncementListSerializer)
)

View File

@@ -0,0 +1,23 @@
# Generated by Django 6.0 on 2026-03-30 15:28
from django.conf import settings
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('contest', '0001_initial'),
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
]
operations = [
migrations.AddIndex(
model_name='acmcontestrank',
index=models.Index(fields=['contest', 'accepted_number', 'total_time'], name='acm_rank_order_idx'),
),
migrations.AddIndex(
model_name='oicontestrank',
index=models.Index(fields=['contest', 'total_score'], name='oi_rank_order_idx'),
),
]

View File

@@ -79,6 +79,10 @@ class ACMContestRank(AbstractContestRank):
class Meta:
db_table = "acm_contest_rank"
unique_together = (("user", "contest"),)
indexes = [
models.Index(fields=["contest", "accepted_number", "total_time"],
name="acm_rank_order_idx"),
]
class OIContestRank(AbstractContestRank):
@@ -90,6 +94,9 @@ class OIContestRank(AbstractContestRank):
class Meta:
db_table = "oi_contest_rank"
unique_together = (("user", "contest"),)
indexes = [
models.Index(fields=["contest", "total_score"], name="oi_rank_order_idx"),
]
class ContestAnnouncement(models.Model):

View File

@@ -169,15 +169,16 @@ class ContestRankAPI(APIView):
cache_key = f"{CacheKey.contest_rank_cache}:{self.contest.id}"
qs = cache.get(cache_key)
if not qs:
qs = self.get_rank()
qs = list(self.get_rank())
cache.set(cache_key, qs)
if download_csv:
data = serializer(qs, many=True, is_contest_admin=is_contest_admin).data
contest_problems = Problem.objects.filter(
contest_problems = list(Problem.objects.filter(
contest=self.contest, visible=True
).order_by("_id")
problem_ids = [item.id for item in contest_problems]
).order_by("_id"))
# 预建 problem_id → 列索引 的字典,避免循环中 O(n) list.index()
problem_id_to_col = {p.id: i for i, p in enumerate(contest_problems)}
f = io.BytesIO()
workbook = xlsxwriter.Workbook(f)
@@ -187,11 +188,8 @@ class ContestRankAPI(APIView):
worksheet.write("C1", "Real Name")
if self.contest.rule_type == ContestRuleType.OI:
worksheet.write("D1", "Total Score")
for item in range(contest_problems.count()):
worksheet.write(
self.column_string(5 + item) + "1",
f"{contest_problems[item].title}",
)
for i, p in enumerate(contest_problems):
worksheet.write(self.column_string(5 + i) + "1", p.title)
for index, item in enumerate(data):
worksheet.write_string(index + 1, 0, str(item["user"]["id"]))
worksheet.write_string(index + 1, 1, item["user"]["username"])
@@ -201,17 +199,14 @@ class ContestRankAPI(APIView):
worksheet.write_string(index + 1, 3, str(item["total_score"]))
for k, v in item["submission_info"].items():
worksheet.write_string(
index + 1, 4 + problem_ids.index(int(k)), str(v)
index + 1, 4 + problem_id_to_col[int(k)], str(v)
)
else:
worksheet.write("D1", "AC")
worksheet.write("E1", "Total Submission")
worksheet.write("F1", "Total Time")
for item in range(contest_problems.count()):
worksheet.write(
self.column_string(7 + item) + "1",
f"{contest_problems[item].title}",
)
for i, p in enumerate(contest_problems):
worksheet.write(self.column_string(7 + i) + "1", p.title)
for index, item in enumerate(data):
worksheet.write_string(index + 1, 0, str(item["user"]["id"]))
@@ -224,7 +219,7 @@ class ContestRankAPI(APIView):
worksheet.write_string(index + 1, 5, str(item["total_time"]))
for k, v in item["submission_info"].items():
worksheet.write_string(
index + 1, 6 + problem_ids.index(int(k)), str(v["is_ac"])
index + 1, 6 + problem_id_to_col[int(k)], str(v["is_ac"])
)
workbook.close()

View File

@@ -13,7 +13,7 @@ class MessageAPI(APIView):
@login_required
def get(self, request):
messages = Message.objects.select_related(
"recipient", "sender", "submission"
"recipient", "sender", "submission", "submission__problem"
).filter(recipient=request.user)
return self.success(self.paginate_data(request, messages, MessageSerializer))

View File

@@ -217,7 +217,7 @@ class _SysOptionsMeta(type):
def website_footer(cls, value):
cls._set_option(OptionKeys.website_footer, value)
@my_property
@my_property(ttl=DEFAULT_SHORT_TTL)
def allow_register(cls):
return cls._get_option(OptionKeys.allow_register)
@@ -249,7 +249,7 @@ class _SysOptionsMeta(type):
def smtp_config(cls, value):
cls._set_option(OptionKeys.smtp_config, value)
@my_property
@my_property(ttl=DEFAULT_SHORT_TTL)
def judge_server_token(cls):
return cls._get_option(OptionKeys.judge_server_token)
@@ -257,7 +257,7 @@ class _SysOptionsMeta(type):
def judge_server_token(cls, value):
cls._set_option(OptionKeys.judge_server_token, value)
@my_property
@my_property(ttl=DEFAULT_SHORT_TTL)
def throttling(cls):
return cls._get_option(OptionKeys.throttling)

View File

@@ -0,0 +1,20 @@
# Generated by Django 6.0 on 2026-03-30 15:28
from django.conf import settings
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('contest', '0002_acmcontestrank_acm_rank_order_idx_and_more'),
('problem', '0005_remove_spj_fields'),
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
]
operations = [
migrations.AddIndex(
model_name='problem',
index=models.Index(fields=['contest', 'visible'], name='problem_contest_visible_idx'),
),
]

View File

@@ -93,6 +93,9 @@ class Problem(models.Model):
db_table = "problem"
unique_together = (("_id", "contest"),)
ordering = ("create_time",)
indexes = [
models.Index(fields=["contest", "visible"], name="problem_contest_visible_idx"),
]
def add_submission_number(self):
self.submission_number = models.F("submission_number") + 1

View File

@@ -89,6 +89,7 @@ class ProblemAPI(APIView):
problems = (
Problem.objects.select_related("created_by")
.prefetch_related("tags")
.filter(contest_id__isnull=True, visible=True)
.order_by("-create_time")
)
@@ -162,7 +163,7 @@ class ContestProblemAPI(APIView):
problem_data = ProblemSafeSerializer(problem).data
return self.success(problem_data)
contest_problems = Problem.objects.select_related("created_by").filter(
contest_problems = Problem.objects.select_related("created_by").prefetch_related("tags").filter(
contest=self.contest, visible=True
)
if self.contest.problem_details_permission(request.user):
@@ -229,7 +230,9 @@ class SimilarProblemAPI(APIView):
exclude_ids.extend(ac_display_ids)
similar = (
Problem.objects.filter(tags__in=tag_ids, visible=True, contest__isnull=True)
Problem.objects.select_related("created_by")
.prefetch_related("tags")
.filter(tags__in=tag_ids, visible=True, contest__isnull=True)
.exclude(_id__in=exclude_ids)
.distinct()
.order_by("difficulty")[:5]
@@ -240,9 +243,8 @@ class SimilarProblemAPI(APIView):
class ProblemAuthorAPI(APIView):
def get(self, request):
show_all = request.GET.get("all", "0") == "1"
cached_data = cache.get(
f"{CacheKey.problem_authors}{'_all' if show_all else '_only_visible'}"
)
cache_key = f"{CacheKey.problem_authors}{'_all' if show_all else '_only_visible'}"
cached_data = cache.get(cache_key)
if cached_data:
return self.success(cached_data)
@@ -264,5 +266,5 @@ class ProblemAuthorAPI(APIView):
for author in authors
]
cache.set(CacheKey.problem_authors, result, 7200)
cache.set(cache_key, result, 7200)
return self.success(result)

View File

@@ -170,15 +170,15 @@ class ProblemSetProgress(models.Model):
)
self.total_problems_count = problemset_problems.count()
# 获取当前题单中所有题目的ID集合
current_problem_ids = {str(psp.problem.id) for psp in problemset_problems}
# 获取当前题单中所有题目的ID集合(直接用 problem_id FK 字段,无需额外查询)
current_problem_ids = {str(psp.problem_id) for psp in problemset_problems}
# 清理已删除题目的进度记录
progress_detail_to_remove = []
for problem_id in self.progress_detail.keys():
if problem_id not in current_problem_ids:
progress_detail_to_remove.append(problem_id)
for problem_id in progress_detail_to_remove:
del self.progress_detail[problem_id]
@@ -187,7 +187,7 @@ class ProblemSetProgress(models.Model):
total_score = 0
for psp in problemset_problems:
problem_id = str(psp.problem.id)
problem_id = str(psp.problem_id)
if problem_id in self.progress_detail:
problem_progress = self.progress_detail[problem_id]
completed_count += 1

View File

@@ -183,16 +183,18 @@ class ProblemSetProblemSerializer(serializers.ModelSerializer):
def get_is_completed(self, obj):
"""获取当前用户是否已完成该题目"""
request = self.context.get("request")
if request and request.user.is_authenticated:
if not (request and request.user.is_authenticated):
return False
# 优先使用 view 层预取的进度对象,避免 N+1
progress = self.context.get("user_progress")
if progress is None:
try:
progress = ProblemSetProgress.objects.get(
problemset=obj.problemset, user=request.user
)
problem_id = str(obj.problem.id)
return problem_id in progress.progress_detail
except ProblemSetProgress.DoesNotExist:
return False
return False
return str(obj.problem.id) in progress.progress_detail
class AddProblemToSetSerializer(serializers.Serializer):

View File

@@ -137,11 +137,23 @@ class ProblemSetProblemAPI(APIView):
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
problems = ProblemSetProblem.objects.filter(problemset=problem_set).order_by(
"order"
problems = (
ProblemSetProblem.objects.filter(problemset=problem_set)
.select_related("problem__created_by")
.prefetch_related("problem__tags")
.order_by("order")
)
# 预取当前用户的题单进度,供 get_is_completed 使用,避免 N+1
user_progress = None
if request.user.is_authenticated:
try:
user_progress = ProblemSetProgress.objects.get(
problemset=problem_set, user=request.user
)
except ProblemSetProgress.DoesNotExist:
pass
serializer = ProblemSetProblemSerializer(
problems, many=True, context={"request": request}
problems, many=True, context={"request": request, "user_progress": user_progress}
)
return self.success(serializer.data)

View File

@@ -1,4 +1,5 @@
from django.db import models
from django.db.models import F
from django.utils import timezone
from .models import Submission
@@ -7,6 +8,33 @@ from utils.serializers import LanguageNameChoiceField
from problemset.models import ProblemSetProgress
def bulk_fetch_problemset_progress(user, problem_ids):
"""一次 IN 查询获取该用户对多个题目的题单进度,返回 {problem_id: ProblemSetProgress|None}"""
if not problem_ids:
return {}
rows = (
ProblemSetProgress.objects.filter(
user=user,
problemset__status="active",
problemset__problemsetproblem__problem_id__in=problem_ids,
)
.filter(
models.Q(problemset__end_time__isnull=True)
| models.Q(problemset__end_time__gt=timezone.now())
)
.annotate(matched_problem_id=F("problemset__problemsetproblem__problem_id"))
.only("join_time", "progress_detail")
)
cache = {}
for row in rows:
pid = row.matched_problem_id
if pid not in cache:
cache[pid] = row
for pid in problem_ids:
cache.setdefault(pid, None)
return cache
class CreateSubmissionSerializer(serializers.Serializer):
problem_id = serializers.IntegerField()
language = LanguageNameChoiceField()
@@ -44,7 +72,10 @@ class SubmissionListSerializer(serializers.ModelSerializer):
def __init__(self, *args, **kwargs):
self.user = kwargs.pop("user", None)
preloaded = kwargs.pop("problemset_progress_cache", None)
super().__init__(*args, **kwargs)
if preloaded is not None:
self._problemset_progress_cache = preloaded
class Meta:
model = Submission

View File

@@ -18,7 +18,7 @@ from ..serializers import (
SubmissionModelSerializer,
ShareSubmissionSerializer,
)
from ..serializers import SubmissionSafeModelSerializer, SubmissionListSerializer
from ..serializers import SubmissionSafeModelSerializer, SubmissionListSerializer, bulk_fetch_problemset_progress
class SubmissionAPI(APIView):
@@ -193,8 +193,14 @@ class SubmissionListAPI(APIView):
)
data = self.paginate_data(request, submissions)
results = data["results"]
if request.user.is_authenticated and request.user.is_regular_user():
problem_ids = list({s.problem_id for s in results})
progress_cache = bulk_fetch_problemset_progress(request.user, problem_ids)
else:
progress_cache = {}
data["results"] = SubmissionListSerializer(
data["results"], many=True, user=request.user
results, many=True, user=request.user, problemset_progress_cache=progress_cache
).data
return self.success(data)
@@ -241,8 +247,14 @@ class ContestSubmissionListAPI(APIView):
submissions = submissions.filter(user_id=request.user.id)
data = self.paginate_data(request, submissions)
results = data["results"]
if request.user.is_authenticated and request.user.is_regular_user():
problem_ids = list({s.problem_id for s in results})
progress_cache = bulk_fetch_problemset_progress(request.user, problem_ids)
else:
progress_cache = {}
data["results"] = SubmissionListSerializer(
data["results"], many=True, user=request.user
results, many=True, user=request.user, problemset_progress_cache=progress_cache
).data
return self.success(data)