Merge pull request #233 from QingdaoU/feature/django20

Feature/django20
This commit is contained in:
李扬
2019-03-26 10:52:34 +08:00
committed by GitHub
49 changed files with 521 additions and 245 deletions

4
.dockerignore Normal file
View File

@@ -0,0 +1,4 @@
venv
.idea
.git
.DS_Store

View File

@@ -4,6 +4,7 @@ exclude =
*/migrations/, */migrations/,
*settings.py *settings.py
*/apps.py */apps.py
venv/
max-line-length = 180 max-line-length = 180
inline-quotes = " inline-quotes = "
no-accept-encodings = True no-accept-encodings = True

View File

@@ -1,4 +1,4 @@
FROM python:3.6-alpine3.6 FROM python:3.7-alpine3.9
ENV OJ_ENV production ENV OJ_ENV production

View File

@@ -31,19 +31,19 @@ class BasePermissionDecorator(object):
class login_required(BasePermissionDecorator): class login_required(BasePermissionDecorator):
def check_permission(self): def check_permission(self):
return self.request.user.is_authenticated() return self.request.user.is_authenticated
class super_admin_required(BasePermissionDecorator): class super_admin_required(BasePermissionDecorator):
def check_permission(self): def check_permission(self):
user = self.request.user user = self.request.user
return user.is_authenticated() and user.is_super_admin() return user.is_authenticated and user.is_super_admin()
class admin_role_required(BasePermissionDecorator): class admin_role_required(BasePermissionDecorator):
def check_permission(self): def check_permission(self):
user = self.request.user user = self.request.user
return user.is_authenticated() and user.is_admin_role() return user.is_authenticated and user.is_admin_role()
class problem_permission_required(admin_role_required): class problem_permission_required(admin_role_required):
@@ -80,7 +80,7 @@ def check_contest_permission(check_type="details"):
return self.error("Contest %s doesn't exist" % contest_id) return self.error("Contest %s doesn't exist" % contest_id)
# Anonymous # Anonymous
if not user.is_authenticated(): if not user.is_authenticated:
return self.error("Please login first.") return self.error("Please login first.")
# creator or owner # creator or owner

View File

@@ -22,7 +22,7 @@ class APITokenAuthMiddleware(MiddlewareMixin):
class SessionRecordMiddleware(MiddlewareMixin): class SessionRecordMiddleware(MiddlewareMixin):
def process_request(self, request): def process_request(self, request):
request.ip = request.META.get(settings.IP_HEADER, request.META.get("REMOTE_ADDR")) request.ip = request.META.get(settings.IP_HEADER, request.META.get("REMOTE_ADDR"))
if request.user.is_authenticated(): if request.user.is_authenticated:
session = request.session session = request.session
session["user_agent"] = request.META.get("HTTP_USER_AGENT", "") session["user_agent"] = request.META.get("HTTP_USER_AGENT", "")
session["ip"] = request.ip session["ip"] = request.ip
@@ -37,7 +37,7 @@ class AdminRoleRequiredMiddleware(MiddlewareMixin):
def process_request(self, request): def process_request(self, request):
path = request.path_info path = request.path_info
if path.startswith("/admin/") or path.startswith("/api/admin/"): if path.startswith("/admin/") or path.startswith("/api/admin/"):
if not (request.user.is_authenticated() and request.user.is_admin_role()): if not (request.user.is_authenticated and request.user.is_admin_role()):
return JSONResponse.response({"error": "login-required", "data": "Please login in first"}) return JSONResponse.response({"error": "login-required", "data": "Please login in first"})

View File

@@ -60,7 +60,7 @@ class User(AbstractBaseUser):
return self.problem_permission == ProblemPermission.ALL return self.problem_permission == ProblemPermission.ALL
def is_contest_admin(self, contest): def is_contest_admin(self, contest):
return self.is_authenticated() and (contest.created_by == self or self.admin_type == AdminType.SUPER_ADMIN) return self.is_authenticated and (contest.created_by == self or self.admin_type == AdminType.SUPER_ADMIN)
class Meta: class Meta:
db_table = "user" db_table = "user"

View File

@@ -131,6 +131,10 @@ class ImageUploadForm(forms.Form):
image = forms.FileField() image = forms.FileField()
class FileUploadForm(forms.Form):
file = forms.FileField()
class RankInfoSerializer(serializers.ModelSerializer): class RankInfoSerializer(serializers.ModelSerializer):
user = UsernameSerializer() user = UsernameSerializer()

View File

@@ -1,14 +1,13 @@
import logging import logging
import dramatiq
from celery import shared_task
from options.options import SysOptions from options.options import SysOptions
from utils.shortcuts import send_email from utils.shortcuts import send_email, DRAMATIQ_WORKER_ARGS
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@shared_task @dramatiq.actor(**DRAMATIQ_WORKER_ARGS(max_retries=3))
def send_email_async(from_name, to_email, to_name, subject, content): def send_email_async(from_name, to_email, to_name, subject, content):
if not SysOptions.smtp_config: if not SysOptions.smtp_config:
return return

View File

@@ -101,13 +101,13 @@ class UserLoginAPITest(APITestCase):
self.assertDictEqual(response.data, {"error": None, "data": "Succeeded"}) self.assertDictEqual(response.data, {"error": None, "data": "Succeeded"})
user = auth.get_user(self.client) user = auth.get_user(self.client)
self.assertTrue(user.is_authenticated()) self.assertTrue(user.is_authenticated)
def test_login_with_correct_info_upper_username(self): def test_login_with_correct_info_upper_username(self):
resp = self.client.post(self.login_url, data={"username": self.username.upper(), "password": self.password}) resp = self.client.post(self.login_url, data={"username": self.username.upper(), "password": self.password})
self.assertDictEqual(resp.data, {"error": None, "data": "Succeeded"}) self.assertDictEqual(resp.data, {"error": None, "data": "Succeeded"})
user = auth.get_user(self.client) user = auth.get_user(self.client)
self.assertTrue(user.is_authenticated()) self.assertTrue(user.is_authenticated)
def test_login_with_wrong_info(self): def test_login_with_wrong_info(self):
response = self.client.post(self.login_url, response = self.client.post(self.login_url,
@@ -115,7 +115,7 @@ class UserLoginAPITest(APITestCase):
self.assertDictEqual(response.data, {"error": "error", "data": "Invalid username or password"}) self.assertDictEqual(response.data, {"error": "error", "data": "Invalid username or password"})
user = auth.get_user(self.client) user = auth.get_user(self.client)
self.assertFalse(user.is_authenticated()) self.assertFalse(user.is_authenticated)
def test_tfa_login(self): def test_tfa_login(self):
token = self._set_tfa() token = self._set_tfa()
@@ -129,7 +129,7 @@ class UserLoginAPITest(APITestCase):
self.assertDictEqual(response.data, {"error": None, "data": "Succeeded"}) self.assertDictEqual(response.data, {"error": None, "data": "Succeeded"})
user = auth.get_user(self.client) user = auth.get_user(self.client)
self.assertTrue(user.is_authenticated()) self.assertTrue(user.is_authenticated)
def test_tfa_login_wrong_code(self): def test_tfa_login_wrong_code(self):
self._set_tfa() self._set_tfa()
@@ -140,7 +140,7 @@ class UserLoginAPITest(APITestCase):
self.assertDictEqual(response.data, {"error": "error", "data": "Invalid two factor verification code"}) self.assertDictEqual(response.data, {"error": "error", "data": "Invalid two factor verification code"})
user = auth.get_user(self.client) user = auth.get_user(self.client)
self.assertFalse(user.is_authenticated()) self.assertFalse(user.is_authenticated)
def test_tfa_login_without_code(self): def test_tfa_login_without_code(self):
self._set_tfa() self._set_tfa()
@@ -150,7 +150,7 @@ class UserLoginAPITest(APITestCase):
self.assertDictEqual(response.data, {"error": "error", "data": "tfa_required"}) self.assertDictEqual(response.data, {"error": "error", "data": "tfa_required"})
user = auth.get_user(self.client) user = auth.get_user(self.client)
self.assertFalse(user.is_authenticated()) self.assertFalse(user.is_authenticated)
def test_user_disabled(self): def test_user_disabled(self):
self.user.is_disabled = True self.user.is_disabled = True
@@ -304,7 +304,7 @@ class TwoFactorAuthAPITest(APITestCase):
self.assertEqual(user.two_factor_auth, False) self.assertEqual(user.two_factor_auth, False)
@mock.patch("account.views.oj.send_email_async.delay") @mock.patch("account.views.oj.send_email_async.send")
class ApplyResetPasswordAPITest(CaptchaTest): class ApplyResetPasswordAPITest(CaptchaTest):
def setUp(self): def setUp(self):
self.create_user("test", "test123", login=False) self.create_user("test", "test123", login=False)
@@ -317,20 +317,20 @@ class ApplyResetPasswordAPITest(CaptchaTest):
def _refresh_captcha(self): def _refresh_captcha(self):
self.data["captcha"] = self._set_captcha(self.client.session) self.data["captcha"] = self._set_captcha(self.client.session)
def test_apply_reset_password(self, send_email_delay): def test_apply_reset_password(self, send_email_send):
resp = self.client.post(self.url, data=self.data) resp = self.client.post(self.url, data=self.data)
self.assertSuccess(resp) self.assertSuccess(resp)
send_email_delay.assert_called() send_email_send.assert_called()
def test_apply_reset_password_twice_in_20_mins(self, send_email_delay): def test_apply_reset_password_twice_in_20_mins(self, send_email_send):
self.test_apply_reset_password() self.test_apply_reset_password()
send_email_delay.reset_mock() send_email_send.reset_mock()
self._refresh_captcha() self._refresh_captcha()
resp = self.client.post(self.url, data=self.data) resp = self.client.post(self.url, data=self.data)
self.assertDictEqual(resp.data, {"error": "error", "data": "You can only reset password once per 20 minutes"}) self.assertDictEqual(resp.data, {"error": "error", "data": "You can only reset password once per 20 minutes"})
send_email_delay.assert_not_called() send_email_send.assert_not_called()
def test_apply_reset_password_again_after_20_mins(self, send_email_delay): def test_apply_reset_password_again_after_20_mins(self, send_email_send):
self.test_apply_reset_password() self.test_apply_reset_password()
user = User.objects.first() user = User.objects.first()
user.reset_password_token_expire_time = now() - timedelta(minutes=21) user.reset_password_token_expire_time = now() - timedelta(minutes=21)

View File

@@ -121,25 +121,15 @@ class UserAdminAPI(APIView):
Q(email__icontains=keyword)) Q(email__icontains=keyword))
return self.success(self.paginate_data(request, user, UserAdminSerializer)) return self.success(self.paginate_data(request, user, UserAdminSerializer))
def delete_one(self, user_id):
try:
user = User.objects.get(id=user_id)
except User.DoesNotExist:
return f"User {user_id} does not exist"
if Submission.objects.filter(user_id=user_id).exists():
return f"Can't delete the user {user_id} as he/she has submissions"
user.delete()
@super_admin_required @super_admin_required
def delete(self, request): def delete(self, request):
id = request.GET.get("id") id = request.GET.get("id")
if not id: if not id:
return self.error("Invalid Parameter, id is required") return self.error("Invalid Parameter, id is required")
for user_id in id.split(","): ids = id.split(",")
if user_id: if str(request.user.id) in ids:
error = self.delete_one(user_id) return self.error("Current user can not be deleted")
if error: User.objects.filter(id__in=ids).delete()
return self.error(error)
return self.success() return self.success()

View File

@@ -35,7 +35,7 @@ class UserProfileAPI(APIView):
判断是否登录, 若登录返回用户信息 判断是否登录, 若登录返回用户信息
""" """
user = request.user user = request.user
if not user.is_authenticated(): if not user.is_authenticated:
return self.success() return self.success()
show_real_name = False show_real_name = False
username = request.GET.get("username") username = request.GET.get("username")
@@ -280,7 +280,7 @@ class UserChangePasswordAPI(APIView):
class ApplyResetPasswordAPI(APIView): class ApplyResetPasswordAPI(APIView):
@validate_serializer(ApplyResetPasswordSerializer) @validate_serializer(ApplyResetPasswordSerializer)
def post(self, request): def post(self, request):
if request.user.is_authenticated(): if request.user.is_authenticated:
return self.error("You have already logged in, are you kidding me? ") return self.error("You have already logged in, are you kidding me? ")
data = request.data data = request.data
captcha = Captcha(request) captcha = Captcha(request)
@@ -302,11 +302,11 @@ class ApplyResetPasswordAPI(APIView):
"link": f"{SysOptions.website_base_url}/reset-password/{user.reset_password_token}" "link": f"{SysOptions.website_base_url}/reset-password/{user.reset_password_token}"
} }
email_html = render_to_string("reset_password_email.html", render_data) email_html = render_to_string("reset_password_email.html", render_data)
send_email_async.delay(from_name=SysOptions.website_name_shortcut, send_email_async.send(from_name=SysOptions.website_name_shortcut,
to_email=user.email, to_email=user.email,
to_name=user.username, to_name=user.username,
subject=f"Reset your password", subject=f"Reset your password",
content=email_html) content=email_html)
return self.success("Succeeded") return self.success("Succeeded")

View File

@@ -9,7 +9,7 @@ class Announcement(models.Model):
# HTML # HTML
content = RichTextField() content = RichTextField()
create_time = models.DateTimeField(auto_now_add=True) create_time = models.DateTimeField(auto_now_add=True)
created_by = models.ForeignKey(User) created_by = models.ForeignKey(User, on_delete=models.CASCADE)
last_update_time = models.DateTimeField(auto_now=True) last_update_time = models.DateTimeField(auto_now=True)
visible = models.BooleanField(default=True) visible = models.BooleanField(default=True)

View File

@@ -0,0 +1,23 @@
# Generated by Django 2.1.7 on 2019-03-26 02:01
from django.conf import settings
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
('contest', '0009_auto_20180501_0436'),
]
operations = [
migrations.AlterUniqueTogether(
name='acmcontestrank',
unique_together={('user', 'contest')},
),
migrations.AlterUniqueTogether(
name='oicontestrank',
unique_together={('user', 'contest')},
),
]

View File

@@ -20,7 +20,7 @@ class Contest(models.Model):
end_time = models.DateTimeField() end_time = models.DateTimeField()
create_time = models.DateTimeField(auto_now_add=True) create_time = models.DateTimeField(auto_now_add=True)
last_update_time = models.DateTimeField(auto_now=True) last_update_time = models.DateTimeField(auto_now=True)
created_by = models.ForeignKey(User) created_by = models.ForeignKey(User, on_delete=models.CASCADE)
# 是否可见 false的话相当于删除 # 是否可见 false的话相当于删除
visible = models.BooleanField(default=True) visible = models.BooleanField(default=True)
allowed_ip_ranges = JSONField(default=list) allowed_ip_ranges = JSONField(default=list)
@@ -47,7 +47,7 @@ class Contest(models.Model):
def problem_details_permission(self, user): def problem_details_permission(self, user):
return self.rule_type == ContestRuleType.ACM or \ return self.rule_type == ContestRuleType.ACM or \
self.status == ContestStatus.CONTEST_ENDED or \ self.status == ContestStatus.CONTEST_ENDED or \
user.is_authenticated() and user.is_contest_admin(self) or \ user.is_authenticated and user.is_contest_admin(self) or \
self.real_time_rank self.real_time_rank
class Meta: class Meta:
@@ -56,8 +56,8 @@ class Contest(models.Model):
class AbstractContestRank(models.Model): class AbstractContestRank(models.Model):
user = models.ForeignKey(User) user = models.ForeignKey(User, on_delete=models.CASCADE)
contest = models.ForeignKey(Contest) contest = models.ForeignKey(Contest, on_delete=models.CASCADE)
submission_number = models.IntegerField(default=0) submission_number = models.IntegerField(default=0)
class Meta: class Meta:
@@ -74,6 +74,7 @@ class ACMContestRank(AbstractContestRank):
class Meta: class Meta:
db_table = "acm_contest_rank" db_table = "acm_contest_rank"
unique_together = (("user", "contest"),)
class OIContestRank(AbstractContestRank): class OIContestRank(AbstractContestRank):
@@ -84,13 +85,14 @@ class OIContestRank(AbstractContestRank):
class Meta: class Meta:
db_table = "oi_contest_rank" db_table = "oi_contest_rank"
unique_together = (("user", "contest"),)
class ContestAnnouncement(models.Model): class ContestAnnouncement(models.Model):
contest = models.ForeignKey(Contest) contest = models.ForeignKey(Contest, on_delete=models.CASCADE)
title = models.TextField() title = models.TextField()
content = RichTextField() content = RichTextField()
created_by = models.ForeignKey(User) created_by = models.ForeignKey(User, on_delete=models.CASCADE)
visible = models.BooleanField(default=True) visible = models.BooleanField(default=True)
create_time = models.DateTimeField(auto_now_add=True) create_time = models.DateTimeField(auto_now_add=True)

View File

@@ -45,7 +45,7 @@ class ContestAdminAPITest(APITestCase):
response_data = response.data["data"] response_data = response.data["data"]
for k in data.keys(): for k in data.keys():
if isinstance(data[k], datetime): if isinstance(data[k], datetime):
continue continue
self.assertEqual(response_data[k], data[k]) self.assertEqual(response_data[k], data[k])
def test_get_contests(self): def test_get_contests(self):

View File

@@ -234,7 +234,7 @@ class DownloadContestSubmissions(APIView):
exclude_admin = request.GET.get("exclude_admin") == "1" exclude_admin = request.GET.get("exclude_admin") == "1"
zip_path = self._dump_submissions(contest, exclude_admin) zip_path = self._dump_submissions(contest, exclude_admin)
delete_files.apply_async((zip_path,), countdown=300) delete_files.send_with_options(args=(zip_path,), delay=300_000)
resp = FileResponse(open(zip_path, "rb")) resp = FileResponse(open(zip_path, "rb"))
resp["Content-Type"] = "application/zip" resp["Content-Type"] = "application/zip"
resp["Content-Disposition"] = f"attachment;filename={os.path.basename(zip_path)}" resp["Content-Disposition"] = f"attachment;filename={os.path.basename(zip_path)}"

View File

@@ -8,7 +8,7 @@ from django.core.cache import cache
from problem.models import Problem from problem.models import Problem
from utils.api import APIView, validate_serializer from utils.api import APIView, validate_serializer
from utils.constants import CacheKey from utils.constants import CacheKey
from utils.shortcuts import datetime2str from utils.shortcuts import datetime2str, check_is_id
from account.models import AdminType from account.models import AdminType
from account.decorators import login_required, check_contest_permission from account.decorators import login_required, check_contest_permission
@@ -35,7 +35,7 @@ class ContestAnnouncementListAPI(APIView):
class ContestAPI(APIView): class ContestAPI(APIView):
def get(self, request): def get(self, request):
id = request.GET.get("id") id = request.GET.get("id")
if not id: if not id or not check_is_id(id):
return self.error("Invalid parameter, id is required") return self.error("Invalid parameter, id is required")
try: try:
contest = Contest.objects.get(id=id, visible=True) contest = Contest.objects.get(id=id, visible=True)
@@ -121,7 +121,7 @@ class ContestRankAPI(APIView):
def get(self, request): def get(self, request):
download_csv = request.GET.get("download_csv") download_csv = request.GET.get("download_csv")
force_refresh = request.GET.get("force_refresh") force_refresh = request.GET.get("force_refresh")
is_contest_admin = request.user.is_authenticated() and request.user.is_contest_admin(self.contest) is_contest_admin = request.user.is_authenticated and request.user.is_contest_admin(self.contest)
if self.contest.rule_type == ContestRuleType.OI: if self.contest.rule_type == ContestRuleType.OI:
serializer = OIContestRankSerializer serializer = OIContestRankSerializer
else: else:

View File

@@ -1,19 +1,32 @@
django==1.11.4 certifi==2019.3.9
djangorestframework==3.4.0 chardet==3.0.4
pillow coverage==4.5.3
otpauth Django==2.1.7
flake8-quotes django-redis==4.10.0
pytz djangorestframework==3.8.2
coverage entrypoints==0.3
python-dateutil Envelopes==0.4
celery flake8==3.7.7
Envelopes flake8-coding==1.3.1
qrcode flake8-quotes==1.0.0
flake8-coding gunicorn==19.9.0
requests idna==2.8
django-redis jsonfield==2.0.2
psycopg2-binary mccabe==0.6.1
gunicorn otpauth==1.0.1
jsonfield Pillow==5.4.1
XlsxWriter psycopg2-binary==2.7.7
raven pycodestyle==2.5.0
pyflakes==2.1.1
python-dateutil==2.8.0
pytz==2018.9
qrcode==6.1
raven==6.10.0
redis==3.2.0
requests==2.21.0
six==1.12.0
urllib3==1.24.1
XlsxWriter==1.1.5
django-dramatiq==0.5.0
dramatiq==1.3.0
django-dbconn-retry==0.1.5

View File

@@ -38,12 +38,12 @@ startsecs=5
stopwaitsecs = 5 stopwaitsecs = 5
killasgroup=true killasgroup=true
[program:celery] [program:dramatiq]
command=celery -A oj worker -l warning --autoscale 2,%(ENV_MAX_WORKER_NUM)s command=python3 manage.py rundramatiq --no-reload --processes %(ENV_MAX_WORKER_NUM)s --threads 4
directory=/app/ directory=/app/
user=nobody user=nobody
stdout_logfile=/data/log/celery.log stdout_logfile=/data/log/dramatiq.log
stderr_logfile=/data/log/celery.log stderr_logfile=/data/log/dramatiq.log
autostart=true autostart=true
autorestart=true autorestart=true
startsecs=5 startsecs=5

View File

@@ -1,5 +1,19 @@
{ {
"update": [ "update": [
{
"version": "2019-03-25",
"level": "Recommend",
"title": "2019-03-25",
"details": [
"Update Django to version 2.1 and Python to version 3.7",
"Replace celery with dramatiq",
"Add problem file IO Mode",
"You can add attachments in all editor",
"You can upload source code file in submission editor",
"Frontend and UI improvements",
"Fixed a lot of bugs"
]
},
{ {
"version": "2018-12-15", "version": "2018-12-15",
"level": "Recommend", "level": "Recommend",

View File

@@ -4,7 +4,7 @@ import logging
from urllib.parse import urljoin from urllib.parse import urljoin
import requests import requests
from django.db import transaction from django.db import transaction, IntegrityError
from django.db.models import F from django.db.models import F
from account.models import User from account.models import User
@@ -26,7 +26,28 @@ def process_pending_task():
# 防止循环引入 # 防止循环引入
from judge.tasks import judge_task from judge.tasks import judge_task
data = json.loads(cache.rpop(CacheKey.waiting_queue).decode("utf-8")) data = json.loads(cache.rpop(CacheKey.waiting_queue).decode("utf-8"))
judge_task.delay(**data) judge_task.send(**data)
class ChooseJudgeServer:
def __init__(self):
self.server = None
def __enter__(self) -> [JudgeServer, None]:
with transaction.atomic():
servers = JudgeServer.objects.select_for_update().filter(is_disabled=False).order_by("task_number")
servers = [s for s in servers if s.status == "normal"]
for server in servers:
if server.task_number <= server.cpu_core * 2:
server.task_number = F("task_number") + 1
server.save()
self.server = server
return server
return None
def __exit__(self, exc_type, exc_val, exc_tb):
if self.server:
JudgeServer.objects.filter(id=self.server.id).update(task_number=F("task_number") - 1)
class DispatcherBase(object): class DispatcherBase(object):
@@ -42,25 +63,6 @@ class DispatcherBase(object):
except Exception as e: except Exception as e:
logger.exception(e) logger.exception(e)
@staticmethod
def choose_judge_server():
with transaction.atomic():
servers = JudgeServer.objects.select_for_update().filter(is_disabled=False).order_by("task_number")
servers = [s for s in servers if s.status == "normal"]
for server in servers:
if server.task_number <= server.cpu_core * 2:
server.task_number = F("task_number") + 1
server.save()
return server
@staticmethod
def release_judge_server(judge_server_id):
with transaction.atomic():
# 使用原子操作, 同时因为use和release中间间隔了判题过程,需要重新查询一下
server = JudgeServer.objects.get(id=judge_server_id)
server.task_number = F("task_number") - 1
server.save()
class SPJCompiler(DispatcherBase): class SPJCompiler(DispatcherBase):
def __init__(self, spj_code, spj_version, spj_language): def __init__(self, spj_code, spj_version, spj_language):
@@ -74,13 +76,14 @@ class SPJCompiler(DispatcherBase):
} }
def compile_spj(self): def compile_spj(self):
server = self.choose_judge_server() with ChooseJudgeServer() as server:
if not server: if not server:
return "No available judge_server" return "No available judge_server"
result = self._request(urljoin(server.service_url, "compile_spj"), data=self.data) result = self._request(urljoin(server.service_url, "compile_spj"), data=self.data)
self.release_judge_server(server.id) if not result:
if result["err"]: return "Failed to call judge server"
return result["data"] if result["err"]:
return result["data"]
class JudgeDispatcher(DispatcherBase): class JudgeDispatcher(DispatcherBase):
@@ -118,12 +121,6 @@ class JudgeDispatcher(DispatcherBase):
self.submission.statistic_info["score"] = score self.submission.statistic_info["score"] = score
def judge(self): def judge(self):
server = self.choose_judge_server()
if not server:
data = {"submission_id": self.submission.id, "problem_id": self.problem.id}
cache.lpush(CacheKey.waiting_queue, json.dumps(data))
return
language = self.submission.language language = self.submission.language
sub_config = list(filter(lambda item: language == item["name"], SysOptions.languages))[0] sub_config = list(filter(lambda item: language == item["name"], SysOptions.languages))[0]
spj_config = {} spj_config = {}
@@ -149,12 +146,22 @@ class JudgeDispatcher(DispatcherBase):
"spj_version": self.problem.spj_version, "spj_version": self.problem.spj_version,
"spj_config": spj_config.get("config"), "spj_config": spj_config.get("config"),
"spj_compile_config": spj_config.get("compile"), "spj_compile_config": spj_config.get("compile"),
"spj_src": self.problem.spj_code "spj_src": self.problem.spj_code,
"io_mode": self.problem.io_mode
} }
Submission.objects.filter(id=self.submission.id).update(result=JudgeStatus.JUDGING) with ChooseJudgeServer() as server:
if not server:
data = {"submission_id": self.submission.id, "problem_id": self.problem.id}
cache.lpush(CacheKey.waiting_queue, json.dumps(data))
return
Submission.objects.filter(id=self.submission.id).update(result=JudgeStatus.JUDGING)
resp = self._request(urljoin(server.service_url, "/judge"), data=data)
if not resp:
Submission.objects.filter(id=self.submission.id).update(result=JudgeStatus.SYSTEM_ERROR)
return
resp = self._request(urljoin(server.service_url, "/judge"), data=data)
if resp["err"]: if resp["err"]:
self.submission.result = JudgeStatus.COMPILE_ERROR self.submission.result = JudgeStatus.COMPILE_ERROR
self.submission.statistic_info["err_info"] = resp["data"] self.submission.statistic_info["err_info"] = resp["data"]
@@ -173,7 +180,6 @@ class JudgeDispatcher(DispatcherBase):
else: else:
self.submission.result = JudgeStatus.PARTIALLY_ACCEPTED self.submission.result = JudgeStatus.PARTIALLY_ACCEPTED
self.submission.save() self.submission.save()
self.release_judge_server(server.id)
if self.contest_id: if self.contest_id:
if self.contest.status != ContestStatus.CONTEST_UNDERWAY or \ if self.contest.status != ContestStatus.CONTEST_UNDERWAY or \
@@ -181,8 +187,9 @@ class JudgeDispatcher(DispatcherBase):
logger.info( logger.info(
"Contest debug mode, id: " + str(self.contest_id) + ", submission id: " + self.submission.id) "Contest debug mode, id: " + str(self.contest_id) + ", submission id: " + self.submission.id)
return return
self.update_contest_problem_status() with transaction.atomic():
self.update_contest_rank() self.update_contest_problem_status()
self.update_contest_rank()
else: else:
if self.last_result: if self.last_result:
self.update_problem_status_rejudge() self.update_problem_status_rejudge()
@@ -322,20 +329,31 @@ class JudgeDispatcher(DispatcherBase):
def update_contest_rank(self): def update_contest_rank(self):
if self.contest.rule_type == ContestRuleType.OI or self.contest.real_time_rank: if self.contest.rule_type == ContestRuleType.OI or self.contest.real_time_rank:
cache.delete(f"{CacheKey.contest_rank_cache}:{self.contest.id}") cache.delete(f"{CacheKey.contest_rank_cache}:{self.contest.id}")
with transaction.atomic():
if self.contest.rule_type == ContestRuleType.ACM: def get_rank(model):
acm_rank, _ = ACMContestRank.objects.select_for_update(). \ return model.objects.select_for_update().get(user_id=self.submission.user_id, contest=self.contest)
get_or_create(user_id=self.submission.user_id, contest=self.contest)
self._update_acm_contest_rank(acm_rank) if self.contest.rule_type == ContestRuleType.ACM:
else: model = ACMContestRank
oi_rank, _ = OIContestRank.objects.select_for_update(). \ func = self._update_acm_contest_rank
get_or_create(user_id=self.submission.user_id, contest=self.contest) else:
self._update_oi_contest_rank(oi_rank) model = OIContestRank
func = self._update_oi_contest_rank
try:
rank = get_rank(model)
except model.DoesNotExist:
try:
model.objects.create(user_id=self.submission.user_id, contest=self.contest)
rank = get_rank(model)
except IntegrityError:
rank = get_rank(model)
func(rank)
def _update_acm_contest_rank(self, rank): def _update_acm_contest_rank(self, rank):
info = rank.submission_info.get(str(self.submission.problem_id)) info = rank.submission_info.get(str(self.submission.problem_id))
# 因前面更改过,这里需要重新获取 # 因前面更改过,这里需要重新获取
problem = Problem.objects.get(contest_id=self.contest_id, id=self.problem.id) problem = Problem.objects.select_for_update().get(contest_id=self.contest_id, id=self.problem.id)
# 此题提交过 # 此题提交过
if info: if info:
if info["is_ac"]: if info["is_ac"]:

View File

@@ -1,3 +1,6 @@
from problem.models import ProblemIOMode
default_env = ["LANG=en_US.UTF-8", "LANGUAGE=en_US:en", "LC_ALL=en_US.UTF-8"] default_env = ["LANG=en_US.UTF-8", "LANGUAGE=en_US:en", "LC_ALL=en_US.UTF-8"]
_c_lang_config = { _c_lang_config = {
@@ -28,7 +31,7 @@ int main() {
}, },
"run": { "run": {
"command": "{exe_path}", "command": "{exe_path}",
"seccomp_rule": "c_cpp", "seccomp_rule": {ProblemIOMode.standard: "c_cpp", ProblemIOMode.file: "c_cpp_file_io"},
"env": default_env "env": default_env
} }
} }

View File

@@ -1,12 +1,12 @@
from __future__ import absolute_import, unicode_literals import dramatiq
from celery import shared_task
from account.models import User from account.models import User
from submission.models import Submission from submission.models import Submission
from judge.dispatcher import JudgeDispatcher from judge.dispatcher import JudgeDispatcher
from utils.shortcuts import DRAMATIQ_WORKER_ARGS
@shared_task @dramatiq.actor(**DRAMATIQ_WORKER_ARGS())
def judge_task(submission_id, problem_id): def judge_task(submission_id, problem_id):
uid = Submission.objects.get(id=submission_id).user_id uid = Submission.objects.get(id=submission_id).user_id
if User.objects.get(id=uid).is_disabled: if User.objects.get(id=uid).is_disabled:

View File

@@ -1,6 +0,0 @@
from __future__ import absolute_import, unicode_literals
# Django starts so that shared_task will use this app.
from .celery import app as celery_app
__all__ = ["celery_app"]

View File

@@ -1,18 +0,0 @@
from __future__ import absolute_import, unicode_literals
import os
from celery import Celery
from django.conf import settings
# set the default Django settings module for the "celery" program.
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "oj.settings")
app = Celery("oj")
# Using a string here means the worker will not have to
# pickle the object when using Windows.
app.config_from_object("django.conf:settings")
# load task modules from all registered Django app configs.
app.autodiscover_tasks(lambda: settings.INSTALLED_APPS)
# app.autodiscover_tasks()

View File

@@ -26,16 +26,22 @@ with open(os.path.join(DATA_DIR, "config", "secret.key"), "r") as f:
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# Applications # Applications
VENDOR_APPS = ( VENDOR_APPS = [
'django.contrib.auth', 'django.contrib.auth',
'django.contrib.sessions', 'django.contrib.sessions',
'django.contrib.contenttypes', 'django.contrib.contenttypes',
'django.contrib.messages', 'django.contrib.messages',
'django.contrib.staticfiles', 'django.contrib.staticfiles',
'rest_framework', 'rest_framework',
'raven.contrib.django.raven_compat' 'django_dramatiq',
) 'django_dbconn_retry',
LOCAL_APPS = ( ]
if production_env:
VENDOR_APPS.append('raven.contrib.django.raven_compat')
LOCAL_APPS = [
'account', 'account',
'announcement', 'announcement',
'conf', 'conf',
@@ -45,11 +51,11 @@ LOCAL_APPS = (
'submission', 'submission',
'options', 'options',
'judge', 'judge',
) ]
INSTALLED_APPS = VENDOR_APPS + LOCAL_APPS INSTALLED_APPS = VENDOR_APPS + LOCAL_APPS
MIDDLEWARE_CLASSES = ( MIDDLEWARE = (
'django.contrib.sessions.middleware.SessionMiddleware', 'django.contrib.sessions.middleware.SessionMiddleware',
'django.middleware.common.CommonMiddleware', 'django.middleware.common.CommonMiddleware',
'django.middleware.csrf.CsrfViewMiddleware', 'django.middleware.csrf.CsrfViewMiddleware',
@@ -164,6 +170,11 @@ LOGGING = {
'level': 'ERROR', 'level': 'ERROR',
'propagate': True, 'propagate': True,
}, },
'dramatiq': {
'handlers': LOGGING_HANDLERS,
'level': 'DEBUG',
'propagate': False,
},
'': { '': {
'handlers': LOGGING_HANDLERS, 'handlers': LOGGING_HANDLERS,
'level': 'WARNING', 'level': 'WARNING',
@@ -202,11 +213,32 @@ CACHES = {
SESSION_ENGINE = "django.contrib.sessions.backends.cache" SESSION_ENGINE = "django.contrib.sessions.backends.cache"
SESSION_CACHE_ALIAS = "default" SESSION_CACHE_ALIAS = "default"
CELERY_RESULT_BACKEND = f"{REDIS_URL}/2" DRAMATIQ_BROKER = {
BROKER_URL = f"{REDIS_URL}/3" "BROKER": "dramatiq.brokers.redis.RedisBroker",
CELERY_TASK_SOFT_TIME_LIMIT = CELERY_TASK_TIME_LIMIT = 180 "OPTIONS": {
CELERY_ACCEPT_CONTENT = ["json"] "url": f"{REDIS_URL}/4",
CELERY_TASK_SERIALIZER = "json" },
"MIDDLEWARE": [
# "dramatiq.middleware.Prometheus",
"dramatiq.middleware.AgeLimit",
"dramatiq.middleware.TimeLimit",
"dramatiq.middleware.Callbacks",
"dramatiq.middleware.Retries",
# "django_dramatiq.middleware.AdminMiddleware",
"django_dramatiq.middleware.DbConnectionsMiddleware"
]
}
DRAMATIQ_RESULT_BACKEND = {
"BACKEND": "dramatiq.results.backends.redis.RedisBackend",
"BACKEND_OPTIONS": {
"url": f"{REDIS_URL}/4",
},
"MIDDLEWARE_OPTIONS": {
"result_ttl": None
}
}
RAVEN_CONFIG = { RAVEN_CONFIG = {
'dsn': 'https://b200023b8aed4d708fb593c5e0a6ad3d:1fddaba168f84fcf97e0d549faaeaff0@sentry.io/263057' 'dsn': 'https://b200023b8aed4d708fb593c5e0a6ad3d:1fddaba168f84fcf97e0d549faaeaff0@sentry.io/263057'
} }

View File

@@ -0,0 +1,18 @@
# -*- coding: utf-8 -*-
# Generated by Django 1.11.3 on 2018-05-01 04:36
from __future__ import unicode_literals
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('options', '0002_auto_20180501_0436'),
]
operations = [
migrations.RunSQL("""
DELETE FROM options_sysoptions WHERE key = 'languages';
""")
]

View File

@@ -1,13 +1,92 @@
import functools
import os import os
from django.core.cache import cache import threading
import time
from django.db import transaction, IntegrityError from django.db import transaction, IntegrityError
from utils.constants import CacheKey
from utils.shortcuts import rand_str from utils.shortcuts import rand_str
from judge.languages import languages from judge.languages import languages
from .models import SysOptions as SysOptionsModel from .models import SysOptions as SysOptionsModel
class my_property:
"""
在 metaclass 中使用,以实现:
1. ttl = None不缓存
2. ttl is callable条件缓存
3. 缓存 ttl 秒
"""
def __init__(self, func=None, fset=None, ttl=None):
self.fset = fset
self.local = threading.local()
self.ttl = ttl
self._check_ttl(ttl)
self.func = func
functools.update_wrapper(self, func)
def _check_ttl(self, value):
if value is None or callable(value):
return
return self._check_timeout(value)
def _check_timeout(self, value):
if not isinstance(value, int):
raise ValueError(f"Invalid timeout type: {type(value)}")
if value < 0:
raise ValueError("Invalid timeout value, it must >= 0")
def __get__(self, obj, cls):
if obj is None:
return self
now = time.time()
if self.ttl:
if hasattr(self.local, "value"):
value, expire_at = self.local.value
if now < expire_at:
return value
value = self.func(obj)
# 如果定义了条件缓存, ttl 是一个函数,返回要缓存多久;返回 0 代表不要缓存
if callable(self.ttl):
# 而且条件缓存说不要缓存,那就直接返回,不要设置 local
timeout = self.ttl(value)
self._check_timeout(timeout)
if timeout == 0:
return value
elif timeout > 0:
self.local.value = (value, now + timeout)
else:
# ttl 是一个数字
self.local.value = (value, now + self.ttl)
return value
else:
return self.func(obj)
def __set__(self, obj, value):
if not self.fset:
raise AttributeError("can't set attribute")
self.fset(obj, value)
if hasattr(self.local, "value"):
del self.local.value
def setter(self, func):
self.fset = func
return self
def __call__(self, func, *args, **kwargs) -> "my_property":
if self.func is None:
self.func = func
functools.update_wrapper(self, func)
return self
DEFAULT_SHORT_TTL = 2
def default_token(): def default_token():
token = os.environ.get("JUDGE_SERVER_TOKEN") token = os.environ.get("JUDGE_SERVER_TOKEN")
return token if token else rand_str() return token if token else rand_str()
@@ -41,23 +120,10 @@ class OptionDefaultValue:
class _SysOptionsMeta(type): class _SysOptionsMeta(type):
@classmethod
def _set_cache(mcs, option_key, option_value):
cache.set(f"{CacheKey.option}:{option_key}", option_value, timeout=60)
@classmethod
def _del_cache(mcs, option_key):
cache.delete(f"{CacheKey.option}:{option_key}")
@classmethod @classmethod
def _get_keys(cls): def _get_keys(cls):
return [key for key in OptionKeys.__dict__ if not key.startswith("__")] return [key for key in OptionKeys.__dict__ if not key.startswith("__")]
def rebuild_cache(cls):
for key in cls._get_keys():
# get option 的时候会写 cache 的
cls._get_option(key, use_cache=False)
@classmethod @classmethod
def _init_option(mcs): def _init_option(mcs):
for item in mcs._get_keys(): for item in mcs._get_keys():
@@ -71,19 +137,14 @@ class _SysOptionsMeta(type):
pass pass
@classmethod @classmethod
def _get_option(mcs, option_key, use_cache=True): def _get_option(mcs, option_key):
try: try:
if use_cache:
option = cache.get(f"{CacheKey.option}:{option_key}")
if option:
return option
option = SysOptionsModel.objects.get(key=option_key) option = SysOptionsModel.objects.get(key=option_key)
value = option.value value = option.value
mcs._set_cache(option_key, value)
return value return value
except SysOptionsModel.DoesNotExist: except SysOptionsModel.DoesNotExist:
mcs._init_option() mcs._init_option()
return mcs._get_option(option_key, use_cache=use_cache) return mcs._get_option(option_key)
@classmethod @classmethod
def _set_option(mcs, option_key: str, option_value): def _set_option(mcs, option_key: str, option_value):
@@ -92,7 +153,6 @@ class _SysOptionsMeta(type):
option = SysOptionsModel.objects.select_for_update().get(key=option_key) option = SysOptionsModel.objects.select_for_update().get(key=option_key)
option.value = option_value option.value = option_value
option.save() option.save()
mcs._del_cache(option_key)
except SysOptionsModel.DoesNotExist: except SysOptionsModel.DoesNotExist:
mcs._init_option() mcs._init_option()
mcs._set_option(option_key, option_value) mcs._set_option(option_key, option_value)
@@ -105,7 +165,6 @@ class _SysOptionsMeta(type):
value = option.value + 1 value = option.value + 1
option.value = value option.value = value
option.save() option.save()
mcs._del_cache(option_key)
except SysOptionsModel.DoesNotExist: except SysOptionsModel.DoesNotExist:
mcs._init_option() mcs._init_option()
return mcs._increment(option_key) return mcs._increment(option_key)
@@ -122,7 +181,7 @@ class _SysOptionsMeta(type):
result[key] = mcs._get_option(key) result[key] = mcs._get_option(key)
return result return result
@property @my_property(ttl=DEFAULT_SHORT_TTL)
def website_base_url(cls): def website_base_url(cls):
return cls._get_option(OptionKeys.website_base_url) return cls._get_option(OptionKeys.website_base_url)
@@ -130,7 +189,7 @@ class _SysOptionsMeta(type):
def website_base_url(cls, value): def website_base_url(cls, value):
cls._set_option(OptionKeys.website_base_url, value) cls._set_option(OptionKeys.website_base_url, value)
@property @my_property(ttl=DEFAULT_SHORT_TTL)
def website_name(cls): def website_name(cls):
return cls._get_option(OptionKeys.website_name) return cls._get_option(OptionKeys.website_name)
@@ -138,7 +197,7 @@ class _SysOptionsMeta(type):
def website_name(cls, value): def website_name(cls, value):
cls._set_option(OptionKeys.website_name, value) cls._set_option(OptionKeys.website_name, value)
@property @my_property(ttl=DEFAULT_SHORT_TTL)
def website_name_shortcut(cls): def website_name_shortcut(cls):
return cls._get_option(OptionKeys.website_name_shortcut) return cls._get_option(OptionKeys.website_name_shortcut)
@@ -146,7 +205,7 @@ class _SysOptionsMeta(type):
def website_name_shortcut(cls, value): def website_name_shortcut(cls, value):
cls._set_option(OptionKeys.website_name_shortcut, value) cls._set_option(OptionKeys.website_name_shortcut, value)
@property @my_property(ttl=DEFAULT_SHORT_TTL)
def website_footer(cls): def website_footer(cls):
return cls._get_option(OptionKeys.website_footer) return cls._get_option(OptionKeys.website_footer)
@@ -154,7 +213,7 @@ class _SysOptionsMeta(type):
def website_footer(cls, value): def website_footer(cls, value):
cls._set_option(OptionKeys.website_footer, value) cls._set_option(OptionKeys.website_footer, value)
@property @my_property
def allow_register(cls): def allow_register(cls):
return cls._get_option(OptionKeys.allow_register) return cls._get_option(OptionKeys.allow_register)
@@ -162,7 +221,7 @@ class _SysOptionsMeta(type):
def allow_register(cls, value): def allow_register(cls, value):
cls._set_option(OptionKeys.allow_register, value) cls._set_option(OptionKeys.allow_register, value)
@property @my_property(ttl=DEFAULT_SHORT_TTL)
def submission_list_show_all(cls): def submission_list_show_all(cls):
return cls._get_option(OptionKeys.submission_list_show_all) return cls._get_option(OptionKeys.submission_list_show_all)
@@ -170,7 +229,7 @@ class _SysOptionsMeta(type):
def submission_list_show_all(cls, value): def submission_list_show_all(cls, value):
cls._set_option(OptionKeys.submission_list_show_all, value) cls._set_option(OptionKeys.submission_list_show_all, value)
@property @my_property
def smtp_config(cls): def smtp_config(cls):
return cls._get_option(OptionKeys.smtp_config) return cls._get_option(OptionKeys.smtp_config)
@@ -178,7 +237,7 @@ class _SysOptionsMeta(type):
def smtp_config(cls, value): def smtp_config(cls, value):
cls._set_option(OptionKeys.smtp_config, value) cls._set_option(OptionKeys.smtp_config, value)
@property @my_property
def judge_server_token(cls): def judge_server_token(cls):
return cls._get_option(OptionKeys.judge_server_token) return cls._get_option(OptionKeys.judge_server_token)
@@ -186,7 +245,7 @@ class _SysOptionsMeta(type):
def judge_server_token(cls, value): def judge_server_token(cls, value):
cls._set_option(OptionKeys.judge_server_token, value) cls._set_option(OptionKeys.judge_server_token, value)
@property @my_property
def throttling(cls): def throttling(cls):
return cls._get_option(OptionKeys.throttling) return cls._get_option(OptionKeys.throttling)
@@ -194,7 +253,7 @@ class _SysOptionsMeta(type):
def throttling(cls, value): def throttling(cls, value):
cls._set_option(OptionKeys.throttling, value) cls._set_option(OptionKeys.throttling, value)
@property @my_property(ttl=DEFAULT_SHORT_TTL)
def languages(cls): def languages(cls):
return cls._get_option(OptionKeys.languages) return cls._get_option(OptionKeys.languages)
@@ -202,15 +261,15 @@ class _SysOptionsMeta(type):
def languages(cls, value): def languages(cls, value):
cls._set_option(OptionKeys.languages, value) cls._set_option(OptionKeys.languages, value)
@property @my_property(ttl=DEFAULT_SHORT_TTL)
def spj_languages(cls): def spj_languages(cls):
return [item for item in cls.languages if "spj" in item] return [item for item in cls.languages if "spj" in item]
@property @my_property(ttl=DEFAULT_SHORT_TTL)
def language_names(cls): def language_names(cls):
return [item["name"] for item in languages] return [item["name"] for item in languages]
@property @my_property(ttl=DEFAULT_SHORT_TTL)
def spj_language_names(cls): def spj_language_names(cls):
return [item["name"] for item in cls.languages if "spj" in item] return [item["name"] for item in cls.languages if "spj" in item]

View File

@@ -0,0 +1,20 @@
# Generated by Django 2.1.7 on 2019-03-12 07:13
import django.contrib.postgres.fields.jsonb
from django.db import migrations
import problem.models
class Migration(migrations.Migration):
dependencies = [
('problem', '0012_auto_20180501_0436'),
]
operations = [
migrations.AddField(
model_name='problem',
name='io_mode',
field=django.contrib.postgres.fields.jsonb.JSONField(default=problem.models._default_io_mode),
),
]

View File

@@ -0,0 +1,18 @@
# Generated by Django 2.1.7 on 2019-03-13 09:38
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('problem', '0013_problem_io_mode'),
]
operations = [
migrations.AddField(
model_name='problem',
name='share_submission',
field=models.BooleanField(default=False),
),
]

View File

@@ -25,10 +25,19 @@ class ProblemDifficulty(object):
Low = "Low" Low = "Low"
class ProblemIOMode(Choices):
standard = "Standard IO"
file = "File IO"
def _default_io_mode():
return {"io_mode": ProblemIOMode.standard, "input": "input.txt", "output": "output.txt"}
class Problem(models.Model): class Problem(models.Model):
# display ID # display ID
_id = models.TextField(db_index=True) _id = models.TextField(db_index=True)
contest = models.ForeignKey(Contest, null=True) contest = models.ForeignKey(Contest, null=True, on_delete=models.CASCADE)
# for contest problem # for contest problem
is_public = models.BooleanField(default=False) is_public = models.BooleanField(default=False)
title = models.TextField() title = models.TextField()
@@ -47,11 +56,13 @@ class Problem(models.Model):
create_time = models.DateTimeField(auto_now_add=True) create_time = models.DateTimeField(auto_now_add=True)
# we can not use auto_now here # we can not use auto_now here
last_update_time = models.DateTimeField(null=True) last_update_time = models.DateTimeField(null=True)
created_by = models.ForeignKey(User) created_by = models.ForeignKey(User, on_delete=models.CASCADE)
# ms # ms
time_limit = models.IntegerField() time_limit = models.IntegerField()
# MB # MB
memory_limit = models.IntegerField() memory_limit = models.IntegerField()
# io mode
io_mode = JSONField(default=_default_io_mode)
# special judge related # special judge related
spj = models.BooleanField(default=False) spj = models.BooleanField(default=False)
spj_language = models.TextField(null=True) spj_language = models.TextField(null=True)
@@ -69,6 +80,7 @@ class Problem(models.Model):
accepted_number = models.BigIntegerField(default=0) accepted_number = models.BigIntegerField(default=0)
# {JudgeStatus.ACCEPTED: 3, JudgeStaus.WRONG_ANSWER: 11}, the number means count # {JudgeStatus.ACCEPTED: 3, JudgeStaus.WRONG_ANSWER: 11}, the number means count
statistic_info = JSONField(default=dict) statistic_info = JSONField(default=dict)
share_submission = models.BooleanField(default=False)
class Meta: class Meta:
db_table = "problem" db_table = "problem"

View File

@@ -1,3 +1,5 @@
import re
from django import forms from django import forms
from options.options import SysOptions from options.options import SysOptions
@@ -5,7 +7,7 @@ from utils.api import UsernameSerializer, serializers
from utils.constants import Difficulty from utils.constants import Difficulty
from utils.serializers import LanguageNameMultiChoiceField, SPJLanguageNameChoiceField, LanguageNameChoiceField from utils.serializers import LanguageNameMultiChoiceField, SPJLanguageNameChoiceField, LanguageNameChoiceField
from .models import Problem, ProblemRuleType, ProblemTag from .models import Problem, ProblemRuleType, ProblemTag, ProblemIOMode
from .utils import parse_problem_template from .utils import parse_problem_template
@@ -29,6 +31,20 @@ class CreateProblemCodeTemplateSerializer(serializers.Serializer):
pass pass
class ProblemIOModeSerializer(serializers.Serializer):
io_mode = serializers.ChoiceField(choices=ProblemIOMode.choices())
input = serializers.CharField()
output = serializers.CharField()
def validate(self, attrs):
if attrs["input"] == attrs["output"]:
raise serializers.ValidationError("Invalid io mode")
for item in (attrs["input"], attrs["output"]):
if not re.match("^[a-zA-Z0-9.]+$", item):
raise serializers.ValidationError("Invalid io file name format")
return attrs
class CreateOrEditProblemSerializer(serializers.Serializer): class CreateOrEditProblemSerializer(serializers.Serializer):
_id = serializers.CharField(max_length=32, allow_blank=True, allow_null=True) _id = serializers.CharField(max_length=32, allow_blank=True, allow_null=True)
title = serializers.CharField(max_length=1024) title = serializers.CharField(max_length=1024)
@@ -43,6 +59,7 @@ class CreateOrEditProblemSerializer(serializers.Serializer):
languages = LanguageNameMultiChoiceField() languages = LanguageNameMultiChoiceField()
template = serializers.DictField(child=serializers.CharField(min_length=1)) template = serializers.DictField(child=serializers.CharField(min_length=1))
rule_type = serializers.ChoiceField(choices=[ProblemRuleType.ACM, ProblemRuleType.OI]) rule_type = serializers.ChoiceField(choices=[ProblemRuleType.ACM, ProblemRuleType.OI])
io_mode = ProblemIOModeSerializer()
spj = serializers.BooleanField() spj = serializers.BooleanField()
spj_language = SPJLanguageNameChoiceField(allow_blank=True, allow_null=True) spj_language = SPJLanguageNameChoiceField(allow_blank=True, allow_null=True)
spj_code = serializers.CharField(allow_blank=True, allow_null=True) spj_code = serializers.CharField(allow_blank=True, allow_null=True)
@@ -52,6 +69,7 @@ class CreateOrEditProblemSerializer(serializers.Serializer):
tags = serializers.ListField(child=serializers.CharField(max_length=32), allow_empty=False) tags = serializers.ListField(child=serializers.CharField(max_length=32), allow_empty=False)
hint = serializers.CharField(allow_blank=True, allow_null=True) hint = serializers.CharField(allow_blank=True, allow_null=True)
source = serializers.CharField(max_length=256, allow_blank=True, allow_null=True) source = serializers.CharField(max_length=256, allow_blank=True, allow_null=True)
share_submission = serializers.BooleanField()
class CreateProblemSerializer(CreateOrEditProblemSerializer): class CreateProblemSerializer(CreateOrEditProblemSerializer):

View File

@@ -9,7 +9,7 @@ from django.conf import settings
from utils.api.tests import APITestCase from utils.api.tests import APITestCase
from .models import ProblemTag from .models import ProblemTag, ProblemIOMode
from .models import Problem, ProblemRuleType from .models import Problem, ProblemRuleType
from contest.models import Contest from contest.models import Contest
from contest.tests import DEFAULT_CONTEST_DATA from contest.tests import DEFAULT_CONTEST_DATA
@@ -25,6 +25,8 @@ DEFAULT_PROBLEM_DATA = {"_id": "A-110", "title": "test", "description": "<p>test
"test_case_score": [{"output_name": "1.out", "input_name": "1.in", "output_size": 0, "test_case_score": [{"output_name": "1.out", "input_name": "1.in", "output_size": 0,
"stripped_output_md5": "d41d8cd98f00b204e9800998ecf8427e", "stripped_output_md5": "d41d8cd98f00b204e9800998ecf8427e",
"input_size": 0, "score": 0}], "input_size": 0, "score": 0}],
"io_mode": {"io_mode": ProblemIOMode.standard, "input": "input.txt", "output": "output.txt"},
"share_submission": False,
"rule_type": "ACM", "hint": "<p>test</p>", "source": "test"} "rule_type": "ACM", "hint": "<p>test</p>", "source": "test"}

View File

@@ -1,4 +1,6 @@
import re import re
from functools import lru_cache
TEMPLATE_BASE = """//PREPEND BEGIN TEMPLATE_BASE = """//PREPEND BEGIN
{} {}
@@ -13,6 +15,7 @@ TEMPLATE_BASE = """//PREPEND BEGIN
//APPEND END""" //APPEND END"""
@lru_cache(maxsize=100)
def parse_problem_template(template_str): def parse_problem_template(template_str):
prepend = re.findall(r"//PREPEND BEGIN\n([\s\S]+?)//PREPEND END", template_str) prepend = re.findall(r"//PREPEND BEGIN\n([\s\S]+?)//PREPEND END", template_str)
template = re.findall(r"//TEMPLATE BEGIN\n([\s\S]+?)//TEMPLATE END", template_str) template = re.findall(r"//TEMPLATE BEGIN\n([\s\S]+?)//TEMPLATE END", template_str)
@@ -22,5 +25,6 @@ def parse_problem_template(template_str):
"append": append[0] if append else ""} "append": append[0] if append else ""}
@lru_cache(maxsize=100)
def build_problem_template(prepend, template, append): def build_problem_template(prepend, template, append):
return TEMPLATE_BASE.format(prepend, template, append) return TEMPLATE_BASE.format(prepend, template, append)

View File

@@ -300,8 +300,6 @@ class ProblemAPI(ProblemBase):
except Problem.DoesNotExist: except Problem.DoesNotExist:
return self.error("Problem does not exists") return self.error("Problem does not exists")
ensure_created_by(problem, request.user) ensure_created_by(problem, request.user)
if Submission.objects.filter(problem=problem).exists():
return self.error("Can't delete the problem as it has submissions")
d = os.path.join(settings.TEST_CASE_DIR, problem.test_case_id) d = os.path.join(settings.TEST_CASE_DIR, problem.test_case_id)
if os.path.isdir(d): if os.path.isdir(d):
shutil.rmtree(d, ignore_errors=True) shutil.rmtree(d, ignore_errors=True)
@@ -541,7 +539,7 @@ class ExportProblemAPI(APIView):
with zipfile.ZipFile(path, "w") as zip_file: with zipfile.ZipFile(path, "w") as zip_file:
for index, problem in enumerate(problems): for index, problem in enumerate(problems):
self.process_one_problem(zip_file=zip_file, user=request.user, problem=problem, index=index + 1) self.process_one_problem(zip_file=zip_file, user=request.user, problem=problem, index=index + 1)
delete_files.apply_async((path,), countdown=300) delete_files.send_with_options(args=(path,), delay=300_000)
resp = FileResponse(open(path, "rb")) resp = FileResponse(open(path, "rb"))
resp["Content-Type"] = "application/zip" resp["Content-Type"] = "application/zip"
resp["Content-Disposition"] = f"attachment;filename=problem-export.zip" resp["Content-Disposition"] = f"attachment;filename=problem-export.zip"

View File

@@ -25,7 +25,7 @@ class PickOneAPI(APIView):
class ProblemAPI(APIView): class ProblemAPI(APIView):
@staticmethod @staticmethod
def _add_problem_status(request, queryset_values): def _add_problem_status(request, queryset_values):
if request.user.is_authenticated(): if request.user.is_authenticated:
profile = request.user.userprofile profile = request.user.userprofile
acm_problems_status = profile.acm_problems_status.get("problems", {}) acm_problems_status = profile.acm_problems_status.get("problems", {})
oi_problems_status = profile.oi_problems_status.get("problems", {}) oi_problems_status = profile.oi_problems_status.get("problems", {})
@@ -81,7 +81,7 @@ class ProblemAPI(APIView):
class ContestProblemAPI(APIView): class ContestProblemAPI(APIView):
def _add_problem_status(self, request, queryset_values): def _add_problem_status(self, request, queryset_values):
if request.user.is_authenticated(): if request.user.is_authenticated:
profile = request.user.userprofile profile = request.user.userprofile
if self.contest.rule_type == ContestRuleType.ACM: if self.contest.rule_type == ContestRuleType.ACM:
problems_status = profile.acm_problems_status.get("contest_problems", {}) problems_status = profile.acm_problems_status.get("contest_problems", {})

View File

@@ -22,8 +22,8 @@ class JudgeStatus:
class Submission(models.Model): class Submission(models.Model):
id = models.TextField(default=rand_str, primary_key=True, db_index=True) id = models.TextField(default=rand_str, primary_key=True, db_index=True)
contest = models.ForeignKey(Contest, null=True) contest = models.ForeignKey(Contest, null=True, on_delete=models.CASCADE)
problem = models.ForeignKey(Problem) problem = models.ForeignKey(Problem, on_delete=models.CASCADE)
create_time = models.DateTimeField(auto_now_add=True) create_time = models.DateTimeField(auto_now_add=True)
user_id = models.IntegerField(db_index=True) user_id = models.IntegerField(db_index=True)
username = models.TextField() username = models.TextField()
@@ -41,6 +41,7 @@ class Submission(models.Model):
def check_user_permission(self, user, check_share=True): def check_user_permission(self, user, check_share=True):
return self.user_id == user.id or \ return self.user_id == user.id or \
(check_share and self.shared is True) or \ (check_share and self.shared is True) or \
(check_share and self.problem.share_submission) or \
user.is_super_admin() or \ user.is_super_admin() or \
user.can_mgmt_all_problem() or \ user.can_mgmt_all_problem() or \
self.problem.created_by_id == user.id self.problem.created_by_id == user.id

View File

@@ -46,6 +46,6 @@ class SubmissionListSerializer(serializers.ModelSerializer):
def get_show_link(self, obj): def get_show_link(self, obj):
# 没传user或为匿名user # 没传user或为匿名user
if self.user is None or not self.user.is_authenticated(): if self.user is None or not self.user.is_authenticated:
return False return False
return obj.check_user_permission(self.user) return obj.check_user_permission(self.user)

View File

@@ -57,7 +57,7 @@ class SubmissionListTest(SubmissionPrepare):
self.assertSuccess(resp) self.assertSuccess(resp)
@mock.patch("submission.views.oj.judge_task.delay") @mock.patch("submission.views.oj.judge_task.send")
class SubmissionAPITest(SubmissionPrepare): class SubmissionAPITest(SubmissionPrepare):
def setUp(self): def setUp(self):
self._create_problem_and_submission() self._create_problem_and_submission()

View File

@@ -18,5 +18,5 @@ class SubmissionRejudgeAPI(APIView):
submission.statistic_info = {} submission.statistic_info = {}
submission.save() submission.save()
judge_task.delay(submission.id, submission.problem.id) judge_task.send(submission.id, submission.problem.id)
return self.success() return self.success()

View File

@@ -80,7 +80,7 @@ class SubmissionAPI(APIView):
contest_id=data.get("contest_id")) contest_id=data.get("contest_id"))
# use this for debug # use this for debug
# JudgeDispatcher(submission.id, problem.id).judge() # JudgeDispatcher(submission.id, problem.id).judge()
judge_task.delay(submission.id, problem.id) judge_task.send(submission.id, problem.id)
if hide_id: if hide_id:
return self.success() return self.success()
else: else:
@@ -198,6 +198,6 @@ class SubmissionExistsAPI(APIView):
def get(self, request): def get(self, request):
if not request.GET.get("problem_id"): if not request.GET.get("problem_id"):
return self.error("Parameter error, problem_id is required") return self.error("Parameter error, problem_id is required")
return self.success(request.user.is_authenticated() and return self.success(request.user.is_authenticated and
Submission.objects.filter(problem_id=request.GET["problem_id"], Submission.objects.filter(problem_id=request.GET["problem_id"],
user_id=request.user.id).exists()) user_id=request.user.id).exists())

View File

@@ -1,7 +1,6 @@
import functools import functools
import json import json
import logging import logging
from collections import OrderedDict
from django.http import HttpResponse, QueryDict from django.http import HttpResponse, QueryDict
from django.utils.decorators import method_decorator from django.utils.decorators import method_decorator
@@ -89,20 +88,24 @@ class APIView(View):
def error(self, msg="error", err="error"): def error(self, msg="error", err="error"):
return self.response({"error": err, "data": msg}) return self.response({"error": err, "data": msg})
def _serializer_error_to_str(self, errors): def extract_errors(self, errors, key="field"):
for k, v in errors.items(): if isinstance(errors, dict):
if isinstance(v, list): if not errors:
return k, v[0] return key, "Invalid field"
elif isinstance(v, OrderedDict): key = list(errors.keys())[0]
for _k, _v in v.items(): return self.extract_errors(errors.pop(key), key)
return self._serializer_error_to_str({_k: _v}) elif isinstance(errors, list):
return self.extract_errors(errors[0], key)
return key, errors
def invalid_serializer(self, serializer): def invalid_serializer(self, serializer):
k, v = self._serializer_error_to_str(serializer.errors) key, error = self.extract_errors(serializer.errors)
if k != "non_field_errors": if key == "non_field_errors":
return self.error(err="invalid-" + k, msg=k + ": " + v) msg = error
else: else:
return self.error(err="invalid-field", msg=v) msg = f"{key}: {error}"
return self.error(err=f"invalid-{key}", msg=msg)
def server_error(self): def server_error(self):
return self.error(err="server-error", msg="server error") return self.error(err="server-error", msg="server error")

View File

@@ -1,4 +1,4 @@
from django.core.urlresolvers import reverse from django.urls import reverse
from django.test.testcases import TestCase from django.test.testcases import TestCase
from rest_framework.test import APIClient from rest_framework.test import APIClient

View File

@@ -25,7 +25,6 @@ class CacheKey:
waiting_queue = "waiting_queue" waiting_queue = "waiting_queue"
contest_rank_cache = "contest_rank_cache" contest_rank_cache = "contest_rank_cache"
website_config = "website_config" website_config = "website_config"
option = "option"
class Difficulty(Choices): class Difficulty(Choices):

View File

@@ -81,3 +81,14 @@ def send_email(smtp_config, from_name, to_email, to_name, subject, content):
def get_env(name, default=""): def get_env(name, default=""):
return os.environ.get(name, default) return os.environ.get(name, default)
def DRAMATIQ_WORKER_ARGS(time_limit=3600_000, max_retries=0, max_age=7200_000):
return {"max_retries": max_retries, "time_limit": time_limit, "max_age": max_age}
def check_is_id(value):
try:
return int(value) > 0
except Exception:
return False

View File

@@ -1,8 +1,10 @@
import os import os
from celery import shared_task import dramatiq
from utils.shortcuts import DRAMATIQ_WORKER_ARGS
@shared_task @dramatiq.actor(**DRAMATIQ_WORKER_ARGS())
def delete_files(*args): def delete_files(*args):
for item in args: for item in args:
try: try:

View File

@@ -1,7 +1,8 @@
from django.conf.urls import url from django.conf.urls import url
from .views import SimditorImageUploadAPIView from .views import SimditorImageUploadAPIView, SimditorFileUploadAPIView
urlpatterns = [ urlpatterns = [
url(r"^upload_image/?$", SimditorImageUploadAPIView.as_view(), name="upload_image") url(r"^upload_image/?$", SimditorImageUploadAPIView.as_view(), name="upload_image"),
url(r"^upload_file/?$", SimditorFileUploadAPIView.as_view(), name="upload_file")
] ]

View File

@@ -1,6 +1,6 @@
import os import os
from django.conf import settings from django.conf import settings
from account.serializers import ImageUploadForm from account.serializers import ImageUploadForm, FileUploadForm
from utils.shortcuts import rand_str from utils.shortcuts import rand_str
from utils.api import CSRFExemptAPIView from utils.api import CSRFExemptAPIView
import logging import logging
@@ -35,10 +35,41 @@ class SimditorImageUploadAPIView(CSRFExemptAPIView):
except IOError as e: except IOError as e:
logger.error(e) logger.error(e)
return self.response({ return self.response({
"success": True, "success": False,
"msg": "Upload Error", "msg": "Upload Error",
"file_path": f"{settings.UPLOAD_PREFIX}/{img_name}"}) "file_path": ""})
return self.response({ return self.response({
"success": True, "success": True,
"msg": "Success", "msg": "Success",
"file_path": f"{settings.UPLOAD_PREFIX}/{img_name}"}) "file_path": f"{settings.UPLOAD_PREFIX}/{img_name}"})
class SimditorFileUploadAPIView(CSRFExemptAPIView):
request_parsers = ()
def post(self, request):
form = FileUploadForm(request.POST, request.FILES)
if form.is_valid():
file = form.cleaned_data["file"]
else:
return self.response({
"success": False,
"msg": "Upload failed"
})
suffix = os.path.splitext(file.name)[-1].lower()
file_name = rand_str(10) + suffix
try:
with open(os.path.join(settings.UPLOAD_DIR, file_name), "wb") as f:
for chunk in file:
f.write(chunk)
except IOError as e:
logger.error(e)
return self.response({
"success": False,
"msg": "Upload Error"})
return self.response({
"success": True,
"msg": "Success",
"file_path": f"{settings.UPLOAD_PREFIX}/{file_name}",
"file_name": file.name})

View File

@@ -142,7 +142,7 @@ class XSSHtml(HTMLParser):
return attrs return attrs
def _true_url(self, url): def _true_url(self, url):
prog = re.compile(r"^(http|https|ftp)://.+", re.I | re.S) prog = re.compile(r"(^(http|https|ftp)://.+)|(^/)", re.I | re.S)
if prog.match(url): if prog.match(url):
return url return url
else: else: