diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..96ae5ca --- /dev/null +++ b/.dockerignore @@ -0,0 +1,4 @@ +venv +.idea +.git +.DS_Store diff --git a/.flake8 b/.flake8 index d64dbd5..0fc0318 100644 --- a/.flake8 +++ b/.flake8 @@ -4,6 +4,7 @@ exclude = */migrations/, *settings.py */apps.py + venv/ max-line-length = 180 inline-quotes = " no-accept-encodings = True diff --git a/Dockerfile b/Dockerfile index 41cabf8..f66305b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM python:3.6-alpine3.6 +FROM python:3.7-alpine3.9 ENV OJ_ENV production diff --git a/account/decorators.py b/account/decorators.py index e57b331..08aaa81 100644 --- a/account/decorators.py +++ b/account/decorators.py @@ -31,19 +31,19 @@ class BasePermissionDecorator(object): class login_required(BasePermissionDecorator): def check_permission(self): - return self.request.user.is_authenticated() + return self.request.user.is_authenticated class super_admin_required(BasePermissionDecorator): def check_permission(self): user = self.request.user - return user.is_authenticated() and user.is_super_admin() + return user.is_authenticated and user.is_super_admin() class admin_role_required(BasePermissionDecorator): def check_permission(self): user = self.request.user - return user.is_authenticated() and user.is_admin_role() + return user.is_authenticated and user.is_admin_role() class problem_permission_required(admin_role_required): @@ -80,7 +80,7 @@ def check_contest_permission(check_type="details"): return self.error("Contest %s doesn't exist" % contest_id) # Anonymous - if not user.is_authenticated(): + if not user.is_authenticated: return self.error("Please login first.") # creator or owner diff --git a/account/middleware.py b/account/middleware.py index 8834b15..91eed27 100644 --- a/account/middleware.py +++ b/account/middleware.py @@ -22,7 +22,7 @@ class APITokenAuthMiddleware(MiddlewareMixin): class SessionRecordMiddleware(MiddlewareMixin): def process_request(self, request): request.ip = request.META.get(settings.IP_HEADER, request.META.get("REMOTE_ADDR")) - if request.user.is_authenticated(): + if request.user.is_authenticated: session = request.session session["user_agent"] = request.META.get("HTTP_USER_AGENT", "") session["ip"] = request.ip @@ -37,7 +37,7 @@ class AdminRoleRequiredMiddleware(MiddlewareMixin): def process_request(self, request): path = request.path_info if path.startswith("/admin/") or path.startswith("/api/admin/"): - if not (request.user.is_authenticated() and request.user.is_admin_role()): + if not (request.user.is_authenticated and request.user.is_admin_role()): return JSONResponse.response({"error": "login-required", "data": "Please login in first"}) diff --git a/account/models.py b/account/models.py index eac1a62..afd644e 100644 --- a/account/models.py +++ b/account/models.py @@ -60,7 +60,7 @@ class User(AbstractBaseUser): return self.problem_permission == ProblemPermission.ALL def is_contest_admin(self, contest): - return self.is_authenticated() and (contest.created_by == self or self.admin_type == AdminType.SUPER_ADMIN) + return self.is_authenticated and (contest.created_by == self or self.admin_type == AdminType.SUPER_ADMIN) class Meta: db_table = "user" diff --git a/account/serializers.py b/account/serializers.py index 21ae24a..31ebd09 100644 --- a/account/serializers.py +++ b/account/serializers.py @@ -131,6 +131,10 @@ class ImageUploadForm(forms.Form): image = forms.FileField() +class FileUploadForm(forms.Form): + file = forms.FileField() + + class RankInfoSerializer(serializers.ModelSerializer): user = UsernameSerializer() diff --git a/account/tasks.py b/account/tasks.py index 12c1587..5135999 100644 --- a/account/tasks.py +++ b/account/tasks.py @@ -1,14 +1,13 @@ import logging - -from celery import shared_task +import dramatiq from options.options import SysOptions -from utils.shortcuts import send_email +from utils.shortcuts import send_email, DRAMATIQ_WORKER_ARGS logger = logging.getLogger(__name__) -@shared_task +@dramatiq.actor(**DRAMATIQ_WORKER_ARGS(max_retries=3)) def send_email_async(from_name, to_email, to_name, subject, content): if not SysOptions.smtp_config: return diff --git a/account/tests.py b/account/tests.py index 65e765d..4941eb2 100644 --- a/account/tests.py +++ b/account/tests.py @@ -101,13 +101,13 @@ class UserLoginAPITest(APITestCase): self.assertDictEqual(response.data, {"error": None, "data": "Succeeded"}) user = auth.get_user(self.client) - self.assertTrue(user.is_authenticated()) + self.assertTrue(user.is_authenticated) def test_login_with_correct_info_upper_username(self): resp = self.client.post(self.login_url, data={"username": self.username.upper(), "password": self.password}) self.assertDictEqual(resp.data, {"error": None, "data": "Succeeded"}) user = auth.get_user(self.client) - self.assertTrue(user.is_authenticated()) + self.assertTrue(user.is_authenticated) def test_login_with_wrong_info(self): response = self.client.post(self.login_url, @@ -115,7 +115,7 @@ class UserLoginAPITest(APITestCase): self.assertDictEqual(response.data, {"error": "error", "data": "Invalid username or password"}) user = auth.get_user(self.client) - self.assertFalse(user.is_authenticated()) + self.assertFalse(user.is_authenticated) def test_tfa_login(self): token = self._set_tfa() @@ -129,7 +129,7 @@ class UserLoginAPITest(APITestCase): self.assertDictEqual(response.data, {"error": None, "data": "Succeeded"}) user = auth.get_user(self.client) - self.assertTrue(user.is_authenticated()) + self.assertTrue(user.is_authenticated) def test_tfa_login_wrong_code(self): self._set_tfa() @@ -140,7 +140,7 @@ class UserLoginAPITest(APITestCase): self.assertDictEqual(response.data, {"error": "error", "data": "Invalid two factor verification code"}) user = auth.get_user(self.client) - self.assertFalse(user.is_authenticated()) + self.assertFalse(user.is_authenticated) def test_tfa_login_without_code(self): self._set_tfa() @@ -150,7 +150,7 @@ class UserLoginAPITest(APITestCase): self.assertDictEqual(response.data, {"error": "error", "data": "tfa_required"}) user = auth.get_user(self.client) - self.assertFalse(user.is_authenticated()) + self.assertFalse(user.is_authenticated) def test_user_disabled(self): self.user.is_disabled = True @@ -304,7 +304,7 @@ class TwoFactorAuthAPITest(APITestCase): self.assertEqual(user.two_factor_auth, False) -@mock.patch("account.views.oj.send_email_async.delay") +@mock.patch("account.views.oj.send_email_async.send") class ApplyResetPasswordAPITest(CaptchaTest): def setUp(self): self.create_user("test", "test123", login=False) @@ -317,20 +317,20 @@ class ApplyResetPasswordAPITest(CaptchaTest): def _refresh_captcha(self): self.data["captcha"] = self._set_captcha(self.client.session) - def test_apply_reset_password(self, send_email_delay): + def test_apply_reset_password(self, send_email_send): resp = self.client.post(self.url, data=self.data) self.assertSuccess(resp) - send_email_delay.assert_called() + send_email_send.assert_called() - def test_apply_reset_password_twice_in_20_mins(self, send_email_delay): + def test_apply_reset_password_twice_in_20_mins(self, send_email_send): self.test_apply_reset_password() - send_email_delay.reset_mock() + send_email_send.reset_mock() self._refresh_captcha() resp = self.client.post(self.url, data=self.data) self.assertDictEqual(resp.data, {"error": "error", "data": "You can only reset password once per 20 minutes"}) - send_email_delay.assert_not_called() + send_email_send.assert_not_called() - def test_apply_reset_password_again_after_20_mins(self, send_email_delay): + def test_apply_reset_password_again_after_20_mins(self, send_email_send): self.test_apply_reset_password() user = User.objects.first() user.reset_password_token_expire_time = now() - timedelta(minutes=21) diff --git a/account/views/admin.py b/account/views/admin.py index 04e21e0..581a607 100644 --- a/account/views/admin.py +++ b/account/views/admin.py @@ -121,25 +121,15 @@ class UserAdminAPI(APIView): Q(email__icontains=keyword)) return self.success(self.paginate_data(request, user, UserAdminSerializer)) - def delete_one(self, user_id): - try: - user = User.objects.get(id=user_id) - except User.DoesNotExist: - return f"User {user_id} does not exist" - if Submission.objects.filter(user_id=user_id).exists(): - return f"Can't delete the user {user_id} as he/she has submissions" - user.delete() - @super_admin_required def delete(self, request): id = request.GET.get("id") if not id: return self.error("Invalid Parameter, id is required") - for user_id in id.split(","): - if user_id: - error = self.delete_one(user_id) - if error: - return self.error(error) + ids = id.split(",") + if str(request.user.id) in ids: + return self.error("Current user can not be deleted") + User.objects.filter(id__in=ids).delete() return self.success() diff --git a/account/views/oj.py b/account/views/oj.py index d7e2a64..2d26e48 100644 --- a/account/views/oj.py +++ b/account/views/oj.py @@ -35,7 +35,7 @@ class UserProfileAPI(APIView): 判断是否登录, 若登录返回用户信息 """ user = request.user - if not user.is_authenticated(): + if not user.is_authenticated: return self.success() show_real_name = False username = request.GET.get("username") @@ -280,7 +280,7 @@ class UserChangePasswordAPI(APIView): class ApplyResetPasswordAPI(APIView): @validate_serializer(ApplyResetPasswordSerializer) def post(self, request): - if request.user.is_authenticated(): + if request.user.is_authenticated: return self.error("You have already logged in, are you kidding me? ") data = request.data captcha = Captcha(request) @@ -302,11 +302,11 @@ class ApplyResetPasswordAPI(APIView): "link": f"{SysOptions.website_base_url}/reset-password/{user.reset_password_token}" } email_html = render_to_string("reset_password_email.html", render_data) - send_email_async.delay(from_name=SysOptions.website_name_shortcut, - to_email=user.email, - to_name=user.username, - subject=f"Reset your password", - content=email_html) + send_email_async.send(from_name=SysOptions.website_name_shortcut, + to_email=user.email, + to_name=user.username, + subject=f"Reset your password", + content=email_html) return self.success("Succeeded") diff --git a/announcement/models.py b/announcement/models.py index 37b8d59..441c4c1 100644 --- a/announcement/models.py +++ b/announcement/models.py @@ -9,7 +9,7 @@ class Announcement(models.Model): # HTML content = RichTextField() create_time = models.DateTimeField(auto_now_add=True) - created_by = models.ForeignKey(User) + created_by = models.ForeignKey(User, on_delete=models.CASCADE) last_update_time = models.DateTimeField(auto_now=True) visible = models.BooleanField(default=True) diff --git a/contest/migrations/0010_auto_20190326_0201.py b/contest/migrations/0010_auto_20190326_0201.py new file mode 100644 index 0000000..8700e1f --- /dev/null +++ b/contest/migrations/0010_auto_20190326_0201.py @@ -0,0 +1,23 @@ +# Generated by Django 2.1.7 on 2019-03-26 02:01 + +from django.conf import settings +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ('contest', '0009_auto_20180501_0436'), + ] + + operations = [ + migrations.AlterUniqueTogether( + name='acmcontestrank', + unique_together={('user', 'contest')}, + ), + migrations.AlterUniqueTogether( + name='oicontestrank', + unique_together={('user', 'contest')}, + ), + ] diff --git a/contest/models.py b/contest/models.py index 77dcf8e..4616bef 100644 --- a/contest/models.py +++ b/contest/models.py @@ -20,7 +20,7 @@ class Contest(models.Model): end_time = models.DateTimeField() create_time = models.DateTimeField(auto_now_add=True) last_update_time = models.DateTimeField(auto_now=True) - created_by = models.ForeignKey(User) + created_by = models.ForeignKey(User, on_delete=models.CASCADE) # 是否可见 false的话相当于删除 visible = models.BooleanField(default=True) allowed_ip_ranges = JSONField(default=list) @@ -47,7 +47,7 @@ class Contest(models.Model): def problem_details_permission(self, user): return self.rule_type == ContestRuleType.ACM or \ self.status == ContestStatus.CONTEST_ENDED or \ - user.is_authenticated() and user.is_contest_admin(self) or \ + user.is_authenticated and user.is_contest_admin(self) or \ self.real_time_rank class Meta: @@ -56,8 +56,8 @@ class Contest(models.Model): class AbstractContestRank(models.Model): - user = models.ForeignKey(User) - contest = models.ForeignKey(Contest) + user = models.ForeignKey(User, on_delete=models.CASCADE) + contest = models.ForeignKey(Contest, on_delete=models.CASCADE) submission_number = models.IntegerField(default=0) class Meta: @@ -74,6 +74,7 @@ class ACMContestRank(AbstractContestRank): class Meta: db_table = "acm_contest_rank" + unique_together = (("user", "contest"),) class OIContestRank(AbstractContestRank): @@ -84,13 +85,14 @@ class OIContestRank(AbstractContestRank): class Meta: db_table = "oi_contest_rank" + unique_together = (("user", "contest"),) class ContestAnnouncement(models.Model): - contest = models.ForeignKey(Contest) + contest = models.ForeignKey(Contest, on_delete=models.CASCADE) title = models.TextField() content = RichTextField() - created_by = models.ForeignKey(User) + created_by = models.ForeignKey(User, on_delete=models.CASCADE) visible = models.BooleanField(default=True) create_time = models.DateTimeField(auto_now_add=True) diff --git a/contest/tests.py b/contest/tests.py index 8a21977..4c16004 100644 --- a/contest/tests.py +++ b/contest/tests.py @@ -45,7 +45,7 @@ class ContestAdminAPITest(APITestCase): response_data = response.data["data"] for k in data.keys(): if isinstance(data[k], datetime): - continue + continue self.assertEqual(response_data[k], data[k]) def test_get_contests(self): diff --git a/contest/views/admin.py b/contest/views/admin.py index 9adb45a..66addb1 100644 --- a/contest/views/admin.py +++ b/contest/views/admin.py @@ -234,7 +234,7 @@ class DownloadContestSubmissions(APIView): exclude_admin = request.GET.get("exclude_admin") == "1" zip_path = self._dump_submissions(contest, exclude_admin) - delete_files.apply_async((zip_path,), countdown=300) + delete_files.send_with_options(args=(zip_path,), delay=300_000) resp = FileResponse(open(zip_path, "rb")) resp["Content-Type"] = "application/zip" resp["Content-Disposition"] = f"attachment;filename={os.path.basename(zip_path)}" diff --git a/contest/views/oj.py b/contest/views/oj.py index 985507f..3aeb681 100644 --- a/contest/views/oj.py +++ b/contest/views/oj.py @@ -8,7 +8,7 @@ from django.core.cache import cache from problem.models import Problem from utils.api import APIView, validate_serializer from utils.constants import CacheKey -from utils.shortcuts import datetime2str +from utils.shortcuts import datetime2str, check_is_id from account.models import AdminType from account.decorators import login_required, check_contest_permission @@ -35,7 +35,7 @@ class ContestAnnouncementListAPI(APIView): class ContestAPI(APIView): def get(self, request): id = request.GET.get("id") - if not id: + if not id or not check_is_id(id): return self.error("Invalid parameter, id is required") try: contest = Contest.objects.get(id=id, visible=True) @@ -121,7 +121,7 @@ class ContestRankAPI(APIView): def get(self, request): download_csv = request.GET.get("download_csv") force_refresh = request.GET.get("force_refresh") - is_contest_admin = request.user.is_authenticated() and request.user.is_contest_admin(self.contest) + is_contest_admin = request.user.is_authenticated and request.user.is_contest_admin(self.contest) if self.contest.rule_type == ContestRuleType.OI: serializer = OIContestRankSerializer else: diff --git a/deploy/requirements.txt b/deploy/requirements.txt index 75827ff..b0c6fe9 100644 --- a/deploy/requirements.txt +++ b/deploy/requirements.txt @@ -1,19 +1,32 @@ -django==1.11.4 -djangorestframework==3.4.0 -pillow -otpauth -flake8-quotes -pytz -coverage -python-dateutil -celery -Envelopes -qrcode -flake8-coding -requests -django-redis -psycopg2-binary -gunicorn -jsonfield -XlsxWriter -raven +certifi==2019.3.9 +chardet==3.0.4 +coverage==4.5.3 +Django==2.1.7 +django-redis==4.10.0 +djangorestframework==3.8.2 +entrypoints==0.3 +Envelopes==0.4 +flake8==3.7.7 +flake8-coding==1.3.1 +flake8-quotes==1.0.0 +gunicorn==19.9.0 +idna==2.8 +jsonfield==2.0.2 +mccabe==0.6.1 +otpauth==1.0.1 +Pillow==5.4.1 +psycopg2-binary==2.7.7 +pycodestyle==2.5.0 +pyflakes==2.1.1 +python-dateutil==2.8.0 +pytz==2018.9 +qrcode==6.1 +raven==6.10.0 +redis==3.2.0 +requests==2.21.0 +six==1.12.0 +urllib3==1.24.1 +XlsxWriter==1.1.5 +django-dramatiq==0.5.0 +dramatiq==1.3.0 +django-dbconn-retry==0.1.5 diff --git a/deploy/supervisord.conf b/deploy/supervisord.conf index 0eca721..b174046 100644 --- a/deploy/supervisord.conf +++ b/deploy/supervisord.conf @@ -38,12 +38,12 @@ startsecs=5 stopwaitsecs = 5 killasgroup=true -[program:celery] -command=celery -A oj worker -l warning --autoscale 2,%(ENV_MAX_WORKER_NUM)s +[program:dramatiq] +command=python3 manage.py rundramatiq --no-reload --processes %(ENV_MAX_WORKER_NUM)s --threads 4 directory=/app/ user=nobody -stdout_logfile=/data/log/celery.log -stderr_logfile=/data/log/celery.log +stdout_logfile=/data/log/dramatiq.log +stderr_logfile=/data/log/dramatiq.log autostart=true autorestart=true startsecs=5 diff --git a/docs/data.json b/docs/data.json index 318c912..f8d5a63 100644 --- a/docs/data.json +++ b/docs/data.json @@ -1,5 +1,19 @@ { "update": [ + { + "version": "2019-03-25", + "level": "Recommend", + "title": "2019-03-25", + "details": [ + "Update Django to version 2.1 and Python to version 3.7", + "Replace celery with dramatiq", + "Add problem file IO Mode", + "You can add attachments in all editor", + "You can upload source code file in submission editor", + "Frontend and UI improvements", + "Fixed a lot of bugs" + ] + }, { "version": "2018-12-15", "level": "Recommend", diff --git a/judge/dispatcher.py b/judge/dispatcher.py index 6f3a943..9204e54 100644 --- a/judge/dispatcher.py +++ b/judge/dispatcher.py @@ -4,7 +4,7 @@ import logging from urllib.parse import urljoin import requests -from django.db import transaction +from django.db import transaction, IntegrityError from django.db.models import F from account.models import User @@ -26,7 +26,28 @@ def process_pending_task(): # 防止循环引入 from judge.tasks import judge_task data = json.loads(cache.rpop(CacheKey.waiting_queue).decode("utf-8")) - judge_task.delay(**data) + judge_task.send(**data) + + +class ChooseJudgeServer: + def __init__(self): + self.server = None + + def __enter__(self) -> [JudgeServer, None]: + with transaction.atomic(): + servers = JudgeServer.objects.select_for_update().filter(is_disabled=False).order_by("task_number") + servers = [s for s in servers if s.status == "normal"] + for server in servers: + if server.task_number <= server.cpu_core * 2: + server.task_number = F("task_number") + 1 + server.save() + self.server = server + return server + return None + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.server: + JudgeServer.objects.filter(id=self.server.id).update(task_number=F("task_number") - 1) class DispatcherBase(object): @@ -42,25 +63,6 @@ class DispatcherBase(object): except Exception as e: logger.exception(e) - @staticmethod - def choose_judge_server(): - with transaction.atomic(): - servers = JudgeServer.objects.select_for_update().filter(is_disabled=False).order_by("task_number") - servers = [s for s in servers if s.status == "normal"] - for server in servers: - if server.task_number <= server.cpu_core * 2: - server.task_number = F("task_number") + 1 - server.save() - return server - - @staticmethod - def release_judge_server(judge_server_id): - with transaction.atomic(): - # 使用原子操作, 同时因为use和release中间间隔了判题过程,需要重新查询一下 - server = JudgeServer.objects.get(id=judge_server_id) - server.task_number = F("task_number") - 1 - server.save() - class SPJCompiler(DispatcherBase): def __init__(self, spj_code, spj_version, spj_language): @@ -74,13 +76,14 @@ class SPJCompiler(DispatcherBase): } def compile_spj(self): - server = self.choose_judge_server() - if not server: - return "No available judge_server" - result = self._request(urljoin(server.service_url, "compile_spj"), data=self.data) - self.release_judge_server(server.id) - if result["err"]: - return result["data"] + with ChooseJudgeServer() as server: + if not server: + return "No available judge_server" + result = self._request(urljoin(server.service_url, "compile_spj"), data=self.data) + if not result: + return "Failed to call judge server" + if result["err"]: + return result["data"] class JudgeDispatcher(DispatcherBase): @@ -118,12 +121,6 @@ class JudgeDispatcher(DispatcherBase): self.submission.statistic_info["score"] = score def judge(self): - server = self.choose_judge_server() - if not server: - data = {"submission_id": self.submission.id, "problem_id": self.problem.id} - cache.lpush(CacheKey.waiting_queue, json.dumps(data)) - return - language = self.submission.language sub_config = list(filter(lambda item: language == item["name"], SysOptions.languages))[0] spj_config = {} @@ -149,12 +146,22 @@ class JudgeDispatcher(DispatcherBase): "spj_version": self.problem.spj_version, "spj_config": spj_config.get("config"), "spj_compile_config": spj_config.get("compile"), - "spj_src": self.problem.spj_code + "spj_src": self.problem.spj_code, + "io_mode": self.problem.io_mode } - Submission.objects.filter(id=self.submission.id).update(result=JudgeStatus.JUDGING) + with ChooseJudgeServer() as server: + if not server: + data = {"submission_id": self.submission.id, "problem_id": self.problem.id} + cache.lpush(CacheKey.waiting_queue, json.dumps(data)) + return + Submission.objects.filter(id=self.submission.id).update(result=JudgeStatus.JUDGING) + resp = self._request(urljoin(server.service_url, "/judge"), data=data) + + if not resp: + Submission.objects.filter(id=self.submission.id).update(result=JudgeStatus.SYSTEM_ERROR) + return - resp = self._request(urljoin(server.service_url, "/judge"), data=data) if resp["err"]: self.submission.result = JudgeStatus.COMPILE_ERROR self.submission.statistic_info["err_info"] = resp["data"] @@ -173,7 +180,6 @@ class JudgeDispatcher(DispatcherBase): else: self.submission.result = JudgeStatus.PARTIALLY_ACCEPTED self.submission.save() - self.release_judge_server(server.id) if self.contest_id: if self.contest.status != ContestStatus.CONTEST_UNDERWAY or \ @@ -181,8 +187,9 @@ class JudgeDispatcher(DispatcherBase): logger.info( "Contest debug mode, id: " + str(self.contest_id) + ", submission id: " + self.submission.id) return - self.update_contest_problem_status() - self.update_contest_rank() + with transaction.atomic(): + self.update_contest_problem_status() + self.update_contest_rank() else: if self.last_result: self.update_problem_status_rejudge() @@ -322,20 +329,31 @@ class JudgeDispatcher(DispatcherBase): def update_contest_rank(self): if self.contest.rule_type == ContestRuleType.OI or self.contest.real_time_rank: cache.delete(f"{CacheKey.contest_rank_cache}:{self.contest.id}") - with transaction.atomic(): - if self.contest.rule_type == ContestRuleType.ACM: - acm_rank, _ = ACMContestRank.objects.select_for_update(). \ - get_or_create(user_id=self.submission.user_id, contest=self.contest) - self._update_acm_contest_rank(acm_rank) - else: - oi_rank, _ = OIContestRank.objects.select_for_update(). \ - get_or_create(user_id=self.submission.user_id, contest=self.contest) - self._update_oi_contest_rank(oi_rank) + + def get_rank(model): + return model.objects.select_for_update().get(user_id=self.submission.user_id, contest=self.contest) + + if self.contest.rule_type == ContestRuleType.ACM: + model = ACMContestRank + func = self._update_acm_contest_rank + else: + model = OIContestRank + func = self._update_oi_contest_rank + + try: + rank = get_rank(model) + except model.DoesNotExist: + try: + model.objects.create(user_id=self.submission.user_id, contest=self.contest) + rank = get_rank(model) + except IntegrityError: + rank = get_rank(model) + func(rank) def _update_acm_contest_rank(self, rank): info = rank.submission_info.get(str(self.submission.problem_id)) # 因前面更改过,这里需要重新获取 - problem = Problem.objects.get(contest_id=self.contest_id, id=self.problem.id) + problem = Problem.objects.select_for_update().get(contest_id=self.contest_id, id=self.problem.id) # 此题提交过 if info: if info["is_ac"]: diff --git a/judge/languages.py b/judge/languages.py index 6849508..2769fcd 100644 --- a/judge/languages.py +++ b/judge/languages.py @@ -1,3 +1,6 @@ +from problem.models import ProblemIOMode + + default_env = ["LANG=en_US.UTF-8", "LANGUAGE=en_US:en", "LC_ALL=en_US.UTF-8"] _c_lang_config = { @@ -28,7 +31,7 @@ int main() { }, "run": { "command": "{exe_path}", - "seccomp_rule": "c_cpp", + "seccomp_rule": {ProblemIOMode.standard: "c_cpp", ProblemIOMode.file: "c_cpp_file_io"}, "env": default_env } } diff --git a/judge/tasks.py b/judge/tasks.py index a5c1818..8a1794a 100644 --- a/judge/tasks.py +++ b/judge/tasks.py @@ -1,12 +1,12 @@ -from __future__ import absolute_import, unicode_literals -from celery import shared_task +import dramatiq from account.models import User from submission.models import Submission from judge.dispatcher import JudgeDispatcher +from utils.shortcuts import DRAMATIQ_WORKER_ARGS -@shared_task +@dramatiq.actor(**DRAMATIQ_WORKER_ARGS()) def judge_task(submission_id, problem_id): uid = Submission.objects.get(id=submission_id).user_id if User.objects.get(id=uid).is_disabled: diff --git a/oj/__init__.py b/oj/__init__.py index 23fc183..e69de29 100644 --- a/oj/__init__.py +++ b/oj/__init__.py @@ -1,6 +0,0 @@ -from __future__ import absolute_import, unicode_literals - -# Django starts so that shared_task will use this app. -from .celery import app as celery_app - -__all__ = ["celery_app"] diff --git a/oj/celery.py b/oj/celery.py deleted file mode 100644 index 4f24c7e..0000000 --- a/oj/celery.py +++ /dev/null @@ -1,18 +0,0 @@ -from __future__ import absolute_import, unicode_literals -import os -from celery import Celery -from django.conf import settings - -# set the default Django settings module for the "celery" program. -os.environ.setdefault("DJANGO_SETTINGS_MODULE", "oj.settings") - - -app = Celery("oj") - -# Using a string here means the worker will not have to -# pickle the object when using Windows. -app.config_from_object("django.conf:settings") - -# load task modules from all registered Django app configs. -app.autodiscover_tasks(lambda: settings.INSTALLED_APPS) -# app.autodiscover_tasks() diff --git a/oj/settings.py b/oj/settings.py index 6ba3987..c1141a0 100644 --- a/oj/settings.py +++ b/oj/settings.py @@ -26,16 +26,22 @@ with open(os.path.join(DATA_DIR, "config", "secret.key"), "r") as f: BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) # Applications -VENDOR_APPS = ( +VENDOR_APPS = [ 'django.contrib.auth', 'django.contrib.sessions', 'django.contrib.contenttypes', 'django.contrib.messages', 'django.contrib.staticfiles', 'rest_framework', - 'raven.contrib.django.raven_compat' -) -LOCAL_APPS = ( + 'django_dramatiq', + 'django_dbconn_retry', +] + +if production_env: + VENDOR_APPS.append('raven.contrib.django.raven_compat') + + +LOCAL_APPS = [ 'account', 'announcement', 'conf', @@ -45,11 +51,11 @@ LOCAL_APPS = ( 'submission', 'options', 'judge', -) +] INSTALLED_APPS = VENDOR_APPS + LOCAL_APPS -MIDDLEWARE_CLASSES = ( +MIDDLEWARE = ( 'django.contrib.sessions.middleware.SessionMiddleware', 'django.middleware.common.CommonMiddleware', 'django.middleware.csrf.CsrfViewMiddleware', @@ -164,6 +170,11 @@ LOGGING = { 'level': 'ERROR', 'propagate': True, }, + 'dramatiq': { + 'handlers': LOGGING_HANDLERS, + 'level': 'DEBUG', + 'propagate': False, + }, '': { 'handlers': LOGGING_HANDLERS, 'level': 'WARNING', @@ -202,11 +213,32 @@ CACHES = { SESSION_ENGINE = "django.contrib.sessions.backends.cache" SESSION_CACHE_ALIAS = "default" -CELERY_RESULT_BACKEND = f"{REDIS_URL}/2" -BROKER_URL = f"{REDIS_URL}/3" -CELERY_TASK_SOFT_TIME_LIMIT = CELERY_TASK_TIME_LIMIT = 180 -CELERY_ACCEPT_CONTENT = ["json"] -CELERY_TASK_SERIALIZER = "json" +DRAMATIQ_BROKER = { + "BROKER": "dramatiq.brokers.redis.RedisBroker", + "OPTIONS": { + "url": f"{REDIS_URL}/4", + }, + "MIDDLEWARE": [ + # "dramatiq.middleware.Prometheus", + "dramatiq.middleware.AgeLimit", + "dramatiq.middleware.TimeLimit", + "dramatiq.middleware.Callbacks", + "dramatiq.middleware.Retries", + # "django_dramatiq.middleware.AdminMiddleware", + "django_dramatiq.middleware.DbConnectionsMiddleware" + ] +} + +DRAMATIQ_RESULT_BACKEND = { + "BACKEND": "dramatiq.results.backends.redis.RedisBackend", + "BACKEND_OPTIONS": { + "url": f"{REDIS_URL}/4", + }, + "MIDDLEWARE_OPTIONS": { + "result_ttl": None + } +} + RAVEN_CONFIG = { 'dsn': 'https://b200023b8aed4d708fb593c5e0a6ad3d:1fddaba168f84fcf97e0d549faaeaff0@sentry.io/263057' } diff --git a/options/migrations/0003_migrate_languages_options.py b/options/migrations/0003_migrate_languages_options.py new file mode 100644 index 0000000..0f8f281 --- /dev/null +++ b/options/migrations/0003_migrate_languages_options.py @@ -0,0 +1,18 @@ +# -*- coding: utf-8 -*- +# Generated by Django 1.11.3 on 2018-05-01 04:36 +from __future__ import unicode_literals + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('options', '0002_auto_20180501_0436'), + ] + + operations = [ + migrations.RunSQL(""" + DELETE FROM options_sysoptions WHERE key = 'languages'; + """) + ] diff --git a/options/options.py b/options/options.py index 605c746..4bc0515 100644 --- a/options/options.py +++ b/options/options.py @@ -1,13 +1,92 @@ +import functools import os -from django.core.cache import cache +import threading +import time + from django.db import transaction, IntegrityError -from utils.constants import CacheKey from utils.shortcuts import rand_str from judge.languages import languages from .models import SysOptions as SysOptionsModel +class my_property: + """ + 在 metaclass 中使用,以实现: + 1. ttl = None,不缓存 + 2. ttl is callable,条件缓存 + 3. 缓存 ttl 秒 + """ + def __init__(self, func=None, fset=None, ttl=None): + self.fset = fset + self.local = threading.local() + self.ttl = ttl + self._check_ttl(ttl) + self.func = func + functools.update_wrapper(self, func) + + def _check_ttl(self, value): + if value is None or callable(value): + return + return self._check_timeout(value) + + def _check_timeout(self, value): + if not isinstance(value, int): + raise ValueError(f"Invalid timeout type: {type(value)}") + if value < 0: + raise ValueError("Invalid timeout value, it must >= 0") + + def __get__(self, obj, cls): + if obj is None: + return self + + now = time.time() + if self.ttl: + if hasattr(self.local, "value"): + value, expire_at = self.local.value + if now < expire_at: + return value + + value = self.func(obj) + + # 如果定义了条件缓存, ttl 是一个函数,返回要缓存多久;返回 0 代表不要缓存 + if callable(self.ttl): + # 而且条件缓存说不要缓存,那就直接返回,不要设置 local + timeout = self.ttl(value) + self._check_timeout(timeout) + + if timeout == 0: + return value + elif timeout > 0: + self.local.value = (value, now + timeout) + else: + # ttl 是一个数字 + self.local.value = (value, now + self.ttl) + return value + else: + return self.func(obj) + + def __set__(self, obj, value): + if not self.fset: + raise AttributeError("can't set attribute") + self.fset(obj, value) + if hasattr(self.local, "value"): + del self.local.value + + def setter(self, func): + self.fset = func + return self + + def __call__(self, func, *args, **kwargs) -> "my_property": + if self.func is None: + self.func = func + functools.update_wrapper(self, func) + return self + + +DEFAULT_SHORT_TTL = 2 + + def default_token(): token = os.environ.get("JUDGE_SERVER_TOKEN") return token if token else rand_str() @@ -41,23 +120,10 @@ class OptionDefaultValue: class _SysOptionsMeta(type): - @classmethod - def _set_cache(mcs, option_key, option_value): - cache.set(f"{CacheKey.option}:{option_key}", option_value, timeout=60) - - @classmethod - def _del_cache(mcs, option_key): - cache.delete(f"{CacheKey.option}:{option_key}") - @classmethod def _get_keys(cls): return [key for key in OptionKeys.__dict__ if not key.startswith("__")] - def rebuild_cache(cls): - for key in cls._get_keys(): - # get option 的时候会写 cache 的 - cls._get_option(key, use_cache=False) - @classmethod def _init_option(mcs): for item in mcs._get_keys(): @@ -71,19 +137,14 @@ class _SysOptionsMeta(type): pass @classmethod - def _get_option(mcs, option_key, use_cache=True): + def _get_option(mcs, option_key): try: - if use_cache: - option = cache.get(f"{CacheKey.option}:{option_key}") - if option: - return option option = SysOptionsModel.objects.get(key=option_key) value = option.value - mcs._set_cache(option_key, value) return value except SysOptionsModel.DoesNotExist: mcs._init_option() - return mcs._get_option(option_key, use_cache=use_cache) + return mcs._get_option(option_key) @classmethod def _set_option(mcs, option_key: str, option_value): @@ -92,7 +153,6 @@ class _SysOptionsMeta(type): option = SysOptionsModel.objects.select_for_update().get(key=option_key) option.value = option_value option.save() - mcs._del_cache(option_key) except SysOptionsModel.DoesNotExist: mcs._init_option() mcs._set_option(option_key, option_value) @@ -105,7 +165,6 @@ class _SysOptionsMeta(type): value = option.value + 1 option.value = value option.save() - mcs._del_cache(option_key) except SysOptionsModel.DoesNotExist: mcs._init_option() return mcs._increment(option_key) @@ -122,7 +181,7 @@ class _SysOptionsMeta(type): result[key] = mcs._get_option(key) return result - @property + @my_property(ttl=DEFAULT_SHORT_TTL) def website_base_url(cls): return cls._get_option(OptionKeys.website_base_url) @@ -130,7 +189,7 @@ class _SysOptionsMeta(type): def website_base_url(cls, value): cls._set_option(OptionKeys.website_base_url, value) - @property + @my_property(ttl=DEFAULT_SHORT_TTL) def website_name(cls): return cls._get_option(OptionKeys.website_name) @@ -138,7 +197,7 @@ class _SysOptionsMeta(type): def website_name(cls, value): cls._set_option(OptionKeys.website_name, value) - @property + @my_property(ttl=DEFAULT_SHORT_TTL) def website_name_shortcut(cls): return cls._get_option(OptionKeys.website_name_shortcut) @@ -146,7 +205,7 @@ class _SysOptionsMeta(type): def website_name_shortcut(cls, value): cls._set_option(OptionKeys.website_name_shortcut, value) - @property + @my_property(ttl=DEFAULT_SHORT_TTL) def website_footer(cls): return cls._get_option(OptionKeys.website_footer) @@ -154,7 +213,7 @@ class _SysOptionsMeta(type): def website_footer(cls, value): cls._set_option(OptionKeys.website_footer, value) - @property + @my_property def allow_register(cls): return cls._get_option(OptionKeys.allow_register) @@ -162,7 +221,7 @@ class _SysOptionsMeta(type): def allow_register(cls, value): cls._set_option(OptionKeys.allow_register, value) - @property + @my_property(ttl=DEFAULT_SHORT_TTL) def submission_list_show_all(cls): return cls._get_option(OptionKeys.submission_list_show_all) @@ -170,7 +229,7 @@ class _SysOptionsMeta(type): def submission_list_show_all(cls, value): cls._set_option(OptionKeys.submission_list_show_all, value) - @property + @my_property def smtp_config(cls): return cls._get_option(OptionKeys.smtp_config) @@ -178,7 +237,7 @@ class _SysOptionsMeta(type): def smtp_config(cls, value): cls._set_option(OptionKeys.smtp_config, value) - @property + @my_property def judge_server_token(cls): return cls._get_option(OptionKeys.judge_server_token) @@ -186,7 +245,7 @@ class _SysOptionsMeta(type): def judge_server_token(cls, value): cls._set_option(OptionKeys.judge_server_token, value) - @property + @my_property def throttling(cls): return cls._get_option(OptionKeys.throttling) @@ -194,7 +253,7 @@ class _SysOptionsMeta(type): def throttling(cls, value): cls._set_option(OptionKeys.throttling, value) - @property + @my_property(ttl=DEFAULT_SHORT_TTL) def languages(cls): return cls._get_option(OptionKeys.languages) @@ -202,15 +261,15 @@ class _SysOptionsMeta(type): def languages(cls, value): cls._set_option(OptionKeys.languages, value) - @property + @my_property(ttl=DEFAULT_SHORT_TTL) def spj_languages(cls): return [item for item in cls.languages if "spj" in item] - @property + @my_property(ttl=DEFAULT_SHORT_TTL) def language_names(cls): return [item["name"] for item in languages] - @property + @my_property(ttl=DEFAULT_SHORT_TTL) def spj_language_names(cls): return [item["name"] for item in cls.languages if "spj" in item] diff --git a/problem/migrations/0013_problem_io_mode.py b/problem/migrations/0013_problem_io_mode.py new file mode 100644 index 0000000..17136c8 --- /dev/null +++ b/problem/migrations/0013_problem_io_mode.py @@ -0,0 +1,20 @@ +# Generated by Django 2.1.7 on 2019-03-12 07:13 + +import django.contrib.postgres.fields.jsonb +from django.db import migrations +import problem.models + + +class Migration(migrations.Migration): + + dependencies = [ + ('problem', '0012_auto_20180501_0436'), + ] + + operations = [ + migrations.AddField( + model_name='problem', + name='io_mode', + field=django.contrib.postgres.fields.jsonb.JSONField(default=problem.models._default_io_mode), + ), + ] diff --git a/problem/migrations/0014_problem_share_submission.py b/problem/migrations/0014_problem_share_submission.py new file mode 100644 index 0000000..c764d7c --- /dev/null +++ b/problem/migrations/0014_problem_share_submission.py @@ -0,0 +1,18 @@ +# Generated by Django 2.1.7 on 2019-03-13 09:38 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('problem', '0013_problem_io_mode'), + ] + + operations = [ + migrations.AddField( + model_name='problem', + name='share_submission', + field=models.BooleanField(default=False), + ), + ] diff --git a/problem/models.py b/problem/models.py index 29e317b..c16aead 100644 --- a/problem/models.py +++ b/problem/models.py @@ -25,10 +25,19 @@ class ProblemDifficulty(object): Low = "Low" +class ProblemIOMode(Choices): + standard = "Standard IO" + file = "File IO" + + +def _default_io_mode(): + return {"io_mode": ProblemIOMode.standard, "input": "input.txt", "output": "output.txt"} + + class Problem(models.Model): # display ID _id = models.TextField(db_index=True) - contest = models.ForeignKey(Contest, null=True) + contest = models.ForeignKey(Contest, null=True, on_delete=models.CASCADE) # for contest problem is_public = models.BooleanField(default=False) title = models.TextField() @@ -47,11 +56,13 @@ class Problem(models.Model): create_time = models.DateTimeField(auto_now_add=True) # we can not use auto_now here last_update_time = models.DateTimeField(null=True) - created_by = models.ForeignKey(User) + created_by = models.ForeignKey(User, on_delete=models.CASCADE) # ms time_limit = models.IntegerField() # MB memory_limit = models.IntegerField() + # io mode + io_mode = JSONField(default=_default_io_mode) # special judge related spj = models.BooleanField(default=False) spj_language = models.TextField(null=True) @@ -69,6 +80,7 @@ class Problem(models.Model): accepted_number = models.BigIntegerField(default=0) # {JudgeStatus.ACCEPTED: 3, JudgeStaus.WRONG_ANSWER: 11}, the number means count statistic_info = JSONField(default=dict) + share_submission = models.BooleanField(default=False) class Meta: db_table = "problem" diff --git a/problem/serializers.py b/problem/serializers.py index 592c758..2c09a6c 100644 --- a/problem/serializers.py +++ b/problem/serializers.py @@ -1,3 +1,5 @@ +import re + from django import forms from options.options import SysOptions @@ -5,7 +7,7 @@ from utils.api import UsernameSerializer, serializers from utils.constants import Difficulty from utils.serializers import LanguageNameMultiChoiceField, SPJLanguageNameChoiceField, LanguageNameChoiceField -from .models import Problem, ProblemRuleType, ProblemTag +from .models import Problem, ProblemRuleType, ProblemTag, ProblemIOMode from .utils import parse_problem_template @@ -29,6 +31,20 @@ class CreateProblemCodeTemplateSerializer(serializers.Serializer): pass +class ProblemIOModeSerializer(serializers.Serializer): + io_mode = serializers.ChoiceField(choices=ProblemIOMode.choices()) + input = serializers.CharField() + output = serializers.CharField() + + def validate(self, attrs): + if attrs["input"] == attrs["output"]: + raise serializers.ValidationError("Invalid io mode") + for item in (attrs["input"], attrs["output"]): + if not re.match("^[a-zA-Z0-9.]+$", item): + raise serializers.ValidationError("Invalid io file name format") + return attrs + + class CreateOrEditProblemSerializer(serializers.Serializer): _id = serializers.CharField(max_length=32, allow_blank=True, allow_null=True) title = serializers.CharField(max_length=1024) @@ -43,6 +59,7 @@ class CreateOrEditProblemSerializer(serializers.Serializer): languages = LanguageNameMultiChoiceField() template = serializers.DictField(child=serializers.CharField(min_length=1)) rule_type = serializers.ChoiceField(choices=[ProblemRuleType.ACM, ProblemRuleType.OI]) + io_mode = ProblemIOModeSerializer() spj = serializers.BooleanField() spj_language = SPJLanguageNameChoiceField(allow_blank=True, allow_null=True) spj_code = serializers.CharField(allow_blank=True, allow_null=True) @@ -52,6 +69,7 @@ class CreateOrEditProblemSerializer(serializers.Serializer): tags = serializers.ListField(child=serializers.CharField(max_length=32), allow_empty=False) hint = serializers.CharField(allow_blank=True, allow_null=True) source = serializers.CharField(max_length=256, allow_blank=True, allow_null=True) + share_submission = serializers.BooleanField() class CreateProblemSerializer(CreateOrEditProblemSerializer): diff --git a/problem/tests.py b/problem/tests.py index f4e7321..7074d2f 100644 --- a/problem/tests.py +++ b/problem/tests.py @@ -9,7 +9,7 @@ from django.conf import settings from utils.api.tests import APITestCase -from .models import ProblemTag +from .models import ProblemTag, ProblemIOMode from .models import Problem, ProblemRuleType from contest.models import Contest from contest.tests import DEFAULT_CONTEST_DATA @@ -25,6 +25,8 @@ DEFAULT_PROBLEM_DATA = {"_id": "A-110", "title": "test", "description": "
test "test_case_score": [{"output_name": "1.out", "input_name": "1.in", "output_size": 0, "stripped_output_md5": "d41d8cd98f00b204e9800998ecf8427e", "input_size": 0, "score": 0}], + "io_mode": {"io_mode": ProblemIOMode.standard, "input": "input.txt", "output": "output.txt"}, + "share_submission": False, "rule_type": "ACM", "hint": "
test
", "source": "test"} diff --git a/problem/utils.py b/problem/utils.py index df70f9b..c530263 100644 --- a/problem/utils.py +++ b/problem/utils.py @@ -1,4 +1,6 @@ import re +from functools import lru_cache + TEMPLATE_BASE = """//PREPEND BEGIN {} @@ -13,6 +15,7 @@ TEMPLATE_BASE = """//PREPEND BEGIN //APPEND END""" +@lru_cache(maxsize=100) def parse_problem_template(template_str): prepend = re.findall(r"//PREPEND BEGIN\n([\s\S]+?)//PREPEND END", template_str) template = re.findall(r"//TEMPLATE BEGIN\n([\s\S]+?)//TEMPLATE END", template_str) @@ -22,5 +25,6 @@ def parse_problem_template(template_str): "append": append[0] if append else ""} +@lru_cache(maxsize=100) def build_problem_template(prepend, template, append): return TEMPLATE_BASE.format(prepend, template, append) diff --git a/problem/views/admin.py b/problem/views/admin.py index e06ff61..5a5268a 100644 --- a/problem/views/admin.py +++ b/problem/views/admin.py @@ -300,8 +300,6 @@ class ProblemAPI(ProblemBase): except Problem.DoesNotExist: return self.error("Problem does not exists") ensure_created_by(problem, request.user) - if Submission.objects.filter(problem=problem).exists(): - return self.error("Can't delete the problem as it has submissions") d = os.path.join(settings.TEST_CASE_DIR, problem.test_case_id) if os.path.isdir(d): shutil.rmtree(d, ignore_errors=True) @@ -541,7 +539,7 @@ class ExportProblemAPI(APIView): with zipfile.ZipFile(path, "w") as zip_file: for index, problem in enumerate(problems): self.process_one_problem(zip_file=zip_file, user=request.user, problem=problem, index=index + 1) - delete_files.apply_async((path,), countdown=300) + delete_files.send_with_options(args=(path,), delay=300_000) resp = FileResponse(open(path, "rb")) resp["Content-Type"] = "application/zip" resp["Content-Disposition"] = f"attachment;filename=problem-export.zip" diff --git a/problem/views/oj.py b/problem/views/oj.py index ed39dda..b091ff3 100644 --- a/problem/views/oj.py +++ b/problem/views/oj.py @@ -25,7 +25,7 @@ class PickOneAPI(APIView): class ProblemAPI(APIView): @staticmethod def _add_problem_status(request, queryset_values): - if request.user.is_authenticated(): + if request.user.is_authenticated: profile = request.user.userprofile acm_problems_status = profile.acm_problems_status.get("problems", {}) oi_problems_status = profile.oi_problems_status.get("problems", {}) @@ -81,7 +81,7 @@ class ProblemAPI(APIView): class ContestProblemAPI(APIView): def _add_problem_status(self, request, queryset_values): - if request.user.is_authenticated(): + if request.user.is_authenticated: profile = request.user.userprofile if self.contest.rule_type == ContestRuleType.ACM: problems_status = profile.acm_problems_status.get("contest_problems", {}) diff --git a/submission/models.py b/submission/models.py index 261052e..2835fe0 100644 --- a/submission/models.py +++ b/submission/models.py @@ -22,8 +22,8 @@ class JudgeStatus: class Submission(models.Model): id = models.TextField(default=rand_str, primary_key=True, db_index=True) - contest = models.ForeignKey(Contest, null=True) - problem = models.ForeignKey(Problem) + contest = models.ForeignKey(Contest, null=True, on_delete=models.CASCADE) + problem = models.ForeignKey(Problem, on_delete=models.CASCADE) create_time = models.DateTimeField(auto_now_add=True) user_id = models.IntegerField(db_index=True) username = models.TextField() @@ -41,6 +41,7 @@ class Submission(models.Model): def check_user_permission(self, user, check_share=True): return self.user_id == user.id or \ (check_share and self.shared is True) or \ + (check_share and self.problem.share_submission) or \ user.is_super_admin() or \ user.can_mgmt_all_problem() or \ self.problem.created_by_id == user.id diff --git a/submission/serializers.py b/submission/serializers.py index b814bc7..5e48f3e 100644 --- a/submission/serializers.py +++ b/submission/serializers.py @@ -46,6 +46,6 @@ class SubmissionListSerializer(serializers.ModelSerializer): def get_show_link(self, obj): # 没传user或为匿名user - if self.user is None or not self.user.is_authenticated(): + if self.user is None or not self.user.is_authenticated: return False return obj.check_user_permission(self.user) diff --git a/submission/tests.py b/submission/tests.py index 42417ea..08ccbd4 100644 --- a/submission/tests.py +++ b/submission/tests.py @@ -57,7 +57,7 @@ class SubmissionListTest(SubmissionPrepare): self.assertSuccess(resp) -@mock.patch("submission.views.oj.judge_task.delay") +@mock.patch("submission.views.oj.judge_task.send") class SubmissionAPITest(SubmissionPrepare): def setUp(self): self._create_problem_and_submission() diff --git a/submission/views/admin.py b/submission/views/admin.py index d2312c0..8797256 100644 --- a/submission/views/admin.py +++ b/submission/views/admin.py @@ -18,5 +18,5 @@ class SubmissionRejudgeAPI(APIView): submission.statistic_info = {} submission.save() - judge_task.delay(submission.id, submission.problem.id) + judge_task.send(submission.id, submission.problem.id) return self.success() diff --git a/submission/views/oj.py b/submission/views/oj.py index 88a794b..0f7c4a5 100644 --- a/submission/views/oj.py +++ b/submission/views/oj.py @@ -80,7 +80,7 @@ class SubmissionAPI(APIView): contest_id=data.get("contest_id")) # use this for debug # JudgeDispatcher(submission.id, problem.id).judge() - judge_task.delay(submission.id, problem.id) + judge_task.send(submission.id, problem.id) if hide_id: return self.success() else: @@ -198,6 +198,6 @@ class SubmissionExistsAPI(APIView): def get(self, request): if not request.GET.get("problem_id"): return self.error("Parameter error, problem_id is required") - return self.success(request.user.is_authenticated() and + return self.success(request.user.is_authenticated and Submission.objects.filter(problem_id=request.GET["problem_id"], user_id=request.user.id).exists()) diff --git a/utils/api/api.py b/utils/api/api.py index 5b7a231..a603bfb 100644 --- a/utils/api/api.py +++ b/utils/api/api.py @@ -1,7 +1,6 @@ import functools import json import logging -from collections import OrderedDict from django.http import HttpResponse, QueryDict from django.utils.decorators import method_decorator @@ -89,20 +88,24 @@ class APIView(View): def error(self, msg="error", err="error"): return self.response({"error": err, "data": msg}) - def _serializer_error_to_str(self, errors): - for k, v in errors.items(): - if isinstance(v, list): - return k, v[0] - elif isinstance(v, OrderedDict): - for _k, _v in v.items(): - return self._serializer_error_to_str({_k: _v}) + def extract_errors(self, errors, key="field"): + if isinstance(errors, dict): + if not errors: + return key, "Invalid field" + key = list(errors.keys())[0] + return self.extract_errors(errors.pop(key), key) + elif isinstance(errors, list): + return self.extract_errors(errors[0], key) + + return key, errors def invalid_serializer(self, serializer): - k, v = self._serializer_error_to_str(serializer.errors) - if k != "non_field_errors": - return self.error(err="invalid-" + k, msg=k + ": " + v) + key, error = self.extract_errors(serializer.errors) + if key == "non_field_errors": + msg = error else: - return self.error(err="invalid-field", msg=v) + msg = f"{key}: {error}" + return self.error(err=f"invalid-{key}", msg=msg) def server_error(self): return self.error(err="server-error", msg="server error") diff --git a/utils/api/tests.py b/utils/api/tests.py index b47ceae..0f0a79b 100644 --- a/utils/api/tests.py +++ b/utils/api/tests.py @@ -1,4 +1,4 @@ -from django.core.urlresolvers import reverse +from django.urls import reverse from django.test.testcases import TestCase from rest_framework.test import APIClient diff --git a/utils/constants.py b/utils/constants.py index 50068f1..004b801 100644 --- a/utils/constants.py +++ b/utils/constants.py @@ -25,7 +25,6 @@ class CacheKey: waiting_queue = "waiting_queue" contest_rank_cache = "contest_rank_cache" website_config = "website_config" - option = "option" class Difficulty(Choices): diff --git a/utils/shortcuts.py b/utils/shortcuts.py index ea0d094..84e14fd 100644 --- a/utils/shortcuts.py +++ b/utils/shortcuts.py @@ -81,3 +81,14 @@ def send_email(smtp_config, from_name, to_email, to_name, subject, content): def get_env(name, default=""): return os.environ.get(name, default) + + +def DRAMATIQ_WORKER_ARGS(time_limit=3600_000, max_retries=0, max_age=7200_000): + return {"max_retries": max_retries, "time_limit": time_limit, "max_age": max_age} + + +def check_is_id(value): + try: + return int(value) > 0 + except Exception: + return False diff --git a/utils/tasks.py b/utils/tasks.py index 442b0bc..26a2180 100644 --- a/utils/tasks.py +++ b/utils/tasks.py @@ -1,8 +1,10 @@ import os -from celery import shared_task +import dramatiq + +from utils.shortcuts import DRAMATIQ_WORKER_ARGS -@shared_task +@dramatiq.actor(**DRAMATIQ_WORKER_ARGS()) def delete_files(*args): for item in args: try: diff --git a/utils/urls.py b/utils/urls.py index ca9fb0f..7e0128e 100644 --- a/utils/urls.py +++ b/utils/urls.py @@ -1,7 +1,8 @@ from django.conf.urls import url -from .views import SimditorImageUploadAPIView +from .views import SimditorImageUploadAPIView, SimditorFileUploadAPIView urlpatterns = [ - url(r"^upload_image/?$", SimditorImageUploadAPIView.as_view(), name="upload_image") + url(r"^upload_image/?$", SimditorImageUploadAPIView.as_view(), name="upload_image"), + url(r"^upload_file/?$", SimditorFileUploadAPIView.as_view(), name="upload_file") ] diff --git a/utils/views.py b/utils/views.py index c3e3861..ce30c71 100644 --- a/utils/views.py +++ b/utils/views.py @@ -1,6 +1,6 @@ import os from django.conf import settings -from account.serializers import ImageUploadForm +from account.serializers import ImageUploadForm, FileUploadForm from utils.shortcuts import rand_str from utils.api import CSRFExemptAPIView import logging @@ -35,10 +35,41 @@ class SimditorImageUploadAPIView(CSRFExemptAPIView): except IOError as e: logger.error(e) return self.response({ - "success": True, + "success": False, "msg": "Upload Error", - "file_path": f"{settings.UPLOAD_PREFIX}/{img_name}"}) + "file_path": ""}) return self.response({ "success": True, "msg": "Success", "file_path": f"{settings.UPLOAD_PREFIX}/{img_name}"}) + + +class SimditorFileUploadAPIView(CSRFExemptAPIView): + request_parsers = () + + def post(self, request): + form = FileUploadForm(request.POST, request.FILES) + if form.is_valid(): + file = form.cleaned_data["file"] + else: + return self.response({ + "success": False, + "msg": "Upload failed" + }) + + suffix = os.path.splitext(file.name)[-1].lower() + file_name = rand_str(10) + suffix + try: + with open(os.path.join(settings.UPLOAD_DIR, file_name), "wb") as f: + for chunk in file: + f.write(chunk) + except IOError as e: + logger.error(e) + return self.response({ + "success": False, + "msg": "Upload Error"}) + return self.response({ + "success": True, + "msg": "Success", + "file_path": f"{settings.UPLOAD_PREFIX}/{file_name}", + "file_name": file.name}) diff --git a/utils/xss_filter.py b/utils/xss_filter.py index 1b45d89..fe4f7aa 100644 --- a/utils/xss_filter.py +++ b/utils/xss_filter.py @@ -142,7 +142,7 @@ class XSSHtml(HTMLParser): return attrs def _true_url(self, url): - prog = re.compile(r"^(http|https|ftp)://.+", re.I | re.S) + prog = re.compile(r"(^(http|https|ftp)://.+)|(^/)", re.I | re.S) if prog.match(url): return url else: