diff --git a/account/decorators.py b/account/decorators.py index 7102b17..2c90499 100644 --- a/account/decorators.py +++ b/account/decorators.py @@ -1,5 +1,6 @@ import functools import hashlib +import inspect import time from contest.models import Contest, ContestRuleType, ContestStatus, ContestType @@ -15,47 +16,58 @@ class BasePermissionDecorator(object): self.func = func def __get__(self, obj, obj_type): + if inspect.iscoroutinefunction(self.func): + return functools.partial(self._async_call, obj) return functools.partial(self.__call__, obj) def error(self, data): return JSONResponse.response({"error": "permission-denied", "data": data}) def __call__(self, *args, **kwargs): - self.request = args[1] + request = args[1] - if self.check_permission(): - if self.request.user.is_disabled: + if self.check_permission(request): + if request.user.is_disabled: return self.error("Your account is disabled") return self.func(*args, **kwargs) else: return self.error("Please login first") - def check_permission(self): + async def _async_call(self, *args, **kwargs): + request = args[1] + + if self.check_permission(request): + if request.user.is_disabled: + return self.error("Your account is disabled") + return await self.func(*args, **kwargs) + return self.error("Please login first") + + def check_permission(self, request): raise NotImplementedError() class login_required(BasePermissionDecorator): - def check_permission(self): - return self.request.user.is_authenticated + def check_permission(self, request): + return request.user.is_authenticated class super_admin_required(BasePermissionDecorator): - def check_permission(self): - user = self.request.user + def check_permission(self, request): + user = request.user return user.is_authenticated and user.is_super_admin() class admin_role_required(BasePermissionDecorator): - def check_permission(self): - user = self.request.user + def check_permission(self, request): + user = request.user return user.is_authenticated and user.is_admin_role() class problem_permission_required(admin_role_required): - def check_permission(self): - if not super(problem_permission_required, self).check_permission(): + def check_permission(self, request): + if not super().check_permission(request): return False - if self.request.user.problem_permission == ProblemPermission.NONE: + if request.user.problem_permission == ProblemPermission.NONE: return False return True diff --git a/account/management/commands/clean_deleted_problems.py b/account/management/commands/clean_deleted_problems.py index 6fc4410..81a3f08 100644 --- a/account/management/commands/clean_deleted_problems.py +++ b/account/management/commands/clean_deleted_problems.py @@ -4,7 +4,6 @@ from account.models import UserProfile from problem.models import Problem from submission.models import JudgeStatus - ACCEPTED_STATUSES = {JudgeStatus.ACCEPTED, JudgeStatus.AST_CHECK_FAILED} diff --git a/account/models.py b/account/models.py index 3fe9afa..5ee8fab 100644 --- a/account/models.py +++ b/account/models.py @@ -23,6 +23,9 @@ class UserManager(models.Manager): def get_by_natural_key(self, username): return self.get(**{f"{self.model.USERNAME_FIELD}__iexact": username}) + async def aget_by_natural_key(self, username): + return await self.aget(**{f"{self.model.USERNAME_FIELD}__iexact": username}) + class User(AbstractBaseUser): username = models.TextField(unique=True) diff --git a/account/urls/admin.py b/account/urls/admin.py index 34ed640..b2d88d1 100644 --- a/account/urls/admin.py +++ b/account/urls/admin.py @@ -4,6 +4,6 @@ from ..views.admin import GenerateUserAPI, ResetUserPasswordAPI, UserAdminAPI urlpatterns = [ path("user", UserAdminAPI.as_view()), - path("generate_user", GenerateUserAPI.as_view()), + path("generate_user", GenerateUserAPI.as_view()), # DEPRECATED: 前端未调用 path("reset_password", ResetUserPasswordAPI.as_view()), ] diff --git a/account/urls/oj.py b/account/urls/oj.py index b5f6ac4..f767d70 100644 --- a/account/urls/oj.py +++ b/account/urls/oj.py @@ -29,28 +29,28 @@ urlpatterns = [ path("login", UserLoginAPI.as_view()), path("logout", UserLogoutAPI.as_view()), path("register", UserRegisterAPI.as_view()), - path("change_password", UserChangePasswordAPI.as_view()), - path("change_email", UserChangeEmailAPI.as_view()), - path("apply_reset_password", ApplyResetPasswordAPI.as_view()), - path("reset_password", ResetPasswordAPI.as_view()), + path("change_password", UserChangePasswordAPI.as_view()), # DEPRECATED: 前端未调用 + path("change_email", UserChangeEmailAPI.as_view()), # DEPRECATED: 前端未调用 + path("apply_reset_password", ApplyResetPasswordAPI.as_view()), # DEPRECATED: 前端未调用 + path("reset_password", ResetPasswordAPI.as_view()), # DEPRECATED: 前端未调用 path("captcha", CaptchaAPIView.as_view()), - path("check_username_or_email", UsernameOrEmailCheck.as_view()), + path("check_username_or_email", UsernameOrEmailCheck.as_view()), # DEPRECATED: 前端未调用 path("profile", UserProfileAPI.as_view(), name="user_profile_api"), path("profile/fresh_display_id", ProfileProblemDisplayIDRefreshAPI.as_view()), path("metrics", Metrics.as_view()), path("upload_avatar", AvatarUploadAPI.as_view()), - path("tfa_required", CheckTFARequiredAPI.as_view()), + path("tfa_required", CheckTFARequiredAPI.as_view()), # DEPRECATED: 前端未调用 path( - "two_factor_auth", + "two_factor_auth", # DEPRECATED: 前端未调用 TwoFactorAuthAPI.as_view(), ), path("user_rank", UserRankAPI.as_view()), path("user_activity_rank", UserActivityRankAPI.as_view()), path("user_problem_rank", UserProblemRankAPI.as_view()), - path("sessions", SessionManagementAPI.as_view()), + path("sessions", SessionManagementAPI.as_view()), # DEPRECATED: 前端未调用 path( - "open_api_appkey", + "open_api_appkey", # DEPRECATED: 前端未调用 OpenAPIAppkeyAPI.as_view(), ), - path("sso", SSOAPI.as_view()), + path("sso", SSOAPI.as_view()), # DEPRECATED: 前端未调用 ] diff --git a/account/views/admin.py b/account/views/admin.py index 95140d4..6c15055 100644 --- a/account/views/admin.py +++ b/account/views/admin.py @@ -191,6 +191,7 @@ class UserAdminAPI(APIView): return self.success() +# DEPRECATED: 前端未调用 (2026-05-26) class GenerateUserAPI(APIView): @super_admin_required def get(self, request): diff --git a/account/views/oj.py b/account/views/oj.py index 15b650e..4975b3f 100644 --- a/account/views/oj.py +++ b/account/views/oj.py @@ -1,3 +1,4 @@ +import asyncio import os from datetime import timedelta from importlib import import_module @@ -5,7 +6,6 @@ from importlib import import_module import qrcode from django.conf import settings from django.contrib import auth -from django.core.cache import cache from django.db.models import Count, Q from django.template.loader import render_to_string from django.utils import timezone @@ -16,8 +16,9 @@ from otpauth import TOTP from options.options import SysOptions from problem.models import Problem -from submission.models import JudgeStatus, Submission, is_accepted -from utils.api import APIView, CSRFExemptAPIView, validate_serializer +from submission.models import JudgeStatus, Submission +from utils.api import APIView, AsyncAPIView, CSRFExemptAPIView, validate_serializer +from utils.async_helpers import async_cache_get, async_cache_set from utils.captcha import Captcha from utils.constants import CacheKey, ContestRuleType from utils.shortcuts import datetime2str, img2base64, rand_str @@ -58,12 +59,9 @@ def _valid_totp(token, code): return _totp(token).verify(code) -class UserProfileAPI(APIView): +class UserProfileAPI(AsyncAPIView): @method_decorator(ensure_csrf_cookie) - def get(self, request, **kwargs): - """ - 判断是否登录, 若登录返回用户信息 - """ + async def get(self, request, **kwargs): user = request.user if not user.is_authenticated: return self.success() @@ -71,52 +69,51 @@ class UserProfileAPI(APIView): username = request.GET.get("username") try: if username: - user = User.objects.get(username=username, is_disabled=False) + user = await User.objects.aget(username=username, is_disabled=False) else: user = request.user - # api返回的是自己的信息,可以返real_name show_real_name = True except User.DoesNotExist: return self.error("User does not exist") - return self.success(UserProfileSerializer(user.userprofile, show_real_name=show_real_name).data) + profile = await UserProfile.objects.select_related("user").aget(user=user) + return self.success(UserProfileSerializer(profile, show_real_name=show_real_name).data) @validate_serializer(EditUserProfileSerializer) @login_required - def put(self, request): + async def put(self, request): data = request.data - user_profile = request.user.userprofile + user_profile = await UserProfile.objects.select_related("user").aget(user=request.user) for k, v in data.items(): setattr(user_profile, k, v) - user_profile.save() + await user_profile.asave() return self.success(UserProfileSerializer(user_profile, show_real_name=True).data) -class Metrics(APIView): - def get(self, request): +class Metrics(AsyncAPIView): + async def get(self, request): userid = request.GET.get("userid") - submissions = Submission.objects.filter(user_id=userid, contest_id__isnull=True) - if submissions.count() == 0: + qs = Submission.objects.filter(user_id=userid, contest_id__isnull=True) + count, latest, first = await asyncio.gather( + qs.acount(), + qs.order_by("-create_time").afirst(), + qs.order_by("create_time").afirst(), + ) + if count == 0 or not latest or not first: return self.error("暂无提交") - else: - latest_submission = submissions.first() - last_submission = submissions.last() - if last_submission and latest_submission: - return self.success( - { - "now": datetime2str(timezone.now()), - "latest": datetime2str(latest_submission.create_time), - "first": datetime2str(last_submission.create_time), - } - ) - else: - return self.error("暂无提交") + return self.success( + { + "now": datetime2str(timezone.now()), + "latest": datetime2str(latest.create_time), + "first": datetime2str(first.create_time), + } + ) -class AvatarUploadAPI(APIView): +class AvatarUploadAPI(AsyncAPIView): request_parsers = () @login_required - def post(self, request): + async def post(self, request): form = ImageUploadForm(request.POST, request.FILES) if form.is_valid(): avatar = form.cleaned_data["image"] @@ -132,13 +129,14 @@ class AvatarUploadAPI(APIView): with open(os.path.join(settings.AVATAR_UPLOAD_DIR, name), "wb") as img: for chunk in avatar: img.write(chunk) - user_profile = request.user.userprofile + user_profile = await UserProfile.objects.aget(user=request.user) user_profile.avatar = f"{settings.AVATAR_URI_PREFIX}/{name}" - user_profile.save() + await user_profile.asave() return self.success("Succeeded") +# DEPRECATED: 前端未调用 (2026-05-26) class TwoFactorAuthAPI(APIView): @login_required def get(self, request): @@ -186,6 +184,7 @@ class TwoFactorAuthAPI(APIView): return self.error("Invalid code") +# DEPRECATED: 前端未调用 (2026-05-26) class CheckTFARequiredAPI(APIView): @validate_serializer(UsernameOrEmailCheckSerializer) def post(self, request): @@ -203,31 +202,26 @@ class CheckTFARequiredAPI(APIView): return self.success({"result": result}) -class UserLoginAPI(APIView): +class UserLoginAPI(AsyncAPIView): @validate_serializer(UserLoginSerializer) - def post(self, request): - """ - User login api - """ + async def post(self, request): data = request.data - user = auth.authenticate(username=data["username"], password=data["password"]) - # None is returned if username or password is wrong + user = await auth.aauthenticate(username=data["username"], password=data["password"]) if user: if user.is_disabled: return self.error("Your account has been disabled") if not user.two_factor_auth: prev_login = user.last_login - auth.login(request, user) + await auth.alogin(request, user) request.session["prev_login"] = datetime2str(prev_login) if prev_login else "" return self.success("Succeeded") - # `tfa_code` not in post data if user.two_factor_auth and "tfa_code" not in data: return self.error("tfa_required") if _valid_totp(user.tfa_token, data["tfa_code"]): prev_login = user.last_login - auth.login(request, user) + await auth.alogin(request, user) request.session["prev_login"] = datetime2str(prev_login) if prev_login else "" return self.success("Succeeded") else: @@ -236,12 +230,13 @@ class UserLoginAPI(APIView): return self.error("Invalid username or password") -class UserLogoutAPI(APIView): - def get(self, request): - auth.logout(request) +class UserLogoutAPI(AsyncAPIView): + async def get(self, request): + await auth.alogout(request) return self.success() +# DEPRECATED: 前端未调用 (2026-05-26) class UsernameOrEmailCheck(APIView): @validate_serializer(UsernameOrEmailCheckSerializer) def post(self, request): @@ -258,13 +253,10 @@ class UsernameOrEmailCheck(APIView): return self.success(result) -class UserRegisterAPI(APIView): +class UserRegisterAPI(AsyncAPIView): @validate_serializer(UserRegisterSerializer) - def post(self, request): - """ - User register api - """ - if not SysOptions.allow_register: + async def post(self, request): + if not await SysOptions.aget("allow_register"): return self.error("Register function has been disabled by admin") data = request.data @@ -273,17 +265,18 @@ class UserRegisterAPI(APIView): captcha = Captcha(request) if not captcha.check(data["captcha"]): return self.error("Invalid captcha") - if User.objects.filter(username=data["username"]).exists(): + if await User.objects.filter(username=data["username"]).aexists(): return self.error("Username already exists") - if User.objects.filter(email=data["email"]).exists(): + if await User.objects.filter(email=data["email"]).aexists(): return self.error("Email already exists") - user = User.objects.create(username=data["username"], email=data["email"]) + user = await User.objects.acreate(username=data["username"], email=data["email"]) user.set_password(data["password"]) - user.save() - UserProfile.objects.create(user=user) + await user.asave() + await UserProfile.objects.acreate(user=user) return self.success("Succeeded") +# DEPRECATED: 前端未调用 (2026-05-26) class UserChangeEmailAPI(APIView): @validate_serializer(UserChangeEmailSerializer) @login_required @@ -306,6 +299,7 @@ class UserChangeEmailAPI(APIView): return self.error("Wrong password") +# DEPRECATED: 前端未调用 (2026-05-26) class UserChangePasswordAPI(APIView): @validate_serializer(UserChangePasswordSerializer) @login_required @@ -329,6 +323,7 @@ class UserChangePasswordAPI(APIView): return self.error("Invalid old password") +# DEPRECATED: 前端未调用 (2026-05-26) class ApplyResetPasswordAPI(APIView): @validate_serializer(ApplyResetPasswordSerializer) def post(self, request): @@ -363,6 +358,7 @@ class ApplyResetPasswordAPI(APIView): return self.success("Succeeded") +# DEPRECATED: 前端未调用 (2026-05-26) class ResetPasswordAPI(APIView): @validate_serializer(ResetPasswordSerializer) def post(self, request): @@ -383,6 +379,7 @@ class ResetPasswordAPI(APIView): return self.success("Succeeded") +# DEPRECATED: 前端未调用 (2026-05-26) class SessionManagementAPI(APIView): @login_required def get(self, request): @@ -426,8 +423,8 @@ class SessionManagementAPI(APIView): return self.error("Invalid session_key") -class UserRankAPI(APIView): - def get(self, request): +class UserRankAPI(AsyncAPIView): + async def get(self, request): rule_type = request.GET.get("rule") username = request.GET.get("username", "") try: @@ -448,16 +445,16 @@ class UserRankAPI(APIView): profiles = profiles.filter(total_score__gt=0).order_by("-total_score") if n > 0: profiles = profiles[:n] - return self.success(self.paginate_data(request, profiles, RankInfoSerializer)) + return self.success(await self.async_paginate_data(request, profiles, RankInfoSerializer)) -class UserActivityRankAPI(APIView): - def get(self, request): +class UserActivityRankAPI(AsyncAPIView): + async def get(self, request): start = request.GET.get("start") if not start: return self.error("start time is required") cache_key = f"{CacheKey.user_activity_rank}:{start}" - cached = cache.get(cache_key) + cached = await async_cache_get(cache_key) if cached is not None: return self.success(cached) @@ -467,35 +464,40 @@ class UserActivityRankAPI(APIView): create_time__gte=start, result__in=[JudgeStatus.ACCEPTED, JudgeStatus.AST_CHECK_FAILED], ).exclude(username__in=hidden_names) - data = list(submissions.values("username").annotate(count=Count("problem_id", distinct=True)).order_by("-count")[:10]) - cache.set(cache_key, data, 600) + data = [ + row + async for row in submissions.values("username") + .annotate(count=Count("problem_id", distinct=True)) + .order_by("-count")[:10] + ] + await async_cache_set(cache_key, data, 600) return self.success(data) -class UserProblemRankAPI(APIView): - def get(self, request): +class UserProblemRankAPI(AsyncAPIView): + async def get(self, request): problem_id = request.GET.get("problem_id") user = request.user if not user.is_authenticated: return self.error("User is not authenticated") - problem = Problem.objects.get(_id__iexact=problem_id, contest_id__isnull=True, visible=True) + problem = await Problem.objects.aget(_id__iexact=problem_id, contest_id__isnull=True, visible=True) submissions = Submission.objects.filter(problem=problem, result__in=[JudgeStatus.ACCEPTED, JudgeStatus.AST_CHECK_FAILED]) - all_ac_count = submissions.values("user_id").distinct().count() + all_ac_count = await submissions.values("user_id").distinct().acount() class_name = user.class_name or "" class_ac_count = 0 if class_name: users = User.objects.filter(class_name=user.class_name, is_disabled=False).values_list("id", flat=True) - user_ids = list(users) + user_ids = [user_id async for user_id in users] submissions = submissions.filter(user_id__in=user_ids) - class_ac_count = submissions.values("user_id").distinct().count() + class_ac_count = await submissions.values("user_id").distinct().acount() my_submissions = submissions.filter(user_id=user.id) - if len(my_submissions) == 0: + if not await my_submissions.aexists(): return self.success( { "class_name": class_name, @@ -505,8 +507,8 @@ class UserProblemRankAPI(APIView): } ) - my_first_submission = my_submissions.order_by("create_time").first() - rank = submissions.filter(create_time__lte=my_first_submission.create_time).count() + my_first_submission = await my_submissions.order_by("create_time").afirst() + rank = await submissions.filter(create_time__lte=my_first_submission.create_time).acount() return self.success( { "class_name": class_name, @@ -517,25 +519,26 @@ class UserProblemRankAPI(APIView): ) -class ProfileProblemDisplayIDRefreshAPI(APIView): +class ProfileProblemDisplayIDRefreshAPI(AsyncAPIView): @login_required - def get(self, request): - profile = request.user.userprofile + async def get(self, request): + profile = await UserProfile.objects.aget(user=request.user) acm_problems = profile.acm_problems_status.get("problems", {}) oi_problems = profile.oi_problems_status.get("problems", {}) ids = list(acm_problems.keys()) + list(oi_problems.keys()) if not ids: return self.success() - display_ids = Problem.objects.filter(id__in=ids, visible=True).values_list("_id", flat=True) + display_ids = [did async for did in Problem.objects.filter(id__in=ids, visible=True).values_list("_id", flat=True)] id_map = dict(zip(ids, display_ids)) for k, v in acm_problems.items(): v["_id"] = id_map[k] for k, v in oi_problems.items(): v["_id"] = id_map[k] - profile.save(update_fields=["acm_problems_status", "oi_problems_status"]) + await profile.asave(update_fields=["acm_problems_status", "oi_problems_status"]) return self.success() +# DEPRECATED: 前端未调用 (2026-05-26) class OpenAPIAppkeyAPI(APIView): @login_required def post(self, request): @@ -548,6 +551,7 @@ class OpenAPIAppkeyAPI(APIView): return self.success({"appkey": api_appkey}) +# DEPRECATED: 前端未调用 (2026-05-26) class SSOAPI(CSRFExemptAPIView): @login_required def get(self, request): diff --git a/announcement/views/oj.py b/announcement/views/oj.py index a987c90..ea9e81e 100644 --- a/announcement/views/oj.py +++ b/announcement/views/oj.py @@ -1,19 +1,25 @@ from announcement.models import Announcement from announcement.serializers import AnnouncementListSerializer, AnnouncementSerializer -from utils.api import APIView +from utils.api import AsyncAPIView -class AnnouncementAPI(APIView): - def get(self, request): +class AnnouncementAPI(AsyncAPIView): + async def get(self, request): id = request.GET.get("id") if id: try: - announcement = Announcement.objects.get(id=id, visible=True) - return self.success(AnnouncementSerializer(announcement).data) + announcement = await ( + Announcement.objects.select_related("created_by") + .filter(id=id, visible=True) + .afirst() + ) + if announcement is None: + raise Announcement.DoesNotExist + return self.success(await self.async_serialize_data(AnnouncementSerializer, announcement)) except Announcement.DoesNotExist: return self.error("Announcement does not exist") announcements = Announcement.objects.select_related("created_by").filter(visible=True) return self.success( - self.paginate_data(request, announcements, AnnouncementListSerializer) + await self.async_paginate_data(request, announcements, AnnouncementListSerializer) ) diff --git a/ast_checker/engines/nesting.py b/ast_checker/engines/nesting.py index cb7df25..fd36755 100644 --- a/ast_checker/engines/nesting.py +++ b/ast_checker/engines/nesting.py @@ -1,6 +1,7 @@ -from .base import BaseEngine from ast_checker.labels import label +from .base import BaseEngine + class MustHaveNestingEngine(BaseEngine): def _has_inner_in_subtree(self, node, inner_type): diff --git a/ast_checker/engines/node_count.py b/ast_checker/engines/node_count.py index 8aebdcb..54b2349 100644 --- a/ast_checker/engines/node_count.py +++ b/ast_checker/engines/node_count.py @@ -1,6 +1,7 @@ -from .base import BaseEngine from ast_checker.labels import label +from .base import BaseEngine + class CountNodeEngine(BaseEngine): def _message(self, rule, count): diff --git a/ast_checker/engines/node_exists.py b/ast_checker/engines/node_exists.py index 114b449..94c4a72 100644 --- a/ast_checker/engines/node_exists.py +++ b/ast_checker/engines/node_exists.py @@ -1,6 +1,7 @@ -from .base import BaseEngine from ast_checker.labels import label +from .base import BaseEngine + class MustExistNodeEngine(BaseEngine): def _message(self, rule): diff --git a/comment/views/oj.py b/comment/views/oj.py index 0a42afb..d2272ba 100644 --- a/comment/views/oj.py +++ b/comment/views/oj.py @@ -1,45 +1,44 @@ -from django.core.cache import cache -from django.db.models import Avg, Count -from django.db.models.functions import Round - -from account.decorators import login_required +from django.db.models import Avg, Count +from django.db.models.functions import Round + +from account.decorators import login_required from comment.models import Comment from comment.serializers import CommentSerializer, CreateCommentSerializer from problem.models import Problem -from submission.models import JudgeStatus, Submission -from utils.api import APIView -from utils.api.api import validate_serializer -from utils.constants import CacheKey +from submission.models import JudgeStatus, Submission +from utils.api import AsyncAPIView +from utils.api.api import validate_serializer +from utils.async_helpers import async_cache_delete, async_cache_get, async_cache_set +from utils.constants import CacheKey -class CommentAPI(APIView): +class CommentAPI(AsyncAPIView): @validate_serializer(CreateCommentSerializer) @login_required - def post(self, request): + async def post(self, request): data = request.data try: - problem = Problem.objects.get(id=data["problem_id"], visible=True) + problem = await Problem.objects.aget(id=data["problem_id"], visible=True) except Problem.DoesNotExist: - self.error("problem is not exists") + return self.error("problem is not exists") - try: - submission = ( - Submission.objects.select_related("problem") - .filter( - user_id=request.user.id, - problem_id=data["problem_id"], - result__in=[JudgeStatus.ACCEPTED, JudgeStatus.AST_CHECK_FAILED], - ) - .first() + submission = await ( + Submission.objects.select_related("problem") + .filter( + user_id=request.user.id, + problem_id=data["problem_id"], + result__in=[JudgeStatus.ACCEPTED, JudgeStatus.AST_CHECK_FAILED], ) - except Submission.DoesNotExist: - self.error("submission is not exists or not accepted") + .afirst() + ) + if not submission: + return self.error("submission is not exists or not accepted") language = submission.language if language == "Python3": language = "Python" - Comment.objects.create( + await Comment.objects.acreate( user=request.user, problem=problem, submission=submission, @@ -49,32 +48,35 @@ class CommentAPI(APIView): comprehensive_rating=data["comprehensive_rating"], content=data["content"], ) - cache.delete(f"{CacheKey.comment_stats}:{problem.id}") - return self.success() + await async_cache_delete(f"{CacheKey.comment_stats}:{problem.id}") + return self.success() @login_required - def get(self, request): + async def get(self, request): problem_id = request.GET.get("problem_id") - comment = ( + comment = await ( Comment.objects.select_related("problem") .filter(user=request.user, problem_id=problem_id) - .first() - ) - if comment: - return self.success(CommentSerializer(comment).data) - else: - return self.success() + .afirst() + ) + if comment: + return self.success(await self.async_serialize_data(CommentSerializer, comment)) + else: + return self.success() -class CommentStatisticsAPI(APIView): - def get(self, request): - problem_id = request.GET.get("problem_id") - cache_key = f"{CacheKey.comment_stats}:{problem_id}" - cached = cache.get(cache_key) - if cached is not None: - return self.success(cached) +class CommentStatisticsAPI(AsyncAPIView): + async def get(self, request): + problem_id = request.GET.get("problem_id") + cache_key = f"{CacheKey.comment_stats}:{problem_id}" + cached = await async_cache_get(cache_key) + if cached is not None: + return self.success(cached) - agg = Comment.objects.filter(problem_id=problem_id).aggregate( + from asgiref.sync import sync_to_async + agg = await sync_to_async( + Comment.objects.filter(problem_id=problem_id).aggregate + )( count=Count("id"), description=Round(Avg("description_rating"), 2), difficulty=Round(Avg("difficulty_rating"), 2), @@ -88,5 +90,5 @@ class CommentStatisticsAPI(APIView): "difficulty": agg["difficulty"], "comprehensive": agg["comprehensive"], }} - cache.set(cache_key, data, 3600) - return self.success(data) + await async_cache_set(cache_key, data, 3600) + return self.success(data) diff --git a/conf/urls/admin.py b/conf/urls/admin.py index ab268e6..a9b5f6c 100644 --- a/conf/urls/admin.py +++ b/conf/urls/admin.py @@ -12,12 +12,12 @@ from ..views import ( ) urlpatterns = [ - path("smtp", SMTPAPI.as_view()), - path("smtp_test", SMTPTestAPI.as_view()), + path("smtp", SMTPAPI.as_view()), # DEPRECATED: 前端未调用 + path("smtp_test", SMTPTestAPI.as_view()), # DEPRECATED: 前端未调用 path("website", WebsiteConfigAPI.as_view()), path("random_user", RandomUsernameAPI.as_view()), path("judge_server", JudgeServerAPI.as_view()), path("prune_test_case", TestCasePruneAPI.as_view()), - path("versions", ReleaseNotesAPI.as_view()), + path("versions", ReleaseNotesAPI.as_view()), # DEPRECATED: 前端未调用 path("dashboard_info", DashboardInfoAPI.as_view()), ] diff --git a/conf/urls/oj.py b/conf/urls/oj.py index c3438c8..dfe01ad 100644 --- a/conf/urls/oj.py +++ b/conf/urls/oj.py @@ -12,7 +12,7 @@ urlpatterns = [ path("website", WebsiteConfigAPI.as_view()), # 这里必须要有 / path("judge_server_heartbeat/", JudgeServerHeartbeatAPI.as_view()), - path("languages", LanguagesAPI.as_view()), + path("languages", LanguagesAPI.as_view()), # DEPRECATED: 前端未调用 path("hitokoto", HitokotoAPI.as_view()), path("class_usernames", ClassUsernamesAPI.as_view()), ] diff --git a/conf/views.py b/conf/views.py index abe268d..08b5b24 100644 --- a/conf/views.py +++ b/conf/views.py @@ -1,3 +1,4 @@ +import asyncio import hashlib import json import os @@ -6,9 +7,10 @@ import re import shutil import smtplib import time -from datetime import datetime +from datetime import timedelta import requests +from asgiref.sync import sync_to_async from django.conf import settings from django.utils import timezone from requests.exceptions import RequestException @@ -20,7 +22,7 @@ from judge.dispatcher import process_pending_task from options.options import SysOptions from problem.models import Problem from submission.models import Submission -from utils.api import APIView, CSRFExemptAPIView, validate_serializer +from utils.api import APIView, AsyncAPIView, CSRFExemptAPIView, validate_serializer from utils.cache import JsonDataLoader from utils.shortcuts import get_env, send_email from utils.websocket import push_config_update @@ -38,6 +40,7 @@ from .serializers import ( ) +# DEPRECATED: 前端未调用 (2026-05-26) class SMTPAPI(APIView): @super_admin_required def get(self, request): @@ -66,6 +69,7 @@ class SMTPAPI(APIView): return self.success() +# DEPRECATED: 前端未调用 (2026-05-26) class SMTPTestAPI(APIView): @super_admin_required @validate_serializer(TestSMTPConfigSerializer) @@ -97,35 +101,33 @@ class SMTPTestAPI(APIView): return self.success() -class WebsiteConfigAPI(APIView): - def get(self, request): - ret = { - key: getattr(SysOptions, key) - for key in [ - "website_base_url", - "website_name", - "website_name_shortcut", - "website_footer", - "allow_register", - "submission_list_show_all", - "class_list", - "enable_maxkb", - ] - } +class WebsiteConfigAPI(AsyncAPIView): + async def get(self, request): + ret = await SysOptions.aget_many( + "website_base_url", + "website_name", + "website_name_shortcut", + "website_footer", + "allow_register", + "submission_list_show_all", + "class_list", + "enable_maxkb", + ) return self.success(ret) @super_admin_required @validate_serializer(CreateEditWebsiteConfigSerializer) - def post(self, request): - for k, v in request.data.items(): - if k == "website_footer": - with XSSHtml() as parser: - v = parser.clean(v) - setattr(SysOptions, k, v) - - # 推送配置更新到所有连接的客户端 - push_config_update(k, v) + async def post(self, request): + @sync_to_async + def _update_config(data): + for k, v in data.items(): + if k == "website_footer": + with XSSHtml() as parser: + v = parser.clean(v) + setattr(SysOptions, k, v) + push_config_update(k, v) + await _update_config(request.data) return self.success() @@ -206,6 +208,7 @@ class JudgeServerHeartbeatAPI(CSRFExemptAPIView): return self.success() +# DEPRECATED: 前端未调用 (2026-05-26) class LanguagesAPI(APIView): def get(self, request): return self.success( @@ -255,6 +258,7 @@ class TestCasePruneAPI(APIView): shutil.rmtree(test_case_dir, ignore_errors=True) +# DEPRECATED: 前端未调用 (2026-05-26) class ReleaseNotesAPI(APIView): def get(self, request): try: @@ -272,24 +276,29 @@ class ReleaseNotesAPI(APIView): return self.success(releases) -class DashboardInfoAPI(APIView): - def get(self, request): - today = datetime.today() - today_submission_count = Submission.objects.filter( - create_time__gte=datetime(today.year, today.month, today.day, 0, 0) - ).count() - recent_contest_count = Contest.objects.exclude( - end_time__lt=timezone.now() - ).count() - judge_server_count = len( - list(filter(lambda x: x.status == "normal", JudgeServer.objects.all())) +class DashboardInfoAPI(AsyncAPIView): + async def get(self, request): + now = timezone.now() + today_start = now.replace(hour=0, minute=0, second=0, microsecond=0) + ( + user_count, + today_submission_count, + recent_contest_count, + judge_servers, + ) = await asyncio.gather( + User.objects.acount(), + Submission.objects.filter(create_time__gte=today_start).acount(), + Contest.objects.exclude(end_time__lt=timezone.now()).acount(), + JudgeServer.objects.filter( + last_heartbeat__gte=timezone.now() - timedelta(seconds=6) + ).acount(), ) return self.success( { - "user_count": User.objects.count(), + "user_count": user_count, "recent_contest_count": recent_contest_count, "today_submission_count": today_submission_count, - "judge_server_count": judge_server_count, + "judge_server_count": judge_servers, "env": { "FORCE_HTTPS": get_env("FORCE_HTTPS", default=False), "STATIC_CDN_HOST": get_env("STATIC_CDN_HOST", default=""), @@ -298,24 +307,21 @@ class DashboardInfoAPI(APIView): ) -class RandomUsernameAPI(APIView): - def get(self, request): +class RandomUsernameAPI(AsyncAPIView): + async def get(self, request): classroom = request.GET.get("classroom", "") if not classroom: return self.error("需要班级号") - usernames = ( - User.objects.filter(username__istartswith=classroom) + usernames = [ + u async for u in User.objects.filter(username__istartswith=classroom) .values_list("username", flat=True) - .order_by("?") - ) - if len(usernames) > 10: - return self.success(usernames[:10]) - else: - return self.success(usernames) + .order_by("?")[:10] + ] + return self.success(usernames) -class HitokotoAPI(APIView): - def get(self, request): +class HitokotoAPI(AsyncAPIView): + async def get(self, request): try: categories = JsonDataLoader.load_data( settings.HITOKOTO_DIR, "categories.json" @@ -328,20 +334,14 @@ class HitokotoAPI(APIView): return self.error("获取一言失败,请稍后再试") -class ClassUsernamesAPI(APIView): - def get(self, request): +class ClassUsernamesAPI(AsyncAPIView): + async def get(self, request): classroom = request.GET.get("classroom", "") if not classroom: return self.error("需要班级号") - users = User.objects.filter(class_name=classroom).order_by("-create_time") - names = [] - for user in users: - prefix = f"ks{classroom}" - result = ( - user.username[len(prefix) :] - if user.username.startswith(prefix) - else user.username - ) - names.append(result) - + prefix = f"ks{classroom}" + names = [ + user.username[len(prefix):] if user.username.startswith(prefix) else user.username + async for user in User.objects.filter(class_name=classroom).order_by("-create_time") + ] return self.success(names) diff --git a/contest/urls/admin.py b/contest/urls/admin.py index 107c557..21e8aca 100644 --- a/contest/urls/admin.py +++ b/contest/urls/admin.py @@ -5,7 +5,7 @@ from ..views.admin import ACMContestHelper, ContestAnnouncementAPI, ContestAPI, urlpatterns = [ path("contest", ContestAPI.as_view()), path("contest/clone", ContestCloneAPI.as_view()), - path("contest/announcement", ContestAnnouncementAPI.as_view()), + path("contest/announcement", ContestAnnouncementAPI.as_view()), # DEPRECATED: 前端未调用 path("contest/acm_helper", ACMContestHelper.as_view()), - path("download_submissions", DownloadContestSubmissions.as_view()), + path("download_submissions", DownloadContestSubmissions.as_view()), # DEPRECATED: 前端未调用 ] diff --git a/contest/urls/oj.py b/contest/urls/oj.py index 7377019..b20ba45 100644 --- a/contest/urls/oj.py +++ b/contest/urls/oj.py @@ -6,7 +6,7 @@ urlpatterns = [ path("contests", ContestListAPI.as_view()), path("contest", ContestAPI.as_view()), path("contest/password", ContestPasswordVerifyAPI.as_view()), - path("contest/announcement", ContestAnnouncementListAPI.as_view()), + path("contest/announcement", ContestAnnouncementListAPI.as_view()), # DEPRECATED: 前端未调用 path("contest/access", ContestAccessAPI.as_view()), path("contest_rank", ContestRankAPI.as_view()), ] diff --git a/contest/views/admin.py b/contest/views/admin.py index 2402953..6189830 100644 --- a/contest/views/admin.py +++ b/contest/views/admin.py @@ -97,6 +97,7 @@ class ContestAPI(APIView): return self.success(self.paginate_data(request, contests, ContestAdminSerializer)) +# DEPRECATED: 前端未调用 (2026-05-26) class ContestAnnouncementAPI(APIView): @validate_serializer(CreateContestAnnouncementSerializer) @super_admin_required @@ -212,6 +213,7 @@ class ACMContestHelper(APIView): return self.success() +# DEPRECATED: 前端未调用 (2026-05-26) class DownloadContestSubmissions(APIView): def _dump_submissions(self, contest, exclude_admin=True): problem_ids = contest.problem_set.all().values_list("id", "_id") diff --git a/contest/views/oj.py b/contest/views/oj.py index 0cfa551..87781a9 100644 --- a/contest/views/oj.py +++ b/contest/views/oj.py @@ -12,7 +12,7 @@ from account.decorators import ( ) from account.models import AdminType from problem.models import Problem -from utils.api import APIView, validate_serializer +from utils.api import APIView, AsyncAPIView, validate_serializer from utils.constants import CONTEST_PASSWORD_SESSION_KEY, CacheKey, ContestRuleType, ContestStatus from utils.shortcuts import check_is_id, datetime2str @@ -20,6 +20,7 @@ from ..models import ACMContestRank, Contest, ContestAnnouncement, OIContestRank from ..serializers import ACMContestRankSerializer, ContestAnnouncementSerializer, ContestPasswordVerifySerializer, ContestSerializer, OIContestRankSerializer +# DEPRECATED: 前端未调用 (2026-05-26) class ContestAnnouncementListAPI(APIView): @check_contest_permission(check_type="announcements") def get(self, request): @@ -35,22 +36,28 @@ class ContestAnnouncementListAPI(APIView): return self.success(ContestAnnouncementSerializer(data, many=True).data) -class ContestAPI(APIView): - def get(self, request): +class ContestAPI(AsyncAPIView): + async def get(self, request): id = request.GET.get("id") if not id or not check_is_id(id): return self.error("Invalid parameter, id is required") try: - contest = Contest.objects.get(id=id, visible=True) + contest = await ( + Contest.objects.select_related("created_by") + .filter(id=id, visible=True) + .afirst() + ) + if contest is None: + raise Contest.DoesNotExist except Contest.DoesNotExist: return self.error("Contest does not exist") - data = ContestSerializer(contest).data + data = await self.async_serialize_data(ContestSerializer, contest) data["now"] = datetime2str(now()) return self.success(data) -class ContestListAPI(APIView): - def get(self, request): +class ContestListAPI(AsyncAPIView): + async def get(self, request): contests = Contest.objects.select_related("created_by").filter(visible=True) keyword = request.GET.get("keyword") rule_type = request.GET.get("rule_type") @@ -70,7 +77,7 @@ class ContestListAPI(APIView): contests = contests.filter(end_time__lt=cur) else: contests = contests.filter(start_time__lte=cur, end_time__gte=cur) - return self.success(self.paginate_data(request, contests, ContestSerializer)) + return self.success(await self.async_paginate_data(request, contests, ContestSerializer)) class ContestPasswordVerifyAPI(APIView): diff --git a/deploy/requirements.txt b/deploy/requirements.txt index 81a66dc..c6eb248 100644 --- a/deploy/requirements.txt +++ b/deploy/requirements.txt @@ -18,6 +18,7 @@ asgiref==3.11.1 \ # channels # channels-redis # django + # onlinejudge certifi==2026.4.22 \ --hash=sha256:3cb2210c8f88ba2318d29b0388d1023c8492ff72ecdde4ebdaddbb13a31b1c4a \ --hash=sha256:8d455352a37b71bf76a79caa83a3d6c25afee4a385d632127b6afb3963f1c580 diff --git a/docs/superpowers/plans/2026-05-26-backend-async.md b/docs/superpowers/plans/2026-05-26-backend-async.md new file mode 100644 index 0000000..21d4139 --- /dev/null +++ b/docs/superpowers/plans/2026-05-26-backend-async.md @@ -0,0 +1,861 @@ +# Backend Async Hardening Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Make the current backend async work correct first, then establish safe patterns for expanding async views without changing API response shapes. + +**Architecture:** Keep the existing custom `APIView`/DRF serializer stack. Add async-safe tests and helpers around `AsyncAPIView`, explicitly preload serializer relations in converted endpoints, and use `sync_to_async(..., thread_sensitive=True)` for synchronous serializer/cache/helper code that remains inside async views. + +**Tech Stack:** Django 6.0.4, custom class-based API views, Django async ORM, DRF serializers, PostgreSQL, Redis cache, Channels ASGI. + +--- + +## Async Rules For This Repository + +1. Async conversion is valid only when the endpoint preserves URL, method, status code, JSON envelope, and permission behavior. +2. Every async view that serializes model instances must either preload all serializer relations with `select_related()` / `prefetch_related()` or run serializer `.data` through a sync boundary. +3. `asyncio.gather()` is only for independent reads. Do not use it around writes that depend on ordering or transaction semantics. +4. Keep file upload/download, SMTP, test-case pruning, judge heartbeat mutation paths, and contest permission-heavy flows sync until the async decorator and middleware work is complete. +5. Current sync `MiddlewareMixin` middleware means ASGI requests still cross sync boundaries. The first async milestone is correctness and latency cleanup, not a full event-loop purity claim. + +--- + +### Task 1: Add Regression Coverage For Converted Async Detail Views + +**Files:** +- Create: `utils/test_async_view_regressions.py` + +- [ ] **Step 1: Write failing regression tests** + +Create `utils/test_async_view_regressions.py`: + +```python +from datetime import timedelta + +from asgiref.sync import sync_to_async +from django.contrib.auth import get_user_model +from django.test import AsyncClient, TestCase +from django.utils import timezone + +from account.models import UserProfile +from announcement.models import Announcement +from contest.models import Contest +from flowchart.models import FlowchartSubmission +from problem.models import Problem, ProblemRuleType +from utils.constants import ContestRuleType, Difficulty + +User = get_user_model() + + +def make_user(username="async_user"): + user = User.objects.create(username=username, email=f"{username}@example.com") + user.set_password("pass1234") + user.save() + UserProfile.objects.create(user=user) + return user + + +def make_problem(user): + return Problem.objects.create( + _id="ASYNC001", + title="Async Problem", + description="desc", + input_description="input", + output_description="output", + samples=[], + test_case_id="async-test-case", + test_case_score=[], + hint="", + languages=["Python3"], + template={}, + created_by=user, + time_limit=1000, + memory_limit=128, + rule_type=ProblemRuleType.ACM, + difficulty=Difficulty.LOW, + share_submission=False, + allow_flowchart=True, + show_flowchart=True, + ) + + +class AsyncConvertedViewRegressionTests(TestCase): + @classmethod + def setUpTestData(cls): + cls.user = make_user() + cls.announcement = Announcement.objects.create( + title="Async Announcement", + content="content", + tag="notice", + visible=True, + top=False, + created_by=cls.user, + ) + cls.contest = Contest.objects.create( + title="Async Contest", + description="contest desc", + tag="weekly", + real_time_rank=True, + password=None, + rule_type=ContestRuleType.ACM, + start_time=timezone.now() - timedelta(hours=1), + end_time=timezone.now() + timedelta(hours=1), + created_by=cls.user, + visible=True, + allowed_ip_ranges=[], + ) + cls.problem = make_problem(cls.user) + cls.flowchart_submission = FlowchartSubmission.objects.create( + user=cls.user, + problem=cls.problem, + mermaid_code="graph TD\nA-->B", + flowchart_data={}, + ) + + async def test_announcement_detail_serializes_created_by(self): + response = await AsyncClient().get(f"/api/announcement?id={self.announcement.id}") + + self.assertEqual(response.status_code, 200) + body = response.json() + self.assertIsNone(body["error"]) + self.assertEqual(body["data"]["created_by"]["username"], self.user.username) + + async def test_contest_detail_serializes_created_by(self): + response = await AsyncClient().get(f"/api/contest?id={self.contest.id}") + + self.assertEqual(response.status_code, 200) + body = response.json() + self.assertIsNone(body["error"]) + self.assertEqual(body["data"]["created_by"]["username"], self.user.username) + + async def test_flowchart_detail_serializes_user_and_problem(self): + client = AsyncClient() + await sync_to_async(client.force_login)(self.user) + + response = await client.get(f"/api/flowchart/submission?id={self.flowchart_submission.id}") + + self.assertEqual(response.status_code, 200) + body = response.json() + self.assertIsNone(body["error"]) + self.assertEqual(body["data"]["username"], self.user.username) + self.assertEqual(body["data"]["problem"], self.problem.id) +``` + +- [ ] **Step 2: Run the regression tests** + +Run: + +```bash +rtk uv run python manage.py test utils.test_async_view_regressions -v 2 +``` + +Expected before Task 2: at least one test returns `server-error` or raises async-unsafe lazy relation access. + +- [ ] **Step 3: Commit the failing tests** + +Run: + +```bash +rtk git add utils/test_async_view_regressions.py +rtk git commit -m "test: cover async view serialization regressions" +``` + +--- + +### Task 2: Preload Relations For Already Converted Async Detail Views + +**Files:** +- Modify: `announcement/views/oj.py` +- Modify: `contest/views/oj.py` +- Modify: `flowchart/views/oj.py` + +- [ ] **Step 1: Fix announcement detail relation loading** + +In `announcement/views/oj.py`, replace the detail query with: + +```python +announcement = await ( + Announcement.objects.select_related("created_by") + .filter(id=id, visible=True) + .afirst() +) +if announcement is None: + raise Announcement.DoesNotExist +``` + +- [ ] **Step 2: Fix contest detail relation loading** + +In `contest/views/oj.py`, replace the detail query with: + +```python +contest = await ( + Contest.objects.select_related("created_by") + .filter(id=id, visible=True) + .afirst() +) +if contest is None: + raise Contest.DoesNotExist +``` + +- [ ] **Step 3: Fix flowchart submission detail relation loading** + +In `flowchart/views/oj.py`, update `FlowchartSubmissionAPI.get()`: + +```python +submission = await ( + FlowchartSubmission.objects.select_related("user", "problem") + .filter(id=submission_id) + .afirst() +) +if submission is None: + raise FlowchartSubmission.DoesNotExist +``` + +- [ ] **Step 4: Fix flowchart retry permission relation loading** + +In `flowchart/views/oj.py`, update `FlowchartSubmissionRetryAPI.post()`: + +```python +submission = await ( + FlowchartSubmission.objects.select_related("problem") + .filter(id=submission_id) + .afirst() +) +if submission is None: + raise FlowchartSubmission.DoesNotExist +``` + +- [ ] **Step 5: Fix flowchart completed-detail relation loading** + +In `flowchart/views/oj.py`, update `FlowchartSubmissionDetailAPI.get()` before serialization: + +```python +submissions = ( + FlowchartSubmission.objects.select_related("user", "problem") + .filter( + user=request.user, + problem=problem, + status=FlowchartSubmissionStatus.COMPLETED, + ) + .order_by("create_time") +) +``` + +- [ ] **Step 6: Run the regression tests** + +Run: + +```bash +rtk uv run python manage.py test utils.test_async_view_regressions -v 2 +``` + +Expected: all tests pass. + +- [ ] **Step 7: Run Django system checks** + +Run: + +```bash +rtk uv run python manage.py check +``` + +Expected: `System check identified no issues`. + +- [ ] **Step 8: Commit relation fixes** + +Run: + +```bash +rtk git add announcement/views/oj.py contest/views/oj.py flowchart/views/oj.py +rtk git commit -m "fix: preload relations in async detail views" +``` + +--- + +### Task 3: Add Async Serialization Helpers To `AsyncAPIView` + +**Files:** +- Modify: `utils/api/api.py` +- Modify: `utils/test_async_api.py` + +- [ ] **Step 1: Write tests for async serialization helper** + +Create `utils/test_async_api.py`: + +```python +import json + +from django.test import AsyncRequestFactory, SimpleTestCase + +from utils.api import AsyncAPIView, serializers, validate_serializer + + +class PayloadSerializer(serializers.Serializer): + name = serializers.CharField() + + +class EchoSerializer(serializers.Serializer): + name = serializers.CharField() + + +class AsyncValidatedEchoView(AsyncAPIView): + @validate_serializer(PayloadSerializer) + async def post(self, request): + return self.success({"name": request.data["name"]}) + + +class AsyncSerializationHelperTests(SimpleTestCase): + async def test_validate_serializer_supports_async_view_methods(self): + request = AsyncRequestFactory().post( + "/api/echo", + data=json.dumps({"name": "alice"}), + content_type="application/json", + ) + + response = await AsyncValidatedEchoView.as_view()(request) + + self.assertEqual(response.status_code, 200) + self.assertEqual(response.data["error"], None) + self.assertEqual(response.data["data"], {"name": "alice"}) + + async def test_async_serialize_data_returns_serializer_data(self): + view = AsyncAPIView() + + data = await view.async_serialize_data( + EchoSerializer, + [{"name": "alice"}, {"name": "bob"}], + many=True, + ) + + self.assertEqual(data, [{"name": "alice"}, {"name": "bob"}]) +``` + +- [ ] **Step 2: Run helper tests and confirm failure** + +Run: + +```bash +rtk uv run python manage.py test utils.test_async_api -v 2 +``` + +Expected before implementation: `AttributeError: 'AsyncAPIView' object has no attribute 'async_serialize_data'`. + +- [ ] **Step 3: Add `sync_to_async` import** + +In `utils/api/api.py`, add: + +```python +from asgiref.sync import sync_to_async +``` + +- [ ] **Step 4: Add serializer helper methods** + +In `AsyncAPIView`, before `async_paginate_data()`, add: + +```python + def serialize_data(self, object_serializer, data, **kwargs): + return object_serializer(data, **kwargs).data + + async def async_serialize_data(self, object_serializer, data, **kwargs): + return await sync_to_async( + self.serialize_data, + thread_sensitive=True, + )(object_serializer, data, **kwargs) +``` + +- [ ] **Step 5: Use helper inside async pagination** + +In `AsyncAPIView.async_paginate_data()`, replace: + +```python +if object_serializer: + results = object_serializer(results, many=True, context={"request": request}).data +``` + +with: + +```python +if object_serializer: + results = await self.async_serialize_data( + object_serializer, + results, + many=True, + context={"request": request}, + ) +``` + +- [ ] **Step 6: Run helper and regression tests** + +Run: + +```bash +rtk uv run python manage.py test utils.test_async_api utils.test_async_view_regressions -v 2 +``` + +Expected: all tests pass. + +- [ ] **Step 7: Commit async serializer helper** + +Run: + +```bash +rtk git add utils/api/api.py utils/test_async_api.py +rtk git commit -m "feat: add async serializer helper" +``` + +--- + +### Task 4: Add Cache Helpers For Async Views + +**Files:** +- Create: `utils/async_helpers.py` +- Create: `utils/test_async_helpers.py` +- Modify: `problem/views/oj.py` +- Modify: `comment/views/oj.py` + +- [ ] **Step 1: Write cache helper tests** + +Create `utils/test_async_helpers.py`: + +```python +from django.test import SimpleTestCase, override_settings + +from utils.async_helpers import async_cache_delete, async_cache_get, async_cache_set + + +@override_settings( + CACHES={ + "default": { + "BACKEND": "django.core.cache.backends.locmem.LocMemCache", + "LOCATION": "async-helper-tests", + } + } +) +class AsyncCacheHelperTests(SimpleTestCase): + async def test_async_cache_round_trip(self): + await async_cache_set("async:key", {"value": 1}, 30) + + value = await async_cache_get("async:key") + + self.assertEqual(value, {"value": 1}) + + async def test_async_cache_delete(self): + await async_cache_set("async:delete", "present", 30) + await async_cache_delete("async:delete") + + value = await async_cache_get("async:delete") + + self.assertIsNone(value) +``` + +- [ ] **Step 2: Run tests and confirm import failure** + +Run: + +```bash +rtk uv run python manage.py test utils.test_async_helpers -v 2 +``` + +Expected before implementation: `ModuleNotFoundError: No module named 'utils.async_helpers'`. + +- [ ] **Step 3: Create async cache helpers** + +Create `utils/async_helpers.py`: + +```python +from asgiref.sync import sync_to_async +from django.core.cache import cache + + +async def async_cache_get(key, default=None): + return await sync_to_async(cache.get, thread_sensitive=True)(key, default) + + +async def async_cache_set(key, value, timeout=None): + return await sync_to_async(cache.set, thread_sensitive=True)(key, value, timeout) + + +async def async_cache_delete(key): + return await sync_to_async(cache.delete, thread_sensitive=True)(key) +``` + +- [ ] **Step 4: Convert async problem cache calls** + +In `problem/views/oj.py`, add: + +```python +from utils.async_helpers import async_cache_get, async_cache_set +``` + +In `ProblemTagAPI.get()`, replace: + +```python +cached = cache.get(cache_key) +``` + +with: + +```python +cached = await async_cache_get(cache_key) +``` + +Replace: + +```python +cache.set(cache_key, data, 3600) +``` + +with: + +```python +await async_cache_set(cache_key, data, 3600) +``` + +- [ ] **Step 5: Convert async comment cache calls** + +In `comment/views/oj.py`, add: + +```python +from utils.async_helpers import async_cache_delete, async_cache_get, async_cache_set +``` + +In `CommentAPI.post()`, replace: + +```python +cache.delete(f"{CacheKey.comment_stats}:{problem.id}") +``` + +with: + +```python +await async_cache_delete(f"{CacheKey.comment_stats}:{problem.id}") +``` + +In `CommentStatisticsAPI.get()`, replace `cache.get()` and `cache.set()` with: + +```python +cached = await async_cache_get(cache_key) +``` + +and: + +```python +await async_cache_set(cache_key, data, 3600) +``` + +- [ ] **Step 6: Run helper and targeted async tests** + +Run: + +```bash +rtk uv run python manage.py test utils.test_async_helpers utils.test_async_api utils.test_async_view_regressions -v 2 +``` + +Expected: all tests pass. + +- [ ] **Step 7: Commit cache helper work** + +Run: + +```bash +rtk git add utils/async_helpers.py utils/test_async_helpers.py problem/views/oj.py comment/views/oj.py +rtk git commit -m "feat: add async cache helpers" +``` + +--- + +### Task 5: Make Permission Decorators Async-Aware Before More Conversions + +**Files:** +- Modify: `account/decorators.py` +- Create: `account/test_async_decorators.py` + +- [ ] **Step 1: Write tests for async `login_required`** + +Create `account/test_async_decorators.py`: + +```python +from django.contrib.auth.models import AnonymousUser +from django.test import AsyncRequestFactory, SimpleTestCase + +from account.decorators import login_required +from utils.api import AsyncAPIView + + +class DisabledUser: + is_authenticated = True + is_disabled = True + + +class ActiveUser: + is_authenticated = True + is_disabled = False + + +class ProtectedAsyncView(AsyncAPIView): + @login_required + async def get(self, request): + return self.success("ok") + + +class AsyncPermissionDecoratorTests(SimpleTestCase): + async def test_async_login_required_allows_active_user(self): + request = AsyncRequestFactory().get("/api/protected") + request.user = ActiveUser() + + response = await ProtectedAsyncView.as_view()(request) + + self.assertEqual(response.data["error"], None) + self.assertEqual(response.data["data"], "ok") + + async def test_async_login_required_rejects_anonymous_user(self): + request = AsyncRequestFactory().get("/api/protected") + request.user = AnonymousUser() + + response = await ProtectedAsyncView.as_view()(request) + + self.assertEqual(response.data["error"], "permission-denied") + self.assertEqual(response.data["data"], "Please login first") + + async def test_async_login_required_rejects_disabled_user(self): + request = AsyncRequestFactory().get("/api/protected") + request.user = DisabledUser() + + response = await ProtectedAsyncView.as_view()(request) + + self.assertEqual(response.data["error"], "permission-denied") + self.assertEqual(response.data["data"], "Your account is disabled") +``` + +- [ ] **Step 2: Run decorator tests** + +Run: + +```bash +rtk uv run python manage.py test account.test_async_decorators -v 2 +``` + +Expected before implementation: this may pass because `AsyncAPIView.dispatch()` awaits returned coroutine. Keep the test anyway as a contract before refactoring. + +- [ ] **Step 3: Refactor `BasePermissionDecorator` with explicit async path** + +In `account/decorators.py`, add: + +```python +import inspect +``` + +Replace `BasePermissionDecorator.__get__()` with: + +```python + def __get__(self, obj, obj_type): + if inspect.iscoroutinefunction(self.func): + return functools.partial(self._async_call, obj) + return functools.partial(self.__call__, obj) +``` + +Add this method to `BasePermissionDecorator`: + +```python + async def _async_call(self, *args, **kwargs): + self.request = args[1] + + if self.check_permission(): + if self.request.user.is_disabled: + return self.error("Your account is disabled") + return await self.func(*args, **kwargs) + return self.error("Please login first") +``` + +- [ ] **Step 4: Run decorator and async view tests** + +Run: + +```bash +rtk uv run python manage.py test account.test_async_decorators utils.test_async_api utils.test_async_view_regressions -v 2 +``` + +Expected: all tests pass. + +- [ ] **Step 5: Commit decorator refactor** + +Run: + +```bash +rtk git add account/decorators.py account/test_async_decorators.py +rtk git commit -m "refactor: make permission decorators async-aware" +``` + +--- + +### Task 6: Convert One Endpoint Family At A Time + +**Files:** +- Modify only the endpoint family being converted in each batch. +- Add or update tests in the same app, using `AsyncClient` for converted URLs. + +- [ ] **Step 1: Choose the next low-risk batch** + +Use this order: + +```text +1. Pure public GET list/detail endpoints already close to async: + - announcement list/detail + - contest list/detail + - problem tag/list/detail + +2. Authenticated read-only list endpoints: + - message list + - submission list + - flowchart list/detail/current + +3. Simple create/update endpoints with no contest permission decorator: + - message create + - comment create + - flowchart retry/create +``` + +- [ ] **Step 2: For each endpoint, write one async smoke test** + +Use this template and replace the URL and assertions with the exact endpoint response: + +```python +from django.test import AsyncClient, TestCase + + +class EndpointAsyncSmokeTests(TestCase): + async def test_endpoint_returns_success_envelope(self): + response = await AsyncClient().get("/api/endpoint?limit=10") + + self.assertEqual(response.status_code, 200) + body = response.json() + self.assertIn("error", body) + self.assertIn("data", body) +``` + +- [ ] **Step 3: Convert ORM access only where async ORM exists** + +Use these replacements: + +```python +obj = await Model.objects.aget(id=id) +count = await queryset.acount() +first = await queryset.afirst() +last = await queryset.alast() +items = [item async for item in queryset[offset:offset + limit]] +created = await Model.objects.acreate(field=value) +await instance.asave(update_fields=["field"]) +``` + +- [ ] **Step 4: Keep sync-only helpers behind `sync_to_async`** + +Use this pattern: + +```python +from asgiref.sync import sync_to_async + +result = await sync_to_async(sync_helper, thread_sensitive=True)(arg1, arg2) +``` + +- [ ] **Step 5: Run targeted app tests and system check after each batch** + +Run: + +```bash +rtk uv run python manage.py test -v 2 +rtk uv run python manage.py check +``` + +Expected: targeted tests pass and system check reports no issues. + +- [ ] **Step 6: Commit each endpoint family separately** + +Run: + +```bash +rtk git add +rtk git commit -m "refactor: async endpoints" +``` + +--- + +### Task 7: Audit Middleware Before Claiming Full Async Benefit + +**Files:** +- Modify: `account/middleware.py` +- Add tests only for behavior that changes. + +- [ ] **Step 1: Record current sync middleware boundaries** + +Before changing middleware, note these current sync classes: + +```text +account.middleware.APITokenAuthMiddleware +account.middleware.AdminRoleRequiredMiddleware +account.middleware.SessionRecordMiddleware +``` + +- [ ] **Step 2: Keep middleware sync during endpoint correctness work** + +Do not convert middleware in the same commit as endpoint conversions. Middleware affects every request and needs its own review. + +- [ ] **Step 3: If middleware conversion is pursued, convert one class per commit** + +Use Django new-style middleware with explicit sync and async handling. `__acall__` is a local helper; `__call__` dispatches to it when Django passes an async `get_response`: + +```python +from asgiref.sync import iscoroutinefunction, markcoroutinefunction, sync_to_async + + +class ExampleMiddleware: + sync_capable = True + async_capable = True + + def __init__(self, get_response): + self.get_response = get_response + self.is_async = iscoroutinefunction(get_response) + if self.is_async: + markcoroutinefunction(self) + + def __call__(self, request): + if self.is_async: + return self.__acall__(request) + response = self.process_request(request) + if response is not None: + return response + return self.get_response(request) + + async def __acall__(self, request): + response = await self.aprocess_request(request) + if response is not None: + return response + return await self.get_response(request) + + def process_request(self, request): + return None + + async def aprocess_request(self, request): + return await sync_to_async(self.process_request, thread_sensitive=True)(request) +``` + +- [ ] **Step 4: Run full backend test command after middleware work** + +Run: + +```bash +rtk uv run python manage.py test -v 2 +rtk uv run python manage.py check +``` + +Expected: all tests pass and system check reports no issues. + +--- + +## Definition Of Done + +- `rtk uv run python manage.py check` passes. +- Async regression tests cover converted detail serializers that depend on FK relations. +- Converted async views do not call DRF serializer `.data` directly unless the data is primitive and relation-free. +- Cache access from async views uses async helper wrappers. +- New endpoint conversions are committed in endpoint-family-sized commits. +- No file upload/download, SMTP, judge heartbeat, test-case prune, or contest permission-heavy endpoint is converted without a separate focused plan. diff --git a/flowchart/tests.py b/flowchart/tests.py index 9b4704a..92e7e92 100644 --- a/flowchart/tests.py +++ b/flowchart/tests.py @@ -11,6 +11,6 @@ class FlowchartEvaluationPromptTests(TestCase): self.assertIn("Mermaid节点ID由系统生成", prompt) self.assertIn("不要评价节点ID", prompt) self.assertIn("不要因节点ID扣分", prompt) - self.assertIn("feedback控制在0字以内", prompt) + self.assertIn("feedback控制在100字以内", prompt) self.assertIn("suggestions最多3条", prompt) self.assertIn("重要建议必须以【重点】开头", prompt) diff --git a/flowchart/views/oj.py b/flowchart/views/oj.py index c75a039..60b0e63 100644 --- a/flowchart/views/oj.py +++ b/flowchart/views/oj.py @@ -7,65 +7,63 @@ from flowchart.serializers import ( ) from flowchart.tasks import evaluate_flowchart_task from problem.models import Problem -from utils.api import APIView +from utils.api import AsyncAPIView -class FlowchartSubmissionAPI(APIView): +class FlowchartSubmissionAPI(AsyncAPIView): @login_required - def post(self, request): - """创建流程图提交""" + async def post(self, request): serializer = CreateFlowchartSubmissionSerializer(data=request.data) if not serializer.is_valid(): return self.error(serializer.errors) data = serializer.validated_data - # 验证题目存在 try: - problem = Problem.objects.get(id=data["problem_id"]) + problem = await Problem.objects.aget(id=data["problem_id"]) except Problem.DoesNotExist: return self.error("Problem doesn't exist") - # 验证题目是否允许流程图提交 if not problem.allow_flowchart: return self.error("This problem does not allow flowchart submission") - # 创建提交记录 - submission = FlowchartSubmission.objects.create( + submission = await FlowchartSubmission.objects.acreate( user=request.user, problem=problem, mermaid_code=data["mermaid_code"], flowchart_data=data.get("flowchart_data", {}), ) - # 启动AI评分任务 evaluate_flowchart_task.send(submission.id) return self.success({"submission_id": submission.id, "status": "pending"}) @login_required - def get(self, request): - """获取流程图提交详情""" + async def get(self, request): submission_id = request.GET.get("id") if not submission_id: return self.error("submission_id is required") try: - submission = FlowchartSubmission.objects.get(id=submission_id) + submission = await ( + FlowchartSubmission.objects.select_related("user", "problem") + .filter(id=submission_id) + .afirst() + ) + if submission is None: + raise FlowchartSubmission.DoesNotExist except FlowchartSubmission.DoesNotExist: return self.error("Submission doesn't exist") if not submission.check_user_permission(request.user): return self.error("No permission for this submission") - serializer = FlowchartSubmissionSerializer(submission) - return self.success(serializer.data) + return self.success(await self.async_serialize_data(FlowchartSubmissionSerializer, submission)) -class FlowchartSubmissionListAPI(APIView): +class FlowchartSubmissionListAPI(AsyncAPIView): @login_required - def get(self, request): - """获取流程图提交列表""" + async def get(self, request): username = request.GET.get("username") problem_id = request.GET.get("problem_id") myself = request.GET.get("myself") @@ -74,7 +72,7 @@ class FlowchartSubmissionListAPI(APIView): if problem_id: try: - problem = Problem.objects.get( + problem = await Problem.objects.aget( _id__iexact=problem_id, contest_id__isnull=True, visible=True ) except Problem.DoesNotExist: @@ -88,38 +86,42 @@ class FlowchartSubmissionListAPI(APIView): elif request.user.is_regular_user(): queryset = queryset.filter(user=request.user) - data = self.paginate_data(request, queryset) - data["results"] = FlowchartSubmissionListSerializer( - data["results"], many=True - ).data + data = await self.async_paginate_data(request, queryset) + data["results"] = await self.async_serialize_data( + FlowchartSubmissionListSerializer, + data["results"], + many=True, + ) return self.success(data) -class FlowchartSubmissionRetryAPI(APIView): +class FlowchartSubmissionRetryAPI(AsyncAPIView): @login_required - def post(self, request): - """重新触发AI评分""" + async def post(self, request): submission_id = request.data.get("submission_id") if not submission_id: return self.error("submission_id is required") try: - submission = FlowchartSubmission.objects.get(id=submission_id) + submission = await ( + FlowchartSubmission.objects.select_related("problem") + .filter(id=submission_id) + .afirst() + ) + if submission is None: + raise FlowchartSubmission.DoesNotExist except FlowchartSubmission.DoesNotExist: return self.error("Submission doesn't exist") - # 检查权限 if not submission.check_user_permission(request.user): return self.error("No permission for this submission") - # 检查是否可以重新评分 if submission.status not in [ FlowchartSubmissionStatus.FAILED, FlowchartSubmissionStatus.COMPLETED, ]: return self.error("Submission is not in a state that allows retry") - # 重置状态并重新启动AI评分 submission.status = FlowchartSubmissionStatus.PENDING submission.ai_score = None submission.ai_grade = None @@ -128,9 +130,8 @@ class FlowchartSubmissionRetryAPI(APIView): submission.ai_criteria_details = {} submission.processing_time = None submission.evaluation_time = None - submission.save() + await submission.asave() - # 重新启动AI评分任务 evaluate_flowchart_task.send(submission.id) return self.success( @@ -142,15 +143,14 @@ class FlowchartSubmissionRetryAPI(APIView): ) -class FlowchartSubmissionDetailAPI(APIView): +class FlowchartSubmissionDetailAPI(AsyncAPIView): @login_required - def get(self, request): - """获取当前用户对指定题目的流程图提交详情""" + async def get(self, request): problem_id = request.GET.get("problem_id") if not problem_id: return self.error("problem_id is required") try: - problem = Problem.objects.get(id=problem_id) + problem = await Problem.objects.aget(id=problem_id) except Problem.DoesNotExist: return self.error("Problem doesn't exist") @@ -158,34 +158,37 @@ class FlowchartSubmissionDetailAPI(APIView): page = int(request.GET.get("page", 0)) except ValueError: return self.error("page must be an integer") - submissions = FlowchartSubmission.objects.filter( - user=request.user, - problem=problem, - status=FlowchartSubmissionStatus.COMPLETED, - ).order_by("create_time") - count = submissions.count() + submissions = ( + FlowchartSubmission.objects.select_related("user", "problem") + .filter( + user=request.user, + problem=problem, + status=FlowchartSubmissionStatus.COMPLETED, + ) + .order_by("create_time") + ) + count = await submissions.acount() if count == 0: return self.success({"submission": None, "count": 0}) - # page=0 means latest; page=N means the Nth submission (1-indexed, chronological) if page == 0: - submission = submissions.last() + submission = await submissions.alast() else: if page < 0 or page > count: return self.error("Page out of range") - submission = submissions[page - 1] - serializer = FlowchartSubmissionSerializer(submission) - return self.success({"submission": serializer.data, "count": count}) + result = [s async for s in submissions[page - 1:page]] + submission = result[0] + data = await self.async_serialize_data(FlowchartSubmissionSerializer, submission) + return self.success({"submission": data, "count": count}) -class FlowchartSubmissionCurrentAPI(APIView): +class FlowchartSubmissionCurrentAPI(AsyncAPIView): @login_required - def get(self, request): - """获取当前用户对指定题目的最新流程图提交,只返回次数和分数""" + async def get(self, request): problem_id = request.GET.get("problem_id") if not problem_id: return self.error("problem_id is required") try: - problem = Problem.objects.get(id=problem_id) + problem = await Problem.objects.aget(id=problem_id) except Problem.DoesNotExist: return self.error("Problem doesn't exist") submissions = ( @@ -197,10 +200,10 @@ class FlowchartSubmissionCurrentAPI(APIView): .values("ai_score", "ai_grade") .order_by("-create_time") ) - count = submissions.count() + count = await submissions.acount() if count == 0: return self.success({"count": 0, "score": 0, "grade": ""}) - submission = submissions[0] + submission = await submissions.afirst() return self.success( { "count": count, diff --git a/message/views/oj.py b/message/views/oj.py index e3d200d..28fb3a4 100644 --- a/message/views/oj.py +++ b/message/views/oj.py @@ -3,33 +3,33 @@ from account.models import User from message.models import Message from message.serializers import CreateMessageSerializer, MessageSerializer from submission.models import Submission -from utils.api import APIView +from utils.api import AsyncAPIView from utils.api.api import validate_serializer -class MessageAPI(APIView): +class MessageAPI(AsyncAPIView): @login_required - def get(self, request): + async def get(self, request): messages = Message.objects.select_related( "recipient", "sender", "submission", "submission__problem" ).filter(recipient=request.user) - return self.success(self.paginate_data(request, messages, MessageSerializer)) + return self.success(await self.async_paginate_data(request, messages, MessageSerializer)) @validate_serializer(CreateMessageSerializer) @super_admin_required - def post(self, request): + async def post(self, request): data = request.data if data["recipient"] == request.user.id: return self.error("Can not send a message to youself") try: - recipient = User.objects.get(id=data["recipient"], is_disabled=False) + recipient = await User.objects.aget(id=data["recipient"], is_disabled=False) except User.DoesNotExist: return self.error("User does not exist") try: - submission = Submission.objects.get(id=data["submission"]) + submission = await Submission.objects.aget(id=data["submission"]) except Submission.DoesNotExist: return self.error("Submission does not exist") - Message.objects.create( + await Message.objects.acreate( submission=submission, message=data["message"], sender=request.user, diff --git a/options/options.py b/options/options.py index b3c5a05..fa673a4 100644 --- a/options/options.py +++ b/options/options.py @@ -292,4 +292,15 @@ class _SysOptionsMeta(type): class SysOptions(metaclass=_SysOptionsMeta): - pass + @classmethod + async def aget(cls, key): + from asgiref.sync import sync_to_async + return await sync_to_async(getattr)(cls, key) + + @classmethod + async def aget_many(cls, *keys): + from asgiref.sync import sync_to_async + + def _get_all(): + return {k: getattr(cls, k) for k in keys} + return await sync_to_async(_get_all)() diff --git a/problem/views/oj.py b/problem/views/oj.py index 0cdded0..8fc4c96 100644 --- a/problem/views/oj.py +++ b/problem/views/oj.py @@ -1,15 +1,16 @@ import random -from datetime import datetime -from django.core.cache import cache +from asgiref.sync import sync_to_async from django.db.models import BooleanField, Case, Count, Q, Value, When from django.db.models.functions import ExtractYear +from django.utils import timezone from account.decorators import check_contest_permission from account.models import User from contest.models import ContestRuleType from submission.models import JudgeStatus, Submission -from utils.api import APIView +from utils.api import APIView, AsyncAPIView +from utils.async_helpers import async_cache_get, async_cache_set from utils.constants import CacheKey from ..models import Problem, ProblemTag @@ -21,11 +22,11 @@ from ..serializers import ( ) -class ProblemTagAPI(APIView): - def get(self, request): +class ProblemTagAPI(AsyncAPIView): + async def get(self, request): keyword = request.GET.get("keyword", "") cache_key = f"{CacheKey.problem_tags}:{keyword}" - cached = cache.get(cache_key) + cached = await async_cache_get(cache_key) if cached is not None: return self.success(cached) @@ -33,48 +34,48 @@ class ProblemTagAPI(APIView): if keyword: qs = ProblemTag.objects.filter(name__icontains=keyword) tags = qs.annotate(problem_count=Count("problem")).filter(problem_count__gt=0) - data = TagSerializer(tags, many=True).data - cache.set(cache_key, data, 3600) + data = await self.async_serialize_data(TagSerializer, [tag async for tag in tags], many=True) + await async_cache_set(cache_key, data, 3600) return self.success(data) -class PickOneAPI(APIView): - def get(self, request): - problems = Problem.objects.filter(contest_id__isnull=True, visible=True) - count = problems.count() +class PickOneAPI(AsyncAPIView): + async def get(self, request): + ids = Problem.objects.filter(contest_id__isnull=True, visible=True).values_list("_id", flat=True) + count = await ids.acount() if count == 0: return self.error("No problem to pick") - return self.success(problems[random.randint(0, count - 1)]._id) + idx = random.randint(0, count - 1) + result = [pid async for pid in ids[idx : idx + 1]] + return self.success(result[0]) -class ProblemAPI(APIView): +class ProblemAPI(AsyncAPIView): @staticmethod - def _add_problem_status(request, queryset_values): - if request.user.is_authenticated: - profile = request.user.userprofile - acm_problems_status = profile.acm_problems_status.get("problems", {}) - # paginate data - results = queryset_values.get("results") - if results is not None: - problems = results - else: - problems = [queryset_values] - for problem in problems: - problem["my_status"] = acm_problems_status.get( - str(problem["id"]), {} - ).get("status") + def _add_problem_status(acm_problems_status, queryset_values): + results = queryset_values.get("results") + if results is not None: + problems = results + else: + problems = [queryset_values] + for problem in problems: + problem["my_status"] = acm_problems_status.get(str(problem["id"]), {}).get("status") - def get(self, request): + async def get(self, request): # 问题详情页 problem_id = request.GET.get("problem_id") if problem_id: try: - problem = Problem.objects.select_related("created_by").get( - _id__iexact=problem_id, contest_id__isnull=True, visible=True - ) - problem_data = ProblemSerializer(problem).data - self._add_problem_status(request, problem_data) + problem = await Problem.objects.select_related("created_by").prefetch_related("tags").filter(_id__iexact=problem_id, contest_id__isnull=True, visible=True).afirst() + if problem is None: + raise Problem.DoesNotExist + problem_data = await self.async_serialize_data(ProblemSerializer, problem) if request.user.is_authenticated: + from account.models import UserProfile + + profile = await UserProfile.objects.aget(user=request.user) + acm_problems_status = profile.acm_problems_status.get("problems", {}) + self._add_problem_status(acm_problems_status, problem_data) failed_statuses = [ JudgeStatus.WRONG_ANSWER, JudgeStatus.CPU_TIME_LIMIT_EXCEEDED, @@ -83,11 +84,11 @@ class ProblemAPI(APIView): JudgeStatus.RUNTIME_ERROR, JudgeStatus.COMPILE_ERROR, ] - problem_data["my_failed_count"] = Submission.objects.filter( + problem_data["my_failed_count"] = await Submission.objects.filter( user_id=request.user.id, problem_id=problem.id, result__in=failed_statuses, - ).count() + ).acount() else: problem_data["my_failed_count"] = 0 return self.success(problem_data) @@ -98,12 +99,7 @@ class ProblemAPI(APIView): if not limit: return self.error("Limit is needed") - problems = ( - Problem.objects.select_related("created_by") - .prefetch_related("tags") - .filter(contest_id__isnull=True, visible=True) - .order_by("-create_time") - ) + problems = Problem.objects.select_related("created_by").prefetch_related("tags").filter(contest_id__isnull=True, visible=True).order_by("-create_time") author = request.GET.get("author") if author: @@ -117,9 +113,7 @@ class ProblemAPI(APIView): # 搜索的情况 keyword = request.GET.get("keyword", "").strip() if keyword: - problems = problems.filter( - Q(title__icontains=keyword) | Q(_id__icontains=keyword) - ) + problems = problems.filter(Q(title__icontains=keyword) | Q(_id__icontains=keyword)) # 难度筛选 difficulty = request.GET.get("difficulty") @@ -142,8 +136,13 @@ class ProblemAPI(APIView): problems = problems.order_by(sort) # 根据profile 为做过的题目添加标记 - data = self.paginate_data(request, problems, ProblemListSerializer) - self._add_problem_status(request, data) + data = await self.async_paginate_data(request, problems, ProblemListSerializer) + if request.user.is_authenticated: + from account.models import UserProfile + + profile = await UserProfile.objects.aget(user=request.user) + acm_problems_status = profile.acm_problems_status.get("problems", {}) + self._add_problem_status(acm_problems_status, data) return self.success(data) @@ -152,24 +151,18 @@ class ContestProblemAPI(APIView): if request.user.is_authenticated: profile = request.user.userprofile if self.contest.rule_type == ContestRuleType.ACM: - problems_status = profile.acm_problems_status.get( - "contest_problems", {} - ) + problems_status = profile.acm_problems_status.get("contest_problems", {}) else: problems_status = profile.oi_problems_status.get("contest_problems", {}) for problem in queryset_values: - problem["my_status"] = problems_status.get(str(problem["id"]), {}).get( - "status" - ) + problem["my_status"] = problems_status.get(str(problem["id"]), {}).get("status") @check_contest_permission(check_type="problems") def get(self, request): problem_id = request.GET.get("problem_id") if problem_id: try: - problem = Problem.objects.select_related("created_by").get( - _id__iexact=problem_id, contest=self.contest, visible=True - ) + problem = Problem.objects.select_related("created_by").get(_id__iexact=problem_id, contest=self.contest, visible=True) except Problem.DoesNotExist: return self.error("Problem does not exist.") if self.contest.problem_details_permission(request.user): @@ -184,9 +177,7 @@ class ContestProblemAPI(APIView): problem_data = ProblemSafeSerializer(problem).data return self.success(problem_data) - contest_problems = Problem.objects.select_related("created_by").prefetch_related("tags").filter( - contest=self.contest, visible=True - ) + contest_problems = Problem.objects.select_related("created_by").prefetch_related("tags").filter(contest=self.contest, visible=True) if self.contest.problem_details_permission(request.user): data = ProblemListSerializer(contest_problems, many=True).data self._add_problem_status(request, data) @@ -195,59 +186,60 @@ class ContestProblemAPI(APIView): return self.success(data) -class ProblemSolvedPeopleCount(APIView): - def get(self, request): +class ProblemSolvedPeopleCount(AsyncAPIView): + async def get(self, request): problem_id = request.GET.get("problem_id") rate = "0" if not request.user.is_authenticated: return self.success(rate) - submission_count = Submission.objects.filter( + submission_count = await Submission.objects.filter( user_id=request.user.id, problem_id=problem_id, result__in=[JudgeStatus.ACCEPTED, JudgeStatus.AST_CHECK_FAILED], - ).count() + ).acount() if submission_count == 0: return self.success(rate) - today = datetime.today() - years_ago = datetime(today.year - 2, today.month, today.day, 0, 0) - total_count = User.objects.filter( - is_disabled=False, last_login__gte=years_ago - ).count() - accepted_count = Submission.objects.filter( - problem_id=problem_id, - result__in=[JudgeStatus.ACCEPTED, JudgeStatus.AST_CHECK_FAILED], - create_time__gte=years_ago, - ).aggregate(user_count=Count("user_id", distinct=True))["user_count"] - if accepted_count < total_count: + now = timezone.now() + years_ago = now.replace(year=now.year - 2, hour=0, minute=0, second=0, microsecond=0) + total_count = await User.objects.filter(is_disabled=False, last_login__gte=years_ago).acount() + accepted_count = ( + await sync_to_async( + Submission.objects.filter( + problem_id=problem_id, + result__in=[JudgeStatus.ACCEPTED, JudgeStatus.AST_CHECK_FAILED], + create_time__gte=years_ago, + ).aggregate, + thread_sensitive=True, + )(user_count=Count("user_id", distinct=True)) + )["user_count"] + if total_count and accepted_count < total_count: rate = "%.2f" % ((total_count - accepted_count) / total_count * 100) else: rate = "0" return self.success(rate) -class SimilarProblemAPI(APIView): - def get(self, request): +class SimilarProblemAPI(AsyncAPIView): + async def get(self, request): problem_display_id = request.GET.get("problem_id") if not problem_display_id: return self.error("problem_id is required") try: - problem = Problem.objects.get(_id__iexact=problem_display_id, contest__isnull=True) + problem = await Problem.objects.aget(_id__iexact=problem_display_id, contest__isnull=True) except Problem.DoesNotExist: return self.error("Problem not found") - tag_ids = list(problem.tags.values_list("id", flat=True)) + tag_ids = [tag_id async for tag_id in problem.tags.values_list("id", flat=True)] if not tag_ids: return self.success([]) exclude_ids = [problem_display_id] if request.user.is_authenticated: - profile = request.user.userprofile - ac_display_ids = [ - v["_id"] - for v in profile.acm_problems_status.get("problems", {}).values() - if v.get("status") == JudgeStatus.ACCEPTED - ] + from account.models import UserProfile + + profile = await UserProfile.objects.aget(user=request.user) + ac_display_ids = [v["_id"] for v in profile.acm_problems_status.get("problems", {}).values() if v.get("status") == JudgeStatus.ACCEPTED] exclude_ids.extend(ac_display_ids) similar = ( @@ -258,14 +250,15 @@ class SimilarProblemAPI(APIView): .distinct() .order_by("difficulty")[:5] ) - return self.success(ProblemListSerializer(similar, many=True).data) + similar_list = [problem async for problem in similar] + return self.success(await self.async_serialize_data(ProblemListSerializer, similar_list, many=True)) -class ProblemAuthorAPI(APIView): - def get(self, request): +class ProblemAuthorAPI(AsyncAPIView): + async def get(self, request): show_all = request.GET.get("all", "0") == "1" cache_key = f"{CacheKey.problem_authors}{'_all' if show_all else '_only_visible'}" - cached_data = cache.get(cache_key) + cached_data = await async_cache_get(cache_key) if cached_data: return self.success(cached_data) @@ -273,38 +266,32 @@ class ProblemAuthorAPI(APIView): if not show_all: problem_filter["visible"] = True - authors = ( - Problem.objects.filter(**problem_filter) - .values("created_by__username") - .annotate(problem_count=Count("id")) - .order_by("-problem_count") - ) + authors = Problem.objects.filter(**problem_filter).values("created_by__username").annotate(problem_count=Count("id")).order_by("-problem_count") result = [ { "username": author["created_by__username"], "problem_count": author["problem_count"], } - for author in authors + async for author in authors ] - cache.set(cache_key, result, 7200) + await async_cache_set(cache_key, result, 7200) + return self.success(result) -class ProblemYearlyACRateAPI(APIView): - def get(self, request): +class ProblemYearlyACRateAPI(AsyncAPIView): + async def get(self, request): problem_id = request.GET.get("problem_id") if not problem_id: return self.error("problem_id is required") cache_key = f"{CacheKey.problem_yearly_ac}:{problem_id}" - cached = cache.get(cache_key) + cached = await async_cache_get(cache_key) if cached is not None: return self.success(cached) try: - problem = Problem.objects.get( - _id__iexact=problem_id, contest_id__isnull=True, visible=True - ) + problem = await Problem.objects.aget(_id__iexact=problem_id, contest_id__isnull=True, visible=True) except Problem.DoesNotExist: return self.error("Problem does not exist") @@ -328,12 +315,10 @@ class ProblemYearlyACRateAPI(APIView): "year": row["year"], "total": row["total"], "accepted": row["accepted"], - "ac_rate": round(row["accepted"] / row["total"] * 100, 2) - if row["total"] > 0 - else 0.0, + "ac_rate": round(row["accepted"] / row["total"] * 100, 2) if row["total"] > 0 else 0.0, } - for row in rows + async for row in rows ] - cache.set(cache_key, data, 3600) + await async_cache_set(cache_key, data, 3600) return self.success(data) diff --git a/problemset/urls/admin.py b/problemset/urls/admin.py index 4382633..77726eb 100644 --- a/problemset/urls/admin.py +++ b/problemset/urls/admin.py @@ -63,7 +63,7 @@ urlpatterns = [ name="admin_problemset_progress_detail_api", ), # 题单同步管理API - path( + path( # DEPRECATED: 前端未调用 "problemset//sync", ProblemSetSyncAPI.as_view(), name="admin_problemset_sync_api", diff --git a/problemset/urls/oj.py b/problemset/urls/oj.py index 0f20df8..7d1b5de 100644 --- a/problemset/urls/oj.py +++ b/problemset/urls/oj.py @@ -24,7 +24,7 @@ urlpatterns = [ ProblemSetProblemAPI.as_view(), name="problemset_problems_api", ), - path( + path( # DEPRECATED: 前端未调用 "problemset//problems/", ProblemSetProblemAPI.as_view(), name="problemset_problem_detail_api", @@ -35,12 +35,12 @@ urlpatterns = [ ProblemSetProgressAPI.as_view(), name="problemset_progress_api", ), - path( + path( # DEPRECATED: 前端未调用 "problemset//progress", ProblemSetProgressAPI.as_view(), name="problemset_progress_detail_api", ), - path("user/progress", UserProgressAPI.as_view(), name="user_progress_api"), + path("user/progress", UserProgressAPI.as_view(), name="user_progress_api"), # DEPRECATED: 前端未调用 # 奖章相关API path("user/badges", UserBadgeAPI.as_view(), name="user_badges_api"), path( diff --git a/problemset/views/admin.py b/problemset/views/admin.py index 6f26ec8..74dd58e 100644 --- a/problemset/views/admin.py +++ b/problemset/views/admin.py @@ -332,6 +332,7 @@ class ProblemSetProgressAdminAPI(APIView): return self.error("用户未加入该题单") +# DEPRECATED: 前端未调用 (2026-05-26) class ProblemSetSyncAPI(APIView): """题单同步管理API""" diff --git a/problemset/views/oj.py b/problemset/views/oj.py index 8c16c1f..8d8d48b 100644 --- a/problemset/views/oj.py +++ b/problemset/views/oj.py @@ -1,3 +1,4 @@ +from asgiref.sync import sync_to_async from django.db.models import Avg, Count, Prefetch, Q from django.utils import timezone @@ -24,14 +25,14 @@ from problemset.serializers import ( UpdateProgressSerializer, UserBadgeSerializer, ) -from submission.models import JudgeStatus, Submission, is_accepted -from utils.api import APIView, validate_serializer +from submission.models import Submission, is_accepted +from utils.api import APIView, AsyncAPIView, validate_serializer -class ProblemSetAPI(APIView): +class ProblemSetAPI(AsyncAPIView): """题单API - 用户端""" - def get(self, request): + async def get(self, request): """获取题单列表""" # 预加载创建者信息 problem_sets = ProblemSet.objects.filter(visible=True).exclude(status=ProblemSetStatus.DRAFT).select_related("created_by") @@ -65,16 +66,19 @@ class ProblemSetAPI(APIView): user_earned_badge_ids = set() if request.user.is_authenticated: # 先获取所有题单ID(不应用prefetch_related,只获取ID) - problem_set_ids = list(problem_sets.values_list("id", flat=True)) + problem_set_ids = [problem_set_id async for problem_set_id in problem_sets.values_list("id", flat=True)] if problem_set_ids: # 批量查询用户在这些题单中的进度 user_progresses = ProblemSetProgress.objects.filter(problemset_id__in=problem_set_ids, user=request.user).select_related("problemset") # 构建映射:题单ID -> 进度对象 - user_progress_map = {progress.problemset_id: progress for progress in user_progresses} + user_progress_map = {progress.problemset_id: progress async for progress in user_progresses} # 批量查询用户已获得的奖章ID(这些题单相关的) - user_earned_badge_ids = set(UserBadge.objects.filter(user=request.user, badge__problemset_id__in=problem_set_ids).values_list("badge_id", flat=True)) + user_earned_badge_ids = { + badge_id + async for badge_id in UserBadge.objects.filter(user=request.user, badge__problemset_id__in=problem_set_ids).values_list("badge_id", flat=True) + } # 预加载奖章信息(在获取ID之后应用,避免在获取ID时也预加载) problem_sets = problem_sets.prefetch_related(Prefetch("problemsetbadge_set", queryset=ProblemSetBadge.objects.all(), to_attr="badges")) @@ -83,31 +87,35 @@ class ProblemSetAPI(APIView): request._user_progress_map = user_progress_map request._user_earned_badge_ids = user_earned_badge_ids - data = self.paginate_data(request, problem_sets, ProblemSetListSerializer) + data = await self.async_paginate_data(request, problem_sets, ProblemSetListSerializer) return self.success(data) -class ProblemSetDetailAPI(APIView): +class ProblemSetDetailAPI(AsyncAPIView): """题单详情API - 用户端""" - def get(self, request, problem_set_id): + async def get(self, request, problem_set_id): """获取题单详情""" try: - problem_set = ProblemSet.objects.filter(id=problem_set_id, visible=True).exclude(status=ProblemSetStatus.DRAFT).get() + problem_set = await ( + ProblemSet.objects.select_related("created_by") + .filter(id=problem_set_id, visible=True) + .exclude(status=ProblemSetStatus.DRAFT) + .aget() + ) except ProblemSet.DoesNotExist: return self.error("题单不存在") - serializer = ProblemSetSerializer(problem_set, context={"request": request}) - return self.success(serializer.data) + return self.success(await self.async_serialize_data(ProblemSetSerializer, problem_set, context={"request": request})) -class ProblemSetProblemAPI(APIView): +class ProblemSetProblemAPI(AsyncAPIView): """题单题目API - 用户端""" - def get(self, request, problem_set_id): + async def get(self, request, problem_set_id): """获取题单中的题目列表""" try: - problem_set = ProblemSet.objects.filter(id=problem_set_id, visible=True).exclude(status=ProblemSetStatus.DRAFT).get() + problem_set = await ProblemSet.objects.filter(id=problem_set_id, visible=True).exclude(status=ProblemSetStatus.DRAFT).aget() except ProblemSet.DoesNotExist: return self.error("题单不存在") @@ -115,12 +123,16 @@ class ProblemSetProblemAPI(APIView): # 预取当前用户的题单进度,供 get_is_completed 使用,避免 N+1 user_progress = None if request.user.is_authenticated: - try: - user_progress = ProblemSetProgress.objects.get(problemset=problem_set, user=request.user) - except ProblemSetProgress.DoesNotExist: - pass - serializer = ProblemSetProblemSerializer(problems, many=True, context={"request": request, "user_progress": user_progress}) - return self.success(serializer.data) + user_progress = await ProblemSetProgress.objects.filter(problemset=problem_set, user=request.user).afirst() + problem_list = [problem async for problem in problems] + return self.success( + await self.async_serialize_data( + ProblemSetProblemSerializer, + problem_list, + many=True, + context={"request": request, "user_progress": user_progress}, + ) + ) class ProblemSetProgressAPI(APIView): @@ -236,6 +248,7 @@ class ProblemSetProgressAPI(APIView): UserBadge.objects.create(user=progress.user, badge=badge) +# DEPRECATED: 前端未调用 (2026-05-26) class UserProgressAPI(APIView): """用户进度API""" @@ -247,10 +260,10 @@ class UserProgressAPI(APIView): return self.success(serializer.data) -class UserBadgeAPI(APIView): +class UserBadgeAPI(AsyncAPIView): """用户奖章API""" - def get(self, request): + async def get(self, request): """获取用户的奖章列表""" # 支持通过username参数获取指定用户的徽章 username = request.GET.get("username") @@ -258,41 +271,41 @@ class UserBadgeAPI(APIView): if username: # 获取指定用户的徽章 try: - target_user = User.objects.get(username=username, is_disabled=False) - badges = UserBadge.objects.filter(user=target_user).order_by("-earned_time") + target_user = await User.objects.aget(username=username, is_disabled=False) + badges = UserBadge.objects.select_related("badge").filter(user=target_user).order_by("-earned_time") except User.DoesNotExist: return self.error("用户不存在") else: # 获取当前用户的徽章 - badges = UserBadge.objects.filter(user=request.user).order_by("-earned_time") + badges = UserBadge.objects.select_related("badge").filter(user=request.user).order_by("-earned_time") - serializer = UserBadgeSerializer(badges, many=True) - return self.success(serializer.data) + badge_list = [badge async for badge in badges] + return self.success(await self.async_serialize_data(UserBadgeSerializer, badge_list, many=True)) -class ProblemSetBadgeAPI(APIView): +class ProblemSetBadgeAPI(AsyncAPIView): """题单奖章API - 用户端""" - def get(self, request, problem_set_id): + async def get(self, request, problem_set_id): """获取题单的奖章列表""" try: - problem_set = ProblemSet.objects.filter(id=problem_set_id, visible=True).exclude(status=ProblemSetStatus.DRAFT).get() + problem_set = await ProblemSet.objects.filter(id=problem_set_id, visible=True).exclude(status=ProblemSetStatus.DRAFT).aget() except ProblemSet.DoesNotExist: return self.error("题单不存在") badges = ProblemSetBadge.objects.filter(problemset=problem_set) - serializer = ProblemSetBadgeSerializer(badges, many=True) - return self.success(serializer.data) + badge_list = [badge async for badge in badges] + return self.success(await self.async_serialize_data(ProblemSetBadgeSerializer, badge_list, many=True)) -class ProblemSetUserProgressAPI(APIView): +class ProblemSetUserProgressAPI(AsyncAPIView): """题单用户进度列表API""" @admin_role_required - def get(self, request, problem_set_id: int): + async def get(self, request, problem_set_id: int): """获取题单的用户进度列表""" try: - problem_set = ProblemSet.objects.filter(id=problem_set_id, visible=True).exclude(status=ProblemSetStatus.DRAFT).get() + problem_set = await ProblemSet.objects.filter(id=problem_set_id, visible=True).exclude(status=ProblemSetStatus.DRAFT).aget() except ProblemSet.DoesNotExist: return self.error("题单不存在") @@ -321,7 +334,7 @@ class ProblemSetUserProgressAPI(APIView): # 计算统计数据(基于所有数据,而非分页数据) # 使用一次查询获取所有统计数据 - stats = progresses.aggregate( + stats = await sync_to_async(progresses.aggregate, thread_sensitive=True)( total=Count("id"), completed=Count("id", filter=Q(is_completed=True)), avg_progress=Avg("progress_percentage"), @@ -351,7 +364,7 @@ class ProblemSetUserProgressAPI(APIView): # 构建题单所有题目的数据结构和映射 all_problems_list = [] all_problems_map = {} - for psp in all_problemset_problems: + async for psp in all_problemset_problems: problem_data = { "id": psp.problem.id, "_id": psp.problem._id, @@ -362,7 +375,7 @@ class ProblemSetUserProgressAPI(APIView): all_problems_map[str(psp.problem.id)] = psp.problem # 从当前页的数据中收集已完成的问题ID,用于序列化器 - paginated_progresses = list(progresses[offset : offset + limit]) + paginated_progresses = [progress async for progress in progresses[offset : offset + limit]] completed_problem_ids = set() for progress in paginated_progresses: if progress.progress_detail: @@ -376,7 +389,7 @@ class ProblemSetUserProgressAPI(APIView): request._problems_dict_cache = problems_dict # 使用分页 - data = self.paginate_data(request, progresses, ProblemSetProgressSerializer) + data = await self.async_paginate_data(request, progresses, ProblemSetProgressSerializer) # 添加统计数据 data["statistics"] = { diff --git a/pyproject.toml b/pyproject.toml index 44aafa8..59f2b9b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ dependencies = [ "tree-sitter-c>=0.24.2", "tree-sitter-python>=0.25.0", "xlsxwriter>=3.2.9,<4", + "asgiref>=3.11.1", ] [dependency-groups] @@ -35,6 +36,10 @@ dev = [ "ruff>=0.15.11", ] +[tool.pyright] +venvPath = "." +venv = ".venv" + [tool.ruff] line-length = 180 exclude = ["*/migrations/*", "*settings.py", "*/apps.py", ".venv"] diff --git a/submission/urls/oj.py b/submission/urls/oj.py index c0b03ca..03320e9 100644 --- a/submission/urls/oj.py +++ b/submission/urls/oj.py @@ -12,6 +12,6 @@ urlpatterns = [ path("submission", SubmissionAPI.as_view()), path("submissions", SubmissionListAPI.as_view()), path("submissions/today_count", SubmissionsTodayCount.as_view()), - path("submission_exists", SubmissionExistsAPI.as_view()), + path("submission_exists", SubmissionExistsAPI.as_view()), # DEPRECATED: 前端未调用 path("contest_submissions", ContestSubmissionListAPI.as_view()), ] diff --git a/submission/views/oj.py b/submission/views/oj.py index 29a0d2e..f1c03d7 100644 --- a/submission/views/oj.py +++ b/submission/views/oj.py @@ -1,5 +1,7 @@ import ipaddress -from datetime import datetime + +from asgiref.sync import sync_to_async +from django.utils import timezone from account.decorators import check_contest_permission, login_required from contest.models import ContestRuleType, ContestStatus @@ -8,7 +10,7 @@ from options.options import SysOptions # from judge.dispatcher import JudgeDispatcher from problem.models import Problem, ProblemRuleType -from utils.api import APIView, validate_serializer +from utils.api import APIView, AsyncAPIView, validate_serializer from utils.cache import cache from utils.captcha import Captcha from utils.throttling import TokenBucket @@ -154,8 +156,8 @@ class SubmissionAPI(APIView): return self.success() -class SubmissionListAPI(APIView): - def get(self, request): +class SubmissionListAPI(AsyncAPIView): + async def get(self, request): if not request.GET.get("limit"): return self.error("Limit is needed") if request.GET.get("contest_id"): @@ -171,14 +173,15 @@ class SubmissionListAPI(APIView): language = request.GET.get("language") if problem_id: try: - problem = Problem.objects.get( + problem = await Problem.objects.aget( _id__iexact=problem_id, contest_id__isnull=True, visible=True ) except Problem.DoesNotExist: return self.error("Problem doesn't exist") submissions = submissions.filter(problem=problem) - if not SysOptions.submission_list_show_all and request.user.is_regular_user(): + show_all = await SysOptions.aget("submission_list_show_all") + if not show_all and request.user.is_regular_user(): return self.success({"results": [], "total": 0}) if myself and myself == "1": @@ -190,21 +193,25 @@ class SubmissionListAPI(APIView): if language: submissions = submissions.filter(language=language) if request.GET.get("today") == "1": - today = datetime.today() + now = timezone.now() submissions = submissions.filter( - create_time__gte=datetime(today.year, today.month, today.day, 0, 0) + create_time__gte=now.replace(hour=0, minute=0, second=0, microsecond=0) ) - data = self.paginate_data(request, submissions) + data = await self.async_paginate_data(request, submissions) results = data["results"] if request.user.is_authenticated and request.user.is_regular_user(): problem_ids = list({s.problem_id for s in results}) - progress_cache = bulk_fetch_problemset_progress(request.user, problem_ids) + progress_cache = await sync_to_async(bulk_fetch_problemset_progress)(request.user, problem_ids) else: progress_cache = {} - data["results"] = SubmissionListSerializer( - results, many=True, user=request.user, problemset_progress_cache=progress_cache - ).data + data["results"] = await self.async_serialize_data( + SubmissionListSerializer, + results, + many=True, + user=request.user, + problemset_progress_cache=progress_cache, + ) return self.success(data) @@ -262,6 +269,7 @@ class ContestSubmissionListAPI(APIView): return self.success(data) +# DEPRECATED: 前端未调用 (2026-05-26) class SubmissionExistsAPI(APIView): def get(self, request): if not request.GET.get("problem_id"): @@ -274,10 +282,10 @@ class SubmissionExistsAPI(APIView): ) -class SubmissionsTodayCount(APIView): - def get(self, request): - today = datetime.today() - count = Submission.objects.filter( - create_time__gte=datetime(today.year, today.month, today.day, 0, 0) - ).count() +class SubmissionsTodayCount(AsyncAPIView): + async def get(self, request): + now = timezone.now() + count = await Submission.objects.filter( + create_time__gte=now.replace(hour=0, minute=0, second=0, microsecond=0) + ).acount() return self.success(count) diff --git a/utils/api/api.py b/utils/api/api.py index 975577f..4c95f6d 100644 --- a/utils/api/api.py +++ b/utils/api/api.py @@ -1,7 +1,10 @@ +import asyncio import functools +import inspect import json import logging +from asgiref.sync import sync_to_async from django.http import HttpResponse, QueryDict from django.utils.decorators import method_decorator from django.views.decorators.csrf import csrf_exempt @@ -162,6 +165,77 @@ class CSRFExemptAPIView(APIView): return super(CSRFExemptAPIView, self).dispatch(request, *args, **kwargs) +class AsyncAPIView(APIView): + view_is_async = True + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + cls.view_is_async = True + + async def dispatch(self, request, *args, **kwargs): + if self.request_parsers: + try: + request.data = self._get_request_data(self.request) + except ValueError as e: + return self.error(err="invalid-request", msg=str(e)) + try: + handler = getattr(self, request.method.lower(), self.http_method_not_allowed) + response = handler(request, *args, **kwargs) + if asyncio.iscoroutine(response): + response = await response + return response + except APIError as e: + ret = {"msg": e.msg} + if e.err: + ret["err"] = e.err + return self.error(**ret) + except Exception as e: + logger.exception(e) + return self.server_error() + + def serialize_data(self, object_serializer, data, **kwargs): + return object_serializer(data, **kwargs).data + + async def async_serialize_data(self, object_serializer, data, **kwargs): + return await sync_to_async( + self.serialize_data, + thread_sensitive=True, + )(object_serializer, data, **kwargs) + + async def async_paginate_data(self, request, query_set, object_serializer=None): + try: + limit = int(request.GET.get("limit", "10")) + except ValueError: + limit = 10 + if limit < 0 or limit > 250: + limit = 10 + try: + offset = int(request.GET.get("offset", "0")) + except ValueError: + offset = 0 + if offset < 0: + offset = 0 + count, results = await asyncio.gather( + query_set.acount(), + sync_to_async(lambda: list(query_set[offset:offset + limit]), thread_sensitive=True)(), + ) + if object_serializer: + results = await self.async_serialize_data( + object_serializer, + results, + many=True, + context={"request": request}, + ) + data = {"results": results, "total": count} + return data + + +class CSRFExemptAsyncAPIView(AsyncAPIView): + @method_decorator(csrf_exempt) + async def dispatch(self, request, *args, **kwargs): + return await super().dispatch(request, *args, **kwargs) + + def validate_serializer(serializer): """ @validate_serializer(TestSerializer) @@ -169,6 +243,20 @@ def validate_serializer(serializer): return self.success(request.data) """ def validate(view_method): + if inspect.iscoroutinefunction(view_method): + @functools.wraps(view_method) + async def async_handle(*args, **kwargs): + self = args[0] + request = args[1] + s = serializer(data=request.data) + if s.is_valid(): + request.data = s.data + request.serializer = s + return await view_method(*args, **kwargs) + else: + return self.invalid_serializer(s) + return async_handle + @functools.wraps(view_method) def handle(*args, **kwargs): self = args[0] @@ -180,7 +268,6 @@ def validate_serializer(serializer): return view_method(*args, **kwargs) else: return self.invalid_serializer(s) - return handle return validate diff --git a/utils/async_helpers.py b/utils/async_helpers.py new file mode 100644 index 0000000..8688c9b --- /dev/null +++ b/utils/async_helpers.py @@ -0,0 +1,14 @@ +from asgiref.sync import sync_to_async +from django.core.cache import cache + + +async def async_cache_get(key, default=None): + return await sync_to_async(cache.get, thread_sensitive=True)(key, default) + + +async def async_cache_set(key, value, timeout=None): + return await sync_to_async(cache.set, thread_sensitive=True)(key, value, timeout) + + +async def async_cache_delete(key): + return await sync_to_async(cache.delete, thread_sensitive=True)(key) diff --git a/utils/urls.py b/utils/urls.py index 5668741..10b3bd7 100644 --- a/utils/urls.py +++ b/utils/urls.py @@ -4,5 +4,5 @@ from .views import SimditorFileUploadAPIView, SimditorImageUploadAPIView urlpatterns = [ path("upload_image", SimditorImageUploadAPIView.as_view()), - path("upload_file", SimditorFileUploadAPIView.as_view()), + path("upload_file", SimditorFileUploadAPIView.as_view()), # DEPRECATED: 前端未调用 ] diff --git a/utils/views.py b/utils/views.py index 17de959..5351f78 100644 --- a/utils/views.py +++ b/utils/views.py @@ -46,6 +46,7 @@ class SimditorImageUploadAPIView(CSRFExemptAPIView): "file_path": f"{settings.UPLOAD_PREFIX}/{img_name}"}) +# DEPRECATED: 前端未调用 (2026-05-26) class SimditorFileUploadAPIView(CSRFExemptAPIView): request_parsers = () diff --git a/uv.lock b/uv.lock index 488282f..5d58ea4 100644 --- a/uv.lock +++ b/uv.lock @@ -550,6 +550,7 @@ name = "onlinejudge" version = "2.0.0" source = { virtual = "." } dependencies = [ + { name = "asgiref" }, { name = "channels" }, { name = "channels-redis" }, { name = "django" }, @@ -582,6 +583,7 @@ dev = [ [package.metadata] requires-dist = [ + { name = "asgiref", specifier = ">=3.11.1" }, { name = "channels", specifier = ">=4.3.2,<5" }, { name = "channels-redis", specifier = ">=4.3.0,<5" }, { name = "django", specifier = ">=6.0.4,<6.1" },