This commit is contained in:
2026-05-26 21:25:26 -06:00
parent 8731012f47
commit 57c0572fd9
38 changed files with 1507 additions and 476 deletions

View File

@@ -1,5 +1,6 @@
import functools
import hashlib
import inspect
import time
from contest.models import Contest, ContestRuleType, ContestStatus, ContestType
@@ -15,47 +16,58 @@ class BasePermissionDecorator(object):
self.func = func
def __get__(self, obj, obj_type):
if inspect.iscoroutinefunction(self.func):
return functools.partial(self._async_call, obj)
return functools.partial(self.__call__, obj)
def error(self, data):
return JSONResponse.response({"error": "permission-denied", "data": data})
def __call__(self, *args, **kwargs):
self.request = args[1]
request = args[1]
if self.check_permission():
if self.request.user.is_disabled:
if self.check_permission(request):
if request.user.is_disabled:
return self.error("Your account is disabled")
return self.func(*args, **kwargs)
else:
return self.error("Please login first")
def check_permission(self):
async def _async_call(self, *args, **kwargs):
request = args[1]
if self.check_permission(request):
if request.user.is_disabled:
return self.error("Your account is disabled")
return await self.func(*args, **kwargs)
return self.error("Please login first")
def check_permission(self, request):
raise NotImplementedError()
class login_required(BasePermissionDecorator):
def check_permission(self):
return self.request.user.is_authenticated
def check_permission(self, request):
return request.user.is_authenticated
class super_admin_required(BasePermissionDecorator):
def check_permission(self):
user = self.request.user
def check_permission(self, request):
user = request.user
return user.is_authenticated and user.is_super_admin()
class admin_role_required(BasePermissionDecorator):
def check_permission(self):
user = self.request.user
def check_permission(self, request):
user = request.user
return user.is_authenticated and user.is_admin_role()
class problem_permission_required(admin_role_required):
def check_permission(self):
if not super(problem_permission_required, self).check_permission():
def check_permission(self, request):
if not super().check_permission(request):
return False
if self.request.user.problem_permission == ProblemPermission.NONE:
if request.user.problem_permission == ProblemPermission.NONE:
return False
return True

View File

@@ -4,7 +4,6 @@ from account.models import UserProfile
from problem.models import Problem
from submission.models import JudgeStatus
ACCEPTED_STATUSES = {JudgeStatus.ACCEPTED, JudgeStatus.AST_CHECK_FAILED}

View File

@@ -23,6 +23,9 @@ class UserManager(models.Manager):
def get_by_natural_key(self, username):
return self.get(**{f"{self.model.USERNAME_FIELD}__iexact": username})
async def aget_by_natural_key(self, username):
return await self.aget(**{f"{self.model.USERNAME_FIELD}__iexact": username})
class User(AbstractBaseUser):
username = models.TextField(unique=True)

View File

@@ -4,6 +4,6 @@ from ..views.admin import GenerateUserAPI, ResetUserPasswordAPI, UserAdminAPI
urlpatterns = [
path("user", UserAdminAPI.as_view()),
path("generate_user", GenerateUserAPI.as_view()),
path("generate_user", GenerateUserAPI.as_view()), # DEPRECATED: 前端未调用
path("reset_password", ResetUserPasswordAPI.as_view()),
]

View File

@@ -29,28 +29,28 @@ urlpatterns = [
path("login", UserLoginAPI.as_view()),
path("logout", UserLogoutAPI.as_view()),
path("register", UserRegisterAPI.as_view()),
path("change_password", UserChangePasswordAPI.as_view()),
path("change_email", UserChangeEmailAPI.as_view()),
path("apply_reset_password", ApplyResetPasswordAPI.as_view()),
path("reset_password", ResetPasswordAPI.as_view()),
path("change_password", UserChangePasswordAPI.as_view()), # DEPRECATED: 前端未调用
path("change_email", UserChangeEmailAPI.as_view()), # DEPRECATED: 前端未调用
path("apply_reset_password", ApplyResetPasswordAPI.as_view()), # DEPRECATED: 前端未调用
path("reset_password", ResetPasswordAPI.as_view()), # DEPRECATED: 前端未调用
path("captcha", CaptchaAPIView.as_view()),
path("check_username_or_email", UsernameOrEmailCheck.as_view()),
path("check_username_or_email", UsernameOrEmailCheck.as_view()), # DEPRECATED: 前端未调用
path("profile", UserProfileAPI.as_view(), name="user_profile_api"),
path("profile/fresh_display_id", ProfileProblemDisplayIDRefreshAPI.as_view()),
path("metrics", Metrics.as_view()),
path("upload_avatar", AvatarUploadAPI.as_view()),
path("tfa_required", CheckTFARequiredAPI.as_view()),
path("tfa_required", CheckTFARequiredAPI.as_view()), # DEPRECATED: 前端未调用
path(
"two_factor_auth",
"two_factor_auth", # DEPRECATED: 前端未调用
TwoFactorAuthAPI.as_view(),
),
path("user_rank", UserRankAPI.as_view()),
path("user_activity_rank", UserActivityRankAPI.as_view()),
path("user_problem_rank", UserProblemRankAPI.as_view()),
path("sessions", SessionManagementAPI.as_view()),
path("sessions", SessionManagementAPI.as_view()), # DEPRECATED: 前端未调用
path(
"open_api_appkey",
"open_api_appkey", # DEPRECATED: 前端未调用
OpenAPIAppkeyAPI.as_view(),
),
path("sso", SSOAPI.as_view()),
path("sso", SSOAPI.as_view()), # DEPRECATED: 前端未调用
]

View File

@@ -191,6 +191,7 @@ class UserAdminAPI(APIView):
return self.success()
# DEPRECATED: 前端未调用 (2026-05-26)
class GenerateUserAPI(APIView):
@super_admin_required
def get(self, request):

View File

@@ -1,3 +1,4 @@
import asyncio
import os
from datetime import timedelta
from importlib import import_module
@@ -5,7 +6,6 @@ from importlib import import_module
import qrcode
from django.conf import settings
from django.contrib import auth
from django.core.cache import cache
from django.db.models import Count, Q
from django.template.loader import render_to_string
from django.utils import timezone
@@ -16,8 +16,9 @@ from otpauth import TOTP
from options.options import SysOptions
from problem.models import Problem
from submission.models import JudgeStatus, Submission, is_accepted
from utils.api import APIView, CSRFExemptAPIView, validate_serializer
from submission.models import JudgeStatus, Submission
from utils.api import APIView, AsyncAPIView, CSRFExemptAPIView, validate_serializer
from utils.async_helpers import async_cache_get, async_cache_set
from utils.captcha import Captcha
from utils.constants import CacheKey, ContestRuleType
from utils.shortcuts import datetime2str, img2base64, rand_str
@@ -58,12 +59,9 @@ def _valid_totp(token, code):
return _totp(token).verify(code)
class UserProfileAPI(APIView):
class UserProfileAPI(AsyncAPIView):
@method_decorator(ensure_csrf_cookie)
def get(self, request, **kwargs):
"""
判断是否登录, 若登录返回用户信息
"""
async def get(self, request, **kwargs):
user = request.user
if not user.is_authenticated:
return self.success()
@@ -71,52 +69,51 @@ class UserProfileAPI(APIView):
username = request.GET.get("username")
try:
if username:
user = User.objects.get(username=username, is_disabled=False)
user = await User.objects.aget(username=username, is_disabled=False)
else:
user = request.user
# api返回的是自己的信息可以返real_name
show_real_name = True
except User.DoesNotExist:
return self.error("User does not exist")
return self.success(UserProfileSerializer(user.userprofile, show_real_name=show_real_name).data)
profile = await UserProfile.objects.select_related("user").aget(user=user)
return self.success(UserProfileSerializer(profile, show_real_name=show_real_name).data)
@validate_serializer(EditUserProfileSerializer)
@login_required
def put(self, request):
async def put(self, request):
data = request.data
user_profile = request.user.userprofile
user_profile = await UserProfile.objects.select_related("user").aget(user=request.user)
for k, v in data.items():
setattr(user_profile, k, v)
user_profile.save()
await user_profile.asave()
return self.success(UserProfileSerializer(user_profile, show_real_name=True).data)
class Metrics(APIView):
def get(self, request):
class Metrics(AsyncAPIView):
async def get(self, request):
userid = request.GET.get("userid")
submissions = Submission.objects.filter(user_id=userid, contest_id__isnull=True)
if submissions.count() == 0:
qs = Submission.objects.filter(user_id=userid, contest_id__isnull=True)
count, latest, first = await asyncio.gather(
qs.acount(),
qs.order_by("-create_time").afirst(),
qs.order_by("create_time").afirst(),
)
if count == 0 or not latest or not first:
return self.error("暂无提交")
else:
latest_submission = submissions.first()
last_submission = submissions.last()
if last_submission and latest_submission:
return self.success(
{
"now": datetime2str(timezone.now()),
"latest": datetime2str(latest_submission.create_time),
"first": datetime2str(last_submission.create_time),
}
)
else:
return self.error("暂无提交")
return self.success(
{
"now": datetime2str(timezone.now()),
"latest": datetime2str(latest.create_time),
"first": datetime2str(first.create_time),
}
)
class AvatarUploadAPI(APIView):
class AvatarUploadAPI(AsyncAPIView):
request_parsers = ()
@login_required
def post(self, request):
async def post(self, request):
form = ImageUploadForm(request.POST, request.FILES)
if form.is_valid():
avatar = form.cleaned_data["image"]
@@ -132,13 +129,14 @@ class AvatarUploadAPI(APIView):
with open(os.path.join(settings.AVATAR_UPLOAD_DIR, name), "wb") as img:
for chunk in avatar:
img.write(chunk)
user_profile = request.user.userprofile
user_profile = await UserProfile.objects.aget(user=request.user)
user_profile.avatar = f"{settings.AVATAR_URI_PREFIX}/{name}"
user_profile.save()
await user_profile.asave()
return self.success("Succeeded")
# DEPRECATED: 前端未调用 (2026-05-26)
class TwoFactorAuthAPI(APIView):
@login_required
def get(self, request):
@@ -186,6 +184,7 @@ class TwoFactorAuthAPI(APIView):
return self.error("Invalid code")
# DEPRECATED: 前端未调用 (2026-05-26)
class CheckTFARequiredAPI(APIView):
@validate_serializer(UsernameOrEmailCheckSerializer)
def post(self, request):
@@ -203,31 +202,26 @@ class CheckTFARequiredAPI(APIView):
return self.success({"result": result})
class UserLoginAPI(APIView):
class UserLoginAPI(AsyncAPIView):
@validate_serializer(UserLoginSerializer)
def post(self, request):
"""
User login api
"""
async def post(self, request):
data = request.data
user = auth.authenticate(username=data["username"], password=data["password"])
# None is returned if username or password is wrong
user = await auth.aauthenticate(username=data["username"], password=data["password"])
if user:
if user.is_disabled:
return self.error("Your account has been disabled")
if not user.two_factor_auth:
prev_login = user.last_login
auth.login(request, user)
await auth.alogin(request, user)
request.session["prev_login"] = datetime2str(prev_login) if prev_login else ""
return self.success("Succeeded")
# `tfa_code` not in post data
if user.two_factor_auth and "tfa_code" not in data:
return self.error("tfa_required")
if _valid_totp(user.tfa_token, data["tfa_code"]):
prev_login = user.last_login
auth.login(request, user)
await auth.alogin(request, user)
request.session["prev_login"] = datetime2str(prev_login) if prev_login else ""
return self.success("Succeeded")
else:
@@ -236,12 +230,13 @@ class UserLoginAPI(APIView):
return self.error("Invalid username or password")
class UserLogoutAPI(APIView):
def get(self, request):
auth.logout(request)
class UserLogoutAPI(AsyncAPIView):
async def get(self, request):
await auth.alogout(request)
return self.success()
# DEPRECATED: 前端未调用 (2026-05-26)
class UsernameOrEmailCheck(APIView):
@validate_serializer(UsernameOrEmailCheckSerializer)
def post(self, request):
@@ -258,13 +253,10 @@ class UsernameOrEmailCheck(APIView):
return self.success(result)
class UserRegisterAPI(APIView):
class UserRegisterAPI(AsyncAPIView):
@validate_serializer(UserRegisterSerializer)
def post(self, request):
"""
User register api
"""
if not SysOptions.allow_register:
async def post(self, request):
if not await SysOptions.aget("allow_register"):
return self.error("Register function has been disabled by admin")
data = request.data
@@ -273,17 +265,18 @@ class UserRegisterAPI(APIView):
captcha = Captcha(request)
if not captcha.check(data["captcha"]):
return self.error("Invalid captcha")
if User.objects.filter(username=data["username"]).exists():
if await User.objects.filter(username=data["username"]).aexists():
return self.error("Username already exists")
if User.objects.filter(email=data["email"]).exists():
if await User.objects.filter(email=data["email"]).aexists():
return self.error("Email already exists")
user = User.objects.create(username=data["username"], email=data["email"])
user = await User.objects.acreate(username=data["username"], email=data["email"])
user.set_password(data["password"])
user.save()
UserProfile.objects.create(user=user)
await user.asave()
await UserProfile.objects.acreate(user=user)
return self.success("Succeeded")
# DEPRECATED: 前端未调用 (2026-05-26)
class UserChangeEmailAPI(APIView):
@validate_serializer(UserChangeEmailSerializer)
@login_required
@@ -306,6 +299,7 @@ class UserChangeEmailAPI(APIView):
return self.error("Wrong password")
# DEPRECATED: 前端未调用 (2026-05-26)
class UserChangePasswordAPI(APIView):
@validate_serializer(UserChangePasswordSerializer)
@login_required
@@ -329,6 +323,7 @@ class UserChangePasswordAPI(APIView):
return self.error("Invalid old password")
# DEPRECATED: 前端未调用 (2026-05-26)
class ApplyResetPasswordAPI(APIView):
@validate_serializer(ApplyResetPasswordSerializer)
def post(self, request):
@@ -363,6 +358,7 @@ class ApplyResetPasswordAPI(APIView):
return self.success("Succeeded")
# DEPRECATED: 前端未调用 (2026-05-26)
class ResetPasswordAPI(APIView):
@validate_serializer(ResetPasswordSerializer)
def post(self, request):
@@ -383,6 +379,7 @@ class ResetPasswordAPI(APIView):
return self.success("Succeeded")
# DEPRECATED: 前端未调用 (2026-05-26)
class SessionManagementAPI(APIView):
@login_required
def get(self, request):
@@ -426,8 +423,8 @@ class SessionManagementAPI(APIView):
return self.error("Invalid session_key")
class UserRankAPI(APIView):
def get(self, request):
class UserRankAPI(AsyncAPIView):
async def get(self, request):
rule_type = request.GET.get("rule")
username = request.GET.get("username", "")
try:
@@ -448,16 +445,16 @@ class UserRankAPI(APIView):
profiles = profiles.filter(total_score__gt=0).order_by("-total_score")
if n > 0:
profiles = profiles[:n]
return self.success(self.paginate_data(request, profiles, RankInfoSerializer))
return self.success(await self.async_paginate_data(request, profiles, RankInfoSerializer))
class UserActivityRankAPI(APIView):
def get(self, request):
class UserActivityRankAPI(AsyncAPIView):
async def get(self, request):
start = request.GET.get("start")
if not start:
return self.error("start time is required")
cache_key = f"{CacheKey.user_activity_rank}:{start}"
cached = cache.get(cache_key)
cached = await async_cache_get(cache_key)
if cached is not None:
return self.success(cached)
@@ -467,35 +464,40 @@ class UserActivityRankAPI(APIView):
create_time__gte=start,
result__in=[JudgeStatus.ACCEPTED, JudgeStatus.AST_CHECK_FAILED],
).exclude(username__in=hidden_names)
data = list(submissions.values("username").annotate(count=Count("problem_id", distinct=True)).order_by("-count")[:10])
cache.set(cache_key, data, 600)
data = [
row
async for row in submissions.values("username")
.annotate(count=Count("problem_id", distinct=True))
.order_by("-count")[:10]
]
await async_cache_set(cache_key, data, 600)
return self.success(data)
class UserProblemRankAPI(APIView):
def get(self, request):
class UserProblemRankAPI(AsyncAPIView):
async def get(self, request):
problem_id = request.GET.get("problem_id")
user = request.user
if not user.is_authenticated:
return self.error("User is not authenticated")
problem = Problem.objects.get(_id__iexact=problem_id, contest_id__isnull=True, visible=True)
problem = await Problem.objects.aget(_id__iexact=problem_id, contest_id__isnull=True, visible=True)
submissions = Submission.objects.filter(problem=problem, result__in=[JudgeStatus.ACCEPTED, JudgeStatus.AST_CHECK_FAILED])
all_ac_count = submissions.values("user_id").distinct().count()
all_ac_count = await submissions.values("user_id").distinct().acount()
class_name = user.class_name or ""
class_ac_count = 0
if class_name:
users = User.objects.filter(class_name=user.class_name, is_disabled=False).values_list("id", flat=True)
user_ids = list(users)
user_ids = [user_id async for user_id in users]
submissions = submissions.filter(user_id__in=user_ids)
class_ac_count = submissions.values("user_id").distinct().count()
class_ac_count = await submissions.values("user_id").distinct().acount()
my_submissions = submissions.filter(user_id=user.id)
if len(my_submissions) == 0:
if not await my_submissions.aexists():
return self.success(
{
"class_name": class_name,
@@ -505,8 +507,8 @@ class UserProblemRankAPI(APIView):
}
)
my_first_submission = my_submissions.order_by("create_time").first()
rank = submissions.filter(create_time__lte=my_first_submission.create_time).count()
my_first_submission = await my_submissions.order_by("create_time").afirst()
rank = await submissions.filter(create_time__lte=my_first_submission.create_time).acount()
return self.success(
{
"class_name": class_name,
@@ -517,25 +519,26 @@ class UserProblemRankAPI(APIView):
)
class ProfileProblemDisplayIDRefreshAPI(APIView):
class ProfileProblemDisplayIDRefreshAPI(AsyncAPIView):
@login_required
def get(self, request):
profile = request.user.userprofile
async def get(self, request):
profile = await UserProfile.objects.aget(user=request.user)
acm_problems = profile.acm_problems_status.get("problems", {})
oi_problems = profile.oi_problems_status.get("problems", {})
ids = list(acm_problems.keys()) + list(oi_problems.keys())
if not ids:
return self.success()
display_ids = Problem.objects.filter(id__in=ids, visible=True).values_list("_id", flat=True)
display_ids = [did async for did in Problem.objects.filter(id__in=ids, visible=True).values_list("_id", flat=True)]
id_map = dict(zip(ids, display_ids))
for k, v in acm_problems.items():
v["_id"] = id_map[k]
for k, v in oi_problems.items():
v["_id"] = id_map[k]
profile.save(update_fields=["acm_problems_status", "oi_problems_status"])
await profile.asave(update_fields=["acm_problems_status", "oi_problems_status"])
return self.success()
# DEPRECATED: 前端未调用 (2026-05-26)
class OpenAPIAppkeyAPI(APIView):
@login_required
def post(self, request):
@@ -548,6 +551,7 @@ class OpenAPIAppkeyAPI(APIView):
return self.success({"appkey": api_appkey})
# DEPRECATED: 前端未调用 (2026-05-26)
class SSOAPI(CSRFExemptAPIView):
@login_required
def get(self, request):

View File

@@ -1,19 +1,25 @@
from announcement.models import Announcement
from announcement.serializers import AnnouncementListSerializer, AnnouncementSerializer
from utils.api import APIView
from utils.api import AsyncAPIView
class AnnouncementAPI(APIView):
def get(self, request):
class AnnouncementAPI(AsyncAPIView):
async def get(self, request):
id = request.GET.get("id")
if id:
try:
announcement = Announcement.objects.get(id=id, visible=True)
return self.success(AnnouncementSerializer(announcement).data)
announcement = await (
Announcement.objects.select_related("created_by")
.filter(id=id, visible=True)
.afirst()
)
if announcement is None:
raise Announcement.DoesNotExist
return self.success(await self.async_serialize_data(AnnouncementSerializer, announcement))
except Announcement.DoesNotExist:
return self.error("Announcement does not exist")
announcements = Announcement.objects.select_related("created_by").filter(visible=True)
return self.success(
self.paginate_data(request, announcements, AnnouncementListSerializer)
await self.async_paginate_data(request, announcements, AnnouncementListSerializer)
)

View File

@@ -1,6 +1,7 @@
from .base import BaseEngine
from ast_checker.labels import label
from .base import BaseEngine
class MustHaveNestingEngine(BaseEngine):
def _has_inner_in_subtree(self, node, inner_type):

View File

@@ -1,6 +1,7 @@
from .base import BaseEngine
from ast_checker.labels import label
from .base import BaseEngine
class CountNodeEngine(BaseEngine):
def _message(self, rule, count):

View File

@@ -1,6 +1,7 @@
from .base import BaseEngine
from ast_checker.labels import label
from .base import BaseEngine
class MustExistNodeEngine(BaseEngine):
def _message(self, rule):

View File

@@ -1,45 +1,44 @@
from django.core.cache import cache
from django.db.models import Avg, Count
from django.db.models.functions import Round
from account.decorators import login_required
from django.db.models import Avg, Count
from django.db.models.functions import Round
from account.decorators import login_required
from comment.models import Comment
from comment.serializers import CommentSerializer, CreateCommentSerializer
from problem.models import Problem
from submission.models import JudgeStatus, Submission
from utils.api import APIView
from utils.api.api import validate_serializer
from utils.constants import CacheKey
from submission.models import JudgeStatus, Submission
from utils.api import AsyncAPIView
from utils.api.api import validate_serializer
from utils.async_helpers import async_cache_delete, async_cache_get, async_cache_set
from utils.constants import CacheKey
class CommentAPI(APIView):
class CommentAPI(AsyncAPIView):
@validate_serializer(CreateCommentSerializer)
@login_required
def post(self, request):
async def post(self, request):
data = request.data
try:
problem = Problem.objects.get(id=data["problem_id"], visible=True)
problem = await Problem.objects.aget(id=data["problem_id"], visible=True)
except Problem.DoesNotExist:
self.error("problem is not exists")
return self.error("problem is not exists")
try:
submission = (
Submission.objects.select_related("problem")
.filter(
user_id=request.user.id,
problem_id=data["problem_id"],
result__in=[JudgeStatus.ACCEPTED, JudgeStatus.AST_CHECK_FAILED],
)
.first()
submission = await (
Submission.objects.select_related("problem")
.filter(
user_id=request.user.id,
problem_id=data["problem_id"],
result__in=[JudgeStatus.ACCEPTED, JudgeStatus.AST_CHECK_FAILED],
)
except Submission.DoesNotExist:
self.error("submission is not exists or not accepted")
.afirst()
)
if not submission:
return self.error("submission is not exists or not accepted")
language = submission.language
if language == "Python3":
language = "Python"
Comment.objects.create(
await Comment.objects.acreate(
user=request.user,
problem=problem,
submission=submission,
@@ -49,32 +48,35 @@ class CommentAPI(APIView):
comprehensive_rating=data["comprehensive_rating"],
content=data["content"],
)
cache.delete(f"{CacheKey.comment_stats}:{problem.id}")
return self.success()
await async_cache_delete(f"{CacheKey.comment_stats}:{problem.id}")
return self.success()
@login_required
def get(self, request):
async def get(self, request):
problem_id = request.GET.get("problem_id")
comment = (
comment = await (
Comment.objects.select_related("problem")
.filter(user=request.user, problem_id=problem_id)
.first()
)
if comment:
return self.success(CommentSerializer(comment).data)
else:
return self.success()
.afirst()
)
if comment:
return self.success(await self.async_serialize_data(CommentSerializer, comment))
else:
return self.success()
class CommentStatisticsAPI(APIView):
def get(self, request):
problem_id = request.GET.get("problem_id")
cache_key = f"{CacheKey.comment_stats}:{problem_id}"
cached = cache.get(cache_key)
if cached is not None:
return self.success(cached)
class CommentStatisticsAPI(AsyncAPIView):
async def get(self, request):
problem_id = request.GET.get("problem_id")
cache_key = f"{CacheKey.comment_stats}:{problem_id}"
cached = await async_cache_get(cache_key)
if cached is not None:
return self.success(cached)
agg = Comment.objects.filter(problem_id=problem_id).aggregate(
from asgiref.sync import sync_to_async
agg = await sync_to_async(
Comment.objects.filter(problem_id=problem_id).aggregate
)(
count=Count("id"),
description=Round(Avg("description_rating"), 2),
difficulty=Round(Avg("difficulty_rating"), 2),
@@ -88,5 +90,5 @@ class CommentStatisticsAPI(APIView):
"difficulty": agg["difficulty"],
"comprehensive": agg["comprehensive"],
}}
cache.set(cache_key, data, 3600)
return self.success(data)
await async_cache_set(cache_key, data, 3600)
return self.success(data)

View File

@@ -12,12 +12,12 @@ from ..views import (
)
urlpatterns = [
path("smtp", SMTPAPI.as_view()),
path("smtp_test", SMTPTestAPI.as_view()),
path("smtp", SMTPAPI.as_view()), # DEPRECATED: 前端未调用
path("smtp_test", SMTPTestAPI.as_view()), # DEPRECATED: 前端未调用
path("website", WebsiteConfigAPI.as_view()),
path("random_user", RandomUsernameAPI.as_view()),
path("judge_server", JudgeServerAPI.as_view()),
path("prune_test_case", TestCasePruneAPI.as_view()),
path("versions", ReleaseNotesAPI.as_view()),
path("versions", ReleaseNotesAPI.as_view()), # DEPRECATED: 前端未调用
path("dashboard_info", DashboardInfoAPI.as_view()),
]

View File

@@ -12,7 +12,7 @@ urlpatterns = [
path("website", WebsiteConfigAPI.as_view()),
# 这里必须要有 /
path("judge_server_heartbeat/", JudgeServerHeartbeatAPI.as_view()),
path("languages", LanguagesAPI.as_view()),
path("languages", LanguagesAPI.as_view()), # DEPRECATED: 前端未调用
path("hitokoto", HitokotoAPI.as_view()),
path("class_usernames", ClassUsernamesAPI.as_view()),
]

View File

@@ -1,3 +1,4 @@
import asyncio
import hashlib
import json
import os
@@ -6,9 +7,10 @@ import re
import shutil
import smtplib
import time
from datetime import datetime
from datetime import timedelta
import requests
from asgiref.sync import sync_to_async
from django.conf import settings
from django.utils import timezone
from requests.exceptions import RequestException
@@ -20,7 +22,7 @@ from judge.dispatcher import process_pending_task
from options.options import SysOptions
from problem.models import Problem
from submission.models import Submission
from utils.api import APIView, CSRFExemptAPIView, validate_serializer
from utils.api import APIView, AsyncAPIView, CSRFExemptAPIView, validate_serializer
from utils.cache import JsonDataLoader
from utils.shortcuts import get_env, send_email
from utils.websocket import push_config_update
@@ -38,6 +40,7 @@ from .serializers import (
)
# DEPRECATED: 前端未调用 (2026-05-26)
class SMTPAPI(APIView):
@super_admin_required
def get(self, request):
@@ -66,6 +69,7 @@ class SMTPAPI(APIView):
return self.success()
# DEPRECATED: 前端未调用 (2026-05-26)
class SMTPTestAPI(APIView):
@super_admin_required
@validate_serializer(TestSMTPConfigSerializer)
@@ -97,35 +101,33 @@ class SMTPTestAPI(APIView):
return self.success()
class WebsiteConfigAPI(APIView):
def get(self, request):
ret = {
key: getattr(SysOptions, key)
for key in [
"website_base_url",
"website_name",
"website_name_shortcut",
"website_footer",
"allow_register",
"submission_list_show_all",
"class_list",
"enable_maxkb",
]
}
class WebsiteConfigAPI(AsyncAPIView):
async def get(self, request):
ret = await SysOptions.aget_many(
"website_base_url",
"website_name",
"website_name_shortcut",
"website_footer",
"allow_register",
"submission_list_show_all",
"class_list",
"enable_maxkb",
)
return self.success(ret)
@super_admin_required
@validate_serializer(CreateEditWebsiteConfigSerializer)
def post(self, request):
for k, v in request.data.items():
if k == "website_footer":
with XSSHtml() as parser:
v = parser.clean(v)
setattr(SysOptions, k, v)
# 推送配置更新到所有连接的客户端
push_config_update(k, v)
async def post(self, request):
@sync_to_async
def _update_config(data):
for k, v in data.items():
if k == "website_footer":
with XSSHtml() as parser:
v = parser.clean(v)
setattr(SysOptions, k, v)
push_config_update(k, v)
await _update_config(request.data)
return self.success()
@@ -206,6 +208,7 @@ class JudgeServerHeartbeatAPI(CSRFExemptAPIView):
return self.success()
# DEPRECATED: 前端未调用 (2026-05-26)
class LanguagesAPI(APIView):
def get(self, request):
return self.success(
@@ -255,6 +258,7 @@ class TestCasePruneAPI(APIView):
shutil.rmtree(test_case_dir, ignore_errors=True)
# DEPRECATED: 前端未调用 (2026-05-26)
class ReleaseNotesAPI(APIView):
def get(self, request):
try:
@@ -272,24 +276,29 @@ class ReleaseNotesAPI(APIView):
return self.success(releases)
class DashboardInfoAPI(APIView):
def get(self, request):
today = datetime.today()
today_submission_count = Submission.objects.filter(
create_time__gte=datetime(today.year, today.month, today.day, 0, 0)
).count()
recent_contest_count = Contest.objects.exclude(
end_time__lt=timezone.now()
).count()
judge_server_count = len(
list(filter(lambda x: x.status == "normal", JudgeServer.objects.all()))
class DashboardInfoAPI(AsyncAPIView):
async def get(self, request):
now = timezone.now()
today_start = now.replace(hour=0, minute=0, second=0, microsecond=0)
(
user_count,
today_submission_count,
recent_contest_count,
judge_servers,
) = await asyncio.gather(
User.objects.acount(),
Submission.objects.filter(create_time__gte=today_start).acount(),
Contest.objects.exclude(end_time__lt=timezone.now()).acount(),
JudgeServer.objects.filter(
last_heartbeat__gte=timezone.now() - timedelta(seconds=6)
).acount(),
)
return self.success(
{
"user_count": User.objects.count(),
"user_count": user_count,
"recent_contest_count": recent_contest_count,
"today_submission_count": today_submission_count,
"judge_server_count": judge_server_count,
"judge_server_count": judge_servers,
"env": {
"FORCE_HTTPS": get_env("FORCE_HTTPS", default=False),
"STATIC_CDN_HOST": get_env("STATIC_CDN_HOST", default=""),
@@ -298,24 +307,21 @@ class DashboardInfoAPI(APIView):
)
class RandomUsernameAPI(APIView):
def get(self, request):
class RandomUsernameAPI(AsyncAPIView):
async def get(self, request):
classroom = request.GET.get("classroom", "")
if not classroom:
return self.error("需要班级号")
usernames = (
User.objects.filter(username__istartswith=classroom)
usernames = [
u async for u in User.objects.filter(username__istartswith=classroom)
.values_list("username", flat=True)
.order_by("?")
)
if len(usernames) > 10:
return self.success(usernames[:10])
else:
return self.success(usernames)
.order_by("?")[:10]
]
return self.success(usernames)
class HitokotoAPI(APIView):
def get(self, request):
class HitokotoAPI(AsyncAPIView):
async def get(self, request):
try:
categories = JsonDataLoader.load_data(
settings.HITOKOTO_DIR, "categories.json"
@@ -328,20 +334,14 @@ class HitokotoAPI(APIView):
return self.error("获取一言失败,请稍后再试")
class ClassUsernamesAPI(APIView):
def get(self, request):
class ClassUsernamesAPI(AsyncAPIView):
async def get(self, request):
classroom = request.GET.get("classroom", "")
if not classroom:
return self.error("需要班级号")
users = User.objects.filter(class_name=classroom).order_by("-create_time")
names = []
for user in users:
prefix = f"ks{classroom}"
result = (
user.username[len(prefix) :]
if user.username.startswith(prefix)
else user.username
)
names.append(result)
prefix = f"ks{classroom}"
names = [
user.username[len(prefix):] if user.username.startswith(prefix) else user.username
async for user in User.objects.filter(class_name=classroom).order_by("-create_time")
]
return self.success(names)

View File

@@ -5,7 +5,7 @@ from ..views.admin import ACMContestHelper, ContestAnnouncementAPI, ContestAPI,
urlpatterns = [
path("contest", ContestAPI.as_view()),
path("contest/clone", ContestCloneAPI.as_view()),
path("contest/announcement", ContestAnnouncementAPI.as_view()),
path("contest/announcement", ContestAnnouncementAPI.as_view()), # DEPRECATED: 前端未调用
path("contest/acm_helper", ACMContestHelper.as_view()),
path("download_submissions", DownloadContestSubmissions.as_view()),
path("download_submissions", DownloadContestSubmissions.as_view()), # DEPRECATED: 前端未调用
]

View File

@@ -6,7 +6,7 @@ urlpatterns = [
path("contests", ContestListAPI.as_view()),
path("contest", ContestAPI.as_view()),
path("contest/password", ContestPasswordVerifyAPI.as_view()),
path("contest/announcement", ContestAnnouncementListAPI.as_view()),
path("contest/announcement", ContestAnnouncementListAPI.as_view()), # DEPRECATED: 前端未调用
path("contest/access", ContestAccessAPI.as_view()),
path("contest_rank", ContestRankAPI.as_view()),
]

View File

@@ -97,6 +97,7 @@ class ContestAPI(APIView):
return self.success(self.paginate_data(request, contests, ContestAdminSerializer))
# DEPRECATED: 前端未调用 (2026-05-26)
class ContestAnnouncementAPI(APIView):
@validate_serializer(CreateContestAnnouncementSerializer)
@super_admin_required
@@ -212,6 +213,7 @@ class ACMContestHelper(APIView):
return self.success()
# DEPRECATED: 前端未调用 (2026-05-26)
class DownloadContestSubmissions(APIView):
def _dump_submissions(self, contest, exclude_admin=True):
problem_ids = contest.problem_set.all().values_list("id", "_id")

View File

@@ -12,7 +12,7 @@ from account.decorators import (
)
from account.models import AdminType
from problem.models import Problem
from utils.api import APIView, validate_serializer
from utils.api import APIView, AsyncAPIView, validate_serializer
from utils.constants import CONTEST_PASSWORD_SESSION_KEY, CacheKey, ContestRuleType, ContestStatus
from utils.shortcuts import check_is_id, datetime2str
@@ -20,6 +20,7 @@ from ..models import ACMContestRank, Contest, ContestAnnouncement, OIContestRank
from ..serializers import ACMContestRankSerializer, ContestAnnouncementSerializer, ContestPasswordVerifySerializer, ContestSerializer, OIContestRankSerializer
# DEPRECATED: 前端未调用 (2026-05-26)
class ContestAnnouncementListAPI(APIView):
@check_contest_permission(check_type="announcements")
def get(self, request):
@@ -35,22 +36,28 @@ class ContestAnnouncementListAPI(APIView):
return self.success(ContestAnnouncementSerializer(data, many=True).data)
class ContestAPI(APIView):
def get(self, request):
class ContestAPI(AsyncAPIView):
async def get(self, request):
id = request.GET.get("id")
if not id or not check_is_id(id):
return self.error("Invalid parameter, id is required")
try:
contest = Contest.objects.get(id=id, visible=True)
contest = await (
Contest.objects.select_related("created_by")
.filter(id=id, visible=True)
.afirst()
)
if contest is None:
raise Contest.DoesNotExist
except Contest.DoesNotExist:
return self.error("Contest does not exist")
data = ContestSerializer(contest).data
data = await self.async_serialize_data(ContestSerializer, contest)
data["now"] = datetime2str(now())
return self.success(data)
class ContestListAPI(APIView):
def get(self, request):
class ContestListAPI(AsyncAPIView):
async def get(self, request):
contests = Contest.objects.select_related("created_by").filter(visible=True)
keyword = request.GET.get("keyword")
rule_type = request.GET.get("rule_type")
@@ -70,7 +77,7 @@ class ContestListAPI(APIView):
contests = contests.filter(end_time__lt=cur)
else:
contests = contests.filter(start_time__lte=cur, end_time__gte=cur)
return self.success(self.paginate_data(request, contests, ContestSerializer))
return self.success(await self.async_paginate_data(request, contests, ContestSerializer))
class ContestPasswordVerifyAPI(APIView):

View File

@@ -18,6 +18,7 @@ asgiref==3.11.1 \
# channels
# channels-redis
# django
# onlinejudge
certifi==2026.4.22 \
--hash=sha256:3cb2210c8f88ba2318d29b0388d1023c8492ff72ecdde4ebdaddbb13a31b1c4a \
--hash=sha256:8d455352a37b71bf76a79caa83a3d6c25afee4a385d632127b6afb3963f1c580

View File

@@ -0,0 +1,861 @@
# Backend Async Hardening Implementation Plan
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
**Goal:** Make the current backend async work correct first, then establish safe patterns for expanding async views without changing API response shapes.
**Architecture:** Keep the existing custom `APIView`/DRF serializer stack. Add async-safe tests and helpers around `AsyncAPIView`, explicitly preload serializer relations in converted endpoints, and use `sync_to_async(..., thread_sensitive=True)` for synchronous serializer/cache/helper code that remains inside async views.
**Tech Stack:** Django 6.0.4, custom class-based API views, Django async ORM, DRF serializers, PostgreSQL, Redis cache, Channels ASGI.
---
## Async Rules For This Repository
1. Async conversion is valid only when the endpoint preserves URL, method, status code, JSON envelope, and permission behavior.
2. Every async view that serializes model instances must either preload all serializer relations with `select_related()` / `prefetch_related()` or run serializer `.data` through a sync boundary.
3. `asyncio.gather()` is only for independent reads. Do not use it around writes that depend on ordering or transaction semantics.
4. Keep file upload/download, SMTP, test-case pruning, judge heartbeat mutation paths, and contest permission-heavy flows sync until the async decorator and middleware work is complete.
5. Current sync `MiddlewareMixin` middleware means ASGI requests still cross sync boundaries. The first async milestone is correctness and latency cleanup, not a full event-loop purity claim.
---
### Task 1: Add Regression Coverage For Converted Async Detail Views
**Files:**
- Create: `utils/test_async_view_regressions.py`
- [ ] **Step 1: Write failing regression tests**
Create `utils/test_async_view_regressions.py`:
```python
from datetime import timedelta
from asgiref.sync import sync_to_async
from django.contrib.auth import get_user_model
from django.test import AsyncClient, TestCase
from django.utils import timezone
from account.models import UserProfile
from announcement.models import Announcement
from contest.models import Contest
from flowchart.models import FlowchartSubmission
from problem.models import Problem, ProblemRuleType
from utils.constants import ContestRuleType, Difficulty
User = get_user_model()
def make_user(username="async_user"):
user = User.objects.create(username=username, email=f"{username}@example.com")
user.set_password("pass1234")
user.save()
UserProfile.objects.create(user=user)
return user
def make_problem(user):
return Problem.objects.create(
_id="ASYNC001",
title="Async Problem",
description="desc",
input_description="input",
output_description="output",
samples=[],
test_case_id="async-test-case",
test_case_score=[],
hint="",
languages=["Python3"],
template={},
created_by=user,
time_limit=1000,
memory_limit=128,
rule_type=ProblemRuleType.ACM,
difficulty=Difficulty.LOW,
share_submission=False,
allow_flowchart=True,
show_flowchart=True,
)
class AsyncConvertedViewRegressionTests(TestCase):
@classmethod
def setUpTestData(cls):
cls.user = make_user()
cls.announcement = Announcement.objects.create(
title="Async Announcement",
content="content",
tag="notice",
visible=True,
top=False,
created_by=cls.user,
)
cls.contest = Contest.objects.create(
title="Async Contest",
description="contest desc",
tag="weekly",
real_time_rank=True,
password=None,
rule_type=ContestRuleType.ACM,
start_time=timezone.now() - timedelta(hours=1),
end_time=timezone.now() + timedelta(hours=1),
created_by=cls.user,
visible=True,
allowed_ip_ranges=[],
)
cls.problem = make_problem(cls.user)
cls.flowchart_submission = FlowchartSubmission.objects.create(
user=cls.user,
problem=cls.problem,
mermaid_code="graph TD\nA-->B",
flowchart_data={},
)
async def test_announcement_detail_serializes_created_by(self):
response = await AsyncClient().get(f"/api/announcement?id={self.announcement.id}")
self.assertEqual(response.status_code, 200)
body = response.json()
self.assertIsNone(body["error"])
self.assertEqual(body["data"]["created_by"]["username"], self.user.username)
async def test_contest_detail_serializes_created_by(self):
response = await AsyncClient().get(f"/api/contest?id={self.contest.id}")
self.assertEqual(response.status_code, 200)
body = response.json()
self.assertIsNone(body["error"])
self.assertEqual(body["data"]["created_by"]["username"], self.user.username)
async def test_flowchart_detail_serializes_user_and_problem(self):
client = AsyncClient()
await sync_to_async(client.force_login)(self.user)
response = await client.get(f"/api/flowchart/submission?id={self.flowchart_submission.id}")
self.assertEqual(response.status_code, 200)
body = response.json()
self.assertIsNone(body["error"])
self.assertEqual(body["data"]["username"], self.user.username)
self.assertEqual(body["data"]["problem"], self.problem.id)
```
- [ ] **Step 2: Run the regression tests**
Run:
```bash
rtk uv run python manage.py test utils.test_async_view_regressions -v 2
```
Expected before Task 2: at least one test returns `server-error` or raises async-unsafe lazy relation access.
- [ ] **Step 3: Commit the failing tests**
Run:
```bash
rtk git add utils/test_async_view_regressions.py
rtk git commit -m "test: cover async view serialization regressions"
```
---
### Task 2: Preload Relations For Already Converted Async Detail Views
**Files:**
- Modify: `announcement/views/oj.py`
- Modify: `contest/views/oj.py`
- Modify: `flowchart/views/oj.py`
- [ ] **Step 1: Fix announcement detail relation loading**
In `announcement/views/oj.py`, replace the detail query with:
```python
announcement = await (
Announcement.objects.select_related("created_by")
.filter(id=id, visible=True)
.afirst()
)
if announcement is None:
raise Announcement.DoesNotExist
```
- [ ] **Step 2: Fix contest detail relation loading**
In `contest/views/oj.py`, replace the detail query with:
```python
contest = await (
Contest.objects.select_related("created_by")
.filter(id=id, visible=True)
.afirst()
)
if contest is None:
raise Contest.DoesNotExist
```
- [ ] **Step 3: Fix flowchart submission detail relation loading**
In `flowchart/views/oj.py`, update `FlowchartSubmissionAPI.get()`:
```python
submission = await (
FlowchartSubmission.objects.select_related("user", "problem")
.filter(id=submission_id)
.afirst()
)
if submission is None:
raise FlowchartSubmission.DoesNotExist
```
- [ ] **Step 4: Fix flowchart retry permission relation loading**
In `flowchart/views/oj.py`, update `FlowchartSubmissionRetryAPI.post()`:
```python
submission = await (
FlowchartSubmission.objects.select_related("problem")
.filter(id=submission_id)
.afirst()
)
if submission is None:
raise FlowchartSubmission.DoesNotExist
```
- [ ] **Step 5: Fix flowchart completed-detail relation loading**
In `flowchart/views/oj.py`, update `FlowchartSubmissionDetailAPI.get()` before serialization:
```python
submissions = (
FlowchartSubmission.objects.select_related("user", "problem")
.filter(
user=request.user,
problem=problem,
status=FlowchartSubmissionStatus.COMPLETED,
)
.order_by("create_time")
)
```
- [ ] **Step 6: Run the regression tests**
Run:
```bash
rtk uv run python manage.py test utils.test_async_view_regressions -v 2
```
Expected: all tests pass.
- [ ] **Step 7: Run Django system checks**
Run:
```bash
rtk uv run python manage.py check
```
Expected: `System check identified no issues`.
- [ ] **Step 8: Commit relation fixes**
Run:
```bash
rtk git add announcement/views/oj.py contest/views/oj.py flowchart/views/oj.py
rtk git commit -m "fix: preload relations in async detail views"
```
---
### Task 3: Add Async Serialization Helpers To `AsyncAPIView`
**Files:**
- Modify: `utils/api/api.py`
- Modify: `utils/test_async_api.py`
- [ ] **Step 1: Write tests for async serialization helper**
Create `utils/test_async_api.py`:
```python
import json
from django.test import AsyncRequestFactory, SimpleTestCase
from utils.api import AsyncAPIView, serializers, validate_serializer
class PayloadSerializer(serializers.Serializer):
name = serializers.CharField()
class EchoSerializer(serializers.Serializer):
name = serializers.CharField()
class AsyncValidatedEchoView(AsyncAPIView):
@validate_serializer(PayloadSerializer)
async def post(self, request):
return self.success({"name": request.data["name"]})
class AsyncSerializationHelperTests(SimpleTestCase):
async def test_validate_serializer_supports_async_view_methods(self):
request = AsyncRequestFactory().post(
"/api/echo",
data=json.dumps({"name": "alice"}),
content_type="application/json",
)
response = await AsyncValidatedEchoView.as_view()(request)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.data["error"], None)
self.assertEqual(response.data["data"], {"name": "alice"})
async def test_async_serialize_data_returns_serializer_data(self):
view = AsyncAPIView()
data = await view.async_serialize_data(
EchoSerializer,
[{"name": "alice"}, {"name": "bob"}],
many=True,
)
self.assertEqual(data, [{"name": "alice"}, {"name": "bob"}])
```
- [ ] **Step 2: Run helper tests and confirm failure**
Run:
```bash
rtk uv run python manage.py test utils.test_async_api -v 2
```
Expected before implementation: `AttributeError: 'AsyncAPIView' object has no attribute 'async_serialize_data'`.
- [ ] **Step 3: Add `sync_to_async` import**
In `utils/api/api.py`, add:
```python
from asgiref.sync import sync_to_async
```
- [ ] **Step 4: Add serializer helper methods**
In `AsyncAPIView`, before `async_paginate_data()`, add:
```python
def serialize_data(self, object_serializer, data, **kwargs):
return object_serializer(data, **kwargs).data
async def async_serialize_data(self, object_serializer, data, **kwargs):
return await sync_to_async(
self.serialize_data,
thread_sensitive=True,
)(object_serializer, data, **kwargs)
```
- [ ] **Step 5: Use helper inside async pagination**
In `AsyncAPIView.async_paginate_data()`, replace:
```python
if object_serializer:
results = object_serializer(results, many=True, context={"request": request}).data
```
with:
```python
if object_serializer:
results = await self.async_serialize_data(
object_serializer,
results,
many=True,
context={"request": request},
)
```
- [ ] **Step 6: Run helper and regression tests**
Run:
```bash
rtk uv run python manage.py test utils.test_async_api utils.test_async_view_regressions -v 2
```
Expected: all tests pass.
- [ ] **Step 7: Commit async serializer helper**
Run:
```bash
rtk git add utils/api/api.py utils/test_async_api.py
rtk git commit -m "feat: add async serializer helper"
```
---
### Task 4: Add Cache Helpers For Async Views
**Files:**
- Create: `utils/async_helpers.py`
- Create: `utils/test_async_helpers.py`
- Modify: `problem/views/oj.py`
- Modify: `comment/views/oj.py`
- [ ] **Step 1: Write cache helper tests**
Create `utils/test_async_helpers.py`:
```python
from django.test import SimpleTestCase, override_settings
from utils.async_helpers import async_cache_delete, async_cache_get, async_cache_set
@override_settings(
CACHES={
"default": {
"BACKEND": "django.core.cache.backends.locmem.LocMemCache",
"LOCATION": "async-helper-tests",
}
}
)
class AsyncCacheHelperTests(SimpleTestCase):
async def test_async_cache_round_trip(self):
await async_cache_set("async:key", {"value": 1}, 30)
value = await async_cache_get("async:key")
self.assertEqual(value, {"value": 1})
async def test_async_cache_delete(self):
await async_cache_set("async:delete", "present", 30)
await async_cache_delete("async:delete")
value = await async_cache_get("async:delete")
self.assertIsNone(value)
```
- [ ] **Step 2: Run tests and confirm import failure**
Run:
```bash
rtk uv run python manage.py test utils.test_async_helpers -v 2
```
Expected before implementation: `ModuleNotFoundError: No module named 'utils.async_helpers'`.
- [ ] **Step 3: Create async cache helpers**
Create `utils/async_helpers.py`:
```python
from asgiref.sync import sync_to_async
from django.core.cache import cache
async def async_cache_get(key, default=None):
return await sync_to_async(cache.get, thread_sensitive=True)(key, default)
async def async_cache_set(key, value, timeout=None):
return await sync_to_async(cache.set, thread_sensitive=True)(key, value, timeout)
async def async_cache_delete(key):
return await sync_to_async(cache.delete, thread_sensitive=True)(key)
```
- [ ] **Step 4: Convert async problem cache calls**
In `problem/views/oj.py`, add:
```python
from utils.async_helpers import async_cache_get, async_cache_set
```
In `ProblemTagAPI.get()`, replace:
```python
cached = cache.get(cache_key)
```
with:
```python
cached = await async_cache_get(cache_key)
```
Replace:
```python
cache.set(cache_key, data, 3600)
```
with:
```python
await async_cache_set(cache_key, data, 3600)
```
- [ ] **Step 5: Convert async comment cache calls**
In `comment/views/oj.py`, add:
```python
from utils.async_helpers import async_cache_delete, async_cache_get, async_cache_set
```
In `CommentAPI.post()`, replace:
```python
cache.delete(f"{CacheKey.comment_stats}:{problem.id}")
```
with:
```python
await async_cache_delete(f"{CacheKey.comment_stats}:{problem.id}")
```
In `CommentStatisticsAPI.get()`, replace `cache.get()` and `cache.set()` with:
```python
cached = await async_cache_get(cache_key)
```
and:
```python
await async_cache_set(cache_key, data, 3600)
```
- [ ] **Step 6: Run helper and targeted async tests**
Run:
```bash
rtk uv run python manage.py test utils.test_async_helpers utils.test_async_api utils.test_async_view_regressions -v 2
```
Expected: all tests pass.
- [ ] **Step 7: Commit cache helper work**
Run:
```bash
rtk git add utils/async_helpers.py utils/test_async_helpers.py problem/views/oj.py comment/views/oj.py
rtk git commit -m "feat: add async cache helpers"
```
---
### Task 5: Make Permission Decorators Async-Aware Before More Conversions
**Files:**
- Modify: `account/decorators.py`
- Create: `account/test_async_decorators.py`
- [ ] **Step 1: Write tests for async `login_required`**
Create `account/test_async_decorators.py`:
```python
from django.contrib.auth.models import AnonymousUser
from django.test import AsyncRequestFactory, SimpleTestCase
from account.decorators import login_required
from utils.api import AsyncAPIView
class DisabledUser:
is_authenticated = True
is_disabled = True
class ActiveUser:
is_authenticated = True
is_disabled = False
class ProtectedAsyncView(AsyncAPIView):
@login_required
async def get(self, request):
return self.success("ok")
class AsyncPermissionDecoratorTests(SimpleTestCase):
async def test_async_login_required_allows_active_user(self):
request = AsyncRequestFactory().get("/api/protected")
request.user = ActiveUser()
response = await ProtectedAsyncView.as_view()(request)
self.assertEqual(response.data["error"], None)
self.assertEqual(response.data["data"], "ok")
async def test_async_login_required_rejects_anonymous_user(self):
request = AsyncRequestFactory().get("/api/protected")
request.user = AnonymousUser()
response = await ProtectedAsyncView.as_view()(request)
self.assertEqual(response.data["error"], "permission-denied")
self.assertEqual(response.data["data"], "Please login first")
async def test_async_login_required_rejects_disabled_user(self):
request = AsyncRequestFactory().get("/api/protected")
request.user = DisabledUser()
response = await ProtectedAsyncView.as_view()(request)
self.assertEqual(response.data["error"], "permission-denied")
self.assertEqual(response.data["data"], "Your account is disabled")
```
- [ ] **Step 2: Run decorator tests**
Run:
```bash
rtk uv run python manage.py test account.test_async_decorators -v 2
```
Expected before implementation: this may pass because `AsyncAPIView.dispatch()` awaits returned coroutine. Keep the test anyway as a contract before refactoring.
- [ ] **Step 3: Refactor `BasePermissionDecorator` with explicit async path**
In `account/decorators.py`, add:
```python
import inspect
```
Replace `BasePermissionDecorator.__get__()` with:
```python
def __get__(self, obj, obj_type):
if inspect.iscoroutinefunction(self.func):
return functools.partial(self._async_call, obj)
return functools.partial(self.__call__, obj)
```
Add this method to `BasePermissionDecorator`:
```python
async def _async_call(self, *args, **kwargs):
self.request = args[1]
if self.check_permission():
if self.request.user.is_disabled:
return self.error("Your account is disabled")
return await self.func(*args, **kwargs)
return self.error("Please login first")
```
- [ ] **Step 4: Run decorator and async view tests**
Run:
```bash
rtk uv run python manage.py test account.test_async_decorators utils.test_async_api utils.test_async_view_regressions -v 2
```
Expected: all tests pass.
- [ ] **Step 5: Commit decorator refactor**
Run:
```bash
rtk git add account/decorators.py account/test_async_decorators.py
rtk git commit -m "refactor: make permission decorators async-aware"
```
---
### Task 6: Convert One Endpoint Family At A Time
**Files:**
- Modify only the endpoint family being converted in each batch.
- Add or update tests in the same app, using `AsyncClient` for converted URLs.
- [ ] **Step 1: Choose the next low-risk batch**
Use this order:
```text
1. Pure public GET list/detail endpoints already close to async:
- announcement list/detail
- contest list/detail
- problem tag/list/detail
2. Authenticated read-only list endpoints:
- message list
- submission list
- flowchart list/detail/current
3. Simple create/update endpoints with no contest permission decorator:
- message create
- comment create
- flowchart retry/create
```
- [ ] **Step 2: For each endpoint, write one async smoke test**
Use this template and replace the URL and assertions with the exact endpoint response:
```python
from django.test import AsyncClient, TestCase
class EndpointAsyncSmokeTests(TestCase):
async def test_endpoint_returns_success_envelope(self):
response = await AsyncClient().get("/api/endpoint?limit=10")
self.assertEqual(response.status_code, 200)
body = response.json()
self.assertIn("error", body)
self.assertIn("data", body)
```
- [ ] **Step 3: Convert ORM access only where async ORM exists**
Use these replacements:
```python
obj = await Model.objects.aget(id=id)
count = await queryset.acount()
first = await queryset.afirst()
last = await queryset.alast()
items = [item async for item in queryset[offset:offset + limit]]
created = await Model.objects.acreate(field=value)
await instance.asave(update_fields=["field"])
```
- [ ] **Step 4: Keep sync-only helpers behind `sync_to_async`**
Use this pattern:
```python
from asgiref.sync import sync_to_async
result = await sync_to_async(sync_helper, thread_sensitive=True)(arg1, arg2)
```
- [ ] **Step 5: Run targeted app tests and system check after each batch**
Run:
```bash
rtk uv run python manage.py test <app_label> -v 2
rtk uv run python manage.py check
```
Expected: targeted tests pass and system check reports no issues.
- [ ] **Step 6: Commit each endpoint family separately**
Run:
```bash
rtk git add <changed-files>
rtk git commit -m "refactor: async <endpoint-family> endpoints"
```
---
### Task 7: Audit Middleware Before Claiming Full Async Benefit
**Files:**
- Modify: `account/middleware.py`
- Add tests only for behavior that changes.
- [ ] **Step 1: Record current sync middleware boundaries**
Before changing middleware, note these current sync classes:
```text
account.middleware.APITokenAuthMiddleware
account.middleware.AdminRoleRequiredMiddleware
account.middleware.SessionRecordMiddleware
```
- [ ] **Step 2: Keep middleware sync during endpoint correctness work**
Do not convert middleware in the same commit as endpoint conversions. Middleware affects every request and needs its own review.
- [ ] **Step 3: If middleware conversion is pursued, convert one class per commit**
Use Django new-style middleware with explicit sync and async handling. `__acall__` is a local helper; `__call__` dispatches to it when Django passes an async `get_response`:
```python
from asgiref.sync import iscoroutinefunction, markcoroutinefunction, sync_to_async
class ExampleMiddleware:
sync_capable = True
async_capable = True
def __init__(self, get_response):
self.get_response = get_response
self.is_async = iscoroutinefunction(get_response)
if self.is_async:
markcoroutinefunction(self)
def __call__(self, request):
if self.is_async:
return self.__acall__(request)
response = self.process_request(request)
if response is not None:
return response
return self.get_response(request)
async def __acall__(self, request):
response = await self.aprocess_request(request)
if response is not None:
return response
return await self.get_response(request)
def process_request(self, request):
return None
async def aprocess_request(self, request):
return await sync_to_async(self.process_request, thread_sensitive=True)(request)
```
- [ ] **Step 4: Run full backend test command after middleware work**
Run:
```bash
rtk uv run python manage.py test -v 2
rtk uv run python manage.py check
```
Expected: all tests pass and system check reports no issues.
---
## Definition Of Done
- `rtk uv run python manage.py check` passes.
- Async regression tests cover converted detail serializers that depend on FK relations.
- Converted async views do not call DRF serializer `.data` directly unless the data is primitive and relation-free.
- Cache access from async views uses async helper wrappers.
- New endpoint conversions are committed in endpoint-family-sized commits.
- No file upload/download, SMTP, judge heartbeat, test-case prune, or contest permission-heavy endpoint is converted without a separate focused plan.

View File

@@ -11,6 +11,6 @@ class FlowchartEvaluationPromptTests(TestCase):
self.assertIn("Mermaid节点ID由系统生成", prompt)
self.assertIn("不要评价节点ID", prompt)
self.assertIn("不要因节点ID扣分", prompt)
self.assertIn("feedback控制在0字以内", prompt)
self.assertIn("feedback控制在100字以内", prompt)
self.assertIn("suggestions最多3条", prompt)
self.assertIn("重要建议必须以【重点】开头", prompt)

View File

@@ -7,65 +7,63 @@ from flowchart.serializers import (
)
from flowchart.tasks import evaluate_flowchart_task
from problem.models import Problem
from utils.api import APIView
from utils.api import AsyncAPIView
class FlowchartSubmissionAPI(APIView):
class FlowchartSubmissionAPI(AsyncAPIView):
@login_required
def post(self, request):
"""创建流程图提交"""
async def post(self, request):
serializer = CreateFlowchartSubmissionSerializer(data=request.data)
if not serializer.is_valid():
return self.error(serializer.errors)
data = serializer.validated_data
# 验证题目存在
try:
problem = Problem.objects.get(id=data["problem_id"])
problem = await Problem.objects.aget(id=data["problem_id"])
except Problem.DoesNotExist:
return self.error("Problem doesn't exist")
# 验证题目是否允许流程图提交
if not problem.allow_flowchart:
return self.error("This problem does not allow flowchart submission")
# 创建提交记录
submission = FlowchartSubmission.objects.create(
submission = await FlowchartSubmission.objects.acreate(
user=request.user,
problem=problem,
mermaid_code=data["mermaid_code"],
flowchart_data=data.get("flowchart_data", {}),
)
# 启动AI评分任务
evaluate_flowchart_task.send(submission.id)
return self.success({"submission_id": submission.id, "status": "pending"})
@login_required
def get(self, request):
"""获取流程图提交详情"""
async def get(self, request):
submission_id = request.GET.get("id")
if not submission_id:
return self.error("submission_id is required")
try:
submission = FlowchartSubmission.objects.get(id=submission_id)
submission = await (
FlowchartSubmission.objects.select_related("user", "problem")
.filter(id=submission_id)
.afirst()
)
if submission is None:
raise FlowchartSubmission.DoesNotExist
except FlowchartSubmission.DoesNotExist:
return self.error("Submission doesn't exist")
if not submission.check_user_permission(request.user):
return self.error("No permission for this submission")
serializer = FlowchartSubmissionSerializer(submission)
return self.success(serializer.data)
return self.success(await self.async_serialize_data(FlowchartSubmissionSerializer, submission))
class FlowchartSubmissionListAPI(APIView):
class FlowchartSubmissionListAPI(AsyncAPIView):
@login_required
def get(self, request):
"""获取流程图提交列表"""
async def get(self, request):
username = request.GET.get("username")
problem_id = request.GET.get("problem_id")
myself = request.GET.get("myself")
@@ -74,7 +72,7 @@ class FlowchartSubmissionListAPI(APIView):
if problem_id:
try:
problem = Problem.objects.get(
problem = await Problem.objects.aget(
_id__iexact=problem_id, contest_id__isnull=True, visible=True
)
except Problem.DoesNotExist:
@@ -88,38 +86,42 @@ class FlowchartSubmissionListAPI(APIView):
elif request.user.is_regular_user():
queryset = queryset.filter(user=request.user)
data = self.paginate_data(request, queryset)
data["results"] = FlowchartSubmissionListSerializer(
data["results"], many=True
).data
data = await self.async_paginate_data(request, queryset)
data["results"] = await self.async_serialize_data(
FlowchartSubmissionListSerializer,
data["results"],
many=True,
)
return self.success(data)
class FlowchartSubmissionRetryAPI(APIView):
class FlowchartSubmissionRetryAPI(AsyncAPIView):
@login_required
def post(self, request):
"""重新触发AI评分"""
async def post(self, request):
submission_id = request.data.get("submission_id")
if not submission_id:
return self.error("submission_id is required")
try:
submission = FlowchartSubmission.objects.get(id=submission_id)
submission = await (
FlowchartSubmission.objects.select_related("problem")
.filter(id=submission_id)
.afirst()
)
if submission is None:
raise FlowchartSubmission.DoesNotExist
except FlowchartSubmission.DoesNotExist:
return self.error("Submission doesn't exist")
# 检查权限
if not submission.check_user_permission(request.user):
return self.error("No permission for this submission")
# 检查是否可以重新评分
if submission.status not in [
FlowchartSubmissionStatus.FAILED,
FlowchartSubmissionStatus.COMPLETED,
]:
return self.error("Submission is not in a state that allows retry")
# 重置状态并重新启动AI评分
submission.status = FlowchartSubmissionStatus.PENDING
submission.ai_score = None
submission.ai_grade = None
@@ -128,9 +130,8 @@ class FlowchartSubmissionRetryAPI(APIView):
submission.ai_criteria_details = {}
submission.processing_time = None
submission.evaluation_time = None
submission.save()
await submission.asave()
# 重新启动AI评分任务
evaluate_flowchart_task.send(submission.id)
return self.success(
@@ -142,15 +143,14 @@ class FlowchartSubmissionRetryAPI(APIView):
)
class FlowchartSubmissionDetailAPI(APIView):
class FlowchartSubmissionDetailAPI(AsyncAPIView):
@login_required
def get(self, request):
"""获取当前用户对指定题目的流程图提交详情"""
async def get(self, request):
problem_id = request.GET.get("problem_id")
if not problem_id:
return self.error("problem_id is required")
try:
problem = Problem.objects.get(id=problem_id)
problem = await Problem.objects.aget(id=problem_id)
except Problem.DoesNotExist:
return self.error("Problem doesn't exist")
@@ -158,34 +158,37 @@ class FlowchartSubmissionDetailAPI(APIView):
page = int(request.GET.get("page", 0))
except ValueError:
return self.error("page must be an integer")
submissions = FlowchartSubmission.objects.filter(
user=request.user,
problem=problem,
status=FlowchartSubmissionStatus.COMPLETED,
).order_by("create_time")
count = submissions.count()
submissions = (
FlowchartSubmission.objects.select_related("user", "problem")
.filter(
user=request.user,
problem=problem,
status=FlowchartSubmissionStatus.COMPLETED,
)
.order_by("create_time")
)
count = await submissions.acount()
if count == 0:
return self.success({"submission": None, "count": 0})
# page=0 means latest; page=N means the Nth submission (1-indexed, chronological)
if page == 0:
submission = submissions.last()
submission = await submissions.alast()
else:
if page < 0 or page > count:
return self.error("Page out of range")
submission = submissions[page - 1]
serializer = FlowchartSubmissionSerializer(submission)
return self.success({"submission": serializer.data, "count": count})
result = [s async for s in submissions[page - 1:page]]
submission = result[0]
data = await self.async_serialize_data(FlowchartSubmissionSerializer, submission)
return self.success({"submission": data, "count": count})
class FlowchartSubmissionCurrentAPI(APIView):
class FlowchartSubmissionCurrentAPI(AsyncAPIView):
@login_required
def get(self, request):
"""获取当前用户对指定题目的最新流程图提交,只返回次数和分数"""
async def get(self, request):
problem_id = request.GET.get("problem_id")
if not problem_id:
return self.error("problem_id is required")
try:
problem = Problem.objects.get(id=problem_id)
problem = await Problem.objects.aget(id=problem_id)
except Problem.DoesNotExist:
return self.error("Problem doesn't exist")
submissions = (
@@ -197,10 +200,10 @@ class FlowchartSubmissionCurrentAPI(APIView):
.values("ai_score", "ai_grade")
.order_by("-create_time")
)
count = submissions.count()
count = await submissions.acount()
if count == 0:
return self.success({"count": 0, "score": 0, "grade": ""})
submission = submissions[0]
submission = await submissions.afirst()
return self.success(
{
"count": count,

View File

@@ -3,33 +3,33 @@ from account.models import User
from message.models import Message
from message.serializers import CreateMessageSerializer, MessageSerializer
from submission.models import Submission
from utils.api import APIView
from utils.api import AsyncAPIView
from utils.api.api import validate_serializer
class MessageAPI(APIView):
class MessageAPI(AsyncAPIView):
@login_required
def get(self, request):
async def get(self, request):
messages = Message.objects.select_related(
"recipient", "sender", "submission", "submission__problem"
).filter(recipient=request.user)
return self.success(self.paginate_data(request, messages, MessageSerializer))
return self.success(await self.async_paginate_data(request, messages, MessageSerializer))
@validate_serializer(CreateMessageSerializer)
@super_admin_required
def post(self, request):
async def post(self, request):
data = request.data
if data["recipient"] == request.user.id:
return self.error("Can not send a message to youself")
try:
recipient = User.objects.get(id=data["recipient"], is_disabled=False)
recipient = await User.objects.aget(id=data["recipient"], is_disabled=False)
except User.DoesNotExist:
return self.error("User does not exist")
try:
submission = Submission.objects.get(id=data["submission"])
submission = await Submission.objects.aget(id=data["submission"])
except Submission.DoesNotExist:
return self.error("Submission does not exist")
Message.objects.create(
await Message.objects.acreate(
submission=submission,
message=data["message"],
sender=request.user,

View File

@@ -292,4 +292,15 @@ class _SysOptionsMeta(type):
class SysOptions(metaclass=_SysOptionsMeta):
pass
@classmethod
async def aget(cls, key):
from asgiref.sync import sync_to_async
return await sync_to_async(getattr)(cls, key)
@classmethod
async def aget_many(cls, *keys):
from asgiref.sync import sync_to_async
def _get_all():
return {k: getattr(cls, k) for k in keys}
return await sync_to_async(_get_all)()

View File

@@ -1,15 +1,16 @@
import random
from datetime import datetime
from django.core.cache import cache
from asgiref.sync import sync_to_async
from django.db.models import BooleanField, Case, Count, Q, Value, When
from django.db.models.functions import ExtractYear
from django.utils import timezone
from account.decorators import check_contest_permission
from account.models import User
from contest.models import ContestRuleType
from submission.models import JudgeStatus, Submission
from utils.api import APIView
from utils.api import APIView, AsyncAPIView
from utils.async_helpers import async_cache_get, async_cache_set
from utils.constants import CacheKey
from ..models import Problem, ProblemTag
@@ -21,11 +22,11 @@ from ..serializers import (
)
class ProblemTagAPI(APIView):
def get(self, request):
class ProblemTagAPI(AsyncAPIView):
async def get(self, request):
keyword = request.GET.get("keyword", "")
cache_key = f"{CacheKey.problem_tags}:{keyword}"
cached = cache.get(cache_key)
cached = await async_cache_get(cache_key)
if cached is not None:
return self.success(cached)
@@ -33,48 +34,48 @@ class ProblemTagAPI(APIView):
if keyword:
qs = ProblemTag.objects.filter(name__icontains=keyword)
tags = qs.annotate(problem_count=Count("problem")).filter(problem_count__gt=0)
data = TagSerializer(tags, many=True).data
cache.set(cache_key, data, 3600)
data = await self.async_serialize_data(TagSerializer, [tag async for tag in tags], many=True)
await async_cache_set(cache_key, data, 3600)
return self.success(data)
class PickOneAPI(APIView):
def get(self, request):
problems = Problem.objects.filter(contest_id__isnull=True, visible=True)
count = problems.count()
class PickOneAPI(AsyncAPIView):
async def get(self, request):
ids = Problem.objects.filter(contest_id__isnull=True, visible=True).values_list("_id", flat=True)
count = await ids.acount()
if count == 0:
return self.error("No problem to pick")
return self.success(problems[random.randint(0, count - 1)]._id)
idx = random.randint(0, count - 1)
result = [pid async for pid in ids[idx : idx + 1]]
return self.success(result[0])
class ProblemAPI(APIView):
class ProblemAPI(AsyncAPIView):
@staticmethod
def _add_problem_status(request, queryset_values):
if request.user.is_authenticated:
profile = request.user.userprofile
acm_problems_status = profile.acm_problems_status.get("problems", {})
# paginate data
results = queryset_values.get("results")
if results is not None:
problems = results
else:
problems = [queryset_values]
for problem in problems:
problem["my_status"] = acm_problems_status.get(
str(problem["id"]), {}
).get("status")
def _add_problem_status(acm_problems_status, queryset_values):
results = queryset_values.get("results")
if results is not None:
problems = results
else:
problems = [queryset_values]
for problem in problems:
problem["my_status"] = acm_problems_status.get(str(problem["id"]), {}).get("status")
def get(self, request):
async def get(self, request):
# 问题详情页
problem_id = request.GET.get("problem_id")
if problem_id:
try:
problem = Problem.objects.select_related("created_by").get(
_id__iexact=problem_id, contest_id__isnull=True, visible=True
)
problem_data = ProblemSerializer(problem).data
self._add_problem_status(request, problem_data)
problem = await Problem.objects.select_related("created_by").prefetch_related("tags").filter(_id__iexact=problem_id, contest_id__isnull=True, visible=True).afirst()
if problem is None:
raise Problem.DoesNotExist
problem_data = await self.async_serialize_data(ProblemSerializer, problem)
if request.user.is_authenticated:
from account.models import UserProfile
profile = await UserProfile.objects.aget(user=request.user)
acm_problems_status = profile.acm_problems_status.get("problems", {})
self._add_problem_status(acm_problems_status, problem_data)
failed_statuses = [
JudgeStatus.WRONG_ANSWER,
JudgeStatus.CPU_TIME_LIMIT_EXCEEDED,
@@ -83,11 +84,11 @@ class ProblemAPI(APIView):
JudgeStatus.RUNTIME_ERROR,
JudgeStatus.COMPILE_ERROR,
]
problem_data["my_failed_count"] = Submission.objects.filter(
problem_data["my_failed_count"] = await Submission.objects.filter(
user_id=request.user.id,
problem_id=problem.id,
result__in=failed_statuses,
).count()
).acount()
else:
problem_data["my_failed_count"] = 0
return self.success(problem_data)
@@ -98,12 +99,7 @@ class ProblemAPI(APIView):
if not limit:
return self.error("Limit is needed")
problems = (
Problem.objects.select_related("created_by")
.prefetch_related("tags")
.filter(contest_id__isnull=True, visible=True)
.order_by("-create_time")
)
problems = Problem.objects.select_related("created_by").prefetch_related("tags").filter(contest_id__isnull=True, visible=True).order_by("-create_time")
author = request.GET.get("author")
if author:
@@ -117,9 +113,7 @@ class ProblemAPI(APIView):
# 搜索的情况
keyword = request.GET.get("keyword", "").strip()
if keyword:
problems = problems.filter(
Q(title__icontains=keyword) | Q(_id__icontains=keyword)
)
problems = problems.filter(Q(title__icontains=keyword) | Q(_id__icontains=keyword))
# 难度筛选
difficulty = request.GET.get("difficulty")
@@ -142,8 +136,13 @@ class ProblemAPI(APIView):
problems = problems.order_by(sort)
# 根据profile 为做过的题目添加标记
data = self.paginate_data(request, problems, ProblemListSerializer)
self._add_problem_status(request, data)
data = await self.async_paginate_data(request, problems, ProblemListSerializer)
if request.user.is_authenticated:
from account.models import UserProfile
profile = await UserProfile.objects.aget(user=request.user)
acm_problems_status = profile.acm_problems_status.get("problems", {})
self._add_problem_status(acm_problems_status, data)
return self.success(data)
@@ -152,24 +151,18 @@ class ContestProblemAPI(APIView):
if request.user.is_authenticated:
profile = request.user.userprofile
if self.contest.rule_type == ContestRuleType.ACM:
problems_status = profile.acm_problems_status.get(
"contest_problems", {}
)
problems_status = profile.acm_problems_status.get("contest_problems", {})
else:
problems_status = profile.oi_problems_status.get("contest_problems", {})
for problem in queryset_values:
problem["my_status"] = problems_status.get(str(problem["id"]), {}).get(
"status"
)
problem["my_status"] = problems_status.get(str(problem["id"]), {}).get("status")
@check_contest_permission(check_type="problems")
def get(self, request):
problem_id = request.GET.get("problem_id")
if problem_id:
try:
problem = Problem.objects.select_related("created_by").get(
_id__iexact=problem_id, contest=self.contest, visible=True
)
problem = Problem.objects.select_related("created_by").get(_id__iexact=problem_id, contest=self.contest, visible=True)
except Problem.DoesNotExist:
return self.error("Problem does not exist.")
if self.contest.problem_details_permission(request.user):
@@ -184,9 +177,7 @@ class ContestProblemAPI(APIView):
problem_data = ProblemSafeSerializer(problem).data
return self.success(problem_data)
contest_problems = Problem.objects.select_related("created_by").prefetch_related("tags").filter(
contest=self.contest, visible=True
)
contest_problems = Problem.objects.select_related("created_by").prefetch_related("tags").filter(contest=self.contest, visible=True)
if self.contest.problem_details_permission(request.user):
data = ProblemListSerializer(contest_problems, many=True).data
self._add_problem_status(request, data)
@@ -195,59 +186,60 @@ class ContestProblemAPI(APIView):
return self.success(data)
class ProblemSolvedPeopleCount(APIView):
def get(self, request):
class ProblemSolvedPeopleCount(AsyncAPIView):
async def get(self, request):
problem_id = request.GET.get("problem_id")
rate = "0"
if not request.user.is_authenticated:
return self.success(rate)
submission_count = Submission.objects.filter(
submission_count = await Submission.objects.filter(
user_id=request.user.id,
problem_id=problem_id,
result__in=[JudgeStatus.ACCEPTED, JudgeStatus.AST_CHECK_FAILED],
).count()
).acount()
if submission_count == 0:
return self.success(rate)
today = datetime.today()
years_ago = datetime(today.year - 2, today.month, today.day, 0, 0)
total_count = User.objects.filter(
is_disabled=False, last_login__gte=years_ago
).count()
accepted_count = Submission.objects.filter(
problem_id=problem_id,
result__in=[JudgeStatus.ACCEPTED, JudgeStatus.AST_CHECK_FAILED],
create_time__gte=years_ago,
).aggregate(user_count=Count("user_id", distinct=True))["user_count"]
if accepted_count < total_count:
now = timezone.now()
years_ago = now.replace(year=now.year - 2, hour=0, minute=0, second=0, microsecond=0)
total_count = await User.objects.filter(is_disabled=False, last_login__gte=years_ago).acount()
accepted_count = (
await sync_to_async(
Submission.objects.filter(
problem_id=problem_id,
result__in=[JudgeStatus.ACCEPTED, JudgeStatus.AST_CHECK_FAILED],
create_time__gte=years_ago,
).aggregate,
thread_sensitive=True,
)(user_count=Count("user_id", distinct=True))
)["user_count"]
if total_count and accepted_count < total_count:
rate = "%.2f" % ((total_count - accepted_count) / total_count * 100)
else:
rate = "0"
return self.success(rate)
class SimilarProblemAPI(APIView):
def get(self, request):
class SimilarProblemAPI(AsyncAPIView):
async def get(self, request):
problem_display_id = request.GET.get("problem_id")
if not problem_display_id:
return self.error("problem_id is required")
try:
problem = Problem.objects.get(_id__iexact=problem_display_id, contest__isnull=True)
problem = await Problem.objects.aget(_id__iexact=problem_display_id, contest__isnull=True)
except Problem.DoesNotExist:
return self.error("Problem not found")
tag_ids = list(problem.tags.values_list("id", flat=True))
tag_ids = [tag_id async for tag_id in problem.tags.values_list("id", flat=True)]
if not tag_ids:
return self.success([])
exclude_ids = [problem_display_id]
if request.user.is_authenticated:
profile = request.user.userprofile
ac_display_ids = [
v["_id"]
for v in profile.acm_problems_status.get("problems", {}).values()
if v.get("status") == JudgeStatus.ACCEPTED
]
from account.models import UserProfile
profile = await UserProfile.objects.aget(user=request.user)
ac_display_ids = [v["_id"] for v in profile.acm_problems_status.get("problems", {}).values() if v.get("status") == JudgeStatus.ACCEPTED]
exclude_ids.extend(ac_display_ids)
similar = (
@@ -258,14 +250,15 @@ class SimilarProblemAPI(APIView):
.distinct()
.order_by("difficulty")[:5]
)
return self.success(ProblemListSerializer(similar, many=True).data)
similar_list = [problem async for problem in similar]
return self.success(await self.async_serialize_data(ProblemListSerializer, similar_list, many=True))
class ProblemAuthorAPI(APIView):
def get(self, request):
class ProblemAuthorAPI(AsyncAPIView):
async def get(self, request):
show_all = request.GET.get("all", "0") == "1"
cache_key = f"{CacheKey.problem_authors}{'_all' if show_all else '_only_visible'}"
cached_data = cache.get(cache_key)
cached_data = await async_cache_get(cache_key)
if cached_data:
return self.success(cached_data)
@@ -273,38 +266,32 @@ class ProblemAuthorAPI(APIView):
if not show_all:
problem_filter["visible"] = True
authors = (
Problem.objects.filter(**problem_filter)
.values("created_by__username")
.annotate(problem_count=Count("id"))
.order_by("-problem_count")
)
authors = Problem.objects.filter(**problem_filter).values("created_by__username").annotate(problem_count=Count("id")).order_by("-problem_count")
result = [
{
"username": author["created_by__username"],
"problem_count": author["problem_count"],
}
for author in authors
async for author in authors
]
cache.set(cache_key, result, 7200)
await async_cache_set(cache_key, result, 7200)
return self.success(result)
class ProblemYearlyACRateAPI(APIView):
def get(self, request):
class ProblemYearlyACRateAPI(AsyncAPIView):
async def get(self, request):
problem_id = request.GET.get("problem_id")
if not problem_id:
return self.error("problem_id is required")
cache_key = f"{CacheKey.problem_yearly_ac}:{problem_id}"
cached = cache.get(cache_key)
cached = await async_cache_get(cache_key)
if cached is not None:
return self.success(cached)
try:
problem = Problem.objects.get(
_id__iexact=problem_id, contest_id__isnull=True, visible=True
)
problem = await Problem.objects.aget(_id__iexact=problem_id, contest_id__isnull=True, visible=True)
except Problem.DoesNotExist:
return self.error("Problem does not exist")
@@ -328,12 +315,10 @@ class ProblemYearlyACRateAPI(APIView):
"year": row["year"],
"total": row["total"],
"accepted": row["accepted"],
"ac_rate": round(row["accepted"] / row["total"] * 100, 2)
if row["total"] > 0
else 0.0,
"ac_rate": round(row["accepted"] / row["total"] * 100, 2) if row["total"] > 0 else 0.0,
}
for row in rows
async for row in rows
]
cache.set(cache_key, data, 3600)
await async_cache_set(cache_key, data, 3600)
return self.success(data)

View File

@@ -63,7 +63,7 @@ urlpatterns = [
name="admin_problemset_progress_detail_api",
),
# 题单同步管理API
path(
path( # DEPRECATED: 前端未调用
"problemset/<int:problem_set_id>/sync",
ProblemSetSyncAPI.as_view(),
name="admin_problemset_sync_api",

View File

@@ -24,7 +24,7 @@ urlpatterns = [
ProblemSetProblemAPI.as_view(),
name="problemset_problems_api",
),
path(
path( # DEPRECATED: 前端未调用
"problemset/<int:problem_set_id>/problems/<int:problem_id>",
ProblemSetProblemAPI.as_view(),
name="problemset_problem_detail_api",
@@ -35,12 +35,12 @@ urlpatterns = [
ProblemSetProgressAPI.as_view(),
name="problemset_progress_api",
),
path(
path( # DEPRECATED: 前端未调用
"problemset/<int:problem_set_id>/progress",
ProblemSetProgressAPI.as_view(),
name="problemset_progress_detail_api",
),
path("user/progress", UserProgressAPI.as_view(), name="user_progress_api"),
path("user/progress", UserProgressAPI.as_view(), name="user_progress_api"), # DEPRECATED: 前端未调用
# 奖章相关API
path("user/badges", UserBadgeAPI.as_view(), name="user_badges_api"),
path(

View File

@@ -332,6 +332,7 @@ class ProblemSetProgressAdminAPI(APIView):
return self.error("用户未加入该题单")
# DEPRECATED: 前端未调用 (2026-05-26)
class ProblemSetSyncAPI(APIView):
"""题单同步管理API"""

View File

@@ -1,3 +1,4 @@
from asgiref.sync import sync_to_async
from django.db.models import Avg, Count, Prefetch, Q
from django.utils import timezone
@@ -24,14 +25,14 @@ from problemset.serializers import (
UpdateProgressSerializer,
UserBadgeSerializer,
)
from submission.models import JudgeStatus, Submission, is_accepted
from utils.api import APIView, validate_serializer
from submission.models import Submission, is_accepted
from utils.api import APIView, AsyncAPIView, validate_serializer
class ProblemSetAPI(APIView):
class ProblemSetAPI(AsyncAPIView):
"""题单API - 用户端"""
def get(self, request):
async def get(self, request):
"""获取题单列表"""
# 预加载创建者信息
problem_sets = ProblemSet.objects.filter(visible=True).exclude(status=ProblemSetStatus.DRAFT).select_related("created_by")
@@ -65,16 +66,19 @@ class ProblemSetAPI(APIView):
user_earned_badge_ids = set()
if request.user.is_authenticated:
# 先获取所有题单ID不应用prefetch_related只获取ID
problem_set_ids = list(problem_sets.values_list("id", flat=True))
problem_set_ids = [problem_set_id async for problem_set_id in problem_sets.values_list("id", flat=True)]
if problem_set_ids:
# 批量查询用户在这些题单中的进度
user_progresses = ProblemSetProgress.objects.filter(problemset_id__in=problem_set_ids, user=request.user).select_related("problemset")
# 构建映射题单ID -> 进度对象
user_progress_map = {progress.problemset_id: progress for progress in user_progresses}
user_progress_map = {progress.problemset_id: progress async for progress in user_progresses}
# 批量查询用户已获得的奖章ID这些题单相关的
user_earned_badge_ids = set(UserBadge.objects.filter(user=request.user, badge__problemset_id__in=problem_set_ids).values_list("badge_id", flat=True))
user_earned_badge_ids = {
badge_id
async for badge_id in UserBadge.objects.filter(user=request.user, badge__problemset_id__in=problem_set_ids).values_list("badge_id", flat=True)
}
# 预加载奖章信息在获取ID之后应用避免在获取ID时也预加载
problem_sets = problem_sets.prefetch_related(Prefetch("problemsetbadge_set", queryset=ProblemSetBadge.objects.all(), to_attr="badges"))
@@ -83,31 +87,35 @@ class ProblemSetAPI(APIView):
request._user_progress_map = user_progress_map
request._user_earned_badge_ids = user_earned_badge_ids
data = self.paginate_data(request, problem_sets, ProblemSetListSerializer)
data = await self.async_paginate_data(request, problem_sets, ProblemSetListSerializer)
return self.success(data)
class ProblemSetDetailAPI(APIView):
class ProblemSetDetailAPI(AsyncAPIView):
"""题单详情API - 用户端"""
def get(self, request, problem_set_id):
async def get(self, request, problem_set_id):
"""获取题单详情"""
try:
problem_set = ProblemSet.objects.filter(id=problem_set_id, visible=True).exclude(status=ProblemSetStatus.DRAFT).get()
problem_set = await (
ProblemSet.objects.select_related("created_by")
.filter(id=problem_set_id, visible=True)
.exclude(status=ProblemSetStatus.DRAFT)
.aget()
)
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
serializer = ProblemSetSerializer(problem_set, context={"request": request})
return self.success(serializer.data)
return self.success(await self.async_serialize_data(ProblemSetSerializer, problem_set, context={"request": request}))
class ProblemSetProblemAPI(APIView):
class ProblemSetProblemAPI(AsyncAPIView):
"""题单题目API - 用户端"""
def get(self, request, problem_set_id):
async def get(self, request, problem_set_id):
"""获取题单中的题目列表"""
try:
problem_set = ProblemSet.objects.filter(id=problem_set_id, visible=True).exclude(status=ProblemSetStatus.DRAFT).get()
problem_set = await ProblemSet.objects.filter(id=problem_set_id, visible=True).exclude(status=ProblemSetStatus.DRAFT).aget()
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
@@ -115,12 +123,16 @@ class ProblemSetProblemAPI(APIView):
# 预取当前用户的题单进度,供 get_is_completed 使用,避免 N+1
user_progress = None
if request.user.is_authenticated:
try:
user_progress = ProblemSetProgress.objects.get(problemset=problem_set, user=request.user)
except ProblemSetProgress.DoesNotExist:
pass
serializer = ProblemSetProblemSerializer(problems, many=True, context={"request": request, "user_progress": user_progress})
return self.success(serializer.data)
user_progress = await ProblemSetProgress.objects.filter(problemset=problem_set, user=request.user).afirst()
problem_list = [problem async for problem in problems]
return self.success(
await self.async_serialize_data(
ProblemSetProblemSerializer,
problem_list,
many=True,
context={"request": request, "user_progress": user_progress},
)
)
class ProblemSetProgressAPI(APIView):
@@ -236,6 +248,7 @@ class ProblemSetProgressAPI(APIView):
UserBadge.objects.create(user=progress.user, badge=badge)
# DEPRECATED: 前端未调用 (2026-05-26)
class UserProgressAPI(APIView):
"""用户进度API"""
@@ -247,10 +260,10 @@ class UserProgressAPI(APIView):
return self.success(serializer.data)
class UserBadgeAPI(APIView):
class UserBadgeAPI(AsyncAPIView):
"""用户奖章API"""
def get(self, request):
async def get(self, request):
"""获取用户的奖章列表"""
# 支持通过username参数获取指定用户的徽章
username = request.GET.get("username")
@@ -258,41 +271,41 @@ class UserBadgeAPI(APIView):
if username:
# 获取指定用户的徽章
try:
target_user = User.objects.get(username=username, is_disabled=False)
badges = UserBadge.objects.filter(user=target_user).order_by("-earned_time")
target_user = await User.objects.aget(username=username, is_disabled=False)
badges = UserBadge.objects.select_related("badge").filter(user=target_user).order_by("-earned_time")
except User.DoesNotExist:
return self.error("用户不存在")
else:
# 获取当前用户的徽章
badges = UserBadge.objects.filter(user=request.user).order_by("-earned_time")
badges = UserBadge.objects.select_related("badge").filter(user=request.user).order_by("-earned_time")
serializer = UserBadgeSerializer(badges, many=True)
return self.success(serializer.data)
badge_list = [badge async for badge in badges]
return self.success(await self.async_serialize_data(UserBadgeSerializer, badge_list, many=True))
class ProblemSetBadgeAPI(APIView):
class ProblemSetBadgeAPI(AsyncAPIView):
"""题单奖章API - 用户端"""
def get(self, request, problem_set_id):
async def get(self, request, problem_set_id):
"""获取题单的奖章列表"""
try:
problem_set = ProblemSet.objects.filter(id=problem_set_id, visible=True).exclude(status=ProblemSetStatus.DRAFT).get()
problem_set = await ProblemSet.objects.filter(id=problem_set_id, visible=True).exclude(status=ProblemSetStatus.DRAFT).aget()
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
badges = ProblemSetBadge.objects.filter(problemset=problem_set)
serializer = ProblemSetBadgeSerializer(badges, many=True)
return self.success(serializer.data)
badge_list = [badge async for badge in badges]
return self.success(await self.async_serialize_data(ProblemSetBadgeSerializer, badge_list, many=True))
class ProblemSetUserProgressAPI(APIView):
class ProblemSetUserProgressAPI(AsyncAPIView):
"""题单用户进度列表API"""
@admin_role_required
def get(self, request, problem_set_id: int):
async def get(self, request, problem_set_id: int):
"""获取题单的用户进度列表"""
try:
problem_set = ProblemSet.objects.filter(id=problem_set_id, visible=True).exclude(status=ProblemSetStatus.DRAFT).get()
problem_set = await ProblemSet.objects.filter(id=problem_set_id, visible=True).exclude(status=ProblemSetStatus.DRAFT).aget()
except ProblemSet.DoesNotExist:
return self.error("题单不存在")
@@ -321,7 +334,7 @@ class ProblemSetUserProgressAPI(APIView):
# 计算统计数据(基于所有数据,而非分页数据)
# 使用一次查询获取所有统计数据
stats = progresses.aggregate(
stats = await sync_to_async(progresses.aggregate, thread_sensitive=True)(
total=Count("id"),
completed=Count("id", filter=Q(is_completed=True)),
avg_progress=Avg("progress_percentage"),
@@ -351,7 +364,7 @@ class ProblemSetUserProgressAPI(APIView):
# 构建题单所有题目的数据结构和映射
all_problems_list = []
all_problems_map = {}
for psp in all_problemset_problems:
async for psp in all_problemset_problems:
problem_data = {
"id": psp.problem.id,
"_id": psp.problem._id,
@@ -362,7 +375,7 @@ class ProblemSetUserProgressAPI(APIView):
all_problems_map[str(psp.problem.id)] = psp.problem
# 从当前页的数据中收集已完成的问题ID用于序列化器
paginated_progresses = list(progresses[offset : offset + limit])
paginated_progresses = [progress async for progress in progresses[offset : offset + limit]]
completed_problem_ids = set()
for progress in paginated_progresses:
if progress.progress_detail:
@@ -376,7 +389,7 @@ class ProblemSetUserProgressAPI(APIView):
request._problems_dict_cache = problems_dict
# 使用分页
data = self.paginate_data(request, progresses, ProblemSetProgressSerializer)
data = await self.async_paginate_data(request, progresses, ProblemSetProgressSerializer)
# 添加统计数据
data["statistics"] = {

View File

@@ -28,6 +28,7 @@ dependencies = [
"tree-sitter-c>=0.24.2",
"tree-sitter-python>=0.25.0",
"xlsxwriter>=3.2.9,<4",
"asgiref>=3.11.1",
]
[dependency-groups]
@@ -35,6 +36,10 @@ dev = [
"ruff>=0.15.11",
]
[tool.pyright]
venvPath = "."
venv = ".venv"
[tool.ruff]
line-length = 180
exclude = ["*/migrations/*", "*settings.py", "*/apps.py", ".venv"]

View File

@@ -12,6 +12,6 @@ urlpatterns = [
path("submission", SubmissionAPI.as_view()),
path("submissions", SubmissionListAPI.as_view()),
path("submissions/today_count", SubmissionsTodayCount.as_view()),
path("submission_exists", SubmissionExistsAPI.as_view()),
path("submission_exists", SubmissionExistsAPI.as_view()), # DEPRECATED: 前端未调用
path("contest_submissions", ContestSubmissionListAPI.as_view()),
]

View File

@@ -1,5 +1,7 @@
import ipaddress
from datetime import datetime
from asgiref.sync import sync_to_async
from django.utils import timezone
from account.decorators import check_contest_permission, login_required
from contest.models import ContestRuleType, ContestStatus
@@ -8,7 +10,7 @@ from options.options import SysOptions
# from judge.dispatcher import JudgeDispatcher
from problem.models import Problem, ProblemRuleType
from utils.api import APIView, validate_serializer
from utils.api import APIView, AsyncAPIView, validate_serializer
from utils.cache import cache
from utils.captcha import Captcha
from utils.throttling import TokenBucket
@@ -154,8 +156,8 @@ class SubmissionAPI(APIView):
return self.success()
class SubmissionListAPI(APIView):
def get(self, request):
class SubmissionListAPI(AsyncAPIView):
async def get(self, request):
if not request.GET.get("limit"):
return self.error("Limit is needed")
if request.GET.get("contest_id"):
@@ -171,14 +173,15 @@ class SubmissionListAPI(APIView):
language = request.GET.get("language")
if problem_id:
try:
problem = Problem.objects.get(
problem = await Problem.objects.aget(
_id__iexact=problem_id, contest_id__isnull=True, visible=True
)
except Problem.DoesNotExist:
return self.error("Problem doesn't exist")
submissions = submissions.filter(problem=problem)
if not SysOptions.submission_list_show_all and request.user.is_regular_user():
show_all = await SysOptions.aget("submission_list_show_all")
if not show_all and request.user.is_regular_user():
return self.success({"results": [], "total": 0})
if myself and myself == "1":
@@ -190,21 +193,25 @@ class SubmissionListAPI(APIView):
if language:
submissions = submissions.filter(language=language)
if request.GET.get("today") == "1":
today = datetime.today()
now = timezone.now()
submissions = submissions.filter(
create_time__gte=datetime(today.year, today.month, today.day, 0, 0)
create_time__gte=now.replace(hour=0, minute=0, second=0, microsecond=0)
)
data = self.paginate_data(request, submissions)
data = await self.async_paginate_data(request, submissions)
results = data["results"]
if request.user.is_authenticated and request.user.is_regular_user():
problem_ids = list({s.problem_id for s in results})
progress_cache = bulk_fetch_problemset_progress(request.user, problem_ids)
progress_cache = await sync_to_async(bulk_fetch_problemset_progress)(request.user, problem_ids)
else:
progress_cache = {}
data["results"] = SubmissionListSerializer(
results, many=True, user=request.user, problemset_progress_cache=progress_cache
).data
data["results"] = await self.async_serialize_data(
SubmissionListSerializer,
results,
many=True,
user=request.user,
problemset_progress_cache=progress_cache,
)
return self.success(data)
@@ -262,6 +269,7 @@ class ContestSubmissionListAPI(APIView):
return self.success(data)
# DEPRECATED: 前端未调用 (2026-05-26)
class SubmissionExistsAPI(APIView):
def get(self, request):
if not request.GET.get("problem_id"):
@@ -274,10 +282,10 @@ class SubmissionExistsAPI(APIView):
)
class SubmissionsTodayCount(APIView):
def get(self, request):
today = datetime.today()
count = Submission.objects.filter(
create_time__gte=datetime(today.year, today.month, today.day, 0, 0)
).count()
class SubmissionsTodayCount(AsyncAPIView):
async def get(self, request):
now = timezone.now()
count = await Submission.objects.filter(
create_time__gte=now.replace(hour=0, minute=0, second=0, microsecond=0)
).acount()
return self.success(count)

View File

@@ -1,7 +1,10 @@
import asyncio
import functools
import inspect
import json
import logging
from asgiref.sync import sync_to_async
from django.http import HttpResponse, QueryDict
from django.utils.decorators import method_decorator
from django.views.decorators.csrf import csrf_exempt
@@ -162,6 +165,77 @@ class CSRFExemptAPIView(APIView):
return super(CSRFExemptAPIView, self).dispatch(request, *args, **kwargs)
class AsyncAPIView(APIView):
view_is_async = True
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls.view_is_async = True
async def dispatch(self, request, *args, **kwargs):
if self.request_parsers:
try:
request.data = self._get_request_data(self.request)
except ValueError as e:
return self.error(err="invalid-request", msg=str(e))
try:
handler = getattr(self, request.method.lower(), self.http_method_not_allowed)
response = handler(request, *args, **kwargs)
if asyncio.iscoroutine(response):
response = await response
return response
except APIError as e:
ret = {"msg": e.msg}
if e.err:
ret["err"] = e.err
return self.error(**ret)
except Exception as e:
logger.exception(e)
return self.server_error()
def serialize_data(self, object_serializer, data, **kwargs):
return object_serializer(data, **kwargs).data
async def async_serialize_data(self, object_serializer, data, **kwargs):
return await sync_to_async(
self.serialize_data,
thread_sensitive=True,
)(object_serializer, data, **kwargs)
async def async_paginate_data(self, request, query_set, object_serializer=None):
try:
limit = int(request.GET.get("limit", "10"))
except ValueError:
limit = 10
if limit < 0 or limit > 250:
limit = 10
try:
offset = int(request.GET.get("offset", "0"))
except ValueError:
offset = 0
if offset < 0:
offset = 0
count, results = await asyncio.gather(
query_set.acount(),
sync_to_async(lambda: list(query_set[offset:offset + limit]), thread_sensitive=True)(),
)
if object_serializer:
results = await self.async_serialize_data(
object_serializer,
results,
many=True,
context={"request": request},
)
data = {"results": results, "total": count}
return data
class CSRFExemptAsyncAPIView(AsyncAPIView):
@method_decorator(csrf_exempt)
async def dispatch(self, request, *args, **kwargs):
return await super().dispatch(request, *args, **kwargs)
def validate_serializer(serializer):
"""
@validate_serializer(TestSerializer)
@@ -169,6 +243,20 @@ def validate_serializer(serializer):
return self.success(request.data)
"""
def validate(view_method):
if inspect.iscoroutinefunction(view_method):
@functools.wraps(view_method)
async def async_handle(*args, **kwargs):
self = args[0]
request = args[1]
s = serializer(data=request.data)
if s.is_valid():
request.data = s.data
request.serializer = s
return await view_method(*args, **kwargs)
else:
return self.invalid_serializer(s)
return async_handle
@functools.wraps(view_method)
def handle(*args, **kwargs):
self = args[0]
@@ -180,7 +268,6 @@ def validate_serializer(serializer):
return view_method(*args, **kwargs)
else:
return self.invalid_serializer(s)
return handle
return validate

14
utils/async_helpers.py Normal file
View File

@@ -0,0 +1,14 @@
from asgiref.sync import sync_to_async
from django.core.cache import cache
async def async_cache_get(key, default=None):
return await sync_to_async(cache.get, thread_sensitive=True)(key, default)
async def async_cache_set(key, value, timeout=None):
return await sync_to_async(cache.set, thread_sensitive=True)(key, value, timeout)
async def async_cache_delete(key):
return await sync_to_async(cache.delete, thread_sensitive=True)(key)

View File

@@ -4,5 +4,5 @@ from .views import SimditorFileUploadAPIView, SimditorImageUploadAPIView
urlpatterns = [
path("upload_image", SimditorImageUploadAPIView.as_view()),
path("upload_file", SimditorFileUploadAPIView.as_view()),
path("upload_file", SimditorFileUploadAPIView.as_view()), # DEPRECATED: 前端未调用
]

View File

@@ -46,6 +46,7 @@ class SimditorImageUploadAPIView(CSRFExemptAPIView):
"file_path": f"{settings.UPLOAD_PREFIX}/{img_name}"})
# DEPRECATED: 前端未调用 (2026-05-26)
class SimditorFileUploadAPIView(CSRFExemptAPIView):
request_parsers = ()

2
uv.lock generated
View File

@@ -550,6 +550,7 @@ name = "onlinejudge"
version = "2.0.0"
source = { virtual = "." }
dependencies = [
{ name = "asgiref" },
{ name = "channels" },
{ name = "channels-redis" },
{ name = "django" },
@@ -582,6 +583,7 @@ dev = [
[package.metadata]
requires-dist = [
{ name = "asgiref", specifier = ">=3.11.1" },
{ name = "channels", specifier = ">=4.3.2,<5" },
{ name = "channels-redis", specifier = ">=4.3.0,<5" },
{ name = "django", specifier = ">=6.0.4,<6.1" },