Files
OnlineJudge/problemset/views/oj.py
2025-12-26 17:11:35 +08:00

461 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from django.db.models import Q, Avg, Count, Prefetch
from django.utils import timezone
from utils.api import APIView, validate_serializer
from account.models import User
from problemset.models import (
ProblemSet,
ProblemSetProblem,
ProblemSetBadge,
ProblemSetProgress,
ProblemSetSubmission,
UserBadge,
)
from problemset.serializers import (
ProblemSetSerializer,
ProblemSetListSerializer,
ProblemSetProblemSerializer,
ProblemSetBadgeSerializer,
ProblemSetProgressSerializer,
UserBadgeSerializer,
JoinProblemSetSerializer,
UpdateProgressSerializer,
)
from submission.models import Submission
from problem.models import Problem
class ProblemSetAPI(APIView):
"""题单API - 用户端"""
def get(self, request):
"""获取题单列表"""
# 预加载创建者信息
problem_sets = ProblemSet.objects.filter(visible=True).exclude(status="draft").select_related("created_by")
# 使用annotate在查询时计算题目数量避免N+1查询
problem_sets = problem_sets.annotate(
problems_count=Count("problemsetproblem", distinct=True)
)
# 过滤条件
keyword = request.GET.get("keyword", "").strip()
if keyword:
problem_sets = problem_sets.filter(
Q(title__icontains=keyword) | Q(description__icontains=keyword)
)
difficulty = request.GET.get("difficulty")
if difficulty:
problem_sets = problem_sets.filter(difficulty=difficulty)
status_filter = request.GET.get("status")
if status_filter:
problem_sets = problem_sets.filter(status=status_filter)
# 排序
sort = request.GET.get("sort")
if sort:
problem_sets = problem_sets.order_by(sort)
else:
problem_sets = problem_sets.order_by("-create_time")
# 批量查询用户进度和已获得的奖章(如果用户已登录)
# 注意需要在应用prefetch_related之前获取ID列表避免不必要的预加载
user_progress_map = {}
user_earned_badge_ids = set()
if request.user.is_authenticated:
# 先获取所有题单ID不应用prefetch_related只获取ID
problem_set_ids = list(problem_sets.values_list("id", flat=True))
if problem_set_ids:
# 批量查询用户在这些题单中的进度
user_progresses = ProblemSetProgress.objects.filter(
problemset_id__in=problem_set_ids,
user=request.user
).select_related("problemset")
# 构建映射题单ID -> 进度对象
user_progress_map = {progress.problemset_id: progress for progress in user_progresses}
# 批量查询用户已获得的奖章ID这些题单相关的
user_earned_badge_ids = set(
UserBadge.objects.filter(
user=request.user,
badge__problemset_id__in=problem_set_ids
).values_list('badge_id', flat=True)
)
# 预加载奖章信息在获取ID之后应用避免在获取ID时也预加载
problem_sets = problem_sets.prefetch_related(
Prefetch(
"problemsetbadge_set",
queryset=ProblemSetBadge.objects.all(),
to_attr="badges"
)
)
# 将用户进度映射和已获得的奖章ID集合存储到request中供序列化器使用
request._user_progress_map = user_progress_map
request._user_earned_badge_ids = user_earned_badge_ids
data = self.paginate_data(request, problem_sets, ProblemSetListSerializer)
return self.success(data)
class ProblemSetDetailAPI(APIView):
"""题单详情API - 用户端"""
def get(self, request, problem_set_id):
"""获取题单详情"""
try:
problem_set = (
ProblemSet.objects.filter(id=problem_set_id, visible=True)
.exclude(status="draft")
.get()
)
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
serializer = ProblemSetSerializer(problem_set, context={"request": request})
return self.success(serializer.data)
class ProblemSetProblemAPI(APIView):
"""题单题目API - 用户端"""
def get(self, request, problem_set_id):
"""获取题单中的题目列表"""
try:
problem_set = (
ProblemSet.objects.filter(id=problem_set_id, visible=True)
.exclude(status="draft")
.get()
)
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
problems = ProblemSetProblem.objects.filter(problemset=problem_set).order_by(
"order"
)
serializer = ProblemSetProblemSerializer(
problems, many=True, context={"request": request}
)
return self.success(serializer.data)
class ProblemSetProgressAPI(APIView):
"""题单进度API"""
@validate_serializer(JoinProblemSetSerializer)
def post(self, request):
"""加入题单"""
data = request.data
try:
problem_set = (
ProblemSet.objects.filter(id=data["problemset_id"], visible=True)
.exclude(status="draft")
.get()
)
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
if ProblemSetProgress.objects.filter(
problemset=problem_set, user=request.user
).exists():
return self.error("已经加入该题单")
# 创建进度记录
progress = ProblemSetProgress.objects.create(
problemset=problem_set, user=request.user
)
progress.update_progress()
return self.success("成功加入题单")
def get(self, request, problem_set_id):
"""获取题单进度"""
try:
problem_set = (
ProblemSet.objects.filter(id=problem_set_id, visible=True)
.exclude(status="draft")
.get()
)
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
try:
progress = ProblemSetProgress.objects.get(
problemset=problem_set, user=request.user
)
except ProblemSetProgress.DoesNotExist:
return self.error("未加入该题单")
serializer = ProblemSetProgressSerializer(progress)
return self.success(serializer.data)
@validate_serializer(UpdateProgressSerializer)
def put(self, request):
"""更新进度"""
data = request.data
try:
problem_set = (
ProblemSet.objects.filter(id=data["problemset_id"], visible=True)
.exclude(status="draft")
.get()
)
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
try:
progress = ProblemSetProgress.objects.get(
problemset=problem_set, user=request.user
)
except ProblemSetProgress.DoesNotExist:
return self.error("未加入该题单")
# 更新详细进度
problem_id = str(data["problem_id"])
# 获取该题目在题单中的分值
try:
problemset_problem = ProblemSetProblem.objects.get(
problemset=problem_set, problem_id=problem_id
)
problem_score = problemset_problem.score
except ProblemSetProblem.DoesNotExist:
problem_score = 0
progress.progress_detail[problem_id] = {
"score": problem_score, # 题单中设置的分值
"submit_time": data.get("submit_time", timezone.now().isoformat()),
}
# 更新进度
progress.update_progress()
# 只有当提供了submission_id时才创建ProblemSetSubmission记录
if "submission_id" in data and data["submission_id"]:
try:
submission = Submission.objects.get(id=data["submission_id"])
problem = Problem.objects.get(id=problem_id)
has_accepted = ProblemSetSubmission.objects.filter(
problemset=problem_set,
user=request.user,
problem=problem,
).exists()
if not has_accepted:
ProblemSetSubmission.objects.create(
problemset=problem_set,
user=request.user,
submission=submission,
problem=problem,
)
except Submission.DoesNotExist:
# 如果提交记录不存在,记录错误但不中断流程
pass
# 检查是否获得奖章
self._check_badges(progress)
return self.success("进度已更新")
def _check_badges(self, progress):
"""检查是否获得奖章"""
badges = ProblemSetBadge.objects.filter(problemset=progress.problemset)
for badge in badges:
if UserBadge.objects.filter(user=progress.user, badge=badge).exists():
continue
if badge.condition_type == "all_problems":
if progress.completed_problems_count == progress.total_problems_count:
UserBadge.objects.create(user=progress.user, badge=badge)
elif badge.condition_type == "problem_count":
if progress.completed_problems_count >= badge.condition_value:
UserBadge.objects.create(user=progress.user, badge=badge)
elif badge.condition_type == "score":
if progress.total_score >= badge.condition_value:
UserBadge.objects.create(user=progress.user, badge=badge)
class UserProgressAPI(APIView):
"""用户进度API"""
def get(self, request):
"""获取用户的题单进度列表"""
progress_list = ProblemSetProgress.objects.filter(user=request.user).order_by(
"-join_time"
)
serializer = ProblemSetProgressSerializer(progress_list, many=True)
return self.success(serializer.data)
class UserBadgeAPI(APIView):
"""用户奖章API"""
def get(self, request):
"""获取用户的奖章列表"""
# 支持通过username参数获取指定用户的徽章
username = request.GET.get("username")
if username:
# 获取指定用户的徽章
try:
target_user = User.objects.get(username=username, is_disabled=False)
badges = UserBadge.objects.filter(user=target_user).order_by(
"-earned_time"
)
except User.DoesNotExist:
return self.error("用户不存在")
else:
# 获取当前用户的徽章
badges = UserBadge.objects.filter(user=request.user).order_by(
"-earned_time"
)
serializer = UserBadgeSerializer(badges, many=True)
return self.success(serializer.data)
class ProblemSetBadgeAPI(APIView):
"""题单奖章API - 用户端"""
def get(self, request, problem_set_id):
"""获取题单的奖章列表"""
try:
problem_set = (
ProblemSet.objects.filter(id=problem_set_id, visible=True)
.exclude(status="draft")
.get()
)
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
badges = ProblemSetBadge.objects.filter(problemset=problem_set)
serializer = ProblemSetBadgeSerializer(badges, many=True)
return self.success(serializer.data)
class ProblemSetUserProgressAPI(APIView):
"""题单用户进度列表API"""
def get(self, request, problem_set_id: int):
"""获取题单的用户进度列表"""
try:
problem_set = (
ProblemSet.objects.filter(id=problem_set_id, visible=True)
.exclude(status="draft")
.get()
)
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
# 获取所有参与该题单的用户进度,使用 select_related 预加载用户信息
progresses = ProblemSetProgress.objects.filter(
problemset=problem_set
).select_related("user")
# 班级过滤
class_name = request.GET.get("class_name", "").strip()
if class_name:
progresses = progresses.filter(user__username__icontains=class_name)
# 完成度筛选
completion_status = request.GET.get("completion_status", "").strip()
if completion_status == "completed":
# 已完成:所有题目都已完成
progresses = progresses.filter(is_completed=True)
elif completion_status == "in_progress":
# 进行中:未完成且已开始(至少完成了一道题,排除未开始的用户)
progresses = progresses.filter(is_completed=False, completed_problems_count__gt=0)
elif completion_status == "not_started":
# 未开始:还没有完成任何题目
progresses = progresses.filter(completed_problems_count=0)
# 排序
progresses = progresses.order_by(
"-is_completed", "-progress_percentage", "join_time"
)
# 计算统计数据(基于所有数据,而非分页数据)
# 使用一次查询获取所有统计数据
stats = progresses.aggregate(
total=Count("id"),
completed=Count("id", filter=Q(is_completed=True)),
avg_progress=Avg("progress_percentage"),
)
total_count = stats["total"]
completed_count = stats["completed"]
avg_progress = stats["avg_progress"] or 0
# 获取分页参数
try:
limit = int(request.GET.get("limit", "10"))
except ValueError:
limit = 10
try:
offset = int(request.GET.get("offset", "0"))
except ValueError:
offset = 0
if offset < 0:
offset = 0
# 提前获取题单的所有题目(用于前端显示未完成题目和序列化器)
# 使用 select_related 和 only 优化查询,只选择需要的字段
all_problemset_problems = (
ProblemSetProblem.objects.filter(problemset=problem_set)
.select_related("problem")
.only("problem__id", "problem___id", "problem__title", "order")
.order_by("order")
)
# 构建题单所有题目的数据结构和映射
all_problems_list = []
all_problems_map = {}
for psp in all_problemset_problems:
problem_data = {
"id": psp.problem.id,
"_id": psp.problem._id,
"title": psp.problem.title,
}
all_problems_list.append(problem_data)
# 用于序列化器查找key 使用字符串格式(与 progress_detail 的 key 格式一致)
all_problems_map[str(psp.problem.id)] = psp.problem
# 从当前页的数据中收集已完成的问题ID用于序列化器
paginated_progresses = list(progresses[offset : offset + limit])
completed_problem_ids = set()
for progress in paginated_progresses:
if progress.progress_detail:
# progress_detail 的 key 是字符串格式的 problem_id
completed_problem_ids.update(progress.progress_detail.keys())
# 从已加载的题单题目中构建 problems_dict避免重复查询
problems_dict = {
pid: all_problems_map[pid]
for pid in completed_problem_ids
if pid in all_problems_map
}
# 将预加载的问题字典存储到 request 中,供序列化器使用
request._problems_dict_cache = problems_dict
# 使用分页
data = self.paginate_data(request, progresses, ProblemSetProgressSerializer)
# 添加统计数据
data["statistics"] = {
"total": total_count,
"completed": completed_count,
"avg_progress": round(avg_progress, 2),
}
# 返回题单的所有题目
data["problems"] = all_problems_list
return self.success(data)