diff --git a/account/decorators.py b/account/decorators.py index 839893f..36a2b44 100644 --- a/account/decorators.py +++ b/account/decorators.py @@ -1,10 +1,7 @@ import functools - -from utils.api import JSONResponse - -from .models import ProblemPermission - from contest.models import Contest, ContestType, ContestStatus, ContestRuleType +from utils.api import JSONResponse, APIError +from .models import ProblemPermission class BasePermissionDecorator(object): @@ -90,8 +87,7 @@ def check_contest_permission(check_type="details"): if not user.is_authenticated(): return self.error("Please login first.") # password error - if ("accessible_contests" not in request.session) or \ - (self.contest.id not in request.session["accessible_contests"]): + if self.contest.id not in request.session.get("accessible_contests", []): return self.error("Password is required.") # regular user get contest problems, ranks etc. before contest started @@ -104,7 +100,10 @@ def check_contest_permission(check_type="details"): return self.error(f"No permission to get {check_type}") return func(*args, **kwargs) - return _check_permission - return decorator + + +def ensure_created_by(obj, user): + if not user.is_admin_role() or (user.is_admin() and obj.created_by != user): + raise APIError(msg=f"{obj.__class__.__name__} does not exist") diff --git a/conf/views.py b/conf/views.py index 972312d..5e82645 100644 --- a/conf/views.py +++ b/conf/views.py @@ -30,14 +30,14 @@ class SMTPAPI(APIView): smtp.pop("password") return self.success(smtp) - @validate_serializer(CreateSMTPConfigSerializer) @super_admin_required + @validate_serializer(CreateSMTPConfigSerializer) def post(self, request): SysOptions.smtp_config = request.data return self.success() - @validate_serializer(EditSMTPConfigSerializer) @super_admin_required + @validate_serializer(EditSMTPConfigSerializer) def put(self, request): smtp = SysOptions.smtp_config data = request.data @@ -81,8 +81,8 @@ class WebsiteConfigAPI(APIView): "website_footer", "allow_register", "submission_list_show_all"]} return self.success(ret) - @validate_serializer(CreateEditWebsiteConfigSerializer) @super_admin_required + @validate_serializer(CreateEditWebsiteConfigSerializer) def post(self, request): for k, v in request.data.items(): if k == "website_footer": diff --git a/contest/views/admin.py b/contest/views/admin.py index d4066b5..a44e9da 100644 --- a/contest/views/admin.py +++ b/contest/views/admin.py @@ -5,7 +5,7 @@ from utils.api import APIView, validate_serializer from utils.cache import cache from utils.constants import CacheKey -from account.decorators import check_contest_permission +from account.decorators import check_contest_permission, ensure_created_by from ..models import Contest, ContestAnnouncement, ACMContestRank from ..serializers import (ContestAnnouncementSerializer, ContestAdminSerializer, CreateConetestSeriaizer, CreateContestAnnouncementSerializer, @@ -37,8 +37,7 @@ class ContestAPI(APIView): data = request.data try: contest = Contest.objects.get(id=data.pop("id")) - if request.user.is_admin() and contest.created_by != request.user: - return self.error("Contest does not exist") + ensure_created_by(contest, request.user) except Contest.DoesNotExist: return self.error("Contest does not exist") data["start_time"] = dateutil.parser.parse(data["start_time"]) @@ -66,20 +65,18 @@ class ContestAPI(APIView): if contest_id: try: contest = Contest.objects.get(id=contest_id) - if request.user.is_admin() and contest.created_by != request.user: - return self.error("Contest does not exist") + ensure_created_by(contest, request.user) return self.success(ContestAdminSerializer(contest).data) except Contest.DoesNotExist: return self.error("Contest does not exist") contests = Contest.objects.all().order_by("-create_time") + if request.user.is_admin(): + contests = contests.filter(created_by=request.user) keyword = request.GET.get("keyword") if keyword: contests = contests.filter(title__contains=keyword) - - if request.user.is_admin(): - contests = contests.filter(created_by=request.user) return self.success(self.paginate_data(request, contests, ContestAdminSerializer)) @@ -92,8 +89,7 @@ class ContestAnnouncementAPI(APIView): data = request.data try: contest = Contest.objects.get(id=data.pop("contest_id")) - if request.user.is_admin() and contest.created_by != request.user: - return self.error("Contest does not exist") + ensure_created_by(contest, request.user) data["contest"] = contest data["created_by"] = request.user except Contest.DoesNotExist: @@ -109,8 +105,7 @@ class ContestAnnouncementAPI(APIView): data = request.data try: contest_announcement = ContestAnnouncement.objects.get(id=data.pop("id")) - if request.user.is_admin() and contest_announcement.created_by != request.user: - return self.error("Contest announcement does not exist") + ensure_created_by(contest_announcement, request.user) except ContestAnnouncement.DoesNotExist: return self.error("Contest announcement does not exist") for k, v in data.items(): @@ -139,15 +134,14 @@ class ContestAnnouncementAPI(APIView): if contest_announcement_id: try: contest_announcement = ContestAnnouncement.objects.get(id=contest_announcement_id) - if request.user.is_admin() and contest_announcement.created_by != request.user: - return self.error("Contest announcement does not exist") + ensure_created_by(contest_announcement, request.user) return self.success(ContestAnnouncementSerializer(contest_announcement).data) except ContestAnnouncement.DoesNotExist: return self.error("Contest announcement does not exist") contest_id = request.GET.get("contest_id") if not contest_id: - return self.error("Paramater error") + return self.error("Parameter error") contest_announcements = ContestAnnouncement.objects.filter(contest_id=contest_id) if request.user.is_admin(): contest_announcements = contest_announcements.filter(created_by=request.user) @@ -177,12 +171,10 @@ class ACMContestHelper(APIView): results.sort(key=lambda x: -x["ac_info"]["ac_time"]) return self.success(results) - @validate_serializer(ACMContesHelperSerializer) @check_contest_permission(check_type="ranks") + @validate_serializer(ACMContesHelperSerializer) def put(self, request): data = request.data - if not request.user.is_contest_admin(self.contest): - return self.error("You are not contest admin") try: rank = ACMContestRank.objects.get(pk=data["rank_id"]) except ACMContestRank.DoesNotExist: diff --git a/problem/views/admin.py b/problem/views/admin.py index 352aa09..5846e35 100644 --- a/problem/views/admin.py +++ b/problem/views/admin.py @@ -8,7 +8,7 @@ from wsgiref.util import FileWrapper from django.conf import settings from django.http import StreamingHttpResponse, HttpResponse -from account.decorators import problem_permission_required +from account.decorators import problem_permission_required, ensure_created_by from judge.dispatcher import SPJCompiler from contest.models import Contest, ContestStatus from submission.models import Submission @@ -49,7 +49,6 @@ class TestCaseAPI(CSRFExemptAPIView): else: return sorted(ret, key=natural_sort_key) - @problem_permission_required def get(self, request): problem_id = request.GET.get("problem_id") if not problem_id: @@ -59,6 +58,11 @@ class TestCaseAPI(CSRFExemptAPIView): except Problem.DoesNotExist: return self.error("Problem does not exists") + if problem.contest: + ensure_created_by(problem.contest, request.user) + else: + ensure_created_by(problem, request.user) + test_case_dir = os.path.join(settings.TEST_CASE_DIR, problem.test_case_id) if not os.path.isdir(test_case_dir): return self.error("Test case does not exists") @@ -79,7 +83,6 @@ class TestCaseAPI(CSRFExemptAPIView): response["Content-Length"] = os.path.getsize(file_name) return response - @problem_permission_required def post(self, request): form = TestCaseUploadForm(request.POST, request.FILES) if form.is_valid(): @@ -147,7 +150,6 @@ class TestCaseAPI(CSRFExemptAPIView): class CompileSPJAPI(APIView): @validate_serializer(CompileSPJSerializer) - @problem_permission_required def post(self, request): data = request.data spj_version = rand_str(8) @@ -186,11 +188,12 @@ class ProblemBase(APIView): def delete(self, request): id = request.GET.get("id") if not id: - return self.error("Invalid parameter, id is requred") + return self.error("Invalid parameter, id is required") try: - problem = Problem.objects.get(id=id) + problem = Problem.objects.get(id=id, contest_id__isnull=True) except Problem.DoesNotExist: return self.error("Problem does not exists") + ensure_created_by(problem, request.user) if Submission.objects.filter(problem=problem).exists(): return self.error("Can't delete the problem as it has submissions") d = os.path.join(settings.TEST_CASE_DIR, problem.test_case_id) @@ -201,11 +204,10 @@ class ProblemBase(APIView): class ProblemAPI(ProblemBase): - @validate_serializer(CreateProblemSerializer) @problem_permission_required + @validate_serializer(CreateProblemSerializer) def post(self, request): data = request.data - _id = data["_id"] if not _id: return self.error("Display ID is required") @@ -236,8 +238,7 @@ class ProblemAPI(ProblemBase): if problem_id: try: problem = Problem.objects.get(id=problem_id) - if not user.can_mgmt_all_problem() and problem.created_by != user: - return self.error("Problem does not exist") + ensure_created_by(problem, request.user) return self.success(ProblemAdminSerializer(problem).data) except Problem.DoesNotExist: return self.error("Problem does not exist") @@ -256,17 +257,15 @@ class ProblemAPI(ProblemBase): problems = problems.filter(title__contains=keyword) return self.success(self.paginate_data(request, problems, ProblemAdminSerializer)) - @validate_serializer(EditProblemSerializer) @problem_permission_required + @validate_serializer(EditProblemSerializer) def put(self, request): data = request.data problem_id = data.pop("id") - user = request.user try: problem = Problem.objects.get(id=problem_id) - if not user.can_mgmt_all_problem() and problem.created_by != user: - return self.error("Problem does not exist") + ensure_created_by(problem, request.user) except Problem.DoesNotExist: return self.error("Problem does not exist") @@ -300,13 +299,11 @@ class ProblemAPI(ProblemBase): class ContestProblemAPI(ProblemBase): @validate_serializer(CreateContestProblemSerializer) - @problem_permission_required def post(self, request): data = request.data try: contest = Contest.objects.get(id=data.pop("contest_id")) - if request.user.is_admin() and contest.created_by != request.user: - return self.error("Contest does not exist") + ensure_created_by(contest, request.user) except Contest.DoesNotExist: return self.error("Contest does not exist") @@ -345,8 +342,7 @@ class ContestProblemAPI(ProblemBase): if problem_id: try: problem = Problem.objects.get(id=problem_id) - if user.is_admin() and problem.contest.created_by != user: - return self.error("Problem does not exist") + ensure_created_by(problem, user) except Problem.DoesNotExist: return self.error("Problem does not exist") return self.success(ProblemAdminSerializer(problem).data) @@ -366,10 +362,11 @@ class ContestProblemAPI(ProblemBase): @problem_permission_required def put(self, request): data = request.data + user = request.user + try: contest = Contest.objects.get(id=data.pop("contest_id")) - if request.user.is_admin() and contest.created_by != request.user: - return self.error("Contest does not exist") + ensure_created_by(contest, user) except Contest.DoesNotExist: return self.error("Contest does not exist") @@ -377,12 +374,10 @@ class ContestProblemAPI(ProblemBase): return self.error("Invalid rule type") problem_id = data.pop("id") - user = request.user try: problem = Problem.objects.get(id=problem_id) - if not user.can_mgmt_all_problem() and problem.created_by != user: - return self.error("Problem does not exist") + ensure_created_by(problem, user) except Problem.DoesNotExist: return self.error("Problem does not exist") diff --git a/utils/api/api.py b/utils/api/api.py index e33daaf..5b7a231 100644 --- a/utils/api/api.py +++ b/utils/api/api.py @@ -11,6 +11,13 @@ from django.views.generic import View logger = logging.getLogger("") +class APIError(Exception): + def __init__(self, msg, err=None): + self.err = err + self.msg = msg + super().__init__(err, msg) + + class ContentType(object): json_request = "application/json" json_response = "application/json;charset=UTF-8" @@ -137,6 +144,11 @@ class APIView(View): return self.error(err="invalid-request", msg=str(e)) try: return super(APIView, self).dispatch(request, *args, **kwargs) + except APIError as e: + ret = {"msg": e.msg} + if e.err: + ret["err"] = e.err + return self.error(**ret) except Exception as e: logger.exception(e) return self.server_error()