change enum

This commit is contained in:
2026-05-09 02:30:47 -06:00
parent 78158471b2
commit c466dfd3c6
23 changed files with 451 additions and 503 deletions

View File

@@ -32,18 +32,14 @@ class ProblemSetAPI(APIView):
"""获取题单列表"""
# 预加载创建者信息
problem_sets = ProblemSet.objects.filter(visible=True).exclude(status="draft").select_related("created_by")
# 使用annotate在查询时计算题目数量避免N+1查询
problem_sets = problem_sets.annotate(
problems_count=Count("problemsetproblem", distinct=True)
)
problem_sets = problem_sets.annotate(problems_count=Count("problemsetproblem", distinct=True))
# 过滤条件
keyword = request.GET.get("keyword", "").strip()
if keyword:
problem_sets = problem_sets.filter(
Q(title__icontains=keyword) | Q(description__icontains=keyword)
)
problem_sets = problem_sets.filter(Q(title__icontains=keyword) | Q(description__icontains=keyword))
difficulty = request.GET.get("difficulty")
if difficulty:
@@ -67,33 +63,19 @@ class ProblemSetAPI(APIView):
if request.user.is_authenticated:
# 先获取所有题单ID不应用prefetch_related只获取ID
problem_set_ids = list(problem_sets.values_list("id", flat=True))
if problem_set_ids:
# 批量查询用户在这些题单中的进度
user_progresses = ProblemSetProgress.objects.filter(
problemset_id__in=problem_set_ids,
user=request.user
).select_related("problemset")
user_progresses = ProblemSetProgress.objects.filter(problemset_id__in=problem_set_ids, user=request.user).select_related("problemset")
# 构建映射题单ID -> 进度对象
user_progress_map = {progress.problemset_id: progress for progress in user_progresses}
# 批量查询用户已获得的奖章ID这些题单相关的
user_earned_badge_ids = set(
UserBadge.objects.filter(
user=request.user,
badge__problemset_id__in=problem_set_ids
).values_list('badge_id', flat=True)
)
user_earned_badge_ids = set(UserBadge.objects.filter(user=request.user, badge__problemset_id__in=problem_set_ids).values_list("badge_id", flat=True))
# 预加载奖章信息在获取ID之后应用避免在获取ID时也预加载
problem_sets = problem_sets.prefetch_related(
Prefetch(
"problemsetbadge_set",
queryset=ProblemSetBadge.objects.all(),
to_attr="badges"
)
)
problem_sets = problem_sets.prefetch_related(Prefetch("problemsetbadge_set", queryset=ProblemSetBadge.objects.all(), to_attr="badges"))
# 将用户进度映射和已获得的奖章ID集合存储到request中供序列化器使用
request._user_progress_map = user_progress_map
request._user_earned_badge_ids = user_earned_badge_ids
@@ -108,11 +90,7 @@ class ProblemSetDetailAPI(APIView):
def get(self, request, problem_set_id):
"""获取题单详情"""
try:
problem_set = (
ProblemSet.objects.filter(id=problem_set_id, visible=True)
.exclude(status="draft")
.get()
)
problem_set = ProblemSet.objects.filter(id=problem_set_id, visible=True).exclude(status="draft").get()
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
@@ -126,32 +104,19 @@ class ProblemSetProblemAPI(APIView):
def get(self, request, problem_set_id):
"""获取题单中的题目列表"""
try:
problem_set = (
ProblemSet.objects.filter(id=problem_set_id, visible=True)
.exclude(status="draft")
.get()
)
problem_set = ProblemSet.objects.filter(id=problem_set_id, visible=True).exclude(status="draft").get()
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
problems = (
ProblemSetProblem.objects.filter(problemset=problem_set)
.select_related("problem__created_by")
.prefetch_related("problem__tags")
.order_by("order")
)
problems = ProblemSetProblem.objects.filter(problemset=problem_set).select_related("problem__created_by").prefetch_related("problem__tags").order_by("order")
# 预取当前用户的题单进度,供 get_is_completed 使用,避免 N+1
user_progress = None
if request.user.is_authenticated:
try:
user_progress = ProblemSetProgress.objects.get(
problemset=problem_set, user=request.user
)
user_progress = ProblemSetProgress.objects.get(problemset=problem_set, user=request.user)
except ProblemSetProgress.DoesNotExist:
pass
serializer = ProblemSetProblemSerializer(
problems, many=True, context={"request": request, "user_progress": user_progress}
)
serializer = ProblemSetProblemSerializer(problems, many=True, context={"request": request, "user_progress": user_progress})
return self.success(serializer.data)
@@ -163,23 +128,15 @@ class ProblemSetProgressAPI(APIView):
"""加入题单"""
data = request.data
try:
problem_set = (
ProblemSet.objects.filter(id=data["problemset_id"], visible=True)
.exclude(status="draft")
.get()
)
problem_set = ProblemSet.objects.filter(id=data["problemset_id"], visible=True).exclude(status="draft").get()
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
if ProblemSetProgress.objects.filter(
problemset=problem_set, user=request.user
).exists():
if ProblemSetProgress.objects.filter(problemset=problem_set, user=request.user).exists():
return self.error("已经加入该题单")
# 创建进度记录
progress = ProblemSetProgress.objects.create(
problemset=problem_set, user=request.user
)
progress = ProblemSetProgress.objects.create(problemset=problem_set, user=request.user)
progress.update_progress()
return self.success("成功加入题单")
@@ -187,18 +144,12 @@ class ProblemSetProgressAPI(APIView):
def get(self, request, problem_set_id):
"""获取题单进度"""
try:
problem_set = (
ProblemSet.objects.filter(id=problem_set_id, visible=True)
.exclude(status="draft")
.get()
)
problem_set = ProblemSet.objects.filter(id=problem_set_id, visible=True).exclude(status="draft").get()
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
try:
progress = ProblemSetProgress.objects.get(
problemset=problem_set, user=request.user
)
progress = ProblemSetProgress.objects.get(problemset=problem_set, user=request.user)
except ProblemSetProgress.DoesNotExist:
return self.error("未加入该题单")
@@ -210,18 +161,12 @@ class ProblemSetProgressAPI(APIView):
"""更新进度"""
data = request.data
try:
problem_set = (
ProblemSet.objects.filter(id=data["problemset_id"], visible=True)
.exclude(status="draft")
.get()
)
problem_set = ProblemSet.objects.filter(id=data["problemset_id"], visible=True).exclude(status="draft").get()
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
try:
progress = ProblemSetProgress.objects.get(
problemset=problem_set, user=request.user
)
progress = ProblemSetProgress.objects.get(problemset=problem_set, user=request.user)
except ProblemSetProgress.DoesNotExist:
return self.error("未加入该题单")
@@ -230,9 +175,7 @@ class ProblemSetProgressAPI(APIView):
# 获取该题目在题单中的分值
try:
problemset_problem = ProblemSetProblem.objects.get(
problemset=problem_set, problem_id=problem_id
)
problemset_problem = ProblemSetProblem.objects.get(problemset=problem_set, problem_id=problem_id)
problem_score = problemset_problem.score
except ProblemSetProblem.DoesNotExist:
problem_score = 0
@@ -296,9 +239,7 @@ class UserProgressAPI(APIView):
def get(self, request):
"""获取用户的题单进度列表"""
progress_list = ProblemSetProgress.objects.filter(user=request.user).order_by(
"-join_time"
)
progress_list = ProblemSetProgress.objects.filter(user=request.user).order_by("-join_time")
serializer = ProblemSetProgressSerializer(progress_list, many=True)
return self.success(serializer.data)
@@ -315,16 +256,12 @@ class UserBadgeAPI(APIView):
# 获取指定用户的徽章
try:
target_user = User.objects.get(username=username, is_disabled=False)
badges = UserBadge.objects.filter(user=target_user).order_by(
"-earned_time"
)
badges = UserBadge.objects.filter(user=target_user).order_by("-earned_time")
except User.DoesNotExist:
return self.error("用户不存在")
else:
# 获取当前用户的徽章
badges = UserBadge.objects.filter(user=request.user).order_by(
"-earned_time"
)
badges = UserBadge.objects.filter(user=request.user).order_by("-earned_time")
serializer = UserBadgeSerializer(badges, many=True)
return self.success(serializer.data)
@@ -336,11 +273,7 @@ class ProblemSetBadgeAPI(APIView):
def get(self, request, problem_set_id):
"""获取题单的奖章列表"""
try:
problem_set = (
ProblemSet.objects.filter(id=problem_set_id, visible=True)
.exclude(status="draft")
.get()
)
problem_set = ProblemSet.objects.filter(id=problem_set_id, visible=True).exclude(status="draft").get()
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
@@ -355,18 +288,12 @@ class ProblemSetUserProgressAPI(APIView):
def get(self, request, problem_set_id: int):
"""获取题单的用户进度列表"""
try:
problem_set = (
ProblemSet.objects.filter(id=problem_set_id, visible=True)
.exclude(status="draft")
.get()
)
problem_set = ProblemSet.objects.filter(id=problem_set_id, visible=True).exclude(status="draft").get()
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
# 获取所有参与该题单的用户进度,使用 select_related 预加载用户信息
progresses = ProblemSetProgress.objects.filter(
problemset=problem_set
).select_related("user")
progresses = ProblemSetProgress.objects.filter(problemset=problem_set).select_related("user")
# 班级过滤
class_name = request.GET.get("class_name", "").strip()
@@ -386,9 +313,7 @@ class ProblemSetUserProgressAPI(APIView):
progresses = progresses.filter(completed_problems_count=0)
# 排序
progresses = progresses.order_by(
"-is_completed", "-progress_percentage", "join_time"
)
progresses = progresses.order_by("-is_completed", "-progress_percentage", "join_time")
# 计算统计数据(基于所有数据,而非分页数据)
# 使用一次查询获取所有统计数据
@@ -416,12 +341,9 @@ class ProblemSetUserProgressAPI(APIView):
# 提前获取题单的所有题目(用于前端显示未完成题目和序列化器)
# 使用 select_related 和 only 优化查询,只选择需要的字段
all_problemset_problems = (
ProblemSetProblem.objects.filter(problemset=problem_set)
.select_related("problem")
.only("problem__id", "problem___id", "problem__title", "order")
.order_by("order")
ProblemSetProblem.objects.filter(problemset=problem_set).select_related("problem").only("problem__id", "problem___id", "problem__title", "order").order_by("order")
)
# 构建题单所有题目的数据结构和映射
all_problems_list = []
all_problems_map = {}
@@ -444,11 +366,7 @@ class ProblemSetUserProgressAPI(APIView):
completed_problem_ids.update(progress.progress_detail.keys())
# 从已加载的题单题目中构建 problems_dict避免重复查询
problems_dict = {
pid: all_problems_map[pid]
for pid in completed_problem_ids
if pid in all_problems_map
}
problems_dict = {pid: all_problems_map[pid] for pid in completed_problem_ids if pid in all_problems_map}
# 将预加载的问题字典存储到 request 中,供序列化器使用
request._problems_dict_cache = problems_dict