diff --git a/account/decorators.py b/account/decorators.py index 6b10896..01cdd68 100644 --- a/account/decorators.py +++ b/account/decorators.py @@ -20,5 +20,13 @@ def login_required(func): return check -def admin_required(): - pass +def admin_required(func): + def check(*args, **kwargs): + request = args[-1] + if request.user.is_authenticated() and request.user.admin_type: + return func(*args, **kwargs) + if request.is_ajax(): + return error_response(u"需要管理员权限") + else: + return render(request, "utils/error.html", {"error": "需要管理员权限"}) + return check diff --git a/account/models.py b/account/models.py index 1f02eef..7c5afb9 100644 --- a/account/models.py +++ b/account/models.py @@ -21,7 +21,8 @@ class User(AbstractBaseUser): real_name = models.CharField(max_length=30, blank=True, null=True) # 用户邮箱 email = models.EmailField(max_length=254, blank=True, null=True) - admin_group = models.ForeignKey(AdminGroup, null=True, on_delete=models.SET_NULL) + # 0代表不是管理员 1是普通管理员 2是超级管理员 + admin_type = models.IntegerField(default=0) USERNAME_FIELD = 'username' REQUIRED_FIELDS = [] diff --git a/account/test_urls.py b/account/test_urls.py index 4066721..9757d90 100644 --- a/account/test_urls.py +++ b/account/test_urls.py @@ -1,12 +1,18 @@ # coding=utf-8 from django.conf.urls import include, url -from .tests import LoginRequiredCBVTestWithArgs, LoginRequiredCBVTestWithoutArgs +from .tests import (LoginRequiredCBVTestWithArgs, LoginRequiredCBVTestWithoutArgs, + AdminRequiredCBVTestWithArgs, AdminRequiredCBVTestWithoutArgs) urlpatterns = [ - url(r'^test/fbv/1/$', "account.tests.login_required_FBV_test_without_args"), - url(r'^test/fbv/(?P\d+)/$', "account.tests.login_required_FBC_test_with_args"), - url(r'^test/cbv/1/$', LoginRequiredCBVTestWithoutArgs.as_view()), - url(r'^test/cbv/(?P\d+)/$', LoginRequiredCBVTestWithArgs.as_view()), + url(r'^login_required_test/fbv/1/$', "account.tests.login_required_FBV_test_without_args"), + url(r'^login_required_test/fbv/(?P\d+)/$', "account.tests.login_required_FBC_test_with_args"), + url(r'^login_required_test/cbv/1/$', LoginRequiredCBVTestWithoutArgs.as_view()), + url(r'^login_required_test/cbv/(?P\d+)/$', LoginRequiredCBVTestWithArgs.as_view()), + + url(r'^admin_required_test/fbv/1/$', "account.tests.admin_required_FBV_test_without_args"), + url(r'^admin_required_test/fbv/(?P\d+)/$', "account.tests.admin_required_FBC_test_with_args"), + url(r'^admin_required_test/cbv/1/$', AdminRequiredCBVTestWithoutArgs.as_view()), + url(r'^admin_required_test/cbv/(?P\d+)/$', AdminRequiredCBVTestWithArgs.as_view()), ] diff --git a/account/tests.py b/account/tests.py index 89410da..3529bb6 100644 --- a/account/tests.py +++ b/account/tests.py @@ -10,7 +10,7 @@ from rest_framework.views import APIView from rest_framework.response import Response from .models import User -from .decorators import login_required +from .decorators import login_required, admin_required class UserLoginTest(TestCase): @@ -159,6 +159,7 @@ class LoginRequiredCBVTestWithoutArgs(APIView): def get(self, request): return HttpResponse("class based view login required test1") + class LoginRequiredCBVTestWithArgs(APIView): @login_required def get(self, request, problem_id): @@ -176,40 +177,113 @@ class LoginRequiredDecoratorTest(TestCase): def test_fbv_without_args(self): # 没登陆 - response = self.client.get("/test/fbv/1/") + response = self.client.get("/login_required_test/fbv/1/") self.assertTemplateUsed(response, "utils/error.html") # 登陆后 self.client.login(username="test", password="test") - response = self.client.get("/test/fbv/1/") + response = self.client.get("/login_required_test/fbv/1/") self.assertEqual(response.content, "function based view test1") def test_fbv_with_args(self): # 没登陆 - response = self.client.get("/test/fbv/1024/") + response = self.client.get("/login_required_test/fbv/1024/") self.assertTemplateUsed(response, "utils/error.html") # 登陆后 self.client.login(username="test", password="test") - response = self.client.get("/test/fbv/1024/") + response = self.client.get("/login_required_test/fbv/1024/") self.assertEqual(response.content, "1024") def test_cbv_without_args(self): # 没登陆 - response = self.client.get("/test/cbv/1/") + response = self.client.get("/login_required_test/cbv/1/") self.assertTemplateUsed(response, "utils/error.html") # 登陆后 self.client.login(username="test", password="test") - response = self.client.get("/test/cbv/1/") + response = self.client.get("/login_required_test/cbv/1/") self.assertEqual(response.content, "class based view login required test1") def test_cbv_with_args(self): # 没登陆 - response = self.client.get("/test/cbv/1024/", HTTP_X_REQUESTED_WITH='XMLHttpRequest') + response = self.client.get("/login_required_test/cbv/1024/", HTTP_X_REQUESTED_WITH='XMLHttpRequest') self.assertEqual(json.loads(response.content), {"code": 1, "data": u"请先登录"}) # 登陆后 self.client.login(username="test", password="test") - response = self.client.get("/test/cbv/1024/") + response = self.client.get("/login_required_test/cbv/1024/") + self.assertEqual(response.content, "1024") + + +@admin_required +def admin_required_FBV_test_without_args(request): + return HttpResponse("function based view test1") + + +@admin_required +def admin_required_FBC_test_with_args(request, problem_id): + return HttpResponse(problem_id) + + +class AdminRequiredCBVTestWithoutArgs(APIView): + @admin_required + def get(self, request): + return HttpResponse("class based view login required test1") + + +class AdminRequiredCBVTestWithArgs(APIView): + @admin_required + def get(self, request, problem_id): + return HttpResponse(problem_id) + + +class AdminRequiredDecoratorTest(TestCase): + urls = 'account.test_urls' + + def setUp(self): + self.client = Client() + user = User.objects.create(username="test") + user.admin_type = 1 + user.set_password("test") + user.save() + + def test_fbv_without_args(self): + # 没登陆 + response = self.client.get("/admin_required_test/fbv/1/") + self.assertTemplateUsed(response, "utils/error.html") + + # 登陆后 + self.client.login(username="test", password="test") + response = self.client.get("/admin_required_test/fbv/1/") + self.assertEqual(response.content, "function based view test1") + + def test_fbv_with_args(self): + # 没登陆 + response = self.client.get("/admin_required_test/fbv/1024/") + self.assertTemplateUsed(response, "utils/error.html") + + # 登陆后 + self.client.login(username="test", password="test") + response = self.client.get("/admin_required_test/fbv/1024/") + self.assertEqual(response.content, "1024") + + def test_cbv_without_args(self): + # 没登陆 + response = self.client.get("/admin_required_test/cbv/1/") + self.assertTemplateUsed(response, "utils/error.html") + + # 登陆后 + self.client.login(username="test", password="test") + response = self.client.get("/admin_required_test/cbv/1/") + self.assertEqual(response.content, "class based view login required test1") + + def test_cbv_with_args(self): + # 没登陆 + response = self.client.get("/admin_required_test/cbv/1024/", HTTP_X_REQUESTED_WITH='XMLHttpRequest') + self.assertEqual(json.loads(response.content), {"code": 1, "data": u"需要管理员权限"}) + + # 登陆后 + self.client.login(username="test", password="test") + response = self.client.get("/admin_required_test/cbv/1024/") self.assertEqual(response.content, "1024")