完善contest权限控制

This commit is contained in:
zema1
2017-10-27 18:36:29 +08:00
parent b694000ab9
commit 728373a5ff
19 changed files with 219 additions and 162 deletions

View File

@@ -4,7 +4,7 @@ from utils.api import JSONResponse
from .models import ProblemPermission from .models import ProblemPermission
from contest.models import Contest, ContestType, ContestStatus from contest.models import Contest, ContestType, ContestStatus, ContestRuleType
class BasePermissionDecorator(object): class BasePermissionDecorator(object):
@@ -25,7 +25,7 @@ class BasePermissionDecorator(object):
return self.error("Your account is disabled") return self.error("Your account is disabled")
return self.func(*args, **kwargs) return self.func(*args, **kwargs)
else: else:
return self.error("Please login in first") return self.error("Please login first")
def check_permission(self): def check_permission(self):
raise NotImplementedError() raise NotImplementedError()
@@ -57,45 +57,54 @@ class problem_permission_required(admin_role_required):
return True return True
def check_contest_permission(func): def check_contest_permission(check_type="details"):
""" """
只供Class based view 使用检查用户是否有权进入该contest 只供Class based view 使用检查用户是否有权进入该contest, check_type 可选 details, problems, ranks, submissions
若通过验证在view中可通过self.contest获得该contest 若通过验证在view中可通过self.contest获得该contest
""" """
@functools.wraps(func)
def _check_permission(*args, **kwargs):
self = args[0]
request = args[1]
user = request.user
if kwargs.get("contest_id"):
contest_id = kwargs.pop("contest_id")
else:
contest_id = request.GET.get("contest_id")
if not contest_id:
return self.error("Parameter contest_id not exist.")
try: def decorator(func):
# use self.contest to avoid query contest again in view. def _check_permission(*args, **kwargs):
self.contest = Contest.objects.select_related("created_by").get(id=contest_id, visible=True) self = args[0]
except Contest.DoesNotExist: request = args[1]
return self.error("Contest %s doesn't exist" % contest_id) user = request.user
if kwargs.get("contest_id"):
contest_id = kwargs.pop("contest_id")
else:
contest_id = request.GET.get("contest_id")
if not contest_id:
return self.error("Parameter contest_id not exist.")
try:
# use self.contest to avoid query contest again in view.
self.contest = Contest.objects.select_related("created_by").get(id=contest_id, visible=True)
except Contest.DoesNotExist:
return self.error("Contest %s doesn't exist" % contest_id)
# creator or owner
if self.contest.is_contest_admin(user):
return func(*args, **kwargs)
if self.contest.contest_type == ContestType.PASSWORD_PROTECTED_CONTEST:
# Anonymous
if not user.is_authenticated():
return self.error("Please login first.")
# password error
if ("accessible_contests" not in request.session) or \
(self.contest.id not in request.session["accessible_contests"]):
return self.error("Password is required.")
# regular use get contest problems, ranks etc. before contest started
if self.contest.status == ContestStatus.CONTEST_NOT_START and check_type != "details":
return self.error("Contest has not started yet.")
# check is user have permission to get ranks, submissions OI Contest
if self.contest.status == ContestStatus.CONTEST_UNDERWAY and self.contest.rule_type == ContestRuleType.OI:
if not self.contest.real_time_rank and (check_type == "ranks" or check_type == "submissions"):
return self.error(f"No permission to get {check_type}")
# creator or owner
if self.contest.is_contest_admin(user):
return func(*args, **kwargs) return func(*args, **kwargs)
if self.contest.status == ContestStatus.CONTEST_NOT_START: return _check_permission
return self.error("Contest has not started yet.")
if self.contest.contest_type == ContestType.PASSWORD_PROTECTED_CONTEST: return decorator
# Anonymous
if not user.is_authenticated():
return self.error("Please login in first.")
# password error
if ("accessible_contests" not in request.session) or \
(self.contest.id not in request.session["accessible_contests"]):
return self.error("Password is required.")
return func(*args, **kwargs)
return _check_permission

View File

@@ -8,7 +8,7 @@ from .models import AdminType, ProblemPermission, User, UserProfile
class UserLoginSerializer(serializers.Serializer): class UserLoginSerializer(serializers.Serializer):
username = serializers.CharField() username = serializers.CharField()
password = serializers.CharField() password = serializers.CharField()
tfa_code = serializers.CharField(required=False, allow_null=True) tfa_code = serializers.CharField(required=False, allow_blank=True)
class UsernameOrEmailCheckSerializer(serializers.Serializer): class UsernameOrEmailCheckSerializer(serializers.Serializer):
@@ -26,6 +26,13 @@ class UserRegisterSerializer(serializers.Serializer):
class UserChangePasswordSerializer(serializers.Serializer): class UserChangePasswordSerializer(serializers.Serializer):
old_password = serializers.CharField() old_password = serializers.CharField()
new_password = serializers.CharField(min_length=6) new_password = serializers.CharField(min_length=6)
tfa_code = serializers.CharField(required=False, allow_blank=True)
class UserChangeEmailSerializer(serializers.Serializer):
password = serializers.CharField()
new_email = serializers.EmailField(max_length=64)
tfa_code = serializers.CharField(required=False, allow_blank=True)
class UserSerializer(serializers.ModelSerializer): class UserSerializer(serializers.ModelSerializer):

View File

@@ -362,7 +362,7 @@ class UserChangePasswordAPITest(CaptchaTest):
def test_login_required(self): def test_login_required(self):
response = self.client.post(self.url, data=self.data) response = self.client.post(self.url, data=self.data)
self.assertEqual(response.data, {"error": "permission-denied", "data": "Please login in first"}) self.assertEqual(response.data, {"error": "permission-denied", "data": "Please login first"})
def test_valid_ola_password(self): def test_valid_ola_password(self):
self.assertTrue(self.client.login(username=self.username, password=self.old_password)) self.assertTrue(self.client.login(username=self.username, password=self.old_password))
@@ -476,13 +476,13 @@ class UserRankAPITest(APITestCase):
def test_get_acm_rank(self): def test_get_acm_rank(self):
resp = self.client.get(self.url, data={"rule": ContestRuleType.ACM}) resp = self.client.get(self.url, data={"rule": ContestRuleType.ACM})
self.assertSuccess(resp) self.assertSuccess(resp)
data = resp.data["data"] data = resp.data["data"]["results"]
self.assertEqual(data[0]["user"]["username"], "test1") self.assertEqual(data[0]["user"]["username"], "test1")
self.assertEqual(data[1]["user"]["username"], "test2") self.assertEqual(data[1]["user"]["username"], "test2")
def test_get_oi_rank(self): def test_get_oi_rank(self):
resp = self.client.get(self.url, data={"rule": ContestRuleType.OI}) resp = self.client.get(self.url, data={"rule": ContestRuleType.OI})
self.assertSuccess(resp) self.assertSuccess(resp)
data = resp.data["data"] data = resp.data["data"]["results"]
self.assertEqual(data[0]["user"]["username"], "test2") self.assertEqual(data[0]["user"]["username"], "test2")
self.assertEqual(data[1]["user"]["username"], "test1") self.assertEqual(data[1]["user"]["username"], "test1")

View File

@@ -1,7 +1,7 @@
from django.conf.urls import url from django.conf.urls import url
from ..views.oj import (ApplyResetPasswordAPI, ResetPasswordAPI, from ..views.oj import (ApplyResetPasswordAPI, ResetPasswordAPI,
UserChangePasswordAPI, UserRegisterAPI, UserChangePasswordAPI, UserRegisterAPI, UserChangeEmailAPI,
UserLoginAPI, UserLogoutAPI, UsernameOrEmailCheck, UserLoginAPI, UserLogoutAPI, UsernameOrEmailCheck,
AvatarUploadAPI, TwoFactorAuthAPI, UserProfileAPI, AvatarUploadAPI, TwoFactorAuthAPI, UserProfileAPI,
UserRankAPI, CheckTFARequiredAPI, SessionManagementAPI) UserRankAPI, CheckTFARequiredAPI, SessionManagementAPI)
@@ -13,6 +13,7 @@ urlpatterns = [
url(r"^logout/?$", UserLogoutAPI.as_view(), name="user_logout_api"), url(r"^logout/?$", UserLogoutAPI.as_view(), name="user_logout_api"),
url(r"^register/?$", UserRegisterAPI.as_view(), name="user_register_api"), url(r"^register/?$", UserRegisterAPI.as_view(), name="user_register_api"),
url(r"^change_password/?$", UserChangePasswordAPI.as_view(), name="user_change_password_api"), url(r"^change_password/?$", UserChangePasswordAPI.as_view(), name="user_change_password_api"),
url(r"^change_email/?$", UserChangeEmailAPI.as_view(), name="user_change_email"),
url(r"^apply_reset_password/?$", ApplyResetPasswordAPI.as_view(), name="apply_reset_password_api"), url(r"^apply_reset_password/?$", ApplyResetPasswordAPI.as_view(), name="apply_reset_password_api"),
url(r"^reset_password/?$", ResetPasswordAPI.as_view(), name="reset_password_api"), url(r"^reset_password/?$", ResetPasswordAPI.as_view(), name="reset_password_api"),
url(r"^captcha/?$", CaptchaAPIView.as_view(), name="show_captcha"), url(r"^captcha/?$", CaptchaAPIView.as_view(), name="show_captcha"),

View File

@@ -21,7 +21,7 @@ from ..models import User, UserProfile
from ..serializers import (ApplyResetPasswordSerializer, ResetPasswordSerializer, from ..serializers import (ApplyResetPasswordSerializer, ResetPasswordSerializer,
UserChangePasswordSerializer, UserLoginSerializer, UserChangePasswordSerializer, UserLoginSerializer,
UserRegisterSerializer, UsernameOrEmailCheckSerializer, UserRegisterSerializer, UsernameOrEmailCheckSerializer,
RankInfoSerializer) RankInfoSerializer, UserChangeEmailSerializer)
from ..serializers import (TwoFactorAuthCodeSerializer, UserProfileSerializer, from ..serializers import (TwoFactorAuthCodeSerializer, UserProfileSerializer,
EditUserProfileSerializer, AvatarUploadForm) EditUserProfileSerializer, AvatarUploadForm)
from ..tasks import send_email_async from ..tasks import send_email_async
@@ -176,11 +176,6 @@ class UserLoginAPI(APIView):
else: else:
return self.error("Invalid username or password") return self.error("Invalid username or password")
# todo remove this, only for debug use
def get(self, request):
auth.login(request, auth.authenticate(username=request.GET["username"], password=request.GET["password"]))
return self.success()
class UserLogoutAPI(APIView): class UserLogoutAPI(APIView):
def get(self, request): def get(self, request):
@@ -233,6 +228,27 @@ class UserRegisterAPI(APIView):
return self.success("Succeeded") return self.success("Succeeded")
class UserChangeEmailAPI(APIView):
@validate_serializer(UserChangeEmailSerializer)
@login_required
def post(self, request):
data = request.data
user = auth.authenticate(username=request.user.username, password=data["password"])
if user:
if user.two_factor_auth:
if "tfa_code" not in data:
return self.error("tfa_required")
if not OtpAuth(user.tfa_token).valid_totp(data["tfa_code"]):
return self.error("Invalid two factor verification code")
if User.objects.filter(email=data["new_email"]).exists():
return self.error("The email is owned by other account")
user.email = data["new_email"]
user.save()
return self.success("Succeeded")
else:
return self.error("Wrong password")
class UserChangePasswordAPI(APIView): class UserChangePasswordAPI(APIView):
@validate_serializer(UserChangePasswordSerializer) @validate_serializer(UserChangePasswordSerializer)
@login_required @login_required
@@ -244,7 +260,11 @@ class UserChangePasswordAPI(APIView):
username = request.user.username username = request.user.username
user = auth.authenticate(username=username, password=data["old_password"]) user = auth.authenticate(username=username, password=data["old_password"])
if user: if user:
# TODO: check tfa? if user.two_factor_auth:
if "tfa_code" not in data:
return self.error("tfa_required")
if not OtpAuth(user.tfa_token).valid_totp(data["tfa_code"]):
return self.error("Invalid two factor verification code")
user.set_password(data["new_password"]) user.set_password(data["new_password"])
user.save() user.save()
return self.success("Succeeded") return self.success("Succeeded")

View File

@@ -45,13 +45,12 @@ class Contest(models.Model):
def is_contest_admin(self, user): def is_contest_admin(self, user):
return user.is_authenticated() and (self.created_by == user or user.admin_type == AdminType.SUPER_ADMIN) return user.is_authenticated() and (self.created_by == user or user.admin_type == AdminType.SUPER_ADMIN)
def check_oi_permission(self, user): # 是否有权查看problem 的一些统计信息 诸如submission_number, accepted_number 等
if self.status != ContestStatus.CONTEST_ENDED and not self.real_time_rank: def problem_details_permission(self, user):
if self.is_contest_admin(user): return self.rule_type == ContestRuleType.ACM or \
return True self.status == ContestStatus.CONTEST_ENDED or \
else: self.is_contest_admin(user) or \
return False self.real_time_rank
return True
class Meta: class Meta:
db_table = "contest" db_table = "contest"

View File

@@ -58,43 +58,40 @@ class ContestAdminAPITest(APITestCase):
class ContestAPITest(APITestCase): class ContestAPITest(APITestCase):
def setUp(self): def setUp(self):
self.create_admin() self.create_admin()
self.url = self.reverse("contest_api")
def create_contest(self):
url = self.reverse("contest_admin_api") url = self.reverse("contest_admin_api")
return self.client.post(url, data=DEFAULT_CONTEST_DATA) self.contest = self.client.post(url, data=DEFAULT_CONTEST_DATA).data["data"]
self.url = self.reverse("contest_api") + "?contest_id=" + str(self.contest["id"])
def test_get_contest_list(self): def test_get_contest_list(self):
self.create_contest() url = self.reverse("contest_list_api")
response = self.client.get(self.url) response = self.client.get(url + "?limit=10")
self.assertSuccess(response) self.assertSuccess(response)
self.assertEqual(len(response.data["data"]["results"]), 1)
def test_get_one_contest(self): def test_get_one_contest(self):
contest_id = self.create_contest().data["data"]["id"] resp = self.client.get(self.url)
response = self.client.get("{}?id={}".format(self.url, contest_id)) self.assertSuccess(resp)
self.assertSuccess(response)
def test_regular_user_validate_contest_password(self): def test_regular_user_validate_contest_password(self):
contest_id = self.create_contest().data["data"]["id"]
self.create_user("test", "test123") self.create_user("test", "test123")
url = self.reverse("contest_password_api") url = self.reverse("contest_password_api")
resp = self.client.post(url, {"contest_id": contest_id, "password": "error_password"}) resp = self.client.post(url, {"contest_id": self.contest["id"], "password": "error_password"})
self.assertDictEqual(resp.data, {"error": "error", "data": "Wrong password"}) self.assertDictEqual(resp.data, {"error": "error", "data": "Wrong password"})
resp = self.client.post(url, {"contest_id": contest_id, "password": DEFAULT_CONTEST_DATA["password"]}) resp = self.client.post(url, {"contest_id": self.contest["id"], "password": DEFAULT_CONTEST_DATA["password"]})
self.assertSuccess(resp) self.assertSuccess(resp)
def test_regular_user_access_contest(self): def test_regular_user_access_contest(self):
contest_id = self.create_contest().data["data"]["id"]
self.create_user("test", "test123") self.create_user("test", "test123")
url = self.reverse("contest_access_api") url = self.reverse("contest_access_api")
resp = self.client.get(url + "?contest_id=" + str(contest_id)) resp = self.client.get(url + "?contest_id=" + str(self.contest["id"]))
self.assertFalse(resp.data["data"]["access"]) self.assertFalse(resp.data["data"]["access"])
password_url = self.reverse("contest_password_api") password_url = self.reverse("contest_password_api")
resp = self.client.post(password_url, {"contest_id": contest_id, "password": DEFAULT_CONTEST_DATA["password"]}) resp = self.client.post(password_url,
{"contest_id": self.contest["id"], "password": DEFAULT_CONTEST_DATA["password"]})
self.assertSuccess(resp) self.assertSuccess(resp)
resp = self.client.get(url + "?contest_id=" + str(contest_id)) resp = self.client.get(self.url)
self.assertSuccess(resp) self.assertSuccess(resp)

View File

@@ -1,10 +1,12 @@
from django.conf.urls import url from django.conf.urls import url
from ..views.oj import ContestAnnouncementListAPI, ContestAPI from ..views.oj import ContestAnnouncementListAPI
from ..views.oj import ContestPasswordVerifyAPI, ContestAccessAPI from ..views.oj import ContestPasswordVerifyAPI, ContestAccessAPI
from ..views.oj import ContestListAPI, ContestAPI
from ..views.oj import ContestRankAPI from ..views.oj import ContestRankAPI
urlpatterns = [ urlpatterns = [
url(r"^contests/?$", ContestListAPI.as_view(), name="contest_list_api"),
url(r"^contest/?$", ContestAPI.as_view(), name="contest_api"), url(r"^contest/?$", ContestAPI.as_view(), name="contest_api"),
url(r"^contest/password/?$", ContestPasswordVerifyAPI.as_view(), name="contest_password_api"), url(r"^contest/password/?$", ContestPasswordVerifyAPI.as_view(), name="contest_password_api"),
url(r"^contest/announcement/?$", ContestAnnouncementListAPI.as_view(), name="contest_announcement_api"), url(r"^contest/announcement/?$", ContestAnnouncementListAPI.as_view(), name="contest_announcement_api"),

View File

@@ -12,6 +12,7 @@ from ..serializers import OIContestRankSerializer, ACMContestRankSerializer
class ContestAnnouncementListAPI(APIView): class ContestAnnouncementListAPI(APIView):
@check_contest_permission(check_type="announcements")
def get(self, request): def get(self, request):
contest_id = request.GET.get("contest_id") contest_id = request.GET.get("contest_id")
if not contest_id: if not contest_id:
@@ -24,15 +25,13 @@ class ContestAnnouncementListAPI(APIView):
class ContestAPI(APIView): class ContestAPI(APIView):
@check_contest_permission(check_type="details")
def get(self, request): def get(self, request):
contest_id = request.GET.get("id") return self.success(ContestSerializer(self.contest).data)
if contest_id:
try:
contest = Contest.objects.select_related("created_by").get(id=contest_id, visible=True)
except Contest.DoesNotExist:
return self.error("Contest does not exist")
return self.success(ContestSerializer(contest).data)
class ContestListAPI(APIView):
def get(self, request):
contests = Contest.objects.select_related("created_by").filter(visible=True) contests = Contest.objects.select_related("created_by").filter(visible=True)
keyword = request.GET.get("keyword") keyword = request.GET.get("keyword")
rule_type = request.GET.get("rule_type") rule_type = request.GET.get("rule_type")
@@ -49,7 +48,8 @@ class ContestAPI(APIView):
contests = contests.filter(end_time__lt=cur) contests = contests.filter(end_time__lt=cur)
else: else:
contests = contests.filter(start_time__lte=cur, end_time__gte=cur) contests = contests.filter(start_time__lte=cur, end_time__gte=cur)
return self.success(self.paginate_data(request, contests, ContestSerializer)) data = self.paginate_data(request, contests, ContestSerializer)
return self.success(data)
class ContestPasswordVerifyAPI(APIView): class ContestPasswordVerifyAPI(APIView):
@@ -91,11 +91,9 @@ class ContestRankAPI(APIView):
return OIContestRank.objects.filter(contest=self.contest). \ return OIContestRank.objects.filter(contest=self.contest). \
select_related("user").order_by("-total_score") select_related("user").order_by("-total_score")
@check_contest_permission @check_contest_permission(check_type="ranks")
def get(self, request): def get(self, request):
if self.contest.rule_type == ContestRuleType.OI: if self.contest.rule_type == ContestRuleType.OI:
if not self.contest.check_oi_permission(request.user):
return self.error("You have no permission for ranks now")
serializer = OIContestRankSerializer serializer = OIContestRankSerializer
else: else:
serializer = ACMContestRankSerializer serializer = ACMContestRankSerializer
@@ -105,5 +103,4 @@ class ContestRankAPI(APIView):
if not qs: if not qs:
qs = self.get_rank() qs = self.get_rank()
cache.set(cache_key, qs) cache.set(cache_key, qs)
return self.success(self.paginate_data(request, qs, serializer)) return self.success(self.paginate_data(request, qs, serializer))

View File

@@ -14,7 +14,7 @@ cd $BASE
find . -name "*.pyc" -delete find . -name "*.pyc" -delete
# wait for postgresql start # wait for postgresql start
sleep 5 sleep 6
n=0 n=0
while [ $n -lt 3 ] while [ $n -lt 3 ]

View File

@@ -192,7 +192,7 @@ CELERY_ACCEPT_CONTENT = ["json"]
CELERY_TASK_SERIALIZER = "json" CELERY_TASK_SERIALIZER = "json"
# 用于限制用户恶意提交大量代码 # 用于限制用户恶意提交大量代码
TOKEN_BUCKET_DEFAULT_CAPACITY = 50 TOKEN_BUCKET_DEFAULT_CAPACITY = 20
# 单位:每分钟 # 单位:每分钟
TOKEN_BUCKET_FILL_RATE = 2 TOKEN_BUCKET_FILL_RATE = 2

View File

@@ -107,4 +107,11 @@ class ProblemSerializer(BaseProblemSerializer):
class ContestProblemSerializer(BaseProblemSerializer): class ContestProblemSerializer(BaseProblemSerializer):
class Meta: class Meta:
model = Problem model = Problem
exclude = ("test_case_score", "test_case_id", "visible", "is_public") exclude = ("test_case_score", "test_case_id", "visible", "is_public", "difficulty")
class ContestProblemSafeSerializer(BaseProblemSerializer):
class Meta:
model = Problem
exclude = ("test_case_score", "test_case_id", "visible", "is_public", "difficulty"
"submission_number", "accepted_number", "statistic_info")

View File

@@ -196,30 +196,26 @@ class ContestProblemAdminTest(APITestCase):
def setUp(self): def setUp(self):
self.url = self.reverse("contest_problem_admin_api") self.url = self.reverse("contest_problem_admin_api")
self.create_admin() self.create_admin()
self.contest = self.client.post(self.reverse("contest_admin_api"), data=DEFAULT_CONTEST_DATA).data["data"]
def create_contest(self):
url = self.reverse("contest_admin_api")
return self.client.post(url, data=DEFAULT_CONTEST_DATA)
def test_create_contest_problem(self): def test_create_contest_problem(self):
contest = self.create_contest()
data = copy.deepcopy(DEFAULT_PROBLEM_DATA) data = copy.deepcopy(DEFAULT_PROBLEM_DATA)
data["contest_id"] = contest.data["data"]["id"] data["contest_id"] = self.contest["id"]
resp = self.client.post(self.url, data=data) resp = self.client.post(self.url, data=data)
self.assertSuccess(resp) self.assertSuccess(resp)
return contest, resp return resp.data["data"]
def test_get_contest_problem(self): def test_get_contest_problem(self):
contest, contest_problem = self.test_create_contest_problem() self.test_create_contest_problem()
contest_id = contest.data["data"]["id"] contest_id = self.contest["id"]
resp = self.client.get(self.url + "?contest_id=" + str(contest_id)) resp = self.client.get(self.url + "?contest_id=" + str(contest_id))
self.assertSuccess(resp) self.assertSuccess(resp)
self.assertEqual(len(resp.data["data"]), 1) self.assertEqual(len(resp.data["data"]["results"]), 1)
def test_get_one_contest_problem(self): def test_get_one_contest_problem(self):
contest, contest_problem = self.test_create_contest_problem() contest_problem = self.test_create_contest_problem()
contest_id = contest.data["data"]["id"] contest_id = self.contest["id"]
problem_id = contest_problem.data["data"]["id"] problem_id = contest_problem["id"]
resp = self.client.get(f"{self.url}?contest_id={contest_id}&id={problem_id}") resp = self.client.get(f"{self.url}?contest_id={contest_id}&id={problem_id}")
self.assertSuccess(resp) self.assertSuccess(resp)

View File

@@ -4,7 +4,7 @@ from utils.api import APIView
from account.decorators import check_contest_permission from account.decorators import check_contest_permission
from ..models import ProblemTag, Problem, ProblemRuleType from ..models import ProblemTag, Problem, ProblemRuleType
from ..serializers import ProblemSerializer, TagSerializer from ..serializers import ProblemSerializer, TagSerializer
from ..serializers import ContestProblemSerializer from ..serializers import ContestProblemSerializer, ContestProblemSafeSerializer
from contest.models import ContestRuleType from contest.models import ContestRuleType
@@ -81,8 +81,6 @@ class ProblemAPI(APIView):
class ContestProblemAPI(APIView): class ContestProblemAPI(APIView):
def _add_problem_status(self, request, queryset_values): def _add_problem_status(self, request, queryset_values):
if self.contest.rule_type == ContestRuleType.OI and not self.contest.check_oi_permission(request.user):
return
if request.user.is_authenticated(): if request.user.is_authenticated():
profile = request.user.userprofile profile = request.user.userprofile
if self.contest.rule_type == ContestRuleType.ACM: if self.contest.rule_type == ContestRuleType.ACM:
@@ -92,7 +90,7 @@ class ContestProblemAPI(APIView):
for problem in queryset_values: for problem in queryset_values:
problem["my_status"] = problems_status.get(str(problem["id"]), {}).get("status") problem["my_status"] = problems_status.get(str(problem["id"]), {}).get("status")
@check_contest_permission @check_contest_permission(check_type="problems")
def get(self, request): def get(self, request):
problem_id = request.GET.get("problem_id") problem_id = request.GET.get("problem_id")
if problem_id: if problem_id:
@@ -102,11 +100,17 @@ class ContestProblemAPI(APIView):
visible=True) visible=True)
except Problem.DoesNotExist: except Problem.DoesNotExist:
return self.error("Problem does not exist.") return self.error("Problem does not exist.")
problem_data = ContestProblemSerializer(problem).data if self.contest.problem_details_permission(request.user):
self._add_problem_status(request, [problem_data, ]) problem_data = ContestProblemSerializer(problem).data
self._add_problem_status(request, [problem_data, ])
else:
problem_data = ContestProblemSafeSerializer(problem).data
return self.success(problem_data) return self.success(problem_data)
contest_problems = Problem.objects.select_related("created_by").filter(contest=self.contest, visible=True) contest_problems = Problem.objects.select_related("created_by").filter(contest=self.contest, visible=True)
# 根据profile 为做过的题目添加标记 if self.contest.problem_details_permission(request.user):
data = ContestProblemSerializer(contest_problems, many=True).data data = ContestProblemSerializer(contest_problems, many=True).data
self._add_problem_status(request, data) self._add_problem_status(request, data)
else:
data = ContestProblemSafeSerializer(contest_problems, many=True).data
return self.success(data) return self.success(data)

View File

@@ -8,6 +8,7 @@ class CreateSubmissionSerializer(serializers.Serializer):
language = serializers.ChoiceField(choices=language_names) language = serializers.ChoiceField(choices=language_names)
code = serializers.CharField(max_length=20000) code = serializers.CharField(max_length=20000)
contest_id = serializers.IntegerField(required=False) contest_id = serializers.IntegerField(required=False)
captcha = serializers.CharField(required=False)
class ShareSubmissionSerializer(serializers.Serializer): class ShareSubmissionSerializer(serializers.Serializer):

View File

@@ -1,3 +1,4 @@
from django.conf import settings
from account.decorators import login_required, check_contest_permission from account.decorators import login_required, check_contest_permission
from judge.tasks import judge_task from judge.tasks import judge_task
# from judge.dispatcher import JudgeDispatcher # from judge.dispatcher import JudgeDispatcher
@@ -5,6 +6,7 @@ from problem.models import Problem, ProblemRuleType
from contest.models import Contest, ContestStatus, ContestRuleType from contest.models import Contest, ContestStatus, ContestRuleType
from utils.api import APIView, validate_serializer from utils.api import APIView, validate_serializer
from utils.throttling import TokenBucket, BucketController from utils.throttling import TokenBucket, BucketController
from utils.captcha import Captcha
from utils.cache import cache from utils.cache import cache
from ..models import Submission from ..models import Submission
from ..serializers import (CreateSubmissionSerializer, SubmissionModelSerializer, from ..serializers import (CreateSubmissionSerializer, SubmissionModelSerializer,
@@ -12,43 +14,38 @@ from ..serializers import (CreateSubmissionSerializer, SubmissionModelSerializer
from ..serializers import SubmissionSafeModelSerializer, SubmissionListSerializer from ..serializers import SubmissionSafeModelSerializer, SubmissionListSerializer
def _submit(response, user, problem_id, language, code, contest_id):
# TODO: 预设默认值,需修改
controller = BucketController(user_id=user.id,
redis_conn=cache,
default_capacity=30)
bucket = TokenBucket(fill_rate=10, capacity=20,
last_capacity=controller.last_capacity,
last_timestamp=controller.last_timestamp)
if bucket.consume():
controller.last_capacity -= 1
else:
return response.error("Please wait %d seconds" % int(bucket.expected_time() + 1))
try:
problem = Problem.objects.get(id=problem_id,
contest_id=contest_id,
visible=True)
except Problem.DoesNotExist:
return response.error("Problem not exist")
submission = Submission.objects.create(user_id=user.id,
username=user.username,
language=language,
code=code,
problem_id=problem.id,
contest_id=contest_id)
# use this for debug
# JudgeDispatcher(submission.id, problem.id).judge()
judge_task.delay(submission.id, problem.id)
return response.success({"submission_id": submission.id})
class SubmissionAPI(APIView): class SubmissionAPI(APIView):
def throttling(self, request):
user_controller = BucketController(factor=request.user.id,
redis_conn=cache,
default_capacity=settings.TOKEN_BUCKET_DEFAULT_CAPACITY)
user_bucket = TokenBucket(fill_rate=settings.TOKEN_BUCKET_FILL_RATE,
capacity=settings.TOKEN_BUCKET_DEFAULT_CAPACITY,
last_capacity=user_controller.last_capacity,
last_timestamp=user_controller.last_timestamp)
if user_bucket.consume():
user_controller.last_capacity -= 1
else:
return "Please wait %d seconds" % int(user_bucket.expected_time() + 1)
ip_controller = BucketController(factor=request.session["ip"],
redis_conn=cache,
default_capacity=settings.TOKEN_BUCKET_DEFAULT_CAPACITY * 3)
ip_bucket = TokenBucket(fill_rate=settings.TOKEN_BUCKET_FILL_RATE * 3,
capacity=settings.TOKEN_BUCKET_DEFAULT_CAPACITY * 3,
last_capacity=ip_controller.last_capacity,
last_timestamp=ip_controller.last_timestamp)
if ip_bucket.consume():
ip_controller.last_capacity -= 1
else:
return "Captcha is required"
@validate_serializer(CreateSubmissionSerializer) @validate_serializer(CreateSubmissionSerializer)
@login_required @login_required
def post(self, request): def post(self, request):
data = request.data data = request.data
hide_id = False
if data.get("contest_id"): if data.get("contest_id"):
try: try:
contest = Contest.objects.get(id=data["contest_id"]) contest = Contest.objects.get(id=data["contest_id"])
@@ -56,9 +53,39 @@ class SubmissionAPI(APIView):
return self.error("Contest doesn't exist.") return self.error("Contest doesn't exist.")
if contest.status == ContestStatus.CONTEST_ENDED: if contest.status == ContestStatus.CONTEST_ENDED:
return self.error("The contest have ended") return self.error("The contest have ended")
if contest.status == ContestStatus.CONTEST_NOT_START and request.user != contest.created_by: if contest.status == ContestStatus.CONTEST_NOT_START and not contest.is_contest_admin(request.user):
return self.error("Contest have not started") return self.error("Contest have not started")
return _submit(self, request.user, data["problem_id"], data["language"], data["code"], data.get("contest_id")) if not contest.problem_details_permission():
hide_id = True
if data.get("captcha"):
if not Captcha(request).check(data["captcha"]):
return self.error("Invalid captcha")
error = self.throttling(request)
if error:
return self.error(error)
try:
problem = Problem.objects.get(id=data["problem_id"],
contest_id=data.get("contest_id"),
visible=True)
except Problem.DoesNotExist:
return self.error("Problem not exist")
submission = Submission.objects.create(user_id=request.user.id,
username=request.user.username,
language=data["language"],
code=data["code"],
problem_id=problem.id,
contest_id=data.get("contest_id"))
# use this for debug
# JudgeDispatcher(submission.id, problem.id).judge()
judge_task.delay(submission.id, problem.id)
if hide_id:
return self.success()
else:
return self.success({"submission_id": submission.id})
@login_required @login_required
def get(self, request): def get(self, request):
@@ -123,15 +150,12 @@ class SubmissionListAPI(APIView):
class ContestSubmissionListAPI(APIView): class ContestSubmissionListAPI(APIView):
@check_contest_permission @check_contest_permission(check_type="submissions")
def get(self, request): def get(self, request):
if not request.GET.get("limit"): if not request.GET.get("limit"):
return self.error("Limit is needed") return self.error("Limit is needed")
contest = self.contest contest = self.contest
if not contest.check_oi_permission(request.user):
return self.error("No permission for OI contest submissions")
submissions = Submission.objects.filter(contest_id=contest.id).select_related("problem__created_by") submissions = Submission.objects.filter(contest_id=contest.id).select_related("problem__created_by")
problem_id = request.GET.get("problem_id") problem_id = request.GET.get("problem_id")
myself = request.GET.get("myself") myself = request.GET.get("myself")

View File

@@ -107,18 +107,12 @@ class APIView(View):
:param object_serializer: 用来序列化query set, 如果为None, 则直接对query set切片 :param object_serializer: 用来序列化query set, 如果为None, 则直接对query set切片
:return: :return:
""" """
need_paginate = request.GET.get("limit", None)
if need_paginate is None:
if object_serializer:
return object_serializer(query_set, many=True).data
else:
return {"results": query_set, "total": query_set.count()}
try: try:
limit = int(request.GET.get("limit", "100")) limit = int(request.GET.get("limit", "10"))
except ValueError: except ValueError:
limit = 100 limit = 10
if limit < 0: if limit < 0 or limit > 100:
limit = 100 limit = 10
try: try:
offset = int(request.GET.get("offset", "0")) offset = int(request.GET.get("offset", "0"))
except ValueError: except ValueError:

View File

@@ -27,8 +27,8 @@ class APITestCase(TestCase):
return self.create_user(username=username, password=password, admin_type=AdminType.SUPER_ADMIN, return self.create_user(username=username, password=password, admin_type=AdminType.SUPER_ADMIN,
problem_permission=ProblemPermission.ALL, login=login) problem_permission=ProblemPermission.ALL, login=login)
def reverse(self, url_name): def reverse(self, url_name, *args, **kwargs):
return reverse(url_name) return reverse(url_name, *args, **kwargs)
def assertSuccess(self, response): def assertSuccess(self, response):
if not response.data["error"] is None: if not response.data["error"] is None:

View File

@@ -31,11 +31,10 @@ class TokenBucket:
class BucketController: class BucketController:
def __init__(self, user_id, redis_conn, default_capacity): def __init__(self, factor, redis_conn, default_capacity):
self.user_id = user_id
self.default_capacity = default_capacity self.default_capacity = default_capacity
self.redis = redis_conn self.redis = redis_conn
self.key = "bucket_" + str(self.user_id) self.key = "bucket_" + str(factor)
@property @property
def last_capacity(self): def last_capacity(self):