async
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import functools
|
||||
import hashlib
|
||||
import inspect
|
||||
import time
|
||||
|
||||
from contest.models import Contest, ContestRuleType, ContestStatus, ContestType
|
||||
@@ -15,47 +16,58 @@ class BasePermissionDecorator(object):
|
||||
self.func = func
|
||||
|
||||
def __get__(self, obj, obj_type):
|
||||
if inspect.iscoroutinefunction(self.func):
|
||||
return functools.partial(self._async_call, obj)
|
||||
return functools.partial(self.__call__, obj)
|
||||
|
||||
def error(self, data):
|
||||
return JSONResponse.response({"error": "permission-denied", "data": data})
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
self.request = args[1]
|
||||
request = args[1]
|
||||
|
||||
if self.check_permission():
|
||||
if self.request.user.is_disabled:
|
||||
if self.check_permission(request):
|
||||
if request.user.is_disabled:
|
||||
return self.error("Your account is disabled")
|
||||
return self.func(*args, **kwargs)
|
||||
else:
|
||||
return self.error("Please login first")
|
||||
|
||||
def check_permission(self):
|
||||
async def _async_call(self, *args, **kwargs):
|
||||
request = args[1]
|
||||
|
||||
if self.check_permission(request):
|
||||
if request.user.is_disabled:
|
||||
return self.error("Your account is disabled")
|
||||
return await self.func(*args, **kwargs)
|
||||
return self.error("Please login first")
|
||||
|
||||
def check_permission(self, request):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class login_required(BasePermissionDecorator):
|
||||
def check_permission(self):
|
||||
return self.request.user.is_authenticated
|
||||
def check_permission(self, request):
|
||||
return request.user.is_authenticated
|
||||
|
||||
|
||||
class super_admin_required(BasePermissionDecorator):
|
||||
def check_permission(self):
|
||||
user = self.request.user
|
||||
def check_permission(self, request):
|
||||
user = request.user
|
||||
return user.is_authenticated and user.is_super_admin()
|
||||
|
||||
|
||||
class admin_role_required(BasePermissionDecorator):
|
||||
def check_permission(self):
|
||||
user = self.request.user
|
||||
def check_permission(self, request):
|
||||
user = request.user
|
||||
return user.is_authenticated and user.is_admin_role()
|
||||
|
||||
|
||||
class problem_permission_required(admin_role_required):
|
||||
def check_permission(self):
|
||||
if not super(problem_permission_required, self).check_permission():
|
||||
def check_permission(self, request):
|
||||
if not super().check_permission(request):
|
||||
return False
|
||||
if self.request.user.problem_permission == ProblemPermission.NONE:
|
||||
if request.user.problem_permission == ProblemPermission.NONE:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ from account.models import UserProfile
|
||||
from problem.models import Problem
|
||||
from submission.models import JudgeStatus
|
||||
|
||||
|
||||
ACCEPTED_STATUSES = {JudgeStatus.ACCEPTED, JudgeStatus.AST_CHECK_FAILED}
|
||||
|
||||
|
||||
|
||||
@@ -23,6 +23,9 @@ class UserManager(models.Manager):
|
||||
def get_by_natural_key(self, username):
|
||||
return self.get(**{f"{self.model.USERNAME_FIELD}__iexact": username})
|
||||
|
||||
async def aget_by_natural_key(self, username):
|
||||
return await self.aget(**{f"{self.model.USERNAME_FIELD}__iexact": username})
|
||||
|
||||
|
||||
class User(AbstractBaseUser):
|
||||
username = models.TextField(unique=True)
|
||||
|
||||
@@ -4,6 +4,6 @@ from ..views.admin import GenerateUserAPI, ResetUserPasswordAPI, UserAdminAPI
|
||||
|
||||
urlpatterns = [
|
||||
path("user", UserAdminAPI.as_view()),
|
||||
path("generate_user", GenerateUserAPI.as_view()),
|
||||
path("generate_user", GenerateUserAPI.as_view()), # DEPRECATED: 前端未调用
|
||||
path("reset_password", ResetUserPasswordAPI.as_view()),
|
||||
]
|
||||
|
||||
@@ -29,28 +29,28 @@ urlpatterns = [
|
||||
path("login", UserLoginAPI.as_view()),
|
||||
path("logout", UserLogoutAPI.as_view()),
|
||||
path("register", UserRegisterAPI.as_view()),
|
||||
path("change_password", UserChangePasswordAPI.as_view()),
|
||||
path("change_email", UserChangeEmailAPI.as_view()),
|
||||
path("apply_reset_password", ApplyResetPasswordAPI.as_view()),
|
||||
path("reset_password", ResetPasswordAPI.as_view()),
|
||||
path("change_password", UserChangePasswordAPI.as_view()), # DEPRECATED: 前端未调用
|
||||
path("change_email", UserChangeEmailAPI.as_view()), # DEPRECATED: 前端未调用
|
||||
path("apply_reset_password", ApplyResetPasswordAPI.as_view()), # DEPRECATED: 前端未调用
|
||||
path("reset_password", ResetPasswordAPI.as_view()), # DEPRECATED: 前端未调用
|
||||
path("captcha", CaptchaAPIView.as_view()),
|
||||
path("check_username_or_email", UsernameOrEmailCheck.as_view()),
|
||||
path("check_username_or_email", UsernameOrEmailCheck.as_view()), # DEPRECATED: 前端未调用
|
||||
path("profile", UserProfileAPI.as_view(), name="user_profile_api"),
|
||||
path("profile/fresh_display_id", ProfileProblemDisplayIDRefreshAPI.as_view()),
|
||||
path("metrics", Metrics.as_view()),
|
||||
path("upload_avatar", AvatarUploadAPI.as_view()),
|
||||
path("tfa_required", CheckTFARequiredAPI.as_view()),
|
||||
path("tfa_required", CheckTFARequiredAPI.as_view()), # DEPRECATED: 前端未调用
|
||||
path(
|
||||
"two_factor_auth",
|
||||
"two_factor_auth", # DEPRECATED: 前端未调用
|
||||
TwoFactorAuthAPI.as_view(),
|
||||
),
|
||||
path("user_rank", UserRankAPI.as_view()),
|
||||
path("user_activity_rank", UserActivityRankAPI.as_view()),
|
||||
path("user_problem_rank", UserProblemRankAPI.as_view()),
|
||||
path("sessions", SessionManagementAPI.as_view()),
|
||||
path("sessions", SessionManagementAPI.as_view()), # DEPRECATED: 前端未调用
|
||||
path(
|
||||
"open_api_appkey",
|
||||
"open_api_appkey", # DEPRECATED: 前端未调用
|
||||
OpenAPIAppkeyAPI.as_view(),
|
||||
),
|
||||
path("sso", SSOAPI.as_view()),
|
||||
path("sso", SSOAPI.as_view()), # DEPRECATED: 前端未调用
|
||||
]
|
||||
|
||||
@@ -191,6 +191,7 @@ class UserAdminAPI(APIView):
|
||||
return self.success()
|
||||
|
||||
|
||||
# DEPRECATED: 前端未调用 (2026-05-26)
|
||||
class GenerateUserAPI(APIView):
|
||||
@super_admin_required
|
||||
def get(self, request):
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import os
|
||||
from datetime import timedelta
|
||||
from importlib import import_module
|
||||
@@ -5,7 +6,6 @@ from importlib import import_module
|
||||
import qrcode
|
||||
from django.conf import settings
|
||||
from django.contrib import auth
|
||||
from django.core.cache import cache
|
||||
from django.db.models import Count, Q
|
||||
from django.template.loader import render_to_string
|
||||
from django.utils import timezone
|
||||
@@ -16,8 +16,9 @@ from otpauth import TOTP
|
||||
|
||||
from options.options import SysOptions
|
||||
from problem.models import Problem
|
||||
from submission.models import JudgeStatus, Submission, is_accepted
|
||||
from utils.api import APIView, CSRFExemptAPIView, validate_serializer
|
||||
from submission.models import JudgeStatus, Submission
|
||||
from utils.api import APIView, AsyncAPIView, CSRFExemptAPIView, validate_serializer
|
||||
from utils.async_helpers import async_cache_get, async_cache_set
|
||||
from utils.captcha import Captcha
|
||||
from utils.constants import CacheKey, ContestRuleType
|
||||
from utils.shortcuts import datetime2str, img2base64, rand_str
|
||||
@@ -58,12 +59,9 @@ def _valid_totp(token, code):
|
||||
return _totp(token).verify(code)
|
||||
|
||||
|
||||
class UserProfileAPI(APIView):
|
||||
class UserProfileAPI(AsyncAPIView):
|
||||
@method_decorator(ensure_csrf_cookie)
|
||||
def get(self, request, **kwargs):
|
||||
"""
|
||||
判断是否登录, 若登录返回用户信息
|
||||
"""
|
||||
async def get(self, request, **kwargs):
|
||||
user = request.user
|
||||
if not user.is_authenticated:
|
||||
return self.success()
|
||||
@@ -71,52 +69,51 @@ class UserProfileAPI(APIView):
|
||||
username = request.GET.get("username")
|
||||
try:
|
||||
if username:
|
||||
user = User.objects.get(username=username, is_disabled=False)
|
||||
user = await User.objects.aget(username=username, is_disabled=False)
|
||||
else:
|
||||
user = request.user
|
||||
# api返回的是自己的信息,可以返real_name
|
||||
show_real_name = True
|
||||
except User.DoesNotExist:
|
||||
return self.error("User does not exist")
|
||||
return self.success(UserProfileSerializer(user.userprofile, show_real_name=show_real_name).data)
|
||||
profile = await UserProfile.objects.select_related("user").aget(user=user)
|
||||
return self.success(UserProfileSerializer(profile, show_real_name=show_real_name).data)
|
||||
|
||||
@validate_serializer(EditUserProfileSerializer)
|
||||
@login_required
|
||||
def put(self, request):
|
||||
async def put(self, request):
|
||||
data = request.data
|
||||
user_profile = request.user.userprofile
|
||||
user_profile = await UserProfile.objects.select_related("user").aget(user=request.user)
|
||||
for k, v in data.items():
|
||||
setattr(user_profile, k, v)
|
||||
user_profile.save()
|
||||
await user_profile.asave()
|
||||
return self.success(UserProfileSerializer(user_profile, show_real_name=True).data)
|
||||
|
||||
|
||||
class Metrics(APIView):
|
||||
def get(self, request):
|
||||
class Metrics(AsyncAPIView):
|
||||
async def get(self, request):
|
||||
userid = request.GET.get("userid")
|
||||
submissions = Submission.objects.filter(user_id=userid, contest_id__isnull=True)
|
||||
if submissions.count() == 0:
|
||||
qs = Submission.objects.filter(user_id=userid, contest_id__isnull=True)
|
||||
count, latest, first = await asyncio.gather(
|
||||
qs.acount(),
|
||||
qs.order_by("-create_time").afirst(),
|
||||
qs.order_by("create_time").afirst(),
|
||||
)
|
||||
if count == 0 or not latest or not first:
|
||||
return self.error("暂无提交")
|
||||
else:
|
||||
latest_submission = submissions.first()
|
||||
last_submission = submissions.last()
|
||||
if last_submission and latest_submission:
|
||||
return self.success(
|
||||
{
|
||||
"now": datetime2str(timezone.now()),
|
||||
"latest": datetime2str(latest_submission.create_time),
|
||||
"first": datetime2str(last_submission.create_time),
|
||||
}
|
||||
)
|
||||
else:
|
||||
return self.error("暂无提交")
|
||||
return self.success(
|
||||
{
|
||||
"now": datetime2str(timezone.now()),
|
||||
"latest": datetime2str(latest.create_time),
|
||||
"first": datetime2str(first.create_time),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class AvatarUploadAPI(APIView):
|
||||
class AvatarUploadAPI(AsyncAPIView):
|
||||
request_parsers = ()
|
||||
|
||||
@login_required
|
||||
def post(self, request):
|
||||
async def post(self, request):
|
||||
form = ImageUploadForm(request.POST, request.FILES)
|
||||
if form.is_valid():
|
||||
avatar = form.cleaned_data["image"]
|
||||
@@ -132,13 +129,14 @@ class AvatarUploadAPI(APIView):
|
||||
with open(os.path.join(settings.AVATAR_UPLOAD_DIR, name), "wb") as img:
|
||||
for chunk in avatar:
|
||||
img.write(chunk)
|
||||
user_profile = request.user.userprofile
|
||||
user_profile = await UserProfile.objects.aget(user=request.user)
|
||||
|
||||
user_profile.avatar = f"{settings.AVATAR_URI_PREFIX}/{name}"
|
||||
user_profile.save()
|
||||
await user_profile.asave()
|
||||
return self.success("Succeeded")
|
||||
|
||||
|
||||
# DEPRECATED: 前端未调用 (2026-05-26)
|
||||
class TwoFactorAuthAPI(APIView):
|
||||
@login_required
|
||||
def get(self, request):
|
||||
@@ -186,6 +184,7 @@ class TwoFactorAuthAPI(APIView):
|
||||
return self.error("Invalid code")
|
||||
|
||||
|
||||
# DEPRECATED: 前端未调用 (2026-05-26)
|
||||
class CheckTFARequiredAPI(APIView):
|
||||
@validate_serializer(UsernameOrEmailCheckSerializer)
|
||||
def post(self, request):
|
||||
@@ -203,31 +202,26 @@ class CheckTFARequiredAPI(APIView):
|
||||
return self.success({"result": result})
|
||||
|
||||
|
||||
class UserLoginAPI(APIView):
|
||||
class UserLoginAPI(AsyncAPIView):
|
||||
@validate_serializer(UserLoginSerializer)
|
||||
def post(self, request):
|
||||
"""
|
||||
User login api
|
||||
"""
|
||||
async def post(self, request):
|
||||
data = request.data
|
||||
user = auth.authenticate(username=data["username"], password=data["password"])
|
||||
# None is returned if username or password is wrong
|
||||
user = await auth.aauthenticate(username=data["username"], password=data["password"])
|
||||
if user:
|
||||
if user.is_disabled:
|
||||
return self.error("Your account has been disabled")
|
||||
if not user.two_factor_auth:
|
||||
prev_login = user.last_login
|
||||
auth.login(request, user)
|
||||
await auth.alogin(request, user)
|
||||
request.session["prev_login"] = datetime2str(prev_login) if prev_login else ""
|
||||
return self.success("Succeeded")
|
||||
|
||||
# `tfa_code` not in post data
|
||||
if user.two_factor_auth and "tfa_code" not in data:
|
||||
return self.error("tfa_required")
|
||||
|
||||
if _valid_totp(user.tfa_token, data["tfa_code"]):
|
||||
prev_login = user.last_login
|
||||
auth.login(request, user)
|
||||
await auth.alogin(request, user)
|
||||
request.session["prev_login"] = datetime2str(prev_login) if prev_login else ""
|
||||
return self.success("Succeeded")
|
||||
else:
|
||||
@@ -236,12 +230,13 @@ class UserLoginAPI(APIView):
|
||||
return self.error("Invalid username or password")
|
||||
|
||||
|
||||
class UserLogoutAPI(APIView):
|
||||
def get(self, request):
|
||||
auth.logout(request)
|
||||
class UserLogoutAPI(AsyncAPIView):
|
||||
async def get(self, request):
|
||||
await auth.alogout(request)
|
||||
return self.success()
|
||||
|
||||
|
||||
# DEPRECATED: 前端未调用 (2026-05-26)
|
||||
class UsernameOrEmailCheck(APIView):
|
||||
@validate_serializer(UsernameOrEmailCheckSerializer)
|
||||
def post(self, request):
|
||||
@@ -258,13 +253,10 @@ class UsernameOrEmailCheck(APIView):
|
||||
return self.success(result)
|
||||
|
||||
|
||||
class UserRegisterAPI(APIView):
|
||||
class UserRegisterAPI(AsyncAPIView):
|
||||
@validate_serializer(UserRegisterSerializer)
|
||||
def post(self, request):
|
||||
"""
|
||||
User register api
|
||||
"""
|
||||
if not SysOptions.allow_register:
|
||||
async def post(self, request):
|
||||
if not await SysOptions.aget("allow_register"):
|
||||
return self.error("Register function has been disabled by admin")
|
||||
|
||||
data = request.data
|
||||
@@ -273,17 +265,18 @@ class UserRegisterAPI(APIView):
|
||||
captcha = Captcha(request)
|
||||
if not captcha.check(data["captcha"]):
|
||||
return self.error("Invalid captcha")
|
||||
if User.objects.filter(username=data["username"]).exists():
|
||||
if await User.objects.filter(username=data["username"]).aexists():
|
||||
return self.error("Username already exists")
|
||||
if User.objects.filter(email=data["email"]).exists():
|
||||
if await User.objects.filter(email=data["email"]).aexists():
|
||||
return self.error("Email already exists")
|
||||
user = User.objects.create(username=data["username"], email=data["email"])
|
||||
user = await User.objects.acreate(username=data["username"], email=data["email"])
|
||||
user.set_password(data["password"])
|
||||
user.save()
|
||||
UserProfile.objects.create(user=user)
|
||||
await user.asave()
|
||||
await UserProfile.objects.acreate(user=user)
|
||||
return self.success("Succeeded")
|
||||
|
||||
|
||||
# DEPRECATED: 前端未调用 (2026-05-26)
|
||||
class UserChangeEmailAPI(APIView):
|
||||
@validate_serializer(UserChangeEmailSerializer)
|
||||
@login_required
|
||||
@@ -306,6 +299,7 @@ class UserChangeEmailAPI(APIView):
|
||||
return self.error("Wrong password")
|
||||
|
||||
|
||||
# DEPRECATED: 前端未调用 (2026-05-26)
|
||||
class UserChangePasswordAPI(APIView):
|
||||
@validate_serializer(UserChangePasswordSerializer)
|
||||
@login_required
|
||||
@@ -329,6 +323,7 @@ class UserChangePasswordAPI(APIView):
|
||||
return self.error("Invalid old password")
|
||||
|
||||
|
||||
# DEPRECATED: 前端未调用 (2026-05-26)
|
||||
class ApplyResetPasswordAPI(APIView):
|
||||
@validate_serializer(ApplyResetPasswordSerializer)
|
||||
def post(self, request):
|
||||
@@ -363,6 +358,7 @@ class ApplyResetPasswordAPI(APIView):
|
||||
return self.success("Succeeded")
|
||||
|
||||
|
||||
# DEPRECATED: 前端未调用 (2026-05-26)
|
||||
class ResetPasswordAPI(APIView):
|
||||
@validate_serializer(ResetPasswordSerializer)
|
||||
def post(self, request):
|
||||
@@ -383,6 +379,7 @@ class ResetPasswordAPI(APIView):
|
||||
return self.success("Succeeded")
|
||||
|
||||
|
||||
# DEPRECATED: 前端未调用 (2026-05-26)
|
||||
class SessionManagementAPI(APIView):
|
||||
@login_required
|
||||
def get(self, request):
|
||||
@@ -426,8 +423,8 @@ class SessionManagementAPI(APIView):
|
||||
return self.error("Invalid session_key")
|
||||
|
||||
|
||||
class UserRankAPI(APIView):
|
||||
def get(self, request):
|
||||
class UserRankAPI(AsyncAPIView):
|
||||
async def get(self, request):
|
||||
rule_type = request.GET.get("rule")
|
||||
username = request.GET.get("username", "")
|
||||
try:
|
||||
@@ -448,16 +445,16 @@ class UserRankAPI(APIView):
|
||||
profiles = profiles.filter(total_score__gt=0).order_by("-total_score")
|
||||
if n > 0:
|
||||
profiles = profiles[:n]
|
||||
return self.success(self.paginate_data(request, profiles, RankInfoSerializer))
|
||||
return self.success(await self.async_paginate_data(request, profiles, RankInfoSerializer))
|
||||
|
||||
|
||||
class UserActivityRankAPI(APIView):
|
||||
def get(self, request):
|
||||
class UserActivityRankAPI(AsyncAPIView):
|
||||
async def get(self, request):
|
||||
start = request.GET.get("start")
|
||||
if not start:
|
||||
return self.error("start time is required")
|
||||
cache_key = f"{CacheKey.user_activity_rank}:{start}"
|
||||
cached = cache.get(cache_key)
|
||||
cached = await async_cache_get(cache_key)
|
||||
if cached is not None:
|
||||
return self.success(cached)
|
||||
|
||||
@@ -467,35 +464,40 @@ class UserActivityRankAPI(APIView):
|
||||
create_time__gte=start,
|
||||
result__in=[JudgeStatus.ACCEPTED, JudgeStatus.AST_CHECK_FAILED],
|
||||
).exclude(username__in=hidden_names)
|
||||
data = list(submissions.values("username").annotate(count=Count("problem_id", distinct=True)).order_by("-count")[:10])
|
||||
cache.set(cache_key, data, 600)
|
||||
data = [
|
||||
row
|
||||
async for row in submissions.values("username")
|
||||
.annotate(count=Count("problem_id", distinct=True))
|
||||
.order_by("-count")[:10]
|
||||
]
|
||||
await async_cache_set(cache_key, data, 600)
|
||||
return self.success(data)
|
||||
|
||||
|
||||
class UserProblemRankAPI(APIView):
|
||||
def get(self, request):
|
||||
class UserProblemRankAPI(AsyncAPIView):
|
||||
async def get(self, request):
|
||||
problem_id = request.GET.get("problem_id")
|
||||
user = request.user
|
||||
if not user.is_authenticated:
|
||||
return self.error("User is not authenticated")
|
||||
|
||||
problem = Problem.objects.get(_id__iexact=problem_id, contest_id__isnull=True, visible=True)
|
||||
problem = await Problem.objects.aget(_id__iexact=problem_id, contest_id__isnull=True, visible=True)
|
||||
submissions = Submission.objects.filter(problem=problem, result__in=[JudgeStatus.ACCEPTED, JudgeStatus.AST_CHECK_FAILED])
|
||||
|
||||
all_ac_count = submissions.values("user_id").distinct().count()
|
||||
all_ac_count = await submissions.values("user_id").distinct().acount()
|
||||
|
||||
class_name = user.class_name or ""
|
||||
class_ac_count = 0
|
||||
|
||||
if class_name:
|
||||
users = User.objects.filter(class_name=user.class_name, is_disabled=False).values_list("id", flat=True)
|
||||
user_ids = list(users)
|
||||
user_ids = [user_id async for user_id in users]
|
||||
submissions = submissions.filter(user_id__in=user_ids)
|
||||
class_ac_count = submissions.values("user_id").distinct().count()
|
||||
class_ac_count = await submissions.values("user_id").distinct().acount()
|
||||
|
||||
my_submissions = submissions.filter(user_id=user.id)
|
||||
|
||||
if len(my_submissions) == 0:
|
||||
if not await my_submissions.aexists():
|
||||
return self.success(
|
||||
{
|
||||
"class_name": class_name,
|
||||
@@ -505,8 +507,8 @@ class UserProblemRankAPI(APIView):
|
||||
}
|
||||
)
|
||||
|
||||
my_first_submission = my_submissions.order_by("create_time").first()
|
||||
rank = submissions.filter(create_time__lte=my_first_submission.create_time).count()
|
||||
my_first_submission = await my_submissions.order_by("create_time").afirst()
|
||||
rank = await submissions.filter(create_time__lte=my_first_submission.create_time).acount()
|
||||
return self.success(
|
||||
{
|
||||
"class_name": class_name,
|
||||
@@ -517,25 +519,26 @@ class UserProblemRankAPI(APIView):
|
||||
)
|
||||
|
||||
|
||||
class ProfileProblemDisplayIDRefreshAPI(APIView):
|
||||
class ProfileProblemDisplayIDRefreshAPI(AsyncAPIView):
|
||||
@login_required
|
||||
def get(self, request):
|
||||
profile = request.user.userprofile
|
||||
async def get(self, request):
|
||||
profile = await UserProfile.objects.aget(user=request.user)
|
||||
acm_problems = profile.acm_problems_status.get("problems", {})
|
||||
oi_problems = profile.oi_problems_status.get("problems", {})
|
||||
ids = list(acm_problems.keys()) + list(oi_problems.keys())
|
||||
if not ids:
|
||||
return self.success()
|
||||
display_ids = Problem.objects.filter(id__in=ids, visible=True).values_list("_id", flat=True)
|
||||
display_ids = [did async for did in Problem.objects.filter(id__in=ids, visible=True).values_list("_id", flat=True)]
|
||||
id_map = dict(zip(ids, display_ids))
|
||||
for k, v in acm_problems.items():
|
||||
v["_id"] = id_map[k]
|
||||
for k, v in oi_problems.items():
|
||||
v["_id"] = id_map[k]
|
||||
profile.save(update_fields=["acm_problems_status", "oi_problems_status"])
|
||||
await profile.asave(update_fields=["acm_problems_status", "oi_problems_status"])
|
||||
return self.success()
|
||||
|
||||
|
||||
# DEPRECATED: 前端未调用 (2026-05-26)
|
||||
class OpenAPIAppkeyAPI(APIView):
|
||||
@login_required
|
||||
def post(self, request):
|
||||
@@ -548,6 +551,7 @@ class OpenAPIAppkeyAPI(APIView):
|
||||
return self.success({"appkey": api_appkey})
|
||||
|
||||
|
||||
# DEPRECATED: 前端未调用 (2026-05-26)
|
||||
class SSOAPI(CSRFExemptAPIView):
|
||||
@login_required
|
||||
def get(self, request):
|
||||
|
||||
@@ -1,19 +1,25 @@
|
||||
from announcement.models import Announcement
|
||||
from announcement.serializers import AnnouncementListSerializer, AnnouncementSerializer
|
||||
from utils.api import APIView
|
||||
from utils.api import AsyncAPIView
|
||||
|
||||
|
||||
class AnnouncementAPI(APIView):
|
||||
def get(self, request):
|
||||
class AnnouncementAPI(AsyncAPIView):
|
||||
async def get(self, request):
|
||||
id = request.GET.get("id")
|
||||
if id:
|
||||
try:
|
||||
announcement = Announcement.objects.get(id=id, visible=True)
|
||||
return self.success(AnnouncementSerializer(announcement).data)
|
||||
announcement = await (
|
||||
Announcement.objects.select_related("created_by")
|
||||
.filter(id=id, visible=True)
|
||||
.afirst()
|
||||
)
|
||||
if announcement is None:
|
||||
raise Announcement.DoesNotExist
|
||||
return self.success(await self.async_serialize_data(AnnouncementSerializer, announcement))
|
||||
except Announcement.DoesNotExist:
|
||||
return self.error("Announcement does not exist")
|
||||
|
||||
announcements = Announcement.objects.select_related("created_by").filter(visible=True)
|
||||
return self.success(
|
||||
self.paginate_data(request, announcements, AnnouncementListSerializer)
|
||||
await self.async_paginate_data(request, announcements, AnnouncementListSerializer)
|
||||
)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from .base import BaseEngine
|
||||
from ast_checker.labels import label
|
||||
|
||||
from .base import BaseEngine
|
||||
|
||||
|
||||
class MustHaveNestingEngine(BaseEngine):
|
||||
def _has_inner_in_subtree(self, node, inner_type):
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from .base import BaseEngine
|
||||
from ast_checker.labels import label
|
||||
|
||||
from .base import BaseEngine
|
||||
|
||||
|
||||
class CountNodeEngine(BaseEngine):
|
||||
def _message(self, rule, count):
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from .base import BaseEngine
|
||||
from ast_checker.labels import label
|
||||
|
||||
from .base import BaseEngine
|
||||
|
||||
|
||||
class MustExistNodeEngine(BaseEngine):
|
||||
def _message(self, rule):
|
||||
|
||||
@@ -1,45 +1,44 @@
|
||||
from django.core.cache import cache
|
||||
from django.db.models import Avg, Count
|
||||
from django.db.models.functions import Round
|
||||
|
||||
from account.decorators import login_required
|
||||
from django.db.models import Avg, Count
|
||||
from django.db.models.functions import Round
|
||||
|
||||
from account.decorators import login_required
|
||||
from comment.models import Comment
|
||||
from comment.serializers import CommentSerializer, CreateCommentSerializer
|
||||
from problem.models import Problem
|
||||
from submission.models import JudgeStatus, Submission
|
||||
from utils.api import APIView
|
||||
from utils.api.api import validate_serializer
|
||||
from utils.constants import CacheKey
|
||||
from submission.models import JudgeStatus, Submission
|
||||
from utils.api import AsyncAPIView
|
||||
from utils.api.api import validate_serializer
|
||||
from utils.async_helpers import async_cache_delete, async_cache_get, async_cache_set
|
||||
from utils.constants import CacheKey
|
||||
|
||||
|
||||
class CommentAPI(APIView):
|
||||
class CommentAPI(AsyncAPIView):
|
||||
@validate_serializer(CreateCommentSerializer)
|
||||
@login_required
|
||||
def post(self, request):
|
||||
async def post(self, request):
|
||||
data = request.data
|
||||
try:
|
||||
problem = Problem.objects.get(id=data["problem_id"], visible=True)
|
||||
problem = await Problem.objects.aget(id=data["problem_id"], visible=True)
|
||||
except Problem.DoesNotExist:
|
||||
self.error("problem is not exists")
|
||||
return self.error("problem is not exists")
|
||||
|
||||
try:
|
||||
submission = (
|
||||
Submission.objects.select_related("problem")
|
||||
.filter(
|
||||
user_id=request.user.id,
|
||||
problem_id=data["problem_id"],
|
||||
result__in=[JudgeStatus.ACCEPTED, JudgeStatus.AST_CHECK_FAILED],
|
||||
)
|
||||
.first()
|
||||
submission = await (
|
||||
Submission.objects.select_related("problem")
|
||||
.filter(
|
||||
user_id=request.user.id,
|
||||
problem_id=data["problem_id"],
|
||||
result__in=[JudgeStatus.ACCEPTED, JudgeStatus.AST_CHECK_FAILED],
|
||||
)
|
||||
except Submission.DoesNotExist:
|
||||
self.error("submission is not exists or not accepted")
|
||||
.afirst()
|
||||
)
|
||||
if not submission:
|
||||
return self.error("submission is not exists or not accepted")
|
||||
|
||||
language = submission.language
|
||||
if language == "Python3":
|
||||
language = "Python"
|
||||
|
||||
Comment.objects.create(
|
||||
await Comment.objects.acreate(
|
||||
user=request.user,
|
||||
problem=problem,
|
||||
submission=submission,
|
||||
@@ -49,32 +48,35 @@ class CommentAPI(APIView):
|
||||
comprehensive_rating=data["comprehensive_rating"],
|
||||
content=data["content"],
|
||||
)
|
||||
cache.delete(f"{CacheKey.comment_stats}:{problem.id}")
|
||||
return self.success()
|
||||
await async_cache_delete(f"{CacheKey.comment_stats}:{problem.id}")
|
||||
return self.success()
|
||||
|
||||
@login_required
|
||||
def get(self, request):
|
||||
async def get(self, request):
|
||||
problem_id = request.GET.get("problem_id")
|
||||
comment = (
|
||||
comment = await (
|
||||
Comment.objects.select_related("problem")
|
||||
.filter(user=request.user, problem_id=problem_id)
|
||||
.first()
|
||||
)
|
||||
if comment:
|
||||
return self.success(CommentSerializer(comment).data)
|
||||
else:
|
||||
return self.success()
|
||||
.afirst()
|
||||
)
|
||||
if comment:
|
||||
return self.success(await self.async_serialize_data(CommentSerializer, comment))
|
||||
else:
|
||||
return self.success()
|
||||
|
||||
|
||||
class CommentStatisticsAPI(APIView):
|
||||
def get(self, request):
|
||||
problem_id = request.GET.get("problem_id")
|
||||
cache_key = f"{CacheKey.comment_stats}:{problem_id}"
|
||||
cached = cache.get(cache_key)
|
||||
if cached is not None:
|
||||
return self.success(cached)
|
||||
class CommentStatisticsAPI(AsyncAPIView):
|
||||
async def get(self, request):
|
||||
problem_id = request.GET.get("problem_id")
|
||||
cache_key = f"{CacheKey.comment_stats}:{problem_id}"
|
||||
cached = await async_cache_get(cache_key)
|
||||
if cached is not None:
|
||||
return self.success(cached)
|
||||
|
||||
agg = Comment.objects.filter(problem_id=problem_id).aggregate(
|
||||
from asgiref.sync import sync_to_async
|
||||
agg = await sync_to_async(
|
||||
Comment.objects.filter(problem_id=problem_id).aggregate
|
||||
)(
|
||||
count=Count("id"),
|
||||
description=Round(Avg("description_rating"), 2),
|
||||
difficulty=Round(Avg("difficulty_rating"), 2),
|
||||
@@ -88,5 +90,5 @@ class CommentStatisticsAPI(APIView):
|
||||
"difficulty": agg["difficulty"],
|
||||
"comprehensive": agg["comprehensive"],
|
||||
}}
|
||||
cache.set(cache_key, data, 3600)
|
||||
return self.success(data)
|
||||
await async_cache_set(cache_key, data, 3600)
|
||||
return self.success(data)
|
||||
|
||||
@@ -12,12 +12,12 @@ from ..views import (
|
||||
)
|
||||
|
||||
urlpatterns = [
|
||||
path("smtp", SMTPAPI.as_view()),
|
||||
path("smtp_test", SMTPTestAPI.as_view()),
|
||||
path("smtp", SMTPAPI.as_view()), # DEPRECATED: 前端未调用
|
||||
path("smtp_test", SMTPTestAPI.as_view()), # DEPRECATED: 前端未调用
|
||||
path("website", WebsiteConfigAPI.as_view()),
|
||||
path("random_user", RandomUsernameAPI.as_view()),
|
||||
path("judge_server", JudgeServerAPI.as_view()),
|
||||
path("prune_test_case", TestCasePruneAPI.as_view()),
|
||||
path("versions", ReleaseNotesAPI.as_view()),
|
||||
path("versions", ReleaseNotesAPI.as_view()), # DEPRECATED: 前端未调用
|
||||
path("dashboard_info", DashboardInfoAPI.as_view()),
|
||||
]
|
||||
|
||||
@@ -12,7 +12,7 @@ urlpatterns = [
|
||||
path("website", WebsiteConfigAPI.as_view()),
|
||||
# 这里必须要有 /
|
||||
path("judge_server_heartbeat/", JudgeServerHeartbeatAPI.as_view()),
|
||||
path("languages", LanguagesAPI.as_view()),
|
||||
path("languages", LanguagesAPI.as_view()), # DEPRECATED: 前端未调用
|
||||
path("hitokoto", HitokotoAPI.as_view()),
|
||||
path("class_usernames", ClassUsernamesAPI.as_view()),
|
||||
]
|
||||
|
||||
128
conf/views.py
128
conf/views.py
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
@@ -6,9 +7,10 @@ import re
|
||||
import shutil
|
||||
import smtplib
|
||||
import time
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
|
||||
import requests
|
||||
from asgiref.sync import sync_to_async
|
||||
from django.conf import settings
|
||||
from django.utils import timezone
|
||||
from requests.exceptions import RequestException
|
||||
@@ -20,7 +22,7 @@ from judge.dispatcher import process_pending_task
|
||||
from options.options import SysOptions
|
||||
from problem.models import Problem
|
||||
from submission.models import Submission
|
||||
from utils.api import APIView, CSRFExemptAPIView, validate_serializer
|
||||
from utils.api import APIView, AsyncAPIView, CSRFExemptAPIView, validate_serializer
|
||||
from utils.cache import JsonDataLoader
|
||||
from utils.shortcuts import get_env, send_email
|
||||
from utils.websocket import push_config_update
|
||||
@@ -38,6 +40,7 @@ from .serializers import (
|
||||
)
|
||||
|
||||
|
||||
# DEPRECATED: 前端未调用 (2026-05-26)
|
||||
class SMTPAPI(APIView):
|
||||
@super_admin_required
|
||||
def get(self, request):
|
||||
@@ -66,6 +69,7 @@ class SMTPAPI(APIView):
|
||||
return self.success()
|
||||
|
||||
|
||||
# DEPRECATED: 前端未调用 (2026-05-26)
|
||||
class SMTPTestAPI(APIView):
|
||||
@super_admin_required
|
||||
@validate_serializer(TestSMTPConfigSerializer)
|
||||
@@ -97,35 +101,33 @@ class SMTPTestAPI(APIView):
|
||||
return self.success()
|
||||
|
||||
|
||||
class WebsiteConfigAPI(APIView):
|
||||
def get(self, request):
|
||||
ret = {
|
||||
key: getattr(SysOptions, key)
|
||||
for key in [
|
||||
"website_base_url",
|
||||
"website_name",
|
||||
"website_name_shortcut",
|
||||
"website_footer",
|
||||
"allow_register",
|
||||
"submission_list_show_all",
|
||||
"class_list",
|
||||
"enable_maxkb",
|
||||
]
|
||||
}
|
||||
class WebsiteConfigAPI(AsyncAPIView):
|
||||
async def get(self, request):
|
||||
ret = await SysOptions.aget_many(
|
||||
"website_base_url",
|
||||
"website_name",
|
||||
"website_name_shortcut",
|
||||
"website_footer",
|
||||
"allow_register",
|
||||
"submission_list_show_all",
|
||||
"class_list",
|
||||
"enable_maxkb",
|
||||
)
|
||||
return self.success(ret)
|
||||
|
||||
@super_admin_required
|
||||
@validate_serializer(CreateEditWebsiteConfigSerializer)
|
||||
def post(self, request):
|
||||
for k, v in request.data.items():
|
||||
if k == "website_footer":
|
||||
with XSSHtml() as parser:
|
||||
v = parser.clean(v)
|
||||
setattr(SysOptions, k, v)
|
||||
|
||||
# 推送配置更新到所有连接的客户端
|
||||
push_config_update(k, v)
|
||||
async def post(self, request):
|
||||
@sync_to_async
|
||||
def _update_config(data):
|
||||
for k, v in data.items():
|
||||
if k == "website_footer":
|
||||
with XSSHtml() as parser:
|
||||
v = parser.clean(v)
|
||||
setattr(SysOptions, k, v)
|
||||
push_config_update(k, v)
|
||||
|
||||
await _update_config(request.data)
|
||||
return self.success()
|
||||
|
||||
|
||||
@@ -206,6 +208,7 @@ class JudgeServerHeartbeatAPI(CSRFExemptAPIView):
|
||||
return self.success()
|
||||
|
||||
|
||||
# DEPRECATED: 前端未调用 (2026-05-26)
|
||||
class LanguagesAPI(APIView):
|
||||
def get(self, request):
|
||||
return self.success(
|
||||
@@ -255,6 +258,7 @@ class TestCasePruneAPI(APIView):
|
||||
shutil.rmtree(test_case_dir, ignore_errors=True)
|
||||
|
||||
|
||||
# DEPRECATED: 前端未调用 (2026-05-26)
|
||||
class ReleaseNotesAPI(APIView):
|
||||
def get(self, request):
|
||||
try:
|
||||
@@ -272,24 +276,29 @@ class ReleaseNotesAPI(APIView):
|
||||
return self.success(releases)
|
||||
|
||||
|
||||
class DashboardInfoAPI(APIView):
|
||||
def get(self, request):
|
||||
today = datetime.today()
|
||||
today_submission_count = Submission.objects.filter(
|
||||
create_time__gte=datetime(today.year, today.month, today.day, 0, 0)
|
||||
).count()
|
||||
recent_contest_count = Contest.objects.exclude(
|
||||
end_time__lt=timezone.now()
|
||||
).count()
|
||||
judge_server_count = len(
|
||||
list(filter(lambda x: x.status == "normal", JudgeServer.objects.all()))
|
||||
class DashboardInfoAPI(AsyncAPIView):
|
||||
async def get(self, request):
|
||||
now = timezone.now()
|
||||
today_start = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
(
|
||||
user_count,
|
||||
today_submission_count,
|
||||
recent_contest_count,
|
||||
judge_servers,
|
||||
) = await asyncio.gather(
|
||||
User.objects.acount(),
|
||||
Submission.objects.filter(create_time__gte=today_start).acount(),
|
||||
Contest.objects.exclude(end_time__lt=timezone.now()).acount(),
|
||||
JudgeServer.objects.filter(
|
||||
last_heartbeat__gte=timezone.now() - timedelta(seconds=6)
|
||||
).acount(),
|
||||
)
|
||||
return self.success(
|
||||
{
|
||||
"user_count": User.objects.count(),
|
||||
"user_count": user_count,
|
||||
"recent_contest_count": recent_contest_count,
|
||||
"today_submission_count": today_submission_count,
|
||||
"judge_server_count": judge_server_count,
|
||||
"judge_server_count": judge_servers,
|
||||
"env": {
|
||||
"FORCE_HTTPS": get_env("FORCE_HTTPS", default=False),
|
||||
"STATIC_CDN_HOST": get_env("STATIC_CDN_HOST", default=""),
|
||||
@@ -298,24 +307,21 @@ class DashboardInfoAPI(APIView):
|
||||
)
|
||||
|
||||
|
||||
class RandomUsernameAPI(APIView):
|
||||
def get(self, request):
|
||||
class RandomUsernameAPI(AsyncAPIView):
|
||||
async def get(self, request):
|
||||
classroom = request.GET.get("classroom", "")
|
||||
if not classroom:
|
||||
return self.error("需要班级号")
|
||||
usernames = (
|
||||
User.objects.filter(username__istartswith=classroom)
|
||||
usernames = [
|
||||
u async for u in User.objects.filter(username__istartswith=classroom)
|
||||
.values_list("username", flat=True)
|
||||
.order_by("?")
|
||||
)
|
||||
if len(usernames) > 10:
|
||||
return self.success(usernames[:10])
|
||||
else:
|
||||
return self.success(usernames)
|
||||
.order_by("?")[:10]
|
||||
]
|
||||
return self.success(usernames)
|
||||
|
||||
|
||||
class HitokotoAPI(APIView):
|
||||
def get(self, request):
|
||||
class HitokotoAPI(AsyncAPIView):
|
||||
async def get(self, request):
|
||||
try:
|
||||
categories = JsonDataLoader.load_data(
|
||||
settings.HITOKOTO_DIR, "categories.json"
|
||||
@@ -328,20 +334,14 @@ class HitokotoAPI(APIView):
|
||||
return self.error("获取一言失败,请稍后再试")
|
||||
|
||||
|
||||
class ClassUsernamesAPI(APIView):
|
||||
def get(self, request):
|
||||
class ClassUsernamesAPI(AsyncAPIView):
|
||||
async def get(self, request):
|
||||
classroom = request.GET.get("classroom", "")
|
||||
if not classroom:
|
||||
return self.error("需要班级号")
|
||||
users = User.objects.filter(class_name=classroom).order_by("-create_time")
|
||||
names = []
|
||||
for user in users:
|
||||
prefix = f"ks{classroom}"
|
||||
result = (
|
||||
user.username[len(prefix) :]
|
||||
if user.username.startswith(prefix)
|
||||
else user.username
|
||||
)
|
||||
names.append(result)
|
||||
|
||||
prefix = f"ks{classroom}"
|
||||
names = [
|
||||
user.username[len(prefix):] if user.username.startswith(prefix) else user.username
|
||||
async for user in User.objects.filter(class_name=classroom).order_by("-create_time")
|
||||
]
|
||||
return self.success(names)
|
||||
|
||||
@@ -5,7 +5,7 @@ from ..views.admin import ACMContestHelper, ContestAnnouncementAPI, ContestAPI,
|
||||
urlpatterns = [
|
||||
path("contest", ContestAPI.as_view()),
|
||||
path("contest/clone", ContestCloneAPI.as_view()),
|
||||
path("contest/announcement", ContestAnnouncementAPI.as_view()),
|
||||
path("contest/announcement", ContestAnnouncementAPI.as_view()), # DEPRECATED: 前端未调用
|
||||
path("contest/acm_helper", ACMContestHelper.as_view()),
|
||||
path("download_submissions", DownloadContestSubmissions.as_view()),
|
||||
path("download_submissions", DownloadContestSubmissions.as_view()), # DEPRECATED: 前端未调用
|
||||
]
|
||||
|
||||
@@ -6,7 +6,7 @@ urlpatterns = [
|
||||
path("contests", ContestListAPI.as_view()),
|
||||
path("contest", ContestAPI.as_view()),
|
||||
path("contest/password", ContestPasswordVerifyAPI.as_view()),
|
||||
path("contest/announcement", ContestAnnouncementListAPI.as_view()),
|
||||
path("contest/announcement", ContestAnnouncementListAPI.as_view()), # DEPRECATED: 前端未调用
|
||||
path("contest/access", ContestAccessAPI.as_view()),
|
||||
path("contest_rank", ContestRankAPI.as_view()),
|
||||
]
|
||||
|
||||
@@ -97,6 +97,7 @@ class ContestAPI(APIView):
|
||||
return self.success(self.paginate_data(request, contests, ContestAdminSerializer))
|
||||
|
||||
|
||||
# DEPRECATED: 前端未调用 (2026-05-26)
|
||||
class ContestAnnouncementAPI(APIView):
|
||||
@validate_serializer(CreateContestAnnouncementSerializer)
|
||||
@super_admin_required
|
||||
@@ -212,6 +213,7 @@ class ACMContestHelper(APIView):
|
||||
return self.success()
|
||||
|
||||
|
||||
# DEPRECATED: 前端未调用 (2026-05-26)
|
||||
class DownloadContestSubmissions(APIView):
|
||||
def _dump_submissions(self, contest, exclude_admin=True):
|
||||
problem_ids = contest.problem_set.all().values_list("id", "_id")
|
||||
|
||||
@@ -12,7 +12,7 @@ from account.decorators import (
|
||||
)
|
||||
from account.models import AdminType
|
||||
from problem.models import Problem
|
||||
from utils.api import APIView, validate_serializer
|
||||
from utils.api import APIView, AsyncAPIView, validate_serializer
|
||||
from utils.constants import CONTEST_PASSWORD_SESSION_KEY, CacheKey, ContestRuleType, ContestStatus
|
||||
from utils.shortcuts import check_is_id, datetime2str
|
||||
|
||||
@@ -20,6 +20,7 @@ from ..models import ACMContestRank, Contest, ContestAnnouncement, OIContestRank
|
||||
from ..serializers import ACMContestRankSerializer, ContestAnnouncementSerializer, ContestPasswordVerifySerializer, ContestSerializer, OIContestRankSerializer
|
||||
|
||||
|
||||
# DEPRECATED: 前端未调用 (2026-05-26)
|
||||
class ContestAnnouncementListAPI(APIView):
|
||||
@check_contest_permission(check_type="announcements")
|
||||
def get(self, request):
|
||||
@@ -35,22 +36,28 @@ class ContestAnnouncementListAPI(APIView):
|
||||
return self.success(ContestAnnouncementSerializer(data, many=True).data)
|
||||
|
||||
|
||||
class ContestAPI(APIView):
|
||||
def get(self, request):
|
||||
class ContestAPI(AsyncAPIView):
|
||||
async def get(self, request):
|
||||
id = request.GET.get("id")
|
||||
if not id or not check_is_id(id):
|
||||
return self.error("Invalid parameter, id is required")
|
||||
try:
|
||||
contest = Contest.objects.get(id=id, visible=True)
|
||||
contest = await (
|
||||
Contest.objects.select_related("created_by")
|
||||
.filter(id=id, visible=True)
|
||||
.afirst()
|
||||
)
|
||||
if contest is None:
|
||||
raise Contest.DoesNotExist
|
||||
except Contest.DoesNotExist:
|
||||
return self.error("Contest does not exist")
|
||||
data = ContestSerializer(contest).data
|
||||
data = await self.async_serialize_data(ContestSerializer, contest)
|
||||
data["now"] = datetime2str(now())
|
||||
return self.success(data)
|
||||
|
||||
|
||||
class ContestListAPI(APIView):
|
||||
def get(self, request):
|
||||
class ContestListAPI(AsyncAPIView):
|
||||
async def get(self, request):
|
||||
contests = Contest.objects.select_related("created_by").filter(visible=True)
|
||||
keyword = request.GET.get("keyword")
|
||||
rule_type = request.GET.get("rule_type")
|
||||
@@ -70,7 +77,7 @@ class ContestListAPI(APIView):
|
||||
contests = contests.filter(end_time__lt=cur)
|
||||
else:
|
||||
contests = contests.filter(start_time__lte=cur, end_time__gte=cur)
|
||||
return self.success(self.paginate_data(request, contests, ContestSerializer))
|
||||
return self.success(await self.async_paginate_data(request, contests, ContestSerializer))
|
||||
|
||||
|
||||
class ContestPasswordVerifyAPI(APIView):
|
||||
|
||||
@@ -18,6 +18,7 @@ asgiref==3.11.1 \
|
||||
# channels
|
||||
# channels-redis
|
||||
# django
|
||||
# onlinejudge
|
||||
certifi==2026.4.22 \
|
||||
--hash=sha256:3cb2210c8f88ba2318d29b0388d1023c8492ff72ecdde4ebdaddbb13a31b1c4a \
|
||||
--hash=sha256:8d455352a37b71bf76a79caa83a3d6c25afee4a385d632127b6afb3963f1c580
|
||||
|
||||
861
docs/superpowers/plans/2026-05-26-backend-async.md
Normal file
861
docs/superpowers/plans/2026-05-26-backend-async.md
Normal file
@@ -0,0 +1,861 @@
|
||||
# Backend Async Hardening Implementation Plan
|
||||
|
||||
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
|
||||
|
||||
**Goal:** Make the current backend async work correct first, then establish safe patterns for expanding async views without changing API response shapes.
|
||||
|
||||
**Architecture:** Keep the existing custom `APIView`/DRF serializer stack. Add async-safe tests and helpers around `AsyncAPIView`, explicitly preload serializer relations in converted endpoints, and use `sync_to_async(..., thread_sensitive=True)` for synchronous serializer/cache/helper code that remains inside async views.
|
||||
|
||||
**Tech Stack:** Django 6.0.4, custom class-based API views, Django async ORM, DRF serializers, PostgreSQL, Redis cache, Channels ASGI.
|
||||
|
||||
---
|
||||
|
||||
## Async Rules For This Repository
|
||||
|
||||
1. Async conversion is valid only when the endpoint preserves URL, method, status code, JSON envelope, and permission behavior.
|
||||
2. Every async view that serializes model instances must either preload all serializer relations with `select_related()` / `prefetch_related()` or run serializer `.data` through a sync boundary.
|
||||
3. `asyncio.gather()` is only for independent reads. Do not use it around writes that depend on ordering or transaction semantics.
|
||||
4. Keep file upload/download, SMTP, test-case pruning, judge heartbeat mutation paths, and contest permission-heavy flows sync until the async decorator and middleware work is complete.
|
||||
5. Current sync `MiddlewareMixin` middleware means ASGI requests still cross sync boundaries. The first async milestone is correctness and latency cleanup, not a full event-loop purity claim.
|
||||
|
||||
---
|
||||
|
||||
### Task 1: Add Regression Coverage For Converted Async Detail Views
|
||||
|
||||
**Files:**
|
||||
- Create: `utils/test_async_view_regressions.py`
|
||||
|
||||
- [ ] **Step 1: Write failing regression tests**
|
||||
|
||||
Create `utils/test_async_view_regressions.py`:
|
||||
|
||||
```python
|
||||
from datetime import timedelta
|
||||
|
||||
from asgiref.sync import sync_to_async
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.test import AsyncClient, TestCase
|
||||
from django.utils import timezone
|
||||
|
||||
from account.models import UserProfile
|
||||
from announcement.models import Announcement
|
||||
from contest.models import Contest
|
||||
from flowchart.models import FlowchartSubmission
|
||||
from problem.models import Problem, ProblemRuleType
|
||||
from utils.constants import ContestRuleType, Difficulty
|
||||
|
||||
User = get_user_model()
|
||||
|
||||
|
||||
def make_user(username="async_user"):
|
||||
user = User.objects.create(username=username, email=f"{username}@example.com")
|
||||
user.set_password("pass1234")
|
||||
user.save()
|
||||
UserProfile.objects.create(user=user)
|
||||
return user
|
||||
|
||||
|
||||
def make_problem(user):
|
||||
return Problem.objects.create(
|
||||
_id="ASYNC001",
|
||||
title="Async Problem",
|
||||
description="desc",
|
||||
input_description="input",
|
||||
output_description="output",
|
||||
samples=[],
|
||||
test_case_id="async-test-case",
|
||||
test_case_score=[],
|
||||
hint="",
|
||||
languages=["Python3"],
|
||||
template={},
|
||||
created_by=user,
|
||||
time_limit=1000,
|
||||
memory_limit=128,
|
||||
rule_type=ProblemRuleType.ACM,
|
||||
difficulty=Difficulty.LOW,
|
||||
share_submission=False,
|
||||
allow_flowchart=True,
|
||||
show_flowchart=True,
|
||||
)
|
||||
|
||||
|
||||
class AsyncConvertedViewRegressionTests(TestCase):
|
||||
@classmethod
|
||||
def setUpTestData(cls):
|
||||
cls.user = make_user()
|
||||
cls.announcement = Announcement.objects.create(
|
||||
title="Async Announcement",
|
||||
content="content",
|
||||
tag="notice",
|
||||
visible=True,
|
||||
top=False,
|
||||
created_by=cls.user,
|
||||
)
|
||||
cls.contest = Contest.objects.create(
|
||||
title="Async Contest",
|
||||
description="contest desc",
|
||||
tag="weekly",
|
||||
real_time_rank=True,
|
||||
password=None,
|
||||
rule_type=ContestRuleType.ACM,
|
||||
start_time=timezone.now() - timedelta(hours=1),
|
||||
end_time=timezone.now() + timedelta(hours=1),
|
||||
created_by=cls.user,
|
||||
visible=True,
|
||||
allowed_ip_ranges=[],
|
||||
)
|
||||
cls.problem = make_problem(cls.user)
|
||||
cls.flowchart_submission = FlowchartSubmission.objects.create(
|
||||
user=cls.user,
|
||||
problem=cls.problem,
|
||||
mermaid_code="graph TD\nA-->B",
|
||||
flowchart_data={},
|
||||
)
|
||||
|
||||
async def test_announcement_detail_serializes_created_by(self):
|
||||
response = await AsyncClient().get(f"/api/announcement?id={self.announcement.id}")
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
body = response.json()
|
||||
self.assertIsNone(body["error"])
|
||||
self.assertEqual(body["data"]["created_by"]["username"], self.user.username)
|
||||
|
||||
async def test_contest_detail_serializes_created_by(self):
|
||||
response = await AsyncClient().get(f"/api/contest?id={self.contest.id}")
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
body = response.json()
|
||||
self.assertIsNone(body["error"])
|
||||
self.assertEqual(body["data"]["created_by"]["username"], self.user.username)
|
||||
|
||||
async def test_flowchart_detail_serializes_user_and_problem(self):
|
||||
client = AsyncClient()
|
||||
await sync_to_async(client.force_login)(self.user)
|
||||
|
||||
response = await client.get(f"/api/flowchart/submission?id={self.flowchart_submission.id}")
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
body = response.json()
|
||||
self.assertIsNone(body["error"])
|
||||
self.assertEqual(body["data"]["username"], self.user.username)
|
||||
self.assertEqual(body["data"]["problem"], self.problem.id)
|
||||
```
|
||||
|
||||
- [ ] **Step 2: Run the regression tests**
|
||||
|
||||
Run:
|
||||
|
||||
```bash
|
||||
rtk uv run python manage.py test utils.test_async_view_regressions -v 2
|
||||
```
|
||||
|
||||
Expected before Task 2: at least one test returns `server-error` or raises async-unsafe lazy relation access.
|
||||
|
||||
- [ ] **Step 3: Commit the failing tests**
|
||||
|
||||
Run:
|
||||
|
||||
```bash
|
||||
rtk git add utils/test_async_view_regressions.py
|
||||
rtk git commit -m "test: cover async view serialization regressions"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 2: Preload Relations For Already Converted Async Detail Views
|
||||
|
||||
**Files:**
|
||||
- Modify: `announcement/views/oj.py`
|
||||
- Modify: `contest/views/oj.py`
|
||||
- Modify: `flowchart/views/oj.py`
|
||||
|
||||
- [ ] **Step 1: Fix announcement detail relation loading**
|
||||
|
||||
In `announcement/views/oj.py`, replace the detail query with:
|
||||
|
||||
```python
|
||||
announcement = await (
|
||||
Announcement.objects.select_related("created_by")
|
||||
.filter(id=id, visible=True)
|
||||
.afirst()
|
||||
)
|
||||
if announcement is None:
|
||||
raise Announcement.DoesNotExist
|
||||
```
|
||||
|
||||
- [ ] **Step 2: Fix contest detail relation loading**
|
||||
|
||||
In `contest/views/oj.py`, replace the detail query with:
|
||||
|
||||
```python
|
||||
contest = await (
|
||||
Contest.objects.select_related("created_by")
|
||||
.filter(id=id, visible=True)
|
||||
.afirst()
|
||||
)
|
||||
if contest is None:
|
||||
raise Contest.DoesNotExist
|
||||
```
|
||||
|
||||
- [ ] **Step 3: Fix flowchart submission detail relation loading**
|
||||
|
||||
In `flowchart/views/oj.py`, update `FlowchartSubmissionAPI.get()`:
|
||||
|
||||
```python
|
||||
submission = await (
|
||||
FlowchartSubmission.objects.select_related("user", "problem")
|
||||
.filter(id=submission_id)
|
||||
.afirst()
|
||||
)
|
||||
if submission is None:
|
||||
raise FlowchartSubmission.DoesNotExist
|
||||
```
|
||||
|
||||
- [ ] **Step 4: Fix flowchart retry permission relation loading**
|
||||
|
||||
In `flowchart/views/oj.py`, update `FlowchartSubmissionRetryAPI.post()`:
|
||||
|
||||
```python
|
||||
submission = await (
|
||||
FlowchartSubmission.objects.select_related("problem")
|
||||
.filter(id=submission_id)
|
||||
.afirst()
|
||||
)
|
||||
if submission is None:
|
||||
raise FlowchartSubmission.DoesNotExist
|
||||
```
|
||||
|
||||
- [ ] **Step 5: Fix flowchart completed-detail relation loading**
|
||||
|
||||
In `flowchart/views/oj.py`, update `FlowchartSubmissionDetailAPI.get()` before serialization:
|
||||
|
||||
```python
|
||||
submissions = (
|
||||
FlowchartSubmission.objects.select_related("user", "problem")
|
||||
.filter(
|
||||
user=request.user,
|
||||
problem=problem,
|
||||
status=FlowchartSubmissionStatus.COMPLETED,
|
||||
)
|
||||
.order_by("create_time")
|
||||
)
|
||||
```
|
||||
|
||||
- [ ] **Step 6: Run the regression tests**
|
||||
|
||||
Run:
|
||||
|
||||
```bash
|
||||
rtk uv run python manage.py test utils.test_async_view_regressions -v 2
|
||||
```
|
||||
|
||||
Expected: all tests pass.
|
||||
|
||||
- [ ] **Step 7: Run Django system checks**
|
||||
|
||||
Run:
|
||||
|
||||
```bash
|
||||
rtk uv run python manage.py check
|
||||
```
|
||||
|
||||
Expected: `System check identified no issues`.
|
||||
|
||||
- [ ] **Step 8: Commit relation fixes**
|
||||
|
||||
Run:
|
||||
|
||||
```bash
|
||||
rtk git add announcement/views/oj.py contest/views/oj.py flowchart/views/oj.py
|
||||
rtk git commit -m "fix: preload relations in async detail views"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 3: Add Async Serialization Helpers To `AsyncAPIView`
|
||||
|
||||
**Files:**
|
||||
- Modify: `utils/api/api.py`
|
||||
- Modify: `utils/test_async_api.py`
|
||||
|
||||
- [ ] **Step 1: Write tests for async serialization helper**
|
||||
|
||||
Create `utils/test_async_api.py`:
|
||||
|
||||
```python
|
||||
import json
|
||||
|
||||
from django.test import AsyncRequestFactory, SimpleTestCase
|
||||
|
||||
from utils.api import AsyncAPIView, serializers, validate_serializer
|
||||
|
||||
|
||||
class PayloadSerializer(serializers.Serializer):
|
||||
name = serializers.CharField()
|
||||
|
||||
|
||||
class EchoSerializer(serializers.Serializer):
|
||||
name = serializers.CharField()
|
||||
|
||||
|
||||
class AsyncValidatedEchoView(AsyncAPIView):
|
||||
@validate_serializer(PayloadSerializer)
|
||||
async def post(self, request):
|
||||
return self.success({"name": request.data["name"]})
|
||||
|
||||
|
||||
class AsyncSerializationHelperTests(SimpleTestCase):
|
||||
async def test_validate_serializer_supports_async_view_methods(self):
|
||||
request = AsyncRequestFactory().post(
|
||||
"/api/echo",
|
||||
data=json.dumps({"name": "alice"}),
|
||||
content_type="application/json",
|
||||
)
|
||||
|
||||
response = await AsyncValidatedEchoView.as_view()(request)
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertEqual(response.data["error"], None)
|
||||
self.assertEqual(response.data["data"], {"name": "alice"})
|
||||
|
||||
async def test_async_serialize_data_returns_serializer_data(self):
|
||||
view = AsyncAPIView()
|
||||
|
||||
data = await view.async_serialize_data(
|
||||
EchoSerializer,
|
||||
[{"name": "alice"}, {"name": "bob"}],
|
||||
many=True,
|
||||
)
|
||||
|
||||
self.assertEqual(data, [{"name": "alice"}, {"name": "bob"}])
|
||||
```
|
||||
|
||||
- [ ] **Step 2: Run helper tests and confirm failure**
|
||||
|
||||
Run:
|
||||
|
||||
```bash
|
||||
rtk uv run python manage.py test utils.test_async_api -v 2
|
||||
```
|
||||
|
||||
Expected before implementation: `AttributeError: 'AsyncAPIView' object has no attribute 'async_serialize_data'`.
|
||||
|
||||
- [ ] **Step 3: Add `sync_to_async` import**
|
||||
|
||||
In `utils/api/api.py`, add:
|
||||
|
||||
```python
|
||||
from asgiref.sync import sync_to_async
|
||||
```
|
||||
|
||||
- [ ] **Step 4: Add serializer helper methods**
|
||||
|
||||
In `AsyncAPIView`, before `async_paginate_data()`, add:
|
||||
|
||||
```python
|
||||
def serialize_data(self, object_serializer, data, **kwargs):
|
||||
return object_serializer(data, **kwargs).data
|
||||
|
||||
async def async_serialize_data(self, object_serializer, data, **kwargs):
|
||||
return await sync_to_async(
|
||||
self.serialize_data,
|
||||
thread_sensitive=True,
|
||||
)(object_serializer, data, **kwargs)
|
||||
```
|
||||
|
||||
- [ ] **Step 5: Use helper inside async pagination**
|
||||
|
||||
In `AsyncAPIView.async_paginate_data()`, replace:
|
||||
|
||||
```python
|
||||
if object_serializer:
|
||||
results = object_serializer(results, many=True, context={"request": request}).data
|
||||
```
|
||||
|
||||
with:
|
||||
|
||||
```python
|
||||
if object_serializer:
|
||||
results = await self.async_serialize_data(
|
||||
object_serializer,
|
||||
results,
|
||||
many=True,
|
||||
context={"request": request},
|
||||
)
|
||||
```
|
||||
|
||||
- [ ] **Step 6: Run helper and regression tests**
|
||||
|
||||
Run:
|
||||
|
||||
```bash
|
||||
rtk uv run python manage.py test utils.test_async_api utils.test_async_view_regressions -v 2
|
||||
```
|
||||
|
||||
Expected: all tests pass.
|
||||
|
||||
- [ ] **Step 7: Commit async serializer helper**
|
||||
|
||||
Run:
|
||||
|
||||
```bash
|
||||
rtk git add utils/api/api.py utils/test_async_api.py
|
||||
rtk git commit -m "feat: add async serializer helper"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 4: Add Cache Helpers For Async Views
|
||||
|
||||
**Files:**
|
||||
- Create: `utils/async_helpers.py`
|
||||
- Create: `utils/test_async_helpers.py`
|
||||
- Modify: `problem/views/oj.py`
|
||||
- Modify: `comment/views/oj.py`
|
||||
|
||||
- [ ] **Step 1: Write cache helper tests**
|
||||
|
||||
Create `utils/test_async_helpers.py`:
|
||||
|
||||
```python
|
||||
from django.test import SimpleTestCase, override_settings
|
||||
|
||||
from utils.async_helpers import async_cache_delete, async_cache_get, async_cache_set
|
||||
|
||||
|
||||
@override_settings(
|
||||
CACHES={
|
||||
"default": {
|
||||
"BACKEND": "django.core.cache.backends.locmem.LocMemCache",
|
||||
"LOCATION": "async-helper-tests",
|
||||
}
|
||||
}
|
||||
)
|
||||
class AsyncCacheHelperTests(SimpleTestCase):
|
||||
async def test_async_cache_round_trip(self):
|
||||
await async_cache_set("async:key", {"value": 1}, 30)
|
||||
|
||||
value = await async_cache_get("async:key")
|
||||
|
||||
self.assertEqual(value, {"value": 1})
|
||||
|
||||
async def test_async_cache_delete(self):
|
||||
await async_cache_set("async:delete", "present", 30)
|
||||
await async_cache_delete("async:delete")
|
||||
|
||||
value = await async_cache_get("async:delete")
|
||||
|
||||
self.assertIsNone(value)
|
||||
```
|
||||
|
||||
- [ ] **Step 2: Run tests and confirm import failure**
|
||||
|
||||
Run:
|
||||
|
||||
```bash
|
||||
rtk uv run python manage.py test utils.test_async_helpers -v 2
|
||||
```
|
||||
|
||||
Expected before implementation: `ModuleNotFoundError: No module named 'utils.async_helpers'`.
|
||||
|
||||
- [ ] **Step 3: Create async cache helpers**
|
||||
|
||||
Create `utils/async_helpers.py`:
|
||||
|
||||
```python
|
||||
from asgiref.sync import sync_to_async
|
||||
from django.core.cache import cache
|
||||
|
||||
|
||||
async def async_cache_get(key, default=None):
|
||||
return await sync_to_async(cache.get, thread_sensitive=True)(key, default)
|
||||
|
||||
|
||||
async def async_cache_set(key, value, timeout=None):
|
||||
return await sync_to_async(cache.set, thread_sensitive=True)(key, value, timeout)
|
||||
|
||||
|
||||
async def async_cache_delete(key):
|
||||
return await sync_to_async(cache.delete, thread_sensitive=True)(key)
|
||||
```
|
||||
|
||||
- [ ] **Step 4: Convert async problem cache calls**
|
||||
|
||||
In `problem/views/oj.py`, add:
|
||||
|
||||
```python
|
||||
from utils.async_helpers import async_cache_get, async_cache_set
|
||||
```
|
||||
|
||||
In `ProblemTagAPI.get()`, replace:
|
||||
|
||||
```python
|
||||
cached = cache.get(cache_key)
|
||||
```
|
||||
|
||||
with:
|
||||
|
||||
```python
|
||||
cached = await async_cache_get(cache_key)
|
||||
```
|
||||
|
||||
Replace:
|
||||
|
||||
```python
|
||||
cache.set(cache_key, data, 3600)
|
||||
```
|
||||
|
||||
with:
|
||||
|
||||
```python
|
||||
await async_cache_set(cache_key, data, 3600)
|
||||
```
|
||||
|
||||
- [ ] **Step 5: Convert async comment cache calls**
|
||||
|
||||
In `comment/views/oj.py`, add:
|
||||
|
||||
```python
|
||||
from utils.async_helpers import async_cache_delete, async_cache_get, async_cache_set
|
||||
```
|
||||
|
||||
In `CommentAPI.post()`, replace:
|
||||
|
||||
```python
|
||||
cache.delete(f"{CacheKey.comment_stats}:{problem.id}")
|
||||
```
|
||||
|
||||
with:
|
||||
|
||||
```python
|
||||
await async_cache_delete(f"{CacheKey.comment_stats}:{problem.id}")
|
||||
```
|
||||
|
||||
In `CommentStatisticsAPI.get()`, replace `cache.get()` and `cache.set()` with:
|
||||
|
||||
```python
|
||||
cached = await async_cache_get(cache_key)
|
||||
```
|
||||
|
||||
and:
|
||||
|
||||
```python
|
||||
await async_cache_set(cache_key, data, 3600)
|
||||
```
|
||||
|
||||
- [ ] **Step 6: Run helper and targeted async tests**
|
||||
|
||||
Run:
|
||||
|
||||
```bash
|
||||
rtk uv run python manage.py test utils.test_async_helpers utils.test_async_api utils.test_async_view_regressions -v 2
|
||||
```
|
||||
|
||||
Expected: all tests pass.
|
||||
|
||||
- [ ] **Step 7: Commit cache helper work**
|
||||
|
||||
Run:
|
||||
|
||||
```bash
|
||||
rtk git add utils/async_helpers.py utils/test_async_helpers.py problem/views/oj.py comment/views/oj.py
|
||||
rtk git commit -m "feat: add async cache helpers"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 5: Make Permission Decorators Async-Aware Before More Conversions
|
||||
|
||||
**Files:**
|
||||
- Modify: `account/decorators.py`
|
||||
- Create: `account/test_async_decorators.py`
|
||||
|
||||
- [ ] **Step 1: Write tests for async `login_required`**
|
||||
|
||||
Create `account/test_async_decorators.py`:
|
||||
|
||||
```python
|
||||
from django.contrib.auth.models import AnonymousUser
|
||||
from django.test import AsyncRequestFactory, SimpleTestCase
|
||||
|
||||
from account.decorators import login_required
|
||||
from utils.api import AsyncAPIView
|
||||
|
||||
|
||||
class DisabledUser:
|
||||
is_authenticated = True
|
||||
is_disabled = True
|
||||
|
||||
|
||||
class ActiveUser:
|
||||
is_authenticated = True
|
||||
is_disabled = False
|
||||
|
||||
|
||||
class ProtectedAsyncView(AsyncAPIView):
|
||||
@login_required
|
||||
async def get(self, request):
|
||||
return self.success("ok")
|
||||
|
||||
|
||||
class AsyncPermissionDecoratorTests(SimpleTestCase):
|
||||
async def test_async_login_required_allows_active_user(self):
|
||||
request = AsyncRequestFactory().get("/api/protected")
|
||||
request.user = ActiveUser()
|
||||
|
||||
response = await ProtectedAsyncView.as_view()(request)
|
||||
|
||||
self.assertEqual(response.data["error"], None)
|
||||
self.assertEqual(response.data["data"], "ok")
|
||||
|
||||
async def test_async_login_required_rejects_anonymous_user(self):
|
||||
request = AsyncRequestFactory().get("/api/protected")
|
||||
request.user = AnonymousUser()
|
||||
|
||||
response = await ProtectedAsyncView.as_view()(request)
|
||||
|
||||
self.assertEqual(response.data["error"], "permission-denied")
|
||||
self.assertEqual(response.data["data"], "Please login first")
|
||||
|
||||
async def test_async_login_required_rejects_disabled_user(self):
|
||||
request = AsyncRequestFactory().get("/api/protected")
|
||||
request.user = DisabledUser()
|
||||
|
||||
response = await ProtectedAsyncView.as_view()(request)
|
||||
|
||||
self.assertEqual(response.data["error"], "permission-denied")
|
||||
self.assertEqual(response.data["data"], "Your account is disabled")
|
||||
```
|
||||
|
||||
- [ ] **Step 2: Run decorator tests**
|
||||
|
||||
Run:
|
||||
|
||||
```bash
|
||||
rtk uv run python manage.py test account.test_async_decorators -v 2
|
||||
```
|
||||
|
||||
Expected before implementation: this may pass because `AsyncAPIView.dispatch()` awaits returned coroutine. Keep the test anyway as a contract before refactoring.
|
||||
|
||||
- [ ] **Step 3: Refactor `BasePermissionDecorator` with explicit async path**
|
||||
|
||||
In `account/decorators.py`, add:
|
||||
|
||||
```python
|
||||
import inspect
|
||||
```
|
||||
|
||||
Replace `BasePermissionDecorator.__get__()` with:
|
||||
|
||||
```python
|
||||
def __get__(self, obj, obj_type):
|
||||
if inspect.iscoroutinefunction(self.func):
|
||||
return functools.partial(self._async_call, obj)
|
||||
return functools.partial(self.__call__, obj)
|
||||
```
|
||||
|
||||
Add this method to `BasePermissionDecorator`:
|
||||
|
||||
```python
|
||||
async def _async_call(self, *args, **kwargs):
|
||||
self.request = args[1]
|
||||
|
||||
if self.check_permission():
|
||||
if self.request.user.is_disabled:
|
||||
return self.error("Your account is disabled")
|
||||
return await self.func(*args, **kwargs)
|
||||
return self.error("Please login first")
|
||||
```
|
||||
|
||||
- [ ] **Step 4: Run decorator and async view tests**
|
||||
|
||||
Run:
|
||||
|
||||
```bash
|
||||
rtk uv run python manage.py test account.test_async_decorators utils.test_async_api utils.test_async_view_regressions -v 2
|
||||
```
|
||||
|
||||
Expected: all tests pass.
|
||||
|
||||
- [ ] **Step 5: Commit decorator refactor**
|
||||
|
||||
Run:
|
||||
|
||||
```bash
|
||||
rtk git add account/decorators.py account/test_async_decorators.py
|
||||
rtk git commit -m "refactor: make permission decorators async-aware"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 6: Convert One Endpoint Family At A Time
|
||||
|
||||
**Files:**
|
||||
- Modify only the endpoint family being converted in each batch.
|
||||
- Add or update tests in the same app, using `AsyncClient` for converted URLs.
|
||||
|
||||
- [ ] **Step 1: Choose the next low-risk batch**
|
||||
|
||||
Use this order:
|
||||
|
||||
```text
|
||||
1. Pure public GET list/detail endpoints already close to async:
|
||||
- announcement list/detail
|
||||
- contest list/detail
|
||||
- problem tag/list/detail
|
||||
|
||||
2. Authenticated read-only list endpoints:
|
||||
- message list
|
||||
- submission list
|
||||
- flowchart list/detail/current
|
||||
|
||||
3. Simple create/update endpoints with no contest permission decorator:
|
||||
- message create
|
||||
- comment create
|
||||
- flowchart retry/create
|
||||
```
|
||||
|
||||
- [ ] **Step 2: For each endpoint, write one async smoke test**
|
||||
|
||||
Use this template and replace the URL and assertions with the exact endpoint response:
|
||||
|
||||
```python
|
||||
from django.test import AsyncClient, TestCase
|
||||
|
||||
|
||||
class EndpointAsyncSmokeTests(TestCase):
|
||||
async def test_endpoint_returns_success_envelope(self):
|
||||
response = await AsyncClient().get("/api/endpoint?limit=10")
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
body = response.json()
|
||||
self.assertIn("error", body)
|
||||
self.assertIn("data", body)
|
||||
```
|
||||
|
||||
- [ ] **Step 3: Convert ORM access only where async ORM exists**
|
||||
|
||||
Use these replacements:
|
||||
|
||||
```python
|
||||
obj = await Model.objects.aget(id=id)
|
||||
count = await queryset.acount()
|
||||
first = await queryset.afirst()
|
||||
last = await queryset.alast()
|
||||
items = [item async for item in queryset[offset:offset + limit]]
|
||||
created = await Model.objects.acreate(field=value)
|
||||
await instance.asave(update_fields=["field"])
|
||||
```
|
||||
|
||||
- [ ] **Step 4: Keep sync-only helpers behind `sync_to_async`**
|
||||
|
||||
Use this pattern:
|
||||
|
||||
```python
|
||||
from asgiref.sync import sync_to_async
|
||||
|
||||
result = await sync_to_async(sync_helper, thread_sensitive=True)(arg1, arg2)
|
||||
```
|
||||
|
||||
- [ ] **Step 5: Run targeted app tests and system check after each batch**
|
||||
|
||||
Run:
|
||||
|
||||
```bash
|
||||
rtk uv run python manage.py test <app_label> -v 2
|
||||
rtk uv run python manage.py check
|
||||
```
|
||||
|
||||
Expected: targeted tests pass and system check reports no issues.
|
||||
|
||||
- [ ] **Step 6: Commit each endpoint family separately**
|
||||
|
||||
Run:
|
||||
|
||||
```bash
|
||||
rtk git add <changed-files>
|
||||
rtk git commit -m "refactor: async <endpoint-family> endpoints"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Task 7: Audit Middleware Before Claiming Full Async Benefit
|
||||
|
||||
**Files:**
|
||||
- Modify: `account/middleware.py`
|
||||
- Add tests only for behavior that changes.
|
||||
|
||||
- [ ] **Step 1: Record current sync middleware boundaries**
|
||||
|
||||
Before changing middleware, note these current sync classes:
|
||||
|
||||
```text
|
||||
account.middleware.APITokenAuthMiddleware
|
||||
account.middleware.AdminRoleRequiredMiddleware
|
||||
account.middleware.SessionRecordMiddleware
|
||||
```
|
||||
|
||||
- [ ] **Step 2: Keep middleware sync during endpoint correctness work**
|
||||
|
||||
Do not convert middleware in the same commit as endpoint conversions. Middleware affects every request and needs its own review.
|
||||
|
||||
- [ ] **Step 3: If middleware conversion is pursued, convert one class per commit**
|
||||
|
||||
Use Django new-style middleware with explicit sync and async handling. `__acall__` is a local helper; `__call__` dispatches to it when Django passes an async `get_response`:
|
||||
|
||||
```python
|
||||
from asgiref.sync import iscoroutinefunction, markcoroutinefunction, sync_to_async
|
||||
|
||||
|
||||
class ExampleMiddleware:
|
||||
sync_capable = True
|
||||
async_capable = True
|
||||
|
||||
def __init__(self, get_response):
|
||||
self.get_response = get_response
|
||||
self.is_async = iscoroutinefunction(get_response)
|
||||
if self.is_async:
|
||||
markcoroutinefunction(self)
|
||||
|
||||
def __call__(self, request):
|
||||
if self.is_async:
|
||||
return self.__acall__(request)
|
||||
response = self.process_request(request)
|
||||
if response is not None:
|
||||
return response
|
||||
return self.get_response(request)
|
||||
|
||||
async def __acall__(self, request):
|
||||
response = await self.aprocess_request(request)
|
||||
if response is not None:
|
||||
return response
|
||||
return await self.get_response(request)
|
||||
|
||||
def process_request(self, request):
|
||||
return None
|
||||
|
||||
async def aprocess_request(self, request):
|
||||
return await sync_to_async(self.process_request, thread_sensitive=True)(request)
|
||||
```
|
||||
|
||||
- [ ] **Step 4: Run full backend test command after middleware work**
|
||||
|
||||
Run:
|
||||
|
||||
```bash
|
||||
rtk uv run python manage.py test -v 2
|
||||
rtk uv run python manage.py check
|
||||
```
|
||||
|
||||
Expected: all tests pass and system check reports no issues.
|
||||
|
||||
---
|
||||
|
||||
## Definition Of Done
|
||||
|
||||
- `rtk uv run python manage.py check` passes.
|
||||
- Async regression tests cover converted detail serializers that depend on FK relations.
|
||||
- Converted async views do not call DRF serializer `.data` directly unless the data is primitive and relation-free.
|
||||
- Cache access from async views uses async helper wrappers.
|
||||
- New endpoint conversions are committed in endpoint-family-sized commits.
|
||||
- No file upload/download, SMTP, judge heartbeat, test-case prune, or contest permission-heavy endpoint is converted without a separate focused plan.
|
||||
@@ -11,6 +11,6 @@ class FlowchartEvaluationPromptTests(TestCase):
|
||||
self.assertIn("Mermaid节点ID由系统生成", prompt)
|
||||
self.assertIn("不要评价节点ID", prompt)
|
||||
self.assertIn("不要因节点ID扣分", prompt)
|
||||
self.assertIn("feedback控制在0字以内", prompt)
|
||||
self.assertIn("feedback控制在100字以内", prompt)
|
||||
self.assertIn("suggestions最多3条", prompt)
|
||||
self.assertIn("重要建议必须以【重点】开头", prompt)
|
||||
|
||||
@@ -7,65 +7,63 @@ from flowchart.serializers import (
|
||||
)
|
||||
from flowchart.tasks import evaluate_flowchart_task
|
||||
from problem.models import Problem
|
||||
from utils.api import APIView
|
||||
from utils.api import AsyncAPIView
|
||||
|
||||
|
||||
class FlowchartSubmissionAPI(APIView):
|
||||
class FlowchartSubmissionAPI(AsyncAPIView):
|
||||
@login_required
|
||||
def post(self, request):
|
||||
"""创建流程图提交"""
|
||||
async def post(self, request):
|
||||
serializer = CreateFlowchartSubmissionSerializer(data=request.data)
|
||||
if not serializer.is_valid():
|
||||
return self.error(serializer.errors)
|
||||
|
||||
data = serializer.validated_data
|
||||
|
||||
# 验证题目存在
|
||||
try:
|
||||
problem = Problem.objects.get(id=data["problem_id"])
|
||||
problem = await Problem.objects.aget(id=data["problem_id"])
|
||||
except Problem.DoesNotExist:
|
||||
return self.error("Problem doesn't exist")
|
||||
|
||||
# 验证题目是否允许流程图提交
|
||||
if not problem.allow_flowchart:
|
||||
return self.error("This problem does not allow flowchart submission")
|
||||
|
||||
# 创建提交记录
|
||||
submission = FlowchartSubmission.objects.create(
|
||||
submission = await FlowchartSubmission.objects.acreate(
|
||||
user=request.user,
|
||||
problem=problem,
|
||||
mermaid_code=data["mermaid_code"],
|
||||
flowchart_data=data.get("flowchart_data", {}),
|
||||
)
|
||||
|
||||
# 启动AI评分任务
|
||||
evaluate_flowchart_task.send(submission.id)
|
||||
|
||||
return self.success({"submission_id": submission.id, "status": "pending"})
|
||||
|
||||
@login_required
|
||||
def get(self, request):
|
||||
"""获取流程图提交详情"""
|
||||
async def get(self, request):
|
||||
submission_id = request.GET.get("id")
|
||||
if not submission_id:
|
||||
return self.error("submission_id is required")
|
||||
|
||||
try:
|
||||
submission = FlowchartSubmission.objects.get(id=submission_id)
|
||||
submission = await (
|
||||
FlowchartSubmission.objects.select_related("user", "problem")
|
||||
.filter(id=submission_id)
|
||||
.afirst()
|
||||
)
|
||||
if submission is None:
|
||||
raise FlowchartSubmission.DoesNotExist
|
||||
except FlowchartSubmission.DoesNotExist:
|
||||
return self.error("Submission doesn't exist")
|
||||
|
||||
if not submission.check_user_permission(request.user):
|
||||
return self.error("No permission for this submission")
|
||||
|
||||
serializer = FlowchartSubmissionSerializer(submission)
|
||||
return self.success(serializer.data)
|
||||
return self.success(await self.async_serialize_data(FlowchartSubmissionSerializer, submission))
|
||||
|
||||
|
||||
class FlowchartSubmissionListAPI(APIView):
|
||||
class FlowchartSubmissionListAPI(AsyncAPIView):
|
||||
@login_required
|
||||
def get(self, request):
|
||||
"""获取流程图提交列表"""
|
||||
async def get(self, request):
|
||||
username = request.GET.get("username")
|
||||
problem_id = request.GET.get("problem_id")
|
||||
myself = request.GET.get("myself")
|
||||
@@ -74,7 +72,7 @@ class FlowchartSubmissionListAPI(APIView):
|
||||
|
||||
if problem_id:
|
||||
try:
|
||||
problem = Problem.objects.get(
|
||||
problem = await Problem.objects.aget(
|
||||
_id__iexact=problem_id, contest_id__isnull=True, visible=True
|
||||
)
|
||||
except Problem.DoesNotExist:
|
||||
@@ -88,38 +86,42 @@ class FlowchartSubmissionListAPI(APIView):
|
||||
elif request.user.is_regular_user():
|
||||
queryset = queryset.filter(user=request.user)
|
||||
|
||||
data = self.paginate_data(request, queryset)
|
||||
data["results"] = FlowchartSubmissionListSerializer(
|
||||
data["results"], many=True
|
||||
).data
|
||||
data = await self.async_paginate_data(request, queryset)
|
||||
data["results"] = await self.async_serialize_data(
|
||||
FlowchartSubmissionListSerializer,
|
||||
data["results"],
|
||||
many=True,
|
||||
)
|
||||
return self.success(data)
|
||||
|
||||
|
||||
class FlowchartSubmissionRetryAPI(APIView):
|
||||
class FlowchartSubmissionRetryAPI(AsyncAPIView):
|
||||
@login_required
|
||||
def post(self, request):
|
||||
"""重新触发AI评分"""
|
||||
async def post(self, request):
|
||||
submission_id = request.data.get("submission_id")
|
||||
if not submission_id:
|
||||
return self.error("submission_id is required")
|
||||
|
||||
try:
|
||||
submission = FlowchartSubmission.objects.get(id=submission_id)
|
||||
submission = await (
|
||||
FlowchartSubmission.objects.select_related("problem")
|
||||
.filter(id=submission_id)
|
||||
.afirst()
|
||||
)
|
||||
if submission is None:
|
||||
raise FlowchartSubmission.DoesNotExist
|
||||
except FlowchartSubmission.DoesNotExist:
|
||||
return self.error("Submission doesn't exist")
|
||||
|
||||
# 检查权限
|
||||
if not submission.check_user_permission(request.user):
|
||||
return self.error("No permission for this submission")
|
||||
|
||||
# 检查是否可以重新评分
|
||||
if submission.status not in [
|
||||
FlowchartSubmissionStatus.FAILED,
|
||||
FlowchartSubmissionStatus.COMPLETED,
|
||||
]:
|
||||
return self.error("Submission is not in a state that allows retry")
|
||||
|
||||
# 重置状态并重新启动AI评分
|
||||
submission.status = FlowchartSubmissionStatus.PENDING
|
||||
submission.ai_score = None
|
||||
submission.ai_grade = None
|
||||
@@ -128,9 +130,8 @@ class FlowchartSubmissionRetryAPI(APIView):
|
||||
submission.ai_criteria_details = {}
|
||||
submission.processing_time = None
|
||||
submission.evaluation_time = None
|
||||
submission.save()
|
||||
await submission.asave()
|
||||
|
||||
# 重新启动AI评分任务
|
||||
evaluate_flowchart_task.send(submission.id)
|
||||
|
||||
return self.success(
|
||||
@@ -142,15 +143,14 @@ class FlowchartSubmissionRetryAPI(APIView):
|
||||
)
|
||||
|
||||
|
||||
class FlowchartSubmissionDetailAPI(APIView):
|
||||
class FlowchartSubmissionDetailAPI(AsyncAPIView):
|
||||
@login_required
|
||||
def get(self, request):
|
||||
"""获取当前用户对指定题目的流程图提交详情"""
|
||||
async def get(self, request):
|
||||
problem_id = request.GET.get("problem_id")
|
||||
if not problem_id:
|
||||
return self.error("problem_id is required")
|
||||
try:
|
||||
problem = Problem.objects.get(id=problem_id)
|
||||
problem = await Problem.objects.aget(id=problem_id)
|
||||
except Problem.DoesNotExist:
|
||||
return self.error("Problem doesn't exist")
|
||||
|
||||
@@ -158,34 +158,37 @@ class FlowchartSubmissionDetailAPI(APIView):
|
||||
page = int(request.GET.get("page", 0))
|
||||
except ValueError:
|
||||
return self.error("page must be an integer")
|
||||
submissions = FlowchartSubmission.objects.filter(
|
||||
user=request.user,
|
||||
problem=problem,
|
||||
status=FlowchartSubmissionStatus.COMPLETED,
|
||||
).order_by("create_time")
|
||||
count = submissions.count()
|
||||
submissions = (
|
||||
FlowchartSubmission.objects.select_related("user", "problem")
|
||||
.filter(
|
||||
user=request.user,
|
||||
problem=problem,
|
||||
status=FlowchartSubmissionStatus.COMPLETED,
|
||||
)
|
||||
.order_by("create_time")
|
||||
)
|
||||
count = await submissions.acount()
|
||||
if count == 0:
|
||||
return self.success({"submission": None, "count": 0})
|
||||
# page=0 means latest; page=N means the Nth submission (1-indexed, chronological)
|
||||
if page == 0:
|
||||
submission = submissions.last()
|
||||
submission = await submissions.alast()
|
||||
else:
|
||||
if page < 0 or page > count:
|
||||
return self.error("Page out of range")
|
||||
submission = submissions[page - 1]
|
||||
serializer = FlowchartSubmissionSerializer(submission)
|
||||
return self.success({"submission": serializer.data, "count": count})
|
||||
result = [s async for s in submissions[page - 1:page]]
|
||||
submission = result[0]
|
||||
data = await self.async_serialize_data(FlowchartSubmissionSerializer, submission)
|
||||
return self.success({"submission": data, "count": count})
|
||||
|
||||
|
||||
class FlowchartSubmissionCurrentAPI(APIView):
|
||||
class FlowchartSubmissionCurrentAPI(AsyncAPIView):
|
||||
@login_required
|
||||
def get(self, request):
|
||||
"""获取当前用户对指定题目的最新流程图提交,只返回次数和分数"""
|
||||
async def get(self, request):
|
||||
problem_id = request.GET.get("problem_id")
|
||||
if not problem_id:
|
||||
return self.error("problem_id is required")
|
||||
try:
|
||||
problem = Problem.objects.get(id=problem_id)
|
||||
problem = await Problem.objects.aget(id=problem_id)
|
||||
except Problem.DoesNotExist:
|
||||
return self.error("Problem doesn't exist")
|
||||
submissions = (
|
||||
@@ -197,10 +200,10 @@ class FlowchartSubmissionCurrentAPI(APIView):
|
||||
.values("ai_score", "ai_grade")
|
||||
.order_by("-create_time")
|
||||
)
|
||||
count = submissions.count()
|
||||
count = await submissions.acount()
|
||||
if count == 0:
|
||||
return self.success({"count": 0, "score": 0, "grade": ""})
|
||||
submission = submissions[0]
|
||||
submission = await submissions.afirst()
|
||||
return self.success(
|
||||
{
|
||||
"count": count,
|
||||
|
||||
@@ -3,33 +3,33 @@ from account.models import User
|
||||
from message.models import Message
|
||||
from message.serializers import CreateMessageSerializer, MessageSerializer
|
||||
from submission.models import Submission
|
||||
from utils.api import APIView
|
||||
from utils.api import AsyncAPIView
|
||||
from utils.api.api import validate_serializer
|
||||
|
||||
|
||||
class MessageAPI(APIView):
|
||||
class MessageAPI(AsyncAPIView):
|
||||
@login_required
|
||||
def get(self, request):
|
||||
async def get(self, request):
|
||||
messages = Message.objects.select_related(
|
||||
"recipient", "sender", "submission", "submission__problem"
|
||||
).filter(recipient=request.user)
|
||||
return self.success(self.paginate_data(request, messages, MessageSerializer))
|
||||
return self.success(await self.async_paginate_data(request, messages, MessageSerializer))
|
||||
|
||||
@validate_serializer(CreateMessageSerializer)
|
||||
@super_admin_required
|
||||
def post(self, request):
|
||||
async def post(self, request):
|
||||
data = request.data
|
||||
if data["recipient"] == request.user.id:
|
||||
return self.error("Can not send a message to youself")
|
||||
try:
|
||||
recipient = User.objects.get(id=data["recipient"], is_disabled=False)
|
||||
recipient = await User.objects.aget(id=data["recipient"], is_disabled=False)
|
||||
except User.DoesNotExist:
|
||||
return self.error("User does not exist")
|
||||
try:
|
||||
submission = Submission.objects.get(id=data["submission"])
|
||||
submission = await Submission.objects.aget(id=data["submission"])
|
||||
except Submission.DoesNotExist:
|
||||
return self.error("Submission does not exist")
|
||||
Message.objects.create(
|
||||
await Message.objects.acreate(
|
||||
submission=submission,
|
||||
message=data["message"],
|
||||
sender=request.user,
|
||||
|
||||
@@ -292,4 +292,15 @@ class _SysOptionsMeta(type):
|
||||
|
||||
|
||||
class SysOptions(metaclass=_SysOptionsMeta):
|
||||
pass
|
||||
@classmethod
|
||||
async def aget(cls, key):
|
||||
from asgiref.sync import sync_to_async
|
||||
return await sync_to_async(getattr)(cls, key)
|
||||
|
||||
@classmethod
|
||||
async def aget_many(cls, *keys):
|
||||
from asgiref.sync import sync_to_async
|
||||
|
||||
def _get_all():
|
||||
return {k: getattr(cls, k) for k in keys}
|
||||
return await sync_to_async(_get_all)()
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
import random
|
||||
from datetime import datetime
|
||||
|
||||
from django.core.cache import cache
|
||||
from asgiref.sync import sync_to_async
|
||||
from django.db.models import BooleanField, Case, Count, Q, Value, When
|
||||
from django.db.models.functions import ExtractYear
|
||||
from django.utils import timezone
|
||||
|
||||
from account.decorators import check_contest_permission
|
||||
from account.models import User
|
||||
from contest.models import ContestRuleType
|
||||
from submission.models import JudgeStatus, Submission
|
||||
from utils.api import APIView
|
||||
from utils.api import APIView, AsyncAPIView
|
||||
from utils.async_helpers import async_cache_get, async_cache_set
|
||||
from utils.constants import CacheKey
|
||||
|
||||
from ..models import Problem, ProblemTag
|
||||
@@ -21,11 +22,11 @@ from ..serializers import (
|
||||
)
|
||||
|
||||
|
||||
class ProblemTagAPI(APIView):
|
||||
def get(self, request):
|
||||
class ProblemTagAPI(AsyncAPIView):
|
||||
async def get(self, request):
|
||||
keyword = request.GET.get("keyword", "")
|
||||
cache_key = f"{CacheKey.problem_tags}:{keyword}"
|
||||
cached = cache.get(cache_key)
|
||||
cached = await async_cache_get(cache_key)
|
||||
if cached is not None:
|
||||
return self.success(cached)
|
||||
|
||||
@@ -33,48 +34,48 @@ class ProblemTagAPI(APIView):
|
||||
if keyword:
|
||||
qs = ProblemTag.objects.filter(name__icontains=keyword)
|
||||
tags = qs.annotate(problem_count=Count("problem")).filter(problem_count__gt=0)
|
||||
data = TagSerializer(tags, many=True).data
|
||||
cache.set(cache_key, data, 3600)
|
||||
data = await self.async_serialize_data(TagSerializer, [tag async for tag in tags], many=True)
|
||||
await async_cache_set(cache_key, data, 3600)
|
||||
return self.success(data)
|
||||
|
||||
|
||||
class PickOneAPI(APIView):
|
||||
def get(self, request):
|
||||
problems = Problem.objects.filter(contest_id__isnull=True, visible=True)
|
||||
count = problems.count()
|
||||
class PickOneAPI(AsyncAPIView):
|
||||
async def get(self, request):
|
||||
ids = Problem.objects.filter(contest_id__isnull=True, visible=True).values_list("_id", flat=True)
|
||||
count = await ids.acount()
|
||||
if count == 0:
|
||||
return self.error("No problem to pick")
|
||||
return self.success(problems[random.randint(0, count - 1)]._id)
|
||||
idx = random.randint(0, count - 1)
|
||||
result = [pid async for pid in ids[idx : idx + 1]]
|
||||
return self.success(result[0])
|
||||
|
||||
|
||||
class ProblemAPI(APIView):
|
||||
class ProblemAPI(AsyncAPIView):
|
||||
@staticmethod
|
||||
def _add_problem_status(request, queryset_values):
|
||||
if request.user.is_authenticated:
|
||||
profile = request.user.userprofile
|
||||
acm_problems_status = profile.acm_problems_status.get("problems", {})
|
||||
# paginate data
|
||||
results = queryset_values.get("results")
|
||||
if results is not None:
|
||||
problems = results
|
||||
else:
|
||||
problems = [queryset_values]
|
||||
for problem in problems:
|
||||
problem["my_status"] = acm_problems_status.get(
|
||||
str(problem["id"]), {}
|
||||
).get("status")
|
||||
def _add_problem_status(acm_problems_status, queryset_values):
|
||||
results = queryset_values.get("results")
|
||||
if results is not None:
|
||||
problems = results
|
||||
else:
|
||||
problems = [queryset_values]
|
||||
for problem in problems:
|
||||
problem["my_status"] = acm_problems_status.get(str(problem["id"]), {}).get("status")
|
||||
|
||||
def get(self, request):
|
||||
async def get(self, request):
|
||||
# 问题详情页
|
||||
problem_id = request.GET.get("problem_id")
|
||||
if problem_id:
|
||||
try:
|
||||
problem = Problem.objects.select_related("created_by").get(
|
||||
_id__iexact=problem_id, contest_id__isnull=True, visible=True
|
||||
)
|
||||
problem_data = ProblemSerializer(problem).data
|
||||
self._add_problem_status(request, problem_data)
|
||||
problem = await Problem.objects.select_related("created_by").prefetch_related("tags").filter(_id__iexact=problem_id, contest_id__isnull=True, visible=True).afirst()
|
||||
if problem is None:
|
||||
raise Problem.DoesNotExist
|
||||
problem_data = await self.async_serialize_data(ProblemSerializer, problem)
|
||||
if request.user.is_authenticated:
|
||||
from account.models import UserProfile
|
||||
|
||||
profile = await UserProfile.objects.aget(user=request.user)
|
||||
acm_problems_status = profile.acm_problems_status.get("problems", {})
|
||||
self._add_problem_status(acm_problems_status, problem_data)
|
||||
failed_statuses = [
|
||||
JudgeStatus.WRONG_ANSWER,
|
||||
JudgeStatus.CPU_TIME_LIMIT_EXCEEDED,
|
||||
@@ -83,11 +84,11 @@ class ProblemAPI(APIView):
|
||||
JudgeStatus.RUNTIME_ERROR,
|
||||
JudgeStatus.COMPILE_ERROR,
|
||||
]
|
||||
problem_data["my_failed_count"] = Submission.objects.filter(
|
||||
problem_data["my_failed_count"] = await Submission.objects.filter(
|
||||
user_id=request.user.id,
|
||||
problem_id=problem.id,
|
||||
result__in=failed_statuses,
|
||||
).count()
|
||||
).acount()
|
||||
else:
|
||||
problem_data["my_failed_count"] = 0
|
||||
return self.success(problem_data)
|
||||
@@ -98,12 +99,7 @@ class ProblemAPI(APIView):
|
||||
if not limit:
|
||||
return self.error("Limit is needed")
|
||||
|
||||
problems = (
|
||||
Problem.objects.select_related("created_by")
|
||||
.prefetch_related("tags")
|
||||
.filter(contest_id__isnull=True, visible=True)
|
||||
.order_by("-create_time")
|
||||
)
|
||||
problems = Problem.objects.select_related("created_by").prefetch_related("tags").filter(contest_id__isnull=True, visible=True).order_by("-create_time")
|
||||
|
||||
author = request.GET.get("author")
|
||||
if author:
|
||||
@@ -117,9 +113,7 @@ class ProblemAPI(APIView):
|
||||
# 搜索的情况
|
||||
keyword = request.GET.get("keyword", "").strip()
|
||||
if keyword:
|
||||
problems = problems.filter(
|
||||
Q(title__icontains=keyword) | Q(_id__icontains=keyword)
|
||||
)
|
||||
problems = problems.filter(Q(title__icontains=keyword) | Q(_id__icontains=keyword))
|
||||
|
||||
# 难度筛选
|
||||
difficulty = request.GET.get("difficulty")
|
||||
@@ -142,8 +136,13 @@ class ProblemAPI(APIView):
|
||||
problems = problems.order_by(sort)
|
||||
|
||||
# 根据profile 为做过的题目添加标记
|
||||
data = self.paginate_data(request, problems, ProblemListSerializer)
|
||||
self._add_problem_status(request, data)
|
||||
data = await self.async_paginate_data(request, problems, ProblemListSerializer)
|
||||
if request.user.is_authenticated:
|
||||
from account.models import UserProfile
|
||||
|
||||
profile = await UserProfile.objects.aget(user=request.user)
|
||||
acm_problems_status = profile.acm_problems_status.get("problems", {})
|
||||
self._add_problem_status(acm_problems_status, data)
|
||||
return self.success(data)
|
||||
|
||||
|
||||
@@ -152,24 +151,18 @@ class ContestProblemAPI(APIView):
|
||||
if request.user.is_authenticated:
|
||||
profile = request.user.userprofile
|
||||
if self.contest.rule_type == ContestRuleType.ACM:
|
||||
problems_status = profile.acm_problems_status.get(
|
||||
"contest_problems", {}
|
||||
)
|
||||
problems_status = profile.acm_problems_status.get("contest_problems", {})
|
||||
else:
|
||||
problems_status = profile.oi_problems_status.get("contest_problems", {})
|
||||
for problem in queryset_values:
|
||||
problem["my_status"] = problems_status.get(str(problem["id"]), {}).get(
|
||||
"status"
|
||||
)
|
||||
problem["my_status"] = problems_status.get(str(problem["id"]), {}).get("status")
|
||||
|
||||
@check_contest_permission(check_type="problems")
|
||||
def get(self, request):
|
||||
problem_id = request.GET.get("problem_id")
|
||||
if problem_id:
|
||||
try:
|
||||
problem = Problem.objects.select_related("created_by").get(
|
||||
_id__iexact=problem_id, contest=self.contest, visible=True
|
||||
)
|
||||
problem = Problem.objects.select_related("created_by").get(_id__iexact=problem_id, contest=self.contest, visible=True)
|
||||
except Problem.DoesNotExist:
|
||||
return self.error("Problem does not exist.")
|
||||
if self.contest.problem_details_permission(request.user):
|
||||
@@ -184,9 +177,7 @@ class ContestProblemAPI(APIView):
|
||||
problem_data = ProblemSafeSerializer(problem).data
|
||||
return self.success(problem_data)
|
||||
|
||||
contest_problems = Problem.objects.select_related("created_by").prefetch_related("tags").filter(
|
||||
contest=self.contest, visible=True
|
||||
)
|
||||
contest_problems = Problem.objects.select_related("created_by").prefetch_related("tags").filter(contest=self.contest, visible=True)
|
||||
if self.contest.problem_details_permission(request.user):
|
||||
data = ProblemListSerializer(contest_problems, many=True).data
|
||||
self._add_problem_status(request, data)
|
||||
@@ -195,59 +186,60 @@ class ContestProblemAPI(APIView):
|
||||
return self.success(data)
|
||||
|
||||
|
||||
class ProblemSolvedPeopleCount(APIView):
|
||||
def get(self, request):
|
||||
class ProblemSolvedPeopleCount(AsyncAPIView):
|
||||
async def get(self, request):
|
||||
problem_id = request.GET.get("problem_id")
|
||||
rate = "0"
|
||||
if not request.user.is_authenticated:
|
||||
return self.success(rate)
|
||||
submission_count = Submission.objects.filter(
|
||||
submission_count = await Submission.objects.filter(
|
||||
user_id=request.user.id,
|
||||
problem_id=problem_id,
|
||||
result__in=[JudgeStatus.ACCEPTED, JudgeStatus.AST_CHECK_FAILED],
|
||||
).count()
|
||||
).acount()
|
||||
if submission_count == 0:
|
||||
return self.success(rate)
|
||||
today = datetime.today()
|
||||
years_ago = datetime(today.year - 2, today.month, today.day, 0, 0)
|
||||
total_count = User.objects.filter(
|
||||
is_disabled=False, last_login__gte=years_ago
|
||||
).count()
|
||||
accepted_count = Submission.objects.filter(
|
||||
problem_id=problem_id,
|
||||
result__in=[JudgeStatus.ACCEPTED, JudgeStatus.AST_CHECK_FAILED],
|
||||
create_time__gte=years_ago,
|
||||
).aggregate(user_count=Count("user_id", distinct=True))["user_count"]
|
||||
if accepted_count < total_count:
|
||||
now = timezone.now()
|
||||
years_ago = now.replace(year=now.year - 2, hour=0, minute=0, second=0, microsecond=0)
|
||||
total_count = await User.objects.filter(is_disabled=False, last_login__gte=years_ago).acount()
|
||||
accepted_count = (
|
||||
await sync_to_async(
|
||||
Submission.objects.filter(
|
||||
problem_id=problem_id,
|
||||
result__in=[JudgeStatus.ACCEPTED, JudgeStatus.AST_CHECK_FAILED],
|
||||
create_time__gte=years_ago,
|
||||
).aggregate,
|
||||
thread_sensitive=True,
|
||||
)(user_count=Count("user_id", distinct=True))
|
||||
)["user_count"]
|
||||
if total_count and accepted_count < total_count:
|
||||
rate = "%.2f" % ((total_count - accepted_count) / total_count * 100)
|
||||
else:
|
||||
rate = "0"
|
||||
return self.success(rate)
|
||||
|
||||
|
||||
class SimilarProblemAPI(APIView):
|
||||
def get(self, request):
|
||||
class SimilarProblemAPI(AsyncAPIView):
|
||||
async def get(self, request):
|
||||
problem_display_id = request.GET.get("problem_id")
|
||||
if not problem_display_id:
|
||||
return self.error("problem_id is required")
|
||||
|
||||
try:
|
||||
problem = Problem.objects.get(_id__iexact=problem_display_id, contest__isnull=True)
|
||||
problem = await Problem.objects.aget(_id__iexact=problem_display_id, contest__isnull=True)
|
||||
except Problem.DoesNotExist:
|
||||
return self.error("Problem not found")
|
||||
|
||||
tag_ids = list(problem.tags.values_list("id", flat=True))
|
||||
tag_ids = [tag_id async for tag_id in problem.tags.values_list("id", flat=True)]
|
||||
if not tag_ids:
|
||||
return self.success([])
|
||||
|
||||
exclude_ids = [problem_display_id]
|
||||
if request.user.is_authenticated:
|
||||
profile = request.user.userprofile
|
||||
ac_display_ids = [
|
||||
v["_id"]
|
||||
for v in profile.acm_problems_status.get("problems", {}).values()
|
||||
if v.get("status") == JudgeStatus.ACCEPTED
|
||||
]
|
||||
from account.models import UserProfile
|
||||
|
||||
profile = await UserProfile.objects.aget(user=request.user)
|
||||
ac_display_ids = [v["_id"] for v in profile.acm_problems_status.get("problems", {}).values() if v.get("status") == JudgeStatus.ACCEPTED]
|
||||
exclude_ids.extend(ac_display_ids)
|
||||
|
||||
similar = (
|
||||
@@ -258,14 +250,15 @@ class SimilarProblemAPI(APIView):
|
||||
.distinct()
|
||||
.order_by("difficulty")[:5]
|
||||
)
|
||||
return self.success(ProblemListSerializer(similar, many=True).data)
|
||||
similar_list = [problem async for problem in similar]
|
||||
return self.success(await self.async_serialize_data(ProblemListSerializer, similar_list, many=True))
|
||||
|
||||
|
||||
class ProblemAuthorAPI(APIView):
|
||||
def get(self, request):
|
||||
class ProblemAuthorAPI(AsyncAPIView):
|
||||
async def get(self, request):
|
||||
show_all = request.GET.get("all", "0") == "1"
|
||||
cache_key = f"{CacheKey.problem_authors}{'_all' if show_all else '_only_visible'}"
|
||||
cached_data = cache.get(cache_key)
|
||||
cached_data = await async_cache_get(cache_key)
|
||||
if cached_data:
|
||||
return self.success(cached_data)
|
||||
|
||||
@@ -273,38 +266,32 @@ class ProblemAuthorAPI(APIView):
|
||||
if not show_all:
|
||||
problem_filter["visible"] = True
|
||||
|
||||
authors = (
|
||||
Problem.objects.filter(**problem_filter)
|
||||
.values("created_by__username")
|
||||
.annotate(problem_count=Count("id"))
|
||||
.order_by("-problem_count")
|
||||
)
|
||||
authors = Problem.objects.filter(**problem_filter).values("created_by__username").annotate(problem_count=Count("id")).order_by("-problem_count")
|
||||
result = [
|
||||
{
|
||||
"username": author["created_by__username"],
|
||||
"problem_count": author["problem_count"],
|
||||
}
|
||||
for author in authors
|
||||
async for author in authors
|
||||
]
|
||||
|
||||
cache.set(cache_key, result, 7200)
|
||||
await async_cache_set(cache_key, result, 7200)
|
||||
return self.success(result)
|
||||
|
||||
|
||||
class ProblemYearlyACRateAPI(APIView):
|
||||
def get(self, request):
|
||||
class ProblemYearlyACRateAPI(AsyncAPIView):
|
||||
async def get(self, request):
|
||||
problem_id = request.GET.get("problem_id")
|
||||
if not problem_id:
|
||||
return self.error("problem_id is required")
|
||||
|
||||
cache_key = f"{CacheKey.problem_yearly_ac}:{problem_id}"
|
||||
cached = cache.get(cache_key)
|
||||
cached = await async_cache_get(cache_key)
|
||||
if cached is not None:
|
||||
return self.success(cached)
|
||||
|
||||
try:
|
||||
problem = Problem.objects.get(
|
||||
_id__iexact=problem_id, contest_id__isnull=True, visible=True
|
||||
)
|
||||
problem = await Problem.objects.aget(_id__iexact=problem_id, contest_id__isnull=True, visible=True)
|
||||
except Problem.DoesNotExist:
|
||||
return self.error("Problem does not exist")
|
||||
|
||||
@@ -328,12 +315,10 @@ class ProblemYearlyACRateAPI(APIView):
|
||||
"year": row["year"],
|
||||
"total": row["total"],
|
||||
"accepted": row["accepted"],
|
||||
"ac_rate": round(row["accepted"] / row["total"] * 100, 2)
|
||||
if row["total"] > 0
|
||||
else 0.0,
|
||||
"ac_rate": round(row["accepted"] / row["total"] * 100, 2) if row["total"] > 0 else 0.0,
|
||||
}
|
||||
for row in rows
|
||||
async for row in rows
|
||||
]
|
||||
|
||||
cache.set(cache_key, data, 3600)
|
||||
await async_cache_set(cache_key, data, 3600)
|
||||
return self.success(data)
|
||||
|
||||
@@ -63,7 +63,7 @@ urlpatterns = [
|
||||
name="admin_problemset_progress_detail_api",
|
||||
),
|
||||
# 题单同步管理API
|
||||
path(
|
||||
path( # DEPRECATED: 前端未调用
|
||||
"problemset/<int:problem_set_id>/sync",
|
||||
ProblemSetSyncAPI.as_view(),
|
||||
name="admin_problemset_sync_api",
|
||||
|
||||
@@ -24,7 +24,7 @@ urlpatterns = [
|
||||
ProblemSetProblemAPI.as_view(),
|
||||
name="problemset_problems_api",
|
||||
),
|
||||
path(
|
||||
path( # DEPRECATED: 前端未调用
|
||||
"problemset/<int:problem_set_id>/problems/<int:problem_id>",
|
||||
ProblemSetProblemAPI.as_view(),
|
||||
name="problemset_problem_detail_api",
|
||||
@@ -35,12 +35,12 @@ urlpatterns = [
|
||||
ProblemSetProgressAPI.as_view(),
|
||||
name="problemset_progress_api",
|
||||
),
|
||||
path(
|
||||
path( # DEPRECATED: 前端未调用
|
||||
"problemset/<int:problem_set_id>/progress",
|
||||
ProblemSetProgressAPI.as_view(),
|
||||
name="problemset_progress_detail_api",
|
||||
),
|
||||
path("user/progress", UserProgressAPI.as_view(), name="user_progress_api"),
|
||||
path("user/progress", UserProgressAPI.as_view(), name="user_progress_api"), # DEPRECATED: 前端未调用
|
||||
# 奖章相关API
|
||||
path("user/badges", UserBadgeAPI.as_view(), name="user_badges_api"),
|
||||
path(
|
||||
|
||||
@@ -332,6 +332,7 @@ class ProblemSetProgressAdminAPI(APIView):
|
||||
return self.error("用户未加入该题单")
|
||||
|
||||
|
||||
# DEPRECATED: 前端未调用 (2026-05-26)
|
||||
class ProblemSetSyncAPI(APIView):
|
||||
"""题单同步管理API"""
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from asgiref.sync import sync_to_async
|
||||
from django.db.models import Avg, Count, Prefetch, Q
|
||||
from django.utils import timezone
|
||||
|
||||
@@ -24,14 +25,14 @@ from problemset.serializers import (
|
||||
UpdateProgressSerializer,
|
||||
UserBadgeSerializer,
|
||||
)
|
||||
from submission.models import JudgeStatus, Submission, is_accepted
|
||||
from utils.api import APIView, validate_serializer
|
||||
from submission.models import Submission, is_accepted
|
||||
from utils.api import APIView, AsyncAPIView, validate_serializer
|
||||
|
||||
|
||||
class ProblemSetAPI(APIView):
|
||||
class ProblemSetAPI(AsyncAPIView):
|
||||
"""题单API - 用户端"""
|
||||
|
||||
def get(self, request):
|
||||
async def get(self, request):
|
||||
"""获取题单列表"""
|
||||
# 预加载创建者信息
|
||||
problem_sets = ProblemSet.objects.filter(visible=True).exclude(status=ProblemSetStatus.DRAFT).select_related("created_by")
|
||||
@@ -65,16 +66,19 @@ class ProblemSetAPI(APIView):
|
||||
user_earned_badge_ids = set()
|
||||
if request.user.is_authenticated:
|
||||
# 先获取所有题单ID(不应用prefetch_related,只获取ID)
|
||||
problem_set_ids = list(problem_sets.values_list("id", flat=True))
|
||||
problem_set_ids = [problem_set_id async for problem_set_id in problem_sets.values_list("id", flat=True)]
|
||||
|
||||
if problem_set_ids:
|
||||
# 批量查询用户在这些题单中的进度
|
||||
user_progresses = ProblemSetProgress.objects.filter(problemset_id__in=problem_set_ids, user=request.user).select_related("problemset")
|
||||
# 构建映射:题单ID -> 进度对象
|
||||
user_progress_map = {progress.problemset_id: progress for progress in user_progresses}
|
||||
user_progress_map = {progress.problemset_id: progress async for progress in user_progresses}
|
||||
|
||||
# 批量查询用户已获得的奖章ID(这些题单相关的)
|
||||
user_earned_badge_ids = set(UserBadge.objects.filter(user=request.user, badge__problemset_id__in=problem_set_ids).values_list("badge_id", flat=True))
|
||||
user_earned_badge_ids = {
|
||||
badge_id
|
||||
async for badge_id in UserBadge.objects.filter(user=request.user, badge__problemset_id__in=problem_set_ids).values_list("badge_id", flat=True)
|
||||
}
|
||||
|
||||
# 预加载奖章信息(在获取ID之后应用,避免在获取ID时也预加载)
|
||||
problem_sets = problem_sets.prefetch_related(Prefetch("problemsetbadge_set", queryset=ProblemSetBadge.objects.all(), to_attr="badges"))
|
||||
@@ -83,31 +87,35 @@ class ProblemSetAPI(APIView):
|
||||
request._user_progress_map = user_progress_map
|
||||
request._user_earned_badge_ids = user_earned_badge_ids
|
||||
|
||||
data = self.paginate_data(request, problem_sets, ProblemSetListSerializer)
|
||||
data = await self.async_paginate_data(request, problem_sets, ProblemSetListSerializer)
|
||||
return self.success(data)
|
||||
|
||||
|
||||
class ProblemSetDetailAPI(APIView):
|
||||
class ProblemSetDetailAPI(AsyncAPIView):
|
||||
"""题单详情API - 用户端"""
|
||||
|
||||
def get(self, request, problem_set_id):
|
||||
async def get(self, request, problem_set_id):
|
||||
"""获取题单详情"""
|
||||
try:
|
||||
problem_set = ProblemSet.objects.filter(id=problem_set_id, visible=True).exclude(status=ProblemSetStatus.DRAFT).get()
|
||||
problem_set = await (
|
||||
ProblemSet.objects.select_related("created_by")
|
||||
.filter(id=problem_set_id, visible=True)
|
||||
.exclude(status=ProblemSetStatus.DRAFT)
|
||||
.aget()
|
||||
)
|
||||
except ProblemSet.DoesNotExist:
|
||||
return self.error("题单不存在")
|
||||
|
||||
serializer = ProblemSetSerializer(problem_set, context={"request": request})
|
||||
return self.success(serializer.data)
|
||||
return self.success(await self.async_serialize_data(ProblemSetSerializer, problem_set, context={"request": request}))
|
||||
|
||||
|
||||
class ProblemSetProblemAPI(APIView):
|
||||
class ProblemSetProblemAPI(AsyncAPIView):
|
||||
"""题单题目API - 用户端"""
|
||||
|
||||
def get(self, request, problem_set_id):
|
||||
async def get(self, request, problem_set_id):
|
||||
"""获取题单中的题目列表"""
|
||||
try:
|
||||
problem_set = ProblemSet.objects.filter(id=problem_set_id, visible=True).exclude(status=ProblemSetStatus.DRAFT).get()
|
||||
problem_set = await ProblemSet.objects.filter(id=problem_set_id, visible=True).exclude(status=ProblemSetStatus.DRAFT).aget()
|
||||
except ProblemSet.DoesNotExist:
|
||||
return self.error("题单不存在")
|
||||
|
||||
@@ -115,12 +123,16 @@ class ProblemSetProblemAPI(APIView):
|
||||
# 预取当前用户的题单进度,供 get_is_completed 使用,避免 N+1
|
||||
user_progress = None
|
||||
if request.user.is_authenticated:
|
||||
try:
|
||||
user_progress = ProblemSetProgress.objects.get(problemset=problem_set, user=request.user)
|
||||
except ProblemSetProgress.DoesNotExist:
|
||||
pass
|
||||
serializer = ProblemSetProblemSerializer(problems, many=True, context={"request": request, "user_progress": user_progress})
|
||||
return self.success(serializer.data)
|
||||
user_progress = await ProblemSetProgress.objects.filter(problemset=problem_set, user=request.user).afirst()
|
||||
problem_list = [problem async for problem in problems]
|
||||
return self.success(
|
||||
await self.async_serialize_data(
|
||||
ProblemSetProblemSerializer,
|
||||
problem_list,
|
||||
many=True,
|
||||
context={"request": request, "user_progress": user_progress},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class ProblemSetProgressAPI(APIView):
|
||||
@@ -236,6 +248,7 @@ class ProblemSetProgressAPI(APIView):
|
||||
UserBadge.objects.create(user=progress.user, badge=badge)
|
||||
|
||||
|
||||
# DEPRECATED: 前端未调用 (2026-05-26)
|
||||
class UserProgressAPI(APIView):
|
||||
"""用户进度API"""
|
||||
|
||||
@@ -247,10 +260,10 @@ class UserProgressAPI(APIView):
|
||||
return self.success(serializer.data)
|
||||
|
||||
|
||||
class UserBadgeAPI(APIView):
|
||||
class UserBadgeAPI(AsyncAPIView):
|
||||
"""用户奖章API"""
|
||||
|
||||
def get(self, request):
|
||||
async def get(self, request):
|
||||
"""获取用户的奖章列表"""
|
||||
# 支持通过username参数获取指定用户的徽章
|
||||
username = request.GET.get("username")
|
||||
@@ -258,41 +271,41 @@ class UserBadgeAPI(APIView):
|
||||
if username:
|
||||
# 获取指定用户的徽章
|
||||
try:
|
||||
target_user = User.objects.get(username=username, is_disabled=False)
|
||||
badges = UserBadge.objects.filter(user=target_user).order_by("-earned_time")
|
||||
target_user = await User.objects.aget(username=username, is_disabled=False)
|
||||
badges = UserBadge.objects.select_related("badge").filter(user=target_user).order_by("-earned_time")
|
||||
except User.DoesNotExist:
|
||||
return self.error("用户不存在")
|
||||
else:
|
||||
# 获取当前用户的徽章
|
||||
badges = UserBadge.objects.filter(user=request.user).order_by("-earned_time")
|
||||
badges = UserBadge.objects.select_related("badge").filter(user=request.user).order_by("-earned_time")
|
||||
|
||||
serializer = UserBadgeSerializer(badges, many=True)
|
||||
return self.success(serializer.data)
|
||||
badge_list = [badge async for badge in badges]
|
||||
return self.success(await self.async_serialize_data(UserBadgeSerializer, badge_list, many=True))
|
||||
|
||||
|
||||
class ProblemSetBadgeAPI(APIView):
|
||||
class ProblemSetBadgeAPI(AsyncAPIView):
|
||||
"""题单奖章API - 用户端"""
|
||||
|
||||
def get(self, request, problem_set_id):
|
||||
async def get(self, request, problem_set_id):
|
||||
"""获取题单的奖章列表"""
|
||||
try:
|
||||
problem_set = ProblemSet.objects.filter(id=problem_set_id, visible=True).exclude(status=ProblemSetStatus.DRAFT).get()
|
||||
problem_set = await ProblemSet.objects.filter(id=problem_set_id, visible=True).exclude(status=ProblemSetStatus.DRAFT).aget()
|
||||
except ProblemSet.DoesNotExist:
|
||||
return self.error("题单不存在")
|
||||
|
||||
badges = ProblemSetBadge.objects.filter(problemset=problem_set)
|
||||
serializer = ProblemSetBadgeSerializer(badges, many=True)
|
||||
return self.success(serializer.data)
|
||||
badge_list = [badge async for badge in badges]
|
||||
return self.success(await self.async_serialize_data(ProblemSetBadgeSerializer, badge_list, many=True))
|
||||
|
||||
|
||||
class ProblemSetUserProgressAPI(APIView):
|
||||
class ProblemSetUserProgressAPI(AsyncAPIView):
|
||||
"""题单用户进度列表API"""
|
||||
|
||||
@admin_role_required
|
||||
def get(self, request, problem_set_id: int):
|
||||
async def get(self, request, problem_set_id: int):
|
||||
"""获取题单的用户进度列表"""
|
||||
try:
|
||||
problem_set = ProblemSet.objects.filter(id=problem_set_id, visible=True).exclude(status=ProblemSetStatus.DRAFT).get()
|
||||
problem_set = await ProblemSet.objects.filter(id=problem_set_id, visible=True).exclude(status=ProblemSetStatus.DRAFT).aget()
|
||||
except ProblemSet.DoesNotExist:
|
||||
return self.error("题单不存在")
|
||||
|
||||
@@ -321,7 +334,7 @@ class ProblemSetUserProgressAPI(APIView):
|
||||
|
||||
# 计算统计数据(基于所有数据,而非分页数据)
|
||||
# 使用一次查询获取所有统计数据
|
||||
stats = progresses.aggregate(
|
||||
stats = await sync_to_async(progresses.aggregate, thread_sensitive=True)(
|
||||
total=Count("id"),
|
||||
completed=Count("id", filter=Q(is_completed=True)),
|
||||
avg_progress=Avg("progress_percentage"),
|
||||
@@ -351,7 +364,7 @@ class ProblemSetUserProgressAPI(APIView):
|
||||
# 构建题单所有题目的数据结构和映射
|
||||
all_problems_list = []
|
||||
all_problems_map = {}
|
||||
for psp in all_problemset_problems:
|
||||
async for psp in all_problemset_problems:
|
||||
problem_data = {
|
||||
"id": psp.problem.id,
|
||||
"_id": psp.problem._id,
|
||||
@@ -362,7 +375,7 @@ class ProblemSetUserProgressAPI(APIView):
|
||||
all_problems_map[str(psp.problem.id)] = psp.problem
|
||||
|
||||
# 从当前页的数据中收集已完成的问题ID,用于序列化器
|
||||
paginated_progresses = list(progresses[offset : offset + limit])
|
||||
paginated_progresses = [progress async for progress in progresses[offset : offset + limit]]
|
||||
completed_problem_ids = set()
|
||||
for progress in paginated_progresses:
|
||||
if progress.progress_detail:
|
||||
@@ -376,7 +389,7 @@ class ProblemSetUserProgressAPI(APIView):
|
||||
request._problems_dict_cache = problems_dict
|
||||
|
||||
# 使用分页
|
||||
data = self.paginate_data(request, progresses, ProblemSetProgressSerializer)
|
||||
data = await self.async_paginate_data(request, progresses, ProblemSetProgressSerializer)
|
||||
|
||||
# 添加统计数据
|
||||
data["statistics"] = {
|
||||
|
||||
@@ -28,6 +28,7 @@ dependencies = [
|
||||
"tree-sitter-c>=0.24.2",
|
||||
"tree-sitter-python>=0.25.0",
|
||||
"xlsxwriter>=3.2.9,<4",
|
||||
"asgiref>=3.11.1",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
@@ -35,6 +36,10 @@ dev = [
|
||||
"ruff>=0.15.11",
|
||||
]
|
||||
|
||||
[tool.pyright]
|
||||
venvPath = "."
|
||||
venv = ".venv"
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 180
|
||||
exclude = ["*/migrations/*", "*settings.py", "*/apps.py", ".venv"]
|
||||
|
||||
@@ -12,6 +12,6 @@ urlpatterns = [
|
||||
path("submission", SubmissionAPI.as_view()),
|
||||
path("submissions", SubmissionListAPI.as_view()),
|
||||
path("submissions/today_count", SubmissionsTodayCount.as_view()),
|
||||
path("submission_exists", SubmissionExistsAPI.as_view()),
|
||||
path("submission_exists", SubmissionExistsAPI.as_view()), # DEPRECATED: 前端未调用
|
||||
path("contest_submissions", ContestSubmissionListAPI.as_view()),
|
||||
]
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import ipaddress
|
||||
from datetime import datetime
|
||||
|
||||
from asgiref.sync import sync_to_async
|
||||
from django.utils import timezone
|
||||
|
||||
from account.decorators import check_contest_permission, login_required
|
||||
from contest.models import ContestRuleType, ContestStatus
|
||||
@@ -8,7 +10,7 @@ from options.options import SysOptions
|
||||
|
||||
# from judge.dispatcher import JudgeDispatcher
|
||||
from problem.models import Problem, ProblemRuleType
|
||||
from utils.api import APIView, validate_serializer
|
||||
from utils.api import APIView, AsyncAPIView, validate_serializer
|
||||
from utils.cache import cache
|
||||
from utils.captcha import Captcha
|
||||
from utils.throttling import TokenBucket
|
||||
@@ -154,8 +156,8 @@ class SubmissionAPI(APIView):
|
||||
return self.success()
|
||||
|
||||
|
||||
class SubmissionListAPI(APIView):
|
||||
def get(self, request):
|
||||
class SubmissionListAPI(AsyncAPIView):
|
||||
async def get(self, request):
|
||||
if not request.GET.get("limit"):
|
||||
return self.error("Limit is needed")
|
||||
if request.GET.get("contest_id"):
|
||||
@@ -171,14 +173,15 @@ class SubmissionListAPI(APIView):
|
||||
language = request.GET.get("language")
|
||||
if problem_id:
|
||||
try:
|
||||
problem = Problem.objects.get(
|
||||
problem = await Problem.objects.aget(
|
||||
_id__iexact=problem_id, contest_id__isnull=True, visible=True
|
||||
)
|
||||
except Problem.DoesNotExist:
|
||||
return self.error("Problem doesn't exist")
|
||||
submissions = submissions.filter(problem=problem)
|
||||
|
||||
if not SysOptions.submission_list_show_all and request.user.is_regular_user():
|
||||
show_all = await SysOptions.aget("submission_list_show_all")
|
||||
if not show_all and request.user.is_regular_user():
|
||||
return self.success({"results": [], "total": 0})
|
||||
|
||||
if myself and myself == "1":
|
||||
@@ -190,21 +193,25 @@ class SubmissionListAPI(APIView):
|
||||
if language:
|
||||
submissions = submissions.filter(language=language)
|
||||
if request.GET.get("today") == "1":
|
||||
today = datetime.today()
|
||||
now = timezone.now()
|
||||
submissions = submissions.filter(
|
||||
create_time__gte=datetime(today.year, today.month, today.day, 0, 0)
|
||||
create_time__gte=now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
)
|
||||
|
||||
data = self.paginate_data(request, submissions)
|
||||
data = await self.async_paginate_data(request, submissions)
|
||||
results = data["results"]
|
||||
if request.user.is_authenticated and request.user.is_regular_user():
|
||||
problem_ids = list({s.problem_id for s in results})
|
||||
progress_cache = bulk_fetch_problemset_progress(request.user, problem_ids)
|
||||
progress_cache = await sync_to_async(bulk_fetch_problemset_progress)(request.user, problem_ids)
|
||||
else:
|
||||
progress_cache = {}
|
||||
data["results"] = SubmissionListSerializer(
|
||||
results, many=True, user=request.user, problemset_progress_cache=progress_cache
|
||||
).data
|
||||
data["results"] = await self.async_serialize_data(
|
||||
SubmissionListSerializer,
|
||||
results,
|
||||
many=True,
|
||||
user=request.user,
|
||||
problemset_progress_cache=progress_cache,
|
||||
)
|
||||
return self.success(data)
|
||||
|
||||
|
||||
@@ -262,6 +269,7 @@ class ContestSubmissionListAPI(APIView):
|
||||
return self.success(data)
|
||||
|
||||
|
||||
# DEPRECATED: 前端未调用 (2026-05-26)
|
||||
class SubmissionExistsAPI(APIView):
|
||||
def get(self, request):
|
||||
if not request.GET.get("problem_id"):
|
||||
@@ -274,10 +282,10 @@ class SubmissionExistsAPI(APIView):
|
||||
)
|
||||
|
||||
|
||||
class SubmissionsTodayCount(APIView):
|
||||
def get(self, request):
|
||||
today = datetime.today()
|
||||
count = Submission.objects.filter(
|
||||
create_time__gte=datetime(today.year, today.month, today.day, 0, 0)
|
||||
).count()
|
||||
class SubmissionsTodayCount(AsyncAPIView):
|
||||
async def get(self, request):
|
||||
now = timezone.now()
|
||||
count = await Submission.objects.filter(
|
||||
create_time__gte=now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
).acount()
|
||||
return self.success(count)
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import asyncio
|
||||
import functools
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
|
||||
from asgiref.sync import sync_to_async
|
||||
from django.http import HttpResponse, QueryDict
|
||||
from django.utils.decorators import method_decorator
|
||||
from django.views.decorators.csrf import csrf_exempt
|
||||
@@ -162,6 +165,77 @@ class CSRFExemptAPIView(APIView):
|
||||
return super(CSRFExemptAPIView, self).dispatch(request, *args, **kwargs)
|
||||
|
||||
|
||||
class AsyncAPIView(APIView):
|
||||
view_is_async = True
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
cls.view_is_async = True
|
||||
|
||||
async def dispatch(self, request, *args, **kwargs):
|
||||
if self.request_parsers:
|
||||
try:
|
||||
request.data = self._get_request_data(self.request)
|
||||
except ValueError as e:
|
||||
return self.error(err="invalid-request", msg=str(e))
|
||||
try:
|
||||
handler = getattr(self, request.method.lower(), self.http_method_not_allowed)
|
||||
response = handler(request, *args, **kwargs)
|
||||
if asyncio.iscoroutine(response):
|
||||
response = await response
|
||||
return response
|
||||
except APIError as e:
|
||||
ret = {"msg": e.msg}
|
||||
if e.err:
|
||||
ret["err"] = e.err
|
||||
return self.error(**ret)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
return self.server_error()
|
||||
|
||||
def serialize_data(self, object_serializer, data, **kwargs):
|
||||
return object_serializer(data, **kwargs).data
|
||||
|
||||
async def async_serialize_data(self, object_serializer, data, **kwargs):
|
||||
return await sync_to_async(
|
||||
self.serialize_data,
|
||||
thread_sensitive=True,
|
||||
)(object_serializer, data, **kwargs)
|
||||
|
||||
async def async_paginate_data(self, request, query_set, object_serializer=None):
|
||||
try:
|
||||
limit = int(request.GET.get("limit", "10"))
|
||||
except ValueError:
|
||||
limit = 10
|
||||
if limit < 0 or limit > 250:
|
||||
limit = 10
|
||||
try:
|
||||
offset = int(request.GET.get("offset", "0"))
|
||||
except ValueError:
|
||||
offset = 0
|
||||
if offset < 0:
|
||||
offset = 0
|
||||
count, results = await asyncio.gather(
|
||||
query_set.acount(),
|
||||
sync_to_async(lambda: list(query_set[offset:offset + limit]), thread_sensitive=True)(),
|
||||
)
|
||||
if object_serializer:
|
||||
results = await self.async_serialize_data(
|
||||
object_serializer,
|
||||
results,
|
||||
many=True,
|
||||
context={"request": request},
|
||||
)
|
||||
data = {"results": results, "total": count}
|
||||
return data
|
||||
|
||||
|
||||
class CSRFExemptAsyncAPIView(AsyncAPIView):
|
||||
@method_decorator(csrf_exempt)
|
||||
async def dispatch(self, request, *args, **kwargs):
|
||||
return await super().dispatch(request, *args, **kwargs)
|
||||
|
||||
|
||||
def validate_serializer(serializer):
|
||||
"""
|
||||
@validate_serializer(TestSerializer)
|
||||
@@ -169,6 +243,20 @@ def validate_serializer(serializer):
|
||||
return self.success(request.data)
|
||||
"""
|
||||
def validate(view_method):
|
||||
if inspect.iscoroutinefunction(view_method):
|
||||
@functools.wraps(view_method)
|
||||
async def async_handle(*args, **kwargs):
|
||||
self = args[0]
|
||||
request = args[1]
|
||||
s = serializer(data=request.data)
|
||||
if s.is_valid():
|
||||
request.data = s.data
|
||||
request.serializer = s
|
||||
return await view_method(*args, **kwargs)
|
||||
else:
|
||||
return self.invalid_serializer(s)
|
||||
return async_handle
|
||||
|
||||
@functools.wraps(view_method)
|
||||
def handle(*args, **kwargs):
|
||||
self = args[0]
|
||||
@@ -180,7 +268,6 @@ def validate_serializer(serializer):
|
||||
return view_method(*args, **kwargs)
|
||||
else:
|
||||
return self.invalid_serializer(s)
|
||||
|
||||
return handle
|
||||
|
||||
return validate
|
||||
|
||||
14
utils/async_helpers.py
Normal file
14
utils/async_helpers.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from asgiref.sync import sync_to_async
|
||||
from django.core.cache import cache
|
||||
|
||||
|
||||
async def async_cache_get(key, default=None):
|
||||
return await sync_to_async(cache.get, thread_sensitive=True)(key, default)
|
||||
|
||||
|
||||
async def async_cache_set(key, value, timeout=None):
|
||||
return await sync_to_async(cache.set, thread_sensitive=True)(key, value, timeout)
|
||||
|
||||
|
||||
async def async_cache_delete(key):
|
||||
return await sync_to_async(cache.delete, thread_sensitive=True)(key)
|
||||
@@ -4,5 +4,5 @@ from .views import SimditorFileUploadAPIView, SimditorImageUploadAPIView
|
||||
|
||||
urlpatterns = [
|
||||
path("upload_image", SimditorImageUploadAPIView.as_view()),
|
||||
path("upload_file", SimditorFileUploadAPIView.as_view()),
|
||||
path("upload_file", SimditorFileUploadAPIView.as_view()), # DEPRECATED: 前端未调用
|
||||
]
|
||||
|
||||
@@ -46,6 +46,7 @@ class SimditorImageUploadAPIView(CSRFExemptAPIView):
|
||||
"file_path": f"{settings.UPLOAD_PREFIX}/{img_name}"})
|
||||
|
||||
|
||||
# DEPRECATED: 前端未调用 (2026-05-26)
|
||||
class SimditorFileUploadAPIView(CSRFExemptAPIView):
|
||||
request_parsers = ()
|
||||
|
||||
|
||||
2
uv.lock
generated
2
uv.lock
generated
@@ -550,6 +550,7 @@ name = "onlinejudge"
|
||||
version = "2.0.0"
|
||||
source = { virtual = "." }
|
||||
dependencies = [
|
||||
{ name = "asgiref" },
|
||||
{ name = "channels" },
|
||||
{ name = "channels-redis" },
|
||||
{ name = "django" },
|
||||
@@ -582,6 +583,7 @@ dev = [
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "asgiref", specifier = ">=3.11.1" },
|
||||
{ name = "channels", specifier = ">=4.3.2,<5" },
|
||||
{ name = "channels-redis", specifier = ">=4.3.0,<5" },
|
||||
{ name = "django", specifier = ">=6.0.4,<6.1" },
|
||||
|
||||
Reference in New Issue
Block a user