From d1fdbcf52b53b10c6fd8919e975fa80ee4d9c31d Mon Sep 17 00:00:00 2001 From: yuetsh <517252939@qq.com> Date: Tue, 2 Jun 2026 23:13:06 -0600 Subject: [PATCH] fix --- account/decorators.py | 74 ++++++++++++------------ account/middleware.py | 6 +- contest/views/oj.py | 116 ++++++++++++++++++++------------------ problem/views/oj.py | 23 +++----- submission/views/admin.py | 4 +- submission/views/oj.py | 71 ++++++++++------------- 6 files changed, 144 insertions(+), 150 deletions(-) diff --git a/account/decorators.py b/account/decorators.py index c4ed886..275db5e 100644 --- a/account/decorators.py +++ b/account/decorators.py @@ -20,27 +20,32 @@ class BasePermissionDecorator(object): return functools.partial(self._async_call, obj) return functools.partial(self.__call__, obj) - def error(self, data): - return JSONResponse.response({"error": "permission-denied", "data": data}) + def error(self, data, err="permission-denied"): + return JSONResponse.response({"error": err, "data": data}) + + def _permission_error(self, request): + if not request.user.is_authenticated: + return self.error("请先登录", err="login-required") + return self.error("权限不足", err="permission-denied") def __call__(self, *args, **kwargs): request = args[1] if self.check_permission(request): if request.user.is_disabled: - return self.error("Your account is disabled") + return self.error("账号已禁用") return self.func(*args, **kwargs) else: - return self.error("Please login first") + return self._permission_error(request) async def _async_call(self, *args, **kwargs): request = args[1] if self.check_permission(request): if request.user.is_disabled: - return self.error("Your account is disabled") + return self.error("账号已禁用") return await self.func(*args, **kwargs) - return self.error("Please login first") + return self._permission_error(request) def check_permission(self, request): raise NotImplementedError() @@ -110,43 +115,42 @@ def check_contest_permission(check_type="details"): 若通过验证,在view中可通过self.contest获得该contest """ + def _get_contest_id(request): + return request.data.get("contest_id") or request.GET.get("contest_id") + + def _check_access(self, request, user): + if not user.is_authenticated: + return self.error("请先登录", err="login-required") + + if user.is_contest_admin(self.contest): + return None + + if self.contest.contest_type == ContestType.PASSWORD_PROTECTED_CONTEST: + if not check_contest_password(request.session.get(CONTEST_PASSWORD_SESSION_KEY, {}).get(self.contest.id), self.contest.password): + return self.error("Wrong password or password expired") + + if self.contest.status == ContestStatus.CONTEST_NOT_START and check_type != "details": + return self.error("Contest has not started yet.") + + return None + def decorator(func): - def _check_permission(*args, **kwargs): + @functools.wraps(func) + async def _wrapper(*args, **kwargs): self = args[0] request = args[1] - user = request.user - if request.data.get("contest_id"): - contest_id = request.data["contest_id"] - else: - contest_id = request.GET.get("contest_id") + contest_id = _get_contest_id(request) if not contest_id: return self.error("Parameter error, contest_id is required") - try: - # use self.contest to avoid query contest again in view. - self.contest = Contest.objects.select_related("created_by").get(id=contest_id, visible=True) + self.contest = await Contest.objects.select_related("created_by").aget(id=contest_id, visible=True) except Contest.DoesNotExist: return self.error("Contest %s doesn't exist" % contest_id) - - # Anonymous - if not user.is_authenticated: - return self.error("Please login first.") - - # creator or owner - if user.is_contest_admin(self.contest): - return func(*args, **kwargs) - - if self.contest.contest_type == ContestType.PASSWORD_PROTECTED_CONTEST: - # password error - if not check_contest_password(request.session.get(CONTEST_PASSWORD_SESSION_KEY, {}).get(self.contest.id), self.contest.password): - return self.error("Wrong password or password expired") - - # regular user get contest problems, ranks etc. before contest started - if self.contest.status == ContestStatus.CONTEST_NOT_START and check_type != "details": - return self.error("Contest has not started yet.") - - return func(*args, **kwargs) - return _check_permission + error = _check_access(self, request, request.user) + if error: + return error + return await func(*args, **kwargs) + return _wrapper return decorator diff --git a/account/middleware.py b/account/middleware.py index 3fc218c..8e8bec1 100644 --- a/account/middleware.py +++ b/account/middleware.py @@ -37,8 +37,10 @@ class AdminRoleRequiredMiddleware(MiddlewareMixin): def process_request(self, request): path = request.path_info if path.startswith("/admin/") or path.startswith("/api/admin/"): - if not (request.user.is_authenticated and request.user.is_admin_role()): - return JSONResponse.response({"error": "login-required", "data": "Please login in first"}) + if not request.user.is_authenticated: + return JSONResponse.response({"error": "login-required", "data": "请先登录"}) + if not request.user.is_admin_role(): + return JSONResponse.response({"error": "permission-denied", "data": "权限不足"}) class LogSqlMiddleware(MiddlewareMixin): diff --git a/contest/views/oj.py b/contest/views/oj.py index bb92064..789aeac 100644 --- a/contest/views/oj.py +++ b/contest/views/oj.py @@ -1,6 +1,7 @@ import io import xlsxwriter +from asgiref.sync import sync_to_async from django.http import HttpResponse from django.utils.timezone import now @@ -11,7 +12,7 @@ from account.decorators import ( ) from account.models import AdminType from problem.models import Problem -from utils.api import APIView, AsyncAPIView, validate_serializer +from utils.api import AsyncAPIView, validate_serializer from utils.constants import CONTEST_PASSWORD_SESSION_KEY, ContestStatus from utils.shortcuts import check_is_id, datetime2str @@ -20,19 +21,20 @@ from ..serializers import ACMContestRankSerializer, ContestAnnouncementSerialize # DEPRECATED: 前端未调用 (2026-05-26) -class ContestAnnouncementListAPI(APIView): +class ContestAnnouncementListAPI(AsyncAPIView): @check_contest_permission(check_type="announcements") - def get(self, request): + async def get(self, request): contest_id = request.GET.get("contest_id") if not contest_id: return self.error("Invalid parameter, contest_id is required") - data = ContestAnnouncement.objects.select_related("created_by").filter( + qs = ContestAnnouncement.objects.select_related("created_by").filter( contest_id=contest_id, visible=True ) max_id = request.GET.get("max_id") if max_id: - data = data.filter(id__gt=max_id) - return self.success(ContestAnnouncementSerializer(data, many=True).data) + qs = qs.filter(id__gt=max_id) + data = await self.async_serialize_data(ContestAnnouncementSerializer, [item async for item in qs], many=True) + return self.success(data) class ContestAPI(AsyncAPIView): @@ -76,13 +78,13 @@ class ContestListAPI(AsyncAPIView): return self.success(await self.async_paginate_data(request, contests, ContestSerializer)) -class ContestPasswordVerifyAPI(APIView): +class ContestPasswordVerifyAPI(AsyncAPIView): @validate_serializer(ContestPasswordVerifySerializer) @login_required - def post(self, request): + async def post(self, request): data = request.data try: - contest = Contest.objects.get( + contest = await Contest.objects.aget( id=data["contest_id"], visible=True, password__isnull=False ) except Contest.DoesNotExist: @@ -90,23 +92,21 @@ class ContestPasswordVerifyAPI(APIView): if not check_contest_password(data["password"], contest.password): return self.error("Wrong password or password expired") - # password verify OK. if CONTEST_PASSWORD_SESSION_KEY not in request.session: request.session[CONTEST_PASSWORD_SESSION_KEY] = {} request.session[CONTEST_PASSWORD_SESSION_KEY][contest.id] = data["password"] - # https://docs.djangoproject.com/en/dev/topics/http/sessions/#when-sessions-are-saved request.session.modified = True return self.success(True) -class ContestAccessAPI(APIView): +class ContestAccessAPI(AsyncAPIView): @login_required - def get(self, request): + async def get(self, request): contest_id = request.GET.get("contest_id") if not contest_id: return self.error() try: - contest = Contest.objects.get( + contest = await Contest.objects.aget( id=contest_id, visible=True, password__isnull=False ) except Contest.DoesNotExist: @@ -119,7 +119,7 @@ class ContestAccessAPI(APIView): ) -class ContestRankAPI(APIView): +class ContestRankAPI(AsyncAPIView): def get_rank(self): return ( ACMContestRank.objects.filter( @@ -138,8 +138,40 @@ class ContestRankAPI(APIView): string = chr(65 + remainder) + string return string + def _build_xlsx(self, data, contest_problems): + problem_id_to_col = {p.id: i for i, p in enumerate(contest_problems)} + f = io.BytesIO() + workbook = xlsxwriter.Workbook(f) + worksheet = workbook.add_worksheet() + worksheet.write("A1", "User ID") + worksheet.write("B1", "Username") + worksheet.write("C1", "Real Name") + worksheet.write("D1", "AC") + worksheet.write("E1", "Total Submission") + worksheet.write("F1", "Total Time") + for i, p in enumerate(contest_problems): + worksheet.write(self.column_string(7 + i) + "1", p.title) + + for index, item in enumerate(data): + worksheet.write_string(index + 1, 0, str(item["user"]["id"])) + worksheet.write_string(index + 1, 1, item["user"]["username"]) + worksheet.write_string( + index + 1, 2, item["user"]["real_name"] or "" + ) + worksheet.write_string(index + 1, 3, str(item["accepted_number"])) + worksheet.write_string(index + 1, 4, str(item["submission_number"])) + worksheet.write_string(index + 1, 5, str(item["total_time"])) + for k, v in item["submission_info"].items(): + worksheet.write_string( + index + 1, 6 + problem_id_to_col[int(k)], str(v["is_ac"]) + ) + + workbook.close() + f.seek(0) + return f.read() + @check_contest_permission(check_type="ranks") - def get(self, request): + async def get(self, request): download_csv = request.GET.get("download_csv") is_contest_admin = ( request.user.is_authenticated @@ -149,50 +181,24 @@ class ContestRankAPI(APIView): qs = self.get_rank() if download_csv: - data = ACMContestRankSerializer(qs, many=True, is_contest_admin=is_contest_admin).data - contest_problems = list(Problem.objects.filter( - contest=self.contest, visible=True - ).order_by("_id")) - # 预建 problem_id → 列索引 的字典,避免循环中 O(n) list.index() - problem_id_to_col = {p.id: i for i, p in enumerate(contest_problems)} - - f = io.BytesIO() - workbook = xlsxwriter.Workbook(f) - worksheet = workbook.add_worksheet() - worksheet.write("A1", "User ID") - worksheet.write("B1", "Username") - worksheet.write("C1", "Real Name") - worksheet.write("D1", "AC") - worksheet.write("E1", "Total Submission") - worksheet.write("F1", "Total Time") - for i, p in enumerate(contest_problems): - worksheet.write(self.column_string(7 + i) + "1", p.title) - - for index, item in enumerate(data): - worksheet.write_string(index + 1, 0, str(item["user"]["id"])) - worksheet.write_string(index + 1, 1, item["user"]["username"]) - worksheet.write_string( - index + 1, 2, item["user"]["real_name"] or "" - ) - worksheet.write_string(index + 1, 3, str(item["accepted_number"])) - worksheet.write_string(index + 1, 4, str(item["submission_number"])) - worksheet.write_string(index + 1, 5, str(item["total_time"])) - for k, v in item["submission_info"].items(): - worksheet.write_string( - index + 1, 6 + problem_id_to_col[int(k)], str(v["is_ac"]) - ) - - workbook.close() - f.seek(0) - response = HttpResponse(f.read()) + rank_list = [item async for item in qs] + data = await self.async_serialize_data( + ACMContestRankSerializer, rank_list, many=True, is_contest_admin=is_contest_admin + ) + contest_problems = await sync_to_async( + lambda: list(Problem.objects.filter(contest=self.contest, visible=True).order_by("_id")) + )() + xlsx_bytes = await sync_to_async(self._build_xlsx)(data, contest_problems) + response = HttpResponse(xlsx_bytes) response["Content-Disposition"] = ( f"attachment; filename=content-{self.contest.id}-rank.xlsx" ) response["Content-Type"] = "application/xlsx" return response - page_qs = self.paginate_data(request, qs) - page_qs["results"] = ACMContestRankSerializer( + page_qs = await self.async_paginate_data(request, qs) + page_qs["results"] = await self.async_serialize_data( + ACMContestRankSerializer, page_qs["results"], many=True, is_contest_admin=is_contest_admin - ).data + ) return self.success(page_qs) diff --git a/problem/views/oj.py b/problem/views/oj.py index 8d01d6e..40d1635 100644 --- a/problem/views/oj.py +++ b/problem/views/oj.py @@ -7,7 +7,7 @@ from django.utils import timezone from account.decorators import check_contest_permission from account.models import User from submission.models import JudgeStatus, Submission -from utils.api import APIView, AsyncAPIView +from utils.api import AsyncAPIView from utils.async_helpers import async_cache_get, async_cache_set from utils.constants import CacheKey @@ -144,7 +144,7 @@ class ProblemAPI(AsyncAPIView): return self.success(data) -class ContestProblemAPI(APIView): +class ContestProblemAPI(AsyncAPIView): def _add_problem_status(self, request, queryset_values): if request.user.is_authenticated: profile = request.user.userprofile @@ -153,31 +153,26 @@ class ContestProblemAPI(APIView): problem["my_status"] = problems_status.get(str(problem["id"]), {}).get("status") @check_contest_permission(check_type="problems") - def get(self, request): + async def get(self, request): problem_id = request.GET.get("problem_id") if problem_id: try: - problem = Problem.objects.select_related("created_by").get(_id__iexact=problem_id, contest=self.contest, visible=True) + problem = await Problem.objects.select_related("created_by").aget(_id__iexact=problem_id, contest=self.contest, visible=True) except Problem.DoesNotExist: return self.error("Problem does not exist.") if self.contest.problem_details_permission(request.user): - problem_data = ProblemSerializer(problem).data - self._add_problem_status( - request, - [ - problem_data, - ], - ) + problem_data = await self.async_serialize_data(ProblemSerializer, problem) + self._add_problem_status(request, [problem_data]) else: - problem_data = ProblemSafeSerializer(problem).data + problem_data = await self.async_serialize_data(ProblemSafeSerializer, problem) return self.success(problem_data) contest_problems = Problem.objects.select_related("created_by").prefetch_related("tags").filter(contest=self.contest, visible=True) if self.contest.problem_details_permission(request.user): - data = ProblemListSerializer(contest_problems, many=True).data + data = await self.async_serialize_data(ProblemListSerializer, [p async for p in contest_problems], many=True) self._add_problem_status(request, data) else: - data = ProblemSafeSerializer(contest_problems, many=True).data + data = await self.async_serialize_data(ProblemSafeSerializer, [p async for p in contest_problems], many=True) return self.success(data) diff --git a/submission/views/admin.py b/submission/views/admin.py index d98d4f5..65ba85e 100644 --- a/submission/views/admin.py +++ b/submission/views/admin.py @@ -1,6 +1,6 @@ from django.db.models import Count, Q -from account.decorators import super_admin_required +from account.decorators import super_admin_required, teacher_admin_required from account.models import AdminType, User from judge.tasks import judge_task from problem.models import Problem @@ -35,7 +35,7 @@ class SubmissionRejudgeAPI(APIView): class SubmissionStatisticsAPI(APIView): - @super_admin_required + @teacher_admin_required def get(self, request): start = request.GET.get("start") end = request.GET.get("end") diff --git a/submission/views/oj.py b/submission/views/oj.py index 385bb88..a2442da 100644 --- a/submission/views/oj.py +++ b/submission/views/oj.py @@ -10,7 +10,7 @@ from options.options import SysOptions # from judge.dispatcher import JudgeDispatcher from problem.models import Problem, ProblemRuleType -from utils.api import APIView, AsyncAPIView, validate_serializer +from utils.api import AsyncAPIView, validate_serializer from utils.cache import cache from utils.captcha import Captcha from utils.throttling import TokenBucket @@ -26,9 +26,8 @@ from ..serializers import ( ) -class SubmissionAPI(APIView): +class SubmissionAPI(AsyncAPIView): def throttling(self, request): - # 使用 open_api 的请求暂不做限制 auth_method = getattr(request, "auth_method", "") if auth_method == "api_key": return @@ -39,14 +38,8 @@ class SubmissionAPI(APIView): if not can_consume: return "Please wait %d seconds" % (int(wait)) - # ip_bucket = TokenBucket(key=request.session["ip"], - # redis_conn=cache, **SysOptions.throttling["ip"]) - # can_consume, wait = ip_bucket.consume() - # if not can_consume: - # return "Captcha is required" - @check_contest_permission(check_type="problems") - def check_contest_permission(self, request): + async def check_contest_permission(self, request): contest = self.contest if contest.status == ContestStatus.CONTEST_ENDED: return self.error("The contest have ended") @@ -61,11 +54,11 @@ class SubmissionAPI(APIView): @validate_serializer(CreateSubmissionSerializer) @login_required - def post(self, request): + async def post(self, request): data = request.data hide_id = False if data.get("contest_id"): - error = self.check_contest_permission(request) + error = await self.check_contest_permission(request) if error: return error contest = self.contest @@ -75,19 +68,19 @@ class SubmissionAPI(APIView): if data.get("captcha"): if not Captcha(request).check(data["captcha"]): return self.error("Invalid captcha") - error = self.throttling(request) + error = await sync_to_async(self.throttling)(request) if error: return self.error(error) try: - problem = Problem.objects.get( + problem = await Problem.objects.aget( id=data["problem_id"], contest_id=data.get("contest_id"), visible=True ) except Problem.DoesNotExist: return self.error("Problem not exist") if data["language"] not in problem.languages: return self.error(f"{data['language']} is not allowed in the problem") - submission = Submission.objects.create( + submission = await Submission.objects.acreate( user_id=request.user.id, username=request.user.username, language=data["language"], @@ -97,8 +90,6 @@ class SubmissionAPI(APIView): contest_id=data.get("contest_id"), ) - # use this for debug - # JudgeDispatcher(submission.id, problem.id).judge() judge_task.send(submission.id, problem.id) if hide_id: return self.success() @@ -106,12 +97,12 @@ class SubmissionAPI(APIView): return self.success({"submission_id": submission.id}) @login_required - def get(self, request): + async def get(self, request): submission_id = request.GET.get("id") if not submission_id: return self.error("Parameter id doesn't exist") try: - submission = Submission.objects.select_related("problem").get( + submission = await Submission.objects.select_related("problem").aget( id=submission_id ) except Submission.DoesNotExist: @@ -123,10 +114,9 @@ class SubmissionAPI(APIView): submission.problem.rule_type == ProblemRuleType.OI or request.user.is_admin_role() ): - submission_data = SubmissionModelSerializer(submission).data + submission_data = await self.async_serialize_data(SubmissionModelSerializer, submission) else: - submission_data = SubmissionSafeModelSerializer(submission).data - # 是否有权限取消共享 + submission_data = await self.async_serialize_data(SubmissionSafeModelSerializer, submission) submission_data["can_unshare"] = submission.check_user_permission( request.user, check_share=False ) @@ -134,12 +124,9 @@ class SubmissionAPI(APIView): @validate_serializer(ShareSubmissionSerializer) @login_required - def put(self, request): - """ - share submission - """ + async def put(self, request): try: - submission = Submission.objects.select_related("problem").get( + submission = await Submission.objects.select_related("problem").aget( id=request.data["id"] ) except Submission.DoesNotExist: @@ -152,7 +139,7 @@ class SubmissionAPI(APIView): ): return self.error("Can not share submission now") submission.shared = request.data["shared"] - submission.save(update_fields=["shared"]) + await submission.asave(update_fields=["shared"]) return self.success() @@ -215,9 +202,9 @@ class SubmissionListAPI(AsyncAPIView): return self.success(data) -class ContestSubmissionListAPI(APIView): +class ContestSubmissionListAPI(AsyncAPIView): @check_contest_permission(check_type="submissions") - def get(self, request): + async def get(self, request): if not request.GET.get("limit"): return self.error("Limit is needed") @@ -231,7 +218,7 @@ class ContestSubmissionListAPI(APIView): username = request.GET.get("username") if problem_id: try: - problem = Problem.objects.get( + problem = await Problem.objects.aget( _id__iexact=problem_id, contest_id=contest.id, visible=True ) except Problem.DoesNotExist: @@ -245,35 +232,35 @@ class ContestSubmissionListAPI(APIView): if result: submissions = submissions.filter(result=result) - # filter the test submissions submitted before contest start if contest.status != ContestStatus.CONTEST_NOT_START: submissions = submissions.filter(create_time__gte=contest.start_time) - - data = self.paginate_data(request, submissions) + data = await self.async_paginate_data(request, submissions) results = data["results"] if request.user.is_authenticated and request.user.is_regular_user(): problem_ids = list({s.problem_id for s in results}) - progress_cache = bulk_fetch_problemset_progress(request.user, problem_ids) + progress_cache = await sync_to_async(bulk_fetch_problemset_progress)(request.user, problem_ids) else: progress_cache = {} - data["results"] = SubmissionListSerializer( + data["results"] = await self.async_serialize_data( + SubmissionListSerializer, results, many=True, user=request.user, problemset_progress_cache=progress_cache - ).data + ) return self.success(data) # DEPRECATED: 前端未调用 (2026-05-26) -class SubmissionExistsAPI(APIView): - def get(self, request): +class SubmissionExistsAPI(AsyncAPIView): + async def get(self, request): if not request.GET.get("problem_id"): return self.error("Parameter error, problem_id is required") - return self.success( + exists = ( request.user.is_authenticated - and Submission.objects.filter( + and await Submission.objects.filter( problem_id=request.GET["problem_id"], user_id=request.user.id - ).exists() + ).aexists() ) + return self.success(exists) class SubmissionsTodayCount(AsyncAPIView):