diff --git a/account/decorators.py b/account/decorators.py index f801f95..c11560d 100644 --- a/account/decorators.py +++ b/account/decorators.py @@ -4,7 +4,7 @@ from django.utils.translation import ugettext as _ from utils.api import JSONResponse -from .models import AdminType +from .models import ProblemPermission class BasePermissionDecorator(object): @@ -38,11 +38,20 @@ class login_required(BasePermissionDecorator): class super_admin_required(BasePermissionDecorator): def check_permission(self): - return self.request.user.is_authenticated() and \ - self.request.user.admin_type == AdminType.SUPER_ADMIN + user = self.request.user + return user.is_authenticated() and user.is_super_admin() -class admin_required(BasePermissionDecorator): +class admin_role_required(BasePermissionDecorator): def check_permission(self): - return self.request.user.is_authenticated() and \ - self.request.user.admin_type in [AdminType.SUPER_ADMIN, AdminType.ADMIN] + user = self.request.user + return user.is_authenticated() and user.is_admin_role() + + +class problem_permission_required(admin_role_required): + def check_permission(self): + if not super(problem_permission_required, self).check_permission(): + return False + if self.request.user.problem_permission == ProblemPermission.NONE: + return False + return True diff --git a/account/middleware.py b/account/middleware.py index 9bc6be4..409a55a 100644 --- a/account/middleware.py +++ b/account/middleware.py @@ -22,11 +22,11 @@ class SessionSecurityMiddleware(object): request.session["last_activity"] = time.time() -class AdminRequiredMiddleware(object): +class AdminRoleRequiredMiddleware(object): def process_request(self, request): path = request.path_info if path.startswith("/admin/") or path.startswith("/api/admin/"): - if not(request.user.is_authenticated() and request.user.is_admin()): + if not(request.user.is_authenticated() and request.user.is_admin_role()): return JSONResponse.response({"error": "login-required", "data": _("Please login in first")}) diff --git a/account/models.py b/account/models.py index 39be979..807df44 100644 --- a/account/models.py +++ b/account/models.py @@ -51,14 +51,14 @@ class User(AbstractBaseUser): objects = UserManager() - def is_admin(self): - return self.admin_type in [AdminType.ADMIN, AdminType.SUPER_ADMIN] + def is_super_admin(self): + return self.admin_type == AdminType.SUPER_ADMIN def is_admin_role(self): - return self.admin_type == AdminType.ADMIN + return self.admin_type in [AdminType.ADMIN, AdminType.SUPER_ADMIN] - def is_super_admin_role(self): - return self.admin_type == AdminType.SUPER_ADMIN + def can_mgmt_all_problem(self): + return self.problem_permission == ProblemPermission.ALL class Meta: db_table = "user" diff --git a/account/serializers.py b/account/serializers.py index aa274a8..fe128f6 100644 --- a/account/serializers.py +++ b/account/serializers.py @@ -1,6 +1,6 @@ from utils.api import DateTimeTZField, serializers -from .models import AdminType, User, ProblemPermission +from .models import AdminType, ProblemPermission, User class UserLoginSerializer(serializers.Serializer): diff --git a/account/tests.py b/account/tests.py index 7a8f6ef..877b317 100644 --- a/account/tests.py +++ b/account/tests.py @@ -8,7 +8,7 @@ from otpauth import OtpAuth from utils.api.tests import APIClient, APITestCase from utils.shortcuts import rand_str -from .models import AdminType, User, ProblemPermission +from .models import AdminType, ProblemPermission, User class PermissionDecoratorTest(APITestCase): diff --git a/account/views/admin.py b/account/views/admin.py index 421e9c2..49e8093 100644 --- a/account/views/admin.py +++ b/account/views/admin.py @@ -6,7 +6,7 @@ from utils.api import APIView, validate_serializer from utils.shortcuts import rand_str from ..decorators import super_admin_required -from ..models import User, AdminType, ProblemPermission +from ..models import AdminType, ProblemPermission, User from ..serializers import EditUserSerializer, UserSerializer diff --git a/contest/views/admin.py b/contest/views/admin.py index 096cb90..0b8e750 100644 --- a/contest/views/admin.py +++ b/contest/views/admin.py @@ -28,7 +28,7 @@ class ContestAPI(APIView): data = request.data try: contest = Contest.objects.get(id=data.pop("id")) - if request.user.is_admin_role(): + if request.user.is_admin(): contest = contest.get(created_by=request.user) except Contest.DoesNotExist: return self.error("Contest does not exist") @@ -48,7 +48,7 @@ class ContestAPI(APIView): if contest_id: try: contest = Contest.objects.get(id=contest_id) - if request.user.is_admin_role(): + if request.user.is_admin(): contest = contest.get(created_by=request.user) return self.success(ContestSerializer(contest).data) except Contest.DoesNotExist: @@ -60,7 +60,7 @@ class ContestAPI(APIView): if keyword: contests = contests.filter(title__contains=keyword) - if request.user.is_admin_role(): + if request.user.is_admin(): contests = contests.filter(created_by=request.user) return self.success(self.paginate_data(request, contests, ContestSerializer)) @@ -71,7 +71,7 @@ class ContestAnnouncementAPI(APIView): data = request.data try: contest = Contest.objects.get(id=data.pop("contest_id")) - if request.user.is_admin_role(): + if request.user.is_admin(): contest = contest.get(created_by=request.user) data["contest"] = contest data["created_by"] = request.user @@ -83,7 +83,7 @@ class ContestAnnouncementAPI(APIView): def delete(self, request): announcement_id = request.GET.get("id") if announcement_id: - if request.user.is_admin_role(): + if request.user.is_admin(): ContestAnnouncement.objects.filter(id=announcement_id, contest__created_by=request.user).delete() else: ContestAnnouncement.objects.filter(id=announcement_id).delete() diff --git a/oj/settings.py b/oj/settings.py index d042b3d..9e23a7b 100644 --- a/oj/settings.py +++ b/oj/settings.py @@ -58,7 +58,7 @@ MIDDLEWARE_CLASSES = ( 'django.contrib.messages.middleware.MessageMiddleware', 'django.middleware.clickjacking.XFrameOptionsMiddleware', 'django.middleware.security.SecurityMiddleware', - 'account.middleware.AdminRequiredMiddleware', + 'account.middleware.AdminRoleRequiredMiddleware', 'account.middleware.SessionSecurityMiddleware', 'account.middleware.TimezoneMiddleware' ) diff --git a/problem/views/admin.py b/problem/views/admin.py index 2006841..01f5522 100644 --- a/problem/views/admin.py +++ b/problem/views/admin.py @@ -5,13 +5,13 @@ import zipfile from django.conf import settings -from account.decorators import admin_required +from account.decorators import problem_permission_required from utils.api import APIView, CSRFExemptAPIView, validate_serializer from utils.shortcuts import rand_str from ..models import Problem, ProblemRuleType, ProblemTag -from ..serializers import (CreateProblemSerializer, ProblemSerializer, - TestCaseUploadForm, EditProblemSerializer) +from ..serializers import (CreateProblemSerializer, EditProblemSerializer, + ProblemSerializer, TestCaseUploadForm) class TestCaseUploadAPI(CSRFExemptAPIView): @@ -41,7 +41,7 @@ class TestCaseUploadAPI(CSRFExemptAPIView): else: return sorted(ret) - @admin_required + @problem_permission_required def post(self, request): form = TestCaseUploadForm(request.POST, request.FILES) if form.is_valid(): @@ -109,6 +109,7 @@ class TestCaseUploadAPI(CSRFExemptAPIView): class ProblemAPI(APIView): @validate_serializer(CreateProblemSerializer) + @problem_permission_required def post(self, request): data = request.data @@ -151,29 +152,34 @@ class ProblemAPI(APIView): problem.tags.add(tag) return self.success() + @problem_permission_required def get(self, request): problem_id = request.GET.get("id") + user = request.user if problem_id: try: problem = Problem.objects.get(id=problem_id) - if request.user.is_admin_role(): + if not user.can_mgmt_all_problem(): problem = problem.get(created_by=request.user) return self.success(ProblemSerializer(problem).data) except Problem.DoesNotExist: return self.error("Problem does not exist") problems = Problem.objects.all().order_by("-create_time") - if request.user.is_admin_role(): + if not user.can_mgmt_all_problem(): problems = problems.filter(created_by=request.user) return self.success(self.paginate_data(request, problems, ProblemSerializer)) @validate_serializer(EditProblemSerializer) + @problem_permission_required def put(self, request): data = request.data - id = data.pop("id") + problem_id = data.pop("id") + user = request.user + try: - problem = Problem.objects.get(id=id) - if request.user.is_admin_role(): + problem = Problem.objects.get(id=problem_id) + if not user.can_mgmt_all_problem(): problem = problem.get(created_by=request.user) except Problem.DoesNotExist: return self.error("Problem does not exist") @@ -181,12 +187,12 @@ class ProblemAPI(APIView): _id = data["_id"] if _id: try: - Problem.objects.exclude(id=id).get(_id=_id) + Problem.objects.exclude(id=problem_id).get(_id=_id) return self.error("Display ID already exists") except Problem.DoesNotExist: pass else: - data["_id"] = str(id) + data["_id"] = str(problem_id) if data["spj"]: if not data["spj_language"] or not data["spj_code"]: diff --git a/utils/api/api.py b/utils/api/api.py index 7400fde..78018cd 100644 --- a/utils/api/api.py +++ b/utils/api/api.py @@ -90,7 +90,6 @@ class APIView(View): return self._serializer_error_to_str({_k: _v}) def invalid_serializer(self, serializer): - print(serializer.errors) k, v = self._serializer_error_to_str(serializer.errors) if k != "non_field_errors": return self.error(err="invalid-" + k, msg=k + ": " + v) diff --git a/utils/api/tests.py b/utils/api/tests.py index d7c20b5..df3bfbe 100644 --- a/utils/api/tests.py +++ b/utils/api/tests.py @@ -2,14 +2,14 @@ from django.core.urlresolvers import reverse from django.test.testcases import TestCase from rest_framework.test import APIClient -from account.models import AdminType, User, UserProfile +from account.models import AdminType, ProblemPermission, User, UserProfile class APITestCase(TestCase): client_class = APIClient - def create_user(self, username, password, admin_type=AdminType.REGULAR_USER, login=True): - user = User.objects.create(username=username, admin_type=admin_type) + def create_user(self, username, password, admin_type=AdminType.REGULAR_USER, login=True, problem_permission=ProblemPermission.NONE): + user = User.objects.create(username=username, admin_type=admin_type, problem_permission=problem_permission) user.set_password(password) UserProfile.objects.create(user=user, time_zone="Asia/Shanghai") user.save() @@ -18,10 +18,12 @@ class APITestCase(TestCase): return user def create_admin(self, username="admin", password="admin", login=True): - return self.create_user(username=username, password=password, admin_type=AdminType.ADMIN, login=login) + return self.create_user(username=username, password=password, admin_type=AdminType.ADMIN, problem_permission=ProblemPermission.OWN, + login=login) def create_super_admin(self, username="root", password="root", login=True): - return self.create_user(username=username, password=password, admin_type=AdminType.SUPER_ADMIN, login=login) + return self.create_user(username=username, password=password, admin_type=AdminType.SUPER_ADMIN, + problem_permission=ProblemPermission.ALL, login=login) def reverse(self, url_name): return reverse(url_name) diff --git a/utils/management/commands/initadmin.py b/utils/management/commands/initadmin.py index 59cdf4d..5829178 100644 --- a/utils/management/commands/initadmin.py +++ b/utils/management/commands/initadmin.py @@ -1,6 +1,6 @@ from django.core.management.base import BaseCommand -from account.models import AdminType, User, UserProfile +from account.models import AdminType, ProblemPermission, User, UserProfile from utils.shortcuts import rand_str # NOQA @@ -26,7 +26,8 @@ class Command(BaseCommand): else: self.stdout.write(self.style.ERROR("User 'root' is not super admin.")) except User.DoesNotExist: - user = User.objects.create(username="root", email="root@oj.com", admin_type=AdminType.SUPER_ADMIN) + user = User.objects.create(username="root", email="root@oj.com", admin_type=AdminType.SUPER_ADMIN, + problem_permission=ProblemPermission.ALL) # for dev # rand_password = rand_str(length=6) rand_password = "rootroot"