完善contest权限控制
This commit is contained in:
@@ -4,7 +4,7 @@ from utils.api import JSONResponse
|
|||||||
|
|
||||||
from .models import ProblemPermission
|
from .models import ProblemPermission
|
||||||
|
|
||||||
from contest.models import Contest, ContestType, ContestStatus
|
from contest.models import Contest, ContestType, ContestStatus, ContestRuleType
|
||||||
|
|
||||||
|
|
||||||
class BasePermissionDecorator(object):
|
class BasePermissionDecorator(object):
|
||||||
@@ -25,7 +25,7 @@ class BasePermissionDecorator(object):
|
|||||||
return self.error("Your account is disabled")
|
return self.error("Your account is disabled")
|
||||||
return self.func(*args, **kwargs)
|
return self.func(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
return self.error("Please login in first")
|
return self.error("Please login first")
|
||||||
|
|
||||||
def check_permission(self):
|
def check_permission(self):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
@@ -57,45 +57,54 @@ class problem_permission_required(admin_role_required):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def check_contest_permission(func):
|
def check_contest_permission(check_type="details"):
|
||||||
"""
|
"""
|
||||||
只供Class based view 使用,检查用户是否有权进入该contest,
|
只供Class based view 使用,检查用户是否有权进入该contest, check_type 可选 details, problems, ranks, submissions
|
||||||
若通过验证,在view中可通过self.contest获得该contest
|
若通过验证,在view中可通过self.contest获得该contest
|
||||||
"""
|
"""
|
||||||
@functools.wraps(func)
|
|
||||||
def _check_permission(*args, **kwargs):
|
|
||||||
self = args[0]
|
|
||||||
request = args[1]
|
|
||||||
user = request.user
|
|
||||||
if kwargs.get("contest_id"):
|
|
||||||
contest_id = kwargs.pop("contest_id")
|
|
||||||
else:
|
|
||||||
contest_id = request.GET.get("contest_id")
|
|
||||||
if not contest_id:
|
|
||||||
return self.error("Parameter contest_id not exist.")
|
|
||||||
|
|
||||||
try:
|
def decorator(func):
|
||||||
# use self.contest to avoid query contest again in view.
|
def _check_permission(*args, **kwargs):
|
||||||
self.contest = Contest.objects.select_related("created_by").get(id=contest_id, visible=True)
|
self = args[0]
|
||||||
except Contest.DoesNotExist:
|
request = args[1]
|
||||||
return self.error("Contest %s doesn't exist" % contest_id)
|
user = request.user
|
||||||
|
if kwargs.get("contest_id"):
|
||||||
|
contest_id = kwargs.pop("contest_id")
|
||||||
|
else:
|
||||||
|
contest_id = request.GET.get("contest_id")
|
||||||
|
if not contest_id:
|
||||||
|
return self.error("Parameter contest_id not exist.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# use self.contest to avoid query contest again in view.
|
||||||
|
self.contest = Contest.objects.select_related("created_by").get(id=contest_id, visible=True)
|
||||||
|
except Contest.DoesNotExist:
|
||||||
|
return self.error("Contest %s doesn't exist" % contest_id)
|
||||||
|
|
||||||
|
# creator or owner
|
||||||
|
if self.contest.is_contest_admin(user):
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
if self.contest.contest_type == ContestType.PASSWORD_PROTECTED_CONTEST:
|
||||||
|
# Anonymous
|
||||||
|
if not user.is_authenticated():
|
||||||
|
return self.error("Please login first.")
|
||||||
|
# password error
|
||||||
|
if ("accessible_contests" not in request.session) or \
|
||||||
|
(self.contest.id not in request.session["accessible_contests"]):
|
||||||
|
return self.error("Password is required.")
|
||||||
|
|
||||||
|
# regular use get contest problems, ranks etc. before contest started
|
||||||
|
if self.contest.status == ContestStatus.CONTEST_NOT_START and check_type != "details":
|
||||||
|
return self.error("Contest has not started yet.")
|
||||||
|
|
||||||
|
# check is user have permission to get ranks, submissions OI Contest
|
||||||
|
if self.contest.status == ContestStatus.CONTEST_UNDERWAY and self.contest.rule_type == ContestRuleType.OI:
|
||||||
|
if not self.contest.real_time_rank and (check_type == "ranks" or check_type == "submissions"):
|
||||||
|
return self.error(f"No permission to get {check_type}")
|
||||||
|
|
||||||
# creator or owner
|
|
||||||
if self.contest.is_contest_admin(user):
|
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
if self.contest.status == ContestStatus.CONTEST_NOT_START:
|
return _check_permission
|
||||||
return self.error("Contest has not started yet.")
|
|
||||||
|
|
||||||
if self.contest.contest_type == ContestType.PASSWORD_PROTECTED_CONTEST:
|
return decorator
|
||||||
# Anonymous
|
|
||||||
if not user.is_authenticated():
|
|
||||||
return self.error("Please login in first.")
|
|
||||||
# password error
|
|
||||||
if ("accessible_contests" not in request.session) or \
|
|
||||||
(self.contest.id not in request.session["accessible_contests"]):
|
|
||||||
return self.error("Password is required.")
|
|
||||||
|
|
||||||
return func(*args, **kwargs)
|
|
||||||
|
|
||||||
return _check_permission
|
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from .models import AdminType, ProblemPermission, User, UserProfile
|
|||||||
class UserLoginSerializer(serializers.Serializer):
|
class UserLoginSerializer(serializers.Serializer):
|
||||||
username = serializers.CharField()
|
username = serializers.CharField()
|
||||||
password = serializers.CharField()
|
password = serializers.CharField()
|
||||||
tfa_code = serializers.CharField(required=False, allow_null=True)
|
tfa_code = serializers.CharField(required=False, allow_blank=True)
|
||||||
|
|
||||||
|
|
||||||
class UsernameOrEmailCheckSerializer(serializers.Serializer):
|
class UsernameOrEmailCheckSerializer(serializers.Serializer):
|
||||||
@@ -26,6 +26,13 @@ class UserRegisterSerializer(serializers.Serializer):
|
|||||||
class UserChangePasswordSerializer(serializers.Serializer):
|
class UserChangePasswordSerializer(serializers.Serializer):
|
||||||
old_password = serializers.CharField()
|
old_password = serializers.CharField()
|
||||||
new_password = serializers.CharField(min_length=6)
|
new_password = serializers.CharField(min_length=6)
|
||||||
|
tfa_code = serializers.CharField(required=False, allow_blank=True)
|
||||||
|
|
||||||
|
|
||||||
|
class UserChangeEmailSerializer(serializers.Serializer):
|
||||||
|
password = serializers.CharField()
|
||||||
|
new_email = serializers.EmailField(max_length=64)
|
||||||
|
tfa_code = serializers.CharField(required=False, allow_blank=True)
|
||||||
|
|
||||||
|
|
||||||
class UserSerializer(serializers.ModelSerializer):
|
class UserSerializer(serializers.ModelSerializer):
|
||||||
|
|||||||
@@ -362,7 +362,7 @@ class UserChangePasswordAPITest(CaptchaTest):
|
|||||||
|
|
||||||
def test_login_required(self):
|
def test_login_required(self):
|
||||||
response = self.client.post(self.url, data=self.data)
|
response = self.client.post(self.url, data=self.data)
|
||||||
self.assertEqual(response.data, {"error": "permission-denied", "data": "Please login in first"})
|
self.assertEqual(response.data, {"error": "permission-denied", "data": "Please login first"})
|
||||||
|
|
||||||
def test_valid_ola_password(self):
|
def test_valid_ola_password(self):
|
||||||
self.assertTrue(self.client.login(username=self.username, password=self.old_password))
|
self.assertTrue(self.client.login(username=self.username, password=self.old_password))
|
||||||
@@ -476,13 +476,13 @@ class UserRankAPITest(APITestCase):
|
|||||||
def test_get_acm_rank(self):
|
def test_get_acm_rank(self):
|
||||||
resp = self.client.get(self.url, data={"rule": ContestRuleType.ACM})
|
resp = self.client.get(self.url, data={"rule": ContestRuleType.ACM})
|
||||||
self.assertSuccess(resp)
|
self.assertSuccess(resp)
|
||||||
data = resp.data["data"]
|
data = resp.data["data"]["results"]
|
||||||
self.assertEqual(data[0]["user"]["username"], "test1")
|
self.assertEqual(data[0]["user"]["username"], "test1")
|
||||||
self.assertEqual(data[1]["user"]["username"], "test2")
|
self.assertEqual(data[1]["user"]["username"], "test2")
|
||||||
|
|
||||||
def test_get_oi_rank(self):
|
def test_get_oi_rank(self):
|
||||||
resp = self.client.get(self.url, data={"rule": ContestRuleType.OI})
|
resp = self.client.get(self.url, data={"rule": ContestRuleType.OI})
|
||||||
self.assertSuccess(resp)
|
self.assertSuccess(resp)
|
||||||
data = resp.data["data"]
|
data = resp.data["data"]["results"]
|
||||||
self.assertEqual(data[0]["user"]["username"], "test2")
|
self.assertEqual(data[0]["user"]["username"], "test2")
|
||||||
self.assertEqual(data[1]["user"]["username"], "test1")
|
self.assertEqual(data[1]["user"]["username"], "test1")
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from django.conf.urls import url
|
from django.conf.urls import url
|
||||||
|
|
||||||
from ..views.oj import (ApplyResetPasswordAPI, ResetPasswordAPI,
|
from ..views.oj import (ApplyResetPasswordAPI, ResetPasswordAPI,
|
||||||
UserChangePasswordAPI, UserRegisterAPI,
|
UserChangePasswordAPI, UserRegisterAPI, UserChangeEmailAPI,
|
||||||
UserLoginAPI, UserLogoutAPI, UsernameOrEmailCheck,
|
UserLoginAPI, UserLogoutAPI, UsernameOrEmailCheck,
|
||||||
AvatarUploadAPI, TwoFactorAuthAPI, UserProfileAPI,
|
AvatarUploadAPI, TwoFactorAuthAPI, UserProfileAPI,
|
||||||
UserRankAPI, CheckTFARequiredAPI, SessionManagementAPI)
|
UserRankAPI, CheckTFARequiredAPI, SessionManagementAPI)
|
||||||
@@ -13,6 +13,7 @@ urlpatterns = [
|
|||||||
url(r"^logout/?$", UserLogoutAPI.as_view(), name="user_logout_api"),
|
url(r"^logout/?$", UserLogoutAPI.as_view(), name="user_logout_api"),
|
||||||
url(r"^register/?$", UserRegisterAPI.as_view(), name="user_register_api"),
|
url(r"^register/?$", UserRegisterAPI.as_view(), name="user_register_api"),
|
||||||
url(r"^change_password/?$", UserChangePasswordAPI.as_view(), name="user_change_password_api"),
|
url(r"^change_password/?$", UserChangePasswordAPI.as_view(), name="user_change_password_api"),
|
||||||
|
url(r"^change_email/?$", UserChangeEmailAPI.as_view(), name="user_change_email"),
|
||||||
url(r"^apply_reset_password/?$", ApplyResetPasswordAPI.as_view(), name="apply_reset_password_api"),
|
url(r"^apply_reset_password/?$", ApplyResetPasswordAPI.as_view(), name="apply_reset_password_api"),
|
||||||
url(r"^reset_password/?$", ResetPasswordAPI.as_view(), name="reset_password_api"),
|
url(r"^reset_password/?$", ResetPasswordAPI.as_view(), name="reset_password_api"),
|
||||||
url(r"^captcha/?$", CaptchaAPIView.as_view(), name="show_captcha"),
|
url(r"^captcha/?$", CaptchaAPIView.as_view(), name="show_captcha"),
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ from ..models import User, UserProfile
|
|||||||
from ..serializers import (ApplyResetPasswordSerializer, ResetPasswordSerializer,
|
from ..serializers import (ApplyResetPasswordSerializer, ResetPasswordSerializer,
|
||||||
UserChangePasswordSerializer, UserLoginSerializer,
|
UserChangePasswordSerializer, UserLoginSerializer,
|
||||||
UserRegisterSerializer, UsernameOrEmailCheckSerializer,
|
UserRegisterSerializer, UsernameOrEmailCheckSerializer,
|
||||||
RankInfoSerializer)
|
RankInfoSerializer, UserChangeEmailSerializer)
|
||||||
from ..serializers import (TwoFactorAuthCodeSerializer, UserProfileSerializer,
|
from ..serializers import (TwoFactorAuthCodeSerializer, UserProfileSerializer,
|
||||||
EditUserProfileSerializer, AvatarUploadForm)
|
EditUserProfileSerializer, AvatarUploadForm)
|
||||||
from ..tasks import send_email_async
|
from ..tasks import send_email_async
|
||||||
@@ -176,11 +176,6 @@ class UserLoginAPI(APIView):
|
|||||||
else:
|
else:
|
||||||
return self.error("Invalid username or password")
|
return self.error("Invalid username or password")
|
||||||
|
|
||||||
# todo remove this, only for debug use
|
|
||||||
def get(self, request):
|
|
||||||
auth.login(request, auth.authenticate(username=request.GET["username"], password=request.GET["password"]))
|
|
||||||
return self.success()
|
|
||||||
|
|
||||||
|
|
||||||
class UserLogoutAPI(APIView):
|
class UserLogoutAPI(APIView):
|
||||||
def get(self, request):
|
def get(self, request):
|
||||||
@@ -233,6 +228,27 @@ class UserRegisterAPI(APIView):
|
|||||||
return self.success("Succeeded")
|
return self.success("Succeeded")
|
||||||
|
|
||||||
|
|
||||||
|
class UserChangeEmailAPI(APIView):
|
||||||
|
@validate_serializer(UserChangeEmailSerializer)
|
||||||
|
@login_required
|
||||||
|
def post(self, request):
|
||||||
|
data = request.data
|
||||||
|
user = auth.authenticate(username=request.user.username, password=data["password"])
|
||||||
|
if user:
|
||||||
|
if user.two_factor_auth:
|
||||||
|
if "tfa_code" not in data:
|
||||||
|
return self.error("tfa_required")
|
||||||
|
if not OtpAuth(user.tfa_token).valid_totp(data["tfa_code"]):
|
||||||
|
return self.error("Invalid two factor verification code")
|
||||||
|
if User.objects.filter(email=data["new_email"]).exists():
|
||||||
|
return self.error("The email is owned by other account")
|
||||||
|
user.email = data["new_email"]
|
||||||
|
user.save()
|
||||||
|
return self.success("Succeeded")
|
||||||
|
else:
|
||||||
|
return self.error("Wrong password")
|
||||||
|
|
||||||
|
|
||||||
class UserChangePasswordAPI(APIView):
|
class UserChangePasswordAPI(APIView):
|
||||||
@validate_serializer(UserChangePasswordSerializer)
|
@validate_serializer(UserChangePasswordSerializer)
|
||||||
@login_required
|
@login_required
|
||||||
@@ -244,7 +260,11 @@ class UserChangePasswordAPI(APIView):
|
|||||||
username = request.user.username
|
username = request.user.username
|
||||||
user = auth.authenticate(username=username, password=data["old_password"])
|
user = auth.authenticate(username=username, password=data["old_password"])
|
||||||
if user:
|
if user:
|
||||||
# TODO: check tfa?
|
if user.two_factor_auth:
|
||||||
|
if "tfa_code" not in data:
|
||||||
|
return self.error("tfa_required")
|
||||||
|
if not OtpAuth(user.tfa_token).valid_totp(data["tfa_code"]):
|
||||||
|
return self.error("Invalid two factor verification code")
|
||||||
user.set_password(data["new_password"])
|
user.set_password(data["new_password"])
|
||||||
user.save()
|
user.save()
|
||||||
return self.success("Succeeded")
|
return self.success("Succeeded")
|
||||||
|
|||||||
@@ -45,13 +45,12 @@ class Contest(models.Model):
|
|||||||
def is_contest_admin(self, user):
|
def is_contest_admin(self, user):
|
||||||
return user.is_authenticated() and (self.created_by == user or user.admin_type == AdminType.SUPER_ADMIN)
|
return user.is_authenticated() and (self.created_by == user or user.admin_type == AdminType.SUPER_ADMIN)
|
||||||
|
|
||||||
def check_oi_permission(self, user):
|
# 是否有权查看problem 的一些统计信息 诸如submission_number, accepted_number 等
|
||||||
if self.status != ContestStatus.CONTEST_ENDED and not self.real_time_rank:
|
def problem_details_permission(self, user):
|
||||||
if self.is_contest_admin(user):
|
return self.rule_type == ContestRuleType.ACM or \
|
||||||
return True
|
self.status == ContestStatus.CONTEST_ENDED or \
|
||||||
else:
|
self.is_contest_admin(user) or \
|
||||||
return False
|
self.real_time_rank
|
||||||
return True
|
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
db_table = "contest"
|
db_table = "contest"
|
||||||
|
|||||||
@@ -58,43 +58,40 @@ class ContestAdminAPITest(APITestCase):
|
|||||||
class ContestAPITest(APITestCase):
|
class ContestAPITest(APITestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.create_admin()
|
self.create_admin()
|
||||||
self.url = self.reverse("contest_api")
|
|
||||||
|
|
||||||
def create_contest(self):
|
|
||||||
url = self.reverse("contest_admin_api")
|
url = self.reverse("contest_admin_api")
|
||||||
return self.client.post(url, data=DEFAULT_CONTEST_DATA)
|
self.contest = self.client.post(url, data=DEFAULT_CONTEST_DATA).data["data"]
|
||||||
|
self.url = self.reverse("contest_api") + "?contest_id=" + str(self.contest["id"])
|
||||||
|
|
||||||
def test_get_contest_list(self):
|
def test_get_contest_list(self):
|
||||||
self.create_contest()
|
url = self.reverse("contest_list_api")
|
||||||
response = self.client.get(self.url)
|
response = self.client.get(url + "?limit=10")
|
||||||
self.assertSuccess(response)
|
self.assertSuccess(response)
|
||||||
|
self.assertEqual(len(response.data["data"]["results"]), 1)
|
||||||
|
|
||||||
def test_get_one_contest(self):
|
def test_get_one_contest(self):
|
||||||
contest_id = self.create_contest().data["data"]["id"]
|
resp = self.client.get(self.url)
|
||||||
response = self.client.get("{}?id={}".format(self.url, contest_id))
|
self.assertSuccess(resp)
|
||||||
self.assertSuccess(response)
|
|
||||||
|
|
||||||
def test_regular_user_validate_contest_password(self):
|
def test_regular_user_validate_contest_password(self):
|
||||||
contest_id = self.create_contest().data["data"]["id"]
|
|
||||||
self.create_user("test", "test123")
|
self.create_user("test", "test123")
|
||||||
url = self.reverse("contest_password_api")
|
url = self.reverse("contest_password_api")
|
||||||
resp = self.client.post(url, {"contest_id": contest_id, "password": "error_password"})
|
resp = self.client.post(url, {"contest_id": self.contest["id"], "password": "error_password"})
|
||||||
self.assertDictEqual(resp.data, {"error": "error", "data": "Wrong password"})
|
self.assertDictEqual(resp.data, {"error": "error", "data": "Wrong password"})
|
||||||
|
|
||||||
resp = self.client.post(url, {"contest_id": contest_id, "password": DEFAULT_CONTEST_DATA["password"]})
|
resp = self.client.post(url, {"contest_id": self.contest["id"], "password": DEFAULT_CONTEST_DATA["password"]})
|
||||||
self.assertSuccess(resp)
|
self.assertSuccess(resp)
|
||||||
|
|
||||||
def test_regular_user_access_contest(self):
|
def test_regular_user_access_contest(self):
|
||||||
contest_id = self.create_contest().data["data"]["id"]
|
|
||||||
self.create_user("test", "test123")
|
self.create_user("test", "test123")
|
||||||
url = self.reverse("contest_access_api")
|
url = self.reverse("contest_access_api")
|
||||||
resp = self.client.get(url + "?contest_id=" + str(contest_id))
|
resp = self.client.get(url + "?contest_id=" + str(self.contest["id"]))
|
||||||
self.assertFalse(resp.data["data"]["access"])
|
self.assertFalse(resp.data["data"]["access"])
|
||||||
|
|
||||||
password_url = self.reverse("contest_password_api")
|
password_url = self.reverse("contest_password_api")
|
||||||
resp = self.client.post(password_url, {"contest_id": contest_id, "password": DEFAULT_CONTEST_DATA["password"]})
|
resp = self.client.post(password_url,
|
||||||
|
{"contest_id": self.contest["id"], "password": DEFAULT_CONTEST_DATA["password"]})
|
||||||
self.assertSuccess(resp)
|
self.assertSuccess(resp)
|
||||||
resp = self.client.get(url + "?contest_id=" + str(contest_id))
|
resp = self.client.get(self.url)
|
||||||
self.assertSuccess(resp)
|
self.assertSuccess(resp)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
from django.conf.urls import url
|
from django.conf.urls import url
|
||||||
|
|
||||||
from ..views.oj import ContestAnnouncementListAPI, ContestAPI
|
from ..views.oj import ContestAnnouncementListAPI
|
||||||
from ..views.oj import ContestPasswordVerifyAPI, ContestAccessAPI
|
from ..views.oj import ContestPasswordVerifyAPI, ContestAccessAPI
|
||||||
|
from ..views.oj import ContestListAPI, ContestAPI
|
||||||
from ..views.oj import ContestRankAPI
|
from ..views.oj import ContestRankAPI
|
||||||
|
|
||||||
urlpatterns = [
|
urlpatterns = [
|
||||||
|
url(r"^contests/?$", ContestListAPI.as_view(), name="contest_list_api"),
|
||||||
url(r"^contest/?$", ContestAPI.as_view(), name="contest_api"),
|
url(r"^contest/?$", ContestAPI.as_view(), name="contest_api"),
|
||||||
url(r"^contest/password/?$", ContestPasswordVerifyAPI.as_view(), name="contest_password_api"),
|
url(r"^contest/password/?$", ContestPasswordVerifyAPI.as_view(), name="contest_password_api"),
|
||||||
url(r"^contest/announcement/?$", ContestAnnouncementListAPI.as_view(), name="contest_announcement_api"),
|
url(r"^contest/announcement/?$", ContestAnnouncementListAPI.as_view(), name="contest_announcement_api"),
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from ..serializers import OIContestRankSerializer, ACMContestRankSerializer
|
|||||||
|
|
||||||
|
|
||||||
class ContestAnnouncementListAPI(APIView):
|
class ContestAnnouncementListAPI(APIView):
|
||||||
|
@check_contest_permission(check_type="announcements")
|
||||||
def get(self, request):
|
def get(self, request):
|
||||||
contest_id = request.GET.get("contest_id")
|
contest_id = request.GET.get("contest_id")
|
||||||
if not contest_id:
|
if not contest_id:
|
||||||
@@ -24,15 +25,13 @@ class ContestAnnouncementListAPI(APIView):
|
|||||||
|
|
||||||
|
|
||||||
class ContestAPI(APIView):
|
class ContestAPI(APIView):
|
||||||
|
@check_contest_permission(check_type="details")
|
||||||
def get(self, request):
|
def get(self, request):
|
||||||
contest_id = request.GET.get("id")
|
return self.success(ContestSerializer(self.contest).data)
|
||||||
if contest_id:
|
|
||||||
try:
|
|
||||||
contest = Contest.objects.select_related("created_by").get(id=contest_id, visible=True)
|
|
||||||
except Contest.DoesNotExist:
|
|
||||||
return self.error("Contest does not exist")
|
|
||||||
return self.success(ContestSerializer(contest).data)
|
|
||||||
|
|
||||||
|
|
||||||
|
class ContestListAPI(APIView):
|
||||||
|
def get(self, request):
|
||||||
contests = Contest.objects.select_related("created_by").filter(visible=True)
|
contests = Contest.objects.select_related("created_by").filter(visible=True)
|
||||||
keyword = request.GET.get("keyword")
|
keyword = request.GET.get("keyword")
|
||||||
rule_type = request.GET.get("rule_type")
|
rule_type = request.GET.get("rule_type")
|
||||||
@@ -49,7 +48,8 @@ class ContestAPI(APIView):
|
|||||||
contests = contests.filter(end_time__lt=cur)
|
contests = contests.filter(end_time__lt=cur)
|
||||||
else:
|
else:
|
||||||
contests = contests.filter(start_time__lte=cur, end_time__gte=cur)
|
contests = contests.filter(start_time__lte=cur, end_time__gte=cur)
|
||||||
return self.success(self.paginate_data(request, contests, ContestSerializer))
|
data = self.paginate_data(request, contests, ContestSerializer)
|
||||||
|
return self.success(data)
|
||||||
|
|
||||||
|
|
||||||
class ContestPasswordVerifyAPI(APIView):
|
class ContestPasswordVerifyAPI(APIView):
|
||||||
@@ -91,11 +91,9 @@ class ContestRankAPI(APIView):
|
|||||||
return OIContestRank.objects.filter(contest=self.contest). \
|
return OIContestRank.objects.filter(contest=self.contest). \
|
||||||
select_related("user").order_by("-total_score")
|
select_related("user").order_by("-total_score")
|
||||||
|
|
||||||
@check_contest_permission
|
@check_contest_permission(check_type="ranks")
|
||||||
def get(self, request):
|
def get(self, request):
|
||||||
if self.contest.rule_type == ContestRuleType.OI:
|
if self.contest.rule_type == ContestRuleType.OI:
|
||||||
if not self.contest.check_oi_permission(request.user):
|
|
||||||
return self.error("You have no permission for ranks now")
|
|
||||||
serializer = OIContestRankSerializer
|
serializer = OIContestRankSerializer
|
||||||
else:
|
else:
|
||||||
serializer = ACMContestRankSerializer
|
serializer = ACMContestRankSerializer
|
||||||
@@ -105,5 +103,4 @@ class ContestRankAPI(APIView):
|
|||||||
if not qs:
|
if not qs:
|
||||||
qs = self.get_rank()
|
qs = self.get_rank()
|
||||||
cache.set(cache_key, qs)
|
cache.set(cache_key, qs)
|
||||||
|
|
||||||
return self.success(self.paginate_data(request, qs, serializer))
|
return self.success(self.paginate_data(request, qs, serializer))
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ cd $BASE
|
|||||||
find . -name "*.pyc" -delete
|
find . -name "*.pyc" -delete
|
||||||
|
|
||||||
# wait for postgresql start
|
# wait for postgresql start
|
||||||
sleep 5
|
sleep 6
|
||||||
|
|
||||||
n=0
|
n=0
|
||||||
while [ $n -lt 3 ]
|
while [ $n -lt 3 ]
|
||||||
|
|||||||
@@ -192,7 +192,7 @@ CELERY_ACCEPT_CONTENT = ["json"]
|
|||||||
CELERY_TASK_SERIALIZER = "json"
|
CELERY_TASK_SERIALIZER = "json"
|
||||||
|
|
||||||
# 用于限制用户恶意提交大量代码
|
# 用于限制用户恶意提交大量代码
|
||||||
TOKEN_BUCKET_DEFAULT_CAPACITY = 50
|
TOKEN_BUCKET_DEFAULT_CAPACITY = 20
|
||||||
|
|
||||||
# 单位:每分钟
|
# 单位:每分钟
|
||||||
TOKEN_BUCKET_FILL_RATE = 2
|
TOKEN_BUCKET_FILL_RATE = 2
|
||||||
|
|||||||
@@ -107,4 +107,11 @@ class ProblemSerializer(BaseProblemSerializer):
|
|||||||
class ContestProblemSerializer(BaseProblemSerializer):
|
class ContestProblemSerializer(BaseProblemSerializer):
|
||||||
class Meta:
|
class Meta:
|
||||||
model = Problem
|
model = Problem
|
||||||
exclude = ("test_case_score", "test_case_id", "visible", "is_public")
|
exclude = ("test_case_score", "test_case_id", "visible", "is_public", "difficulty")
|
||||||
|
|
||||||
|
|
||||||
|
class ContestProblemSafeSerializer(BaseProblemSerializer):
|
||||||
|
class Meta:
|
||||||
|
model = Problem
|
||||||
|
exclude = ("test_case_score", "test_case_id", "visible", "is_public", "difficulty"
|
||||||
|
"submission_number", "accepted_number", "statistic_info")
|
||||||
|
|||||||
@@ -196,30 +196,26 @@ class ContestProblemAdminTest(APITestCase):
|
|||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.url = self.reverse("contest_problem_admin_api")
|
self.url = self.reverse("contest_problem_admin_api")
|
||||||
self.create_admin()
|
self.create_admin()
|
||||||
|
self.contest = self.client.post(self.reverse("contest_admin_api"), data=DEFAULT_CONTEST_DATA).data["data"]
|
||||||
def create_contest(self):
|
|
||||||
url = self.reverse("contest_admin_api")
|
|
||||||
return self.client.post(url, data=DEFAULT_CONTEST_DATA)
|
|
||||||
|
|
||||||
def test_create_contest_problem(self):
|
def test_create_contest_problem(self):
|
||||||
contest = self.create_contest()
|
|
||||||
data = copy.deepcopy(DEFAULT_PROBLEM_DATA)
|
data = copy.deepcopy(DEFAULT_PROBLEM_DATA)
|
||||||
data["contest_id"] = contest.data["data"]["id"]
|
data["contest_id"] = self.contest["id"]
|
||||||
resp = self.client.post(self.url, data=data)
|
resp = self.client.post(self.url, data=data)
|
||||||
self.assertSuccess(resp)
|
self.assertSuccess(resp)
|
||||||
return contest, resp
|
return resp.data["data"]
|
||||||
|
|
||||||
def test_get_contest_problem(self):
|
def test_get_contest_problem(self):
|
||||||
contest, contest_problem = self.test_create_contest_problem()
|
self.test_create_contest_problem()
|
||||||
contest_id = contest.data["data"]["id"]
|
contest_id = self.contest["id"]
|
||||||
resp = self.client.get(self.url + "?contest_id=" + str(contest_id))
|
resp = self.client.get(self.url + "?contest_id=" + str(contest_id))
|
||||||
self.assertSuccess(resp)
|
self.assertSuccess(resp)
|
||||||
self.assertEqual(len(resp.data["data"]), 1)
|
self.assertEqual(len(resp.data["data"]["results"]), 1)
|
||||||
|
|
||||||
def test_get_one_contest_problem(self):
|
def test_get_one_contest_problem(self):
|
||||||
contest, contest_problem = self.test_create_contest_problem()
|
contest_problem = self.test_create_contest_problem()
|
||||||
contest_id = contest.data["data"]["id"]
|
contest_id = self.contest["id"]
|
||||||
problem_id = contest_problem.data["data"]["id"]
|
problem_id = contest_problem["id"]
|
||||||
resp = self.client.get(f"{self.url}?contest_id={contest_id}&id={problem_id}")
|
resp = self.client.get(f"{self.url}?contest_id={contest_id}&id={problem_id}")
|
||||||
self.assertSuccess(resp)
|
self.assertSuccess(resp)
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from utils.api import APIView
|
|||||||
from account.decorators import check_contest_permission
|
from account.decorators import check_contest_permission
|
||||||
from ..models import ProblemTag, Problem, ProblemRuleType
|
from ..models import ProblemTag, Problem, ProblemRuleType
|
||||||
from ..serializers import ProblemSerializer, TagSerializer
|
from ..serializers import ProblemSerializer, TagSerializer
|
||||||
from ..serializers import ContestProblemSerializer
|
from ..serializers import ContestProblemSerializer, ContestProblemSafeSerializer
|
||||||
from contest.models import ContestRuleType
|
from contest.models import ContestRuleType
|
||||||
|
|
||||||
|
|
||||||
@@ -81,8 +81,6 @@ class ProblemAPI(APIView):
|
|||||||
|
|
||||||
class ContestProblemAPI(APIView):
|
class ContestProblemAPI(APIView):
|
||||||
def _add_problem_status(self, request, queryset_values):
|
def _add_problem_status(self, request, queryset_values):
|
||||||
if self.contest.rule_type == ContestRuleType.OI and not self.contest.check_oi_permission(request.user):
|
|
||||||
return
|
|
||||||
if request.user.is_authenticated():
|
if request.user.is_authenticated():
|
||||||
profile = request.user.userprofile
|
profile = request.user.userprofile
|
||||||
if self.contest.rule_type == ContestRuleType.ACM:
|
if self.contest.rule_type == ContestRuleType.ACM:
|
||||||
@@ -92,7 +90,7 @@ class ContestProblemAPI(APIView):
|
|||||||
for problem in queryset_values:
|
for problem in queryset_values:
|
||||||
problem["my_status"] = problems_status.get(str(problem["id"]), {}).get("status")
|
problem["my_status"] = problems_status.get(str(problem["id"]), {}).get("status")
|
||||||
|
|
||||||
@check_contest_permission
|
@check_contest_permission(check_type="problems")
|
||||||
def get(self, request):
|
def get(self, request):
|
||||||
problem_id = request.GET.get("problem_id")
|
problem_id = request.GET.get("problem_id")
|
||||||
if problem_id:
|
if problem_id:
|
||||||
@@ -102,11 +100,17 @@ class ContestProblemAPI(APIView):
|
|||||||
visible=True)
|
visible=True)
|
||||||
except Problem.DoesNotExist:
|
except Problem.DoesNotExist:
|
||||||
return self.error("Problem does not exist.")
|
return self.error("Problem does not exist.")
|
||||||
problem_data = ContestProblemSerializer(problem).data
|
if self.contest.problem_details_permission(request.user):
|
||||||
self._add_problem_status(request, [problem_data, ])
|
problem_data = ContestProblemSerializer(problem).data
|
||||||
|
self._add_problem_status(request, [problem_data, ])
|
||||||
|
else:
|
||||||
|
problem_data = ContestProblemSafeSerializer(problem).data
|
||||||
return self.success(problem_data)
|
return self.success(problem_data)
|
||||||
|
|
||||||
contest_problems = Problem.objects.select_related("created_by").filter(contest=self.contest, visible=True)
|
contest_problems = Problem.objects.select_related("created_by").filter(contest=self.contest, visible=True)
|
||||||
# 根据profile, 为做过的题目添加标记
|
if self.contest.problem_details_permission(request.user):
|
||||||
data = ContestProblemSerializer(contest_problems, many=True).data
|
data = ContestProblemSerializer(contest_problems, many=True).data
|
||||||
self._add_problem_status(request, data)
|
self._add_problem_status(request, data)
|
||||||
|
else:
|
||||||
|
data = ContestProblemSafeSerializer(contest_problems, many=True).data
|
||||||
return self.success(data)
|
return self.success(data)
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ class CreateSubmissionSerializer(serializers.Serializer):
|
|||||||
language = serializers.ChoiceField(choices=language_names)
|
language = serializers.ChoiceField(choices=language_names)
|
||||||
code = serializers.CharField(max_length=20000)
|
code = serializers.CharField(max_length=20000)
|
||||||
contest_id = serializers.IntegerField(required=False)
|
contest_id = serializers.IntegerField(required=False)
|
||||||
|
captcha = serializers.CharField(required=False)
|
||||||
|
|
||||||
|
|
||||||
class ShareSubmissionSerializer(serializers.Serializer):
|
class ShareSubmissionSerializer(serializers.Serializer):
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
from django.conf import settings
|
||||||
from account.decorators import login_required, check_contest_permission
|
from account.decorators import login_required, check_contest_permission
|
||||||
from judge.tasks import judge_task
|
from judge.tasks import judge_task
|
||||||
# from judge.dispatcher import JudgeDispatcher
|
# from judge.dispatcher import JudgeDispatcher
|
||||||
@@ -5,6 +6,7 @@ from problem.models import Problem, ProblemRuleType
|
|||||||
from contest.models import Contest, ContestStatus, ContestRuleType
|
from contest.models import Contest, ContestStatus, ContestRuleType
|
||||||
from utils.api import APIView, validate_serializer
|
from utils.api import APIView, validate_serializer
|
||||||
from utils.throttling import TokenBucket, BucketController
|
from utils.throttling import TokenBucket, BucketController
|
||||||
|
from utils.captcha import Captcha
|
||||||
from utils.cache import cache
|
from utils.cache import cache
|
||||||
from ..models import Submission
|
from ..models import Submission
|
||||||
from ..serializers import (CreateSubmissionSerializer, SubmissionModelSerializer,
|
from ..serializers import (CreateSubmissionSerializer, SubmissionModelSerializer,
|
||||||
@@ -12,43 +14,38 @@ from ..serializers import (CreateSubmissionSerializer, SubmissionModelSerializer
|
|||||||
from ..serializers import SubmissionSafeModelSerializer, SubmissionListSerializer
|
from ..serializers import SubmissionSafeModelSerializer, SubmissionListSerializer
|
||||||
|
|
||||||
|
|
||||||
def _submit(response, user, problem_id, language, code, contest_id):
|
|
||||||
# TODO: 预设默认值,需修改
|
|
||||||
controller = BucketController(user_id=user.id,
|
|
||||||
redis_conn=cache,
|
|
||||||
default_capacity=30)
|
|
||||||
bucket = TokenBucket(fill_rate=10, capacity=20,
|
|
||||||
last_capacity=controller.last_capacity,
|
|
||||||
last_timestamp=controller.last_timestamp)
|
|
||||||
if bucket.consume():
|
|
||||||
controller.last_capacity -= 1
|
|
||||||
else:
|
|
||||||
return response.error("Please wait %d seconds" % int(bucket.expected_time() + 1))
|
|
||||||
|
|
||||||
try:
|
|
||||||
problem = Problem.objects.get(id=problem_id,
|
|
||||||
contest_id=contest_id,
|
|
||||||
visible=True)
|
|
||||||
except Problem.DoesNotExist:
|
|
||||||
return response.error("Problem not exist")
|
|
||||||
|
|
||||||
submission = Submission.objects.create(user_id=user.id,
|
|
||||||
username=user.username,
|
|
||||||
language=language,
|
|
||||||
code=code,
|
|
||||||
problem_id=problem.id,
|
|
||||||
contest_id=contest_id)
|
|
||||||
# use this for debug
|
|
||||||
# JudgeDispatcher(submission.id, problem.id).judge()
|
|
||||||
judge_task.delay(submission.id, problem.id)
|
|
||||||
return response.success({"submission_id": submission.id})
|
|
||||||
|
|
||||||
|
|
||||||
class SubmissionAPI(APIView):
|
class SubmissionAPI(APIView):
|
||||||
|
def throttling(self, request):
|
||||||
|
user_controller = BucketController(factor=request.user.id,
|
||||||
|
redis_conn=cache,
|
||||||
|
default_capacity=settings.TOKEN_BUCKET_DEFAULT_CAPACITY)
|
||||||
|
user_bucket = TokenBucket(fill_rate=settings.TOKEN_BUCKET_FILL_RATE,
|
||||||
|
capacity=settings.TOKEN_BUCKET_DEFAULT_CAPACITY,
|
||||||
|
last_capacity=user_controller.last_capacity,
|
||||||
|
last_timestamp=user_controller.last_timestamp)
|
||||||
|
if user_bucket.consume():
|
||||||
|
user_controller.last_capacity -= 1
|
||||||
|
else:
|
||||||
|
return "Please wait %d seconds" % int(user_bucket.expected_time() + 1)
|
||||||
|
|
||||||
|
ip_controller = BucketController(factor=request.session["ip"],
|
||||||
|
redis_conn=cache,
|
||||||
|
default_capacity=settings.TOKEN_BUCKET_DEFAULT_CAPACITY * 3)
|
||||||
|
|
||||||
|
ip_bucket = TokenBucket(fill_rate=settings.TOKEN_BUCKET_FILL_RATE * 3,
|
||||||
|
capacity=settings.TOKEN_BUCKET_DEFAULT_CAPACITY * 3,
|
||||||
|
last_capacity=ip_controller.last_capacity,
|
||||||
|
last_timestamp=ip_controller.last_timestamp)
|
||||||
|
if ip_bucket.consume():
|
||||||
|
ip_controller.last_capacity -= 1
|
||||||
|
else:
|
||||||
|
return "Captcha is required"
|
||||||
|
|
||||||
@validate_serializer(CreateSubmissionSerializer)
|
@validate_serializer(CreateSubmissionSerializer)
|
||||||
@login_required
|
@login_required
|
||||||
def post(self, request):
|
def post(self, request):
|
||||||
data = request.data
|
data = request.data
|
||||||
|
hide_id = False
|
||||||
if data.get("contest_id"):
|
if data.get("contest_id"):
|
||||||
try:
|
try:
|
||||||
contest = Contest.objects.get(id=data["contest_id"])
|
contest = Contest.objects.get(id=data["contest_id"])
|
||||||
@@ -56,9 +53,39 @@ class SubmissionAPI(APIView):
|
|||||||
return self.error("Contest doesn't exist.")
|
return self.error("Contest doesn't exist.")
|
||||||
if contest.status == ContestStatus.CONTEST_ENDED:
|
if contest.status == ContestStatus.CONTEST_ENDED:
|
||||||
return self.error("The contest have ended")
|
return self.error("The contest have ended")
|
||||||
if contest.status == ContestStatus.CONTEST_NOT_START and request.user != contest.created_by:
|
if contest.status == ContestStatus.CONTEST_NOT_START and not contest.is_contest_admin(request.user):
|
||||||
return self.error("Contest have not started")
|
return self.error("Contest have not started")
|
||||||
return _submit(self, request.user, data["problem_id"], data["language"], data["code"], data.get("contest_id"))
|
if not contest.problem_details_permission():
|
||||||
|
hide_id = True
|
||||||
|
|
||||||
|
if data.get("captcha"):
|
||||||
|
if not Captcha(request).check(data["captcha"]):
|
||||||
|
return self.error("Invalid captcha")
|
||||||
|
|
||||||
|
error = self.throttling(request)
|
||||||
|
if error:
|
||||||
|
return self.error(error)
|
||||||
|
|
||||||
|
try:
|
||||||
|
problem = Problem.objects.get(id=data["problem_id"],
|
||||||
|
contest_id=data.get("contest_id"),
|
||||||
|
visible=True)
|
||||||
|
except Problem.DoesNotExist:
|
||||||
|
return self.error("Problem not exist")
|
||||||
|
|
||||||
|
submission = Submission.objects.create(user_id=request.user.id,
|
||||||
|
username=request.user.username,
|
||||||
|
language=data["language"],
|
||||||
|
code=data["code"],
|
||||||
|
problem_id=problem.id,
|
||||||
|
contest_id=data.get("contest_id"))
|
||||||
|
# use this for debug
|
||||||
|
# JudgeDispatcher(submission.id, problem.id).judge()
|
||||||
|
judge_task.delay(submission.id, problem.id)
|
||||||
|
if hide_id:
|
||||||
|
return self.success()
|
||||||
|
else:
|
||||||
|
return self.success({"submission_id": submission.id})
|
||||||
|
|
||||||
@login_required
|
@login_required
|
||||||
def get(self, request):
|
def get(self, request):
|
||||||
@@ -123,15 +150,12 @@ class SubmissionListAPI(APIView):
|
|||||||
|
|
||||||
|
|
||||||
class ContestSubmissionListAPI(APIView):
|
class ContestSubmissionListAPI(APIView):
|
||||||
@check_contest_permission
|
@check_contest_permission(check_type="submissions")
|
||||||
def get(self, request):
|
def get(self, request):
|
||||||
if not request.GET.get("limit"):
|
if not request.GET.get("limit"):
|
||||||
return self.error("Limit is needed")
|
return self.error("Limit is needed")
|
||||||
|
|
||||||
contest = self.contest
|
contest = self.contest
|
||||||
if not contest.check_oi_permission(request.user):
|
|
||||||
return self.error("No permission for OI contest submissions")
|
|
||||||
|
|
||||||
submissions = Submission.objects.filter(contest_id=contest.id).select_related("problem__created_by")
|
submissions = Submission.objects.filter(contest_id=contest.id).select_related("problem__created_by")
|
||||||
problem_id = request.GET.get("problem_id")
|
problem_id = request.GET.get("problem_id")
|
||||||
myself = request.GET.get("myself")
|
myself = request.GET.get("myself")
|
||||||
|
|||||||
@@ -107,18 +107,12 @@ class APIView(View):
|
|||||||
:param object_serializer: 用来序列化query set, 如果为None, 则直接对query set切片
|
:param object_serializer: 用来序列化query set, 如果为None, 则直接对query set切片
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
need_paginate = request.GET.get("limit", None)
|
|
||||||
if need_paginate is None:
|
|
||||||
if object_serializer:
|
|
||||||
return object_serializer(query_set, many=True).data
|
|
||||||
else:
|
|
||||||
return {"results": query_set, "total": query_set.count()}
|
|
||||||
try:
|
try:
|
||||||
limit = int(request.GET.get("limit", "100"))
|
limit = int(request.GET.get("limit", "10"))
|
||||||
except ValueError:
|
except ValueError:
|
||||||
limit = 100
|
limit = 10
|
||||||
if limit < 0:
|
if limit < 0 or limit > 100:
|
||||||
limit = 100
|
limit = 10
|
||||||
try:
|
try:
|
||||||
offset = int(request.GET.get("offset", "0"))
|
offset = int(request.GET.get("offset", "0"))
|
||||||
except ValueError:
|
except ValueError:
|
||||||
|
|||||||
@@ -27,8 +27,8 @@ class APITestCase(TestCase):
|
|||||||
return self.create_user(username=username, password=password, admin_type=AdminType.SUPER_ADMIN,
|
return self.create_user(username=username, password=password, admin_type=AdminType.SUPER_ADMIN,
|
||||||
problem_permission=ProblemPermission.ALL, login=login)
|
problem_permission=ProblemPermission.ALL, login=login)
|
||||||
|
|
||||||
def reverse(self, url_name):
|
def reverse(self, url_name, *args, **kwargs):
|
||||||
return reverse(url_name)
|
return reverse(url_name, *args, **kwargs)
|
||||||
|
|
||||||
def assertSuccess(self, response):
|
def assertSuccess(self, response):
|
||||||
if not response.data["error"] is None:
|
if not response.data["error"] is None:
|
||||||
|
|||||||
@@ -31,11 +31,10 @@ class TokenBucket:
|
|||||||
|
|
||||||
|
|
||||||
class BucketController:
|
class BucketController:
|
||||||
def __init__(self, user_id, redis_conn, default_capacity):
|
def __init__(self, factor, redis_conn, default_capacity):
|
||||||
self.user_id = user_id
|
|
||||||
self.default_capacity = default_capacity
|
self.default_capacity = default_capacity
|
||||||
self.redis = redis_conn
|
self.redis = redis_conn
|
||||||
self.key = "bucket_" + str(self.user_id)
|
self.key = "bucket_" + str(factor)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def last_capacity(self):
|
def last_capacity(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user