完善contest权限控制
This commit is contained in:
@@ -4,7 +4,7 @@ from utils.api import JSONResponse
|
||||
|
||||
from .models import ProblemPermission
|
||||
|
||||
from contest.models import Contest, ContestType, ContestStatus
|
||||
from contest.models import Contest, ContestType, ContestStatus, ContestRuleType
|
||||
|
||||
|
||||
class BasePermissionDecorator(object):
|
||||
@@ -25,7 +25,7 @@ class BasePermissionDecorator(object):
|
||||
return self.error("Your account is disabled")
|
||||
return self.func(*args, **kwargs)
|
||||
else:
|
||||
return self.error("Please login in first")
|
||||
return self.error("Please login first")
|
||||
|
||||
def check_permission(self):
|
||||
raise NotImplementedError()
|
||||
@@ -57,45 +57,54 @@ class problem_permission_required(admin_role_required):
|
||||
return True
|
||||
|
||||
|
||||
def check_contest_permission(func):
|
||||
def check_contest_permission(check_type="details"):
|
||||
"""
|
||||
只供Class based view 使用,检查用户是否有权进入该contest,
|
||||
只供Class based view 使用,检查用户是否有权进入该contest, check_type 可选 details, problems, ranks, submissions
|
||||
若通过验证,在view中可通过self.contest获得该contest
|
||||
"""
|
||||
@functools.wraps(func)
|
||||
def _check_permission(*args, **kwargs):
|
||||
self = args[0]
|
||||
request = args[1]
|
||||
user = request.user
|
||||
if kwargs.get("contest_id"):
|
||||
contest_id = kwargs.pop("contest_id")
|
||||
else:
|
||||
contest_id = request.GET.get("contest_id")
|
||||
if not contest_id:
|
||||
return self.error("Parameter contest_id not exist.")
|
||||
|
||||
try:
|
||||
# use self.contest to avoid query contest again in view.
|
||||
self.contest = Contest.objects.select_related("created_by").get(id=contest_id, visible=True)
|
||||
except Contest.DoesNotExist:
|
||||
return self.error("Contest %s doesn't exist" % contest_id)
|
||||
def decorator(func):
|
||||
def _check_permission(*args, **kwargs):
|
||||
self = args[0]
|
||||
request = args[1]
|
||||
user = request.user
|
||||
if kwargs.get("contest_id"):
|
||||
contest_id = kwargs.pop("contest_id")
|
||||
else:
|
||||
contest_id = request.GET.get("contest_id")
|
||||
if not contest_id:
|
||||
return self.error("Parameter contest_id not exist.")
|
||||
|
||||
try:
|
||||
# use self.contest to avoid query contest again in view.
|
||||
self.contest = Contest.objects.select_related("created_by").get(id=contest_id, visible=True)
|
||||
except Contest.DoesNotExist:
|
||||
return self.error("Contest %s doesn't exist" % contest_id)
|
||||
|
||||
# creator or owner
|
||||
if self.contest.is_contest_admin(user):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
if self.contest.contest_type == ContestType.PASSWORD_PROTECTED_CONTEST:
|
||||
# Anonymous
|
||||
if not user.is_authenticated():
|
||||
return self.error("Please login first.")
|
||||
# password error
|
||||
if ("accessible_contests" not in request.session) or \
|
||||
(self.contest.id not in request.session["accessible_contests"]):
|
||||
return self.error("Password is required.")
|
||||
|
||||
# regular use get contest problems, ranks etc. before contest started
|
||||
if self.contest.status == ContestStatus.CONTEST_NOT_START and check_type != "details":
|
||||
return self.error("Contest has not started yet.")
|
||||
|
||||
# check is user have permission to get ranks, submissions OI Contest
|
||||
if self.contest.status == ContestStatus.CONTEST_UNDERWAY and self.contest.rule_type == ContestRuleType.OI:
|
||||
if not self.contest.real_time_rank and (check_type == "ranks" or check_type == "submissions"):
|
||||
return self.error(f"No permission to get {check_type}")
|
||||
|
||||
# creator or owner
|
||||
if self.contest.is_contest_admin(user):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
if self.contest.status == ContestStatus.CONTEST_NOT_START:
|
||||
return self.error("Contest has not started yet.")
|
||||
return _check_permission
|
||||
|
||||
if self.contest.contest_type == ContestType.PASSWORD_PROTECTED_CONTEST:
|
||||
# Anonymous
|
||||
if not user.is_authenticated():
|
||||
return self.error("Please login in first.")
|
||||
# password error
|
||||
if ("accessible_contests" not in request.session) or \
|
||||
(self.contest.id not in request.session["accessible_contests"]):
|
||||
return self.error("Password is required.")
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return _check_permission
|
||||
return decorator
|
||||
|
||||
@@ -8,7 +8,7 @@ from .models import AdminType, ProblemPermission, User, UserProfile
|
||||
class UserLoginSerializer(serializers.Serializer):
|
||||
username = serializers.CharField()
|
||||
password = serializers.CharField()
|
||||
tfa_code = serializers.CharField(required=False, allow_null=True)
|
||||
tfa_code = serializers.CharField(required=False, allow_blank=True)
|
||||
|
||||
|
||||
class UsernameOrEmailCheckSerializer(serializers.Serializer):
|
||||
@@ -26,6 +26,13 @@ class UserRegisterSerializer(serializers.Serializer):
|
||||
class UserChangePasswordSerializer(serializers.Serializer):
|
||||
old_password = serializers.CharField()
|
||||
new_password = serializers.CharField(min_length=6)
|
||||
tfa_code = serializers.CharField(required=False, allow_blank=True)
|
||||
|
||||
|
||||
class UserChangeEmailSerializer(serializers.Serializer):
|
||||
password = serializers.CharField()
|
||||
new_email = serializers.EmailField(max_length=64)
|
||||
tfa_code = serializers.CharField(required=False, allow_blank=True)
|
||||
|
||||
|
||||
class UserSerializer(serializers.ModelSerializer):
|
||||
|
||||
@@ -362,7 +362,7 @@ class UserChangePasswordAPITest(CaptchaTest):
|
||||
|
||||
def test_login_required(self):
|
||||
response = self.client.post(self.url, data=self.data)
|
||||
self.assertEqual(response.data, {"error": "permission-denied", "data": "Please login in first"})
|
||||
self.assertEqual(response.data, {"error": "permission-denied", "data": "Please login first"})
|
||||
|
||||
def test_valid_ola_password(self):
|
||||
self.assertTrue(self.client.login(username=self.username, password=self.old_password))
|
||||
@@ -476,13 +476,13 @@ class UserRankAPITest(APITestCase):
|
||||
def test_get_acm_rank(self):
|
||||
resp = self.client.get(self.url, data={"rule": ContestRuleType.ACM})
|
||||
self.assertSuccess(resp)
|
||||
data = resp.data["data"]
|
||||
data = resp.data["data"]["results"]
|
||||
self.assertEqual(data[0]["user"]["username"], "test1")
|
||||
self.assertEqual(data[1]["user"]["username"], "test2")
|
||||
|
||||
def test_get_oi_rank(self):
|
||||
resp = self.client.get(self.url, data={"rule": ContestRuleType.OI})
|
||||
self.assertSuccess(resp)
|
||||
data = resp.data["data"]
|
||||
data = resp.data["data"]["results"]
|
||||
self.assertEqual(data[0]["user"]["username"], "test2")
|
||||
self.assertEqual(data[1]["user"]["username"], "test1")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from django.conf.urls import url
|
||||
|
||||
from ..views.oj import (ApplyResetPasswordAPI, ResetPasswordAPI,
|
||||
UserChangePasswordAPI, UserRegisterAPI,
|
||||
UserChangePasswordAPI, UserRegisterAPI, UserChangeEmailAPI,
|
||||
UserLoginAPI, UserLogoutAPI, UsernameOrEmailCheck,
|
||||
AvatarUploadAPI, TwoFactorAuthAPI, UserProfileAPI,
|
||||
UserRankAPI, CheckTFARequiredAPI, SessionManagementAPI)
|
||||
@@ -13,6 +13,7 @@ urlpatterns = [
|
||||
url(r"^logout/?$", UserLogoutAPI.as_view(), name="user_logout_api"),
|
||||
url(r"^register/?$", UserRegisterAPI.as_view(), name="user_register_api"),
|
||||
url(r"^change_password/?$", UserChangePasswordAPI.as_view(), name="user_change_password_api"),
|
||||
url(r"^change_email/?$", UserChangeEmailAPI.as_view(), name="user_change_email"),
|
||||
url(r"^apply_reset_password/?$", ApplyResetPasswordAPI.as_view(), name="apply_reset_password_api"),
|
||||
url(r"^reset_password/?$", ResetPasswordAPI.as_view(), name="reset_password_api"),
|
||||
url(r"^captcha/?$", CaptchaAPIView.as_view(), name="show_captcha"),
|
||||
|
||||
@@ -21,7 +21,7 @@ from ..models import User, UserProfile
|
||||
from ..serializers import (ApplyResetPasswordSerializer, ResetPasswordSerializer,
|
||||
UserChangePasswordSerializer, UserLoginSerializer,
|
||||
UserRegisterSerializer, UsernameOrEmailCheckSerializer,
|
||||
RankInfoSerializer)
|
||||
RankInfoSerializer, UserChangeEmailSerializer)
|
||||
from ..serializers import (TwoFactorAuthCodeSerializer, UserProfileSerializer,
|
||||
EditUserProfileSerializer, AvatarUploadForm)
|
||||
from ..tasks import send_email_async
|
||||
@@ -176,11 +176,6 @@ class UserLoginAPI(APIView):
|
||||
else:
|
||||
return self.error("Invalid username or password")
|
||||
|
||||
# todo remove this, only for debug use
|
||||
def get(self, request):
|
||||
auth.login(request, auth.authenticate(username=request.GET["username"], password=request.GET["password"]))
|
||||
return self.success()
|
||||
|
||||
|
||||
class UserLogoutAPI(APIView):
|
||||
def get(self, request):
|
||||
@@ -233,6 +228,27 @@ class UserRegisterAPI(APIView):
|
||||
return self.success("Succeeded")
|
||||
|
||||
|
||||
class UserChangeEmailAPI(APIView):
|
||||
@validate_serializer(UserChangeEmailSerializer)
|
||||
@login_required
|
||||
def post(self, request):
|
||||
data = request.data
|
||||
user = auth.authenticate(username=request.user.username, password=data["password"])
|
||||
if user:
|
||||
if user.two_factor_auth:
|
||||
if "tfa_code" not in data:
|
||||
return self.error("tfa_required")
|
||||
if not OtpAuth(user.tfa_token).valid_totp(data["tfa_code"]):
|
||||
return self.error("Invalid two factor verification code")
|
||||
if User.objects.filter(email=data["new_email"]).exists():
|
||||
return self.error("The email is owned by other account")
|
||||
user.email = data["new_email"]
|
||||
user.save()
|
||||
return self.success("Succeeded")
|
||||
else:
|
||||
return self.error("Wrong password")
|
||||
|
||||
|
||||
class UserChangePasswordAPI(APIView):
|
||||
@validate_serializer(UserChangePasswordSerializer)
|
||||
@login_required
|
||||
@@ -244,7 +260,11 @@ class UserChangePasswordAPI(APIView):
|
||||
username = request.user.username
|
||||
user = auth.authenticate(username=username, password=data["old_password"])
|
||||
if user:
|
||||
# TODO: check tfa?
|
||||
if user.two_factor_auth:
|
||||
if "tfa_code" not in data:
|
||||
return self.error("tfa_required")
|
||||
if not OtpAuth(user.tfa_token).valid_totp(data["tfa_code"]):
|
||||
return self.error("Invalid two factor verification code")
|
||||
user.set_password(data["new_password"])
|
||||
user.save()
|
||||
return self.success("Succeeded")
|
||||
|
||||
Reference in New Issue
Block a user