change enum

This commit is contained in:
2026-05-09 02:30:47 -06:00
parent 78158471b2
commit c466dfd3c6
23 changed files with 451 additions and 503 deletions

View File

@@ -0,0 +1,54 @@
# Generated by Django 6.0.4 on 2026-05-09 08:18
from django.conf import settings
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("problem", "0008_alter_problem_unique_together_and_more"),
("problemset", "0007_problemset_end_time"),
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
]
operations = [
migrations.AlterUniqueTogether(
name="problemsetproblem",
unique_together=set(),
),
migrations.AlterUniqueTogether(
name="problemsetprogress",
unique_together=set(),
),
migrations.AlterUniqueTogether(
name="userbadge",
unique_together=set(),
),
migrations.AlterField(
model_name="problemset",
name="status",
field=models.TextField(choices=[("draft", "Draft"), ("active", "Active"), ("archived", "Archived")], default="draft", verbose_name="状态"),
),
migrations.AlterField(
model_name="problemset",
name="difficulty",
field=models.TextField(choices=[("Easy", "Easy"), ("Medium", "Medium"), ("Hard", "Hard")], default="Easy", verbose_name="难度等级"),
),
migrations.AlterField(
model_name="problemsetbadge",
name="condition_type",
field=models.TextField(choices=[("all_problems", "All Problems"), ("problem_count", "Problem Count"), ("score", "Score")], verbose_name="获得条件类型"),
),
migrations.AddConstraint(
model_name="problemsetproblem",
constraint=models.UniqueConstraint(fields=("problemset", "problem"), name="unique_problemset_problem"),
),
migrations.AddConstraint(
model_name="problemsetprogress",
constraint=models.UniqueConstraint(fields=("problemset", "user"), name="unique_problemset_progress_user"),
),
migrations.AddConstraint(
model_name="userbadge",
constraint=models.UniqueConstraint(fields=("user", "badge"), name="unique_user_badge"),
),
]

View File

@@ -6,15 +6,31 @@ from problem.models import Problem
from utils.models import JSONField, RichTextField
class ProblemSetStatus(models.TextChoices):
DRAFT = "draft", "Draft"
ACTIVE = "active", "Active"
ARCHIVED = "archived", "Archived"
class ProblemSetDifficulty(models.TextChoices):
EASY = "Easy", "Easy"
MEDIUM = "Medium", "Medium"
HARD = "Hard", "Hard"
class BadgeConditionType(models.TextChoices):
ALL_PROBLEMS = "all_problems", "All Problems"
PROBLEM_COUNT = "problem_count", "Problem Count"
SCORE = "score", "Score"
class ProblemSet(models.Model):
"""题单模型"""
title = models.TextField(verbose_name="题单标题")
description = RichTextField(verbose_name="题单描述")
# 创建者
created_by = models.ForeignKey(
User, on_delete=models.CASCADE, verbose_name="创建者"
)
created_by = models.ForeignKey(User, on_delete=models.CASCADE, verbose_name="创建者")
# 创建时间
create_time = models.DateTimeField(auto_now_add=True, verbose_name="创建时间")
# 更新时间
@@ -22,11 +38,13 @@ class ProblemSet(models.Model):
# 是否可见
visible = models.BooleanField(default=True, verbose_name="是否可见")
# 题单难度等级
difficulty = models.TextField(default="Easy", verbose_name="难度等级")
difficulty = models.TextField(
default=ProblemSetDifficulty.EASY,
choices=ProblemSetDifficulty.choices,
verbose_name="难度等级",
)
# 题单状态
status = models.TextField(
default="draft", verbose_name="状态"
) # active, archived, draft
status = models.TextField(default=ProblemSetStatus.DRAFT, choices=ProblemSetStatus.choices, verbose_name="状态")
# 截止时间(到期后自动解除防作弊隐藏)
end_time = models.DateTimeField(null=True, blank=True, verbose_name="截止时间")
@@ -43,9 +61,7 @@ class ProblemSet(models.Model):
class ProblemSetProblem(models.Model):
"""题单题目关联模型"""
problemset = models.ForeignKey(
ProblemSet, on_delete=models.CASCADE, verbose_name="题单"
)
problemset = models.ForeignKey(ProblemSet, on_delete=models.CASCADE, verbose_name="题单")
problem = models.ForeignKey(Problem, on_delete=models.CASCADE, verbose_name="题目")
# 在题单中的顺序
order = models.IntegerField(default=0, verbose_name="顺序")
@@ -58,7 +74,9 @@ class ProblemSetProblem(models.Model):
class Meta:
db_table = "problemset_problem"
unique_together = (("problemset", "problem"),)
constraints = [
models.UniqueConstraint(fields=["problemset", "problem"], name="unique_problemset_problem"),
]
ordering = ("order",)
verbose_name = "题单题目"
verbose_name_plural = "题单题目"
@@ -70,17 +88,13 @@ class ProblemSetProblem(models.Model):
class ProblemSetBadge(models.Model):
"""题单奖章模型"""
problemset = models.ForeignKey(
ProblemSet, on_delete=models.CASCADE, verbose_name="题单"
)
problemset = models.ForeignKey(ProblemSet, on_delete=models.CASCADE, verbose_name="题单")
name = models.TextField(verbose_name="奖章名称")
description = models.TextField(verbose_name="奖章描述")
# 奖章图标路径
icon = models.TextField(verbose_name="奖章图标")
# 获得条件:完成所有题目、完成指定数量题目、达到指定分数等
condition_type = models.TextField(
verbose_name="获得条件类型"
) # all_problems, problem_count, score
condition_type = models.TextField(choices=BadgeConditionType.choices, verbose_name="获得条件类型")
condition_value = models.IntegerField(default=0, verbose_name="条件值")
class Meta:
@@ -90,17 +104,13 @@ class ProblemSetBadge(models.Model):
def __str__(self):
return f"{self.problemset.title} - {self.name}"
def recalculate_user_badges(self):
"""重新计算所有用户的徽章资格"""
from django.db import transaction
user_progresses = ProblemSetProgress.objects.filter(problemset=self.problemset)
new_badges = [
UserBadge(user=progress.user, badge=self)
for progress in user_progresses
if self._is_eligible(progress)
]
new_badges = [UserBadge(user=progress.user, badge=self) for progress in user_progresses if self._is_eligible(progress)]
with transaction.atomic():
UserBadge.objects.filter(badge=self).delete()
if new_badges:
@@ -118,9 +128,7 @@ class ProblemSetBadge(models.Model):
def _check_user_badge_eligibility(self, progress):
"""检查并授予单个用户的徽章(供外部单次调用)"""
if self._is_eligible(progress) and not UserBadge.objects.filter(
user=progress.user, badge=self
).exists():
if self._is_eligible(progress) and not UserBadge.objects.filter(user=progress.user, badge=self).exists():
UserBadge.objects.create(user=progress.user, badge=self)
return True
return False
@@ -129,9 +137,7 @@ class ProblemSetBadge(models.Model):
class ProblemSetProgress(models.Model):
"""题单进度模型"""
problemset = models.ForeignKey(
ProblemSet, on_delete=models.CASCADE, verbose_name="题单"
)
problemset = models.ForeignKey(ProblemSet, on_delete=models.CASCADE, verbose_name="题单")
user = models.ForeignKey(User, on_delete=models.CASCADE, verbose_name="用户")
# 加入时间
join_time = models.DateTimeField(auto_now_add=True, verbose_name="加入时间")
@@ -142,9 +148,7 @@ class ProblemSetProgress(models.Model):
# 完成进度百分比
progress_percentage = models.FloatField(default=0.0, verbose_name="完成进度")
# 已完成的题目数量
completed_problems_count = models.IntegerField(
default=0, verbose_name="已完成题目数"
)
completed_problems_count = models.IntegerField(default=0, verbose_name="已完成题目数")
# 总题目数量
total_problems_count = models.IntegerField(default=0, verbose_name="总题目数")
# 获得的总分
@@ -155,7 +159,9 @@ class ProblemSetProgress(models.Model):
class Meta:
db_table = "problemset_progress"
unique_together = (("problemset", "user"),)
constraints = [
models.UniqueConstraint(fields=["problemset", "user"], name="unique_problemset_progress_user"),
]
verbose_name = "题单进度"
verbose_name_plural = "题单进度"
@@ -165,9 +171,7 @@ class ProblemSetProgress(models.Model):
def update_progress(self):
"""更新进度信息"""
# 获取题单中的所有题目
problemset_problems = ProblemSetProblem.objects.filter(
problemset=self.problemset
)
problemset_problems = ProblemSetProblem.objects.filter(problemset=self.problemset)
self.total_problems_count = problemset_problems.count()
# 获取当前题单中所有题目的ID集合直接用 problem_id FK 字段,无需额外查询)
@@ -199,9 +203,7 @@ class ProblemSetProgress(models.Model):
# 计算完成百分比
if self.total_problems_count > 0:
self.progress_percentage = (
completed_count / self.total_problems_count
) * 100
self.progress_percentage = (completed_count / self.total_problems_count) * 100
else:
self.progress_percentage = 0
@@ -223,17 +225,11 @@ class ProblemSetProgress(models.Model):
class ProblemSetSubmission(models.Model):
"""题单提交记录模型"""
problemset = models.ForeignKey(
ProblemSet, on_delete=models.CASCADE, verbose_name="题单"
)
problemset = models.ForeignKey(ProblemSet, on_delete=models.CASCADE, verbose_name="题单")
user = models.ForeignKey(User, on_delete=models.CASCADE, verbose_name="用户")
submission = models.ForeignKey(
"submission.Submission", on_delete=models.CASCADE, verbose_name="提交记录"
)
problem = models.ForeignKey(
"problem.Problem", on_delete=models.CASCADE, verbose_name="题目"
)
submission = models.ForeignKey("submission.Submission", on_delete=models.CASCADE, verbose_name="提交记录")
problem = models.ForeignKey("problem.Problem", on_delete=models.CASCADE, verbose_name="题目")
class Meta:
db_table = "problemset_submission"
@@ -253,34 +249,33 @@ class ProblemSetSubmission(models.Model):
def submit_time(self):
"""提交时间"""
return self.submission.create_time
@property
def result(self):
"""提交结果"""
return self.submission.result
@property
def language(self):
"""编程语言"""
return self.submission.language
class UserBadge(models.Model):
"""用户奖章模型"""
user = models.ForeignKey(User, on_delete=models.CASCADE, verbose_name="用户")
badge = models.ForeignKey(
ProblemSetBadge, on_delete=models.CASCADE, verbose_name="奖章"
)
badge = models.ForeignKey(ProblemSetBadge, on_delete=models.CASCADE, verbose_name="奖章")
# 获得时间
earned_time = models.DateTimeField(auto_now_add=True, verbose_name="获得时间")
class Meta:
db_table = "user_badge"
unique_together = (("user", "badge"),)
constraints = [
models.UniqueConstraint(fields=["user", "badge"], name="unique_user_badge"),
]
verbose_name = "用户奖章"
verbose_name_plural = "用户奖章"
def __str__(self):
return f"{self.user.username} - {self.badge.name}"

View File

@@ -1,10 +1,13 @@
from utils.api import UsernameSerializer, serializers
from .models import (
BadgeConditionType,
ProblemSet,
ProblemSetBadge,
ProblemSetDifficulty,
ProblemSetProblem,
ProblemSetProgress,
ProblemSetStatus,
UserBadge,
)
@@ -13,9 +16,7 @@ def get_user_progress_data(problemset, request):
"""获取当前用户在该题单中的进度 - 公共方法"""
if request and request.user.is_authenticated:
try:
progress = ProblemSetProgress.objects.get(
problemset=problemset, user=request.user
)
progress = ProblemSetProgress.objects.get(problemset=problemset, user=request.user)
return {
"is_joined": True,
"progress_percentage": progress.progress_percentage,
@@ -61,9 +62,7 @@ class ProblemSetSerializer(serializers.ModelSerializer):
request = self.context.get("request")
if request and request.user.is_authenticated:
try:
progress = ProblemSetProgress.objects.get(
problemset=obj, user=request.user
)
progress = ProblemSetProgress.objects.get(problemset=obj, user=request.user)
return progress.completed_problems_count
except ProblemSetProgress.DoesNotExist:
return 0
@@ -124,22 +123,22 @@ class ProblemSetListSerializer(serializers.ModelSerializer):
def get_badges(self, obj):
"""获取题单的奖章列表,并标记用户已获得的徽章"""
request = self.context.get("request")
# 使用预加载的奖章数据
badges = getattr(obj, "badges", [])
badge_data = ProblemSetBadgeSerializer(badges, many=True).data
# 如果用户已登录,检查哪些徽章已被获得
if request and request.user.is_authenticated and hasattr(request, "_user_earned_badge_ids"):
earned_badge_ids = request._user_earned_badge_ids
# 为每个徽章添加是否已获得的标记
for badge in badge_data:
badge['is_earned'] = badge['id'] in earned_badge_ids
badge["is_earned"] = badge["id"] in earned_badge_ids
else:
# 未登录用户或未预加载,所有徽章都标记为未获得
for badge in badge_data:
badge['is_earned'] = False
badge["is_earned"] = False
return badge_data
@@ -148,8 +147,8 @@ class CreateProblemSetSerializer(serializers.Serializer):
title = serializers.CharField(max_length=200)
description = serializers.CharField()
difficulty = serializers.CharField(default="Easy")
status = serializers.CharField(default="active")
difficulty = serializers.ChoiceField(choices=ProblemSetDifficulty.choices, default=ProblemSetDifficulty.EASY)
status = serializers.ChoiceField(choices=ProblemSetStatus.choices, default=ProblemSetStatus.ACTIVE)
end_time = serializers.DateTimeField(required=False)
@@ -159,8 +158,8 @@ class EditProblemSetSerializer(serializers.Serializer):
id = serializers.IntegerField()
title = serializers.CharField(max_length=200, required=False)
description = serializers.CharField(required=False)
difficulty = serializers.CharField(required=False)
status = serializers.CharField(required=False)
difficulty = serializers.ChoiceField(choices=ProblemSetDifficulty.choices, required=False)
status = serializers.ChoiceField(choices=ProblemSetStatus.choices, required=False)
visible = serializers.BooleanField(required=False)
end_time = serializers.DateTimeField(required=False, allow_null=True)
@@ -190,9 +189,7 @@ class ProblemSetProblemSerializer(serializers.ModelSerializer):
progress = self.context.get("user_progress")
if progress is None:
try:
progress = ProblemSetProgress.objects.get(
problemset=obj.problemset, user=request.user
)
progress = ProblemSetProgress.objects.get(problemset=obj.problemset, user=request.user)
except ProblemSetProgress.DoesNotExist:
return False
return str(obj.problem.id) in progress.progress_detail
@@ -227,19 +224,21 @@ class ProblemSetBadgeSerializer(serializers.ModelSerializer):
class CreateProblemSetBadgeSerializer(serializers.Serializer):
"""创建题单奖章序列化器"""
name = serializers.CharField(max_length=100)
description = serializers.CharField()
icon = serializers.CharField()
condition_type = serializers.CharField() # all_problems, problem_count, score
condition_type = serializers.ChoiceField(choices=BadgeConditionType.choices)
condition_value = serializers.IntegerField(required=False)
class EditProblemSetBadgeSerializer(serializers.Serializer):
"""编辑题单奖章序列化器"""
name = serializers.CharField(max_length=100, required=False)
description = serializers.CharField(required=False)
icon = serializers.CharField(required=False)
condition_type = serializers.CharField(required=False) # all_problems, problem_count, score
condition_type = serializers.ChoiceField(choices=BadgeConditionType.choices, required=False)
condition_value = serializers.IntegerField(required=False)
@@ -252,42 +251,35 @@ class ProblemSetProgressSerializer(serializers.ModelSerializer):
class Meta:
model = ProblemSetProgress
fields = "__all__"
def get_completed_problems(self, obj):
"""获取已完成的题目列表"""
completed_problems = []
# 尝试从 request 中获取预加载的问题字典(用于批量查询优化)
problems_dict = {}
request = self.context.get('request')
if request and hasattr(request, '_problems_dict_cache'):
request = self.context.get("request")
if request and hasattr(request, "_problems_dict_cache"):
problems_dict = request._problems_dict_cache
if obj.progress_detail:
for problem_id in obj.progress_detail.keys():
# 优先使用预加载的问题字典
if problems_dict:
problem = problems_dict.get(problem_id)
if problem:
completed_problems.append({
'id': problem.id,
'_id': problem._id,
'title': problem.title
})
completed_problems.append({"id": problem.id, "_id": problem._id, "title": problem.title})
continue
# 如果没有预加载字典,则回退到单独查询(向后兼容)
from problem.models import Problem
try:
problem = Problem.objects.get(id=problem_id)
completed_problems.append({
'id': problem.id,
'_id': problem._id,
'title': problem.title
})
completed_problems.append({"id": problem.id, "_id": problem._id, "title": problem.title})
except Problem.DoesNotExist:
continue
return completed_problems
@@ -313,5 +305,3 @@ class UpdateProgressSerializer(serializers.Serializer):
problemset_id = serializers.IntegerField()
problem_id = serializers.IntegerField()
submission_id = serializers.CharField(required=False)

View File

@@ -7,6 +7,7 @@ from problemset.models import (
ProblemSetBadge,
ProblemSetProblem,
ProblemSetProgress,
ProblemSetStatus,
)
from problemset.serializers import (
AddProblemToSetSerializer,
@@ -35,9 +36,7 @@ class ProblemSetAdminAPI(APIView):
# 过滤条件
keyword = request.GET.get("keyword", "").strip()
if keyword:
problem_sets = problem_sets.filter(
Q(title__icontains=keyword) | Q(description__icontains=keyword)
)
problem_sets = problem_sets.filter(Q(title__icontains=keyword) | Q(description__icontains=keyword))
difficulty = request.GET.get("difficulty")
if difficulty:
@@ -129,12 +128,8 @@ class ProblemSetProblemAdminAPI(APIView):
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
problems = ProblemSetProblem.objects.filter(problemset=problem_set).order_by(
"order"
)
serializer = ProblemSetProblemSerializer(
problems, many=True, context={"request": request}
)
problems = ProblemSetProblem.objects.filter(problemset=problem_set).order_by("order")
serializer = ProblemSetProblemSerializer(problems, many=True, context={"request": request})
return self.success(serializer.data)
@super_admin_required
@@ -158,9 +153,7 @@ class ProblemSetProblemAdminAPI(APIView):
return self.error("题目不存在或不可见")
# 检查题目是否已经在题单中
if ProblemSetProblem.objects.filter(
problemset=problem_set, problem=problem
).exists():
if ProblemSetProblem.objects.filter(problemset=problem_set, problem=problem).exists():
return self.error("题目已在该题单中")
ProblemSetProblem.objects.create(
@@ -188,9 +181,7 @@ class ProblemSetProblemAdminAPI(APIView):
return self.error("题单不存在")
try:
problem_set_problem = ProblemSetProblem.objects.get(
id=problem_set_problem_id, problemset=problem_set
)
problem_set_problem = ProblemSetProblem.objects.get(id=problem_set_problem_id, problemset=problem_set)
except ProblemSetProblem.DoesNotExist:
return self.error("题目不在该题单中")
@@ -206,10 +197,10 @@ class ProblemSetProblemAdminAPI(APIView):
problem_set_problem.hint = data["hint"]
problem_set_problem.save()
# 同步所有用户的进度
ProblemSetProgress.sync_all_progress_for_problemset(problem_set)
return self.success("题目已更新")
@super_admin_required
@@ -222,14 +213,12 @@ class ProblemSetProblemAdminAPI(APIView):
return self.error("题单不存在")
try:
problem_set_problem = ProblemSetProblem.objects.get(
id=problem_set_problem_id, problemset=problem_set
)
problem_set_problem = ProblemSetProblem.objects.get(id=problem_set_problem_id, problemset=problem_set)
problem_set_problem.delete()
# 同步所有用户的进度
ProblemSetProgress.sync_all_progress_for_problemset(problem_set)
return self.success("题目已从题单中移除")
except ProblemSetProblem.DoesNotExist:
return self.error("题目不在该题单中")
@@ -283,10 +272,10 @@ class ProblemSetBadgeAdminAPI(APIView):
return self.error("奖章不存在")
data = request.data
# 记录是否修改了条件相关的字段
condition_changed = False
# 更新奖章属性
if "name" in data:
badge.name = data["name"]
@@ -304,7 +293,7 @@ class ProblemSetBadgeAdminAPI(APIView):
badge.level = data["level"]
badge.save()
# 如果修改了条件,重新计算所有用户的徽章资格
if condition_changed:
try:
@@ -312,7 +301,7 @@ class ProblemSetBadgeAdminAPI(APIView):
return self.success("奖章已更新,并重新计算了所有用户的徽章资格")
except Exception as e:
return self.error(f"奖章已更新,但重新计算徽章资格时出错: {str(e)}")
return self.success("奖章已更新")
@super_admin_required
@@ -344,9 +333,7 @@ class ProblemSetProgressAdminAPI(APIView):
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
progress_list = ProblemSetProgress.objects.filter(
problemset=problem_set
).order_by("-join_time")
progress_list = ProblemSetProgress.objects.filter(problemset=problem_set).order_by("-join_time")
serializer = ProblemSetProgressSerializer(progress_list, many=True)
return self.success(serializer.data)
@@ -360,9 +347,7 @@ class ProblemSetProgressAdminAPI(APIView):
return self.error("题单不存在")
try:
progress = ProblemSetProgress.objects.get(
problemset=problem_set, user_id=user_id
)
progress = ProblemSetProgress.objects.get(problemset=problem_set, user_id=user_id)
progress.delete()
return self.success("用户已从题单中移除")
except ProblemSetProgress.DoesNotExist:
@@ -371,7 +356,7 @@ class ProblemSetProgressAdminAPI(APIView):
class ProblemSetSyncAPI(APIView):
"""题单同步管理API"""
@super_admin_required
def post(self, request, problem_set_id):
"""手动同步题单的所有用户进度(管理员)"""
@@ -380,10 +365,10 @@ class ProblemSetSyncAPI(APIView):
ensure_created_by(problem_set, request.user)
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
# 同步所有用户的进度
synced_count = ProblemSetProgress.sync_all_progress_for_problemset(problem_set)
return self.success(f"已同步 {synced_count} 个用户的进度")
@@ -419,7 +404,7 @@ class ProblemSetStatusAPI(APIView):
return self.error("题单不存在")
status = data.get("status")
if status not in ["active", "archived", "draft"]:
if status not in ProblemSetStatus.values:
return self.error("无效的状态")
problem_set.status = status

View File

@@ -32,18 +32,14 @@ class ProblemSetAPI(APIView):
"""获取题单列表"""
# 预加载创建者信息
problem_sets = ProblemSet.objects.filter(visible=True).exclude(status="draft").select_related("created_by")
# 使用annotate在查询时计算题目数量避免N+1查询
problem_sets = problem_sets.annotate(
problems_count=Count("problemsetproblem", distinct=True)
)
problem_sets = problem_sets.annotate(problems_count=Count("problemsetproblem", distinct=True))
# 过滤条件
keyword = request.GET.get("keyword", "").strip()
if keyword:
problem_sets = problem_sets.filter(
Q(title__icontains=keyword) | Q(description__icontains=keyword)
)
problem_sets = problem_sets.filter(Q(title__icontains=keyword) | Q(description__icontains=keyword))
difficulty = request.GET.get("difficulty")
if difficulty:
@@ -67,33 +63,19 @@ class ProblemSetAPI(APIView):
if request.user.is_authenticated:
# 先获取所有题单ID不应用prefetch_related只获取ID
problem_set_ids = list(problem_sets.values_list("id", flat=True))
if problem_set_ids:
# 批量查询用户在这些题单中的进度
user_progresses = ProblemSetProgress.objects.filter(
problemset_id__in=problem_set_ids,
user=request.user
).select_related("problemset")
user_progresses = ProblemSetProgress.objects.filter(problemset_id__in=problem_set_ids, user=request.user).select_related("problemset")
# 构建映射题单ID -> 进度对象
user_progress_map = {progress.problemset_id: progress for progress in user_progresses}
# 批量查询用户已获得的奖章ID这些题单相关的
user_earned_badge_ids = set(
UserBadge.objects.filter(
user=request.user,
badge__problemset_id__in=problem_set_ids
).values_list('badge_id', flat=True)
)
user_earned_badge_ids = set(UserBadge.objects.filter(user=request.user, badge__problemset_id__in=problem_set_ids).values_list("badge_id", flat=True))
# 预加载奖章信息在获取ID之后应用避免在获取ID时也预加载
problem_sets = problem_sets.prefetch_related(
Prefetch(
"problemsetbadge_set",
queryset=ProblemSetBadge.objects.all(),
to_attr="badges"
)
)
problem_sets = problem_sets.prefetch_related(Prefetch("problemsetbadge_set", queryset=ProblemSetBadge.objects.all(), to_attr="badges"))
# 将用户进度映射和已获得的奖章ID集合存储到request中供序列化器使用
request._user_progress_map = user_progress_map
request._user_earned_badge_ids = user_earned_badge_ids
@@ -108,11 +90,7 @@ class ProblemSetDetailAPI(APIView):
def get(self, request, problem_set_id):
"""获取题单详情"""
try:
problem_set = (
ProblemSet.objects.filter(id=problem_set_id, visible=True)
.exclude(status="draft")
.get()
)
problem_set = ProblemSet.objects.filter(id=problem_set_id, visible=True).exclude(status="draft").get()
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
@@ -126,32 +104,19 @@ class ProblemSetProblemAPI(APIView):
def get(self, request, problem_set_id):
"""获取题单中的题目列表"""
try:
problem_set = (
ProblemSet.objects.filter(id=problem_set_id, visible=True)
.exclude(status="draft")
.get()
)
problem_set = ProblemSet.objects.filter(id=problem_set_id, visible=True).exclude(status="draft").get()
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
problems = (
ProblemSetProblem.objects.filter(problemset=problem_set)
.select_related("problem__created_by")
.prefetch_related("problem__tags")
.order_by("order")
)
problems = ProblemSetProblem.objects.filter(problemset=problem_set).select_related("problem__created_by").prefetch_related("problem__tags").order_by("order")
# 预取当前用户的题单进度,供 get_is_completed 使用,避免 N+1
user_progress = None
if request.user.is_authenticated:
try:
user_progress = ProblemSetProgress.objects.get(
problemset=problem_set, user=request.user
)
user_progress = ProblemSetProgress.objects.get(problemset=problem_set, user=request.user)
except ProblemSetProgress.DoesNotExist:
pass
serializer = ProblemSetProblemSerializer(
problems, many=True, context={"request": request, "user_progress": user_progress}
)
serializer = ProblemSetProblemSerializer(problems, many=True, context={"request": request, "user_progress": user_progress})
return self.success(serializer.data)
@@ -163,23 +128,15 @@ class ProblemSetProgressAPI(APIView):
"""加入题单"""
data = request.data
try:
problem_set = (
ProblemSet.objects.filter(id=data["problemset_id"], visible=True)
.exclude(status="draft")
.get()
)
problem_set = ProblemSet.objects.filter(id=data["problemset_id"], visible=True).exclude(status="draft").get()
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
if ProblemSetProgress.objects.filter(
problemset=problem_set, user=request.user
).exists():
if ProblemSetProgress.objects.filter(problemset=problem_set, user=request.user).exists():
return self.error("已经加入该题单")
# 创建进度记录
progress = ProblemSetProgress.objects.create(
problemset=problem_set, user=request.user
)
progress = ProblemSetProgress.objects.create(problemset=problem_set, user=request.user)
progress.update_progress()
return self.success("成功加入题单")
@@ -187,18 +144,12 @@ class ProblemSetProgressAPI(APIView):
def get(self, request, problem_set_id):
"""获取题单进度"""
try:
problem_set = (
ProblemSet.objects.filter(id=problem_set_id, visible=True)
.exclude(status="draft")
.get()
)
problem_set = ProblemSet.objects.filter(id=problem_set_id, visible=True).exclude(status="draft").get()
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
try:
progress = ProblemSetProgress.objects.get(
problemset=problem_set, user=request.user
)
progress = ProblemSetProgress.objects.get(problemset=problem_set, user=request.user)
except ProblemSetProgress.DoesNotExist:
return self.error("未加入该题单")
@@ -210,18 +161,12 @@ class ProblemSetProgressAPI(APIView):
"""更新进度"""
data = request.data
try:
problem_set = (
ProblemSet.objects.filter(id=data["problemset_id"], visible=True)
.exclude(status="draft")
.get()
)
problem_set = ProblemSet.objects.filter(id=data["problemset_id"], visible=True).exclude(status="draft").get()
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
try:
progress = ProblemSetProgress.objects.get(
problemset=problem_set, user=request.user
)
progress = ProblemSetProgress.objects.get(problemset=problem_set, user=request.user)
except ProblemSetProgress.DoesNotExist:
return self.error("未加入该题单")
@@ -230,9 +175,7 @@ class ProblemSetProgressAPI(APIView):
# 获取该题目在题单中的分值
try:
problemset_problem = ProblemSetProblem.objects.get(
problemset=problem_set, problem_id=problem_id
)
problemset_problem = ProblemSetProblem.objects.get(problemset=problem_set, problem_id=problem_id)
problem_score = problemset_problem.score
except ProblemSetProblem.DoesNotExist:
problem_score = 0
@@ -296,9 +239,7 @@ class UserProgressAPI(APIView):
def get(self, request):
"""获取用户的题单进度列表"""
progress_list = ProblemSetProgress.objects.filter(user=request.user).order_by(
"-join_time"
)
progress_list = ProblemSetProgress.objects.filter(user=request.user).order_by("-join_time")
serializer = ProblemSetProgressSerializer(progress_list, many=True)
return self.success(serializer.data)
@@ -315,16 +256,12 @@ class UserBadgeAPI(APIView):
# 获取指定用户的徽章
try:
target_user = User.objects.get(username=username, is_disabled=False)
badges = UserBadge.objects.filter(user=target_user).order_by(
"-earned_time"
)
badges = UserBadge.objects.filter(user=target_user).order_by("-earned_time")
except User.DoesNotExist:
return self.error("用户不存在")
else:
# 获取当前用户的徽章
badges = UserBadge.objects.filter(user=request.user).order_by(
"-earned_time"
)
badges = UserBadge.objects.filter(user=request.user).order_by("-earned_time")
serializer = UserBadgeSerializer(badges, many=True)
return self.success(serializer.data)
@@ -336,11 +273,7 @@ class ProblemSetBadgeAPI(APIView):
def get(self, request, problem_set_id):
"""获取题单的奖章列表"""
try:
problem_set = (
ProblemSet.objects.filter(id=problem_set_id, visible=True)
.exclude(status="draft")
.get()
)
problem_set = ProblemSet.objects.filter(id=problem_set_id, visible=True).exclude(status="draft").get()
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
@@ -355,18 +288,12 @@ class ProblemSetUserProgressAPI(APIView):
def get(self, request, problem_set_id: int):
"""获取题单的用户进度列表"""
try:
problem_set = (
ProblemSet.objects.filter(id=problem_set_id, visible=True)
.exclude(status="draft")
.get()
)
problem_set = ProblemSet.objects.filter(id=problem_set_id, visible=True).exclude(status="draft").get()
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
# 获取所有参与该题单的用户进度,使用 select_related 预加载用户信息
progresses = ProblemSetProgress.objects.filter(
problemset=problem_set
).select_related("user")
progresses = ProblemSetProgress.objects.filter(problemset=problem_set).select_related("user")
# 班级过滤
class_name = request.GET.get("class_name", "").strip()
@@ -386,9 +313,7 @@ class ProblemSetUserProgressAPI(APIView):
progresses = progresses.filter(completed_problems_count=0)
# 排序
progresses = progresses.order_by(
"-is_completed", "-progress_percentage", "join_time"
)
progresses = progresses.order_by("-is_completed", "-progress_percentage", "join_time")
# 计算统计数据(基于所有数据,而非分页数据)
# 使用一次查询获取所有统计数据
@@ -416,12 +341,9 @@ class ProblemSetUserProgressAPI(APIView):
# 提前获取题单的所有题目(用于前端显示未完成题目和序列化器)
# 使用 select_related 和 only 优化查询,只选择需要的字段
all_problemset_problems = (
ProblemSetProblem.objects.filter(problemset=problem_set)
.select_related("problem")
.only("problem__id", "problem___id", "problem__title", "order")
.order_by("order")
ProblemSetProblem.objects.filter(problemset=problem_set).select_related("problem").only("problem__id", "problem___id", "problem__title", "order").order_by("order")
)
# 构建题单所有题目的数据结构和映射
all_problems_list = []
all_problems_map = {}
@@ -444,11 +366,7 @@ class ProblemSetUserProgressAPI(APIView):
completed_problem_ids.update(progress.progress_detail.keys())
# 从已加载的题单题目中构建 problems_dict避免重复查询
problems_dict = {
pid: all_problems_map[pid]
for pid in completed_problem_ids
if pid in all_problems_map
}
problems_dict = {pid: all_problems_map[pid] for pid in completed_problem_ids if pid in all_problems_map}
# 将预加载的问题字典存储到 request 中,供序列化器使用
request._problems_dict_cache = problems_dict