rename some method and add some decorator

This commit is contained in:
virusdefender
2017-02-10 11:41:01 +08:00
parent 802f27a516
commit 817e5aadae
12 changed files with 58 additions and 41 deletions

View File

@@ -4,7 +4,7 @@ from django.utils.translation import ugettext as _
from utils.api import JSONResponse
from .models import AdminType
from .models import ProblemPermission
class BasePermissionDecorator(object):
@@ -38,11 +38,20 @@ class login_required(BasePermissionDecorator):
class super_admin_required(BasePermissionDecorator):
def check_permission(self):
return self.request.user.is_authenticated() and \
self.request.user.admin_type == AdminType.SUPER_ADMIN
user = self.request.user
return user.is_authenticated() and user.is_super_admin()
class admin_required(BasePermissionDecorator):
class admin_role_required(BasePermissionDecorator):
def check_permission(self):
return self.request.user.is_authenticated() and \
self.request.user.admin_type in [AdminType.SUPER_ADMIN, AdminType.ADMIN]
user = self.request.user
return user.is_authenticated() and user.is_admin_role()
class problem_permission_required(admin_role_required):
def check_permission(self):
if not super(problem_permission_required, self).check_permission():
return False
if self.request.user.problem_permission == ProblemPermission.NONE:
return False
return True

View File

@@ -22,11 +22,11 @@ class SessionSecurityMiddleware(object):
request.session["last_activity"] = time.time()
class AdminRequiredMiddleware(object):
class AdminRoleRequiredMiddleware(object):
def process_request(self, request):
path = request.path_info
if path.startswith("/admin/") or path.startswith("/api/admin/"):
if not(request.user.is_authenticated() and request.user.is_admin()):
if not(request.user.is_authenticated() and request.user.is_admin_role()):
return JSONResponse.response({"error": "login-required", "data": _("Please login in first")})

View File

@@ -51,14 +51,14 @@ class User(AbstractBaseUser):
objects = UserManager()
def is_admin(self):
return self.admin_type in [AdminType.ADMIN, AdminType.SUPER_ADMIN]
def is_super_admin(self):
return self.admin_type == AdminType.SUPER_ADMIN
def is_admin_role(self):
return self.admin_type == AdminType.ADMIN
return self.admin_type in [AdminType.ADMIN, AdminType.SUPER_ADMIN]
def is_super_admin_role(self):
return self.admin_type == AdminType.SUPER_ADMIN
def can_mgmt_all_problem(self):
return self.problem_permission == ProblemPermission.ALL
class Meta:
db_table = "user"

View File

@@ -1,6 +1,6 @@
from utils.api import DateTimeTZField, serializers
from .models import AdminType, User, ProblemPermission
from .models import AdminType, ProblemPermission, User
class UserLoginSerializer(serializers.Serializer):

View File

@@ -8,7 +8,7 @@ from otpauth import OtpAuth
from utils.api.tests import APIClient, APITestCase
from utils.shortcuts import rand_str
from .models import AdminType, User, ProblemPermission
from .models import AdminType, ProblemPermission, User
class PermissionDecoratorTest(APITestCase):

View File

@@ -6,7 +6,7 @@ from utils.api import APIView, validate_serializer
from utils.shortcuts import rand_str
from ..decorators import super_admin_required
from ..models import User, AdminType, ProblemPermission
from ..models import AdminType, ProblemPermission, User
from ..serializers import EditUserSerializer, UserSerializer

View File

@@ -28,7 +28,7 @@ class ContestAPI(APIView):
data = request.data
try:
contest = Contest.objects.get(id=data.pop("id"))
if request.user.is_admin_role():
if request.user.is_admin():
contest = contest.get(created_by=request.user)
except Contest.DoesNotExist:
return self.error("Contest does not exist")
@@ -48,7 +48,7 @@ class ContestAPI(APIView):
if contest_id:
try:
contest = Contest.objects.get(id=contest_id)
if request.user.is_admin_role():
if request.user.is_admin():
contest = contest.get(created_by=request.user)
return self.success(ContestSerializer(contest).data)
except Contest.DoesNotExist:
@@ -60,7 +60,7 @@ class ContestAPI(APIView):
if keyword:
contests = contests.filter(title__contains=keyword)
if request.user.is_admin_role():
if request.user.is_admin():
contests = contests.filter(created_by=request.user)
return self.success(self.paginate_data(request, contests, ContestSerializer))
@@ -71,7 +71,7 @@ class ContestAnnouncementAPI(APIView):
data = request.data
try:
contest = Contest.objects.get(id=data.pop("contest_id"))
if request.user.is_admin_role():
if request.user.is_admin():
contest = contest.get(created_by=request.user)
data["contest"] = contest
data["created_by"] = request.user
@@ -83,7 +83,7 @@ class ContestAnnouncementAPI(APIView):
def delete(self, request):
announcement_id = request.GET.get("id")
if announcement_id:
if request.user.is_admin_role():
if request.user.is_admin():
ContestAnnouncement.objects.filter(id=announcement_id, contest__created_by=request.user).delete()
else:
ContestAnnouncement.objects.filter(id=announcement_id).delete()

View File

@@ -58,7 +58,7 @@ MIDDLEWARE_CLASSES = (
'django.contrib.messages.middleware.MessageMiddleware',
'django.middleware.clickjacking.XFrameOptionsMiddleware',
'django.middleware.security.SecurityMiddleware',
'account.middleware.AdminRequiredMiddleware',
'account.middleware.AdminRoleRequiredMiddleware',
'account.middleware.SessionSecurityMiddleware',
'account.middleware.TimezoneMiddleware'
)

View File

@@ -5,13 +5,13 @@ import zipfile
from django.conf import settings
from account.decorators import admin_required
from account.decorators import problem_permission_required
from utils.api import APIView, CSRFExemptAPIView, validate_serializer
from utils.shortcuts import rand_str
from ..models import Problem, ProblemRuleType, ProblemTag
from ..serializers import (CreateProblemSerializer, ProblemSerializer,
TestCaseUploadForm, EditProblemSerializer)
from ..serializers import (CreateProblemSerializer, EditProblemSerializer,
ProblemSerializer, TestCaseUploadForm)
class TestCaseUploadAPI(CSRFExemptAPIView):
@@ -41,7 +41,7 @@ class TestCaseUploadAPI(CSRFExemptAPIView):
else:
return sorted(ret)
@admin_required
@problem_permission_required
def post(self, request):
form = TestCaseUploadForm(request.POST, request.FILES)
if form.is_valid():
@@ -109,6 +109,7 @@ class TestCaseUploadAPI(CSRFExemptAPIView):
class ProblemAPI(APIView):
@validate_serializer(CreateProblemSerializer)
@problem_permission_required
def post(self, request):
data = request.data
@@ -151,29 +152,34 @@ class ProblemAPI(APIView):
problem.tags.add(tag)
return self.success()
@problem_permission_required
def get(self, request):
problem_id = request.GET.get("id")
user = request.user
if problem_id:
try:
problem = Problem.objects.get(id=problem_id)
if request.user.is_admin_role():
if not user.can_mgmt_all_problem():
problem = problem.get(created_by=request.user)
return self.success(ProblemSerializer(problem).data)
except Problem.DoesNotExist:
return self.error("Problem does not exist")
problems = Problem.objects.all().order_by("-create_time")
if request.user.is_admin_role():
if not user.can_mgmt_all_problem():
problems = problems.filter(created_by=request.user)
return self.success(self.paginate_data(request, problems, ProblemSerializer))
@validate_serializer(EditProblemSerializer)
@problem_permission_required
def put(self, request):
data = request.data
id = data.pop("id")
problem_id = data.pop("id")
user = request.user
try:
problem = Problem.objects.get(id=id)
if request.user.is_admin_role():
problem = Problem.objects.get(id=problem_id)
if not user.can_mgmt_all_problem():
problem = problem.get(created_by=request.user)
except Problem.DoesNotExist:
return self.error("Problem does not exist")
@@ -181,12 +187,12 @@ class ProblemAPI(APIView):
_id = data["_id"]
if _id:
try:
Problem.objects.exclude(id=id).get(_id=_id)
Problem.objects.exclude(id=problem_id).get(_id=_id)
return self.error("Display ID already exists")
except Problem.DoesNotExist:
pass
else:
data["_id"] = str(id)
data["_id"] = str(problem_id)
if data["spj"]:
if not data["spj_language"] or not data["spj_code"]:

View File

@@ -90,7 +90,6 @@ class APIView(View):
return self._serializer_error_to_str({_k: _v})
def invalid_serializer(self, serializer):
print(serializer.errors)
k, v = self._serializer_error_to_str(serializer.errors)
if k != "non_field_errors":
return self.error(err="invalid-" + k, msg=k + ": " + v)

View File

@@ -2,14 +2,14 @@ from django.core.urlresolvers import reverse
from django.test.testcases import TestCase
from rest_framework.test import APIClient
from account.models import AdminType, User, UserProfile
from account.models import AdminType, ProblemPermission, User, UserProfile
class APITestCase(TestCase):
client_class = APIClient
def create_user(self, username, password, admin_type=AdminType.REGULAR_USER, login=True):
user = User.objects.create(username=username, admin_type=admin_type)
def create_user(self, username, password, admin_type=AdminType.REGULAR_USER, login=True, problem_permission=ProblemPermission.NONE):
user = User.objects.create(username=username, admin_type=admin_type, problem_permission=problem_permission)
user.set_password(password)
UserProfile.objects.create(user=user, time_zone="Asia/Shanghai")
user.save()
@@ -18,10 +18,12 @@ class APITestCase(TestCase):
return user
def create_admin(self, username="admin", password="admin", login=True):
return self.create_user(username=username, password=password, admin_type=AdminType.ADMIN, login=login)
return self.create_user(username=username, password=password, admin_type=AdminType.ADMIN, problem_permission=ProblemPermission.OWN,
login=login)
def create_super_admin(self, username="root", password="root", login=True):
return self.create_user(username=username, password=password, admin_type=AdminType.SUPER_ADMIN, login=login)
return self.create_user(username=username, password=password, admin_type=AdminType.SUPER_ADMIN,
problem_permission=ProblemPermission.ALL, login=login)
def reverse(self, url_name):
return reverse(url_name)

View File

@@ -1,6 +1,6 @@
from django.core.management.base import BaseCommand
from account.models import AdminType, User, UserProfile
from account.models import AdminType, ProblemPermission, User, UserProfile
from utils.shortcuts import rand_str # NOQA
@@ -26,7 +26,8 @@ class Command(BaseCommand):
else:
self.stdout.write(self.style.ERROR("User 'root' is not super admin."))
except User.DoesNotExist:
user = User.objects.create(username="root", email="root@oj.com", admin_type=AdminType.SUPER_ADMIN)
user = User.objects.create(username="root", email="root@oj.com", admin_type=AdminType.SUPER_ADMIN,
problem_permission=ProblemPermission.ALL)
# for dev
# rand_password = rand_str(length=6)
rand_password = "rootroot"