async
This commit is contained in:
@@ -1,5 +1,6 @@
|
|||||||
import functools
|
import functools
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import inspect
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from contest.models import Contest, ContestRuleType, ContestStatus, ContestType
|
from contest.models import Contest, ContestRuleType, ContestStatus, ContestType
|
||||||
@@ -15,47 +16,58 @@ class BasePermissionDecorator(object):
|
|||||||
self.func = func
|
self.func = func
|
||||||
|
|
||||||
def __get__(self, obj, obj_type):
|
def __get__(self, obj, obj_type):
|
||||||
|
if inspect.iscoroutinefunction(self.func):
|
||||||
|
return functools.partial(self._async_call, obj)
|
||||||
return functools.partial(self.__call__, obj)
|
return functools.partial(self.__call__, obj)
|
||||||
|
|
||||||
def error(self, data):
|
def error(self, data):
|
||||||
return JSONResponse.response({"error": "permission-denied", "data": data})
|
return JSONResponse.response({"error": "permission-denied", "data": data})
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
self.request = args[1]
|
request = args[1]
|
||||||
|
|
||||||
if self.check_permission():
|
if self.check_permission(request):
|
||||||
if self.request.user.is_disabled:
|
if request.user.is_disabled:
|
||||||
return self.error("Your account is disabled")
|
return self.error("Your account is disabled")
|
||||||
return self.func(*args, **kwargs)
|
return self.func(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
return self.error("Please login first")
|
return self.error("Please login first")
|
||||||
|
|
||||||
def check_permission(self):
|
async def _async_call(self, *args, **kwargs):
|
||||||
|
request = args[1]
|
||||||
|
|
||||||
|
if self.check_permission(request):
|
||||||
|
if request.user.is_disabled:
|
||||||
|
return self.error("Your account is disabled")
|
||||||
|
return await self.func(*args, **kwargs)
|
||||||
|
return self.error("Please login first")
|
||||||
|
|
||||||
|
def check_permission(self, request):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
class login_required(BasePermissionDecorator):
|
class login_required(BasePermissionDecorator):
|
||||||
def check_permission(self):
|
def check_permission(self, request):
|
||||||
return self.request.user.is_authenticated
|
return request.user.is_authenticated
|
||||||
|
|
||||||
|
|
||||||
class super_admin_required(BasePermissionDecorator):
|
class super_admin_required(BasePermissionDecorator):
|
||||||
def check_permission(self):
|
def check_permission(self, request):
|
||||||
user = self.request.user
|
user = request.user
|
||||||
return user.is_authenticated and user.is_super_admin()
|
return user.is_authenticated and user.is_super_admin()
|
||||||
|
|
||||||
|
|
||||||
class admin_role_required(BasePermissionDecorator):
|
class admin_role_required(BasePermissionDecorator):
|
||||||
def check_permission(self):
|
def check_permission(self, request):
|
||||||
user = self.request.user
|
user = request.user
|
||||||
return user.is_authenticated and user.is_admin_role()
|
return user.is_authenticated and user.is_admin_role()
|
||||||
|
|
||||||
|
|
||||||
class problem_permission_required(admin_role_required):
|
class problem_permission_required(admin_role_required):
|
||||||
def check_permission(self):
|
def check_permission(self, request):
|
||||||
if not super(problem_permission_required, self).check_permission():
|
if not super().check_permission(request):
|
||||||
return False
|
return False
|
||||||
if self.request.user.problem_permission == ProblemPermission.NONE:
|
if request.user.problem_permission == ProblemPermission.NONE:
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ from account.models import UserProfile
|
|||||||
from problem.models import Problem
|
from problem.models import Problem
|
||||||
from submission.models import JudgeStatus
|
from submission.models import JudgeStatus
|
||||||
|
|
||||||
|
|
||||||
ACCEPTED_STATUSES = {JudgeStatus.ACCEPTED, JudgeStatus.AST_CHECK_FAILED}
|
ACCEPTED_STATUSES = {JudgeStatus.ACCEPTED, JudgeStatus.AST_CHECK_FAILED}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -23,6 +23,9 @@ class UserManager(models.Manager):
|
|||||||
def get_by_natural_key(self, username):
|
def get_by_natural_key(self, username):
|
||||||
return self.get(**{f"{self.model.USERNAME_FIELD}__iexact": username})
|
return self.get(**{f"{self.model.USERNAME_FIELD}__iexact": username})
|
||||||
|
|
||||||
|
async def aget_by_natural_key(self, username):
|
||||||
|
return await self.aget(**{f"{self.model.USERNAME_FIELD}__iexact": username})
|
||||||
|
|
||||||
|
|
||||||
class User(AbstractBaseUser):
|
class User(AbstractBaseUser):
|
||||||
username = models.TextField(unique=True)
|
username = models.TextField(unique=True)
|
||||||
|
|||||||
@@ -4,6 +4,6 @@ from ..views.admin import GenerateUserAPI, ResetUserPasswordAPI, UserAdminAPI
|
|||||||
|
|
||||||
urlpatterns = [
|
urlpatterns = [
|
||||||
path("user", UserAdminAPI.as_view()),
|
path("user", UserAdminAPI.as_view()),
|
||||||
path("generate_user", GenerateUserAPI.as_view()),
|
path("generate_user", GenerateUserAPI.as_view()), # DEPRECATED: 前端未调用
|
||||||
path("reset_password", ResetUserPasswordAPI.as_view()),
|
path("reset_password", ResetUserPasswordAPI.as_view()),
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -29,28 +29,28 @@ urlpatterns = [
|
|||||||
path("login", UserLoginAPI.as_view()),
|
path("login", UserLoginAPI.as_view()),
|
||||||
path("logout", UserLogoutAPI.as_view()),
|
path("logout", UserLogoutAPI.as_view()),
|
||||||
path("register", UserRegisterAPI.as_view()),
|
path("register", UserRegisterAPI.as_view()),
|
||||||
path("change_password", UserChangePasswordAPI.as_view()),
|
path("change_password", UserChangePasswordAPI.as_view()), # DEPRECATED: 前端未调用
|
||||||
path("change_email", UserChangeEmailAPI.as_view()),
|
path("change_email", UserChangeEmailAPI.as_view()), # DEPRECATED: 前端未调用
|
||||||
path("apply_reset_password", ApplyResetPasswordAPI.as_view()),
|
path("apply_reset_password", ApplyResetPasswordAPI.as_view()), # DEPRECATED: 前端未调用
|
||||||
path("reset_password", ResetPasswordAPI.as_view()),
|
path("reset_password", ResetPasswordAPI.as_view()), # DEPRECATED: 前端未调用
|
||||||
path("captcha", CaptchaAPIView.as_view()),
|
path("captcha", CaptchaAPIView.as_view()),
|
||||||
path("check_username_or_email", UsernameOrEmailCheck.as_view()),
|
path("check_username_or_email", UsernameOrEmailCheck.as_view()), # DEPRECATED: 前端未调用
|
||||||
path("profile", UserProfileAPI.as_view(), name="user_profile_api"),
|
path("profile", UserProfileAPI.as_view(), name="user_profile_api"),
|
||||||
path("profile/fresh_display_id", ProfileProblemDisplayIDRefreshAPI.as_view()),
|
path("profile/fresh_display_id", ProfileProblemDisplayIDRefreshAPI.as_view()),
|
||||||
path("metrics", Metrics.as_view()),
|
path("metrics", Metrics.as_view()),
|
||||||
path("upload_avatar", AvatarUploadAPI.as_view()),
|
path("upload_avatar", AvatarUploadAPI.as_view()),
|
||||||
path("tfa_required", CheckTFARequiredAPI.as_view()),
|
path("tfa_required", CheckTFARequiredAPI.as_view()), # DEPRECATED: 前端未调用
|
||||||
path(
|
path(
|
||||||
"two_factor_auth",
|
"two_factor_auth", # DEPRECATED: 前端未调用
|
||||||
TwoFactorAuthAPI.as_view(),
|
TwoFactorAuthAPI.as_view(),
|
||||||
),
|
),
|
||||||
path("user_rank", UserRankAPI.as_view()),
|
path("user_rank", UserRankAPI.as_view()),
|
||||||
path("user_activity_rank", UserActivityRankAPI.as_view()),
|
path("user_activity_rank", UserActivityRankAPI.as_view()),
|
||||||
path("user_problem_rank", UserProblemRankAPI.as_view()),
|
path("user_problem_rank", UserProblemRankAPI.as_view()),
|
||||||
path("sessions", SessionManagementAPI.as_view()),
|
path("sessions", SessionManagementAPI.as_view()), # DEPRECATED: 前端未调用
|
||||||
path(
|
path(
|
||||||
"open_api_appkey",
|
"open_api_appkey", # DEPRECATED: 前端未调用
|
||||||
OpenAPIAppkeyAPI.as_view(),
|
OpenAPIAppkeyAPI.as_view(),
|
||||||
),
|
),
|
||||||
path("sso", SSOAPI.as_view()),
|
path("sso", SSOAPI.as_view()), # DEPRECATED: 前端未调用
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -191,6 +191,7 @@ class UserAdminAPI(APIView):
|
|||||||
return self.success()
|
return self.success()
|
||||||
|
|
||||||
|
|
||||||
|
# DEPRECATED: 前端未调用 (2026-05-26)
|
||||||
class GenerateUserAPI(APIView):
|
class GenerateUserAPI(APIView):
|
||||||
@super_admin_required
|
@super_admin_required
|
||||||
def get(self, request):
|
def get(self, request):
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import asyncio
|
||||||
import os
|
import os
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
@@ -5,7 +6,6 @@ from importlib import import_module
|
|||||||
import qrcode
|
import qrcode
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.contrib import auth
|
from django.contrib import auth
|
||||||
from django.core.cache import cache
|
|
||||||
from django.db.models import Count, Q
|
from django.db.models import Count, Q
|
||||||
from django.template.loader import render_to_string
|
from django.template.loader import render_to_string
|
||||||
from django.utils import timezone
|
from django.utils import timezone
|
||||||
@@ -16,8 +16,9 @@ from otpauth import TOTP
|
|||||||
|
|
||||||
from options.options import SysOptions
|
from options.options import SysOptions
|
||||||
from problem.models import Problem
|
from problem.models import Problem
|
||||||
from submission.models import JudgeStatus, Submission, is_accepted
|
from submission.models import JudgeStatus, Submission
|
||||||
from utils.api import APIView, CSRFExemptAPIView, validate_serializer
|
from utils.api import APIView, AsyncAPIView, CSRFExemptAPIView, validate_serializer
|
||||||
|
from utils.async_helpers import async_cache_get, async_cache_set
|
||||||
from utils.captcha import Captcha
|
from utils.captcha import Captcha
|
||||||
from utils.constants import CacheKey, ContestRuleType
|
from utils.constants import CacheKey, ContestRuleType
|
||||||
from utils.shortcuts import datetime2str, img2base64, rand_str
|
from utils.shortcuts import datetime2str, img2base64, rand_str
|
||||||
@@ -58,12 +59,9 @@ def _valid_totp(token, code):
|
|||||||
return _totp(token).verify(code)
|
return _totp(token).verify(code)
|
||||||
|
|
||||||
|
|
||||||
class UserProfileAPI(APIView):
|
class UserProfileAPI(AsyncAPIView):
|
||||||
@method_decorator(ensure_csrf_cookie)
|
@method_decorator(ensure_csrf_cookie)
|
||||||
def get(self, request, **kwargs):
|
async def get(self, request, **kwargs):
|
||||||
"""
|
|
||||||
判断是否登录, 若登录返回用户信息
|
|
||||||
"""
|
|
||||||
user = request.user
|
user = request.user
|
||||||
if not user.is_authenticated:
|
if not user.is_authenticated:
|
||||||
return self.success()
|
return self.success()
|
||||||
@@ -71,52 +69,51 @@ class UserProfileAPI(APIView):
|
|||||||
username = request.GET.get("username")
|
username = request.GET.get("username")
|
||||||
try:
|
try:
|
||||||
if username:
|
if username:
|
||||||
user = User.objects.get(username=username, is_disabled=False)
|
user = await User.objects.aget(username=username, is_disabled=False)
|
||||||
else:
|
else:
|
||||||
user = request.user
|
user = request.user
|
||||||
# api返回的是自己的信息,可以返real_name
|
|
||||||
show_real_name = True
|
show_real_name = True
|
||||||
except User.DoesNotExist:
|
except User.DoesNotExist:
|
||||||
return self.error("User does not exist")
|
return self.error("User does not exist")
|
||||||
return self.success(UserProfileSerializer(user.userprofile, show_real_name=show_real_name).data)
|
profile = await UserProfile.objects.select_related("user").aget(user=user)
|
||||||
|
return self.success(UserProfileSerializer(profile, show_real_name=show_real_name).data)
|
||||||
|
|
||||||
@validate_serializer(EditUserProfileSerializer)
|
@validate_serializer(EditUserProfileSerializer)
|
||||||
@login_required
|
@login_required
|
||||||
def put(self, request):
|
async def put(self, request):
|
||||||
data = request.data
|
data = request.data
|
||||||
user_profile = request.user.userprofile
|
user_profile = await UserProfile.objects.select_related("user").aget(user=request.user)
|
||||||
for k, v in data.items():
|
for k, v in data.items():
|
||||||
setattr(user_profile, k, v)
|
setattr(user_profile, k, v)
|
||||||
user_profile.save()
|
await user_profile.asave()
|
||||||
return self.success(UserProfileSerializer(user_profile, show_real_name=True).data)
|
return self.success(UserProfileSerializer(user_profile, show_real_name=True).data)
|
||||||
|
|
||||||
|
|
||||||
class Metrics(APIView):
|
class Metrics(AsyncAPIView):
|
||||||
def get(self, request):
|
async def get(self, request):
|
||||||
userid = request.GET.get("userid")
|
userid = request.GET.get("userid")
|
||||||
submissions = Submission.objects.filter(user_id=userid, contest_id__isnull=True)
|
qs = Submission.objects.filter(user_id=userid, contest_id__isnull=True)
|
||||||
if submissions.count() == 0:
|
count, latest, first = await asyncio.gather(
|
||||||
|
qs.acount(),
|
||||||
|
qs.order_by("-create_time").afirst(),
|
||||||
|
qs.order_by("create_time").afirst(),
|
||||||
|
)
|
||||||
|
if count == 0 or not latest or not first:
|
||||||
return self.error("暂无提交")
|
return self.error("暂无提交")
|
||||||
else:
|
return self.success(
|
||||||
latest_submission = submissions.first()
|
{
|
||||||
last_submission = submissions.last()
|
"now": datetime2str(timezone.now()),
|
||||||
if last_submission and latest_submission:
|
"latest": datetime2str(latest.create_time),
|
||||||
return self.success(
|
"first": datetime2str(first.create_time),
|
||||||
{
|
}
|
||||||
"now": datetime2str(timezone.now()),
|
)
|
||||||
"latest": datetime2str(latest_submission.create_time),
|
|
||||||
"first": datetime2str(last_submission.create_time),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return self.error("暂无提交")
|
|
||||||
|
|
||||||
|
|
||||||
class AvatarUploadAPI(APIView):
|
class AvatarUploadAPI(AsyncAPIView):
|
||||||
request_parsers = ()
|
request_parsers = ()
|
||||||
|
|
||||||
@login_required
|
@login_required
|
||||||
def post(self, request):
|
async def post(self, request):
|
||||||
form = ImageUploadForm(request.POST, request.FILES)
|
form = ImageUploadForm(request.POST, request.FILES)
|
||||||
if form.is_valid():
|
if form.is_valid():
|
||||||
avatar = form.cleaned_data["image"]
|
avatar = form.cleaned_data["image"]
|
||||||
@@ -132,13 +129,14 @@ class AvatarUploadAPI(APIView):
|
|||||||
with open(os.path.join(settings.AVATAR_UPLOAD_DIR, name), "wb") as img:
|
with open(os.path.join(settings.AVATAR_UPLOAD_DIR, name), "wb") as img:
|
||||||
for chunk in avatar:
|
for chunk in avatar:
|
||||||
img.write(chunk)
|
img.write(chunk)
|
||||||
user_profile = request.user.userprofile
|
user_profile = await UserProfile.objects.aget(user=request.user)
|
||||||
|
|
||||||
user_profile.avatar = f"{settings.AVATAR_URI_PREFIX}/{name}"
|
user_profile.avatar = f"{settings.AVATAR_URI_PREFIX}/{name}"
|
||||||
user_profile.save()
|
await user_profile.asave()
|
||||||
return self.success("Succeeded")
|
return self.success("Succeeded")
|
||||||
|
|
||||||
|
|
||||||
|
# DEPRECATED: 前端未调用 (2026-05-26)
|
||||||
class TwoFactorAuthAPI(APIView):
|
class TwoFactorAuthAPI(APIView):
|
||||||
@login_required
|
@login_required
|
||||||
def get(self, request):
|
def get(self, request):
|
||||||
@@ -186,6 +184,7 @@ class TwoFactorAuthAPI(APIView):
|
|||||||
return self.error("Invalid code")
|
return self.error("Invalid code")
|
||||||
|
|
||||||
|
|
||||||
|
# DEPRECATED: 前端未调用 (2026-05-26)
|
||||||
class CheckTFARequiredAPI(APIView):
|
class CheckTFARequiredAPI(APIView):
|
||||||
@validate_serializer(UsernameOrEmailCheckSerializer)
|
@validate_serializer(UsernameOrEmailCheckSerializer)
|
||||||
def post(self, request):
|
def post(self, request):
|
||||||
@@ -203,31 +202,26 @@ class CheckTFARequiredAPI(APIView):
|
|||||||
return self.success({"result": result})
|
return self.success({"result": result})
|
||||||
|
|
||||||
|
|
||||||
class UserLoginAPI(APIView):
|
class UserLoginAPI(AsyncAPIView):
|
||||||
@validate_serializer(UserLoginSerializer)
|
@validate_serializer(UserLoginSerializer)
|
||||||
def post(self, request):
|
async def post(self, request):
|
||||||
"""
|
|
||||||
User login api
|
|
||||||
"""
|
|
||||||
data = request.data
|
data = request.data
|
||||||
user = auth.authenticate(username=data["username"], password=data["password"])
|
user = await auth.aauthenticate(username=data["username"], password=data["password"])
|
||||||
# None is returned if username or password is wrong
|
|
||||||
if user:
|
if user:
|
||||||
if user.is_disabled:
|
if user.is_disabled:
|
||||||
return self.error("Your account has been disabled")
|
return self.error("Your account has been disabled")
|
||||||
if not user.two_factor_auth:
|
if not user.two_factor_auth:
|
||||||
prev_login = user.last_login
|
prev_login = user.last_login
|
||||||
auth.login(request, user)
|
await auth.alogin(request, user)
|
||||||
request.session["prev_login"] = datetime2str(prev_login) if prev_login else ""
|
request.session["prev_login"] = datetime2str(prev_login) if prev_login else ""
|
||||||
return self.success("Succeeded")
|
return self.success("Succeeded")
|
||||||
|
|
||||||
# `tfa_code` not in post data
|
|
||||||
if user.two_factor_auth and "tfa_code" not in data:
|
if user.two_factor_auth and "tfa_code" not in data:
|
||||||
return self.error("tfa_required")
|
return self.error("tfa_required")
|
||||||
|
|
||||||
if _valid_totp(user.tfa_token, data["tfa_code"]):
|
if _valid_totp(user.tfa_token, data["tfa_code"]):
|
||||||
prev_login = user.last_login
|
prev_login = user.last_login
|
||||||
auth.login(request, user)
|
await auth.alogin(request, user)
|
||||||
request.session["prev_login"] = datetime2str(prev_login) if prev_login else ""
|
request.session["prev_login"] = datetime2str(prev_login) if prev_login else ""
|
||||||
return self.success("Succeeded")
|
return self.success("Succeeded")
|
||||||
else:
|
else:
|
||||||
@@ -236,12 +230,13 @@ class UserLoginAPI(APIView):
|
|||||||
return self.error("Invalid username or password")
|
return self.error("Invalid username or password")
|
||||||
|
|
||||||
|
|
||||||
class UserLogoutAPI(APIView):
|
class UserLogoutAPI(AsyncAPIView):
|
||||||
def get(self, request):
|
async def get(self, request):
|
||||||
auth.logout(request)
|
await auth.alogout(request)
|
||||||
return self.success()
|
return self.success()
|
||||||
|
|
||||||
|
|
||||||
|
# DEPRECATED: 前端未调用 (2026-05-26)
|
||||||
class UsernameOrEmailCheck(APIView):
|
class UsernameOrEmailCheck(APIView):
|
||||||
@validate_serializer(UsernameOrEmailCheckSerializer)
|
@validate_serializer(UsernameOrEmailCheckSerializer)
|
||||||
def post(self, request):
|
def post(self, request):
|
||||||
@@ -258,13 +253,10 @@ class UsernameOrEmailCheck(APIView):
|
|||||||
return self.success(result)
|
return self.success(result)
|
||||||
|
|
||||||
|
|
||||||
class UserRegisterAPI(APIView):
|
class UserRegisterAPI(AsyncAPIView):
|
||||||
@validate_serializer(UserRegisterSerializer)
|
@validate_serializer(UserRegisterSerializer)
|
||||||
def post(self, request):
|
async def post(self, request):
|
||||||
"""
|
if not await SysOptions.aget("allow_register"):
|
||||||
User register api
|
|
||||||
"""
|
|
||||||
if not SysOptions.allow_register:
|
|
||||||
return self.error("Register function has been disabled by admin")
|
return self.error("Register function has been disabled by admin")
|
||||||
|
|
||||||
data = request.data
|
data = request.data
|
||||||
@@ -273,17 +265,18 @@ class UserRegisterAPI(APIView):
|
|||||||
captcha = Captcha(request)
|
captcha = Captcha(request)
|
||||||
if not captcha.check(data["captcha"]):
|
if not captcha.check(data["captcha"]):
|
||||||
return self.error("Invalid captcha")
|
return self.error("Invalid captcha")
|
||||||
if User.objects.filter(username=data["username"]).exists():
|
if await User.objects.filter(username=data["username"]).aexists():
|
||||||
return self.error("Username already exists")
|
return self.error("Username already exists")
|
||||||
if User.objects.filter(email=data["email"]).exists():
|
if await User.objects.filter(email=data["email"]).aexists():
|
||||||
return self.error("Email already exists")
|
return self.error("Email already exists")
|
||||||
user = User.objects.create(username=data["username"], email=data["email"])
|
user = await User.objects.acreate(username=data["username"], email=data["email"])
|
||||||
user.set_password(data["password"])
|
user.set_password(data["password"])
|
||||||
user.save()
|
await user.asave()
|
||||||
UserProfile.objects.create(user=user)
|
await UserProfile.objects.acreate(user=user)
|
||||||
return self.success("Succeeded")
|
return self.success("Succeeded")
|
||||||
|
|
||||||
|
|
||||||
|
# DEPRECATED: 前端未调用 (2026-05-26)
|
||||||
class UserChangeEmailAPI(APIView):
|
class UserChangeEmailAPI(APIView):
|
||||||
@validate_serializer(UserChangeEmailSerializer)
|
@validate_serializer(UserChangeEmailSerializer)
|
||||||
@login_required
|
@login_required
|
||||||
@@ -306,6 +299,7 @@ class UserChangeEmailAPI(APIView):
|
|||||||
return self.error("Wrong password")
|
return self.error("Wrong password")
|
||||||
|
|
||||||
|
|
||||||
|
# DEPRECATED: 前端未调用 (2026-05-26)
|
||||||
class UserChangePasswordAPI(APIView):
|
class UserChangePasswordAPI(APIView):
|
||||||
@validate_serializer(UserChangePasswordSerializer)
|
@validate_serializer(UserChangePasswordSerializer)
|
||||||
@login_required
|
@login_required
|
||||||
@@ -329,6 +323,7 @@ class UserChangePasswordAPI(APIView):
|
|||||||
return self.error("Invalid old password")
|
return self.error("Invalid old password")
|
||||||
|
|
||||||
|
|
||||||
|
# DEPRECATED: 前端未调用 (2026-05-26)
|
||||||
class ApplyResetPasswordAPI(APIView):
|
class ApplyResetPasswordAPI(APIView):
|
||||||
@validate_serializer(ApplyResetPasswordSerializer)
|
@validate_serializer(ApplyResetPasswordSerializer)
|
||||||
def post(self, request):
|
def post(self, request):
|
||||||
@@ -363,6 +358,7 @@ class ApplyResetPasswordAPI(APIView):
|
|||||||
return self.success("Succeeded")
|
return self.success("Succeeded")
|
||||||
|
|
||||||
|
|
||||||
|
# DEPRECATED: 前端未调用 (2026-05-26)
|
||||||
class ResetPasswordAPI(APIView):
|
class ResetPasswordAPI(APIView):
|
||||||
@validate_serializer(ResetPasswordSerializer)
|
@validate_serializer(ResetPasswordSerializer)
|
||||||
def post(self, request):
|
def post(self, request):
|
||||||
@@ -383,6 +379,7 @@ class ResetPasswordAPI(APIView):
|
|||||||
return self.success("Succeeded")
|
return self.success("Succeeded")
|
||||||
|
|
||||||
|
|
||||||
|
# DEPRECATED: 前端未调用 (2026-05-26)
|
||||||
class SessionManagementAPI(APIView):
|
class SessionManagementAPI(APIView):
|
||||||
@login_required
|
@login_required
|
||||||
def get(self, request):
|
def get(self, request):
|
||||||
@@ -426,8 +423,8 @@ class SessionManagementAPI(APIView):
|
|||||||
return self.error("Invalid session_key")
|
return self.error("Invalid session_key")
|
||||||
|
|
||||||
|
|
||||||
class UserRankAPI(APIView):
|
class UserRankAPI(AsyncAPIView):
|
||||||
def get(self, request):
|
async def get(self, request):
|
||||||
rule_type = request.GET.get("rule")
|
rule_type = request.GET.get("rule")
|
||||||
username = request.GET.get("username", "")
|
username = request.GET.get("username", "")
|
||||||
try:
|
try:
|
||||||
@@ -448,16 +445,16 @@ class UserRankAPI(APIView):
|
|||||||
profiles = profiles.filter(total_score__gt=0).order_by("-total_score")
|
profiles = profiles.filter(total_score__gt=0).order_by("-total_score")
|
||||||
if n > 0:
|
if n > 0:
|
||||||
profiles = profiles[:n]
|
profiles = profiles[:n]
|
||||||
return self.success(self.paginate_data(request, profiles, RankInfoSerializer))
|
return self.success(await self.async_paginate_data(request, profiles, RankInfoSerializer))
|
||||||
|
|
||||||
|
|
||||||
class UserActivityRankAPI(APIView):
|
class UserActivityRankAPI(AsyncAPIView):
|
||||||
def get(self, request):
|
async def get(self, request):
|
||||||
start = request.GET.get("start")
|
start = request.GET.get("start")
|
||||||
if not start:
|
if not start:
|
||||||
return self.error("start time is required")
|
return self.error("start time is required")
|
||||||
cache_key = f"{CacheKey.user_activity_rank}:{start}"
|
cache_key = f"{CacheKey.user_activity_rank}:{start}"
|
||||||
cached = cache.get(cache_key)
|
cached = await async_cache_get(cache_key)
|
||||||
if cached is not None:
|
if cached is not None:
|
||||||
return self.success(cached)
|
return self.success(cached)
|
||||||
|
|
||||||
@@ -467,35 +464,40 @@ class UserActivityRankAPI(APIView):
|
|||||||
create_time__gte=start,
|
create_time__gte=start,
|
||||||
result__in=[JudgeStatus.ACCEPTED, JudgeStatus.AST_CHECK_FAILED],
|
result__in=[JudgeStatus.ACCEPTED, JudgeStatus.AST_CHECK_FAILED],
|
||||||
).exclude(username__in=hidden_names)
|
).exclude(username__in=hidden_names)
|
||||||
data = list(submissions.values("username").annotate(count=Count("problem_id", distinct=True)).order_by("-count")[:10])
|
data = [
|
||||||
cache.set(cache_key, data, 600)
|
row
|
||||||
|
async for row in submissions.values("username")
|
||||||
|
.annotate(count=Count("problem_id", distinct=True))
|
||||||
|
.order_by("-count")[:10]
|
||||||
|
]
|
||||||
|
await async_cache_set(cache_key, data, 600)
|
||||||
return self.success(data)
|
return self.success(data)
|
||||||
|
|
||||||
|
|
||||||
class UserProblemRankAPI(APIView):
|
class UserProblemRankAPI(AsyncAPIView):
|
||||||
def get(self, request):
|
async def get(self, request):
|
||||||
problem_id = request.GET.get("problem_id")
|
problem_id = request.GET.get("problem_id")
|
||||||
user = request.user
|
user = request.user
|
||||||
if not user.is_authenticated:
|
if not user.is_authenticated:
|
||||||
return self.error("User is not authenticated")
|
return self.error("User is not authenticated")
|
||||||
|
|
||||||
problem = Problem.objects.get(_id__iexact=problem_id, contest_id__isnull=True, visible=True)
|
problem = await Problem.objects.aget(_id__iexact=problem_id, contest_id__isnull=True, visible=True)
|
||||||
submissions = Submission.objects.filter(problem=problem, result__in=[JudgeStatus.ACCEPTED, JudgeStatus.AST_CHECK_FAILED])
|
submissions = Submission.objects.filter(problem=problem, result__in=[JudgeStatus.ACCEPTED, JudgeStatus.AST_CHECK_FAILED])
|
||||||
|
|
||||||
all_ac_count = submissions.values("user_id").distinct().count()
|
all_ac_count = await submissions.values("user_id").distinct().acount()
|
||||||
|
|
||||||
class_name = user.class_name or ""
|
class_name = user.class_name or ""
|
||||||
class_ac_count = 0
|
class_ac_count = 0
|
||||||
|
|
||||||
if class_name:
|
if class_name:
|
||||||
users = User.objects.filter(class_name=user.class_name, is_disabled=False).values_list("id", flat=True)
|
users = User.objects.filter(class_name=user.class_name, is_disabled=False).values_list("id", flat=True)
|
||||||
user_ids = list(users)
|
user_ids = [user_id async for user_id in users]
|
||||||
submissions = submissions.filter(user_id__in=user_ids)
|
submissions = submissions.filter(user_id__in=user_ids)
|
||||||
class_ac_count = submissions.values("user_id").distinct().count()
|
class_ac_count = await submissions.values("user_id").distinct().acount()
|
||||||
|
|
||||||
my_submissions = submissions.filter(user_id=user.id)
|
my_submissions = submissions.filter(user_id=user.id)
|
||||||
|
|
||||||
if len(my_submissions) == 0:
|
if not await my_submissions.aexists():
|
||||||
return self.success(
|
return self.success(
|
||||||
{
|
{
|
||||||
"class_name": class_name,
|
"class_name": class_name,
|
||||||
@@ -505,8 +507,8 @@ class UserProblemRankAPI(APIView):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
my_first_submission = my_submissions.order_by("create_time").first()
|
my_first_submission = await my_submissions.order_by("create_time").afirst()
|
||||||
rank = submissions.filter(create_time__lte=my_first_submission.create_time).count()
|
rank = await submissions.filter(create_time__lte=my_first_submission.create_time).acount()
|
||||||
return self.success(
|
return self.success(
|
||||||
{
|
{
|
||||||
"class_name": class_name,
|
"class_name": class_name,
|
||||||
@@ -517,25 +519,26 @@ class UserProblemRankAPI(APIView):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ProfileProblemDisplayIDRefreshAPI(APIView):
|
class ProfileProblemDisplayIDRefreshAPI(AsyncAPIView):
|
||||||
@login_required
|
@login_required
|
||||||
def get(self, request):
|
async def get(self, request):
|
||||||
profile = request.user.userprofile
|
profile = await UserProfile.objects.aget(user=request.user)
|
||||||
acm_problems = profile.acm_problems_status.get("problems", {})
|
acm_problems = profile.acm_problems_status.get("problems", {})
|
||||||
oi_problems = profile.oi_problems_status.get("problems", {})
|
oi_problems = profile.oi_problems_status.get("problems", {})
|
||||||
ids = list(acm_problems.keys()) + list(oi_problems.keys())
|
ids = list(acm_problems.keys()) + list(oi_problems.keys())
|
||||||
if not ids:
|
if not ids:
|
||||||
return self.success()
|
return self.success()
|
||||||
display_ids = Problem.objects.filter(id__in=ids, visible=True).values_list("_id", flat=True)
|
display_ids = [did async for did in Problem.objects.filter(id__in=ids, visible=True).values_list("_id", flat=True)]
|
||||||
id_map = dict(zip(ids, display_ids))
|
id_map = dict(zip(ids, display_ids))
|
||||||
for k, v in acm_problems.items():
|
for k, v in acm_problems.items():
|
||||||
v["_id"] = id_map[k]
|
v["_id"] = id_map[k]
|
||||||
for k, v in oi_problems.items():
|
for k, v in oi_problems.items():
|
||||||
v["_id"] = id_map[k]
|
v["_id"] = id_map[k]
|
||||||
profile.save(update_fields=["acm_problems_status", "oi_problems_status"])
|
await profile.asave(update_fields=["acm_problems_status", "oi_problems_status"])
|
||||||
return self.success()
|
return self.success()
|
||||||
|
|
||||||
|
|
||||||
|
# DEPRECATED: 前端未调用 (2026-05-26)
|
||||||
class OpenAPIAppkeyAPI(APIView):
|
class OpenAPIAppkeyAPI(APIView):
|
||||||
@login_required
|
@login_required
|
||||||
def post(self, request):
|
def post(self, request):
|
||||||
@@ -548,6 +551,7 @@ class OpenAPIAppkeyAPI(APIView):
|
|||||||
return self.success({"appkey": api_appkey})
|
return self.success({"appkey": api_appkey})
|
||||||
|
|
||||||
|
|
||||||
|
# DEPRECATED: 前端未调用 (2026-05-26)
|
||||||
class SSOAPI(CSRFExemptAPIView):
|
class SSOAPI(CSRFExemptAPIView):
|
||||||
@login_required
|
@login_required
|
||||||
def get(self, request):
|
def get(self, request):
|
||||||
|
|||||||
@@ -1,19 +1,25 @@
|
|||||||
from announcement.models import Announcement
|
from announcement.models import Announcement
|
||||||
from announcement.serializers import AnnouncementListSerializer, AnnouncementSerializer
|
from announcement.serializers import AnnouncementListSerializer, AnnouncementSerializer
|
||||||
from utils.api import APIView
|
from utils.api import AsyncAPIView
|
||||||
|
|
||||||
|
|
||||||
class AnnouncementAPI(APIView):
|
class AnnouncementAPI(AsyncAPIView):
|
||||||
def get(self, request):
|
async def get(self, request):
|
||||||
id = request.GET.get("id")
|
id = request.GET.get("id")
|
||||||
if id:
|
if id:
|
||||||
try:
|
try:
|
||||||
announcement = Announcement.objects.get(id=id, visible=True)
|
announcement = await (
|
||||||
return self.success(AnnouncementSerializer(announcement).data)
|
Announcement.objects.select_related("created_by")
|
||||||
|
.filter(id=id, visible=True)
|
||||||
|
.afirst()
|
||||||
|
)
|
||||||
|
if announcement is None:
|
||||||
|
raise Announcement.DoesNotExist
|
||||||
|
return self.success(await self.async_serialize_data(AnnouncementSerializer, announcement))
|
||||||
except Announcement.DoesNotExist:
|
except Announcement.DoesNotExist:
|
||||||
return self.error("Announcement does not exist")
|
return self.error("Announcement does not exist")
|
||||||
|
|
||||||
announcements = Announcement.objects.select_related("created_by").filter(visible=True)
|
announcements = Announcement.objects.select_related("created_by").filter(visible=True)
|
||||||
return self.success(
|
return self.success(
|
||||||
self.paginate_data(request, announcements, AnnouncementListSerializer)
|
await self.async_paginate_data(request, announcements, AnnouncementListSerializer)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from .base import BaseEngine
|
|
||||||
from ast_checker.labels import label
|
from ast_checker.labels import label
|
||||||
|
|
||||||
|
from .base import BaseEngine
|
||||||
|
|
||||||
|
|
||||||
class MustHaveNestingEngine(BaseEngine):
|
class MustHaveNestingEngine(BaseEngine):
|
||||||
def _has_inner_in_subtree(self, node, inner_type):
|
def _has_inner_in_subtree(self, node, inner_type):
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from .base import BaseEngine
|
|
||||||
from ast_checker.labels import label
|
from ast_checker.labels import label
|
||||||
|
|
||||||
|
from .base import BaseEngine
|
||||||
|
|
||||||
|
|
||||||
class CountNodeEngine(BaseEngine):
|
class CountNodeEngine(BaseEngine):
|
||||||
def _message(self, rule, count):
|
def _message(self, rule, count):
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from .base import BaseEngine
|
|
||||||
from ast_checker.labels import label
|
from ast_checker.labels import label
|
||||||
|
|
||||||
|
from .base import BaseEngine
|
||||||
|
|
||||||
|
|
||||||
class MustExistNodeEngine(BaseEngine):
|
class MustExistNodeEngine(BaseEngine):
|
||||||
def _message(self, rule):
|
def _message(self, rule):
|
||||||
|
|||||||
@@ -1,45 +1,44 @@
|
|||||||
from django.core.cache import cache
|
from django.db.models import Avg, Count
|
||||||
from django.db.models import Avg, Count
|
from django.db.models.functions import Round
|
||||||
from django.db.models.functions import Round
|
|
||||||
|
from account.decorators import login_required
|
||||||
from account.decorators import login_required
|
|
||||||
from comment.models import Comment
|
from comment.models import Comment
|
||||||
from comment.serializers import CommentSerializer, CreateCommentSerializer
|
from comment.serializers import CommentSerializer, CreateCommentSerializer
|
||||||
from problem.models import Problem
|
from problem.models import Problem
|
||||||
from submission.models import JudgeStatus, Submission
|
from submission.models import JudgeStatus, Submission
|
||||||
from utils.api import APIView
|
from utils.api import AsyncAPIView
|
||||||
from utils.api.api import validate_serializer
|
from utils.api.api import validate_serializer
|
||||||
from utils.constants import CacheKey
|
from utils.async_helpers import async_cache_delete, async_cache_get, async_cache_set
|
||||||
|
from utils.constants import CacheKey
|
||||||
|
|
||||||
|
|
||||||
class CommentAPI(APIView):
|
class CommentAPI(AsyncAPIView):
|
||||||
@validate_serializer(CreateCommentSerializer)
|
@validate_serializer(CreateCommentSerializer)
|
||||||
@login_required
|
@login_required
|
||||||
def post(self, request):
|
async def post(self, request):
|
||||||
data = request.data
|
data = request.data
|
||||||
try:
|
try:
|
||||||
problem = Problem.objects.get(id=data["problem_id"], visible=True)
|
problem = await Problem.objects.aget(id=data["problem_id"], visible=True)
|
||||||
except Problem.DoesNotExist:
|
except Problem.DoesNotExist:
|
||||||
self.error("problem is not exists")
|
return self.error("problem is not exists")
|
||||||
|
|
||||||
try:
|
submission = await (
|
||||||
submission = (
|
Submission.objects.select_related("problem")
|
||||||
Submission.objects.select_related("problem")
|
.filter(
|
||||||
.filter(
|
user_id=request.user.id,
|
||||||
user_id=request.user.id,
|
problem_id=data["problem_id"],
|
||||||
problem_id=data["problem_id"],
|
result__in=[JudgeStatus.ACCEPTED, JudgeStatus.AST_CHECK_FAILED],
|
||||||
result__in=[JudgeStatus.ACCEPTED, JudgeStatus.AST_CHECK_FAILED],
|
|
||||||
)
|
|
||||||
.first()
|
|
||||||
)
|
)
|
||||||
except Submission.DoesNotExist:
|
.afirst()
|
||||||
self.error("submission is not exists or not accepted")
|
)
|
||||||
|
if not submission:
|
||||||
|
return self.error("submission is not exists or not accepted")
|
||||||
|
|
||||||
language = submission.language
|
language = submission.language
|
||||||
if language == "Python3":
|
if language == "Python3":
|
||||||
language = "Python"
|
language = "Python"
|
||||||
|
|
||||||
Comment.objects.create(
|
await Comment.objects.acreate(
|
||||||
user=request.user,
|
user=request.user,
|
||||||
problem=problem,
|
problem=problem,
|
||||||
submission=submission,
|
submission=submission,
|
||||||
@@ -49,32 +48,35 @@ class CommentAPI(APIView):
|
|||||||
comprehensive_rating=data["comprehensive_rating"],
|
comprehensive_rating=data["comprehensive_rating"],
|
||||||
content=data["content"],
|
content=data["content"],
|
||||||
)
|
)
|
||||||
cache.delete(f"{CacheKey.comment_stats}:{problem.id}")
|
await async_cache_delete(f"{CacheKey.comment_stats}:{problem.id}")
|
||||||
return self.success()
|
return self.success()
|
||||||
|
|
||||||
@login_required
|
@login_required
|
||||||
def get(self, request):
|
async def get(self, request):
|
||||||
problem_id = request.GET.get("problem_id")
|
problem_id = request.GET.get("problem_id")
|
||||||
comment = (
|
comment = await (
|
||||||
Comment.objects.select_related("problem")
|
Comment.objects.select_related("problem")
|
||||||
.filter(user=request.user, problem_id=problem_id)
|
.filter(user=request.user, problem_id=problem_id)
|
||||||
.first()
|
.afirst()
|
||||||
)
|
)
|
||||||
if comment:
|
if comment:
|
||||||
return self.success(CommentSerializer(comment).data)
|
return self.success(await self.async_serialize_data(CommentSerializer, comment))
|
||||||
else:
|
else:
|
||||||
return self.success()
|
return self.success()
|
||||||
|
|
||||||
|
|
||||||
class CommentStatisticsAPI(APIView):
|
class CommentStatisticsAPI(AsyncAPIView):
|
||||||
def get(self, request):
|
async def get(self, request):
|
||||||
problem_id = request.GET.get("problem_id")
|
problem_id = request.GET.get("problem_id")
|
||||||
cache_key = f"{CacheKey.comment_stats}:{problem_id}"
|
cache_key = f"{CacheKey.comment_stats}:{problem_id}"
|
||||||
cached = cache.get(cache_key)
|
cached = await async_cache_get(cache_key)
|
||||||
if cached is not None:
|
if cached is not None:
|
||||||
return self.success(cached)
|
return self.success(cached)
|
||||||
|
|
||||||
agg = Comment.objects.filter(problem_id=problem_id).aggregate(
|
from asgiref.sync import sync_to_async
|
||||||
|
agg = await sync_to_async(
|
||||||
|
Comment.objects.filter(problem_id=problem_id).aggregate
|
||||||
|
)(
|
||||||
count=Count("id"),
|
count=Count("id"),
|
||||||
description=Round(Avg("description_rating"), 2),
|
description=Round(Avg("description_rating"), 2),
|
||||||
difficulty=Round(Avg("difficulty_rating"), 2),
|
difficulty=Round(Avg("difficulty_rating"), 2),
|
||||||
@@ -88,5 +90,5 @@ class CommentStatisticsAPI(APIView):
|
|||||||
"difficulty": agg["difficulty"],
|
"difficulty": agg["difficulty"],
|
||||||
"comprehensive": agg["comprehensive"],
|
"comprehensive": agg["comprehensive"],
|
||||||
}}
|
}}
|
||||||
cache.set(cache_key, data, 3600)
|
await async_cache_set(cache_key, data, 3600)
|
||||||
return self.success(data)
|
return self.success(data)
|
||||||
|
|||||||
@@ -12,12 +12,12 @@ from ..views import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
urlpatterns = [
|
urlpatterns = [
|
||||||
path("smtp", SMTPAPI.as_view()),
|
path("smtp", SMTPAPI.as_view()), # DEPRECATED: 前端未调用
|
||||||
path("smtp_test", SMTPTestAPI.as_view()),
|
path("smtp_test", SMTPTestAPI.as_view()), # DEPRECATED: 前端未调用
|
||||||
path("website", WebsiteConfigAPI.as_view()),
|
path("website", WebsiteConfigAPI.as_view()),
|
||||||
path("random_user", RandomUsernameAPI.as_view()),
|
path("random_user", RandomUsernameAPI.as_view()),
|
||||||
path("judge_server", JudgeServerAPI.as_view()),
|
path("judge_server", JudgeServerAPI.as_view()),
|
||||||
path("prune_test_case", TestCasePruneAPI.as_view()),
|
path("prune_test_case", TestCasePruneAPI.as_view()),
|
||||||
path("versions", ReleaseNotesAPI.as_view()),
|
path("versions", ReleaseNotesAPI.as_view()), # DEPRECATED: 前端未调用
|
||||||
path("dashboard_info", DashboardInfoAPI.as_view()),
|
path("dashboard_info", DashboardInfoAPI.as_view()),
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ urlpatterns = [
|
|||||||
path("website", WebsiteConfigAPI.as_view()),
|
path("website", WebsiteConfigAPI.as_view()),
|
||||||
# 这里必须要有 /
|
# 这里必须要有 /
|
||||||
path("judge_server_heartbeat/", JudgeServerHeartbeatAPI.as_view()),
|
path("judge_server_heartbeat/", JudgeServerHeartbeatAPI.as_view()),
|
||||||
path("languages", LanguagesAPI.as_view()),
|
path("languages", LanguagesAPI.as_view()), # DEPRECATED: 前端未调用
|
||||||
path("hitokoto", HitokotoAPI.as_view()),
|
path("hitokoto", HitokotoAPI.as_view()),
|
||||||
path("class_usernames", ClassUsernamesAPI.as_view()),
|
path("class_usernames", ClassUsernamesAPI.as_view()),
|
||||||
]
|
]
|
||||||
|
|||||||
128
conf/views.py
128
conf/views.py
@@ -1,3 +1,4 @@
|
|||||||
|
import asyncio
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
@@ -6,9 +7,10 @@ import re
|
|||||||
import shutil
|
import shutil
|
||||||
import smtplib
|
import smtplib
|
||||||
import time
|
import time
|
||||||
from datetime import datetime
|
from datetime import timedelta
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
from asgiref.sync import sync_to_async
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.utils import timezone
|
from django.utils import timezone
|
||||||
from requests.exceptions import RequestException
|
from requests.exceptions import RequestException
|
||||||
@@ -20,7 +22,7 @@ from judge.dispatcher import process_pending_task
|
|||||||
from options.options import SysOptions
|
from options.options import SysOptions
|
||||||
from problem.models import Problem
|
from problem.models import Problem
|
||||||
from submission.models import Submission
|
from submission.models import Submission
|
||||||
from utils.api import APIView, CSRFExemptAPIView, validate_serializer
|
from utils.api import APIView, AsyncAPIView, CSRFExemptAPIView, validate_serializer
|
||||||
from utils.cache import JsonDataLoader
|
from utils.cache import JsonDataLoader
|
||||||
from utils.shortcuts import get_env, send_email
|
from utils.shortcuts import get_env, send_email
|
||||||
from utils.websocket import push_config_update
|
from utils.websocket import push_config_update
|
||||||
@@ -38,6 +40,7 @@ from .serializers import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# DEPRECATED: 前端未调用 (2026-05-26)
|
||||||
class SMTPAPI(APIView):
|
class SMTPAPI(APIView):
|
||||||
@super_admin_required
|
@super_admin_required
|
||||||
def get(self, request):
|
def get(self, request):
|
||||||
@@ -66,6 +69,7 @@ class SMTPAPI(APIView):
|
|||||||
return self.success()
|
return self.success()
|
||||||
|
|
||||||
|
|
||||||
|
# DEPRECATED: 前端未调用 (2026-05-26)
|
||||||
class SMTPTestAPI(APIView):
|
class SMTPTestAPI(APIView):
|
||||||
@super_admin_required
|
@super_admin_required
|
||||||
@validate_serializer(TestSMTPConfigSerializer)
|
@validate_serializer(TestSMTPConfigSerializer)
|
||||||
@@ -97,35 +101,33 @@ class SMTPTestAPI(APIView):
|
|||||||
return self.success()
|
return self.success()
|
||||||
|
|
||||||
|
|
||||||
class WebsiteConfigAPI(APIView):
|
class WebsiteConfigAPI(AsyncAPIView):
|
||||||
def get(self, request):
|
async def get(self, request):
|
||||||
ret = {
|
ret = await SysOptions.aget_many(
|
||||||
key: getattr(SysOptions, key)
|
"website_base_url",
|
||||||
for key in [
|
"website_name",
|
||||||
"website_base_url",
|
"website_name_shortcut",
|
||||||
"website_name",
|
"website_footer",
|
||||||
"website_name_shortcut",
|
"allow_register",
|
||||||
"website_footer",
|
"submission_list_show_all",
|
||||||
"allow_register",
|
"class_list",
|
||||||
"submission_list_show_all",
|
"enable_maxkb",
|
||||||
"class_list",
|
)
|
||||||
"enable_maxkb",
|
|
||||||
]
|
|
||||||
}
|
|
||||||
return self.success(ret)
|
return self.success(ret)
|
||||||
|
|
||||||
@super_admin_required
|
@super_admin_required
|
||||||
@validate_serializer(CreateEditWebsiteConfigSerializer)
|
@validate_serializer(CreateEditWebsiteConfigSerializer)
|
||||||
def post(self, request):
|
async def post(self, request):
|
||||||
for k, v in request.data.items():
|
@sync_to_async
|
||||||
if k == "website_footer":
|
def _update_config(data):
|
||||||
with XSSHtml() as parser:
|
for k, v in data.items():
|
||||||
v = parser.clean(v)
|
if k == "website_footer":
|
||||||
setattr(SysOptions, k, v)
|
with XSSHtml() as parser:
|
||||||
|
v = parser.clean(v)
|
||||||
# 推送配置更新到所有连接的客户端
|
setattr(SysOptions, k, v)
|
||||||
push_config_update(k, v)
|
push_config_update(k, v)
|
||||||
|
|
||||||
|
await _update_config(request.data)
|
||||||
return self.success()
|
return self.success()
|
||||||
|
|
||||||
|
|
||||||
@@ -206,6 +208,7 @@ class JudgeServerHeartbeatAPI(CSRFExemptAPIView):
|
|||||||
return self.success()
|
return self.success()
|
||||||
|
|
||||||
|
|
||||||
|
# DEPRECATED: 前端未调用 (2026-05-26)
|
||||||
class LanguagesAPI(APIView):
|
class LanguagesAPI(APIView):
|
||||||
def get(self, request):
|
def get(self, request):
|
||||||
return self.success(
|
return self.success(
|
||||||
@@ -255,6 +258,7 @@ class TestCasePruneAPI(APIView):
|
|||||||
shutil.rmtree(test_case_dir, ignore_errors=True)
|
shutil.rmtree(test_case_dir, ignore_errors=True)
|
||||||
|
|
||||||
|
|
||||||
|
# DEPRECATED: 前端未调用 (2026-05-26)
|
||||||
class ReleaseNotesAPI(APIView):
|
class ReleaseNotesAPI(APIView):
|
||||||
def get(self, request):
|
def get(self, request):
|
||||||
try:
|
try:
|
||||||
@@ -272,24 +276,29 @@ class ReleaseNotesAPI(APIView):
|
|||||||
return self.success(releases)
|
return self.success(releases)
|
||||||
|
|
||||||
|
|
||||||
class DashboardInfoAPI(APIView):
|
class DashboardInfoAPI(AsyncAPIView):
|
||||||
def get(self, request):
|
async def get(self, request):
|
||||||
today = datetime.today()
|
now = timezone.now()
|
||||||
today_submission_count = Submission.objects.filter(
|
today_start = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||||
create_time__gte=datetime(today.year, today.month, today.day, 0, 0)
|
(
|
||||||
).count()
|
user_count,
|
||||||
recent_contest_count = Contest.objects.exclude(
|
today_submission_count,
|
||||||
end_time__lt=timezone.now()
|
recent_contest_count,
|
||||||
).count()
|
judge_servers,
|
||||||
judge_server_count = len(
|
) = await asyncio.gather(
|
||||||
list(filter(lambda x: x.status == "normal", JudgeServer.objects.all()))
|
User.objects.acount(),
|
||||||
|
Submission.objects.filter(create_time__gte=today_start).acount(),
|
||||||
|
Contest.objects.exclude(end_time__lt=timezone.now()).acount(),
|
||||||
|
JudgeServer.objects.filter(
|
||||||
|
last_heartbeat__gte=timezone.now() - timedelta(seconds=6)
|
||||||
|
).acount(),
|
||||||
)
|
)
|
||||||
return self.success(
|
return self.success(
|
||||||
{
|
{
|
||||||
"user_count": User.objects.count(),
|
"user_count": user_count,
|
||||||
"recent_contest_count": recent_contest_count,
|
"recent_contest_count": recent_contest_count,
|
||||||
"today_submission_count": today_submission_count,
|
"today_submission_count": today_submission_count,
|
||||||
"judge_server_count": judge_server_count,
|
"judge_server_count": judge_servers,
|
||||||
"env": {
|
"env": {
|
||||||
"FORCE_HTTPS": get_env("FORCE_HTTPS", default=False),
|
"FORCE_HTTPS": get_env("FORCE_HTTPS", default=False),
|
||||||
"STATIC_CDN_HOST": get_env("STATIC_CDN_HOST", default=""),
|
"STATIC_CDN_HOST": get_env("STATIC_CDN_HOST", default=""),
|
||||||
@@ -298,24 +307,21 @@ class DashboardInfoAPI(APIView):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class RandomUsernameAPI(APIView):
|
class RandomUsernameAPI(AsyncAPIView):
|
||||||
def get(self, request):
|
async def get(self, request):
|
||||||
classroom = request.GET.get("classroom", "")
|
classroom = request.GET.get("classroom", "")
|
||||||
if not classroom:
|
if not classroom:
|
||||||
return self.error("需要班级号")
|
return self.error("需要班级号")
|
||||||
usernames = (
|
usernames = [
|
||||||
User.objects.filter(username__istartswith=classroom)
|
u async for u in User.objects.filter(username__istartswith=classroom)
|
||||||
.values_list("username", flat=True)
|
.values_list("username", flat=True)
|
||||||
.order_by("?")
|
.order_by("?")[:10]
|
||||||
)
|
]
|
||||||
if len(usernames) > 10:
|
return self.success(usernames)
|
||||||
return self.success(usernames[:10])
|
|
||||||
else:
|
|
||||||
return self.success(usernames)
|
|
||||||
|
|
||||||
|
|
||||||
class HitokotoAPI(APIView):
|
class HitokotoAPI(AsyncAPIView):
|
||||||
def get(self, request):
|
async def get(self, request):
|
||||||
try:
|
try:
|
||||||
categories = JsonDataLoader.load_data(
|
categories = JsonDataLoader.load_data(
|
||||||
settings.HITOKOTO_DIR, "categories.json"
|
settings.HITOKOTO_DIR, "categories.json"
|
||||||
@@ -328,20 +334,14 @@ class HitokotoAPI(APIView):
|
|||||||
return self.error("获取一言失败,请稍后再试")
|
return self.error("获取一言失败,请稍后再试")
|
||||||
|
|
||||||
|
|
||||||
class ClassUsernamesAPI(APIView):
|
class ClassUsernamesAPI(AsyncAPIView):
|
||||||
def get(self, request):
|
async def get(self, request):
|
||||||
classroom = request.GET.get("classroom", "")
|
classroom = request.GET.get("classroom", "")
|
||||||
if not classroom:
|
if not classroom:
|
||||||
return self.error("需要班级号")
|
return self.error("需要班级号")
|
||||||
users = User.objects.filter(class_name=classroom).order_by("-create_time")
|
prefix = f"ks{classroom}"
|
||||||
names = []
|
names = [
|
||||||
for user in users:
|
user.username[len(prefix):] if user.username.startswith(prefix) else user.username
|
||||||
prefix = f"ks{classroom}"
|
async for user in User.objects.filter(class_name=classroom).order_by("-create_time")
|
||||||
result = (
|
]
|
||||||
user.username[len(prefix) :]
|
|
||||||
if user.username.startswith(prefix)
|
|
||||||
else user.username
|
|
||||||
)
|
|
||||||
names.append(result)
|
|
||||||
|
|
||||||
return self.success(names)
|
return self.success(names)
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from ..views.admin import ACMContestHelper, ContestAnnouncementAPI, ContestAPI,
|
|||||||
urlpatterns = [
|
urlpatterns = [
|
||||||
path("contest", ContestAPI.as_view()),
|
path("contest", ContestAPI.as_view()),
|
||||||
path("contest/clone", ContestCloneAPI.as_view()),
|
path("contest/clone", ContestCloneAPI.as_view()),
|
||||||
path("contest/announcement", ContestAnnouncementAPI.as_view()),
|
path("contest/announcement", ContestAnnouncementAPI.as_view()), # DEPRECATED: 前端未调用
|
||||||
path("contest/acm_helper", ACMContestHelper.as_view()),
|
path("contest/acm_helper", ACMContestHelper.as_view()),
|
||||||
path("download_submissions", DownloadContestSubmissions.as_view()),
|
path("download_submissions", DownloadContestSubmissions.as_view()), # DEPRECATED: 前端未调用
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ urlpatterns = [
|
|||||||
path("contests", ContestListAPI.as_view()),
|
path("contests", ContestListAPI.as_view()),
|
||||||
path("contest", ContestAPI.as_view()),
|
path("contest", ContestAPI.as_view()),
|
||||||
path("contest/password", ContestPasswordVerifyAPI.as_view()),
|
path("contest/password", ContestPasswordVerifyAPI.as_view()),
|
||||||
path("contest/announcement", ContestAnnouncementListAPI.as_view()),
|
path("contest/announcement", ContestAnnouncementListAPI.as_view()), # DEPRECATED: 前端未调用
|
||||||
path("contest/access", ContestAccessAPI.as_view()),
|
path("contest/access", ContestAccessAPI.as_view()),
|
||||||
path("contest_rank", ContestRankAPI.as_view()),
|
path("contest_rank", ContestRankAPI.as_view()),
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -97,6 +97,7 @@ class ContestAPI(APIView):
|
|||||||
return self.success(self.paginate_data(request, contests, ContestAdminSerializer))
|
return self.success(self.paginate_data(request, contests, ContestAdminSerializer))
|
||||||
|
|
||||||
|
|
||||||
|
# DEPRECATED: 前端未调用 (2026-05-26)
|
||||||
class ContestAnnouncementAPI(APIView):
|
class ContestAnnouncementAPI(APIView):
|
||||||
@validate_serializer(CreateContestAnnouncementSerializer)
|
@validate_serializer(CreateContestAnnouncementSerializer)
|
||||||
@super_admin_required
|
@super_admin_required
|
||||||
@@ -212,6 +213,7 @@ class ACMContestHelper(APIView):
|
|||||||
return self.success()
|
return self.success()
|
||||||
|
|
||||||
|
|
||||||
|
# DEPRECATED: 前端未调用 (2026-05-26)
|
||||||
class DownloadContestSubmissions(APIView):
|
class DownloadContestSubmissions(APIView):
|
||||||
def _dump_submissions(self, contest, exclude_admin=True):
|
def _dump_submissions(self, contest, exclude_admin=True):
|
||||||
problem_ids = contest.problem_set.all().values_list("id", "_id")
|
problem_ids = contest.problem_set.all().values_list("id", "_id")
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from account.decorators import (
|
|||||||
)
|
)
|
||||||
from account.models import AdminType
|
from account.models import AdminType
|
||||||
from problem.models import Problem
|
from problem.models import Problem
|
||||||
from utils.api import APIView, validate_serializer
|
from utils.api import APIView, AsyncAPIView, validate_serializer
|
||||||
from utils.constants import CONTEST_PASSWORD_SESSION_KEY, CacheKey, ContestRuleType, ContestStatus
|
from utils.constants import CONTEST_PASSWORD_SESSION_KEY, CacheKey, ContestRuleType, ContestStatus
|
||||||
from utils.shortcuts import check_is_id, datetime2str
|
from utils.shortcuts import check_is_id, datetime2str
|
||||||
|
|
||||||
@@ -20,6 +20,7 @@ from ..models import ACMContestRank, Contest, ContestAnnouncement, OIContestRank
|
|||||||
from ..serializers import ACMContestRankSerializer, ContestAnnouncementSerializer, ContestPasswordVerifySerializer, ContestSerializer, OIContestRankSerializer
|
from ..serializers import ACMContestRankSerializer, ContestAnnouncementSerializer, ContestPasswordVerifySerializer, ContestSerializer, OIContestRankSerializer
|
||||||
|
|
||||||
|
|
||||||
|
# DEPRECATED: 前端未调用 (2026-05-26)
|
||||||
class ContestAnnouncementListAPI(APIView):
|
class ContestAnnouncementListAPI(APIView):
|
||||||
@check_contest_permission(check_type="announcements")
|
@check_contest_permission(check_type="announcements")
|
||||||
def get(self, request):
|
def get(self, request):
|
||||||
@@ -35,22 +36,28 @@ class ContestAnnouncementListAPI(APIView):
|
|||||||
return self.success(ContestAnnouncementSerializer(data, many=True).data)
|
return self.success(ContestAnnouncementSerializer(data, many=True).data)
|
||||||
|
|
||||||
|
|
||||||
class ContestAPI(APIView):
|
class ContestAPI(AsyncAPIView):
|
||||||
def get(self, request):
|
async def get(self, request):
|
||||||
id = request.GET.get("id")
|
id = request.GET.get("id")
|
||||||
if not id or not check_is_id(id):
|
if not id or not check_is_id(id):
|
||||||
return self.error("Invalid parameter, id is required")
|
return self.error("Invalid parameter, id is required")
|
||||||
try:
|
try:
|
||||||
contest = Contest.objects.get(id=id, visible=True)
|
contest = await (
|
||||||
|
Contest.objects.select_related("created_by")
|
||||||
|
.filter(id=id, visible=True)
|
||||||
|
.afirst()
|
||||||
|
)
|
||||||
|
if contest is None:
|
||||||
|
raise Contest.DoesNotExist
|
||||||
except Contest.DoesNotExist:
|
except Contest.DoesNotExist:
|
||||||
return self.error("Contest does not exist")
|
return self.error("Contest does not exist")
|
||||||
data = ContestSerializer(contest).data
|
data = await self.async_serialize_data(ContestSerializer, contest)
|
||||||
data["now"] = datetime2str(now())
|
data["now"] = datetime2str(now())
|
||||||
return self.success(data)
|
return self.success(data)
|
||||||
|
|
||||||
|
|
||||||
class ContestListAPI(APIView):
|
class ContestListAPI(AsyncAPIView):
|
||||||
def get(self, request):
|
async def get(self, request):
|
||||||
contests = Contest.objects.select_related("created_by").filter(visible=True)
|
contests = Contest.objects.select_related("created_by").filter(visible=True)
|
||||||
keyword = request.GET.get("keyword")
|
keyword = request.GET.get("keyword")
|
||||||
rule_type = request.GET.get("rule_type")
|
rule_type = request.GET.get("rule_type")
|
||||||
@@ -70,7 +77,7 @@ class ContestListAPI(APIView):
|
|||||||
contests = contests.filter(end_time__lt=cur)
|
contests = contests.filter(end_time__lt=cur)
|
||||||
else:
|
else:
|
||||||
contests = contests.filter(start_time__lte=cur, end_time__gte=cur)
|
contests = contests.filter(start_time__lte=cur, end_time__gte=cur)
|
||||||
return self.success(self.paginate_data(request, contests, ContestSerializer))
|
return self.success(await self.async_paginate_data(request, contests, ContestSerializer))
|
||||||
|
|
||||||
|
|
||||||
class ContestPasswordVerifyAPI(APIView):
|
class ContestPasswordVerifyAPI(APIView):
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ asgiref==3.11.1 \
|
|||||||
# channels
|
# channels
|
||||||
# channels-redis
|
# channels-redis
|
||||||
# django
|
# django
|
||||||
|
# onlinejudge
|
||||||
certifi==2026.4.22 \
|
certifi==2026.4.22 \
|
||||||
--hash=sha256:3cb2210c8f88ba2318d29b0388d1023c8492ff72ecdde4ebdaddbb13a31b1c4a \
|
--hash=sha256:3cb2210c8f88ba2318d29b0388d1023c8492ff72ecdde4ebdaddbb13a31b1c4a \
|
||||||
--hash=sha256:8d455352a37b71bf76a79caa83a3d6c25afee4a385d632127b6afb3963f1c580
|
--hash=sha256:8d455352a37b71bf76a79caa83a3d6c25afee4a385d632127b6afb3963f1c580
|
||||||
|
|||||||
861
docs/superpowers/plans/2026-05-26-backend-async.md
Normal file
861
docs/superpowers/plans/2026-05-26-backend-async.md
Normal file
@@ -0,0 +1,861 @@
|
|||||||
|
# Backend Async Hardening Implementation Plan
|
||||||
|
|
||||||
|
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
|
||||||
|
|
||||||
|
**Goal:** Make the current backend async work correct first, then establish safe patterns for expanding async views without changing API response shapes.
|
||||||
|
|
||||||
|
**Architecture:** Keep the existing custom `APIView`/DRF serializer stack. Add async-safe tests and helpers around `AsyncAPIView`, explicitly preload serializer relations in converted endpoints, and use `sync_to_async(..., thread_sensitive=True)` for synchronous serializer/cache/helper code that remains inside async views.
|
||||||
|
|
||||||
|
**Tech Stack:** Django 6.0.4, custom class-based API views, Django async ORM, DRF serializers, PostgreSQL, Redis cache, Channels ASGI.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Async Rules For This Repository
|
||||||
|
|
||||||
|
1. Async conversion is valid only when the endpoint preserves URL, method, status code, JSON envelope, and permission behavior.
|
||||||
|
2. Every async view that serializes model instances must either preload all serializer relations with `select_related()` / `prefetch_related()` or run serializer `.data` through a sync boundary.
|
||||||
|
3. `asyncio.gather()` is only for independent reads. Do not use it around writes that depend on ordering or transaction semantics.
|
||||||
|
4. Keep file upload/download, SMTP, test-case pruning, judge heartbeat mutation paths, and contest permission-heavy flows sync until the async decorator and middleware work is complete.
|
||||||
|
5. Current sync `MiddlewareMixin` middleware means ASGI requests still cross sync boundaries. The first async milestone is correctness and latency cleanup, not a full event-loop purity claim.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 1: Add Regression Coverage For Converted Async Detail Views
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Create: `utils/test_async_view_regressions.py`
|
||||||
|
|
||||||
|
- [ ] **Step 1: Write failing regression tests**
|
||||||
|
|
||||||
|
Create `utils/test_async_view_regressions.py`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
|
from asgiref.sync import sync_to_async
|
||||||
|
from django.contrib.auth import get_user_model
|
||||||
|
from django.test import AsyncClient, TestCase
|
||||||
|
from django.utils import timezone
|
||||||
|
|
||||||
|
from account.models import UserProfile
|
||||||
|
from announcement.models import Announcement
|
||||||
|
from contest.models import Contest
|
||||||
|
from flowchart.models import FlowchartSubmission
|
||||||
|
from problem.models import Problem, ProblemRuleType
|
||||||
|
from utils.constants import ContestRuleType, Difficulty
|
||||||
|
|
||||||
|
User = get_user_model()
|
||||||
|
|
||||||
|
|
||||||
|
def make_user(username="async_user"):
|
||||||
|
user = User.objects.create(username=username, email=f"{username}@example.com")
|
||||||
|
user.set_password("pass1234")
|
||||||
|
user.save()
|
||||||
|
UserProfile.objects.create(user=user)
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
def make_problem(user):
|
||||||
|
return Problem.objects.create(
|
||||||
|
_id="ASYNC001",
|
||||||
|
title="Async Problem",
|
||||||
|
description="desc",
|
||||||
|
input_description="input",
|
||||||
|
output_description="output",
|
||||||
|
samples=[],
|
||||||
|
test_case_id="async-test-case",
|
||||||
|
test_case_score=[],
|
||||||
|
hint="",
|
||||||
|
languages=["Python3"],
|
||||||
|
template={},
|
||||||
|
created_by=user,
|
||||||
|
time_limit=1000,
|
||||||
|
memory_limit=128,
|
||||||
|
rule_type=ProblemRuleType.ACM,
|
||||||
|
difficulty=Difficulty.LOW,
|
||||||
|
share_submission=False,
|
||||||
|
allow_flowchart=True,
|
||||||
|
show_flowchart=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncConvertedViewRegressionTests(TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpTestData(cls):
|
||||||
|
cls.user = make_user()
|
||||||
|
cls.announcement = Announcement.objects.create(
|
||||||
|
title="Async Announcement",
|
||||||
|
content="content",
|
||||||
|
tag="notice",
|
||||||
|
visible=True,
|
||||||
|
top=False,
|
||||||
|
created_by=cls.user,
|
||||||
|
)
|
||||||
|
cls.contest = Contest.objects.create(
|
||||||
|
title="Async Contest",
|
||||||
|
description="contest desc",
|
||||||
|
tag="weekly",
|
||||||
|
real_time_rank=True,
|
||||||
|
password=None,
|
||||||
|
rule_type=ContestRuleType.ACM,
|
||||||
|
start_time=timezone.now() - timedelta(hours=1),
|
||||||
|
end_time=timezone.now() + timedelta(hours=1),
|
||||||
|
created_by=cls.user,
|
||||||
|
visible=True,
|
||||||
|
allowed_ip_ranges=[],
|
||||||
|
)
|
||||||
|
cls.problem = make_problem(cls.user)
|
||||||
|
cls.flowchart_submission = FlowchartSubmission.objects.create(
|
||||||
|
user=cls.user,
|
||||||
|
problem=cls.problem,
|
||||||
|
mermaid_code="graph TD\nA-->B",
|
||||||
|
flowchart_data={},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_announcement_detail_serializes_created_by(self):
|
||||||
|
response = await AsyncClient().get(f"/api/announcement?id={self.announcement.id}")
|
||||||
|
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
body = response.json()
|
||||||
|
self.assertIsNone(body["error"])
|
||||||
|
self.assertEqual(body["data"]["created_by"]["username"], self.user.username)
|
||||||
|
|
||||||
|
async def test_contest_detail_serializes_created_by(self):
|
||||||
|
response = await AsyncClient().get(f"/api/contest?id={self.contest.id}")
|
||||||
|
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
body = response.json()
|
||||||
|
self.assertIsNone(body["error"])
|
||||||
|
self.assertEqual(body["data"]["created_by"]["username"], self.user.username)
|
||||||
|
|
||||||
|
async def test_flowchart_detail_serializes_user_and_problem(self):
|
||||||
|
client = AsyncClient()
|
||||||
|
await sync_to_async(client.force_login)(self.user)
|
||||||
|
|
||||||
|
response = await client.get(f"/api/flowchart/submission?id={self.flowchart_submission.id}")
|
||||||
|
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
body = response.json()
|
||||||
|
self.assertIsNone(body["error"])
|
||||||
|
self.assertEqual(body["data"]["username"], self.user.username)
|
||||||
|
self.assertEqual(body["data"]["problem"], self.problem.id)
|
||||||
|
```
|
||||||
|
|
||||||
|
- [ ] **Step 2: Run the regression tests**
|
||||||
|
|
||||||
|
Run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
rtk uv run python manage.py test utils.test_async_view_regressions -v 2
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected before Task 2: at least one test returns `server-error` or raises async-unsafe lazy relation access.
|
||||||
|
|
||||||
|
- [ ] **Step 3: Commit the failing tests**
|
||||||
|
|
||||||
|
Run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
rtk git add utils/test_async_view_regressions.py
|
||||||
|
rtk git commit -m "test: cover async view serialization regressions"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 2: Preload Relations For Already Converted Async Detail Views
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `announcement/views/oj.py`
|
||||||
|
- Modify: `contest/views/oj.py`
|
||||||
|
- Modify: `flowchart/views/oj.py`
|
||||||
|
|
||||||
|
- [ ] **Step 1: Fix announcement detail relation loading**
|
||||||
|
|
||||||
|
In `announcement/views/oj.py`, replace the detail query with:
|
||||||
|
|
||||||
|
```python
|
||||||
|
announcement = await (
|
||||||
|
Announcement.objects.select_related("created_by")
|
||||||
|
.filter(id=id, visible=True)
|
||||||
|
.afirst()
|
||||||
|
)
|
||||||
|
if announcement is None:
|
||||||
|
raise Announcement.DoesNotExist
|
||||||
|
```
|
||||||
|
|
||||||
|
- [ ] **Step 2: Fix contest detail relation loading**
|
||||||
|
|
||||||
|
In `contest/views/oj.py`, replace the detail query with:
|
||||||
|
|
||||||
|
```python
|
||||||
|
contest = await (
|
||||||
|
Contest.objects.select_related("created_by")
|
||||||
|
.filter(id=id, visible=True)
|
||||||
|
.afirst()
|
||||||
|
)
|
||||||
|
if contest is None:
|
||||||
|
raise Contest.DoesNotExist
|
||||||
|
```
|
||||||
|
|
||||||
|
- [ ] **Step 3: Fix flowchart submission detail relation loading**
|
||||||
|
|
||||||
|
In `flowchart/views/oj.py`, update `FlowchartSubmissionAPI.get()`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
submission = await (
|
||||||
|
FlowchartSubmission.objects.select_related("user", "problem")
|
||||||
|
.filter(id=submission_id)
|
||||||
|
.afirst()
|
||||||
|
)
|
||||||
|
if submission is None:
|
||||||
|
raise FlowchartSubmission.DoesNotExist
|
||||||
|
```
|
||||||
|
|
||||||
|
- [ ] **Step 4: Fix flowchart retry permission relation loading**
|
||||||
|
|
||||||
|
In `flowchart/views/oj.py`, update `FlowchartSubmissionRetryAPI.post()`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
submission = await (
|
||||||
|
FlowchartSubmission.objects.select_related("problem")
|
||||||
|
.filter(id=submission_id)
|
||||||
|
.afirst()
|
||||||
|
)
|
||||||
|
if submission is None:
|
||||||
|
raise FlowchartSubmission.DoesNotExist
|
||||||
|
```
|
||||||
|
|
||||||
|
- [ ] **Step 5: Fix flowchart completed-detail relation loading**
|
||||||
|
|
||||||
|
In `flowchart/views/oj.py`, update `FlowchartSubmissionDetailAPI.get()` before serialization:
|
||||||
|
|
||||||
|
```python
|
||||||
|
submissions = (
|
||||||
|
FlowchartSubmission.objects.select_related("user", "problem")
|
||||||
|
.filter(
|
||||||
|
user=request.user,
|
||||||
|
problem=problem,
|
||||||
|
status=FlowchartSubmissionStatus.COMPLETED,
|
||||||
|
)
|
||||||
|
.order_by("create_time")
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
- [ ] **Step 6: Run the regression tests**
|
||||||
|
|
||||||
|
Run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
rtk uv run python manage.py test utils.test_async_view_regressions -v 2
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected: all tests pass.
|
||||||
|
|
||||||
|
- [ ] **Step 7: Run Django system checks**
|
||||||
|
|
||||||
|
Run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
rtk uv run python manage.py check
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected: `System check identified no issues`.
|
||||||
|
|
||||||
|
- [ ] **Step 8: Commit relation fixes**
|
||||||
|
|
||||||
|
Run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
rtk git add announcement/views/oj.py contest/views/oj.py flowchart/views/oj.py
|
||||||
|
rtk git commit -m "fix: preload relations in async detail views"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 3: Add Async Serialization Helpers To `AsyncAPIView`
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `utils/api/api.py`
|
||||||
|
- Modify: `utils/test_async_api.py`
|
||||||
|
|
||||||
|
- [ ] **Step 1: Write tests for async serialization helper**
|
||||||
|
|
||||||
|
Create `utils/test_async_api.py`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import json
|
||||||
|
|
||||||
|
from django.test import AsyncRequestFactory, SimpleTestCase
|
||||||
|
|
||||||
|
from utils.api import AsyncAPIView, serializers, validate_serializer
|
||||||
|
|
||||||
|
|
||||||
|
class PayloadSerializer(serializers.Serializer):
|
||||||
|
name = serializers.CharField()
|
||||||
|
|
||||||
|
|
||||||
|
class EchoSerializer(serializers.Serializer):
|
||||||
|
name = serializers.CharField()
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncValidatedEchoView(AsyncAPIView):
|
||||||
|
@validate_serializer(PayloadSerializer)
|
||||||
|
async def post(self, request):
|
||||||
|
return self.success({"name": request.data["name"]})
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncSerializationHelperTests(SimpleTestCase):
|
||||||
|
async def test_validate_serializer_supports_async_view_methods(self):
|
||||||
|
request = AsyncRequestFactory().post(
|
||||||
|
"/api/echo",
|
||||||
|
data=json.dumps({"name": "alice"}),
|
||||||
|
content_type="application/json",
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await AsyncValidatedEchoView.as_view()(request)
|
||||||
|
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
self.assertEqual(response.data["error"], None)
|
||||||
|
self.assertEqual(response.data["data"], {"name": "alice"})
|
||||||
|
|
||||||
|
async def test_async_serialize_data_returns_serializer_data(self):
|
||||||
|
view = AsyncAPIView()
|
||||||
|
|
||||||
|
data = await view.async_serialize_data(
|
||||||
|
EchoSerializer,
|
||||||
|
[{"name": "alice"}, {"name": "bob"}],
|
||||||
|
many=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(data, [{"name": "alice"}, {"name": "bob"}])
|
||||||
|
```
|
||||||
|
|
||||||
|
- [ ] **Step 2: Run helper tests and confirm failure**
|
||||||
|
|
||||||
|
Run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
rtk uv run python manage.py test utils.test_async_api -v 2
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected before implementation: `AttributeError: 'AsyncAPIView' object has no attribute 'async_serialize_data'`.
|
||||||
|
|
||||||
|
- [ ] **Step 3: Add `sync_to_async` import**
|
||||||
|
|
||||||
|
In `utils/api/api.py`, add:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from asgiref.sync import sync_to_async
|
||||||
|
```
|
||||||
|
|
||||||
|
- [ ] **Step 4: Add serializer helper methods**
|
||||||
|
|
||||||
|
In `AsyncAPIView`, before `async_paginate_data()`, add:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def serialize_data(self, object_serializer, data, **kwargs):
|
||||||
|
return object_serializer(data, **kwargs).data
|
||||||
|
|
||||||
|
async def async_serialize_data(self, object_serializer, data, **kwargs):
|
||||||
|
return await sync_to_async(
|
||||||
|
self.serialize_data,
|
||||||
|
thread_sensitive=True,
|
||||||
|
)(object_serializer, data, **kwargs)
|
||||||
|
```
|
||||||
|
|
||||||
|
- [ ] **Step 5: Use helper inside async pagination**
|
||||||
|
|
||||||
|
In `AsyncAPIView.async_paginate_data()`, replace:
|
||||||
|
|
||||||
|
```python
|
||||||
|
if object_serializer:
|
||||||
|
results = object_serializer(results, many=True, context={"request": request}).data
|
||||||
|
```
|
||||||
|
|
||||||
|
with:
|
||||||
|
|
||||||
|
```python
|
||||||
|
if object_serializer:
|
||||||
|
results = await self.async_serialize_data(
|
||||||
|
object_serializer,
|
||||||
|
results,
|
||||||
|
many=True,
|
||||||
|
context={"request": request},
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
- [ ] **Step 6: Run helper and regression tests**
|
||||||
|
|
||||||
|
Run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
rtk uv run python manage.py test utils.test_async_api utils.test_async_view_regressions -v 2
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected: all tests pass.
|
||||||
|
|
||||||
|
- [ ] **Step 7: Commit async serializer helper**
|
||||||
|
|
||||||
|
Run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
rtk git add utils/api/api.py utils/test_async_api.py
|
||||||
|
rtk git commit -m "feat: add async serializer helper"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 4: Add Cache Helpers For Async Views
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Create: `utils/async_helpers.py`
|
||||||
|
- Create: `utils/test_async_helpers.py`
|
||||||
|
- Modify: `problem/views/oj.py`
|
||||||
|
- Modify: `comment/views/oj.py`
|
||||||
|
|
||||||
|
- [ ] **Step 1: Write cache helper tests**
|
||||||
|
|
||||||
|
Create `utils/test_async_helpers.py`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from django.test import SimpleTestCase, override_settings
|
||||||
|
|
||||||
|
from utils.async_helpers import async_cache_delete, async_cache_get, async_cache_set
|
||||||
|
|
||||||
|
|
||||||
|
@override_settings(
|
||||||
|
CACHES={
|
||||||
|
"default": {
|
||||||
|
"BACKEND": "django.core.cache.backends.locmem.LocMemCache",
|
||||||
|
"LOCATION": "async-helper-tests",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
class AsyncCacheHelperTests(SimpleTestCase):
|
||||||
|
async def test_async_cache_round_trip(self):
|
||||||
|
await async_cache_set("async:key", {"value": 1}, 30)
|
||||||
|
|
||||||
|
value = await async_cache_get("async:key")
|
||||||
|
|
||||||
|
self.assertEqual(value, {"value": 1})
|
||||||
|
|
||||||
|
async def test_async_cache_delete(self):
|
||||||
|
await async_cache_set("async:delete", "present", 30)
|
||||||
|
await async_cache_delete("async:delete")
|
||||||
|
|
||||||
|
value = await async_cache_get("async:delete")
|
||||||
|
|
||||||
|
self.assertIsNone(value)
|
||||||
|
```
|
||||||
|
|
||||||
|
- [ ] **Step 2: Run tests and confirm import failure**
|
||||||
|
|
||||||
|
Run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
rtk uv run python manage.py test utils.test_async_helpers -v 2
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected before implementation: `ModuleNotFoundError: No module named 'utils.async_helpers'`.
|
||||||
|
|
||||||
|
- [ ] **Step 3: Create async cache helpers**
|
||||||
|
|
||||||
|
Create `utils/async_helpers.py`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from asgiref.sync import sync_to_async
|
||||||
|
from django.core.cache import cache
|
||||||
|
|
||||||
|
|
||||||
|
async def async_cache_get(key, default=None):
|
||||||
|
return await sync_to_async(cache.get, thread_sensitive=True)(key, default)
|
||||||
|
|
||||||
|
|
||||||
|
async def async_cache_set(key, value, timeout=None):
|
||||||
|
return await sync_to_async(cache.set, thread_sensitive=True)(key, value, timeout)
|
||||||
|
|
||||||
|
|
||||||
|
async def async_cache_delete(key):
|
||||||
|
return await sync_to_async(cache.delete, thread_sensitive=True)(key)
|
||||||
|
```
|
||||||
|
|
||||||
|
- [ ] **Step 4: Convert async problem cache calls**
|
||||||
|
|
||||||
|
In `problem/views/oj.py`, add:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from utils.async_helpers import async_cache_get, async_cache_set
|
||||||
|
```
|
||||||
|
|
||||||
|
In `ProblemTagAPI.get()`, replace:
|
||||||
|
|
||||||
|
```python
|
||||||
|
cached = cache.get(cache_key)
|
||||||
|
```
|
||||||
|
|
||||||
|
with:
|
||||||
|
|
||||||
|
```python
|
||||||
|
cached = await async_cache_get(cache_key)
|
||||||
|
```
|
||||||
|
|
||||||
|
Replace:
|
||||||
|
|
||||||
|
```python
|
||||||
|
cache.set(cache_key, data, 3600)
|
||||||
|
```
|
||||||
|
|
||||||
|
with:
|
||||||
|
|
||||||
|
```python
|
||||||
|
await async_cache_set(cache_key, data, 3600)
|
||||||
|
```
|
||||||
|
|
||||||
|
- [ ] **Step 5: Convert async comment cache calls**
|
||||||
|
|
||||||
|
In `comment/views/oj.py`, add:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from utils.async_helpers import async_cache_delete, async_cache_get, async_cache_set
|
||||||
|
```
|
||||||
|
|
||||||
|
In `CommentAPI.post()`, replace:
|
||||||
|
|
||||||
|
```python
|
||||||
|
cache.delete(f"{CacheKey.comment_stats}:{problem.id}")
|
||||||
|
```
|
||||||
|
|
||||||
|
with:
|
||||||
|
|
||||||
|
```python
|
||||||
|
await async_cache_delete(f"{CacheKey.comment_stats}:{problem.id}")
|
||||||
|
```
|
||||||
|
|
||||||
|
In `CommentStatisticsAPI.get()`, replace `cache.get()` and `cache.set()` with:
|
||||||
|
|
||||||
|
```python
|
||||||
|
cached = await async_cache_get(cache_key)
|
||||||
|
```
|
||||||
|
|
||||||
|
and:
|
||||||
|
|
||||||
|
```python
|
||||||
|
await async_cache_set(cache_key, data, 3600)
|
||||||
|
```
|
||||||
|
|
||||||
|
- [ ] **Step 6: Run helper and targeted async tests**
|
||||||
|
|
||||||
|
Run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
rtk uv run python manage.py test utils.test_async_helpers utils.test_async_api utils.test_async_view_regressions -v 2
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected: all tests pass.
|
||||||
|
|
||||||
|
- [ ] **Step 7: Commit cache helper work**
|
||||||
|
|
||||||
|
Run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
rtk git add utils/async_helpers.py utils/test_async_helpers.py problem/views/oj.py comment/views/oj.py
|
||||||
|
rtk git commit -m "feat: add async cache helpers"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 5: Make Permission Decorators Async-Aware Before More Conversions
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `account/decorators.py`
|
||||||
|
- Create: `account/test_async_decorators.py`
|
||||||
|
|
||||||
|
- [ ] **Step 1: Write tests for async `login_required`**
|
||||||
|
|
||||||
|
Create `account/test_async_decorators.py`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from django.contrib.auth.models import AnonymousUser
|
||||||
|
from django.test import AsyncRequestFactory, SimpleTestCase
|
||||||
|
|
||||||
|
from account.decorators import login_required
|
||||||
|
from utils.api import AsyncAPIView
|
||||||
|
|
||||||
|
|
||||||
|
class DisabledUser:
|
||||||
|
is_authenticated = True
|
||||||
|
is_disabled = True
|
||||||
|
|
||||||
|
|
||||||
|
class ActiveUser:
|
||||||
|
is_authenticated = True
|
||||||
|
is_disabled = False
|
||||||
|
|
||||||
|
|
||||||
|
class ProtectedAsyncView(AsyncAPIView):
|
||||||
|
@login_required
|
||||||
|
async def get(self, request):
|
||||||
|
return self.success("ok")
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncPermissionDecoratorTests(SimpleTestCase):
|
||||||
|
async def test_async_login_required_allows_active_user(self):
|
||||||
|
request = AsyncRequestFactory().get("/api/protected")
|
||||||
|
request.user = ActiveUser()
|
||||||
|
|
||||||
|
response = await ProtectedAsyncView.as_view()(request)
|
||||||
|
|
||||||
|
self.assertEqual(response.data["error"], None)
|
||||||
|
self.assertEqual(response.data["data"], "ok")
|
||||||
|
|
||||||
|
async def test_async_login_required_rejects_anonymous_user(self):
|
||||||
|
request = AsyncRequestFactory().get("/api/protected")
|
||||||
|
request.user = AnonymousUser()
|
||||||
|
|
||||||
|
response = await ProtectedAsyncView.as_view()(request)
|
||||||
|
|
||||||
|
self.assertEqual(response.data["error"], "permission-denied")
|
||||||
|
self.assertEqual(response.data["data"], "Please login first")
|
||||||
|
|
||||||
|
async def test_async_login_required_rejects_disabled_user(self):
|
||||||
|
request = AsyncRequestFactory().get("/api/protected")
|
||||||
|
request.user = DisabledUser()
|
||||||
|
|
||||||
|
response = await ProtectedAsyncView.as_view()(request)
|
||||||
|
|
||||||
|
self.assertEqual(response.data["error"], "permission-denied")
|
||||||
|
self.assertEqual(response.data["data"], "Your account is disabled")
|
||||||
|
```
|
||||||
|
|
||||||
|
- [ ] **Step 2: Run decorator tests**
|
||||||
|
|
||||||
|
Run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
rtk uv run python manage.py test account.test_async_decorators -v 2
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected before implementation: this may pass because `AsyncAPIView.dispatch()` awaits returned coroutine. Keep the test anyway as a contract before refactoring.
|
||||||
|
|
||||||
|
- [ ] **Step 3: Refactor `BasePermissionDecorator` with explicit async path**
|
||||||
|
|
||||||
|
In `account/decorators.py`, add:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import inspect
|
||||||
|
```
|
||||||
|
|
||||||
|
Replace `BasePermissionDecorator.__get__()` with:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def __get__(self, obj, obj_type):
|
||||||
|
if inspect.iscoroutinefunction(self.func):
|
||||||
|
return functools.partial(self._async_call, obj)
|
||||||
|
return functools.partial(self.__call__, obj)
|
||||||
|
```
|
||||||
|
|
||||||
|
Add this method to `BasePermissionDecorator`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
async def _async_call(self, *args, **kwargs):
|
||||||
|
self.request = args[1]
|
||||||
|
|
||||||
|
if self.check_permission():
|
||||||
|
if self.request.user.is_disabled:
|
||||||
|
return self.error("Your account is disabled")
|
||||||
|
return await self.func(*args, **kwargs)
|
||||||
|
return self.error("Please login first")
|
||||||
|
```
|
||||||
|
|
||||||
|
- [ ] **Step 4: Run decorator and async view tests**
|
||||||
|
|
||||||
|
Run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
rtk uv run python manage.py test account.test_async_decorators utils.test_async_api utils.test_async_view_regressions -v 2
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected: all tests pass.
|
||||||
|
|
||||||
|
- [ ] **Step 5: Commit decorator refactor**
|
||||||
|
|
||||||
|
Run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
rtk git add account/decorators.py account/test_async_decorators.py
|
||||||
|
rtk git commit -m "refactor: make permission decorators async-aware"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 6: Convert One Endpoint Family At A Time
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify only the endpoint family being converted in each batch.
|
||||||
|
- Add or update tests in the same app, using `AsyncClient` for converted URLs.
|
||||||
|
|
||||||
|
- [ ] **Step 1: Choose the next low-risk batch**
|
||||||
|
|
||||||
|
Use this order:
|
||||||
|
|
||||||
|
```text
|
||||||
|
1. Pure public GET list/detail endpoints already close to async:
|
||||||
|
- announcement list/detail
|
||||||
|
- contest list/detail
|
||||||
|
- problem tag/list/detail
|
||||||
|
|
||||||
|
2. Authenticated read-only list endpoints:
|
||||||
|
- message list
|
||||||
|
- submission list
|
||||||
|
- flowchart list/detail/current
|
||||||
|
|
||||||
|
3. Simple create/update endpoints with no contest permission decorator:
|
||||||
|
- message create
|
||||||
|
- comment create
|
||||||
|
- flowchart retry/create
|
||||||
|
```
|
||||||
|
|
||||||
|
- [ ] **Step 2: For each endpoint, write one async smoke test**
|
||||||
|
|
||||||
|
Use this template and replace the URL and assertions with the exact endpoint response:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from django.test import AsyncClient, TestCase
|
||||||
|
|
||||||
|
|
||||||
|
class EndpointAsyncSmokeTests(TestCase):
|
||||||
|
async def test_endpoint_returns_success_envelope(self):
|
||||||
|
response = await AsyncClient().get("/api/endpoint?limit=10")
|
||||||
|
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
body = response.json()
|
||||||
|
self.assertIn("error", body)
|
||||||
|
self.assertIn("data", body)
|
||||||
|
```
|
||||||
|
|
||||||
|
- [ ] **Step 3: Convert ORM access only where async ORM exists**
|
||||||
|
|
||||||
|
Use these replacements:
|
||||||
|
|
||||||
|
```python
|
||||||
|
obj = await Model.objects.aget(id=id)
|
||||||
|
count = await queryset.acount()
|
||||||
|
first = await queryset.afirst()
|
||||||
|
last = await queryset.alast()
|
||||||
|
items = [item async for item in queryset[offset:offset + limit]]
|
||||||
|
created = await Model.objects.acreate(field=value)
|
||||||
|
await instance.asave(update_fields=["field"])
|
||||||
|
```
|
||||||
|
|
||||||
|
- [ ] **Step 4: Keep sync-only helpers behind `sync_to_async`**
|
||||||
|
|
||||||
|
Use this pattern:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from asgiref.sync import sync_to_async
|
||||||
|
|
||||||
|
result = await sync_to_async(sync_helper, thread_sensitive=True)(arg1, arg2)
|
||||||
|
```
|
||||||
|
|
||||||
|
- [ ] **Step 5: Run targeted app tests and system check after each batch**
|
||||||
|
|
||||||
|
Run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
rtk uv run python manage.py test <app_label> -v 2
|
||||||
|
rtk uv run python manage.py check
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected: targeted tests pass and system check reports no issues.
|
||||||
|
|
||||||
|
- [ ] **Step 6: Commit each endpoint family separately**
|
||||||
|
|
||||||
|
Run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
rtk git add <changed-files>
|
||||||
|
rtk git commit -m "refactor: async <endpoint-family> endpoints"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 7: Audit Middleware Before Claiming Full Async Benefit
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- Modify: `account/middleware.py`
|
||||||
|
- Add tests only for behavior that changes.
|
||||||
|
|
||||||
|
- [ ] **Step 1: Record current sync middleware boundaries**
|
||||||
|
|
||||||
|
Before changing middleware, note these current sync classes:
|
||||||
|
|
||||||
|
```text
|
||||||
|
account.middleware.APITokenAuthMiddleware
|
||||||
|
account.middleware.AdminRoleRequiredMiddleware
|
||||||
|
account.middleware.SessionRecordMiddleware
|
||||||
|
```
|
||||||
|
|
||||||
|
- [ ] **Step 2: Keep middleware sync during endpoint correctness work**
|
||||||
|
|
||||||
|
Do not convert middleware in the same commit as endpoint conversions. Middleware affects every request and needs its own review.
|
||||||
|
|
||||||
|
- [ ] **Step 3: If middleware conversion is pursued, convert one class per commit**
|
||||||
|
|
||||||
|
Use Django new-style middleware with explicit sync and async handling. `__acall__` is a local helper; `__call__` dispatches to it when Django passes an async `get_response`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from asgiref.sync import iscoroutinefunction, markcoroutinefunction, sync_to_async
|
||||||
|
|
||||||
|
|
||||||
|
class ExampleMiddleware:
|
||||||
|
sync_capable = True
|
||||||
|
async_capable = True
|
||||||
|
|
||||||
|
def __init__(self, get_response):
|
||||||
|
self.get_response = get_response
|
||||||
|
self.is_async = iscoroutinefunction(get_response)
|
||||||
|
if self.is_async:
|
||||||
|
markcoroutinefunction(self)
|
||||||
|
|
||||||
|
def __call__(self, request):
|
||||||
|
if self.is_async:
|
||||||
|
return self.__acall__(request)
|
||||||
|
response = self.process_request(request)
|
||||||
|
if response is not None:
|
||||||
|
return response
|
||||||
|
return self.get_response(request)
|
||||||
|
|
||||||
|
async def __acall__(self, request):
|
||||||
|
response = await self.aprocess_request(request)
|
||||||
|
if response is not None:
|
||||||
|
return response
|
||||||
|
return await self.get_response(request)
|
||||||
|
|
||||||
|
def process_request(self, request):
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def aprocess_request(self, request):
|
||||||
|
return await sync_to_async(self.process_request, thread_sensitive=True)(request)
|
||||||
|
```
|
||||||
|
|
||||||
|
- [ ] **Step 4: Run full backend test command after middleware work**
|
||||||
|
|
||||||
|
Run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
rtk uv run python manage.py test -v 2
|
||||||
|
rtk uv run python manage.py check
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected: all tests pass and system check reports no issues.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Definition Of Done
|
||||||
|
|
||||||
|
- `rtk uv run python manage.py check` passes.
|
||||||
|
- Async regression tests cover converted detail serializers that depend on FK relations.
|
||||||
|
- Converted async views do not call DRF serializer `.data` directly unless the data is primitive and relation-free.
|
||||||
|
- Cache access from async views uses async helper wrappers.
|
||||||
|
- New endpoint conversions are committed in endpoint-family-sized commits.
|
||||||
|
- No file upload/download, SMTP, judge heartbeat, test-case prune, or contest permission-heavy endpoint is converted without a separate focused plan.
|
||||||
@@ -11,6 +11,6 @@ class FlowchartEvaluationPromptTests(TestCase):
|
|||||||
self.assertIn("Mermaid节点ID由系统生成", prompt)
|
self.assertIn("Mermaid节点ID由系统生成", prompt)
|
||||||
self.assertIn("不要评价节点ID", prompt)
|
self.assertIn("不要评价节点ID", prompt)
|
||||||
self.assertIn("不要因节点ID扣分", prompt)
|
self.assertIn("不要因节点ID扣分", prompt)
|
||||||
self.assertIn("feedback控制在0字以内", prompt)
|
self.assertIn("feedback控制在100字以内", prompt)
|
||||||
self.assertIn("suggestions最多3条", prompt)
|
self.assertIn("suggestions最多3条", prompt)
|
||||||
self.assertIn("重要建议必须以【重点】开头", prompt)
|
self.assertIn("重要建议必须以【重点】开头", prompt)
|
||||||
|
|||||||
@@ -7,65 +7,63 @@ from flowchart.serializers import (
|
|||||||
)
|
)
|
||||||
from flowchart.tasks import evaluate_flowchart_task
|
from flowchart.tasks import evaluate_flowchart_task
|
||||||
from problem.models import Problem
|
from problem.models import Problem
|
||||||
from utils.api import APIView
|
from utils.api import AsyncAPIView
|
||||||
|
|
||||||
|
|
||||||
class FlowchartSubmissionAPI(APIView):
|
class FlowchartSubmissionAPI(AsyncAPIView):
|
||||||
@login_required
|
@login_required
|
||||||
def post(self, request):
|
async def post(self, request):
|
||||||
"""创建流程图提交"""
|
|
||||||
serializer = CreateFlowchartSubmissionSerializer(data=request.data)
|
serializer = CreateFlowchartSubmissionSerializer(data=request.data)
|
||||||
if not serializer.is_valid():
|
if not serializer.is_valid():
|
||||||
return self.error(serializer.errors)
|
return self.error(serializer.errors)
|
||||||
|
|
||||||
data = serializer.validated_data
|
data = serializer.validated_data
|
||||||
|
|
||||||
# 验证题目存在
|
|
||||||
try:
|
try:
|
||||||
problem = Problem.objects.get(id=data["problem_id"])
|
problem = await Problem.objects.aget(id=data["problem_id"])
|
||||||
except Problem.DoesNotExist:
|
except Problem.DoesNotExist:
|
||||||
return self.error("Problem doesn't exist")
|
return self.error("Problem doesn't exist")
|
||||||
|
|
||||||
# 验证题目是否允许流程图提交
|
|
||||||
if not problem.allow_flowchart:
|
if not problem.allow_flowchart:
|
||||||
return self.error("This problem does not allow flowchart submission")
|
return self.error("This problem does not allow flowchart submission")
|
||||||
|
|
||||||
# 创建提交记录
|
submission = await FlowchartSubmission.objects.acreate(
|
||||||
submission = FlowchartSubmission.objects.create(
|
|
||||||
user=request.user,
|
user=request.user,
|
||||||
problem=problem,
|
problem=problem,
|
||||||
mermaid_code=data["mermaid_code"],
|
mermaid_code=data["mermaid_code"],
|
||||||
flowchart_data=data.get("flowchart_data", {}),
|
flowchart_data=data.get("flowchart_data", {}),
|
||||||
)
|
)
|
||||||
|
|
||||||
# 启动AI评分任务
|
|
||||||
evaluate_flowchart_task.send(submission.id)
|
evaluate_flowchart_task.send(submission.id)
|
||||||
|
|
||||||
return self.success({"submission_id": submission.id, "status": "pending"})
|
return self.success({"submission_id": submission.id, "status": "pending"})
|
||||||
|
|
||||||
@login_required
|
@login_required
|
||||||
def get(self, request):
|
async def get(self, request):
|
||||||
"""获取流程图提交详情"""
|
|
||||||
submission_id = request.GET.get("id")
|
submission_id = request.GET.get("id")
|
||||||
if not submission_id:
|
if not submission_id:
|
||||||
return self.error("submission_id is required")
|
return self.error("submission_id is required")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
submission = FlowchartSubmission.objects.get(id=submission_id)
|
submission = await (
|
||||||
|
FlowchartSubmission.objects.select_related("user", "problem")
|
||||||
|
.filter(id=submission_id)
|
||||||
|
.afirst()
|
||||||
|
)
|
||||||
|
if submission is None:
|
||||||
|
raise FlowchartSubmission.DoesNotExist
|
||||||
except FlowchartSubmission.DoesNotExist:
|
except FlowchartSubmission.DoesNotExist:
|
||||||
return self.error("Submission doesn't exist")
|
return self.error("Submission doesn't exist")
|
||||||
|
|
||||||
if not submission.check_user_permission(request.user):
|
if not submission.check_user_permission(request.user):
|
||||||
return self.error("No permission for this submission")
|
return self.error("No permission for this submission")
|
||||||
|
|
||||||
serializer = FlowchartSubmissionSerializer(submission)
|
return self.success(await self.async_serialize_data(FlowchartSubmissionSerializer, submission))
|
||||||
return self.success(serializer.data)
|
|
||||||
|
|
||||||
|
|
||||||
class FlowchartSubmissionListAPI(APIView):
|
class FlowchartSubmissionListAPI(AsyncAPIView):
|
||||||
@login_required
|
@login_required
|
||||||
def get(self, request):
|
async def get(self, request):
|
||||||
"""获取流程图提交列表"""
|
|
||||||
username = request.GET.get("username")
|
username = request.GET.get("username")
|
||||||
problem_id = request.GET.get("problem_id")
|
problem_id = request.GET.get("problem_id")
|
||||||
myself = request.GET.get("myself")
|
myself = request.GET.get("myself")
|
||||||
@@ -74,7 +72,7 @@ class FlowchartSubmissionListAPI(APIView):
|
|||||||
|
|
||||||
if problem_id:
|
if problem_id:
|
||||||
try:
|
try:
|
||||||
problem = Problem.objects.get(
|
problem = await Problem.objects.aget(
|
||||||
_id__iexact=problem_id, contest_id__isnull=True, visible=True
|
_id__iexact=problem_id, contest_id__isnull=True, visible=True
|
||||||
)
|
)
|
||||||
except Problem.DoesNotExist:
|
except Problem.DoesNotExist:
|
||||||
@@ -88,38 +86,42 @@ class FlowchartSubmissionListAPI(APIView):
|
|||||||
elif request.user.is_regular_user():
|
elif request.user.is_regular_user():
|
||||||
queryset = queryset.filter(user=request.user)
|
queryset = queryset.filter(user=request.user)
|
||||||
|
|
||||||
data = self.paginate_data(request, queryset)
|
data = await self.async_paginate_data(request, queryset)
|
||||||
data["results"] = FlowchartSubmissionListSerializer(
|
data["results"] = await self.async_serialize_data(
|
||||||
data["results"], many=True
|
FlowchartSubmissionListSerializer,
|
||||||
).data
|
data["results"],
|
||||||
|
many=True,
|
||||||
|
)
|
||||||
return self.success(data)
|
return self.success(data)
|
||||||
|
|
||||||
|
|
||||||
class FlowchartSubmissionRetryAPI(APIView):
|
class FlowchartSubmissionRetryAPI(AsyncAPIView):
|
||||||
@login_required
|
@login_required
|
||||||
def post(self, request):
|
async def post(self, request):
|
||||||
"""重新触发AI评分"""
|
|
||||||
submission_id = request.data.get("submission_id")
|
submission_id = request.data.get("submission_id")
|
||||||
if not submission_id:
|
if not submission_id:
|
||||||
return self.error("submission_id is required")
|
return self.error("submission_id is required")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
submission = FlowchartSubmission.objects.get(id=submission_id)
|
submission = await (
|
||||||
|
FlowchartSubmission.objects.select_related("problem")
|
||||||
|
.filter(id=submission_id)
|
||||||
|
.afirst()
|
||||||
|
)
|
||||||
|
if submission is None:
|
||||||
|
raise FlowchartSubmission.DoesNotExist
|
||||||
except FlowchartSubmission.DoesNotExist:
|
except FlowchartSubmission.DoesNotExist:
|
||||||
return self.error("Submission doesn't exist")
|
return self.error("Submission doesn't exist")
|
||||||
|
|
||||||
# 检查权限
|
|
||||||
if not submission.check_user_permission(request.user):
|
if not submission.check_user_permission(request.user):
|
||||||
return self.error("No permission for this submission")
|
return self.error("No permission for this submission")
|
||||||
|
|
||||||
# 检查是否可以重新评分
|
|
||||||
if submission.status not in [
|
if submission.status not in [
|
||||||
FlowchartSubmissionStatus.FAILED,
|
FlowchartSubmissionStatus.FAILED,
|
||||||
FlowchartSubmissionStatus.COMPLETED,
|
FlowchartSubmissionStatus.COMPLETED,
|
||||||
]:
|
]:
|
||||||
return self.error("Submission is not in a state that allows retry")
|
return self.error("Submission is not in a state that allows retry")
|
||||||
|
|
||||||
# 重置状态并重新启动AI评分
|
|
||||||
submission.status = FlowchartSubmissionStatus.PENDING
|
submission.status = FlowchartSubmissionStatus.PENDING
|
||||||
submission.ai_score = None
|
submission.ai_score = None
|
||||||
submission.ai_grade = None
|
submission.ai_grade = None
|
||||||
@@ -128,9 +130,8 @@ class FlowchartSubmissionRetryAPI(APIView):
|
|||||||
submission.ai_criteria_details = {}
|
submission.ai_criteria_details = {}
|
||||||
submission.processing_time = None
|
submission.processing_time = None
|
||||||
submission.evaluation_time = None
|
submission.evaluation_time = None
|
||||||
submission.save()
|
await submission.asave()
|
||||||
|
|
||||||
# 重新启动AI评分任务
|
|
||||||
evaluate_flowchart_task.send(submission.id)
|
evaluate_flowchart_task.send(submission.id)
|
||||||
|
|
||||||
return self.success(
|
return self.success(
|
||||||
@@ -142,15 +143,14 @@ class FlowchartSubmissionRetryAPI(APIView):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class FlowchartSubmissionDetailAPI(APIView):
|
class FlowchartSubmissionDetailAPI(AsyncAPIView):
|
||||||
@login_required
|
@login_required
|
||||||
def get(self, request):
|
async def get(self, request):
|
||||||
"""获取当前用户对指定题目的流程图提交详情"""
|
|
||||||
problem_id = request.GET.get("problem_id")
|
problem_id = request.GET.get("problem_id")
|
||||||
if not problem_id:
|
if not problem_id:
|
||||||
return self.error("problem_id is required")
|
return self.error("problem_id is required")
|
||||||
try:
|
try:
|
||||||
problem = Problem.objects.get(id=problem_id)
|
problem = await Problem.objects.aget(id=problem_id)
|
||||||
except Problem.DoesNotExist:
|
except Problem.DoesNotExist:
|
||||||
return self.error("Problem doesn't exist")
|
return self.error("Problem doesn't exist")
|
||||||
|
|
||||||
@@ -158,34 +158,37 @@ class FlowchartSubmissionDetailAPI(APIView):
|
|||||||
page = int(request.GET.get("page", 0))
|
page = int(request.GET.get("page", 0))
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return self.error("page must be an integer")
|
return self.error("page must be an integer")
|
||||||
submissions = FlowchartSubmission.objects.filter(
|
submissions = (
|
||||||
user=request.user,
|
FlowchartSubmission.objects.select_related("user", "problem")
|
||||||
problem=problem,
|
.filter(
|
||||||
status=FlowchartSubmissionStatus.COMPLETED,
|
user=request.user,
|
||||||
).order_by("create_time")
|
problem=problem,
|
||||||
count = submissions.count()
|
status=FlowchartSubmissionStatus.COMPLETED,
|
||||||
|
)
|
||||||
|
.order_by("create_time")
|
||||||
|
)
|
||||||
|
count = await submissions.acount()
|
||||||
if count == 0:
|
if count == 0:
|
||||||
return self.success({"submission": None, "count": 0})
|
return self.success({"submission": None, "count": 0})
|
||||||
# page=0 means latest; page=N means the Nth submission (1-indexed, chronological)
|
|
||||||
if page == 0:
|
if page == 0:
|
||||||
submission = submissions.last()
|
submission = await submissions.alast()
|
||||||
else:
|
else:
|
||||||
if page < 0 or page > count:
|
if page < 0 or page > count:
|
||||||
return self.error("Page out of range")
|
return self.error("Page out of range")
|
||||||
submission = submissions[page - 1]
|
result = [s async for s in submissions[page - 1:page]]
|
||||||
serializer = FlowchartSubmissionSerializer(submission)
|
submission = result[0]
|
||||||
return self.success({"submission": serializer.data, "count": count})
|
data = await self.async_serialize_data(FlowchartSubmissionSerializer, submission)
|
||||||
|
return self.success({"submission": data, "count": count})
|
||||||
|
|
||||||
|
|
||||||
class FlowchartSubmissionCurrentAPI(APIView):
|
class FlowchartSubmissionCurrentAPI(AsyncAPIView):
|
||||||
@login_required
|
@login_required
|
||||||
def get(self, request):
|
async def get(self, request):
|
||||||
"""获取当前用户对指定题目的最新流程图提交,只返回次数和分数"""
|
|
||||||
problem_id = request.GET.get("problem_id")
|
problem_id = request.GET.get("problem_id")
|
||||||
if not problem_id:
|
if not problem_id:
|
||||||
return self.error("problem_id is required")
|
return self.error("problem_id is required")
|
||||||
try:
|
try:
|
||||||
problem = Problem.objects.get(id=problem_id)
|
problem = await Problem.objects.aget(id=problem_id)
|
||||||
except Problem.DoesNotExist:
|
except Problem.DoesNotExist:
|
||||||
return self.error("Problem doesn't exist")
|
return self.error("Problem doesn't exist")
|
||||||
submissions = (
|
submissions = (
|
||||||
@@ -197,10 +200,10 @@ class FlowchartSubmissionCurrentAPI(APIView):
|
|||||||
.values("ai_score", "ai_grade")
|
.values("ai_score", "ai_grade")
|
||||||
.order_by("-create_time")
|
.order_by("-create_time")
|
||||||
)
|
)
|
||||||
count = submissions.count()
|
count = await submissions.acount()
|
||||||
if count == 0:
|
if count == 0:
|
||||||
return self.success({"count": 0, "score": 0, "grade": ""})
|
return self.success({"count": 0, "score": 0, "grade": ""})
|
||||||
submission = submissions[0]
|
submission = await submissions.afirst()
|
||||||
return self.success(
|
return self.success(
|
||||||
{
|
{
|
||||||
"count": count,
|
"count": count,
|
||||||
|
|||||||
@@ -3,33 +3,33 @@ from account.models import User
|
|||||||
from message.models import Message
|
from message.models import Message
|
||||||
from message.serializers import CreateMessageSerializer, MessageSerializer
|
from message.serializers import CreateMessageSerializer, MessageSerializer
|
||||||
from submission.models import Submission
|
from submission.models import Submission
|
||||||
from utils.api import APIView
|
from utils.api import AsyncAPIView
|
||||||
from utils.api.api import validate_serializer
|
from utils.api.api import validate_serializer
|
||||||
|
|
||||||
|
|
||||||
class MessageAPI(APIView):
|
class MessageAPI(AsyncAPIView):
|
||||||
@login_required
|
@login_required
|
||||||
def get(self, request):
|
async def get(self, request):
|
||||||
messages = Message.objects.select_related(
|
messages = Message.objects.select_related(
|
||||||
"recipient", "sender", "submission", "submission__problem"
|
"recipient", "sender", "submission", "submission__problem"
|
||||||
).filter(recipient=request.user)
|
).filter(recipient=request.user)
|
||||||
return self.success(self.paginate_data(request, messages, MessageSerializer))
|
return self.success(await self.async_paginate_data(request, messages, MessageSerializer))
|
||||||
|
|
||||||
@validate_serializer(CreateMessageSerializer)
|
@validate_serializer(CreateMessageSerializer)
|
||||||
@super_admin_required
|
@super_admin_required
|
||||||
def post(self, request):
|
async def post(self, request):
|
||||||
data = request.data
|
data = request.data
|
||||||
if data["recipient"] == request.user.id:
|
if data["recipient"] == request.user.id:
|
||||||
return self.error("Can not send a message to youself")
|
return self.error("Can not send a message to youself")
|
||||||
try:
|
try:
|
||||||
recipient = User.objects.get(id=data["recipient"], is_disabled=False)
|
recipient = await User.objects.aget(id=data["recipient"], is_disabled=False)
|
||||||
except User.DoesNotExist:
|
except User.DoesNotExist:
|
||||||
return self.error("User does not exist")
|
return self.error("User does not exist")
|
||||||
try:
|
try:
|
||||||
submission = Submission.objects.get(id=data["submission"])
|
submission = await Submission.objects.aget(id=data["submission"])
|
||||||
except Submission.DoesNotExist:
|
except Submission.DoesNotExist:
|
||||||
return self.error("Submission does not exist")
|
return self.error("Submission does not exist")
|
||||||
Message.objects.create(
|
await Message.objects.acreate(
|
||||||
submission=submission,
|
submission=submission,
|
||||||
message=data["message"],
|
message=data["message"],
|
||||||
sender=request.user,
|
sender=request.user,
|
||||||
|
|||||||
@@ -292,4 +292,15 @@ class _SysOptionsMeta(type):
|
|||||||
|
|
||||||
|
|
||||||
class SysOptions(metaclass=_SysOptionsMeta):
|
class SysOptions(metaclass=_SysOptionsMeta):
|
||||||
pass
|
@classmethod
|
||||||
|
async def aget(cls, key):
|
||||||
|
from asgiref.sync import sync_to_async
|
||||||
|
return await sync_to_async(getattr)(cls, key)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def aget_many(cls, *keys):
|
||||||
|
from asgiref.sync import sync_to_async
|
||||||
|
|
||||||
|
def _get_all():
|
||||||
|
return {k: getattr(cls, k) for k in keys}
|
||||||
|
return await sync_to_async(_get_all)()
|
||||||
|
|||||||
@@ -1,15 +1,16 @@
|
|||||||
import random
|
import random
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from django.core.cache import cache
|
from asgiref.sync import sync_to_async
|
||||||
from django.db.models import BooleanField, Case, Count, Q, Value, When
|
from django.db.models import BooleanField, Case, Count, Q, Value, When
|
||||||
from django.db.models.functions import ExtractYear
|
from django.db.models.functions import ExtractYear
|
||||||
|
from django.utils import timezone
|
||||||
|
|
||||||
from account.decorators import check_contest_permission
|
from account.decorators import check_contest_permission
|
||||||
from account.models import User
|
from account.models import User
|
||||||
from contest.models import ContestRuleType
|
from contest.models import ContestRuleType
|
||||||
from submission.models import JudgeStatus, Submission
|
from submission.models import JudgeStatus, Submission
|
||||||
from utils.api import APIView
|
from utils.api import APIView, AsyncAPIView
|
||||||
|
from utils.async_helpers import async_cache_get, async_cache_set
|
||||||
from utils.constants import CacheKey
|
from utils.constants import CacheKey
|
||||||
|
|
||||||
from ..models import Problem, ProblemTag
|
from ..models import Problem, ProblemTag
|
||||||
@@ -21,11 +22,11 @@ from ..serializers import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ProblemTagAPI(APIView):
|
class ProblemTagAPI(AsyncAPIView):
|
||||||
def get(self, request):
|
async def get(self, request):
|
||||||
keyword = request.GET.get("keyword", "")
|
keyword = request.GET.get("keyword", "")
|
||||||
cache_key = f"{CacheKey.problem_tags}:{keyword}"
|
cache_key = f"{CacheKey.problem_tags}:{keyword}"
|
||||||
cached = cache.get(cache_key)
|
cached = await async_cache_get(cache_key)
|
||||||
if cached is not None:
|
if cached is not None:
|
||||||
return self.success(cached)
|
return self.success(cached)
|
||||||
|
|
||||||
@@ -33,48 +34,48 @@ class ProblemTagAPI(APIView):
|
|||||||
if keyword:
|
if keyword:
|
||||||
qs = ProblemTag.objects.filter(name__icontains=keyword)
|
qs = ProblemTag.objects.filter(name__icontains=keyword)
|
||||||
tags = qs.annotate(problem_count=Count("problem")).filter(problem_count__gt=0)
|
tags = qs.annotate(problem_count=Count("problem")).filter(problem_count__gt=0)
|
||||||
data = TagSerializer(tags, many=True).data
|
data = await self.async_serialize_data(TagSerializer, [tag async for tag in tags], many=True)
|
||||||
cache.set(cache_key, data, 3600)
|
await async_cache_set(cache_key, data, 3600)
|
||||||
return self.success(data)
|
return self.success(data)
|
||||||
|
|
||||||
|
|
||||||
class PickOneAPI(APIView):
|
class PickOneAPI(AsyncAPIView):
|
||||||
def get(self, request):
|
async def get(self, request):
|
||||||
problems = Problem.objects.filter(contest_id__isnull=True, visible=True)
|
ids = Problem.objects.filter(contest_id__isnull=True, visible=True).values_list("_id", flat=True)
|
||||||
count = problems.count()
|
count = await ids.acount()
|
||||||
if count == 0:
|
if count == 0:
|
||||||
return self.error("No problem to pick")
|
return self.error("No problem to pick")
|
||||||
return self.success(problems[random.randint(0, count - 1)]._id)
|
idx = random.randint(0, count - 1)
|
||||||
|
result = [pid async for pid in ids[idx : idx + 1]]
|
||||||
|
return self.success(result[0])
|
||||||
|
|
||||||
|
|
||||||
class ProblemAPI(APIView):
|
class ProblemAPI(AsyncAPIView):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _add_problem_status(request, queryset_values):
|
def _add_problem_status(acm_problems_status, queryset_values):
|
||||||
if request.user.is_authenticated:
|
results = queryset_values.get("results")
|
||||||
profile = request.user.userprofile
|
if results is not None:
|
||||||
acm_problems_status = profile.acm_problems_status.get("problems", {})
|
problems = results
|
||||||
# paginate data
|
else:
|
||||||
results = queryset_values.get("results")
|
problems = [queryset_values]
|
||||||
if results is not None:
|
for problem in problems:
|
||||||
problems = results
|
problem["my_status"] = acm_problems_status.get(str(problem["id"]), {}).get("status")
|
||||||
else:
|
|
||||||
problems = [queryset_values]
|
|
||||||
for problem in problems:
|
|
||||||
problem["my_status"] = acm_problems_status.get(
|
|
||||||
str(problem["id"]), {}
|
|
||||||
).get("status")
|
|
||||||
|
|
||||||
def get(self, request):
|
async def get(self, request):
|
||||||
# 问题详情页
|
# 问题详情页
|
||||||
problem_id = request.GET.get("problem_id")
|
problem_id = request.GET.get("problem_id")
|
||||||
if problem_id:
|
if problem_id:
|
||||||
try:
|
try:
|
||||||
problem = Problem.objects.select_related("created_by").get(
|
problem = await Problem.objects.select_related("created_by").prefetch_related("tags").filter(_id__iexact=problem_id, contest_id__isnull=True, visible=True).afirst()
|
||||||
_id__iexact=problem_id, contest_id__isnull=True, visible=True
|
if problem is None:
|
||||||
)
|
raise Problem.DoesNotExist
|
||||||
problem_data = ProblemSerializer(problem).data
|
problem_data = await self.async_serialize_data(ProblemSerializer, problem)
|
||||||
self._add_problem_status(request, problem_data)
|
|
||||||
if request.user.is_authenticated:
|
if request.user.is_authenticated:
|
||||||
|
from account.models import UserProfile
|
||||||
|
|
||||||
|
profile = await UserProfile.objects.aget(user=request.user)
|
||||||
|
acm_problems_status = profile.acm_problems_status.get("problems", {})
|
||||||
|
self._add_problem_status(acm_problems_status, problem_data)
|
||||||
failed_statuses = [
|
failed_statuses = [
|
||||||
JudgeStatus.WRONG_ANSWER,
|
JudgeStatus.WRONG_ANSWER,
|
||||||
JudgeStatus.CPU_TIME_LIMIT_EXCEEDED,
|
JudgeStatus.CPU_TIME_LIMIT_EXCEEDED,
|
||||||
@@ -83,11 +84,11 @@ class ProblemAPI(APIView):
|
|||||||
JudgeStatus.RUNTIME_ERROR,
|
JudgeStatus.RUNTIME_ERROR,
|
||||||
JudgeStatus.COMPILE_ERROR,
|
JudgeStatus.COMPILE_ERROR,
|
||||||
]
|
]
|
||||||
problem_data["my_failed_count"] = Submission.objects.filter(
|
problem_data["my_failed_count"] = await Submission.objects.filter(
|
||||||
user_id=request.user.id,
|
user_id=request.user.id,
|
||||||
problem_id=problem.id,
|
problem_id=problem.id,
|
||||||
result__in=failed_statuses,
|
result__in=failed_statuses,
|
||||||
).count()
|
).acount()
|
||||||
else:
|
else:
|
||||||
problem_data["my_failed_count"] = 0
|
problem_data["my_failed_count"] = 0
|
||||||
return self.success(problem_data)
|
return self.success(problem_data)
|
||||||
@@ -98,12 +99,7 @@ class ProblemAPI(APIView):
|
|||||||
if not limit:
|
if not limit:
|
||||||
return self.error("Limit is needed")
|
return self.error("Limit is needed")
|
||||||
|
|
||||||
problems = (
|
problems = Problem.objects.select_related("created_by").prefetch_related("tags").filter(contest_id__isnull=True, visible=True).order_by("-create_time")
|
||||||
Problem.objects.select_related("created_by")
|
|
||||||
.prefetch_related("tags")
|
|
||||||
.filter(contest_id__isnull=True, visible=True)
|
|
||||||
.order_by("-create_time")
|
|
||||||
)
|
|
||||||
|
|
||||||
author = request.GET.get("author")
|
author = request.GET.get("author")
|
||||||
if author:
|
if author:
|
||||||
@@ -117,9 +113,7 @@ class ProblemAPI(APIView):
|
|||||||
# 搜索的情况
|
# 搜索的情况
|
||||||
keyword = request.GET.get("keyword", "").strip()
|
keyword = request.GET.get("keyword", "").strip()
|
||||||
if keyword:
|
if keyword:
|
||||||
problems = problems.filter(
|
problems = problems.filter(Q(title__icontains=keyword) | Q(_id__icontains=keyword))
|
||||||
Q(title__icontains=keyword) | Q(_id__icontains=keyword)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 难度筛选
|
# 难度筛选
|
||||||
difficulty = request.GET.get("difficulty")
|
difficulty = request.GET.get("difficulty")
|
||||||
@@ -142,8 +136,13 @@ class ProblemAPI(APIView):
|
|||||||
problems = problems.order_by(sort)
|
problems = problems.order_by(sort)
|
||||||
|
|
||||||
# 根据profile 为做过的题目添加标记
|
# 根据profile 为做过的题目添加标记
|
||||||
data = self.paginate_data(request, problems, ProblemListSerializer)
|
data = await self.async_paginate_data(request, problems, ProblemListSerializer)
|
||||||
self._add_problem_status(request, data)
|
if request.user.is_authenticated:
|
||||||
|
from account.models import UserProfile
|
||||||
|
|
||||||
|
profile = await UserProfile.objects.aget(user=request.user)
|
||||||
|
acm_problems_status = profile.acm_problems_status.get("problems", {})
|
||||||
|
self._add_problem_status(acm_problems_status, data)
|
||||||
return self.success(data)
|
return self.success(data)
|
||||||
|
|
||||||
|
|
||||||
@@ -152,24 +151,18 @@ class ContestProblemAPI(APIView):
|
|||||||
if request.user.is_authenticated:
|
if request.user.is_authenticated:
|
||||||
profile = request.user.userprofile
|
profile = request.user.userprofile
|
||||||
if self.contest.rule_type == ContestRuleType.ACM:
|
if self.contest.rule_type == ContestRuleType.ACM:
|
||||||
problems_status = profile.acm_problems_status.get(
|
problems_status = profile.acm_problems_status.get("contest_problems", {})
|
||||||
"contest_problems", {}
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
problems_status = profile.oi_problems_status.get("contest_problems", {})
|
problems_status = profile.oi_problems_status.get("contest_problems", {})
|
||||||
for problem in queryset_values:
|
for problem in queryset_values:
|
||||||
problem["my_status"] = problems_status.get(str(problem["id"]), {}).get(
|
problem["my_status"] = problems_status.get(str(problem["id"]), {}).get("status")
|
||||||
"status"
|
|
||||||
)
|
|
||||||
|
|
||||||
@check_contest_permission(check_type="problems")
|
@check_contest_permission(check_type="problems")
|
||||||
def get(self, request):
|
def get(self, request):
|
||||||
problem_id = request.GET.get("problem_id")
|
problem_id = request.GET.get("problem_id")
|
||||||
if problem_id:
|
if problem_id:
|
||||||
try:
|
try:
|
||||||
problem = Problem.objects.select_related("created_by").get(
|
problem = Problem.objects.select_related("created_by").get(_id__iexact=problem_id, contest=self.contest, visible=True)
|
||||||
_id__iexact=problem_id, contest=self.contest, visible=True
|
|
||||||
)
|
|
||||||
except Problem.DoesNotExist:
|
except Problem.DoesNotExist:
|
||||||
return self.error("Problem does not exist.")
|
return self.error("Problem does not exist.")
|
||||||
if self.contest.problem_details_permission(request.user):
|
if self.contest.problem_details_permission(request.user):
|
||||||
@@ -184,9 +177,7 @@ class ContestProblemAPI(APIView):
|
|||||||
problem_data = ProblemSafeSerializer(problem).data
|
problem_data = ProblemSafeSerializer(problem).data
|
||||||
return self.success(problem_data)
|
return self.success(problem_data)
|
||||||
|
|
||||||
contest_problems = Problem.objects.select_related("created_by").prefetch_related("tags").filter(
|
contest_problems = Problem.objects.select_related("created_by").prefetch_related("tags").filter(contest=self.contest, visible=True)
|
||||||
contest=self.contest, visible=True
|
|
||||||
)
|
|
||||||
if self.contest.problem_details_permission(request.user):
|
if self.contest.problem_details_permission(request.user):
|
||||||
data = ProblemListSerializer(contest_problems, many=True).data
|
data = ProblemListSerializer(contest_problems, many=True).data
|
||||||
self._add_problem_status(request, data)
|
self._add_problem_status(request, data)
|
||||||
@@ -195,59 +186,60 @@ class ContestProblemAPI(APIView):
|
|||||||
return self.success(data)
|
return self.success(data)
|
||||||
|
|
||||||
|
|
||||||
class ProblemSolvedPeopleCount(APIView):
|
class ProblemSolvedPeopleCount(AsyncAPIView):
|
||||||
def get(self, request):
|
async def get(self, request):
|
||||||
problem_id = request.GET.get("problem_id")
|
problem_id = request.GET.get("problem_id")
|
||||||
rate = "0"
|
rate = "0"
|
||||||
if not request.user.is_authenticated:
|
if not request.user.is_authenticated:
|
||||||
return self.success(rate)
|
return self.success(rate)
|
||||||
submission_count = Submission.objects.filter(
|
submission_count = await Submission.objects.filter(
|
||||||
user_id=request.user.id,
|
user_id=request.user.id,
|
||||||
problem_id=problem_id,
|
problem_id=problem_id,
|
||||||
result__in=[JudgeStatus.ACCEPTED, JudgeStatus.AST_CHECK_FAILED],
|
result__in=[JudgeStatus.ACCEPTED, JudgeStatus.AST_CHECK_FAILED],
|
||||||
).count()
|
).acount()
|
||||||
if submission_count == 0:
|
if submission_count == 0:
|
||||||
return self.success(rate)
|
return self.success(rate)
|
||||||
today = datetime.today()
|
now = timezone.now()
|
||||||
years_ago = datetime(today.year - 2, today.month, today.day, 0, 0)
|
years_ago = now.replace(year=now.year - 2, hour=0, minute=0, second=0, microsecond=0)
|
||||||
total_count = User.objects.filter(
|
total_count = await User.objects.filter(is_disabled=False, last_login__gte=years_ago).acount()
|
||||||
is_disabled=False, last_login__gte=years_ago
|
accepted_count = (
|
||||||
).count()
|
await sync_to_async(
|
||||||
accepted_count = Submission.objects.filter(
|
Submission.objects.filter(
|
||||||
problem_id=problem_id,
|
problem_id=problem_id,
|
||||||
result__in=[JudgeStatus.ACCEPTED, JudgeStatus.AST_CHECK_FAILED],
|
result__in=[JudgeStatus.ACCEPTED, JudgeStatus.AST_CHECK_FAILED],
|
||||||
create_time__gte=years_ago,
|
create_time__gte=years_ago,
|
||||||
).aggregate(user_count=Count("user_id", distinct=True))["user_count"]
|
).aggregate,
|
||||||
if accepted_count < total_count:
|
thread_sensitive=True,
|
||||||
|
)(user_count=Count("user_id", distinct=True))
|
||||||
|
)["user_count"]
|
||||||
|
if total_count and accepted_count < total_count:
|
||||||
rate = "%.2f" % ((total_count - accepted_count) / total_count * 100)
|
rate = "%.2f" % ((total_count - accepted_count) / total_count * 100)
|
||||||
else:
|
else:
|
||||||
rate = "0"
|
rate = "0"
|
||||||
return self.success(rate)
|
return self.success(rate)
|
||||||
|
|
||||||
|
|
||||||
class SimilarProblemAPI(APIView):
|
class SimilarProblemAPI(AsyncAPIView):
|
||||||
def get(self, request):
|
async def get(self, request):
|
||||||
problem_display_id = request.GET.get("problem_id")
|
problem_display_id = request.GET.get("problem_id")
|
||||||
if not problem_display_id:
|
if not problem_display_id:
|
||||||
return self.error("problem_id is required")
|
return self.error("problem_id is required")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
problem = Problem.objects.get(_id__iexact=problem_display_id, contest__isnull=True)
|
problem = await Problem.objects.aget(_id__iexact=problem_display_id, contest__isnull=True)
|
||||||
except Problem.DoesNotExist:
|
except Problem.DoesNotExist:
|
||||||
return self.error("Problem not found")
|
return self.error("Problem not found")
|
||||||
|
|
||||||
tag_ids = list(problem.tags.values_list("id", flat=True))
|
tag_ids = [tag_id async for tag_id in problem.tags.values_list("id", flat=True)]
|
||||||
if not tag_ids:
|
if not tag_ids:
|
||||||
return self.success([])
|
return self.success([])
|
||||||
|
|
||||||
exclude_ids = [problem_display_id]
|
exclude_ids = [problem_display_id]
|
||||||
if request.user.is_authenticated:
|
if request.user.is_authenticated:
|
||||||
profile = request.user.userprofile
|
from account.models import UserProfile
|
||||||
ac_display_ids = [
|
|
||||||
v["_id"]
|
profile = await UserProfile.objects.aget(user=request.user)
|
||||||
for v in profile.acm_problems_status.get("problems", {}).values()
|
ac_display_ids = [v["_id"] for v in profile.acm_problems_status.get("problems", {}).values() if v.get("status") == JudgeStatus.ACCEPTED]
|
||||||
if v.get("status") == JudgeStatus.ACCEPTED
|
|
||||||
]
|
|
||||||
exclude_ids.extend(ac_display_ids)
|
exclude_ids.extend(ac_display_ids)
|
||||||
|
|
||||||
similar = (
|
similar = (
|
||||||
@@ -258,14 +250,15 @@ class SimilarProblemAPI(APIView):
|
|||||||
.distinct()
|
.distinct()
|
||||||
.order_by("difficulty")[:5]
|
.order_by("difficulty")[:5]
|
||||||
)
|
)
|
||||||
return self.success(ProblemListSerializer(similar, many=True).data)
|
similar_list = [problem async for problem in similar]
|
||||||
|
return self.success(await self.async_serialize_data(ProblemListSerializer, similar_list, many=True))
|
||||||
|
|
||||||
|
|
||||||
class ProblemAuthorAPI(APIView):
|
class ProblemAuthorAPI(AsyncAPIView):
|
||||||
def get(self, request):
|
async def get(self, request):
|
||||||
show_all = request.GET.get("all", "0") == "1"
|
show_all = request.GET.get("all", "0") == "1"
|
||||||
cache_key = f"{CacheKey.problem_authors}{'_all' if show_all else '_only_visible'}"
|
cache_key = f"{CacheKey.problem_authors}{'_all' if show_all else '_only_visible'}"
|
||||||
cached_data = cache.get(cache_key)
|
cached_data = await async_cache_get(cache_key)
|
||||||
if cached_data:
|
if cached_data:
|
||||||
return self.success(cached_data)
|
return self.success(cached_data)
|
||||||
|
|
||||||
@@ -273,38 +266,32 @@ class ProblemAuthorAPI(APIView):
|
|||||||
if not show_all:
|
if not show_all:
|
||||||
problem_filter["visible"] = True
|
problem_filter["visible"] = True
|
||||||
|
|
||||||
authors = (
|
authors = Problem.objects.filter(**problem_filter).values("created_by__username").annotate(problem_count=Count("id")).order_by("-problem_count")
|
||||||
Problem.objects.filter(**problem_filter)
|
|
||||||
.values("created_by__username")
|
|
||||||
.annotate(problem_count=Count("id"))
|
|
||||||
.order_by("-problem_count")
|
|
||||||
)
|
|
||||||
result = [
|
result = [
|
||||||
{
|
{
|
||||||
"username": author["created_by__username"],
|
"username": author["created_by__username"],
|
||||||
"problem_count": author["problem_count"],
|
"problem_count": author["problem_count"],
|
||||||
}
|
}
|
||||||
for author in authors
|
async for author in authors
|
||||||
]
|
]
|
||||||
|
|
||||||
cache.set(cache_key, result, 7200)
|
await async_cache_set(cache_key, result, 7200)
|
||||||
|
return self.success(result)
|
||||||
|
|
||||||
|
|
||||||
class ProblemYearlyACRateAPI(APIView):
|
class ProblemYearlyACRateAPI(AsyncAPIView):
|
||||||
def get(self, request):
|
async def get(self, request):
|
||||||
problem_id = request.GET.get("problem_id")
|
problem_id = request.GET.get("problem_id")
|
||||||
if not problem_id:
|
if not problem_id:
|
||||||
return self.error("problem_id is required")
|
return self.error("problem_id is required")
|
||||||
|
|
||||||
cache_key = f"{CacheKey.problem_yearly_ac}:{problem_id}"
|
cache_key = f"{CacheKey.problem_yearly_ac}:{problem_id}"
|
||||||
cached = cache.get(cache_key)
|
cached = await async_cache_get(cache_key)
|
||||||
if cached is not None:
|
if cached is not None:
|
||||||
return self.success(cached)
|
return self.success(cached)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
problem = Problem.objects.get(
|
problem = await Problem.objects.aget(_id__iexact=problem_id, contest_id__isnull=True, visible=True)
|
||||||
_id__iexact=problem_id, contest_id__isnull=True, visible=True
|
|
||||||
)
|
|
||||||
except Problem.DoesNotExist:
|
except Problem.DoesNotExist:
|
||||||
return self.error("Problem does not exist")
|
return self.error("Problem does not exist")
|
||||||
|
|
||||||
@@ -328,12 +315,10 @@ class ProblemYearlyACRateAPI(APIView):
|
|||||||
"year": row["year"],
|
"year": row["year"],
|
||||||
"total": row["total"],
|
"total": row["total"],
|
||||||
"accepted": row["accepted"],
|
"accepted": row["accepted"],
|
||||||
"ac_rate": round(row["accepted"] / row["total"] * 100, 2)
|
"ac_rate": round(row["accepted"] / row["total"] * 100, 2) if row["total"] > 0 else 0.0,
|
||||||
if row["total"] > 0
|
|
||||||
else 0.0,
|
|
||||||
}
|
}
|
||||||
for row in rows
|
async for row in rows
|
||||||
]
|
]
|
||||||
|
|
||||||
cache.set(cache_key, data, 3600)
|
await async_cache_set(cache_key, data, 3600)
|
||||||
return self.success(data)
|
return self.success(data)
|
||||||
|
|||||||
@@ -63,7 +63,7 @@ urlpatterns = [
|
|||||||
name="admin_problemset_progress_detail_api",
|
name="admin_problemset_progress_detail_api",
|
||||||
),
|
),
|
||||||
# 题单同步管理API
|
# 题单同步管理API
|
||||||
path(
|
path( # DEPRECATED: 前端未调用
|
||||||
"problemset/<int:problem_set_id>/sync",
|
"problemset/<int:problem_set_id>/sync",
|
||||||
ProblemSetSyncAPI.as_view(),
|
ProblemSetSyncAPI.as_view(),
|
||||||
name="admin_problemset_sync_api",
|
name="admin_problemset_sync_api",
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ urlpatterns = [
|
|||||||
ProblemSetProblemAPI.as_view(),
|
ProblemSetProblemAPI.as_view(),
|
||||||
name="problemset_problems_api",
|
name="problemset_problems_api",
|
||||||
),
|
),
|
||||||
path(
|
path( # DEPRECATED: 前端未调用
|
||||||
"problemset/<int:problem_set_id>/problems/<int:problem_id>",
|
"problemset/<int:problem_set_id>/problems/<int:problem_id>",
|
||||||
ProblemSetProblemAPI.as_view(),
|
ProblemSetProblemAPI.as_view(),
|
||||||
name="problemset_problem_detail_api",
|
name="problemset_problem_detail_api",
|
||||||
@@ -35,12 +35,12 @@ urlpatterns = [
|
|||||||
ProblemSetProgressAPI.as_view(),
|
ProblemSetProgressAPI.as_view(),
|
||||||
name="problemset_progress_api",
|
name="problemset_progress_api",
|
||||||
),
|
),
|
||||||
path(
|
path( # DEPRECATED: 前端未调用
|
||||||
"problemset/<int:problem_set_id>/progress",
|
"problemset/<int:problem_set_id>/progress",
|
||||||
ProblemSetProgressAPI.as_view(),
|
ProblemSetProgressAPI.as_view(),
|
||||||
name="problemset_progress_detail_api",
|
name="problemset_progress_detail_api",
|
||||||
),
|
),
|
||||||
path("user/progress", UserProgressAPI.as_view(), name="user_progress_api"),
|
path("user/progress", UserProgressAPI.as_view(), name="user_progress_api"), # DEPRECATED: 前端未调用
|
||||||
# 奖章相关API
|
# 奖章相关API
|
||||||
path("user/badges", UserBadgeAPI.as_view(), name="user_badges_api"),
|
path("user/badges", UserBadgeAPI.as_view(), name="user_badges_api"),
|
||||||
path(
|
path(
|
||||||
|
|||||||
@@ -332,6 +332,7 @@ class ProblemSetProgressAdminAPI(APIView):
|
|||||||
return self.error("用户未加入该题单")
|
return self.error("用户未加入该题单")
|
||||||
|
|
||||||
|
|
||||||
|
# DEPRECATED: 前端未调用 (2026-05-26)
|
||||||
class ProblemSetSyncAPI(APIView):
|
class ProblemSetSyncAPI(APIView):
|
||||||
"""题单同步管理API"""
|
"""题单同步管理API"""
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
from asgiref.sync import sync_to_async
|
||||||
from django.db.models import Avg, Count, Prefetch, Q
|
from django.db.models import Avg, Count, Prefetch, Q
|
||||||
from django.utils import timezone
|
from django.utils import timezone
|
||||||
|
|
||||||
@@ -24,14 +25,14 @@ from problemset.serializers import (
|
|||||||
UpdateProgressSerializer,
|
UpdateProgressSerializer,
|
||||||
UserBadgeSerializer,
|
UserBadgeSerializer,
|
||||||
)
|
)
|
||||||
from submission.models import JudgeStatus, Submission, is_accepted
|
from submission.models import Submission, is_accepted
|
||||||
from utils.api import APIView, validate_serializer
|
from utils.api import APIView, AsyncAPIView, validate_serializer
|
||||||
|
|
||||||
|
|
||||||
class ProblemSetAPI(APIView):
|
class ProblemSetAPI(AsyncAPIView):
|
||||||
"""题单API - 用户端"""
|
"""题单API - 用户端"""
|
||||||
|
|
||||||
def get(self, request):
|
async def get(self, request):
|
||||||
"""获取题单列表"""
|
"""获取题单列表"""
|
||||||
# 预加载创建者信息
|
# 预加载创建者信息
|
||||||
problem_sets = ProblemSet.objects.filter(visible=True).exclude(status=ProblemSetStatus.DRAFT).select_related("created_by")
|
problem_sets = ProblemSet.objects.filter(visible=True).exclude(status=ProblemSetStatus.DRAFT).select_related("created_by")
|
||||||
@@ -65,16 +66,19 @@ class ProblemSetAPI(APIView):
|
|||||||
user_earned_badge_ids = set()
|
user_earned_badge_ids = set()
|
||||||
if request.user.is_authenticated:
|
if request.user.is_authenticated:
|
||||||
# 先获取所有题单ID(不应用prefetch_related,只获取ID)
|
# 先获取所有题单ID(不应用prefetch_related,只获取ID)
|
||||||
problem_set_ids = list(problem_sets.values_list("id", flat=True))
|
problem_set_ids = [problem_set_id async for problem_set_id in problem_sets.values_list("id", flat=True)]
|
||||||
|
|
||||||
if problem_set_ids:
|
if problem_set_ids:
|
||||||
# 批量查询用户在这些题单中的进度
|
# 批量查询用户在这些题单中的进度
|
||||||
user_progresses = ProblemSetProgress.objects.filter(problemset_id__in=problem_set_ids, user=request.user).select_related("problemset")
|
user_progresses = ProblemSetProgress.objects.filter(problemset_id__in=problem_set_ids, user=request.user).select_related("problemset")
|
||||||
# 构建映射:题单ID -> 进度对象
|
# 构建映射:题单ID -> 进度对象
|
||||||
user_progress_map = {progress.problemset_id: progress for progress in user_progresses}
|
user_progress_map = {progress.problemset_id: progress async for progress in user_progresses}
|
||||||
|
|
||||||
# 批量查询用户已获得的奖章ID(这些题单相关的)
|
# 批量查询用户已获得的奖章ID(这些题单相关的)
|
||||||
user_earned_badge_ids = set(UserBadge.objects.filter(user=request.user, badge__problemset_id__in=problem_set_ids).values_list("badge_id", flat=True))
|
user_earned_badge_ids = {
|
||||||
|
badge_id
|
||||||
|
async for badge_id in UserBadge.objects.filter(user=request.user, badge__problemset_id__in=problem_set_ids).values_list("badge_id", flat=True)
|
||||||
|
}
|
||||||
|
|
||||||
# 预加载奖章信息(在获取ID之后应用,避免在获取ID时也预加载)
|
# 预加载奖章信息(在获取ID之后应用,避免在获取ID时也预加载)
|
||||||
problem_sets = problem_sets.prefetch_related(Prefetch("problemsetbadge_set", queryset=ProblemSetBadge.objects.all(), to_attr="badges"))
|
problem_sets = problem_sets.prefetch_related(Prefetch("problemsetbadge_set", queryset=ProblemSetBadge.objects.all(), to_attr="badges"))
|
||||||
@@ -83,31 +87,35 @@ class ProblemSetAPI(APIView):
|
|||||||
request._user_progress_map = user_progress_map
|
request._user_progress_map = user_progress_map
|
||||||
request._user_earned_badge_ids = user_earned_badge_ids
|
request._user_earned_badge_ids = user_earned_badge_ids
|
||||||
|
|
||||||
data = self.paginate_data(request, problem_sets, ProblemSetListSerializer)
|
data = await self.async_paginate_data(request, problem_sets, ProblemSetListSerializer)
|
||||||
return self.success(data)
|
return self.success(data)
|
||||||
|
|
||||||
|
|
||||||
class ProblemSetDetailAPI(APIView):
|
class ProblemSetDetailAPI(AsyncAPIView):
|
||||||
"""题单详情API - 用户端"""
|
"""题单详情API - 用户端"""
|
||||||
|
|
||||||
def get(self, request, problem_set_id):
|
async def get(self, request, problem_set_id):
|
||||||
"""获取题单详情"""
|
"""获取题单详情"""
|
||||||
try:
|
try:
|
||||||
problem_set = ProblemSet.objects.filter(id=problem_set_id, visible=True).exclude(status=ProblemSetStatus.DRAFT).get()
|
problem_set = await (
|
||||||
|
ProblemSet.objects.select_related("created_by")
|
||||||
|
.filter(id=problem_set_id, visible=True)
|
||||||
|
.exclude(status=ProblemSetStatus.DRAFT)
|
||||||
|
.aget()
|
||||||
|
)
|
||||||
except ProblemSet.DoesNotExist:
|
except ProblemSet.DoesNotExist:
|
||||||
return self.error("题单不存在")
|
return self.error("题单不存在")
|
||||||
|
|
||||||
serializer = ProblemSetSerializer(problem_set, context={"request": request})
|
return self.success(await self.async_serialize_data(ProblemSetSerializer, problem_set, context={"request": request}))
|
||||||
return self.success(serializer.data)
|
|
||||||
|
|
||||||
|
|
||||||
class ProblemSetProblemAPI(APIView):
|
class ProblemSetProblemAPI(AsyncAPIView):
|
||||||
"""题单题目API - 用户端"""
|
"""题单题目API - 用户端"""
|
||||||
|
|
||||||
def get(self, request, problem_set_id):
|
async def get(self, request, problem_set_id):
|
||||||
"""获取题单中的题目列表"""
|
"""获取题单中的题目列表"""
|
||||||
try:
|
try:
|
||||||
problem_set = ProblemSet.objects.filter(id=problem_set_id, visible=True).exclude(status=ProblemSetStatus.DRAFT).get()
|
problem_set = await ProblemSet.objects.filter(id=problem_set_id, visible=True).exclude(status=ProblemSetStatus.DRAFT).aget()
|
||||||
except ProblemSet.DoesNotExist:
|
except ProblemSet.DoesNotExist:
|
||||||
return self.error("题单不存在")
|
return self.error("题单不存在")
|
||||||
|
|
||||||
@@ -115,12 +123,16 @@ class ProblemSetProblemAPI(APIView):
|
|||||||
# 预取当前用户的题单进度,供 get_is_completed 使用,避免 N+1
|
# 预取当前用户的题单进度,供 get_is_completed 使用,避免 N+1
|
||||||
user_progress = None
|
user_progress = None
|
||||||
if request.user.is_authenticated:
|
if request.user.is_authenticated:
|
||||||
try:
|
user_progress = await ProblemSetProgress.objects.filter(problemset=problem_set, user=request.user).afirst()
|
||||||
user_progress = ProblemSetProgress.objects.get(problemset=problem_set, user=request.user)
|
problem_list = [problem async for problem in problems]
|
||||||
except ProblemSetProgress.DoesNotExist:
|
return self.success(
|
||||||
pass
|
await self.async_serialize_data(
|
||||||
serializer = ProblemSetProblemSerializer(problems, many=True, context={"request": request, "user_progress": user_progress})
|
ProblemSetProblemSerializer,
|
||||||
return self.success(serializer.data)
|
problem_list,
|
||||||
|
many=True,
|
||||||
|
context={"request": request, "user_progress": user_progress},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ProblemSetProgressAPI(APIView):
|
class ProblemSetProgressAPI(APIView):
|
||||||
@@ -236,6 +248,7 @@ class ProblemSetProgressAPI(APIView):
|
|||||||
UserBadge.objects.create(user=progress.user, badge=badge)
|
UserBadge.objects.create(user=progress.user, badge=badge)
|
||||||
|
|
||||||
|
|
||||||
|
# DEPRECATED: 前端未调用 (2026-05-26)
|
||||||
class UserProgressAPI(APIView):
|
class UserProgressAPI(APIView):
|
||||||
"""用户进度API"""
|
"""用户进度API"""
|
||||||
|
|
||||||
@@ -247,10 +260,10 @@ class UserProgressAPI(APIView):
|
|||||||
return self.success(serializer.data)
|
return self.success(serializer.data)
|
||||||
|
|
||||||
|
|
||||||
class UserBadgeAPI(APIView):
|
class UserBadgeAPI(AsyncAPIView):
|
||||||
"""用户奖章API"""
|
"""用户奖章API"""
|
||||||
|
|
||||||
def get(self, request):
|
async def get(self, request):
|
||||||
"""获取用户的奖章列表"""
|
"""获取用户的奖章列表"""
|
||||||
# 支持通过username参数获取指定用户的徽章
|
# 支持通过username参数获取指定用户的徽章
|
||||||
username = request.GET.get("username")
|
username = request.GET.get("username")
|
||||||
@@ -258,41 +271,41 @@ class UserBadgeAPI(APIView):
|
|||||||
if username:
|
if username:
|
||||||
# 获取指定用户的徽章
|
# 获取指定用户的徽章
|
||||||
try:
|
try:
|
||||||
target_user = User.objects.get(username=username, is_disabled=False)
|
target_user = await User.objects.aget(username=username, is_disabled=False)
|
||||||
badges = UserBadge.objects.filter(user=target_user).order_by("-earned_time")
|
badges = UserBadge.objects.select_related("badge").filter(user=target_user).order_by("-earned_time")
|
||||||
except User.DoesNotExist:
|
except User.DoesNotExist:
|
||||||
return self.error("用户不存在")
|
return self.error("用户不存在")
|
||||||
else:
|
else:
|
||||||
# 获取当前用户的徽章
|
# 获取当前用户的徽章
|
||||||
badges = UserBadge.objects.filter(user=request.user).order_by("-earned_time")
|
badges = UserBadge.objects.select_related("badge").filter(user=request.user).order_by("-earned_time")
|
||||||
|
|
||||||
serializer = UserBadgeSerializer(badges, many=True)
|
badge_list = [badge async for badge in badges]
|
||||||
return self.success(serializer.data)
|
return self.success(await self.async_serialize_data(UserBadgeSerializer, badge_list, many=True))
|
||||||
|
|
||||||
|
|
||||||
class ProblemSetBadgeAPI(APIView):
|
class ProblemSetBadgeAPI(AsyncAPIView):
|
||||||
"""题单奖章API - 用户端"""
|
"""题单奖章API - 用户端"""
|
||||||
|
|
||||||
def get(self, request, problem_set_id):
|
async def get(self, request, problem_set_id):
|
||||||
"""获取题单的奖章列表"""
|
"""获取题单的奖章列表"""
|
||||||
try:
|
try:
|
||||||
problem_set = ProblemSet.objects.filter(id=problem_set_id, visible=True).exclude(status=ProblemSetStatus.DRAFT).get()
|
problem_set = await ProblemSet.objects.filter(id=problem_set_id, visible=True).exclude(status=ProblemSetStatus.DRAFT).aget()
|
||||||
except ProblemSet.DoesNotExist:
|
except ProblemSet.DoesNotExist:
|
||||||
return self.error("题单不存在")
|
return self.error("题单不存在")
|
||||||
|
|
||||||
badges = ProblemSetBadge.objects.filter(problemset=problem_set)
|
badges = ProblemSetBadge.objects.filter(problemset=problem_set)
|
||||||
serializer = ProblemSetBadgeSerializer(badges, many=True)
|
badge_list = [badge async for badge in badges]
|
||||||
return self.success(serializer.data)
|
return self.success(await self.async_serialize_data(ProblemSetBadgeSerializer, badge_list, many=True))
|
||||||
|
|
||||||
|
|
||||||
class ProblemSetUserProgressAPI(APIView):
|
class ProblemSetUserProgressAPI(AsyncAPIView):
|
||||||
"""题单用户进度列表API"""
|
"""题单用户进度列表API"""
|
||||||
|
|
||||||
@admin_role_required
|
@admin_role_required
|
||||||
def get(self, request, problem_set_id: int):
|
async def get(self, request, problem_set_id: int):
|
||||||
"""获取题单的用户进度列表"""
|
"""获取题单的用户进度列表"""
|
||||||
try:
|
try:
|
||||||
problem_set = ProblemSet.objects.filter(id=problem_set_id, visible=True).exclude(status=ProblemSetStatus.DRAFT).get()
|
problem_set = await ProblemSet.objects.filter(id=problem_set_id, visible=True).exclude(status=ProblemSetStatus.DRAFT).aget()
|
||||||
except ProblemSet.DoesNotExist:
|
except ProblemSet.DoesNotExist:
|
||||||
return self.error("题单不存在")
|
return self.error("题单不存在")
|
||||||
|
|
||||||
@@ -321,7 +334,7 @@ class ProblemSetUserProgressAPI(APIView):
|
|||||||
|
|
||||||
# 计算统计数据(基于所有数据,而非分页数据)
|
# 计算统计数据(基于所有数据,而非分页数据)
|
||||||
# 使用一次查询获取所有统计数据
|
# 使用一次查询获取所有统计数据
|
||||||
stats = progresses.aggregate(
|
stats = await sync_to_async(progresses.aggregate, thread_sensitive=True)(
|
||||||
total=Count("id"),
|
total=Count("id"),
|
||||||
completed=Count("id", filter=Q(is_completed=True)),
|
completed=Count("id", filter=Q(is_completed=True)),
|
||||||
avg_progress=Avg("progress_percentage"),
|
avg_progress=Avg("progress_percentage"),
|
||||||
@@ -351,7 +364,7 @@ class ProblemSetUserProgressAPI(APIView):
|
|||||||
# 构建题单所有题目的数据结构和映射
|
# 构建题单所有题目的数据结构和映射
|
||||||
all_problems_list = []
|
all_problems_list = []
|
||||||
all_problems_map = {}
|
all_problems_map = {}
|
||||||
for psp in all_problemset_problems:
|
async for psp in all_problemset_problems:
|
||||||
problem_data = {
|
problem_data = {
|
||||||
"id": psp.problem.id,
|
"id": psp.problem.id,
|
||||||
"_id": psp.problem._id,
|
"_id": psp.problem._id,
|
||||||
@@ -362,7 +375,7 @@ class ProblemSetUserProgressAPI(APIView):
|
|||||||
all_problems_map[str(psp.problem.id)] = psp.problem
|
all_problems_map[str(psp.problem.id)] = psp.problem
|
||||||
|
|
||||||
# 从当前页的数据中收集已完成的问题ID,用于序列化器
|
# 从当前页的数据中收集已完成的问题ID,用于序列化器
|
||||||
paginated_progresses = list(progresses[offset : offset + limit])
|
paginated_progresses = [progress async for progress in progresses[offset : offset + limit]]
|
||||||
completed_problem_ids = set()
|
completed_problem_ids = set()
|
||||||
for progress in paginated_progresses:
|
for progress in paginated_progresses:
|
||||||
if progress.progress_detail:
|
if progress.progress_detail:
|
||||||
@@ -376,7 +389,7 @@ class ProblemSetUserProgressAPI(APIView):
|
|||||||
request._problems_dict_cache = problems_dict
|
request._problems_dict_cache = problems_dict
|
||||||
|
|
||||||
# 使用分页
|
# 使用分页
|
||||||
data = self.paginate_data(request, progresses, ProblemSetProgressSerializer)
|
data = await self.async_paginate_data(request, progresses, ProblemSetProgressSerializer)
|
||||||
|
|
||||||
# 添加统计数据
|
# 添加统计数据
|
||||||
data["statistics"] = {
|
data["statistics"] = {
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ dependencies = [
|
|||||||
"tree-sitter-c>=0.24.2",
|
"tree-sitter-c>=0.24.2",
|
||||||
"tree-sitter-python>=0.25.0",
|
"tree-sitter-python>=0.25.0",
|
||||||
"xlsxwriter>=3.2.9,<4",
|
"xlsxwriter>=3.2.9,<4",
|
||||||
|
"asgiref>=3.11.1",
|
||||||
]
|
]
|
||||||
|
|
||||||
[dependency-groups]
|
[dependency-groups]
|
||||||
@@ -35,6 +36,10 @@ dev = [
|
|||||||
"ruff>=0.15.11",
|
"ruff>=0.15.11",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[tool.pyright]
|
||||||
|
venvPath = "."
|
||||||
|
venv = ".venv"
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
line-length = 180
|
line-length = 180
|
||||||
exclude = ["*/migrations/*", "*settings.py", "*/apps.py", ".venv"]
|
exclude = ["*/migrations/*", "*settings.py", "*/apps.py", ".venv"]
|
||||||
|
|||||||
@@ -12,6 +12,6 @@ urlpatterns = [
|
|||||||
path("submission", SubmissionAPI.as_view()),
|
path("submission", SubmissionAPI.as_view()),
|
||||||
path("submissions", SubmissionListAPI.as_view()),
|
path("submissions", SubmissionListAPI.as_view()),
|
||||||
path("submissions/today_count", SubmissionsTodayCount.as_view()),
|
path("submissions/today_count", SubmissionsTodayCount.as_view()),
|
||||||
path("submission_exists", SubmissionExistsAPI.as_view()),
|
path("submission_exists", SubmissionExistsAPI.as_view()), # DEPRECATED: 前端未调用
|
||||||
path("contest_submissions", ContestSubmissionListAPI.as_view()),
|
path("contest_submissions", ContestSubmissionListAPI.as_view()),
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
import ipaddress
|
import ipaddress
|
||||||
from datetime import datetime
|
|
||||||
|
from asgiref.sync import sync_to_async
|
||||||
|
from django.utils import timezone
|
||||||
|
|
||||||
from account.decorators import check_contest_permission, login_required
|
from account.decorators import check_contest_permission, login_required
|
||||||
from contest.models import ContestRuleType, ContestStatus
|
from contest.models import ContestRuleType, ContestStatus
|
||||||
@@ -8,7 +10,7 @@ from options.options import SysOptions
|
|||||||
|
|
||||||
# from judge.dispatcher import JudgeDispatcher
|
# from judge.dispatcher import JudgeDispatcher
|
||||||
from problem.models import Problem, ProblemRuleType
|
from problem.models import Problem, ProblemRuleType
|
||||||
from utils.api import APIView, validate_serializer
|
from utils.api import APIView, AsyncAPIView, validate_serializer
|
||||||
from utils.cache import cache
|
from utils.cache import cache
|
||||||
from utils.captcha import Captcha
|
from utils.captcha import Captcha
|
||||||
from utils.throttling import TokenBucket
|
from utils.throttling import TokenBucket
|
||||||
@@ -154,8 +156,8 @@ class SubmissionAPI(APIView):
|
|||||||
return self.success()
|
return self.success()
|
||||||
|
|
||||||
|
|
||||||
class SubmissionListAPI(APIView):
|
class SubmissionListAPI(AsyncAPIView):
|
||||||
def get(self, request):
|
async def get(self, request):
|
||||||
if not request.GET.get("limit"):
|
if not request.GET.get("limit"):
|
||||||
return self.error("Limit is needed")
|
return self.error("Limit is needed")
|
||||||
if request.GET.get("contest_id"):
|
if request.GET.get("contest_id"):
|
||||||
@@ -171,14 +173,15 @@ class SubmissionListAPI(APIView):
|
|||||||
language = request.GET.get("language")
|
language = request.GET.get("language")
|
||||||
if problem_id:
|
if problem_id:
|
||||||
try:
|
try:
|
||||||
problem = Problem.objects.get(
|
problem = await Problem.objects.aget(
|
||||||
_id__iexact=problem_id, contest_id__isnull=True, visible=True
|
_id__iexact=problem_id, contest_id__isnull=True, visible=True
|
||||||
)
|
)
|
||||||
except Problem.DoesNotExist:
|
except Problem.DoesNotExist:
|
||||||
return self.error("Problem doesn't exist")
|
return self.error("Problem doesn't exist")
|
||||||
submissions = submissions.filter(problem=problem)
|
submissions = submissions.filter(problem=problem)
|
||||||
|
|
||||||
if not SysOptions.submission_list_show_all and request.user.is_regular_user():
|
show_all = await SysOptions.aget("submission_list_show_all")
|
||||||
|
if not show_all and request.user.is_regular_user():
|
||||||
return self.success({"results": [], "total": 0})
|
return self.success({"results": [], "total": 0})
|
||||||
|
|
||||||
if myself and myself == "1":
|
if myself and myself == "1":
|
||||||
@@ -190,21 +193,25 @@ class SubmissionListAPI(APIView):
|
|||||||
if language:
|
if language:
|
||||||
submissions = submissions.filter(language=language)
|
submissions = submissions.filter(language=language)
|
||||||
if request.GET.get("today") == "1":
|
if request.GET.get("today") == "1":
|
||||||
today = datetime.today()
|
now = timezone.now()
|
||||||
submissions = submissions.filter(
|
submissions = submissions.filter(
|
||||||
create_time__gte=datetime(today.year, today.month, today.day, 0, 0)
|
create_time__gte=now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||||
)
|
)
|
||||||
|
|
||||||
data = self.paginate_data(request, submissions)
|
data = await self.async_paginate_data(request, submissions)
|
||||||
results = data["results"]
|
results = data["results"]
|
||||||
if request.user.is_authenticated and request.user.is_regular_user():
|
if request.user.is_authenticated and request.user.is_regular_user():
|
||||||
problem_ids = list({s.problem_id for s in results})
|
problem_ids = list({s.problem_id for s in results})
|
||||||
progress_cache = bulk_fetch_problemset_progress(request.user, problem_ids)
|
progress_cache = await sync_to_async(bulk_fetch_problemset_progress)(request.user, problem_ids)
|
||||||
else:
|
else:
|
||||||
progress_cache = {}
|
progress_cache = {}
|
||||||
data["results"] = SubmissionListSerializer(
|
data["results"] = await self.async_serialize_data(
|
||||||
results, many=True, user=request.user, problemset_progress_cache=progress_cache
|
SubmissionListSerializer,
|
||||||
).data
|
results,
|
||||||
|
many=True,
|
||||||
|
user=request.user,
|
||||||
|
problemset_progress_cache=progress_cache,
|
||||||
|
)
|
||||||
return self.success(data)
|
return self.success(data)
|
||||||
|
|
||||||
|
|
||||||
@@ -262,6 +269,7 @@ class ContestSubmissionListAPI(APIView):
|
|||||||
return self.success(data)
|
return self.success(data)
|
||||||
|
|
||||||
|
|
||||||
|
# DEPRECATED: 前端未调用 (2026-05-26)
|
||||||
class SubmissionExistsAPI(APIView):
|
class SubmissionExistsAPI(APIView):
|
||||||
def get(self, request):
|
def get(self, request):
|
||||||
if not request.GET.get("problem_id"):
|
if not request.GET.get("problem_id"):
|
||||||
@@ -274,10 +282,10 @@ class SubmissionExistsAPI(APIView):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class SubmissionsTodayCount(APIView):
|
class SubmissionsTodayCount(AsyncAPIView):
|
||||||
def get(self, request):
|
async def get(self, request):
|
||||||
today = datetime.today()
|
now = timezone.now()
|
||||||
count = Submission.objects.filter(
|
count = await Submission.objects.filter(
|
||||||
create_time__gte=datetime(today.year, today.month, today.day, 0, 0)
|
create_time__gte=now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||||
).count()
|
).acount()
|
||||||
return self.success(count)
|
return self.success(count)
|
||||||
|
|||||||
@@ -1,7 +1,10 @@
|
|||||||
|
import asyncio
|
||||||
import functools
|
import functools
|
||||||
|
import inspect
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from asgiref.sync import sync_to_async
|
||||||
from django.http import HttpResponse, QueryDict
|
from django.http import HttpResponse, QueryDict
|
||||||
from django.utils.decorators import method_decorator
|
from django.utils.decorators import method_decorator
|
||||||
from django.views.decorators.csrf import csrf_exempt
|
from django.views.decorators.csrf import csrf_exempt
|
||||||
@@ -162,6 +165,77 @@ class CSRFExemptAPIView(APIView):
|
|||||||
return super(CSRFExemptAPIView, self).dispatch(request, *args, **kwargs)
|
return super(CSRFExemptAPIView, self).dispatch(request, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncAPIView(APIView):
|
||||||
|
view_is_async = True
|
||||||
|
|
||||||
|
def __init_subclass__(cls, **kwargs):
|
||||||
|
super().__init_subclass__(**kwargs)
|
||||||
|
cls.view_is_async = True
|
||||||
|
|
||||||
|
async def dispatch(self, request, *args, **kwargs):
|
||||||
|
if self.request_parsers:
|
||||||
|
try:
|
||||||
|
request.data = self._get_request_data(self.request)
|
||||||
|
except ValueError as e:
|
||||||
|
return self.error(err="invalid-request", msg=str(e))
|
||||||
|
try:
|
||||||
|
handler = getattr(self, request.method.lower(), self.http_method_not_allowed)
|
||||||
|
response = handler(request, *args, **kwargs)
|
||||||
|
if asyncio.iscoroutine(response):
|
||||||
|
response = await response
|
||||||
|
return response
|
||||||
|
except APIError as e:
|
||||||
|
ret = {"msg": e.msg}
|
||||||
|
if e.err:
|
||||||
|
ret["err"] = e.err
|
||||||
|
return self.error(**ret)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(e)
|
||||||
|
return self.server_error()
|
||||||
|
|
||||||
|
def serialize_data(self, object_serializer, data, **kwargs):
|
||||||
|
return object_serializer(data, **kwargs).data
|
||||||
|
|
||||||
|
async def async_serialize_data(self, object_serializer, data, **kwargs):
|
||||||
|
return await sync_to_async(
|
||||||
|
self.serialize_data,
|
||||||
|
thread_sensitive=True,
|
||||||
|
)(object_serializer, data, **kwargs)
|
||||||
|
|
||||||
|
async def async_paginate_data(self, request, query_set, object_serializer=None):
|
||||||
|
try:
|
||||||
|
limit = int(request.GET.get("limit", "10"))
|
||||||
|
except ValueError:
|
||||||
|
limit = 10
|
||||||
|
if limit < 0 or limit > 250:
|
||||||
|
limit = 10
|
||||||
|
try:
|
||||||
|
offset = int(request.GET.get("offset", "0"))
|
||||||
|
except ValueError:
|
||||||
|
offset = 0
|
||||||
|
if offset < 0:
|
||||||
|
offset = 0
|
||||||
|
count, results = await asyncio.gather(
|
||||||
|
query_set.acount(),
|
||||||
|
sync_to_async(lambda: list(query_set[offset:offset + limit]), thread_sensitive=True)(),
|
||||||
|
)
|
||||||
|
if object_serializer:
|
||||||
|
results = await self.async_serialize_data(
|
||||||
|
object_serializer,
|
||||||
|
results,
|
||||||
|
many=True,
|
||||||
|
context={"request": request},
|
||||||
|
)
|
||||||
|
data = {"results": results, "total": count}
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class CSRFExemptAsyncAPIView(AsyncAPIView):
|
||||||
|
@method_decorator(csrf_exempt)
|
||||||
|
async def dispatch(self, request, *args, **kwargs):
|
||||||
|
return await super().dispatch(request, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def validate_serializer(serializer):
|
def validate_serializer(serializer):
|
||||||
"""
|
"""
|
||||||
@validate_serializer(TestSerializer)
|
@validate_serializer(TestSerializer)
|
||||||
@@ -169,6 +243,20 @@ def validate_serializer(serializer):
|
|||||||
return self.success(request.data)
|
return self.success(request.data)
|
||||||
"""
|
"""
|
||||||
def validate(view_method):
|
def validate(view_method):
|
||||||
|
if inspect.iscoroutinefunction(view_method):
|
||||||
|
@functools.wraps(view_method)
|
||||||
|
async def async_handle(*args, **kwargs):
|
||||||
|
self = args[0]
|
||||||
|
request = args[1]
|
||||||
|
s = serializer(data=request.data)
|
||||||
|
if s.is_valid():
|
||||||
|
request.data = s.data
|
||||||
|
request.serializer = s
|
||||||
|
return await view_method(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
return self.invalid_serializer(s)
|
||||||
|
return async_handle
|
||||||
|
|
||||||
@functools.wraps(view_method)
|
@functools.wraps(view_method)
|
||||||
def handle(*args, **kwargs):
|
def handle(*args, **kwargs):
|
||||||
self = args[0]
|
self = args[0]
|
||||||
@@ -180,7 +268,6 @@ def validate_serializer(serializer):
|
|||||||
return view_method(*args, **kwargs)
|
return view_method(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
return self.invalid_serializer(s)
|
return self.invalid_serializer(s)
|
||||||
|
|
||||||
return handle
|
return handle
|
||||||
|
|
||||||
return validate
|
return validate
|
||||||
|
|||||||
14
utils/async_helpers.py
Normal file
14
utils/async_helpers.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
from asgiref.sync import sync_to_async
|
||||||
|
from django.core.cache import cache
|
||||||
|
|
||||||
|
|
||||||
|
async def async_cache_get(key, default=None):
|
||||||
|
return await sync_to_async(cache.get, thread_sensitive=True)(key, default)
|
||||||
|
|
||||||
|
|
||||||
|
async def async_cache_set(key, value, timeout=None):
|
||||||
|
return await sync_to_async(cache.set, thread_sensitive=True)(key, value, timeout)
|
||||||
|
|
||||||
|
|
||||||
|
async def async_cache_delete(key):
|
||||||
|
return await sync_to_async(cache.delete, thread_sensitive=True)(key)
|
||||||
@@ -4,5 +4,5 @@ from .views import SimditorFileUploadAPIView, SimditorImageUploadAPIView
|
|||||||
|
|
||||||
urlpatterns = [
|
urlpatterns = [
|
||||||
path("upload_image", SimditorImageUploadAPIView.as_view()),
|
path("upload_image", SimditorImageUploadAPIView.as_view()),
|
||||||
path("upload_file", SimditorFileUploadAPIView.as_view()),
|
path("upload_file", SimditorFileUploadAPIView.as_view()), # DEPRECATED: 前端未调用
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ class SimditorImageUploadAPIView(CSRFExemptAPIView):
|
|||||||
"file_path": f"{settings.UPLOAD_PREFIX}/{img_name}"})
|
"file_path": f"{settings.UPLOAD_PREFIX}/{img_name}"})
|
||||||
|
|
||||||
|
|
||||||
|
# DEPRECATED: 前端未调用 (2026-05-26)
|
||||||
class SimditorFileUploadAPIView(CSRFExemptAPIView):
|
class SimditorFileUploadAPIView(CSRFExemptAPIView):
|
||||||
request_parsers = ()
|
request_parsers = ()
|
||||||
|
|
||||||
|
|||||||
2
uv.lock
generated
2
uv.lock
generated
@@ -550,6 +550,7 @@ name = "onlinejudge"
|
|||||||
version = "2.0.0"
|
version = "2.0.0"
|
||||||
source = { virtual = "." }
|
source = { virtual = "." }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
{ name = "asgiref" },
|
||||||
{ name = "channels" },
|
{ name = "channels" },
|
||||||
{ name = "channels-redis" },
|
{ name = "channels-redis" },
|
||||||
{ name = "django" },
|
{ name = "django" },
|
||||||
@@ -582,6 +583,7 @@ dev = [
|
|||||||
|
|
||||||
[package.metadata]
|
[package.metadata]
|
||||||
requires-dist = [
|
requires-dist = [
|
||||||
|
{ name = "asgiref", specifier = ">=3.11.1" },
|
||||||
{ name = "channels", specifier = ">=4.3.2,<5" },
|
{ name = "channels", specifier = ">=4.3.2,<5" },
|
||||||
{ name = "channels-redis", specifier = ">=4.3.0,<5" },
|
{ name = "channels-redis", specifier = ">=4.3.0,<5" },
|
||||||
{ name = "django", specifier = ">=6.0.4,<6.1" },
|
{ name = "django", specifier = ">=6.0.4,<6.1" },
|
||||||
|
|||||||
Reference in New Issue
Block a user