diff --git a/.flake8 b/.flake8
index 2e0788e..d64dbd5 100644
--- a/.flake8
+++ b/.flake8
@@ -1,8 +1,9 @@
[flake8]
exclude =
xss_filter.py,
- migrations/,
+ */migrations/,
*settings.py
+ */apps.py
max-line-length = 180
inline-quotes = "
no-accept-encodings = True
diff --git a/.gitignore b/.gitignore
index e9a0324..3a8cb90 100644
--- a/.gitignore
+++ b/.gitignore
@@ -54,21 +54,18 @@ db.db
#*.out
*.sqlite3
.DS_Store
-log/
-static/release/css
-static/release/js
-static/release/img
-static/src/upload_image/*
build.txt
tmp/
-test_case/
-release/
-upload/
custom_settings.py
-docker-compose.yml
*.zip
-rsyncd.passwd
-node_modules/
-update.sh
-ssh.sh
+data/log/*
+!data/log/.gitkeep
+data/test_case/*
+!data/test_case/.gitkeep
+data/ssl/*
+!data/ssl/.gitkeep
+data/public/upload/*
+!data/public/upload/.gitkeep
+data/public/avatar/*
+!data/public/avatar/default.png
diff --git a/.python-version b/.python-version
index 1545d96..b727628 100644
--- a/.python-version
+++ b/.python-version
@@ -1 +1 @@
-3.5.0
+3.6.2
diff --git a/.travis.yml b/.travis.yml
index f51b0b5..5a01bfa 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -1,14 +1,19 @@
language: python
python:
- - "3.5"
+ - "3.6"
+services:
+ - redis-server
+ - docker
+before_install:
+ - docker pull postgres:10
+ - docker run -it -d -e POSTGRES_DB=onlinejudge -e POSTGRES_USER=onlinejudge -e POSTGRES_PASSWORD=onlinejudge -p 127.0.0.1:5433:5432 postgres:10
install:
- pip install -r deploy/requirements.txt
- - mkdir log test_case upload
- cp oj/custom_settings.example.py oj/custom_settings.py
- echo "SECRET_KEY=\"`cat /dev/urandom | head -1 | md5sum | head -c 32`\"" >> oj/custom_settings.py
- python manage.py migrate
- - python manage.py initadmin
script:
+ - docker ps -a
- flake8 .
- coverage run --include="$PWD/*" manage.py test
- coverage report
diff --git a/Dockerfile b/Dockerfile
new file mode 100644
index 0000000..9e3f69c
--- /dev/null
+++ b/Dockerfile
@@ -0,0 +1,15 @@
+FROM python:3.6-alpine3.6
+
+ENV OJ_ENV production
+
+ADD . /app
+WORKDIR /app
+
+RUN printf "https://mirrors.tuna.tsinghua.edu.cn/alpine/v3.6/community/\nhttps://mirrors.tuna.tsinghua.edu.cn/alpine/v3.6/main/" > /etc/apk/repositories && \
+ apk add --update --no-cache build-base nginx openssl curl unzip supervisor jpeg-dev zlib-dev postgresql-dev freetype-dev && \
+ pip install --no-cache-dir -r /app/deploy/requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple && \
+ apk del build-base --purge
+RUN curl -L $(curl -s https://api.github.com/repos/QingdaoU/OnlineJudgeFE/releases/latest | grep /dist.zip | cut -d '"' -f 4) -o dist.zip && \
+ unzip dist.zip && \
+ rm dist.zip
+CMD sh /app/deploy/run.sh
diff --git a/account/decorators.py b/account/decorators.py
index 9df97e5..4cbb15a 100644
--- a/account/decorators.py
+++ b/account/decorators.py
@@ -4,6 +4,8 @@ from utils.api import JSONResponse
from .models import ProblemPermission
+from contest.models import Contest, ContestType, ContestStatus, ContestRuleType
+
class BasePermissionDecorator(object):
def __init__(self, func):
@@ -23,7 +25,7 @@ class BasePermissionDecorator(object):
return self.error("Your account is disabled")
return self.func(*args, **kwargs)
else:
- return self.error("Please login in first")
+ return self.error("Please login first")
def check_permission(self):
raise NotImplementedError()
@@ -53,3 +55,56 @@ class problem_permission_required(admin_role_required):
if self.request.user.problem_permission == ProblemPermission.NONE:
return False
return True
+
+
+def check_contest_permission(check_type="details"):
+ """
+ 只供Class based view 使用,检查用户是否有权进入该contest, check_type 可选 details, problems, ranks, submissions
+ 若通过验证,在view中可通过self.contest获得该contest
+ """
+
+ def decorator(func):
+ def _check_permission(*args, **kwargs):
+ self = args[0]
+ request = args[1]
+ user = request.user
+ if kwargs.get("contest_id"):
+ contest_id = kwargs.pop("contest_id")
+ else:
+ contest_id = request.GET.get("contest_id")
+ if not contest_id:
+ return self.error("Parameter contest_id doesn't exist.")
+
+ try:
+ # use self.contest to avoid query contest again in view.
+ self.contest = Contest.objects.select_related("created_by").get(id=contest_id, visible=True)
+ except Contest.DoesNotExist:
+ return self.error("Contest %s doesn't exist" % contest_id)
+
+ # creator or owner
+ if user.is_authenticated() and user.is_contest_admin(self.contest):
+ return func(*args, **kwargs)
+
+ if self.contest.contest_type == ContestType.PASSWORD_PROTECTED_CONTEST:
+ # Anonymous
+ if not user.is_authenticated():
+ return self.error("Please login first.")
+ # password error
+ if ("accessible_contests" not in request.session) or \
+ (self.contest.id not in request.session["accessible_contests"]):
+ return self.error("Password is required.")
+
+ # regular user get contest problems, ranks etc. before contest started
+ if self.contest.status == ContestStatus.CONTEST_NOT_START and check_type != "details":
+ return self.error("Contest has not started yet.")
+
+ # check does user have permission to get ranks, submissions in OI Contest
+ if self.contest.status == ContestStatus.CONTEST_UNDERWAY and self.contest.rule_type == ContestRuleType.OI:
+ if not self.contest.real_time_rank and (check_type == "ranks" or check_type == "submissions"):
+ return self.error(f"No permission to get {check_type}")
+
+ return func(*args, **kwargs)
+
+ return _check_permission
+
+ return decorator
diff --git a/account/middleware.py b/account/middleware.py
index f510fd2..245c32a 100644
--- a/account/middleware.py
+++ b/account/middleware.py
@@ -1,34 +1,50 @@
-import time
-
-import pytz
-from django.contrib import auth
-from django.utils import timezone
-from django.utils.translation import ugettext as _
+from django.db import connection
+from django.utils.timezone import now
+from django.utils.deprecation import MiddlewareMixin
from utils.api import JSONResponse
+from account.models import User
-class SessionSecurityMiddleware(object):
+class APITokenAuthMiddleware(MiddlewareMixin):
def process_request(self, request):
- if request.user.is_authenticated() and request.user.is_admin_role():
- if "last_activity" in request.session:
- # 24 hours passed since last visit
- if time.time() - request.session["last_activity"] >= 24 * 60 * 60:
- auth.logout(request)
- return JSONResponse.response({"error": "login-required", "data": _("Please login in first")})
- # update last active time
- request.session["last_activity"] = time.time()
+ appkey = request.META.get("HTTP_APPKEY")
+ if appkey:
+ try:
+ request.user = User.objects.get(open_api_appkey=appkey, open_api=True, is_disabled=False)
+ request.csrf_processing_done = True
+ except User.DoesNotExist:
+ pass
-class AdminRoleRequiredMiddleware(object):
+class SessionRecordMiddleware(MiddlewareMixin):
+ def process_request(self, request):
+ if request.user.is_authenticated():
+ session = request.session
+ session["user_agent"] = request.META.get("HTTP_USER_AGENT", "")
+ session["ip"] = request.META.get("HTTP_X_REAL_IP", request.META.get("REMOTE_ADDR"))
+ session["last_activity"] = now()
+ user_sessions = request.user.session_keys
+ if session.session_key not in user_sessions:
+ user_sessions.append(session.session_key)
+ request.user.save()
+
+
+class AdminRoleRequiredMiddleware(MiddlewareMixin):
def process_request(self, request):
path = request.path_info
if path.startswith("/admin/") or path.startswith("/api/admin/"):
- if not(request.user.is_authenticated() and request.user.is_admin_role()):
- return JSONResponse.response({"error": "login-required", "data": _("Please login in first")})
+ if not (request.user.is_authenticated() and request.user.is_admin_role()):
+ return JSONResponse.response({"error": "login-required", "data": "Please login in first"})
-class TimezoneMiddleware(object):
- def process_request(self, request):
- if request.user.is_authenticated():
- timezone.activate(pytz.timezone(request.user.userprofile.time_zone))
+class LogSqlMiddleware(MiddlewareMixin):
+ def process_response(self, request, response):
+ print("\033[94m", "#" * 30, "\033[0m")
+ time_threshold = 0.03
+ for query in connection.queries:
+ if float(query["time"]) > time_threshold:
+ print("\033[93m", query, "\n", "-" * 30, "\033[0m")
+ else:
+ print(query, "\n", "-" * 30)
+ return response
diff --git a/account/migrations/0001_initial.py b/account/migrations/0001_initial.py
index c6de9c3..e1e588e 100644
--- a/account/migrations/0001_initial.py
+++ b/account/migrations/0001_initial.py
@@ -50,7 +50,7 @@ class Migration(migrations.Migration):
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('problems_status', jsonfield.fields.JSONField(default={})),
- ('avatar', models.CharField(default=account.models._random_avatar, max_length=50)),
+ ('avatar', models.CharField(default="default.png", max_length=50)),
('blog', models.URLField(blank=True, null=True)),
('mood', models.CharField(blank=True, max_length=200, null=True)),
('accepted_problem_number', models.IntegerField(default=0)),
diff --git a/account/migrations/0003_userprofile_total_score.py b/account/migrations/0003_userprofile_total_score.py
new file mode 100644
index 0000000..f7efe88
--- /dev/null
+++ b/account/migrations/0003_userprofile_total_score.py
@@ -0,0 +1,29 @@
+# -*- coding: utf-8 -*-
+# Generated by Django 1.9.6 on 2017-08-20 02:03
+from __future__ import unicode_literals
+
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ('account', '0002_auto_20170209_1028'),
+ ]
+
+ operations = [
+ migrations.AddField(
+ model_name='userprofile',
+ name='total_score',
+ field=models.BigIntegerField(default=0),
+ ),
+ migrations.RenameField(
+ model_name='userprofile',
+ old_name='accepted_problem_number',
+ new_name='accepted_number',
+ ),
+ migrations.RemoveField(
+ model_name='userprofile',
+ name='time_zone',
+ )
+ ]
diff --git a/account/migrations/0005_auto_20170830_1154.py b/account/migrations/0005_auto_20170830_1154.py
new file mode 100644
index 0000000..1ba8a94
--- /dev/null
+++ b/account/migrations/0005_auto_20170830_1154.py
@@ -0,0 +1,39 @@
+# -*- coding: utf-8 -*-
+# Generated by Django 1.11.4 on 2017-08-30 11:54
+from __future__ import unicode_literals
+
+from django.db import migrations, models
+import jsonfield.fields
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ('account', '0003_userprofile_total_score'),
+ ]
+
+ operations = [
+ migrations.RenameField(
+ model_name='userprofile',
+ old_name='problems_status',
+ new_name='acm_problems_status',
+ ),
+ migrations.AddField(
+ model_name='userprofile',
+ name='oi_problems_status',
+ field=jsonfield.fields.JSONField(default={}),
+ ),
+ migrations.RemoveField(
+ model_name='user',
+ name='real_name',
+ ),
+ migrations.RemoveField(
+ model_name='userprofile',
+ name='student_id',
+ ),
+ migrations.AddField(
+ model_name='userprofile',
+ name='real_name',
+ field=models.CharField(max_length=30, blank=True, null=True),
+ ),
+ ]
diff --git a/account/migrations/0006_user_session_keys.py b/account/migrations/0006_user_session_keys.py
new file mode 100644
index 0000000..6dc991a
--- /dev/null
+++ b/account/migrations/0006_user_session_keys.py
@@ -0,0 +1,36 @@
+# -*- coding: utf-8 -*-
+# Generated by Django 1.11.4 on 2017-09-16 06:22
+from __future__ import unicode_literals
+
+from django.db import migrations, models
+import jsonfield.fields
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ('account', '0005_auto_20170830_1154'),
+ ]
+
+ operations = [
+ migrations.AddField(
+ model_name='user',
+ name='session_keys',
+ field=jsonfield.fields.JSONField(default=[]),
+ ),
+ migrations.RenameField(
+ model_name='userprofile',
+ old_name='phone_number',
+ new_name='github',
+ ),
+ migrations.AlterField(
+ model_name='userprofile',
+ name='avatar',
+ field=models.CharField(default='/static/avatar/default.png', max_length=50),
+ ),
+ migrations.AlterField(
+ model_name='userprofile',
+ name='github',
+ field=models.CharField(blank=True, max_length=50, null=True),
+ ),
+ ]
diff --git a/account/migrations/0008_auto_20171011_1214.py b/account/migrations/0008_auto_20171011_1214.py
new file mode 100644
index 0000000..f27cac8
--- /dev/null
+++ b/account/migrations/0008_auto_20171011_1214.py
@@ -0,0 +1,105 @@
+# -*- coding: utf-8 -*-
+# Generated by Django 1.11.4 on 2017-10-11 12:14
+from __future__ import unicode_literals
+
+import django.contrib.postgres.fields.jsonb
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ('account', '0006_user_session_keys'),
+ ]
+
+ operations = [
+ migrations.RemoveField(
+ model_name='userprofile',
+ name='language',
+ ),
+ migrations.AlterField(
+ model_name='user',
+ name='admin_type',
+ field=models.CharField(default='Regular User', max_length=32),
+ ),
+ migrations.AlterField(
+ model_name='user',
+ name='auth_token',
+ field=models.CharField(max_length=32, null=True),
+ ),
+ migrations.AlterField(
+ model_name='user',
+ name='email',
+ field=models.EmailField(max_length=64, null=True),
+ ),
+ migrations.AlterField(
+ model_name='user',
+ name='open_api_appkey',
+ field=models.CharField(max_length=32, null=True),
+ ),
+ migrations.AlterField(
+ model_name='user',
+ name='problem_permission',
+ field=models.CharField(default='None', max_length=32),
+ ),
+ migrations.AlterField(
+ model_name='user',
+ name='reset_password_token',
+ field=models.CharField(max_length=32, null=True),
+ ),
+ migrations.AlterField(
+ model_name='user',
+ name='session_keys',
+ field=django.contrib.postgres.fields.jsonb.JSONField(default=list),
+ ),
+ migrations.AlterField(
+ model_name='user',
+ name='tfa_token',
+ field=models.CharField(max_length=32, null=True),
+ ),
+ migrations.AlterField(
+ model_name='user',
+ name='username',
+ field=models.CharField(max_length=32, unique=True),
+ ),
+ migrations.AlterField(
+ model_name='userprofile',
+ name='acm_problems_status',
+ field=django.contrib.postgres.fields.jsonb.JSONField(default=dict),
+ ),
+ migrations.AlterField(
+ model_name='userprofile',
+ name='avatar',
+ field=models.CharField(default='/static/avatar/default.png', max_length=256),
+ ),
+ migrations.AlterField(
+ model_name='userprofile',
+ name='github',
+ field=models.CharField(blank=True, max_length=64, null=True),
+ ),
+ migrations.AlterField(
+ model_name='userprofile',
+ name='major',
+ field=models.CharField(blank=True, max_length=64, null=True),
+ ),
+ migrations.AlterField(
+ model_name='userprofile',
+ name='mood',
+ field=models.CharField(blank=True, max_length=256, null=True),
+ ),
+ migrations.AlterField(
+ model_name='userprofile',
+ name='oi_problems_status',
+ field=django.contrib.postgres.fields.jsonb.JSONField(default=dict),
+ ),
+ migrations.AlterField(
+ model_name='userprofile',
+ name='real_name',
+ field=models.CharField(blank=True, max_length=32, null=True),
+ ),
+ migrations.AlterField(
+ model_name='userprofile',
+ name='school',
+ field=models.CharField(blank=True, max_length=64, null=True),
+ ),
+ ]
diff --git a/account/migrations/0009_auto_20171125_1514.py b/account/migrations/0009_auto_20171125_1514.py
new file mode 100644
index 0000000..b476b78
--- /dev/null
+++ b/account/migrations/0009_auto_20171125_1514.py
@@ -0,0 +1,20 @@
+# -*- coding: utf-8 -*-
+# Generated by Django 1.11.4 on 2017-11-25 15:14
+from __future__ import unicode_literals
+
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ('account', '0008_auto_20171011_1214'),
+ ]
+
+ operations = [
+ migrations.AlterField(
+ model_name='userprofile',
+ name='avatar',
+ field=models.CharField(default='/public/avatar/default.png', max_length=256),
+ ),
+ ]
diff --git a/account/models.py b/account/models.py
index 0e53c36..ceac9af 100644
--- a/account/models.py
+++ b/account/models.py
@@ -1,6 +1,7 @@
from django.contrib.auth.models import AbstractBaseUser
+from django.conf import settings
from django.db import models
-from jsonfield import JSONField
+from utils.models import JSONField
class AdminType(object):
@@ -9,11 +10,6 @@ class AdminType(object):
SUPER_ADMIN = "Super Admin"
-class ProblemSolutionStatus(object):
- ACCEPTED = 1
- PENDING = 2
-
-
class ProblemPermission(object):
NONE = "None"
OWN = "Own"
@@ -24,26 +20,26 @@ class UserManager(models.Manager):
use_in_migrations = True
def get_by_natural_key(self, username):
- return self.get(**{self.model.USERNAME_FIELD: username})
+ return self.get(**{f"{self.model.USERNAME_FIELD}__iexact": username})
class User(AbstractBaseUser):
- username = models.CharField(max_length=30, unique=True)
- real_name = models.CharField(max_length=30, null=True)
- email = models.EmailField(max_length=254, null=True)
+ username = models.CharField(max_length=32, unique=True)
+ email = models.EmailField(max_length=64, null=True)
create_time = models.DateTimeField(auto_now_add=True, null=True)
# One of UserType
- admin_type = models.CharField(max_length=24, default=AdminType.REGULAR_USER)
- problem_permission = models.CharField(max_length=24, default=ProblemPermission.NONE)
- reset_password_token = models.CharField(max_length=40, null=True)
+ admin_type = models.CharField(max_length=32, default=AdminType.REGULAR_USER)
+ problem_permission = models.CharField(max_length=32, default=ProblemPermission.NONE)
+ reset_password_token = models.CharField(max_length=32, null=True)
reset_password_token_expire_time = models.DateTimeField(null=True)
# SSO auth token
- auth_token = models.CharField(max_length=40, null=True)
+ auth_token = models.CharField(max_length=32, null=True)
two_factor_auth = models.BooleanField(default=False)
- tfa_token = models.CharField(max_length=40, null=True)
+ tfa_token = models.CharField(max_length=32, null=True)
+ session_keys = JSONField(default=list)
# open api key
open_api = models.BooleanField(default=False)
- open_api_appkey = models.CharField(max_length=35, null=True)
+ open_api_appkey = models.CharField(max_length=32, null=True)
is_disabled = models.BooleanField(default=False)
USERNAME_FIELD = "username"
@@ -63,42 +59,59 @@ class User(AbstractBaseUser):
def can_mgmt_all_problem(self):
return self.problem_permission == ProblemPermission.ALL
+ def is_contest_admin(self, contest):
+ return self.is_authenticated() and (contest.created_by == self or self.admin_type == AdminType.SUPER_ADMIN)
+
class Meta:
db_table = "user"
-def _random_avatar():
- import random
- return "/static/img/avatar/avatar-" + str(random.randint(1, 20)) + ".png"
-
-
class UserProfile(models.Model):
- user = models.OneToOneField(User)
- # Store user problem solution status with json string format
- # {"problems": {1: ProblemSolutionStatus.ACCEPTED}, "contest_problems": {20: ProblemSolutionStatus.PENDING)}
- problems_status = JSONField(default={})
- avatar = models.CharField(max_length=50, default=_random_avatar)
+ user = models.OneToOneField(User, on_delete=models.CASCADE)
+ # acm_problems_status examples:
+ # {
+ # "problems": {
+ # "1": {
+ # "status": JudgeStatus.ACCEPTED,
+ # "_id": "1000"
+ # }
+ # },
+ # "contest_problems": {
+ # "1": {
+ # "status": JudgeStatus.ACCEPTED,
+ # "_id": "1000"
+ # }
+ # }
+ # }
+ acm_problems_status = JSONField(default=dict)
+ # like acm_problems_status, merely add "score" field
+ oi_problems_status = JSONField(default=dict)
+
+ real_name = models.CharField(max_length=32, blank=True, null=True)
+ avatar = models.CharField(max_length=256, default=f"{settings.AVATAR_URI_PREFIX}/default.png")
blog = models.URLField(blank=True, null=True)
- mood = models.CharField(max_length=200, blank=True, null=True)
- accepted_problem_number = models.IntegerField(default=0)
+ mood = models.CharField(max_length=256, blank=True, null=True)
+ github = models.CharField(max_length=64, blank=True, null=True)
+ school = models.CharField(max_length=64, blank=True, null=True)
+ major = models.CharField(max_length=64, blank=True, null=True)
+ # for ACM
+ accepted_number = models.IntegerField(default=0)
+ # for OI
+ total_score = models.BigIntegerField(default=0)
submission_number = models.IntegerField(default=0)
- phone_number = models.CharField(max_length=15, blank=True, null=True)
- school = models.CharField(max_length=200, blank=True, null=True)
- major = models.CharField(max_length=200, blank=True, null=True)
- student_id = models.CharField(max_length=15, blank=True, null=True)
- time_zone = models.CharField(max_length=32, blank=True, null=True)
- language = models.CharField(max_length=32, blank=True, null=True)
def add_accepted_problem_number(self):
- self.accepted_problem_number = models.F("accepted_problem_number") + 1
+ self.accepted_number = models.F("accepted_number") + 1
self.save()
def add_submission_number(self):
self.submission_number = models.F("submission_number") + 1
self.save()
- def minus_accepted_problem_number(self):
- self.accepted_problem_number = models.F("accepted_problem_number") - 1
+ # 计算总分时, 应先减掉上次该题所得分数, 然后再加上本次所得分数
+ def add_score(self, this_time_score, last_time_score=None):
+ last_time_score = last_time_score or 0
+ self.total_score = models.F("total_score") - last_time_score + this_time_score
self.save()
class Meta:
diff --git a/account/serializers.py b/account/serializers.py
index 112d59d..d346677 100644
--- a/account/serializers.py
+++ b/account/serializers.py
@@ -1,25 +1,51 @@
-from utils.api import DateTimeTZField, serializers
+from django import forms
-from .models import AdminType, ProblemPermission, User
+from utils.api import DateTimeTZField, serializers, UsernameSerializer
+
+from .models import AdminType, ProblemPermission, User, UserProfile
class UserLoginSerializer(serializers.Serializer):
- username = serializers.CharField(max_length=30)
- password = serializers.CharField(max_length=30)
- tfa_code = serializers.CharField(min_length=6, max_length=6, required=False, allow_null=True)
+ username = serializers.CharField()
+ password = serializers.CharField()
+ tfa_code = serializers.CharField(required=False, allow_blank=True)
+
+
+class UsernameOrEmailCheckSerializer(serializers.Serializer):
+ username = serializers.CharField(required=False)
+ email = serializers.EmailField(required=False)
class UserRegisterSerializer(serializers.Serializer):
- username = serializers.CharField(max_length=30)
- password = serializers.CharField(max_length=30, min_length=6)
- email = serializers.EmailField(max_length=254)
- captcha = serializers.CharField(max_length=4, min_length=4)
+ username = serializers.CharField(max_length=32)
+ password = serializers.CharField(min_length=6)
+ email = serializers.EmailField(max_length=64)
+ captcha = serializers.CharField()
class UserChangePasswordSerializer(serializers.Serializer):
old_password = serializers.CharField()
- new_password = serializers.CharField(max_length=30, min_length=6)
- captcha = serializers.CharField(max_length=4, min_length=4)
+ new_password = serializers.CharField(min_length=6)
+ tfa_code = serializers.CharField(required=False, allow_blank=True)
+
+
+class UserChangeEmailSerializer(serializers.Serializer):
+ password = serializers.CharField()
+ new_email = serializers.EmailField(max_length=64)
+ tfa_code = serializers.CharField(required=False, allow_blank=True)
+
+
+class GenerateUserSerializer(serializers.Serializer):
+ prefix = serializers.CharField(max_length=16, allow_blank=True)
+ suffix = serializers.CharField(max_length=16, allow_blank=True)
+ number_from = serializers.IntegerField()
+ number_to = serializers.IntegerField()
+ password_length = serializers.IntegerField(max_value=16, default=8)
+
+
+class ImportUserSeralizer(serializers.Serializer):
+ users = serializers.ListField(
+ child=serializers.ListField(child=serializers.CharField(max_length=64)))
class UserSerializer(serializers.ModelSerializer):
@@ -28,16 +54,33 @@ class UserSerializer(serializers.ModelSerializer):
class Meta:
model = User
- fields = ["id", "username", "real_name", "email", "admin_type", "problem_permission",
+ fields = ["id", "username", "email", "admin_type", "problem_permission",
"create_time", "last_login", "two_factor_auth", "open_api", "is_disabled"]
+class UserProfileSerializer(serializers.ModelSerializer):
+ user = UserSerializer()
+ acm_problems_status = serializers.JSONField()
+ oi_problems_status = serializers.JSONField()
+
+ class Meta:
+ model = UserProfile
+ fields = "__all__"
+
+
+class UserInfoSerializer(serializers.ModelSerializer):
+ acm_problems_status = serializers.JSONField()
+ oi_problems_status = serializers.JSONField()
+
+ class Meta:
+ model = UserProfile
+
+
class EditUserSerializer(serializers.Serializer):
id = serializers.IntegerField()
- username = serializers.CharField(max_length=30)
- real_name = serializers.CharField(max_length=30)
- password = serializers.CharField(max_length=30, min_length=6, allow_blank=True, required=False, default=None)
- email = serializers.EmailField(max_length=254)
+ username = serializers.CharField(max_length=32)
+ password = serializers.CharField(min_length=6, allow_blank=True, required=False, default=None)
+ email = serializers.EmailField(max_length=64)
admin_type = serializers.ChoiceField(choices=(AdminType.REGULAR_USER, AdminType.ADMIN, AdminType.SUPER_ADMIN))
problem_permission = serializers.ChoiceField(choices=(ProblemPermission.NONE, ProblemPermission.OWN,
ProblemPermission.ALL))
@@ -46,21 +89,42 @@ class EditUserSerializer(serializers.Serializer):
is_disabled = serializers.BooleanField()
+class EditUserProfileSerializer(serializers.Serializer):
+ real_name = serializers.CharField(max_length=32, allow_null=True, required=False)
+ avatar = serializers.CharField(max_length=256, allow_null=True, allow_blank=True, required=False)
+ blog = serializers.URLField(max_length=256, allow_null=True, allow_blank=True, required=False)
+ mood = serializers.CharField(max_length=256, allow_null=True, allow_blank=True, required=False)
+ github = serializers.CharField(max_length=64, allow_null=True, allow_blank=True, required=False)
+ school = serializers.CharField(max_length=64, allow_null=True, allow_blank=True, required=False)
+ major = serializers.CharField(max_length=64, allow_null=True, allow_blank=True, required=False)
+
+
class ApplyResetPasswordSerializer(serializers.Serializer):
email = serializers.EmailField()
- captcha = serializers.CharField(max_length=4, min_length=4)
+ captcha = serializers.CharField()
class ResetPasswordSerializer(serializers.Serializer):
- token = serializers.CharField(min_length=1, max_length=40)
- password = serializers.CharField(min_length=6, max_length=30)
- captcha = serializers.CharField(max_length=4, min_length=4)
+ token = serializers.CharField()
+ password = serializers.CharField(min_length=6)
+ captcha = serializers.CharField()
class SSOSerializer(serializers.Serializer):
- appkey = serializers.CharField(max_length=35)
- token = serializers.CharField(max_length=40)
+ appkey = serializers.CharField()
+ token = serializers.CharField()
class TwoFactorAuthCodeSerializer(serializers.Serializer):
code = serializers.IntegerField()
+
+
+class ImageUploadForm(forms.Form):
+ image = forms.FileField()
+
+
+class RankInfoSerializer(serializers.ModelSerializer):
+ user = UsernameSerializer()
+
+ class Meta:
+ model = UserProfile
diff --git a/account/tasks.py b/account/tasks.py
index 0aacec9..3e7c1d2 100644
--- a/account/tasks.py
+++ b/account/tasks.py
@@ -1,6 +1,31 @@
-from celery import shared_task
+import logging
-from utils.shortcuts import send_email
+from celery import shared_task
+from envelopes import Envelope
+
+from options.options import SysOptions
+
+logger = logging.getLogger(__name__)
+
+
+def send_email(from_name, to_email, to_name, subject, content):
+ smtp = SysOptions.smtp_config
+ if not smtp:
+ return
+ envlope = Envelope(from_addr=(smtp["email"], from_name),
+ to_addr=(to_email, to_name),
+ subject=subject,
+ html_body=content)
+ try:
+ envlope.send(smtp["server"],
+ login=smtp["email"],
+ password=smtp["password"],
+ port=smtp["port"],
+ tls=smtp["tls"])
+ return True
+ except Exception as e:
+ logger.exception(e)
+ return False
@shared_task
diff --git a/account/templates/reset_password_email.html b/account/templates/reset_password_email.html
index 5a0b591..b2f76f8 100644
--- a/account/templates/reset_password_email.html
+++ b/account/templates/reset_password_email.html
@@ -8,7 +8,7 @@
|
- {{ website_name }} 登录信息找回
+ {{ website_name }}
|
@@ -32,18 +32,18 @@
|
- 您刚刚在 {{ website_name }} 申请了找回登录信息服务。
+ We received a request to reset your password for {{ website_name }}.
|
|
- 请在30分钟内点击下面链接设置您的新密码:
+ You can use the following link to reset your password in 20 minutes.
|
|
重置密码
+ style="color: rgb(255,255,255);text-decoration: none;display: block;min-height: 39px;width: 158px;line-height: 39px;background-color:rgb(80,165,230);font-size:20px;text-align:center;">Reset Password
|
@@ -51,7 +51,7 @@
|
- 如果上面的链接点击无效,请复制以下链接至浏览器的地址栏直接打开。
+ If the button above doesn't work, please copy the following link to your browser and press enter.
|
@@ -63,8 +63,7 @@
|
- 如果您没有提出过该申请,请忽略此邮件。有可能是其他用户误填了您的邮件地址,我们不会对你的帐户进行任何修改。
- 请不要向他人透露本邮件的内容,否则可能会导致您的账号被盗。
+ If you did not ask that, please ignore this email. It will expire and become useless in 20 minutes.
|
diff --git a/account/tests.py b/account/tests.py
index d4addd5..97880aa 100644
--- a/account/tests.py
+++ b/account/tests.py
@@ -1,13 +1,19 @@
import time
+
from unittest import mock
+from datetime import timedelta
+from copy import deepcopy
from django.contrib import auth
+from django.utils.timezone import now
from otpauth import OtpAuth
from utils.api.tests import APIClient, APITestCase
from utils.shortcuts import rand_str
+from options.options import SysOptions
from .models import AdminType, ProblemPermission, User
+from utils.constants import ContestRuleType
class PermissionDecoratorTest(APITestCase):
@@ -28,6 +34,54 @@ class PermissionDecoratorTest(APITestCase):
pass
+class DuplicateUserCheckAPITest(APITestCase):
+ def setUp(self):
+ user = self.create_user("test", "test123", login=False)
+ user.email = "test@test.com"
+ user.save()
+ self.url = self.reverse("check_username_or_email")
+
+ def test_duplicate_username(self):
+ resp = self.client.post(self.url, data={"username": "test"})
+ data = resp.data["data"]
+ self.assertEqual(data["username"], True)
+ resp = self.client.post(self.url, data={"username": "Test"})
+ self.assertEqual(resp.data["data"]["username"], True)
+
+ def test_ok_username(self):
+ resp = self.client.post(self.url, data={"username": "test1"})
+ data = resp.data["data"]
+ self.assertFalse(data["username"])
+
+ def test_duplicate_email(self):
+ resp = self.client.post(self.url, data={"email": "test@test.com"})
+ self.assertEqual(resp.data["data"]["email"], True)
+ resp = self.client.post(self.url, data={"email": "Test@Test.com"})
+ self.assertTrue(resp.data["data"]["email"])
+
+ def test_ok_email(self):
+ resp = self.client.post(self.url, data={"email": "aa@test.com"})
+ self.assertFalse(resp.data["data"]["email"])
+
+
+class TFARequiredCheckAPITest(APITestCase):
+ def setUp(self):
+ self.url = self.reverse("tfa_required_check")
+ self.create_user("test", "test123", login=False)
+
+ def test_not_required_tfa(self):
+ resp = self.client.post(self.url, data={"username": "test"})
+ self.assertSuccess(resp)
+ self.assertEqual(resp.data["data"]["result"], False)
+
+ def test_required_tfa(self):
+ user = User.objects.first()
+ user.two_factor_auth = True
+ user.save()
+ resp = self.client.post(self.url, data={"username": "test"})
+ self.assertEqual(resp.data["data"]["result"], True)
+
+
class UserLoginAPITest(APITestCase):
def setUp(self):
self.username = self.password = "test"
@@ -49,6 +103,12 @@ class UserLoginAPITest(APITestCase):
user = auth.get_user(self.client)
self.assertTrue(user.is_authenticated())
+ def test_login_with_correct_info_upper_username(self):
+ resp = self.client.post(self.login_url, data={"username": self.username.upper(), "password": self.password})
+ self.assertDictEqual(resp.data, {"error": None, "data": "Succeeded"})
+ user = auth.get_user(self.client)
+ self.assertTrue(user.is_authenticated())
+
def test_login_with_wrong_info(self):
response = self.client.post(self.login_url,
data={"username": self.username, "password": "invalid_password"})
@@ -87,11 +147,18 @@ class UserLoginAPITest(APITestCase):
response = self.client.post(self.login_url,
data={"username": self.username,
"password": self.password})
- self.assertDictEqual(response.data, {"error": None, "data": "tfa_required"})
+ self.assertDictEqual(response.data, {"error": "error", "data": "tfa_required"})
user = auth.get_user(self.client)
self.assertFalse(user.is_authenticated())
+ def test_user_disabled(self):
+ self.user.is_disabled = True
+ self.user.save()
+ resp = self.client.post(self.login_url, data={"username": self.username,
+ "password": self.password})
+ self.assertDictEqual(resp.data, {"error": "error", "data": "Your account has been disabled"})
+
class CaptchaTest(APITestCase):
def _set_captcha(self, session):
@@ -112,6 +179,11 @@ class UserRegisterAPITest(CaptchaTest):
"real_name": "real_name", "email": "test@qduoj.com",
"captcha": self._set_captcha(self.client.session)}
+ def test_website_config_limit(self):
+ SysOptions.allow_register = False
+ resp = self.client.post(self.register_url, data=self.data)
+ self.assertDictEqual(resp.data, {"error": "error", "data": "Register function has been disabled by admin"})
+
def test_invalid_captcha(self):
self.data["captcha"] = "****"
response = self.client.post(self.register_url, data=self.data)
@@ -142,23 +214,206 @@ class UserRegisterAPITest(CaptchaTest):
self.assertDictEqual(response.data, {"error": "error", "data": "Email already exists"})
-class UserChangePasswordAPITest(CaptchaTest):
+class SessionManagementAPITest(APITestCase):
+ def setUp(self):
+ self.create_user("test", "test123")
+ self.url = self.reverse("session_management_api")
+ # launch a request to provide session data
+ login_url = self.reverse("user_login_api")
+ self.client.post(login_url, data={"username": "test", "password": "test123"})
+
+ def test_get_sessions(self):
+ resp = self.client.get(self.url)
+ self.assertSuccess(resp)
+ data = resp.data["data"]
+ self.assertEqual(len(data), 1)
+
+ # def test_delete_session_key(self):
+ # resp = self.client.delete(self.url + "?session_key=" + self.session_key)
+ # self.assertSuccess(resp)
+
+ def test_delete_session_with_invalid_key(self):
+ resp = self.client.delete(self.url + "?session_key=aaaaaaaaaa")
+ self.assertDictEqual(resp.data, {"error": "error", "data": "Invalid session_key"})
+
+
+class UserProfileAPITest(APITestCase):
+ def setUp(self):
+ self.url = self.reverse("user_profile_api")
+
+ def test_get_profile_without_login(self):
+ resp = self.client.get(self.url)
+ self.assertDictEqual(resp.data, {"error": None, "data": None})
+
+ def test_get_profile(self):
+ self.create_user("test", "test123")
+ resp = self.client.get(self.url)
+ self.assertSuccess(resp)
+
+ def test_update_profile(self):
+ self.create_user("test", "test123")
+ update_data = {"real_name": "zemal", "submission_number": 233}
+ resp = self.client.put(self.url, data=update_data)
+ self.assertSuccess(resp)
+ data = resp.data["data"]
+ self.assertEqual(data["real_name"], "zemal")
+ self.assertEqual(data["submission_number"], 0)
+
+
+class TwoFactorAuthAPITest(APITestCase):
+ def setUp(self):
+ self.url = self.reverse("two_factor_auth_api")
+ self.create_user("test", "test123")
+
+ def _get_tfa_code(self):
+ user = User.objects.first()
+ code = OtpAuth(user.tfa_token).totp()
+ if len(str(code)) < 6:
+ code = (6 - len(str(code))) * "0" + str(code)
+ return code
+
+ def test_get_image(self):
+ resp = self.client.get(self.url)
+ self.assertSuccess(resp)
+
+ def test_open_tfa_with_invalid_code(self):
+ self.test_get_image()
+ resp = self.client.post(self.url, data={"code": "000000"})
+ self.assertDictEqual(resp.data, {"error": "error", "data": "Invalid code"})
+
+ def test_open_tfa_with_correct_code(self):
+ self.test_get_image()
+ code = self._get_tfa_code()
+ resp = self.client.post(self.url, data={"code": code})
+ self.assertSuccess(resp)
+ user = User.objects.first()
+ self.assertEqual(user.two_factor_auth, True)
+
+ def test_close_tfa_with_invalid_code(self):
+ self.test_open_tfa_with_correct_code()
+ resp = self.client.post(self.url, data={"code": "000000"})
+ self.assertDictEqual(resp.data, {"error": "error", "data": "Invalid code"})
+
+ def test_close_tfa_with_correct_code(self):
+ self.test_open_tfa_with_correct_code()
+ code = self._get_tfa_code()
+ resp = self.client.put(self.url, data={"code": code})
+ self.assertSuccess(resp)
+ user = User.objects.first()
+ self.assertEqual(user.two_factor_auth, False)
+
+
+@mock.patch("account.views.oj.send_email_async.delay")
+class ApplyResetPasswordAPITest(CaptchaTest):
+ def setUp(self):
+ self.create_user("test", "test123", login=False)
+ user = User.objects.first()
+ user.email = "test@oj.com"
+ user.save()
+ self.url = self.reverse("apply_reset_password_api")
+ self.data = {"email": "test@oj.com", "captcha": self._set_captcha(self.client.session)}
+
+ def _refresh_captcha(self):
+ self.data["captcha"] = self._set_captcha(self.client.session)
+
+ def test_apply_reset_password(self, send_email_delay):
+ resp = self.client.post(self.url, data=self.data)
+ self.assertSuccess(resp)
+ send_email_delay.assert_called()
+
+ def test_apply_reset_password_twice_in_20_mins(self, send_email_delay):
+ self.test_apply_reset_password()
+ send_email_delay.reset_mock()
+ self._refresh_captcha()
+ resp = self.client.post(self.url, data=self.data)
+ self.assertDictEqual(resp.data, {"error": "error", "data": "You can only reset password once per 20 minutes"})
+ send_email_delay.assert_not_called()
+
+ def test_apply_reset_password_again_after_20_mins(self, send_email_delay):
+ self.test_apply_reset_password()
+ user = User.objects.first()
+ user.reset_password_token_expire_time = now() - timedelta(minutes=21)
+ user.save()
+ self._refresh_captcha()
+ self.test_apply_reset_password()
+
+
+class ResetPasswordAPITest(CaptchaTest):
+ def setUp(self):
+ self.create_user("test", "test123", login=False)
+ self.url = self.reverse("reset_password_api")
+ user = User.objects.first()
+ user.reset_password_token = "online_judge?"
+ user.reset_password_token_expire_time = now() + timedelta(minutes=20)
+ user.save()
+ self.data = {"token": user.reset_password_token,
+ "captcha": self._set_captcha(self.client.session),
+ "password": "test456"}
+
+ def test_reset_password_with_correct_token(self):
+ resp = self.client.post(self.url, data=self.data)
+ self.assertSuccess(resp)
+ self.assertTrue(self.client.login(username="test", password="test456"))
+
+ def test_reset_password_with_invalid_token(self):
+ self.data["token"] = "aaaaaaaaaaa"
+ resp = self.client.post(self.url, data=self.data)
+ self.assertDictEqual(resp.data, {"error": "error", "data": "Token does not exist"})
+
+ def test_reset_password_with_expired_token(self):
+ user = User.objects.first()
+ user.reset_password_token_expire_time = now() - timedelta(seconds=30)
+ user.save()
+ resp = self.client.post(self.url, data=self.data)
+ self.assertDictEqual(resp.data, {"error": "error", "data": "Token has expired"})
+
+
+class UserChangeEmailAPITest(APITestCase):
+ def setUp(self):
+ self.url = self.reverse("user_change_email_api")
+ self.user = self.create_user("test", "test123")
+ self.new_mail = "test@oj.com"
+ self.data = {"password": "test123", "new_email": self.new_mail}
+
+ def test_change_email_success(self):
+ resp = self.client.post(self.url, data=self.data)
+ self.assertSuccess(resp)
+
+ def test_wrong_password(self):
+ self.data["password"] = "aaaa"
+ resp = self.client.post(self.url, data=self.data)
+ self.assertDictEqual(resp.data, {"error": "error", "data": "Wrong password"})
+
+ def test_duplicate_email(self):
+ u = self.create_user("aa", "bb", login=False)
+ u.email = self.new_mail
+ u.save()
+ resp = self.client.post(self.url, data=self.data)
+ self.assertDictEqual(resp.data, {"error": "error", "data": "The email is owned by other account"})
+
+
+class UserChangePasswordAPITest(APITestCase):
def setUp(self):
- self.client = APIClient()
self.url = self.reverse("user_change_password_api")
# Create user at first
self.username = "test_user"
self.old_password = "testuserpassword"
self.new_password = "new_password"
- self.create_user(username=self.username, password=self.old_password, login=False)
+ self.user = self.create_user(username=self.username, password=self.old_password, login=False)
- self.data = {"old_password": self.old_password, "new_password": self.new_password,
- "captcha": self._set_captcha(self.client.session)}
+ self.data = {"old_password": self.old_password, "new_password": self.new_password}
+
+ def _get_tfa_code(self):
+ user = User.objects.first()
+ code = OtpAuth(user.tfa_token).totp()
+ if len(str(code)) < 6:
+ code = (6 - len(str(code))) * "0" + str(code)
+ return code
def test_login_required(self):
response = self.client.post(self.url, data=self.data)
- self.assertEqual(response.data, {"error": "permission-denied", "data": "Please login in first"})
+ self.assertEqual(response.data, {"error": "permission-denied", "data": "Please login first"})
def test_valid_ola_password(self):
self.assertTrue(self.client.login(username=self.username, password=self.old_password))
@@ -172,6 +427,58 @@ class UserChangePasswordAPITest(CaptchaTest):
response = self.client.post(self.url, data=self.data)
self.assertEqual(response.data, {"error": "error", "data": "Invalid old password"})
+ def test_tfa_code_required(self):
+ self.user.two_factor_auth = True
+ self.user.tfa_token = "tfa_token"
+ self.user.save()
+ self.assertTrue(self.client.login(username=self.username, password=self.old_password))
+ self.data["tfa_code"] = rand_str(6)
+ resp = self.client.post(self.url, data=self.data)
+ self.assertEqual(resp.data, {"error": "error", "data": "Invalid two factor verification code"})
+
+ self.data["tfa_code"] = self._get_tfa_code()
+ resp = self.client.post(self.url, data=self.data)
+ self.assertSuccess(resp)
+
+
+class UserRankAPITest(APITestCase):
+ def setUp(self):
+ self.url = self.reverse("user_rank_api")
+ self.create_user("test1", "test123", login=False)
+ self.create_user("test2", "test123", login=False)
+ test1 = User.objects.get(username="test1")
+ profile1 = test1.userprofile
+ profile1.submission_number = 10
+ profile1.accepted_number = 10
+ profile1.total_score = 240
+ profile1.save()
+
+ test2 = User.objects.get(username="test2")
+ profile2 = test2.userprofile
+ profile2.submission_number = 15
+ profile2.accepted_number = 10
+ profile2.total_score = 700
+ profile2.save()
+
+ def test_get_acm_rank(self):
+ resp = self.client.get(self.url, data={"rule": ContestRuleType.ACM})
+ self.assertSuccess(resp)
+ data = resp.data["data"]["results"]
+ self.assertEqual(data[0]["user"]["username"], "test1")
+ self.assertEqual(data[1]["user"]["username"], "test2")
+
+ def test_get_oi_rank(self):
+ resp = self.client.get(self.url, data={"rule": ContestRuleType.OI})
+ self.assertSuccess(resp)
+ data = resp.data["data"]["results"]
+ self.assertEqual(data[0]["user"]["username"], "test2")
+ self.assertEqual(data[1]["user"]["username"], "test1")
+
+
+class ProfileProblemDisplayIDRefreshAPITest(APITestCase):
+ def setUp(self):
+ pass
+
class AdminUserTest(APITestCase):
def setUp(self):
@@ -194,7 +501,6 @@ class AdminUserTest(APITestCase):
resp_data = response.data["data"]
self.assertEqual(resp_data["username"], self.username)
self.assertEqual(resp_data["email"], "test@qq.com")
- self.assertEqual(resp_data["real_name"], "test_name")
self.assertEqual(resp_data["open_api"], True)
self.assertEqual(resp_data["two_factor_auth"], False)
self.assertEqual(resp_data["is_disabled"], False)
@@ -249,3 +555,75 @@ class AdminUserTest(APITestCase):
# if `openapi_app_key` is not None, the value is not changed
self.assertTrue(resp_data["open_api"])
self.assertEqual(User.objects.get(id=self.regular_user.id).open_api_appkey, key)
+
+ def test_import_users(self):
+ data = {"users": [["user1", "pass1", "eami1@e.com"],
+ ["user2", "pass3", "eamil3@e.com"]]
+ }
+ resp = self.client.post(self.url, data)
+ self.assertSuccess(resp)
+ # successfully created 2 users
+ self.assertEqual(User.objects.all().count(), 4)
+
+ def test_import_duplicate_user(self):
+ data = {"users": [["user1", "pass1", "eami1@e.com"],
+ ["user1", "pass1", "eami1@e.com"]]
+ }
+ resp = self.client.post(self.url, data)
+ self.assertFailed(resp, "DETAIL: Key (username)=(user1) already exists.")
+ # no user is created
+ self.assertEqual(User.objects.all().count(), 2)
+
+ def test_delete_users(self):
+ self.test_import_users()
+ user_ids = User.objects.filter(username__in=["user1", "user2"]).values_list("id", flat=True)
+ user_ids = ",".join([str(id) for id in user_ids])
+ resp = self.client.delete(self.url + "?id=" + user_ids)
+ self.assertSuccess(resp)
+ self.assertEqual(User.objects.all().count(), 2)
+
+
+class GenerateUserAPITest(APITestCase):
+ def setUp(self):
+ self.create_super_admin()
+ self.url = self.reverse("generate_user_api")
+ self.data = {
+ "number_from": 100, "number_to": 105,
+ "prefix": "pre", "suffix": "suf",
+ "default_email": "test@test.com",
+ "password_length": 8
+ }
+
+ def test_error_case(self):
+ data = deepcopy(self.data)
+ data["prefix"] = "t" * 16
+ data["suffix"] = "s" * 14
+ resp = self.client.post(self.url, data=data)
+ self.assertEqual(resp.data["data"], "Username should not more than 32 characters")
+
+ data2 = deepcopy(self.data)
+ data2["number_from"] = 106
+ resp = self.client.post(self.url, data=data2)
+ self.assertEqual(resp.data["data"], "Start number must be lower than end number")
+
+ @mock.patch("account.views.admin.xlsxwriter.Workbook")
+ def test_generate_user_success(self, mock_workbook):
+ resp = self.client.post(self.url, data=self.data)
+ self.assertSuccess(resp)
+ mock_workbook.assert_called()
+
+
+class OpenAPIAppkeyAPITest(APITestCase):
+ def setUp(self):
+ self.user = self.create_super_admin()
+ self.url = self.reverse("open_api_appkey_api")
+
+ def test_reset_appkey(self):
+ resp = self.client.post(self.url, data={})
+ self.assertFailed(resp)
+
+ self.user.open_api = True
+ self.user.save()
+ resp = self.client.post(self.url, data={})
+ self.assertSuccess(resp)
+ self.assertEqual(resp.data["data"]["appkey"], User.objects.get(username=self.user.username).open_api_appkey)
diff --git a/account/urls/admin.py b/account/urls/admin.py
index b10741e..5826ae2 100644
--- a/account/urls/admin.py
+++ b/account/urls/admin.py
@@ -1,7 +1,8 @@
from django.conf.urls import url
-from ..views.admin import UserAdminAPI
+from ..views.admin import UserAdminAPI, GenerateUserAPI
urlpatterns = [
url(r"^user/?$", UserAdminAPI.as_view(), name="user_admin_api"),
+ url(r"^generate_user/?$", GenerateUserAPI.as_view(), name="generate_user_api"),
]
diff --git a/account/urls/oj.py b/account/urls/oj.py
index fa47e33..1b26e14 100644
--- a/account/urls/oj.py
+++ b/account/urls/oj.py
@@ -1,12 +1,30 @@
from django.conf.urls import url
from ..views.oj import (ApplyResetPasswordAPI, ResetPasswordAPI,
- UserChangePasswordAPI, UserLoginAPI, UserRegisterAPI)
+ UserChangePasswordAPI, UserRegisterAPI, UserChangeEmailAPI,
+ UserLoginAPI, UserLogoutAPI, UsernameOrEmailCheck,
+ AvatarUploadAPI, TwoFactorAuthAPI, UserProfileAPI,
+ UserRankAPI, CheckTFARequiredAPI, SessionManagementAPI,
+ ProfileProblemDisplayIDRefreshAPI, OpenAPIAppkeyAPI)
+
+from utils.captcha.views import CaptchaAPIView
urlpatterns = [
url(r"^login/?$", UserLoginAPI.as_view(), name="user_login_api"),
+ url(r"^logout/?$", UserLogoutAPI.as_view(), name="user_logout_api"),
url(r"^register/?$", UserRegisterAPI.as_view(), name="user_register_api"),
url(r"^change_password/?$", UserChangePasswordAPI.as_view(), name="user_change_password_api"),
+ url(r"^change_email/?$", UserChangeEmailAPI.as_view(), name="user_change_email_api"),
url(r"^apply_reset_password/?$", ApplyResetPasswordAPI.as_view(), name="apply_reset_password_api"),
- url(r"^reset_password/?$", ResetPasswordAPI.as_view(), name="apply_reset_password_api")
+ url(r"^reset_password/?$", ResetPasswordAPI.as_view(), name="reset_password_api"),
+ url(r"^captcha/?$", CaptchaAPIView.as_view(), name="show_captcha"),
+ url(r"^check_username_or_email", UsernameOrEmailCheck.as_view(), name="check_username_or_email"),
+ url(r"^profile/?$", UserProfileAPI.as_view(), name="user_profile_api"),
+ url(r"^profile/fresh_display_id", ProfileProblemDisplayIDRefreshAPI.as_view(), name="display_id_fresh"),
+ url(r"^upload_avatar/?$", AvatarUploadAPI.as_view(), name="avatar_upload_api"),
+ url(r"^tfa_required/?$", CheckTFARequiredAPI.as_view(), name="tfa_required_check"),
+ url(r"^two_factor_auth/?$", TwoFactorAuthAPI.as_view(), name="two_factor_auth_api"),
+ url(r"^user_rank/?$", UserRankAPI.as_view(), name="user_rank_api"),
+ url(r"^sessions/?$", SessionManagementAPI.as_view(), name="session_management_api"),
+ url(r"^open_api_appkey/?$", OpenAPIAppkeyAPI.as_view(), name="open_api_appkey_api"),
]
diff --git a/account/urls/user.py b/account/urls/user.py
deleted file mode 100644
index 1676ddc..0000000
--- a/account/urls/user.py
+++ /dev/null
@@ -1,12 +0,0 @@
-from django.conf.urls import url
-
-from ..views.user import (SSOAPI, AvatarUploadAPI, TwoFactorAuthAPI,
- UserInfoAPI, UserProfileAPI)
-
-urlpatterns = [
- url(r"^user/?$", UserInfoAPI.as_view(), name="user_info_api"),
- url(r"^profile/?$", UserProfileAPI.as_view(), name="user_profile_api"),
- url(r"^avatar/upload/?$", AvatarUploadAPI.as_view(), name="avatar_upload_api"),
- url(r"^sso/?$", SSOAPI.as_view(), name="sso_api"),
- url(r"^two_factor_auth/?$", TwoFactorAuthAPI.as_view(), name="two_factor_auth_api")
-]
diff --git a/account/views/admin.py b/account/views/admin.py
index 62c5115..171fa37 100644
--- a/account/views/admin.py
+++ b/account/views/admin.py
@@ -1,15 +1,48 @@
-from django.core.exceptions import MultipleObjectsReturned
-from django.db.models import Q
+import os
+import re
+import xlsxwriter
+from django.db import transaction, IntegrityError
+from django.db.models import Q
+from django.http import HttpResponse
+from django.contrib.auth.hashers import make_password
+
+from submission.models import Submission
from utils.api import APIView, validate_serializer
from utils.shortcuts import rand_str
from ..decorators import super_admin_required
-from ..models import AdminType, ProblemPermission, User
-from ..serializers import EditUserSerializer, UserSerializer
+from ..models import AdminType, ProblemPermission, User, UserProfile
+from ..serializers import EditUserSerializer, UserSerializer, GenerateUserSerializer
+from ..serializers import ImportUserSeralizer
class UserAdminAPI(APIView):
+ @validate_serializer(ImportUserSeralizer)
+ @super_admin_required
+ def post(self, request):
+ """
+ Import User
+ """
+ data = request.data["users"]
+
+ user_list = []
+ for user_data in data:
+ if len(user_data) != 3 or len(user_data[0]) > 32:
+ return self.error(f"Error occurred while processing data '{user_data}'")
+ user_list.append(User(username=user_data[0], password=make_password(user_data[1]), email=user_data[2]))
+
+ try:
+ with transaction.atomic():
+ ret = User.objects.bulk_create(user_list)
+ UserProfile.objects.bulk_create([UserProfile(user=user) for user in ret])
+ return self.success()
+ except IntegrityError as e:
+ # Extract detail from exception message
+ # duplicate key value violates unique constraint "user_username_key"
+ # DETAIL: Key (username)=(root11) already exists.
+ return self.error(str(e).split("\n")[1])
+
@validate_serializer(EditUserSerializer)
@super_admin_required
def put(self, request):
@@ -21,25 +54,13 @@ class UserAdminAPI(APIView):
user = User.objects.get(id=data["id"])
except User.DoesNotExist:
return self.error("User does not exist")
- try:
- user = User.objects.get(username=data["username"])
- if user.id != data["id"]:
- return self.error("Username already exists")
- except User.DoesNotExist:
- pass
-
- try:
- user = User.objects.get(email=data["email"])
- if user.id != data["id"]:
- return self.error("Email already exists")
- # Some old data has duplicate email
- except MultipleObjectsReturned:
+ if User.objects.filter(username=data["username"]).exclude(id=user.id).exists():
+ return self.error("Username already exists")
+ if User.objects.filter(email=data["email"].lower()).exclude(id=user.id).exists():
return self.error("Email already exists")
- except User.DoesNotExist:
- pass
+ pre_username = user.username
user.username = data["username"]
- user.real_name = data["real_name"]
user.email = data["email"]
user.admin_type = data["admin_type"]
user.is_disabled = data["is_disabled"]
@@ -72,6 +93,8 @@ class UserAdminAPI(APIView):
user.two_factor_auth = data["two_factor_auth"]
user.save()
+ if pre_username != user.username:
+ Submission.objects.filter(username=pre_username).update(username=user.username)
return self.success(UserSerializer(user).data)
@super_admin_required
@@ -91,7 +114,97 @@ class UserAdminAPI(APIView):
keyword = request.GET.get("keyword", None)
if keyword:
- user = user.filter(Q(username__contains=keyword) |
- Q(real_name__contains=keyword) |
- Q(email__contains=keyword))
+ user = user.filter(Q(username__icontains=keyword) |
+ Q(userprofile__real_name__icontains=keyword) |
+ Q(email__icontains=keyword))
return self.success(self.paginate_data(request, user, UserSerializer))
+
+ def delete_one(self, user_id):
+ try:
+ user = User.objects.get(id=user_id)
+ except User.DoesNotExist:
+ return f"User {user_id} does not exist"
+ if Submission.objects.filter(user_id=user_id).exists():
+ return f"Can't delete the user {user_id} as he/she has submissions"
+ user.delete()
+
+ @super_admin_required
+ def delete(self, request):
+ id = request.GET.get("id")
+ if not id:
+ return self.error("Invalid Parameter, id is required")
+ for user_id in id.split(","):
+ if user_id:
+ error = self.delete_one(user_id)
+ if error:
+ return self.error(error)
+ return self.success()
+
+
+class GenerateUserAPI(APIView):
+ @super_admin_required
+ def get(self, request):
+ """
+ download users excel
+ """
+ file_id = request.GET.get("file_id")
+ if not file_id:
+ return self.error("Invalid Parameter, file_id is required")
+ if not re.match(r"^[a-zA-Z0-9]+$", file_id):
+ return self.error("Illegal file_id")
+ file_path = f"/tmp/{file_id}.xlsx"
+ if not os.path.isfile(file_path):
+ return self.error("File does not exist")
+ with open(file_path, "rb") as f:
+ raw_data = f.read()
+ os.remove(file_path)
+ response = HttpResponse(raw_data)
+ response["Content-Disposition"] = f"attachment; filename=users.xlsx"
+ response["Content-Type"] = "application/xlsx"
+ return response
+
+ @validate_serializer(GenerateUserSerializer)
+ @super_admin_required
+ def post(self, request):
+ """
+ Generate User
+ """
+ data = request.data
+ number_max_length = max(len(str(data["number_from"])), len(str(data["number_to"])))
+ if number_max_length + len(data["prefix"]) + len(data["suffix"]) > 32:
+ return self.error("Username should not more than 32 characters")
+ if data["number_from"] > data["number_to"]:
+ return self.error("Start number must be lower than end number")
+
+ file_id = rand_str(8)
+ filename = f"/tmp/{file_id}.xlsx"
+ workbook = xlsxwriter.Workbook(filename)
+ worksheet = workbook.add_worksheet()
+ worksheet.set_column("A:B", 20)
+ worksheet.write("A1", "Username")
+ worksheet.write("B1", "Password")
+ i = 1
+
+ user_list = []
+ for number in range(data["number_from"], data["number_to"] + 1):
+ raw_password = rand_str(data["password_length"])
+ user = User(username=f"{data['prefix']}{number}{data['suffix']}", password=make_password(raw_password))
+ user.raw_password = raw_password
+ user_list.append(user)
+
+ try:
+ with transaction.atomic():
+
+ ret = User.objects.bulk_create(user_list)
+ UserProfile.objects.bulk_create([UserProfile(user=user) for user in ret])
+ for item in user_list:
+ worksheet.write_string(i, 0, item.username)
+ worksheet.write_string(i, 1, item.raw_password)
+ i += 1
+ workbook.close()
+ return self.success({"file_id": file_id})
+ except IntegrityError as e:
+ # Extract detail from exception message
+ # duplicate key value violates unique constraint "user_username_key"
+ # DETAIL: Key (username)=(root11) already exists.
+ return self.error(str(e).split("\n")[1])
diff --git a/account/views/oj.py b/account/views/oj.py
index 7db3bbb..5f25817 100644
--- a/account/views/oj.py
+++ b/account/views/oj.py
@@ -1,25 +1,154 @@
+import os
from datetime import timedelta
+from importlib import import_module
+import qrcode
from django.conf import settings
from django.contrib import auth
-from django.core.exceptions import MultipleObjectsReturned
+from django.template.loader import render_to_string
+from django.utils.decorators import method_decorator
from django.utils.timezone import now
+from django.views.decorators.csrf import ensure_csrf_cookie
from otpauth import OtpAuth
-from conf.models import WebsiteConfig
+from problem.models import Problem
+from utils.constants import ContestRuleType
+from options.options import SysOptions
from utils.api import APIView, validate_serializer
from utils.captcha import Captcha
-from utils.shortcuts import rand_str
-
+from utils.shortcuts import rand_str, img2base64, datetime2str
from ..decorators import login_required
from ..models import User, UserProfile
-from ..serializers import (ApplyResetPasswordSerializer,
- ResetPasswordSerializer,
+from ..serializers import (ApplyResetPasswordSerializer, ResetPasswordSerializer,
UserChangePasswordSerializer, UserLoginSerializer,
- UserRegisterSerializer)
+ UserRegisterSerializer, UsernameOrEmailCheckSerializer,
+ RankInfoSerializer, UserChangeEmailSerializer)
+from ..serializers import (TwoFactorAuthCodeSerializer, UserProfileSerializer,
+ EditUserProfileSerializer, ImageUploadForm)
from ..tasks import send_email_async
+class UserProfileAPI(APIView):
+ @method_decorator(ensure_csrf_cookie)
+ def get(self, request, **kwargs):
+ """
+ 判断是否登录, 若登录返回用户信息
+ """
+ user = request.user
+ if not user.is_authenticated():
+ return self.success()
+ username = request.GET.get("username")
+ try:
+ if username:
+ user = User.objects.get(username=username, is_disabled=False)
+ else:
+ user = request.user
+ except User.DoesNotExist:
+ return self.error("User does not exist")
+ return self.success(UserProfileSerializer(user.userprofile).data)
+
+ @validate_serializer(EditUserProfileSerializer)
+ @login_required
+ def put(self, request):
+ data = request.data
+ user_profile = request.user.userprofile
+ for k, v in data.items():
+ setattr(user_profile, k, v)
+ user_profile.save()
+ return self.success(UserProfileSerializer(user_profile).data)
+
+
+class AvatarUploadAPI(APIView):
+ request_parsers = ()
+
+ @login_required
+ def post(self, request):
+ form = ImageUploadForm(request.POST, request.FILES)
+ if form.is_valid():
+ avatar = form.cleaned_data["image"]
+ else:
+ return self.error("Invalid file content")
+ if avatar.size > 2 * 1024 * 1024:
+ return self.error("Picture is too large")
+ suffix = os.path.splitext(avatar.name)[-1].lower()
+ if suffix not in [".gif", ".jpg", ".jpeg", ".bmp", ".png"]:
+ return self.error("Unsupported file format")
+
+ name = rand_str(10) + suffix
+ with open(os.path.join(settings.AVATAR_UPLOAD_DIR, name), "wb") as img:
+ for chunk in avatar:
+ img.write(chunk)
+ user_profile = request.user.userprofile
+
+ user_profile.avatar = f"{settings.AVATAR_URI_PREFIX}/{name}"
+ user_profile.save()
+ return self.success("Succeeded")
+
+
+class TwoFactorAuthAPI(APIView):
+ @login_required
+ def get(self, request):
+ """
+ Get QR code
+ """
+ user = request.user
+ if user.two_factor_auth:
+ return self.error("2FA is already turned on")
+ token = rand_str()
+ user.tfa_token = token
+ user.save()
+
+ label = f"{SysOptions.website_name_shortcut}:{user.username}"
+ image = qrcode.make(OtpAuth(token).to_uri("totp", label, SysOptions.website_name))
+ return self.success(img2base64(image))
+
+ @login_required
+ @validate_serializer(TwoFactorAuthCodeSerializer)
+ def post(self, request):
+ """
+ Open 2FA
+ """
+ code = request.data["code"]
+ user = request.user
+ if OtpAuth(user.tfa_token).valid_totp(code):
+ user.two_factor_auth = True
+ user.save()
+ return self.success("Succeeded")
+ else:
+ return self.error("Invalid code")
+
+ @login_required
+ @validate_serializer(TwoFactorAuthCodeSerializer)
+ def put(self, request):
+ code = request.data["code"]
+ user = request.user
+ if not user.two_factor_auth:
+ return self.error("2FA is already turned off")
+ if OtpAuth(user.tfa_token).valid_totp(code):
+ user.two_factor_auth = False
+ user.save()
+ return self.success("Succeeded")
+ else:
+ return self.error("Invalid code")
+
+
+class CheckTFARequiredAPI(APIView):
+ @validate_serializer(UsernameOrEmailCheckSerializer)
+ def post(self, request):
+ """
+ Check TFA is required
+ """
+ data = request.data
+ result = False
+ if data.get("username"):
+ try:
+ user = User.objects.get(username=data["username"])
+ result = user.two_factor_auth
+ except User.DoesNotExist:
+ pass
+ return self.success({"result": result})
+
+
class UserLoginAPI(APIView):
@validate_serializer(UserLoginSerializer)
def post(self, request):
@@ -30,13 +159,15 @@ class UserLoginAPI(APIView):
user = auth.authenticate(username=data["username"], password=data["password"])
# None is returned if username or password is wrong
if user:
+ if user.is_disabled:
+ return self.error("Your account has been disabled")
if not user.two_factor_auth:
auth.login(request, user)
return self.success("Succeeded")
# `tfa_code` not in post data
if user.two_factor_auth and "tfa_code" not in data:
- return self.success("tfa_required")
+ return self.error("tfa_required")
if OtpAuth(user.tfa_token).valid_totp(data["tfa_code"]):
auth.login(request, user)
@@ -46,10 +177,30 @@ class UserLoginAPI(APIView):
else:
return self.error("Invalid username or password")
- # todo remove this, only for debug use
+
+class UserLogoutAPI(APIView):
def get(self, request):
- auth.login(request, auth.authenticate(username=request.GET["username"], password=request.GET["password"]))
- return self.success({})
+ auth.logout(request)
+ return self.success()
+
+
+class UsernameOrEmailCheck(APIView):
+ @validate_serializer(UsernameOrEmailCheckSerializer)
+ def post(self, request):
+ """
+ check username or email is duplicate
+ """
+ data = request.data
+ # True means already exist.
+ result = {
+ "username": False,
+ "email": False
+ }
+ if data.get("username"):
+ result["username"] = User.objects.filter(username=data["username"].lower()).exists()
+ if data.get("email"):
+ result["email"] = User.objects.filter(email=data["email"].lower()).exists()
+ return self.success(result)
class UserRegisterAPI(APIView):
@@ -58,27 +209,46 @@ class UserRegisterAPI(APIView):
"""
User register api
"""
+
+ if not SysOptions.allow_register:
+ return self.error("Register function has been disabled by admin")
+
data = request.data
captcha = Captcha(request)
if not captcha.check(data["captcha"]):
return self.error("Invalid captcha")
- try:
- User.objects.get(username=data["username"])
+ if User.objects.filter(username=data["username"]).exists():
return self.error("Username already exists")
- except User.DoesNotExist:
- pass
- try:
- User.objects.get(email=data["email"])
+ data["email"] = data["email"].lower()
+ if User.objects.filter(email=data["email"]).exists():
return self.error("Email already exists")
- # Some old data has duplicate email
- except MultipleObjectsReturned:
- return self.error("Email already exists")
- except User.DoesNotExist:
- user = User.objects.create(username=data["username"], email=data["email"])
- user.set_password(data["password"])
+ user = User.objects.create(username=data["username"], email=data["email"])
+ user.set_password(data["password"])
+ user.save()
+ UserProfile.objects.create(user=user)
+ return self.success("Succeeded")
+
+
+class UserChangeEmailAPI(APIView):
+ @validate_serializer(UserChangeEmailSerializer)
+ @login_required
+ def post(self, request):
+ data = request.data
+ user = auth.authenticate(username=request.user.username, password=data["password"])
+ if user:
+ if user.two_factor_auth:
+ if "tfa_code" not in data:
+ return self.error("tfa_required")
+ if not OtpAuth(user.tfa_token).valid_totp(data["tfa_code"]):
+ return self.error("Invalid two factor verification code")
+ data["new_email"] = data["new_email"].lower()
+ if User.objects.filter(email=data["new_email"]).exists():
+ return self.error("The email is owned by other account")
+ user.email = data["new_email"]
user.save()
- UserProfile.objects.create(user=user)
return self.success("Succeeded")
+ else:
+ return self.error("Wrong password")
class UserChangePasswordAPI(APIView):
@@ -89,12 +259,14 @@ class UserChangePasswordAPI(APIView):
User change password api
"""
data = request.data
- captcha = Captcha(request)
- if not captcha.check(data["captcha"]):
- return self.error("Invalid captcha")
username = request.user.username
user = auth.authenticate(username=username, password=data["old_password"])
if user:
+ if user.two_factor_auth:
+ if "tfa_code" not in data:
+ return self.error("tfa_required")
+ if not OtpAuth(user.tfa_token).valid_totp(data["tfa_code"]):
+ return self.error("Invalid two factor verification code")
user.set_password(data["new_password"])
user.save()
return self.success("Succeeded")
@@ -105,33 +277,33 @@ class UserChangePasswordAPI(APIView):
class ApplyResetPasswordAPI(APIView):
@validate_serializer(ApplyResetPasswordSerializer)
def post(self, request):
+ if request.user.is_authenticated():
+ return self.error("You have already logged in, are you kidding me? ")
data = request.data
captcha = Captcha(request)
- config = WebsiteConfig.objects.first()
if not captcha.check(data["captcha"]):
return self.error("Invalid captcha")
try:
- user = User.objects.get(email=data["email"])
+ user = User.objects.get(email__iexact=data["email"])
except User.DoesNotExist:
return self.error("User does not exist")
- if user.reset_password_token_expire_time and 0 < (
- user.reset_password_token_expire_time - now()).total_seconds() < 20 * 60:
+ if user.reset_password_token_expire_time and 0 < int(
+ (user.reset_password_token_expire_time - now()).total_seconds()) < 20 * 60:
return self.error("You can only reset password once per 20 minutes")
user.reset_password_token = rand_str()
-
user.reset_password_token_expire_time = now() + timedelta(minutes=20)
user.save()
- email_template = open("reset_password_email.html", "w",
- encoding="utf-8").read()
- email_template = email_template.replace("{{ username }}", user.username). \
- replace("{{ website_name }}", settings.WEBSITE_INFO["website_name"]). \
- replace("{{ link }}", settings.WEBSITE_INFO["url"] + "/reset_password/t/" +
- user.reset_password_token)
- send_email_async.delay(config.name,
+ render_data = {
+ "username": user.username,
+ "website_name": SysOptions.website_name,
+ "link": f"{SysOptions.website_base_url}/reset-password/{user.reset_password_token}"
+ }
+ email_html = render_to_string("reset_password_email.html", render_data)
+ send_email_async.delay(SysOptions.website_name,
user.email,
user.username,
- config.name + " 登录信息找回邮件",
- email_template)
+ f"{SysOptions.website_name} 登录信息找回邮件",
+ email_html)
return self.success("Succeeded")
@@ -145,10 +317,99 @@ class ResetPasswordAPI(APIView):
try:
user = User.objects.get(reset_password_token=data["token"])
except User.DoesNotExist:
- return self.error("Token dose not exist")
- if 0 < (user.reset_password_token_expire_time - now()).total_seconds() < 30 * 60:
- return self.error("Token expired")
+ return self.error("Token does not exist")
+ if user.reset_password_token_expire_time < now():
+ return self.error("Token has expired")
user.reset_password_token = None
+ user.two_factor_auth = False
user.set_password(data["password"])
user.save()
return self.success("Succeeded")
+
+
+class SessionManagementAPI(APIView):
+ @login_required
+ def get(self, request):
+ engine = import_module(settings.SESSION_ENGINE)
+ session_store = engine.SessionStore
+ current_session = request.session.session_key
+ session_keys = request.user.session_keys
+ result = []
+ modified = False
+ for key in session_keys[:]:
+ session = session_store(key)
+ # session does not exist or is expiry
+ if not session._session:
+ session_keys.remove(key)
+ modified = True
+ continue
+
+ s = {}
+ if current_session == key:
+ s["current_session"] = True
+ s["ip"] = session["ip"]
+ s["user_agent"] = session["user_agent"]
+ s["last_activity"] = datetime2str(session["last_activity"])
+ s["session_key"] = key
+ result.append(s)
+ if modified:
+ request.user.save()
+ return self.success(result)
+
+ @login_required
+ def delete(self, request):
+ session_key = request.GET.get("session_key")
+ if not session_key:
+ return self.error("Parameter Error")
+ request.session.delete(session_key)
+ if session_key in request.user.session_keys:
+ request.user.session_keys.remove(session_key)
+ request.user.save()
+ return self.success("Succeeded")
+ else:
+ return self.error("Invalid session_key")
+
+
+class UserRankAPI(APIView):
+ def get(self, request):
+ rule_type = request.GET.get("rule")
+ if rule_type not in ContestRuleType.choices():
+ rule_type = ContestRuleType.ACM
+ profiles = UserProfile.objects.select_related("user")\
+ .exclude(user__is_disabled=True)
+ if rule_type == ContestRuleType.ACM:
+ profiles = profiles.filter(submission_number__gt=0).order_by("-accepted_number", "submission_number")
+ else:
+ profiles = profiles.filter(total_score__gt=0).order_by("-total_score")
+ return self.success(self.paginate_data(request, profiles, RankInfoSerializer))
+
+
+class ProfileProblemDisplayIDRefreshAPI(APIView):
+ @login_required
+ def get(self, request):
+ profile = request.user.userprofile
+ acm_problems = profile.acm_problems_status.get("problems", {})
+ oi_problems = profile.oi_problems_status.get("problems", {})
+ ids = list(acm_problems.keys()) + list(oi_problems.keys())
+ if not ids:
+ return self.success()
+ display_ids = Problem.objects.filter(id__in=ids).values_list("_id", flat=True)
+ id_map = dict(zip(ids, display_ids))
+ for k, v in acm_problems.items():
+ v["_id"] = id_map[k]
+ for k, v in oi_problems.items():
+ v["_id"] = id_map[k]
+ profile.save(update_fields=["acm_problems_status", "oi_problems_status"])
+ return self.success()
+
+
+class OpenAPIAppkeyAPI(APIView):
+ @login_required
+ def post(self, request):
+ user = request.user
+ if not user.open_api:
+ return self.error("Permission denied")
+ api_appkey = rand_str()
+ user.open_api_appkey = api_appkey
+ user.save()
+ return self.success({"appkey": api_appkey})
diff --git a/account/views/user.py b/account/views/user.py
deleted file mode 100644
index 19eb893..0000000
--- a/account/views/user.py
+++ /dev/null
@@ -1,148 +0,0 @@
-import os
-from io import StringIO
-
-import qrcode
-from django.conf import settings
-from django.http import HttpResponse
-from otpauth import OtpAuth
-
-from conf.models import WebsiteConfig
-from utils.api import APIView, validate_serializer
-from utils.shortcuts import rand_str
-
-from ..decorators import login_required
-from ..models import User
-from ..serializers import (EditUserSerializer, SSOSerializer,
- TwoFactorAuthCodeSerializer, UserSerializer)
-
-
-class UserInfoAPI(APIView):
- @login_required
- def get(self, request):
- """
- Return user info api
- """
- return self.success(UserSerializer(request.user).data)
-
-
-class UserProfileAPI(APIView):
- @login_required
- def get(self, request):
- """
- Return user info api
- """
- return self.success(UserSerializer(request.user).data)
-
- @validate_serializer(EditUserSerializer)
- @login_required
- def put(self, request):
- data = request.data
- user_profile = request.user.userprofile
- if data["avatar"]:
- user_profile.avatar = data["avatar"]
- else:
- user_profile.mood = data["mood"]
- user_profile.blog = data["blog"]
- user_profile.school = data["school"]
- user_profile.student_id = data["student_id"]
- user_profile.phone_number = data["phone_number"]
- user_profile.major = data["major"]
- # Timezone & language 暂时不加
- user_profile.save()
- return self.success("Succeeded")
-
-
-class AvatarUploadAPI(APIView):
- def post(self, request):
- if "file" not in request.FILES:
- return self.error("Upload failed")
-
- f = request.FILES["file"]
- if f.size > 1024 * 1024:
- return self.error("Picture too large")
- if os.path.splitext(f.name)[-1].lower() not in [".gif", ".jpg", ".jpeg", ".bmp", ".png"]:
- return self.error("Unsupported file format")
-
- name = "avatar_" + rand_str(5) + os.path.splitext(f.name)[-1]
- with open(os.path.join(settings.IMAGE_UPLOAD_DIR, name), "wb") as img:
- for chunk in request.FILES["file"]:
- img.write(chunk)
- return self.success({"path": "/static/upload/" + name})
-
-
-class SSOAPI(APIView):
- @login_required
- def get(self, request):
- callback = request.GET.get("callback", None)
- if not callback:
- return self.error("Parameter Error")
- token = rand_str()
- request.user.auth_token = token
- request.user.save()
- return self.success({"redirect_url": callback + "?token=" + token,
- "callback": callback})
-
- @validate_serializer(SSOSerializer)
- def post(self, request):
- data = request.data
- try:
- User.objects.get(open_api_appkey=data["appkey"])
- except User.DoesNotExist:
- return self.error("Invalid appkey")
- try:
- user = User.objects.get(auth_token=data["token"])
- user.auth_token = None
- user.save()
- return self.success({"username": user.username,
- "id": user.id,
- "admin_type": user.admin_type,
- "avatar": user.userprofile.avatar})
- except User.DoesNotExist:
- return self.error("User does not exist")
-
-
-class TwoFactorAuthAPI(APIView):
- @login_required
- def get(self, request):
- """
- Get QR code
- """
- user = request.user
- if user.two_factor_auth:
- return self.error("Already open 2FA")
- token = rand_str()
- user.tfa_token = token
- user.save()
-
- config = WebsiteConfig.objects.first()
- image = qrcode.make(OtpAuth(token).to_uri("totp", config.base_url, config.name))
- buf = StringIO()
- image.save(buf, "gif")
-
- return HttpResponse(buf.getvalue(), "image/gif")
-
- @login_required
- @validate_serializer(TwoFactorAuthCodeSerializer)
- def post(self, request):
- """
- Open 2FA
- """
- code = request.data["code"]
- user = request.user
- if OtpAuth(user.tfa_token).valid_totp(code):
- user.two_factor_auth = True
- user.save()
- return self.success("Succeeded")
- else:
- return self.error("Invalid captcha")
-
- @login_required
- @validate_serializer(TwoFactorAuthCodeSerializer)
- def put(self, request):
- code = request.data["code"]
- user = request.user
- if OtpAuth(user.tfa_token).valid_totp(code):
- user.two_factor_auth = False
- user.save()
- else:
- return self.error("Invalid captcha")
diff --git a/announcement/migrations/0002_auto_20171011_1214.py b/announcement/migrations/0002_auto_20171011_1214.py
new file mode 100644
index 0000000..e2d5abe
--- /dev/null
+++ b/announcement/migrations/0002_auto_20171011_1214.py
@@ -0,0 +1,24 @@
+# -*- coding: utf-8 -*-
+# Generated by Django 1.11.4 on 2017-10-11 12:14
+from __future__ import unicode_literals
+
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ('announcement', '0001_initial'),
+ ]
+
+ operations = [
+ migrations.AlterField(
+ model_name='announcement',
+ name='title',
+ field=models.CharField(max_length=64),
+ ),
+ migrations.AlterModelOptions(
+ name='announcement',
+ options={'ordering': ('-create_time',)},
+ ),
+ ]
diff --git a/announcement/models.py b/announcement/models.py
index 186d4ea..a19b06c 100644
--- a/announcement/models.py
+++ b/announcement/models.py
@@ -5,7 +5,7 @@ from utils.models import RichTextField
class Announcement(models.Model):
- title = models.CharField(max_length=50)
+ title = models.CharField(max_length=64)
# HTML
content = RichTextField()
create_time = models.DateTimeField(auto_now_add=True)
@@ -15,3 +15,4 @@ class Announcement(models.Model):
class Meta:
db_table = "announcement"
+ ordering = ("-create_time",)
diff --git a/announcement/serializers.py b/announcement/serializers.py
index 0c0becc..b660a61 100644
--- a/announcement/serializers.py
+++ b/announcement/serializers.py
@@ -5,8 +5,8 @@ from .models import Announcement
class CreateAnnouncementSerializer(serializers.Serializer):
- title = serializers.CharField(max_length=50)
- content = serializers.CharField(max_length=10000)
+ title = serializers.CharField(max_length=64)
+ content = serializers.CharField(max_length=1024 * 1024 * 8)
visible = serializers.BooleanField()
@@ -21,6 +21,6 @@ class AnnouncementSerializer(serializers.ModelSerializer):
class EditAnnouncementSerializer(serializers.Serializer):
id = serializers.IntegerField()
- title = serializers.CharField(max_length=50)
- content = serializers.CharField(max_length=10000)
+ title = serializers.CharField(max_length=64)
+ content = serializers.CharField(max_length=1024 * 1024 * 8)
visible = serializers.BooleanField()
diff --git a/announcement/tests.py b/announcement/tests.py
index dd702ba..98caa1c 100644
--- a/announcement/tests.py
+++ b/announcement/tests.py
@@ -35,3 +35,14 @@ class AnnouncementAdminTest(APITestCase):
resp = self.client.delete(self.url + "?id=" + str(id))
self.assertSuccess(resp)
self.assertFalse(Announcement.objects.filter(id=id).exists())
+
+
+class AnnouncementAPITest(APITestCase):
+ def setUp(self):
+ self.user = self.create_super_admin()
+ Announcement.objects.create(title="title", content="content", visible=True, created_by=self.user)
+ self.url = self.reverse("announcement_api")
+
+ def test_get_announcement_list(self):
+ resp = self.client.get(self.url)
+ self.assertSuccess(resp)
diff --git a/announcement/urls/admin.py b/announcement/urls/admin.py
index 6b9ce0f..09673e6 100644
--- a/announcement/urls/admin.py
+++ b/announcement/urls/admin.py
@@ -1,6 +1,6 @@
from django.conf.urls import url
-from ..views import AnnouncementAdminAPI
+from ..views.admin import AnnouncementAdminAPI
urlpatterns = [
url(r"^announcement/?$", AnnouncementAdminAPI.as_view(), name="announcement_admin_api"),
diff --git a/announcement/urls/oj.py b/announcement/urls/oj.py
new file mode 100644
index 0000000..67178b0
--- /dev/null
+++ b/announcement/urls/oj.py
@@ -0,0 +1,7 @@
+from django.conf.urls import url
+
+from ..views.oj import AnnouncementAPI
+
+urlpatterns = [
+ url(r"^announcement/?$", AnnouncementAPI.as_view(), name="announcement_api"),
+]
diff --git a/announcement/views/__init__.py b/announcement/views/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/announcement/views.py b/announcement/views/admin.py
similarity index 84%
rename from announcement/views.py
rename to announcement/views/admin.py
index f607a7d..58d3578 100644
--- a/announcement/views.py
+++ b/announcement/views/admin.py
@@ -1,9 +1,9 @@
from account.decorators import super_admin_required
from utils.api import APIView, validate_serializer
-from .models import Announcement
-from .serializers import (AnnouncementSerializer, CreateAnnouncementSerializer,
- EditAnnouncementSerializer)
+from announcement.models import Announcement
+from announcement.serializers import (AnnouncementSerializer, CreateAnnouncementSerializer,
+ EditAnnouncementSerializer)
class AnnouncementAdminAPI(APIView):
@@ -28,13 +28,12 @@ class AnnouncementAdminAPI(APIView):
"""
data = request.data
try:
- announcement = Announcement.objects.get(id=data["id"])
+ announcement = Announcement.objects.get(id=data.pop("id"))
except Announcement.DoesNotExist:
return self.error("Announcement does not exist")
- announcement.title = data["title"]
- announcement.content = data["content"]
- announcement.visible = data["visible"]
+ for k, v in data.items():
+ setattr(announcement, k, v)
announcement.save()
return self.success(AnnouncementSerializer(announcement).data)
diff --git a/announcement/views/oj.py b/announcement/views/oj.py
new file mode 100644
index 0000000..1176c36
--- /dev/null
+++ b/announcement/views/oj.py
@@ -0,0 +1,10 @@
+from utils.api import APIView
+
+from announcement.models import Announcement
+from announcement.serializers import AnnouncementSerializer
+
+
+class AnnouncementAPI(APIView):
+ def get(self, request):
+ announcements = Announcement.objects.filter(visible=True)
+ return self.success(self.paginate_data(request, announcements, AnnouncementSerializer))
diff --git a/conf/migrations/0002_auto_20171011_1214.py b/conf/migrations/0002_auto_20171011_1214.py
new file mode 100644
index 0000000..ef355b5
--- /dev/null
+++ b/conf/migrations/0002_auto_20171011_1214.py
@@ -0,0 +1,39 @@
+# -*- coding: utf-8 -*-
+# Generated by Django 1.11.4 on 2017-10-11 12:14
+from __future__ import unicode_literals
+
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ('conf', '0001_initial'),
+ ]
+
+ operations = [
+ migrations.DeleteModel(
+ name='JudgeServerToken',
+ ),
+ migrations.DeleteModel(
+ name='SMTPConfig',
+ ),
+ migrations.DeleteModel(
+ name='WebsiteConfig',
+ ),
+ migrations.AlterField(
+ model_name='judgeserver',
+ name='hostname',
+ field=models.CharField(max_length=128),
+ ),
+ migrations.AlterField(
+ model_name='judgeserver',
+ name='judger_version',
+ field=models.CharField(max_length=32),
+ ),
+ migrations.AlterField(
+ model_name='judgeserver',
+ name='service_url',
+ field=models.CharField(blank=True, max_length=256, null=True),
+ ),
+ ]
diff --git a/conf/models.py b/conf/models.py
index 9a88fdf..4c6348d 100644
--- a/conf/models.py
+++ b/conf/models.py
@@ -2,55 +2,24 @@ from django.db import models
from django.utils import timezone
-class SMTPConfig(models.Model):
- server = models.CharField(max_length=128)
- port = models.IntegerField(default=25)
- email = models.CharField(max_length=128)
- password = models.CharField(max_length=128)
- tls = models.BooleanField()
-
- class Meta:
- db_table = "smtp_config"
-
-
-class WebsiteConfig(models.Model):
- base_url = models.CharField(max_length=128, default="http://127.0.0.1")
- name = models.CharField(max_length=32, default="Online Judge")
- name_shortcut = models.CharField(max_length=32, default="oj")
- footer = models.TextField(default="Online Judge Footer")
- # allow register
- allow_register = models.BooleanField(default=True)
- # submission list show all user's submission
- submission_list_show_all = models.BooleanField(default=True)
-
- class Meta:
- db_table = "website_config"
-
-
class JudgeServer(models.Model):
- hostname = models.CharField(max_length=64)
+ hostname = models.CharField(max_length=128)
ip = models.CharField(max_length=32, blank=True, null=True)
- judger_version = models.CharField(max_length=24)
+ judger_version = models.CharField(max_length=32)
cpu_core = models.IntegerField()
memory_usage = models.FloatField()
cpu_usage = models.FloatField()
last_heartbeat = models.DateTimeField()
create_time = models.DateTimeField(auto_now_add=True)
task_number = models.IntegerField(default=0)
- service_url = models.CharField(max_length=128, blank=True, null=True)
+ service_url = models.CharField(max_length=256, blank=True, null=True)
@property
def status(self):
- if (timezone.now() - self.last_heartbeat).total_seconds() > 5:
+ # 增加一秒延时,提高对网络环境的适应性
+ if (timezone.now() - self.last_heartbeat).total_seconds() > 6:
return "abnormal"
return "normal"
class Meta:
db_table = "judge_server"
-
-
-class JudgeServerToken(models.Model):
- token = models.CharField(max_length=32)
-
- class Meta:
- db_table = "judge_server_token"
diff --git a/conf/serializers.py b/conf/serializers.py
index 59b7203..7f0cf57 100644
--- a/conf/serializers.py
+++ b/conf/serializers.py
@@ -1,6 +1,6 @@
from utils.api import DateTimeTZField, serializers
-from .models import JudgeServer, SMTPConfig, WebsiteConfig
+from .models import JudgeServer
class EditSMTPConfigSerializer(serializers.Serializer):
@@ -15,31 +15,19 @@ class CreateSMTPConfigSerializer(EditSMTPConfigSerializer):
password = serializers.CharField(max_length=128)
-class SMTPConfigSerializer(serializers.ModelSerializer):
- class Meta:
- model = SMTPConfig
- exclude = ["id", "password"]
-
-
class TestSMTPConfigSerializer(serializers.Serializer):
email = serializers.EmailField()
class CreateEditWebsiteConfigSerializer(serializers.Serializer):
- base_url = serializers.CharField(max_length=128)
- name = serializers.CharField(max_length=32)
- name_shortcut = serializers.CharField(max_length=32)
- footer = serializers.CharField(max_length=1024)
+ website_base_url = serializers.CharField(max_length=128)
+ website_name = serializers.CharField(max_length=64)
+ website_name_shortcut = serializers.CharField(max_length=64)
+ website_footer = serializers.CharField(max_length=1024 * 1024)
allow_register = serializers.BooleanField()
submission_list_show_all = serializers.BooleanField()
-class WebsiteConfigSerializer(serializers.ModelSerializer):
- class Meta:
- model = WebsiteConfig
- exclude = ["id"]
-
-
class JudgeServerSerializer(serializers.ModelSerializer):
create_time = DateTimeTZField()
last_heartbeat = DateTimeTZField()
@@ -47,13 +35,14 @@ class JudgeServerSerializer(serializers.ModelSerializer):
class Meta:
model = JudgeServer
+ fields = "__all__"
class JudgeServerHeartbeatSerializer(serializers.Serializer):
- hostname = serializers.CharField(max_length=64)
- judger_version = serializers.CharField(max_length=24)
+ hostname = serializers.CharField(max_length=128)
+ judger_version = serializers.CharField(max_length=32)
cpu_core = serializers.IntegerField(min_value=1)
memory = serializers.FloatField(min_value=0, max_value=100)
cpu = serializers.FloatField(min_value=0, max_value=100)
action = serializers.ChoiceField(choices=("heartbeat", ))
- service_url = serializers.CharField(max_length=128, required=False)
+ service_url = serializers.CharField(max_length=256, required=False)
diff --git a/conf/tests.py b/conf/tests.py
index fd05b43..cd83444 100644
--- a/conf/tests.py
+++ b/conf/tests.py
@@ -2,9 +2,9 @@ import hashlib
from django.utils import timezone
+from options.options import SysOptions
from utils.api.tests import APITestCase
-
-from .models import JudgeServer, JudgeServerToken, SMTPConfig
+from .models import JudgeServer
class SMTPConfigTest(APITestCase):
@@ -27,10 +27,6 @@ class SMTPConfigTest(APITestCase):
"tls": True}
resp = self.client.put(self.url, data=data)
self.assertSuccess(resp)
- smtp = SMTPConfig.objects.first()
- self.assertEqual(smtp.password, self.password)
- self.assertEqual(smtp.server, "smtp1.test.com")
- self.assertEqual(smtp.email, "test2@test.com")
def test_edit_without_password1(self):
self.test_create_smtp_config()
@@ -38,7 +34,6 @@ class SMTPConfigTest(APITestCase):
"tls": True, "password": ""}
resp = self.client.put(self.url, data=data)
self.assertSuccess(resp)
- self.assertEqual(SMTPConfig.objects.first().password, self.password)
def test_edit_with_password(self):
self.test_create_smtp_config()
@@ -46,18 +41,14 @@ class SMTPConfigTest(APITestCase):
"tls": True, "password": "newpassword"}
resp = self.client.put(self.url, data=data)
self.assertSuccess(resp)
- smtp = SMTPConfig.objects.first()
- self.assertEqual(smtp.password, "newpassword")
- self.assertEqual(smtp.server, "smtp1.test.com")
- self.assertEqual(smtp.email, "test2@test.com")
class WebsiteConfigAPITest(APITestCase):
def test_create_website_config(self):
self.create_super_admin()
url = self.reverse("website_config_api")
- data = {"base_url": "http://test.com", "name": "test name",
- "name_shortcut": "test oj", "footer": "test",
+ data = {"website_base_url": "http://test.com", "website_name": "test name",
+ "website_name_shortcut": "test oj", "website_footer": "test",
"allow_register": True, "submission_list_show_all": False}
resp = self.client.post(url, data=data)
self.assertSuccess(resp)
@@ -65,8 +56,8 @@ class WebsiteConfigAPITest(APITestCase):
def test_edit_website_config(self):
self.create_super_admin()
url = self.reverse("website_config_api")
- data = {"base_url": "http://test.com", "name": "test name",
- "name_shortcut": "test oj", "footer": "test",
+ data = {"website_base_url": "http://test.com", "website_name": "test name",
+ "website_name_shortcut": "test oj", "website_footer": "test",
"allow_register": True, "submission_list_show_all": False}
resp = self.client.post(url, data=data)
self.assertSuccess(resp)
@@ -76,7 +67,6 @@ class WebsiteConfigAPITest(APITestCase):
url = self.reverse("website_info_api")
resp = self.client.get(url)
self.assertSuccess(resp)
- self.assertEqual(resp.data["data"]["name_shortcut"], "oj")
class JudgeServerHeartbeatTest(APITestCase):
@@ -86,7 +76,7 @@ class JudgeServerHeartbeatTest(APITestCase):
"cpu": 90.5, "memory": 80.3, "action": "heartbeat"}
self.token = "test"
self.hashed_token = hashlib.sha256(self.token.encode("utf-8")).hexdigest()
- JudgeServerToken.objects.create(token=self.token)
+ SysOptions.judge_server_token = self.token
def test_new_heartbeat(self):
resp = self.client.post(self.url, data=self.data, **{"HTTP_X_JUDGE_SERVER_TOKEN": self.hashed_token})
@@ -122,11 +112,9 @@ class JudgeServerAPITest(APITestCase):
self.create_super_admin()
def test_get_judge_server(self):
- self.assertFalse(JudgeServerToken.objects.exists())
resp = self.client.get(self.url)
self.assertSuccess(resp)
self.assertEqual(len(resp.data["data"]["servers"]), 1)
- self.assertEqual(JudgeServerToken.objects.first().token, resp.data["data"]["token"])
def test_delete_judge_server(self):
resp = self.client.delete(self.url + "?hostname=testhostname")
diff --git a/conf/views.py b/conf/views.py
index 3dc9959..f09972c 100644
--- a/conf/views.py
+++ b/conf/views.py
@@ -3,48 +3,43 @@ import hashlib
from django.utils import timezone
from account.decorators import super_admin_required
+from judge.dispatcher import process_pending_task
from judge.languages import languages, spj_languages
+from options.options import SysOptions
from utils.api import APIView, CSRFExemptAPIView, validate_serializer
-from utils.shortcuts import rand_str
-
-from .models import JudgeServer, JudgeServerToken, SMTPConfig, WebsiteConfig
+from .models import JudgeServer
from .serializers import (CreateEditWebsiteConfigSerializer,
CreateSMTPConfigSerializer, EditSMTPConfigSerializer,
JudgeServerHeartbeatSerializer,
- JudgeServerSerializer, SMTPConfigSerializer,
- TestSMTPConfigSerializer, WebsiteConfigSerializer)
+ JudgeServerSerializer, TestSMTPConfigSerializer)
class SMTPAPI(APIView):
@super_admin_required
def get(self, request):
- smtp = SMTPConfig.objects.first()
+ smtp = SysOptions.smtp_config
if not smtp:
return self.success(None)
- return self.success(SMTPConfigSerializer(smtp).data)
+ smtp.pop("password")
+ return self.success(smtp)
@validate_serializer(CreateSMTPConfigSerializer)
@super_admin_required
def post(self, request):
- SMTPConfig.objects.all().delete()
- smtp = SMTPConfig.objects.create(**request.data)
- return self.success(SMTPConfigSerializer(smtp).data)
+ SysOptions.smtp_config = request.data
+ return self.success()
@validate_serializer(EditSMTPConfigSerializer)
@super_admin_required
def put(self, request):
+ smtp = SysOptions.smtp_config
data = request.data
- smtp = SMTPConfig.objects.first()
- if not smtp:
- return self.error("SMTP config is missing")
- smtp.server = data["server"]
- smtp.port = data["port"]
- smtp.email = data["email"]
- smtp.tls = data["tls"]
- if data.get("password"):
- smtp.password = data["password"]
- smtp.save()
- return self.success(SMTPConfigSerializer(smtp).data)
+ for item in ["server", "port", "email", "tls"]:
+ smtp[item] = data[item]
+ if "password" in data:
+ smtp["password"] = data["password"]
+ SysOptions.smtp_config = smtp
+ return self.success()
class SMTPTestAPI(APIView):
@@ -56,31 +51,24 @@ class SMTPTestAPI(APIView):
class WebsiteConfigAPI(APIView):
def get(self, request):
- config = WebsiteConfig.objects.first()
- if not config:
- config = WebsiteConfig.objects.create()
- return self.success(WebsiteConfigSerializer(config).data)
+ ret = {key: getattr(SysOptions, key) for key in
+ ["website_base_url", "website_name", "website_name_shortcut",
+ "website_footer", "allow_register", "submission_list_show_all"]}
+ return self.success(ret)
@validate_serializer(CreateEditWebsiteConfigSerializer)
@super_admin_required
def post(self, request):
- data = request.data
- WebsiteConfig.objects.all().delete()
- config = WebsiteConfig.objects.create(**data)
- return self.success(WebsiteConfigSerializer(config).data)
+ for k, v in request.data.items():
+ setattr(SysOptions, k, v)
+ return self.success()
class JudgeServerAPI(APIView):
@super_admin_required
def get(self, request):
- judge_server_token = JudgeServerToken.objects.first()
- if not judge_server_token:
- token = rand_str(12)
- JudgeServerToken.objects.create(token=token)
- else:
- token = judge_server_token.token
servers = JudgeServer.objects.all().order_by("-last_heartbeat")
- return self.success({"token": token,
+ return self.success({"token": SysOptions.judge_server_token,
"servers": JudgeServerSerializer(servers, many=True).data})
@super_admin_required
@@ -94,15 +82,9 @@ class JudgeServerAPI(APIView):
class JudgeServerHeartbeatAPI(CSRFExemptAPIView):
@validate_serializer(JudgeServerHeartbeatSerializer)
def post(self, request):
- judge_server_token = JudgeServerToken.objects.first()
- if not judge_server_token:
- token = rand_str(12)
- JudgeServerToken.objects.create(token=token)
- else:
- token = judge_server_token.token
data = request.data
client_token = request.META.get("HTTP_X_JUDGE_SERVER_TOKEN")
- if hashlib.sha256(token.encode("utf-8")).hexdigest() != client_token:
+ if hashlib.sha256(SysOptions.judge_server_token.encode("utf-8")).hexdigest() != client_token:
return self.error("Invalid token")
service_url = data.get("service_url")
@@ -126,6 +108,9 @@ class JudgeServerHeartbeatAPI(CSRFExemptAPIView):
service_url=service_url,
last_heartbeat=timezone.now(),
)
+ # 新server上线 处理队列中的,防止没有新的提交而导致一直waiting
+ process_pending_task()
+
return self.success()
diff --git a/contest/migrations/0004_auto_20170717_1324.py b/contest/migrations/0004_auto_20170717_1324.py
new file mode 100644
index 0000000..617790a
--- /dev/null
+++ b/contest/migrations/0004_auto_20170717_1324.py
@@ -0,0 +1,23 @@
+# -*- coding: utf-8 -*-
+# Generated by Django 1.9.6 on 2017-07-17 13:24
+from __future__ import unicode_literals
+
+from django.db import migrations
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ('contest', '0003_auto_20170217_0820'),
+ ]
+
+ operations = [
+ migrations.AlterModelOptions(
+ name='contest',
+ options={'ordering': ('-create_time',)},
+ ),
+ migrations.AlterModelOptions(
+ name='contestannouncement',
+ options={'ordering': ('-create_time',)},
+ ),
+ ]
diff --git a/contest/migrations/0005_auto_20170823_0918.py b/contest/migrations/0005_auto_20170823_0918.py
new file mode 100644
index 0000000..dbf12c6
--- /dev/null
+++ b/contest/migrations/0005_auto_20170823_0918.py
@@ -0,0 +1,30 @@
+# -*- coding: utf-8 -*-
+# Generated by Django 1.9.6 on 2017-08-23 09:18
+from __future__ import unicode_literals
+
+from django.db import migrations
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ('contest', '0004_auto_20170717_1324'),
+ ]
+
+ operations = [
+ migrations.RenameField(
+ model_name='acmcontestrank',
+ old_name='total_ac_number',
+ new_name='accepted_number',
+ ),
+ migrations.RenameField(
+ model_name='acmcontestrank',
+ old_name='total_submission_number',
+ new_name='submission_number',
+ ),
+ migrations.RenameField(
+ model_name='oicontestrank',
+ old_name='total_submission_number',
+ new_name='submission_number',
+ ),
+ ]
diff --git a/contest/migrations/0006_auto_20171011_1214.py b/contest/migrations/0006_auto_20171011_1214.py
new file mode 100644
index 0000000..0134a5b
--- /dev/null
+++ b/contest/migrations/0006_auto_20171011_1214.py
@@ -0,0 +1,30 @@
+# -*- coding: utf-8 -*-
+# Generated by Django 1.11.4 on 2017-10-11 12:14
+from __future__ import unicode_literals
+
+import django.contrib.postgres.fields.jsonb
+from django.db import migrations
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ('contest', '0005_auto_20170823_0918'),
+ ]
+
+ operations = [
+ migrations.AlterField(
+ model_name='acmcontestrank',
+ name='submission_info',
+ field=django.contrib.postgres.fields.jsonb.JSONField(default=dict),
+ ),
+ migrations.AlterField(
+ model_name='oicontestrank',
+ name='submission_info',
+ field=django.contrib.postgres.fields.jsonb.JSONField(default=dict),
+ ),
+ migrations.AlterModelOptions(
+ name='contest',
+ options={'ordering': ('-start_time',)},
+ ),
+ ]
diff --git a/contest/migrations/0007_contestannouncement_visible.py b/contest/migrations/0007_contestannouncement_visible.py
new file mode 100644
index 0000000..679874f
--- /dev/null
+++ b/contest/migrations/0007_contestannouncement_visible.py
@@ -0,0 +1,20 @@
+# -*- coding: utf-8 -*-
+# Generated by Django 1.11.4 on 2017-11-06 09:02
+from __future__ import unicode_literals
+
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ('contest', '0006_auto_20171011_1214'),
+ ]
+
+ operations = [
+ migrations.AddField(
+ model_name='contestannouncement',
+ name='visible',
+ field=models.BooleanField(default=True),
+ ),
+ ]
diff --git a/contest/migrations/0008_contest_allowed_ip_ranges.py b/contest/migrations/0008_contest_allowed_ip_ranges.py
new file mode 100644
index 0000000..fd6c6ff
--- /dev/null
+++ b/contest/migrations/0008_contest_allowed_ip_ranges.py
@@ -0,0 +1,21 @@
+# -*- coding: utf-8 -*-
+# Generated by Django 1.11.4 on 2017-11-10 06:57
+from __future__ import unicode_literals
+
+import django.contrib.postgres.fields.jsonb
+from django.db import migrations
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ('contest', '0007_contestannouncement_visible'),
+ ]
+
+ operations = [
+ migrations.AddField(
+ model_name='contest',
+ name='allowed_ip_ranges',
+ field=django.contrib.postgres.fields.jsonb.JSONField(default=list),
+ ),
+ ]
diff --git a/contest/models.py b/contest/models.py
index 72e60cb..2f07b37 100644
--- a/contest/models.py
+++ b/contest/models.py
@@ -1,27 +1,13 @@
+from utils.constants import ContestRuleType # noqa
from django.db import models
from django.utils.timezone import now
-from jsonfield import JSONField
+from utils.models import JSONField
+from utils.constants import ContestStatus, ContestType
from account.models import User
from utils.models import RichTextField
-class ContestType(object):
- PUBLIC_CONTEST = "Public"
- PASSWORD_PROTECTED_CONTEST = "Password Protected"
-
-
-class ContestStatus(object):
- CONTEST_NOT_START = "Not Started"
- CONTEST_ENDED = "Ended"
- CONTEST_UNDERWAY = "Underway"
-
-
-class ContestRuleType(object):
- ACM = "ACM"
- OI = "OI"
-
-
class Contest(models.Model):
title = models.CharField(max_length=40)
description = RichTextField()
@@ -37,6 +23,7 @@ class Contest(models.Model):
created_by = models.ForeignKey(User)
# 是否可见 false的话相当于删除
visible = models.BooleanField(default=True)
+ allowed_ip_ranges = JSONField(default=list)
@property
def status(self):
@@ -56,36 +43,44 @@ class Contest(models.Model):
return ContestType.PASSWORD_PROTECTED_CONTEST
return ContestType.PUBLIC_CONTEST
+ # 是否有权查看problem 的一些统计信息 诸如submission_number, accepted_number 等
+ def problem_details_permission(self, user):
+ return self.rule_type == ContestRuleType.ACM or \
+ self.status == ContestStatus.CONTEST_ENDED or \
+ user.is_authenticated() and user.is_contest_admin(self) or \
+ self.real_time_rank
+
class Meta:
db_table = "contest"
+ ordering = ("-start_time",)
-class ContestRank(models.Model):
+class AbstractContestRank(models.Model):
user = models.ForeignKey(User)
contest = models.ForeignKey(Contest)
- total_submission_number = models.IntegerField(default=0)
+ submission_number = models.IntegerField(default=0)
class Meta:
abstract = True
-class ACMContestRank(ContestRank):
- total_ac_number = models.IntegerField(default=0)
+class ACMContestRank(AbstractContestRank):
+ accepted_number = models.IntegerField(default=0)
# total_time is only for ACM contest total_time = ac time + none-ac times * 20 * 60
total_time = models.IntegerField(default=0)
# {23: {"is_ac": True, "ac_time": 8999, "error_number": 2, "is_first_ac": True}}
# key is problem id
- submission_info = JSONField(default={})
+ submission_info = JSONField(default=dict)
class Meta:
db_table = "acm_contest_rank"
-class OIContestRank(ContestRank):
+class OIContestRank(AbstractContestRank):
total_score = models.IntegerField(default=0)
- # {23: {"score": 80, "total_score": 100}}
- # key is problem id
- submission_info = JSONField(default={})
+ # {23: 333}}
+ # key is problem id, value is current score
+ submission_info = JSONField(default=dict)
class Meta:
db_table = "oi_contest_rank"
@@ -96,7 +91,9 @@ class ContestAnnouncement(models.Model):
title = models.CharField(max_length=128)
content = RichTextField()
created_by = models.ForeignKey(User)
+ visible = models.BooleanField(default=True)
create_time = models.DateTimeField(auto_now_add=True)
class Meta:
db_table = "contest_announcement"
+ ordering = ("-create_time",)
diff --git a/contest/serializers.py b/contest/serializers.py
index c99c16d..abbdccc 100644
--- a/contest/serializers.py
+++ b/contest/serializers.py
@@ -1,6 +1,7 @@
from utils.api import DateTimeTZField, UsernameSerializer, serializers
from .models import Contest, ContestAnnouncement, ContestRuleType
+from .models import ACMContestRank, OIContestRank
class CreateConetestSeriaizer(serializers.Serializer):
@@ -12,9 +13,22 @@ class CreateConetestSeriaizer(serializers.Serializer):
password = serializers.CharField(allow_blank=True, max_length=32)
visible = serializers.BooleanField()
real_time_rank = serializers.BooleanField()
+ allowed_ip_ranges = serializers.ListField(child=serializers.CharField(max_length=32), allow_empty=True)
-class ContestSerializer(serializers.ModelSerializer):
+class EditConetestSeriaizer(serializers.Serializer):
+ id = serializers.IntegerField()
+ title = serializers.CharField(max_length=128)
+ description = serializers.CharField()
+ start_time = serializers.DateTimeField()
+ end_time = serializers.DateTimeField()
+ password = serializers.CharField(allow_blank=True, allow_null=True, max_length=32)
+ visible = serializers.BooleanField()
+ real_time_rank = serializers.BooleanField()
+ allowed_ip_ranges = serializers.ListField(child=serializers.CharField(max_length=32))
+
+
+class ContestAdminSerializer(serializers.ModelSerializer):
start_time = DateTimeTZField()
end_time = DateTimeTZField()
create_time = DateTimeTZField()
@@ -27,15 +41,10 @@ class ContestSerializer(serializers.ModelSerializer):
model = Contest
-class EditConetestSeriaizer(serializers.Serializer):
- id = serializers.IntegerField()
- title = serializers.CharField(max_length=128)
- description = serializers.CharField()
- start_time = serializers.DateTimeField()
- end_time = serializers.DateTimeField()
- password = serializers.CharField(allow_blank=True, allow_null=True, max_length=32)
- visible = serializers.BooleanField()
- real_time_rank = serializers.BooleanField()
+class ContestSerializer(ContestAdminSerializer):
+ class Meta:
+ model = Contest
+ exclude = ("password", "visible", "allowed_ip_ranges")
class ContestAnnouncementSerializer(serializers.ModelSerializer):
@@ -47,6 +56,35 @@ class ContestAnnouncementSerializer(serializers.ModelSerializer):
class CreateContestAnnouncementSerializer(serializers.Serializer):
+ contest_id = serializers.IntegerField()
title = serializers.CharField(max_length=128)
content = serializers.CharField()
+ visible = serializers.BooleanField()
+
+
+class EditContestAnnouncementSerializer(serializers.Serializer):
+ id = serializers.IntegerField()
+ title = serializers.CharField(max_length=128, required=False)
+ content = serializers.CharField(required=False, allow_blank=True)
+ visible = serializers.BooleanField(required=False)
+
+
+class ContestPasswordVerifySerializer(serializers.Serializer):
contest_id = serializers.IntegerField()
+ password = serializers.CharField(max_length=30, required=True)
+
+
+class ACMContestRankSerializer(serializers.ModelSerializer):
+ user = UsernameSerializer()
+ submission_info = serializers.JSONField()
+
+ class Meta:
+ model = ACMContestRank
+
+
+class OIContestRankSerializer(serializers.ModelSerializer):
+ user = UsernameSerializer()
+ submission_info = serializers.JSONField()
+
+ class Meta:
+ model = OIContestRank
diff --git a/contest/tests.py b/contest/tests.py
index 8a8134f..965ee8d 100644
--- a/contest/tests.py
+++ b/contest/tests.py
@@ -6,27 +6,33 @@ from django.utils import timezone
from utils.api._serializers import DateTimeTZField
from utils.api.tests import APITestCase
-from .models import ContestAnnouncement, ContestRuleType
+from .models import ContestAnnouncement, ContestRuleType, Contest
DEFAULT_CONTEST_DATA = {"title": "test title", "description": "test description",
"start_time": timezone.localtime(timezone.now()),
"end_time": timezone.localtime(timezone.now()) + timedelta(days=1),
"rule_type": ContestRuleType.ACM,
"password": "123",
+ "allowed_ip_ranges": [],
"visible": True, "real_time_rank": True}
-class ContestAPITest(APITestCase):
+class ContestAdminAPITest(APITestCase):
def setUp(self):
self.create_super_admin()
- self.url = self.reverse("contest_api")
- self.data = DEFAULT_CONTEST_DATA
+ self.url = self.reverse("contest_admin_api")
+ self.data = copy.deepcopy(DEFAULT_CONTEST_DATA)
def test_create_contest(self):
response = self.client.post(self.url, data=self.data)
self.assertSuccess(response)
return response
+ def test_create_contest_with_invalid_cidr(self):
+ self.data["allowed_ip_ranges"] = ["127.0.0"]
+ resp = self.client.post(self.url, data=self.data)
+ self.assertTrue(resp.data["data"].endswith("is not a valid cidr network"))
+
def test_update_contest(self):
id = self.test_create_contest().data["data"]["id"]
update_data = {"id": id, "title": "update title",
@@ -55,15 +61,54 @@ class ContestAPITest(APITestCase):
self.assertSuccess(response)
-class ContestAnnouncementAPITest(APITestCase):
+class ContestAPITest(APITestCase):
+ def setUp(self):
+ user = self.create_admin()
+ self.contest = Contest.objects.create(created_by=user, **DEFAULT_CONTEST_DATA)
+ self.url = self.reverse("contest_api") + "?id=" + str(self.contest.id)
+
+ def test_get_contest_list(self):
+ url = self.reverse("contest_list_api")
+ response = self.client.get(url + "?limit=10")
+ self.assertSuccess(response)
+ self.assertEqual(len(response.data["data"]["results"]), 1)
+
+ def test_get_one_contest(self):
+ resp = self.client.get(self.url)
+ self.assertSuccess(resp)
+
+ def test_regular_user_validate_contest_password(self):
+ self.create_user("test", "test123")
+ url = self.reverse("contest_password_api")
+ resp = self.client.post(url, {"contest_id": self.contest.id, "password": "error_password"})
+ self.assertDictEqual(resp.data, {"error": "error", "data": "Wrong password"})
+
+ resp = self.client.post(url, {"contest_id": self.contest.id, "password": DEFAULT_CONTEST_DATA["password"]})
+ self.assertSuccess(resp)
+
+ def test_regular_user_access_contest(self):
+ self.create_user("test", "test123")
+ url = self.reverse("contest_access_api")
+ resp = self.client.get(url + "?contest_id=" + str(self.contest.id))
+ self.assertFalse(resp.data["data"]["access"])
+
+ password_url = self.reverse("contest_password_api")
+ resp = self.client.post(password_url,
+ {"contest_id": self.contest.id, "password": DEFAULT_CONTEST_DATA["password"]})
+ self.assertSuccess(resp)
+ resp = self.client.get(self.url)
+ self.assertSuccess(resp)
+
+
+class ContestAnnouncementAdminAPITest(APITestCase):
def setUp(self):
self.create_super_admin()
self.url = self.reverse("contest_announcement_admin_api")
contest_id = self.create_contest().data["data"]["id"]
- self.data = {"title": "test title", "content": "test content", "contest_id": contest_id}
+ self.data = {"title": "test title", "content": "test content", "contest_id": contest_id, "visible": True}
def create_contest(self):
- url = self.reverse("contest_api")
+ url = self.reverse("contest_admin_api")
data = DEFAULT_CONTEST_DATA
return self.client.post(url, data=data)
@@ -80,7 +125,7 @@ class ContestAnnouncementAPITest(APITestCase):
def test_get_contest_announcements(self):
self.test_create_contest_announcement()
- response = self.client.get(self.url)
+ response = self.client.get(self.url + "?contest_id=" + str(self.data["contest_id"]))
self.assertSuccess(response)
def test_get_one_contest_announcement(self):
@@ -92,10 +137,10 @@ class ContestAnnouncementAPITest(APITestCase):
class ContestAnnouncementListAPITest(APITestCase):
def setUp(self):
self.create_super_admin()
- self.url = self.reverse("contest_list_api")
+ self.url = self.reverse("contest_announcement_api")
def create_contest_announcements(self):
- contest_id = self.client.post(self.reverse("contest_api"), data=DEFAULT_CONTEST_DATA).data["data"]["id"]
+ contest_id = self.client.post(self.reverse("contest_admin_api"), data=DEFAULT_CONTEST_DATA).data["data"]["id"]
url = self.reverse("contest_announcement_admin_api")
self.client.post(url, data={"title": "test title1", "content": "test content1", "contest_id": contest_id})
self.client.post(url, data={"title": "test title2", "content": "test content2", "contest_id": contest_id})
@@ -105,3 +150,15 @@ class ContestAnnouncementListAPITest(APITestCase):
contest_id = self.create_contest_announcements()
response = self.client.get(self.url, data={"contest_id": contest_id})
self.assertSuccess(response)
+
+
+class ContestRankAPITest(APITestCase):
+ def setUp(self):
+ user = self.create_admin()
+ self.acm_contest = Contest.objects.create(created_by=user, **DEFAULT_CONTEST_DATA)
+ self.create_user("test", "test123")
+ self.url = self.reverse("contest_rank_api")
+
+ def get_contest_rank(self):
+ resp = self.client.get(self.url + "?contest_id=" + self.acm_contest.id)
+ self.assertSuccess(resp)
diff --git a/contest/urls/admin.py b/contest/urls/admin.py
index 2a7705a..5a8bc75 100644
--- a/contest/urls/admin.py
+++ b/contest/urls/admin.py
@@ -3,6 +3,6 @@ from django.conf.urls import url
from ..views.admin import ContestAnnouncementAPI, ContestAPI
urlpatterns = [
- url(r"^contest/?$", ContestAPI.as_view(), name="contest_api"),
+ url(r"^contest/?$", ContestAPI.as_view(), name="contest_admin_api"),
url(r"^contest/announcement/?$", ContestAnnouncementAPI.as_view(), name="contest_announcement_admin_api")
]
diff --git a/contest/urls/oj.py b/contest/urls/oj.py
index bfc80b8..9e94fa5 100644
--- a/contest/urls/oj.py
+++ b/contest/urls/oj.py
@@ -1,7 +1,15 @@
from django.conf.urls import url
from ..views.oj import ContestAnnouncementListAPI
+from ..views.oj import ContestPasswordVerifyAPI, ContestAccessAPI
+from ..views.oj import ContestListAPI, ContestAPI
+from ..views.oj import ContestRankAPI
urlpatterns = [
- url(r"^contest/?$", ContestAnnouncementListAPI.as_view(), name="contest_list_api"),
+ url(r"^contests/?$", ContestListAPI.as_view(), name="contest_list_api"),
+ url(r"^contest/?$", ContestAPI.as_view(), name="contest_api"),
+ url(r"^contest/password/?$", ContestPasswordVerifyAPI.as_view(), name="contest_password_api"),
+ url(r"^contest/announcement/?$", ContestAnnouncementListAPI.as_view(), name="contest_announcement_api"),
+ url(r"^contest/access/?$", ContestAccessAPI.as_view(), name="contest_access_api"),
+ url(r"^contest_rank/?$", ContestRankAPI.as_view(), name="contest_rank_api"),
]
diff --git a/contest/views/admin.py b/contest/views/admin.py
index 60bb161..d5ce347 100644
--- a/contest/views/admin.py
+++ b/contest/views/admin.py
@@ -1,12 +1,14 @@
+from ipaddress import ip_network
import dateutil.parser
from utils.api import APIView, validate_serializer
from ..models import Contest, ContestAnnouncement
-from ..serializers import (ContestAnnouncementSerializer, ContestSerializer,
+from ..serializers import (ContestAnnouncementSerializer, ContestAdminSerializer,
CreateConetestSeriaizer,
CreateContestAnnouncementSerializer,
- EditConetestSeriaizer)
+ EditConetestSeriaizer,
+ EditContestAnnouncementSerializer)
class ContestAPI(APIView):
@@ -18,10 +20,15 @@ class ContestAPI(APIView):
data["created_by"] = request.user
if data["end_time"] <= data["start_time"]:
return self.error("Start time must occur earlier than end time")
- if not data["password"]:
+ if data.get("password") and data["password"] == "":
data["password"] = None
+ for ip_range in data["allowed_ip_ranges"]:
+ try:
+ ip_network(ip_range, strict=False)
+ except ValueError:
+ return self.error(f"{ip_range} is not a valid cidr network")
contest = Contest.objects.create(**data)
- return self.success(ContestSerializer(contest).data)
+ return self.success(ContestAdminSerializer(contest).data)
@validate_serializer(EditConetestSeriaizer)
def put(self, request):
@@ -38,10 +45,16 @@ class ContestAPI(APIView):
return self.error("Start time must occur earlier than end time")
if not data["password"]:
data["password"] = None
+ for ip_range in data["allowed_ip_ranges"]:
+ try:
+ ip_network(ip_range, strict=False)
+ except ValueError as e:
+ return self.error(f"{ip_range} is not a valid cidr network")
+
for k, v in data.items():
setattr(contest, k, v)
contest.save()
- return self.success(ContestSerializer(contest).data)
+ return self.success(ContestAdminSerializer(contest).data)
def get(self, request):
contest_id = request.GET.get("id")
@@ -50,7 +63,7 @@ class ContestAPI(APIView):
contest = Contest.objects.get(id=contest_id)
if request.user.is_admin() and contest.created_by != request.user:
return self.error("Contest does not exist")
- return self.success(ContestSerializer(contest).data)
+ return self.success(ContestAdminSerializer(contest).data)
except Contest.DoesNotExist:
return self.error("Contest does not exist")
@@ -62,7 +75,7 @@ class ContestAPI(APIView):
if request.user.is_admin():
contests = contests.filter(created_by=request.user)
- return self.success(self.paginate_data(request, contests, ContestSerializer))
+ return self.success(self.paginate_data(request, contests, ContestAdminSerializer))
class ContestAnnouncementAPI(APIView):
@@ -83,6 +96,23 @@ class ContestAnnouncementAPI(APIView):
announcement = ContestAnnouncement.objects.create(**data)
return self.success(ContestAnnouncementSerializer(announcement).data)
+ @validate_serializer(EditContestAnnouncementSerializer)
+ def put(self, request):
+ """
+ update contest_announcement
+ """
+ data = request.data
+ try:
+ contest_announcement = ContestAnnouncement.objects.get(id=data.pop("id"))
+ if request.user.is_admin() and contest_announcement.created_by != request.user:
+ return self.error("Contest announcement does not exist")
+ except ContestAnnouncement.DoesNotExist:
+ return self.error("Contest announcement does not exist")
+ for k, v in data.items():
+ setattr(contest_announcement, k, v)
+ contest_announcement.save()
+ return self.success()
+
def delete(self, request):
"""
Delete one contest_announcement.
@@ -110,10 +140,13 @@ class ContestAnnouncementAPI(APIView):
except ContestAnnouncement.DoesNotExist:
return self.error("Contest announcement does not exist")
- contest_announcements = ContestAnnouncement.objects.all().order_by("-create_time")
+ contest_id = request.GET.get("contest_id")
+ if not contest_id:
+ return self.error("Paramater error")
+ contest_announcements = ContestAnnouncement.objects.filter(contest_id=contest_id)
if request.user.is_admin():
contest_announcements = contest_announcements.filter(created_by=request.user)
keyword = request.GET.get("keyword")
if keyword:
contest_announcements = contest_announcements.filter(title__contains=keyword)
- return self.success(self.paginate_data(request, contest_announcements, ContestAnnouncementSerializer))
+ return self.success(ContestAnnouncementSerializer(contest_announcements, many=True).data)
diff --git a/contest/views/oj.py b/contest/views/oj.py
index e9ffe81..66889d5 100644
--- a/contest/views/oj.py
+++ b/contest/views/oj.py
@@ -1,16 +1,115 @@
-from utils.api import APIView
+from django.utils.timezone import now
+from django.core.cache import cache
+from utils.api import APIView, validate_serializer
+from utils.constants import CacheKey
+from utils.shortcuts import datetime2str
+from account.decorators import login_required, check_contest_permission
-from ..models import ContestAnnouncement
+from utils.constants import ContestRuleType, ContestStatus
+from ..models import ContestAnnouncement, Contest, OIContestRank, ACMContestRank
from ..serializers import ContestAnnouncementSerializer
+from ..serializers import ContestSerializer, ContestPasswordVerifySerializer
+from ..serializers import OIContestRankSerializer, ACMContestRankSerializer
class ContestAnnouncementListAPI(APIView):
+ @check_contest_permission(check_type="announcements")
def get(self, request):
contest_id = request.GET.get("contest_id")
if not contest_id:
- return self.error("Invalid parameter")
- data = ContestAnnouncement.objects.filter(contest_id=contest_id).order_by("-create_time")
+ return self.error("Invalid parameter, contest_id is required")
+ data = ContestAnnouncement.objects.select_related("created_by").filter(contest_id=contest_id, visible=True)
max_id = request.GET.get("max_id")
if max_id:
data = data.filter(id__gt=max_id)
return self.success(ContestAnnouncementSerializer(data, many=True).data)
+
+
+class ContestAPI(APIView):
+ def get(self, request):
+ id = request.GET.get("id")
+ if not id:
+ return self.error("Invalid parameter, id is required")
+ try:
+ contest = Contest.objects.get(id=id)
+ except Contest.DoesNotExist:
+ return self.error("Contest does not exist")
+ data = ContestSerializer(contest).data
+ data["now"] = datetime2str(now())
+ return self.success(data)
+
+
+class ContestListAPI(APIView):
+ def get(self, request):
+ contests = Contest.objects.select_related("created_by").filter(visible=True)
+ keyword = request.GET.get("keyword")
+ rule_type = request.GET.get("rule_type")
+ status = request.GET.get("status")
+ if keyword:
+ contests = contests.filter(title__contains=keyword)
+ if rule_type:
+ contests = contests.filter(rule_type=rule_type)
+ if status:
+ cur = now()
+ if status == ContestStatus.CONTEST_NOT_START:
+ contests = contests.filter(start_time__gt=cur)
+ elif status == ContestStatus.CONTEST_ENDED:
+ contests = contests.filter(end_time__lt=cur)
+ else:
+ contests = contests.filter(start_time__lte=cur, end_time__gte=cur)
+ data = self.paginate_data(request, contests, ContestSerializer)
+ return self.success(data)
+
+
+class ContestPasswordVerifyAPI(APIView):
+ @validate_serializer(ContestPasswordVerifySerializer)
+ @login_required
+ def post(self, request):
+ data = request.data
+ try:
+ contest = Contest.objects.get(id=data["contest_id"], visible=True, password__isnull=False)
+ except Contest.DoesNotExist:
+ return self.error("Contest does not exist")
+ if contest.password != data["password"]:
+ return self.error("Wrong password")
+
+ # password verify OK.
+ if "accessible_contests" not in request.session:
+ request.session["accessible_contests"] = []
+ request.session["accessible_contests"].append(contest.id)
+ # https://docs.djangoproject.com/en/dev/topics/http/sessions/#when-sessions-are-saved
+ request.session.modified = True
+ return self.success(True)
+
+
+class ContestAccessAPI(APIView):
+ @login_required
+ def get(self, request):
+ contest_id = request.GET.get("contest_id")
+ if not contest_id:
+ return self.error()
+ return self.success({"access": int(contest_id) in request.session.get("accessible_contests", [])})
+
+
+class ContestRankAPI(APIView):
+ def get_rank(self):
+ if self.contest.rule_type == ContestRuleType.ACM:
+ return ACMContestRank.objects.filter(contest=self.contest). \
+ select_related("user").order_by("-accepted_number", "total_time")
+ else:
+ return OIContestRank.objects.filter(contest=self.contest). \
+ select_related("user").order_by("-total_score")
+
+ @check_contest_permission(check_type="ranks")
+ def get(self, request):
+ if self.contest.rule_type == ContestRuleType.OI:
+ serializer = OIContestRankSerializer
+ else:
+ serializer = ACMContestRankSerializer
+
+ cache_key = f"{CacheKey.contest_rank_cache}:{self.contest.id}"
+ qs = cache.get(cache_key)
+ if not qs:
+ qs = self.get_rank()
+ cache.set(cache_key, qs)
+ return self.success(self.paginate_data(request, qs, serializer))
diff --git a/data/log/.gitkeep b/data/log/.gitkeep
new file mode 100644
index 0000000..e69de29
diff --git a/data/public/avatar/default.png b/data/public/avatar/default.png
new file mode 100644
index 0000000..97f3495
Binary files /dev/null and b/data/public/avatar/default.png differ
diff --git a/data/public/upload/.gitkeep b/data/public/upload/.gitkeep
new file mode 100644
index 0000000..e69de29
diff --git a/data/ssl/.gitkeep b/data/ssl/.gitkeep
new file mode 100644
index 0000000..e69de29
diff --git a/data/test_case/.gitkeep b/data/test_case/.gitkeep
new file mode 100644
index 0000000..e69de29
diff --git a/deploy/Dockerfile b/deploy/Dockerfile
deleted file mode 100644
index 6f3d17d..0000000
--- a/deploy/Dockerfile
+++ /dev/null
@@ -1,5 +0,0 @@
-FROM python:3.5
-ADD requirements.txt /tmp
-RUN pip install -r /tmp/requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
-WORKDIR /app
-CMD python manage.py runserver 0.0.0.0:8085
diff --git a/deploy/nginx/common.conf b/deploy/nginx/common.conf
new file mode 100644
index 0000000..ecfd251
--- /dev/null
+++ b/deploy/nginx/common.conf
@@ -0,0 +1,20 @@
+location /public {
+ root /data;
+}
+
+location /api {
+ proxy_pass http://backend;
+ proxy_set_header X-Real-IP $remote_addr;
+ proxy_set_header Host $http_host;
+ client_max_body_size 200M;
+}
+
+location /admin {
+ root /app/dist/admin;
+ try_files $uri $uri/ /index.html =404;
+}
+
+location / {
+ root /app/dist;
+ try_files $uri $uri/ /index.html =404;
+}
\ No newline at end of file
diff --git a/deploy/nginx/nginx.conf b/deploy/nginx/nginx.conf
new file mode 100644
index 0000000..0942890
--- /dev/null
+++ b/deploy/nginx/nginx.conf
@@ -0,0 +1,57 @@
+user nobody;
+daemon off;
+pid /tmp/nginx.pid;
+worker_processes auto;
+pcre_jit on;
+error_log /data/log/nginx_error.log warn;
+
+events {
+ worker_connections 1024;
+}
+
+http {
+ include /etc/nginx/mime.types;
+ default_type application/octet-stream;
+ server_tokens off;
+ keepalive_timeout 65;
+ sendfile on;
+ tcp_nodelay on;
+
+ gzip on;
+ gzip_vary on;
+ gzip_types application/javascript text/css;
+ client_body_temp_path /tmp 1 2;
+
+ log_format main '$remote_addr - $remote_user [$time_local] "$request" '
+ '$status $body_bytes_sent "$http_referer" '
+ '"$http_user_agent" "$http_x_forwarded_for"';
+
+ access_log /data/log/nginx_access.log main;
+
+ upstream backend {
+ server 127.0.0.1:8080;
+ keepalive 32;
+ }
+
+ server {
+ listen 8000 default_server;
+ server_name _;
+
+ include common.conf;
+ }
+
+ server {
+ listen 1443 ssl http2 default_server;
+ server_name _;
+ ssl_certificate /data/ssl/server.crt;
+ ssl_certificate_key /data/ssl/server.key;
+ ssl_protocols TLSv1 TLSv1.1 TLSv1.2;
+ ssl_ciphers "EECDH+AESGCM:EDH+AESGCM:ECDHE-RSA-AES128-GCM-SHA256:AES256+EECDH:DHE-RSA-AES128-GCM-SHA256:AES256+EDH:ECDHE-RSA-AES256-GCM-SHA384:DHE-RSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-SHA384:ECDHE-RSA-AES128-SHA256:ECDHE-RSA-AES256-SHA:ECDHE-RSA-AES128-SHA:DHE-RSA-AES256-SHA256:DHE-RSA-AES128-SHA256:DHE-RSA-AES256-SHA:DHE-RSA-AES128-SHA:ECDHE-RSA-DES-CBC3-SHA:EDH-RSA-DES-CBC3-SHA:AES256-GCM-SHA384:AES128-GCM-SHA256:AES256-SHA256:AES128-SHA256:AES256-SHA:AES128-SHA:DES-CBC3-SHA:HIGH:!aNULL:!eNULL:!EXPORT:!DES:!MD5:!PSK:!RC4";
+ ssl_prefer_server_ciphers on;
+ ssl_session_cache shared:SSL:10m;
+
+ include common.conf;
+ }
+
+}
+
diff --git a/deploy/requirements.txt b/deploy/requirements.txt
index 3d8bf5a..3d83ebf 100644
--- a/deploy/requirements.txt
+++ b/deploy/requirements.txt
@@ -1,7 +1,6 @@
-django==1.9.6
+django==1.11.4
djangorestframework==3.4.0
pillow
-jsonfield
otpauth
flake8-quotes
pytz
@@ -11,3 +10,9 @@ celery
Envelopes
qrcode
flake8-coding
+requests
+django-redis
+psycopg2
+gunicorn
+jsonfield
+XlsxWriter
diff --git a/deploy/run.sh b/deploy/run.sh
new file mode 100644
index 0000000..b710cba
--- /dev/null
+++ b/deploy/run.sh
@@ -0,0 +1,33 @@
+#!/bin/bash
+
+APP=/app
+DATA=/data
+
+if [ ! -f "$APP/oj/custom_settings.py" ]; then
+ echo SECRET_KEY=\"$(cat /dev/urandom | head -1 | md5sum | head -c 32)\" >> $APP/oj/custom_settings.py
+fi
+
+mkdir -p $DATA/log $DATA/ssl $DATA/test_case $DATA/public/upload $DATA/public/avatar
+
+SSL="$DATA/ssl"
+if [ ! -f "$SSL/server.key" ]; then
+ openssl req -x509 -newkey rsa:2048 -keyout "$SSL/server.key" -out "$SSL/server.crt" -days 1000 \
+ -subj "/C=CN/ST=Beijing/L=Beijing/O=Beijing OnlineJudge Technology Co., Ltd./OU=Service Infrastructure Department/CN=`hostname`" -nodes
+fi
+
+cd $APP
+
+n=0
+while [ $n -lt 5 ]
+do
+ python manage.py migrate --no-input &&
+ python manage.py inituser --username=root --password=rootroot --action=create_super_admin &&
+ break
+ n=$(($n+1))
+ echo "Failed to migrate, going to retry..."
+ sleep 8
+done
+
+cp data/public/avatar/default.png /data/public/avatar
+chown -R nobody:nogroup $DATA $APP/dist
+exec supervisord -c /app/deploy/supervisord.conf
diff --git a/deploy/supervisord.conf b/deploy/supervisord.conf
new file mode 100644
index 0000000..6b23166
--- /dev/null
+++ b/deploy/supervisord.conf
@@ -0,0 +1,52 @@
+[supervisord]
+logfile=/data/log/supervisord.log
+logfile_maxbytes=10MB
+logfile_backups=10
+loglevel=info
+pidfile=/tmp/supervisord.pid
+nodaemon=true
+childlogdir=/data/log/
+
+[inet_http_server]
+port=127.0.0.1:9005
+
+[rpcinterface:supervisor]
+supervisor.rpcinterface_factory=supervisor.rpcinterface:make_main_rpcinterface
+
+[supervisorctl]
+serverurl=http://127.0.0.1:9005
+
+[program:nginx]
+command=nginx -c /app/deploy/nginx/nginx.conf
+directory=/app/
+stdout_logfile=/data/log/nginx.log
+stderr_logfile=/data/log/nginx.log
+autostart=true
+autorestart=true
+startsecs=5
+stopwaitsecs = 5
+killasgroup=true
+
+[program:gunicorn]
+command=sh -c "gunicorn oj.wsgi --user nobody -b 127.0.0.1:8080 --reload -w `grep -c ^processor /proc/cpuinfo`"
+directory=/app/
+user=nobody
+stdout_logfile=/data/log/gunicorn.log
+stderr_logfile=/data/log/gunicorn.log
+autostart=true
+autorestart=true
+startsecs=5
+stopwaitsecs = 5
+killasgroup=true
+
+[program:celery]
+command=celery -A oj worker -l warning
+directory=/app/
+user=nobody
+stdout_logfile=/data/log/celery.log
+stderr_logfile=/data/log/celery.log
+autostart=true
+autorestart=true
+startsecs=5
+stopwaitsecs = 5
+killasgroup=true
diff --git a/judge/dispatcher.py b/judge/dispatcher.py
new file mode 100644
index 0000000..0f5e3cb
--- /dev/null
+++ b/judge/dispatcher.py
@@ -0,0 +1,352 @@
+import hashlib
+import json
+import logging
+from urllib.parse import urljoin
+
+import requests
+from django.db import transaction
+from django.db.models import F
+from django.conf import settings
+
+from account.models import User
+from conf.models import JudgeServer
+from contest.models import ContestRuleType, ACMContestRank, OIContestRank, ContestStatus
+from judge.languages import languages, spj_languages
+from options.options import SysOptions
+from problem.models import Problem, ProblemRuleType
+from problem.utils import parse_problem_template
+from submission.models import JudgeStatus, Submission
+from utils.cache import cache
+from utils.constants import CacheKey
+
+logger = logging.getLogger(__name__)
+
+
+# 继续处理在队列中的问题
+def process_pending_task():
+ if cache.llen(CacheKey.waiting_queue):
+ # 防止循环引入
+ from judge.tasks import judge_task
+ data = json.loads(cache.rpop(CacheKey.waiting_queue).decode("utf-8"))
+ judge_task.delay(**data)
+
+
+class DispatcherBase(object):
+ def __init__(self):
+ self.token = hashlib.sha256(SysOptions.judge_server_token.encode("utf-8")).hexdigest()
+
+ def _request(self, url, data=None):
+ kwargs = {"headers": {"X-Judge-Server-Token": self.token}}
+ if data:
+ kwargs["json"] = data
+ try:
+ return requests.post(url, **kwargs).json()
+ except Exception as e:
+ logger.exception(e)
+
+ @staticmethod
+ def choose_judge_server():
+ with transaction.atomic():
+ servers = JudgeServer.objects.select_for_update().all().order_by("task_number")
+ servers = [s for s in servers if s.status == "normal"]
+ if servers:
+ server = servers[0]
+ server.used_instance_number = F("task_number") + 1
+ server.save()
+ return server
+
+ @staticmethod
+ def release_judge_server(judge_server_id):
+ with transaction.atomic():
+ # 使用原子操作, 同时因为use和release中间间隔了判题过程,需要重新查询一下
+ server = JudgeServer.objects.get(id=judge_server_id)
+ server.used_instance_number = F("task_number") - 1
+ server.save()
+
+
+class SPJCompiler(DispatcherBase):
+ def __init__(self, spj_code, spj_version, spj_language):
+ super().__init__()
+ spj_compile_config = list(filter(lambda config: spj_language == config["name"], spj_languages))[0]["spj"][
+ "compile"]
+ self.data = {
+ "src": spj_code,
+ "spj_version": spj_version,
+ "spj_compile_config": spj_compile_config
+ }
+
+ def compile_spj(self):
+ server = self.choose_judge_server()
+ if not server:
+ return "No available judge_server"
+ result = self._request(urljoin(server.service_url, "compile_spj"), data=self.data)
+ self.release_judge_server(server.id)
+ if result["err"]:
+ return result["data"]
+
+
+class JudgeDispatcher(DispatcherBase):
+ def __init__(self, submission_id, problem_id):
+ super().__init__()
+ self.submission = Submission.objects.get(id=submission_id)
+ self.contest_id = self.submission.contest_id
+ if self.contest_id:
+ self.problem = Problem.objects.select_related("contest").get(id=problem_id, contest_id=self.contest_id)
+ self.contest = self.problem.contest
+ else:
+ self.problem = Problem.objects.get(id=problem_id)
+
+ def _compute_statistic_info(self, resp_data):
+ # 用时和内存占用保存为多个测试点中最长的那个
+ self.submission.statistic_info["time_cost"] = max([x["cpu_time"] for x in resp_data])
+ self.submission.statistic_info["memory_cost"] = max([x["memory"] for x in resp_data])
+
+ # sum up the score in OI mode
+ if self.problem.rule_type == ProblemRuleType.OI:
+ score = 0
+ try:
+ for i in range(len(resp_data)):
+ if resp_data[i]["result"] == JudgeStatus.ACCEPTED:
+ resp_data[i]["score"] = self.problem.test_case_score[i]["score"]
+ score += resp_data[i]["score"]
+ else:
+ resp_data[i]["score"] = 0
+ except IndexError:
+ logger.error(f"Index Error raised when summing up the score in problem {self.problem.id}")
+ self.submission.statistic_info["score"] = 0
+ return
+ self.submission.statistic_info["score"] = score
+
+ def judge(self, output=True):
+ server = self.choose_judge_server()
+ if not server:
+ data = {"submission_id": self.submission.id, "problem_id": self.problem.id}
+ cache.lpush(CacheKey.waiting_queue, json.dumps(data))
+ return
+
+ language = self.submission.language
+ sub_config = list(filter(lambda item: language == item["name"], languages))[0]
+ spj_config = {}
+ if self.problem.spj_code:
+ for lang in spj_languages:
+ if lang["name"] == self.problem.spj_language:
+ spj_config = lang["spj"]
+ break
+
+ if language in self.problem.template:
+ template = parse_problem_template(self.problem.template[language])
+ code = f"{template['prepend']}\n{self.submission.code}\n{template['append']}"
+ else:
+ code = self.submission.code
+
+ data = {
+ "language_config": sub_config["config"],
+ "src": code,
+ "max_cpu_time": self.problem.time_limit,
+ "max_memory": 1024 * 1024 * self.problem.memory_limit,
+ "test_case_id": self.problem.test_case_id,
+ "output": output,
+ "spj_version": self.problem.spj_version,
+ "spj_config": spj_config.get("config"),
+ "spj_compile_config": spj_config.get("compile"),
+ "spj_src": self.problem.spj_code
+ }
+
+ Submission.objects.filter(id=self.submission.id).update(result=JudgeStatus.JUDGING)
+
+ service_url = server.service_url
+ # not set service_url, it should be a linked container
+ if not service_url:
+ service_url = settings.DEFAULT_JUDGE_SERVER_SERVICE_URL
+ resp = self._request(urljoin(service_url, "/judge"), data=data)
+ if resp["err"]:
+ self.submission.result = JudgeStatus.COMPILE_ERROR
+ self.submission.statistic_info["err_info"] = resp["data"]
+ self.submission.statistic_info["score"] = 0
+ else:
+ resp["data"].sort(key=lambda x: int(x["test_case"]))
+ self.submission.info = resp
+ self._compute_statistic_info(resp["data"])
+ error_test_case = list(filter(lambda case: case["result"] != 0, resp["data"]))
+ # ACM模式下,多个测试点全部正确则AC,否则取第一个错误的测试点的状态
+ # OI模式下, 若多个测试点全部正确则AC, 若全部错误则取第一个错误测试点状态,否则为部分正确
+ if not error_test_case:
+ self.submission.result = JudgeStatus.ACCEPTED
+ elif self.problem.rule_type == ProblemRuleType.ACM or len(error_test_case) == len(resp["data"]):
+ self.submission.result = error_test_case[0]["result"]
+ else:
+ self.submission.result = JudgeStatus.PARTIALLY_ACCEPTED
+ self.submission.save()
+ self.release_judge_server(server.id)
+
+ if self.contest_id:
+ self.update_contest_problem_status()
+ self.update_contest_rank()
+ else:
+ self.update_problem_status()
+
+ # 至此判题结束,尝试处理任务队列中剩余的任务
+ process_pending_task()
+
+ def update_problem_status(self):
+ result = str(self.submission.result)
+ problem_id = str(self.problem.id)
+ with transaction.atomic():
+ # update problem status
+ problem = Problem.objects.select_for_update().get(contest_id=self.contest_id, id=self.problem.id)
+ problem.submission_number += 1
+ if self.submission.result == JudgeStatus.ACCEPTED:
+ problem.accepted_number += 1
+ problem_info = problem.statistic_info
+ problem_info[result] = problem_info.get(result, 0) + 1
+ problem.save(update_fields=["accepted_number", "submission_number", "statistic_info"])
+
+ # update_userprofile
+ user = User.objects.select_for_update().get(id=self.submission.user_id)
+ user_profile = user.userprofile
+ user_profile.submission_number += 1
+ if problem.rule_type == ProblemRuleType.ACM:
+ acm_problems_status = user_profile.acm_problems_status.get("problems", {})
+ if problem_id not in acm_problems_status:
+ acm_problems_status[problem_id] = {"status": self.submission.result, "_id": self.problem._id}
+ if self.submission.result == JudgeStatus.ACCEPTED:
+ user_profile.accepted_number += 1
+ elif acm_problems_status[problem_id]["status"] != JudgeStatus.ACCEPTED:
+ acm_problems_status[problem_id]["status"] = self.submission.result
+ if self.submission.result == JudgeStatus.ACCEPTED:
+ user_profile.accepted_number += 1
+ user_profile.acm_problems_status["problems"] = acm_problems_status
+ user_profile.save(update_fields=["submission_number", "accepted_number", "acm_problems_status"])
+
+ else:
+ oi_problems_status = user_profile.oi_problems_status.get("problems", {})
+ score = self.submission.statistic_info["score"]
+ if problem_id not in oi_problems_status:
+ user_profile.add_score(score)
+ oi_problems_status[problem_id] = {"status": self.submission.result,
+ "_id": self.problem._id,
+ "score": score}
+ if self.submission.result == JudgeStatus.ACCEPTED:
+ user_profile.accepted_number += 1
+ else:
+ if oi_problems_status[problem_id]["status"] == JudgeStatus.ACCEPTED and \
+ self.submission.result != JudgeStatus.ACCEPTED:
+ user_profile.accepted_number -= 1
+ elif oi_problems_status[problem_id]["status"] != JudgeStatus.ACCEPTED and \
+ self.submission.result == JudgeStatus:
+ user_profile.accepted_number += 1
+
+ # minus last time score, add this time score
+ user_profile.add_score(this_time_score=score,
+ last_time_score=oi_problems_status[problem_id]["score"])
+ oi_problems_status[problem_id]["score"] = score
+ oi_problems_status[problem_id]["status"] = self.submission.result
+ user_profile.oi_problems_status["problems"] = oi_problems_status
+ user_profile.save(update_fields=["submission_number", "accepted_number", "oi_problems_status"])
+
+ def update_contest_problem_status(self):
+ if self.contest_id and self.contest.status != ContestStatus.CONTEST_UNDERWAY:
+ logger.info("Contest debug mode, id: " + str(self.contest_id) + ", submission id: " + self.submission.id)
+ return
+ with transaction.atomic():
+ user = User.objects.select_for_update().get(id=self.submission.user_id)
+ user_profile = user.userprofile
+ problem_id = str(self.problem.id)
+ if self.contest.rule_type == ContestRuleType.ACM:
+ contest_problems_status = user_profile.acm_problems_status.get("contest_problems", {})
+ if problem_id not in contest_problems_status:
+ contest_problems_status[problem_id] = {"status": self.submission.result, "_id": self.problem._id}
+ elif contest_problems_status[problem_id]["status"] != JudgeStatus.ACCEPTED:
+ contest_problems_status[problem_id]["status"] = self.submission.result
+ else:
+ # 如果已AC, 直接跳过 不计入任何计数器
+ return
+ user_profile.acm_problems_status["contest_problems"] = contest_problems_status
+ user_profile.save(update_fields=["acm_problems_status"])
+
+ elif self.contest.rule_type == ContestRuleType.OI:
+ contest_problems_status = user_profile.oi_problems_status.get("contest_problems", {})
+ score = self.submission.statistic_info["score"]
+ if problem_id not in contest_problems_status:
+ contest_problems_status[problem_id] = {"status": self.submission.result,
+ "_id": self.problem._id,
+ "score": score}
+ else:
+ contest_problems_status[problem_id]["score"] = score
+ contest_problems_status[problem_id]["status"] = self.submission.result
+ user_profile.oi_problems_status["contest_problems"] = contest_problems_status
+ user_profile.save(update_fields=["oi_problems_status"])
+
+ problem = Problem.objects.select_for_update().get(contest_id=self.contest_id, id=self.problem.id)
+ result = str(self.submission.result)
+ problem_info = problem.statistic_info
+ problem_info[result] = problem_info.get(result, 0) + 1
+ problem.submission_number += 1
+ if self.submission.result == JudgeStatus.ACCEPTED:
+ problem.accepted_number += 1
+ problem.save(update_fields=["submission_number", "accepted_number", "statistic_info"])
+
+ def update_contest_rank(self):
+ if self.contest_id and self.contest.status != ContestStatus.CONTEST_UNDERWAY:
+ return
+ if self.contest.rule_type == ContestRuleType.OI or self.contest.real_time_rank:
+ cache.delete(f"{CacheKey.contest_rank_cache}:{self.contest.id}")
+ with transaction.atomic():
+ if self.contest.rule_type == ContestRuleType.ACM:
+ acm_rank, _ = ACMContestRank.objects.select_for_update(). \
+ get_or_create(user_id=self.submission.user_id, contest=self.contest)
+ self._update_acm_contest_rank(acm_rank)
+ else:
+ oi_rank, _ = OIContestRank.objects.select_for_update(). \
+ get_or_create(user_id=self.submission.user_id, contest=self.contest)
+ self._update_oi_contest_rank(oi_rank)
+
+ def _update_acm_contest_rank(self, rank):
+ info = rank.submission_info.get(str(self.submission.problem_id))
+ # 因前面更改过,这里需要重新获取
+ problem = Problem.objects.get(contest_id=self.contest_id, id=self.problem.id)
+ # 此题提交过
+ if info:
+ if info["is_ac"]:
+ return
+
+ rank.submission_number += 1
+ if self.submission.result == JudgeStatus.ACCEPTED:
+ rank.accepted_number += 1
+ info["is_ac"] = True
+ info["ac_time"] = (self.submission.create_time - self.contest.start_time).total_seconds()
+ rank.total_time += info["ac_time"] + info["error_number"] * 20 * 60
+
+ if problem.accepted_number == 1:
+ info["is_first_ac"] = True
+ else:
+ info["error_number"] += 1
+
+ # 第一次提交
+ else:
+ rank.submission_number += 1
+ info = {"is_ac": False, "ac_time": 0, "error_number": 0, "is_first_ac": False}
+ if self.submission.result == JudgeStatus.ACCEPTED:
+ rank.accepted_number += 1
+ info["is_ac"] = True
+ info["ac_time"] = (self.submission.create_time - self.contest.start_time).total_seconds()
+ rank.total_time += info["ac_time"]
+
+ if problem.accepted_number == 1:
+ info["is_first_ac"] = True
+
+ else:
+ info["error_number"] = 1
+ rank.submission_info[str(self.submission.problem_id)] = info
+ rank.save()
+
+ def _update_oi_contest_rank(self, rank):
+ problem_id = str(self.submission.problem_id)
+ current_score = self.submission.statistic_info["score"]
+ last_score = rank.submission_info.get(problem_id)
+ if last_score:
+ rank.total_score = rank.total_score - last_score + current_score
+ else:
+ rank.total_score = rank.total_score + current_score
+ rank.submission_info[problem_id] = current_score
+ rank.save()
diff --git a/judge/languages.py b/judge/languages.py
index 6002760..e7f35ff 100644
--- a/judge/languages.py
+++ b/judge/languages.py
@@ -1,7 +1,7 @@
_c_lang_config = {
- "template": """//PREPEND START
+ "template": """//PREPEND BEGIN
#include
//PREPEND END
@@ -12,7 +12,7 @@ int add(int a, int b) {
}
//TEMPLATE END
-//APPEND START
+//APPEND BEGIN
int main() {
printf("%d", add(1, 2));
return 0;
@@ -23,7 +23,7 @@ int main() {
"exe_name": "main",
"max_cpu_time": 3000,
"max_real_time": 5000,
- "max_memory": 128 * 1024 * 1024,
+ "max_memory": 256 * 1024 * 1024,
"compile_command": "/usr/bin/gcc -DONLINE_JUDGE -O2 -w -fmax-errors=3 -std=c99 {src_path} -lm -o {exe_path}",
},
"run": {
@@ -48,18 +48,29 @@ _c_lang_spj_config = {
}
_cpp_lang_config = {
- "template": """/*--PREPEND START--*/
-/*--PREPEND END--*/
-/*--TEMPLATE BEGIN--*/
-/*--TEMPLATE END--*/
-/*--APPEND START--*/
-/*--APPEND END--*/""",
+ "template": """//PREPEND BEGIN
+#include
+//PREPEND END
+
+//TEMPLATE BEGIN
+int add(int a, int b) {
+ // Please fill this blank
+ return ___________;
+}
+//TEMPLATE END
+
+//APPEND BEGIN
+int main() {
+ std::cout << add(1, 2);
+ return 0;
+}
+//APPEND END""",
"compile": {
"src_name": "main.cpp",
"exe_name": "main",
"max_cpu_time": 3000,
"max_real_time": 5000,
- "max_memory": 128 * 1024 * 1024,
+ "max_memory": 512 * 1024 * 1024,
"compile_command": "/usr/bin/g++ -DONLINE_JUDGE -O2 -w -fmax-errors=3 -std=c++11 {src_path} -lm -o {exe_path}",
},
"run": {
@@ -99,8 +110,8 @@ _java_lang_config = {
"compile_command": "/usr/bin/javac {src_path} -d {exe_dir} -encoding UTF8"
},
"run": {
- "command": "/usr/bin/java -cp {exe_dir} -Xss1M -XX:MaxPermSize=16M -XX:PermSize=8M -Xms16M -Xmx{max_memory}k "
- "-Djava.security.manager -Djava.security.policy==/etc/java_policy -Djava.awt.headless=true Main",
+ "command": "/usr/bin/java -cp {exe_dir} -Xss1M -Xms16M -Xmx{max_memory}k "
+ "-Djava.security.manager -Djava.security.policy=/etc/java_policy -Djava.awt.headless=true Main",
"seccomp_rule": None,
"env": ["MALLOC_ARENA_MAX=1"]
}
diff --git a/judge/tasks.py b/judge/tasks.py
new file mode 100644
index 0000000..eda9e0f
--- /dev/null
+++ b/judge/tasks.py
@@ -0,0 +1,8 @@
+from __future__ import absolute_import, unicode_literals
+from celery import shared_task
+from judge.dispatcher import JudgeDispatcher
+
+
+@shared_task
+def judge_task(submission_id, problem_id):
+ JudgeDispatcher(submission_id, problem_id).judge()
diff --git a/oj/__init__.py b/oj/__init__.py
index e69de29..23fc183 100644
--- a/oj/__init__.py
+++ b/oj/__init__.py
@@ -0,0 +1,6 @@
+from __future__ import absolute_import, unicode_literals
+
+# Django starts so that shared_task will use this app.
+from .celery import app as celery_app
+
+__all__ = ["celery_app"]
diff --git a/oj/celery.py b/oj/celery.py
new file mode 100644
index 0000000..4f24c7e
--- /dev/null
+++ b/oj/celery.py
@@ -0,0 +1,18 @@
+from __future__ import absolute_import, unicode_literals
+import os
+from celery import Celery
+from django.conf import settings
+
+# set the default Django settings module for the "celery" program.
+os.environ.setdefault("DJANGO_SETTINGS_MODULE", "oj.settings")
+
+
+app = Celery("oj")
+
+# Using a string here means the worker will not have to
+# pickle the object when using Windows.
+app.config_from_object("django.conf:settings")
+
+# load task modules from all registered Django app configs.
+app.autodiscover_tasks(lambda: settings.INSTALLED_APPS)
+# app.autodiscover_tasks()
diff --git a/oj/db_router.py b/oj/db_router.py
deleted file mode 100644
index 823d784..0000000
--- a/oj/db_router.py
+++ /dev/null
@@ -1,19 +0,0 @@
-class DBRouter(object):
- def db_for_read(self, model, **hints):
- if model._meta.app_label == "submission":
- return "submission"
- return "default"
-
- def db_for_write(self, model, **hints):
- if model._meta.app_label == "submission":
- return "submission"
- return "default"
-
- def allow_relation(self, obj1, obj2, **hints):
- return True
-
- def allow_migrate(self, db, app_label, model=None, **hints):
- if app_label == "submission":
- return db == app_label
- else:
- return db == "default"
diff --git a/oj/dev_settings.py b/oj/dev_settings.py
new file mode 100644
index 0000000..5c75a0b
--- /dev/null
+++ b/oj/dev_settings.py
@@ -0,0 +1,27 @@
+# coding=utf-8
+import os
+
+BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+
+DATABASES = {
+ 'default': {
+ 'ENGINE': 'django.db.backends.postgresql_psycopg2',
+ 'HOST': '127.0.0.1',
+ 'PORT': 5433,
+ 'NAME': "onlinejudge",
+ 'USER': "onlinejudge",
+ 'PASSWORD': 'onlinejudge'
+ }
+}
+
+REDIS_CONF = {
+ "host": "127.0.0.1",
+ "port": "6379"
+}
+
+
+DEBUG = True
+
+ALLOWED_HOSTS = ["*"]
+
+DATA_DIR = f"{BASE_DIR}/data"
diff --git a/oj/local_settings.py b/oj/local_settings.py
deleted file mode 100644
index 79dc44e..0000000
--- a/oj/local_settings.py
+++ /dev/null
@@ -1,31 +0,0 @@
-# coding=utf-8
-import os
-
-BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
-
-DATABASES = {
- 'default': {
- 'ENGINE': 'django.db.backends.sqlite3',
- 'NAME': os.path.join(BASE_DIR, 'db.sqlite3'),
- }
-}
-
-REDIS_CACHE = {
- "host": "127.0.0.1",
- "port": 6379,
- "db": 1
-}
-
-REDIS_QUEUE = {
- "host": "127.0.0.1",
- "port": 6379,
- "db": 2
-}
-
-DEBUG = True
-
-ALLOWED_HOSTS = ["*"]
-
-TEST_CASE_DIR = "/tmp"
-
-LOG_PATH = "log/"
diff --git a/oj/production_settings.py b/oj/production_settings.py
new file mode 100644
index 0000000..026c53a
--- /dev/null
+++ b/oj/production_settings.py
@@ -0,0 +1,28 @@
+import os
+
+
+def get_env(name, default=""):
+ return os.environ.get(name, default)
+
+
+DATABASES = {
+ 'default': {
+ 'ENGINE': 'django.db.backends.postgresql_psycopg2',
+ 'HOST': get_env("POSTGRES_HOST", "oj-postgres"),
+ 'PORT': get_env("POSTGRES_PORT", "5432"),
+ 'NAME': get_env("POSTGRES_DB"),
+ 'USER': get_env("POSTGRES_USER"),
+ 'PASSWORD': get_env("POSTGRES_PASSWORD")
+ }
+}
+
+REDIS_CONF = {
+ "host": get_env("REDIS_HOST", "oj-redis"),
+ "port": get_env("REDIS_PORT", "6379")
+}
+
+DEBUG = False
+
+ALLOWED_HOSTS = ['*']
+
+DATA_DIR = "/data"
diff --git a/oj/server_settings.py b/oj/server_settings.py
deleted file mode 100644
index 7434ee7..0000000
--- a/oj/server_settings.py
+++ /dev/null
@@ -1,37 +0,0 @@
-# coding=utf-8
-import os
-
-BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
-
-DATABASES = {
- 'default': {
- 'ENGINE': 'django.db.backends.mysql',
- 'NAME': "oj",
- 'CONN_MAX_AGE': 0.1,
- 'HOST': os.environ["MYSQL_PORT_3306_TCP_ADDR"],
- 'PORT': 3306,
- 'USER': os.environ["MYSQL_ENV_MYSQL_USER"],
- 'PASSWORD': os.environ["MYSQL_ENV_MYSQL_ROOT_PASSWORD"]
- }
-}
-
-REDIS_CACHE = {
- "host": os.environ["REDIS_PORT_6379_TCP_ADDR"],
- "port": 6379,
- "db": 1
-}
-
-REDIS_QUEUE = {
- "host": os.environ["REDIS_PORT_6379_TCP_ADDR"],
- "port": 6379,
- "db": 2
-}
-
-DEBUG = False
-
-ALLOWED_HOSTS = ['*']
-
-
-TEST_CASE_DIR = "/test_case"
-
-LOG_PATH = "log/"
diff --git a/oj/settings.py b/oj/settings.py
index e7199d9..40335a1 100644
--- a/oj/settings.py
+++ b/oj/settings.py
@@ -1,8 +1,7 @@
-# coding=utf-8
"""
Django settings for oj project.
-Generated by 'django-admin startproject' using Django 1.8.
+Generated by 'django-admin startproject' using Django 1.11.
For more information on this file, see
https://docs.djangoproject.com/en/1.8/topics/settings/
@@ -10,59 +9,54 @@ https://docs.djangoproject.com/en/1.8/topics/settings/
For the full list of settings and their values, see
https://docs.djangoproject.com/en/1.8/ref/settings/
"""
-# Build paths inside the project like this: os.path.join(BASE_DIR, ...)
import os
+from copy import deepcopy
+
+if os.environ.get("OJ_ENV") == "production":
+ from .production_settings import *
+else:
+ from .dev_settings import *
from .custom_settings import *
-# 判断运行环境
-ENV = os.environ.get("oj_env", "local")
-
-if ENV == "local":
- from .local_settings import *
-elif ENV == "server":
- from .server_settings import *
-
-
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
-
-# Quick-start development settings - unsuitable for production
-# See https://docs.djangoproject.com/en/1.8/howto/deployment/checklist/
-
-# Application definition
-
-INSTALLED_APPS = (
+# Applications
+VENDOR_APPS = (
'django.contrib.auth',
- 'django.contrib.contenttypes',
'django.contrib.sessions',
+ 'django.contrib.contenttypes',
'django.contrib.messages',
'django.contrib.staticfiles',
-
+ 'rest_framework',
+)
+LOCAL_APPS = (
'account',
'announcement',
'conf',
'problem',
'contest',
'utils',
-
- 'rest_framework',
+ 'submission',
+ 'options',
+ 'judge',
)
+INSTALLED_APPS = VENDOR_APPS + LOCAL_APPS
+
MIDDLEWARE_CLASSES = (
'django.contrib.sessions.middleware.SessionMiddleware',
'django.middleware.common.CommonMiddleware',
'django.middleware.csrf.CsrfViewMiddleware',
'django.contrib.auth.middleware.AuthenticationMiddleware',
- 'django.contrib.auth.middleware.SessionAuthenticationMiddleware',
+ 'account.middleware.APITokenAuthMiddleware',
'django.contrib.messages.middleware.MessageMiddleware',
'django.middleware.clickjacking.XFrameOptionsMiddleware',
'django.middleware.security.SecurityMiddleware',
'account.middleware.AdminRoleRequiredMiddleware',
- 'account.middleware.SessionSecurityMiddleware',
- 'account.middleware.TimezoneMiddleware'
+ 'account.middleware.SessionRecordMiddleware',
+ # 'account.middleware.LogSqlMiddleware',
)
-
ROOT_URLCONF = 'oj.urls'
TEMPLATES = [
@@ -80,9 +74,26 @@ TEMPLATES = [
},
},
]
-
WSGI_APPLICATION = 'oj.wsgi.application'
+# Password validation
+# https://docs.djangoproject.com/en/1.9/ref/settings/#auth-password-validators
+
+AUTH_PASSWORD_VALIDATORS = [
+ {
+ 'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator',
+ },
+ {
+ 'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator',
+ },
+ {
+ 'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator',
+ },
+ {
+ 'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator',
+ },
+]
+
# Internationalization
# https://docs.djangoproject.com/en/1.8/topics/i18n/
@@ -96,60 +107,58 @@ USE_L10N = True
USE_TZ = True
-
# Static files (CSS, JavaScript, Images)
# https://docs.djangoproject.com/en/1.8/howto/static-files/
-STATIC_URL = '/static/'
+STATIC_URL = '/public/'
AUTH_USER_MODEL = 'account.User'
-LOGGING = {
- 'version': 1,
- 'disable_existing_loggers': True,
- 'formatters': {
- 'standard': {
- 'format': '%(asctime)s [%(threadName)s:%(thread)d] [%(name)s:%(lineno)d] [%(module)s:%(funcName)s] [%(levelname)s]- %(message)s'}
- # 日志格式
- },
- 'handlers': {
- 'django_error': {
- 'level': 'DEBUG',
- 'class': 'logging.handlers.RotatingFileHandler',
- 'filename': os.path.join(LOG_PATH, 'django.log'),
- 'formatter': 'standard'
- },
- 'app_info': {
- 'level': 'DEBUG',
- 'class': 'logging.handlers.RotatingFileHandler',
- 'filename': os.path.join(LOG_PATH, 'app_info.log'),
- 'formatter': 'standard'
- },
- 'console': {
- 'level': 'DEBUG',
- 'class': 'logging.StreamHandler',
- 'formatter': 'standard'
- }
- },
- 'loggers': {
- 'app_info': {
- 'handlers': ['app_info', "console"],
- 'level': 'DEBUG',
- 'propagate': True
- },
- 'django.request': {
- 'handlers': ['django_error', 'console'],
- 'level': 'DEBUG',
- 'propagate': True,
- },
- 'django.db.backends': {
- 'handlers': ['console'],
- 'level': 'ERROR',
- 'propagate': True,
- }
- },
-}
+TEST_CASE_DIR = os.path.join(DATA_DIR, "test_case")
+LOG_PATH = os.path.join(DATA_DIR, "log")
+AVATAR_URI_PREFIX = "/public/avatar"
+AVATAR_UPLOAD_DIR = f"{DATA_DIR}{AVATAR_URI_PREFIX}"
+
+UPLOAD_PREFIX = "/public/upload"
+UPLOAD_DIR = f"{DATA_DIR}{UPLOAD_PREFIX}"
+
+STATICFILES_DIRS = [os.path.join(DATA_DIR, "public")]
+
+LOGGING = {
+ 'version': 1,
+ 'disable_existing_loggers': False,
+ 'formatters': {
+ 'standard': {
+ 'format': '[%(asctime)s] - [%(levelname)s] - [%(name)s:%(lineno)d] - %(message)s',
+ 'datefmt': '%Y-%m-%d %H:%M:%S'
+ }
+ },
+ 'handlers': {
+ 'console': {
+ 'level': 'DEBUG',
+ 'class': 'logging.StreamHandler',
+ 'formatter': 'standard'
+ }
+ },
+ 'loggers': {
+ 'django.request': {
+ 'handlers': ['console'],
+ 'level': 'ERROR',
+ 'propagate': True,
+ },
+ 'django.db.backends': {
+ 'handlers': ['console'],
+ 'level': 'ERROR',
+ 'propagate': True,
+ },
+ '': {
+ 'handlers': ['console'],
+ 'level': 'WARNING',
+ 'propagate': True,
+ }
+ },
+}
REST_FRAMEWORK = {
'TEST_REQUEST_DEFAULT_FORMAT': 'json',
@@ -158,17 +167,37 @@ REST_FRAMEWORK = {
)
}
-# for celery
-BROKER_URL = 'redis://%s:%s/%s' % (REDIS_QUEUE["host"], str(REDIS_QUEUE["port"]), str(REDIS_QUEUE["db"]))
+REDIS_URL = "redis://%s:%s" % (REDIS_CONF["host"], REDIS_CONF["port"])
+
+
+def redis_config(db):
+ def make_key(key, key_prefix, version):
+ return key
+
+ return {
+ "BACKEND": "utils.cache.MyRedisCache",
+ "LOCATION": f"{REDIS_URL}/{db}",
+ "TIMEOUT": None,
+ "KEY_PREFIX": "",
+ "KEY_FUNCTION": make_key
+ }
+
+
+CACHES = {
+ "default": redis_config(db=1)
+}
+
+SESSION_ENGINE = "django.contrib.sessions.backends.cache"
+SESSION_CACHE_ALIAS = "default"
+
+CELERY_RESULT_BACKEND = f"{REDIS_URL}/2"
+BROKER_URL = f"{REDIS_URL}/3"
+CELERY_TASK_SOFT_TIME_LIMIT = CELERY_TASK_TIME_LIMIT = 180
CELERY_ACCEPT_CONTENT = ["json"]
CELERY_TASK_SERIALIZER = "json"
-DATABASE_ROUTERS = ['oj.db_router.DBRouter']
-
-IMAGE_UPLOAD_DIR = os.path.join(BASE_DIR, 'upload/')
-
# 用于限制用户恶意提交大量代码
-TOKEN_BUCKET_DEFAULT_CAPACITY = 50
+TOKEN_BUCKET_DEFAULT_CAPACITY = 20
# 单位:每分钟
TOKEN_BUCKET_FILL_RATE = 2
diff --git a/oj/urls.py b/oj/urls.py
index 79e6b03..c626656 100644
--- a/oj/urls.py
+++ b/oj/urls.py
@@ -3,12 +3,15 @@ from django.conf.urls import include, url
urlpatterns = [
url(r"^api/", include("account.urls.oj")),
url(r"^api/admin/", include("account.urls.admin")),
- url(r"^api/account/", include("account.urls.user")),
+ url(r"^api/", include("announcement.urls.oj")),
url(r"^api/admin/", include("announcement.urls.admin")),
url(r"^api/", include("conf.urls.oj")),
url(r"^api/admin/", include("conf.urls.admin")),
url(r"^api/", include("problem.urls.oj")),
url(r"^api/admin/", include("problem.urls.admin")),
+ url(r"^api/", include("contest.urls.oj")),
url(r"^api/admin/", include("contest.urls.admin")),
- url(r"^api/", include("contest.urls.oj"))
+ url(r"^api/", include("submission.urls.oj")),
+ url(r"^api/admin/", include("submission.urls.admin")),
+ url(r"^api/admin/", include("utils.urls")),
]
diff --git a/options/__init__.py b/options/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/options/migrations/0001_initial.py b/options/migrations/0001_initial.py
new file mode 100644
index 0000000..a109c91
--- /dev/null
+++ b/options/migrations/0001_initial.py
@@ -0,0 +1,25 @@
+# -*- coding: utf-8 -*-
+# Generated by Django 1.11.4 on 2017-10-23 08:11
+from __future__ import unicode_literals
+
+import django.contrib.postgres.fields.jsonb
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+
+ initial = True
+
+ dependencies = [
+ ]
+
+ operations = [
+ migrations.CreateModel(
+ name='SysOptions',
+ fields=[
+ ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
+ ('key', models.CharField(db_index=True, max_length=128, unique=True)),
+ ('value', django.contrib.postgres.fields.jsonb.JSONField()),
+ ],
+ ),
+ ]
diff --git a/options/migrations/__init__.py b/options/migrations/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/options/models.py b/options/models.py
new file mode 100644
index 0000000..04dee5e
--- /dev/null
+++ b/options/models.py
@@ -0,0 +1,7 @@
+from django.db import models
+from utils.models import JSONField
+
+
+class SysOptions(models.Model):
+ key = models.CharField(max_length=128, unique=True, db_index=True)
+ value = JSONField()
diff --git a/options/options.py b/options/options.py
new file mode 100644
index 0000000..7d8b9a9
--- /dev/null
+++ b/options/options.py
@@ -0,0 +1,185 @@
+import os
+from django.core.cache import cache
+from django.db import transaction, IntegrityError
+
+from utils.constants import CacheKey
+from utils.shortcuts import rand_str
+from .models import SysOptions as SysOptionsModel
+
+
+def default_token():
+ token = os.environ.get("JUDGE_SERVER_TOKEN")
+ return token if token else rand_str()
+
+
+class OptionKeys:
+ website_base_url = "website_base_url"
+ website_name = "website_name"
+ website_name_shortcut = "website_name_shortcut"
+ website_footer = "website_footer"
+ allow_register = "allow_register"
+ submission_list_show_all = "submission_list_show_all"
+ smtp_config = "smtp_config"
+ judge_server_token = "judge_server_token"
+
+
+class OptionDefaultValue:
+ website_base_url = "http://127.0.0.1"
+ website_name = "Online Judge"
+ website_name_shortcut = "oj"
+ website_footer = "Online Judge Footer"
+ allow_register = True
+ submission_list_show_all = True
+ smtp_config = {}
+ judge_server_token = default_token
+
+
+class _SysOptionsMeta(type):
+ @classmethod
+ def _set_cache(mcs, option_key, option_value):
+ cache.set(f"{CacheKey.option}:{option_key}", option_value, timeout=60)
+
+ @classmethod
+ def _del_cache(mcs, option_key):
+ cache.delete(f"{CacheKey.option}:{option_key}")
+
+ @classmethod
+ def _get_keys(cls):
+ return [key for key in OptionKeys.__dict__ if not key.startswith("__")]
+
+ def rebuild_cache(cls):
+ for key in cls._get_keys():
+ # get option 的时候会写 cache 的
+ cls._get_option(key, use_cache=False)
+
+ @classmethod
+ def _init_option(mcs):
+ for item in mcs._get_keys():
+ if not SysOptionsModel.objects.filter(key=item).exists():
+ default_value = getattr(OptionDefaultValue, item)
+ if callable(default_value):
+ default_value = default_value()
+ try:
+ SysOptionsModel.objects.create(key=item, value=default_value)
+ except IntegrityError:
+ pass
+
+ @classmethod
+ def _get_option(mcs, option_key, use_cache=True):
+ try:
+ if use_cache:
+ option = cache.get(f"{CacheKey.option}:{option_key}")
+ if option:
+ return option
+ option = SysOptionsModel.objects.get(key=option_key)
+ value = option.value
+ mcs._set_cache(option_key, value)
+ return value
+ except SysOptionsModel.DoesNotExist:
+ mcs._init_option()
+ return mcs._get_option(option_key, use_cache=use_cache)
+
+ @classmethod
+ def _set_option(mcs, option_key: str, option_value):
+ try:
+ with transaction.atomic():
+ option = SysOptionsModel.objects.select_for_update().get(key=option_key)
+ option.value = option_value
+ option.save()
+ mcs._del_cache(option_key)
+ except SysOptionsModel.DoesNotExist:
+ mcs._init_option()
+ mcs._set_option(option_key, option_value)
+
+ @classmethod
+ def _increment(mcs, option_key):
+ try:
+ with transaction.atomic():
+ option = SysOptionsModel.objects.select_for_update().get(key=option_key)
+ value = option.value + 1
+ option.value = value
+ option.save()
+ mcs._del_cache(option_key)
+ except SysOptionsModel.DoesNotExist:
+ mcs._init_option()
+ return mcs._increment(option_key)
+
+ @classmethod
+ def set_options(mcs, options):
+ for key, value in options:
+ mcs._set_option(key, value)
+
+ @classmethod
+ def get_options(mcs, keys):
+ result = {}
+ for key in keys:
+ result[key] = mcs._get_option(key)
+ return result
+
+ @property
+ def website_base_url(cls):
+ return cls._get_option(OptionKeys.website_base_url)
+
+ @website_base_url.setter
+ def website_base_url(cls, value):
+ cls._set_option(OptionKeys.website_base_url, value)
+
+ @property
+ def website_name(cls):
+ return cls._get_option(OptionKeys.website_name)
+
+ @website_name.setter
+ def website_name(cls, value):
+ cls._set_option(OptionKeys.website_name, value)
+
+ @property
+ def website_name_shortcut(cls):
+ return cls._get_option(OptionKeys.website_name_shortcut)
+
+ @website_name_shortcut.setter
+ def website_name_shortcut(cls, value):
+ cls._set_option(OptionKeys.website_name_shortcut, value)
+
+ @property
+ def website_footer(cls):
+ return cls._get_option(OptionKeys.website_footer)
+
+ @website_footer.setter
+ def website_footer(cls, value):
+ cls._set_option(OptionKeys.website_footer, value)
+
+ @property
+ def allow_register(cls):
+ return cls._get_option(OptionKeys.allow_register)
+
+ @allow_register.setter
+ def allow_register(cls, value):
+ cls._set_option(OptionKeys.allow_register, value)
+
+ @property
+ def submission_list_show_all(cls):
+ return cls._get_option(OptionKeys.submission_list_show_all)
+
+ @submission_list_show_all.setter
+ def submission_list_show_all(cls, value):
+ cls._set_option(OptionKeys.submission_list_show_all, value)
+
+ @property
+ def smtp_config(cls):
+ return cls._get_option(OptionKeys.smtp_config)
+
+ @smtp_config.setter
+ def smtp_config(cls, value):
+ cls._set_option(OptionKeys.smtp_config, value)
+
+ @property
+ def judge_server_token(cls):
+ return cls._get_option(OptionKeys.judge_server_token)
+
+ @judge_server_token.setter
+ def judge_server_token(cls, value):
+ cls._set_option(OptionKeys.judge_server_token, value)
+
+
+class SysOptions(metaclass=_SysOptionsMeta):
+ pass
diff --git a/options/tests.py b/options/tests.py
new file mode 100644
index 0000000..a39b155
--- /dev/null
+++ b/options/tests.py
@@ -0,0 +1 @@
+# Create your tests here.
diff --git a/options/views.py b/options/views.py
new file mode 100644
index 0000000..60f00ef
--- /dev/null
+++ b/options/views.py
@@ -0,0 +1 @@
+# Create your views here.
diff --git a/problem/migrations/0004_auto_20170501_0637.py b/problem/migrations/0004_auto_20170501_0637.py
new file mode 100644
index 0000000..5d55ace
--- /dev/null
+++ b/problem/migrations/0004_auto_20170501_0637.py
@@ -0,0 +1,35 @@
+# -*- coding: utf-8 -*-
+# Generated by Django 1.9.6 on 2017-05-01 06:37
+from __future__ import unicode_literals
+
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ('problem', '0003_auto_20170217_0820'),
+ ]
+
+ operations = [
+ migrations.AlterField(
+ model_name='contestproblem',
+ name='total_accepted_number',
+ field=models.BigIntegerField(default=0),
+ ),
+ migrations.AlterField(
+ model_name='contestproblem',
+ name='total_submit_number',
+ field=models.BigIntegerField(default=0),
+ ),
+ migrations.AlterField(
+ model_name='problem',
+ name='total_accepted_number',
+ field=models.BigIntegerField(default=0),
+ ),
+ migrations.AlterField(
+ model_name='problem',
+ name='total_submit_number',
+ field=models.BigIntegerField(default=0),
+ ),
+ ]
diff --git a/problem/migrations/0005_auto_20170815_1258.py b/problem/migrations/0005_auto_20170815_1258.py
new file mode 100644
index 0000000..1949696
--- /dev/null
+++ b/problem/migrations/0005_auto_20170815_1258.py
@@ -0,0 +1,26 @@
+# -*- coding: utf-8 -*-
+# Generated by Django 1.9.6 on 2017-08-15 12:58
+from __future__ import unicode_literals
+
+from django.db import migrations
+import jsonfield.fields
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ('problem', '0004_auto_20170501_0637'),
+ ]
+
+ operations = [
+ migrations.AddField(
+ model_name='contestproblem',
+ name='statistic_info',
+ field=jsonfield.fields.JSONField(default={}),
+ ),
+ migrations.AddField(
+ model_name='problem',
+ name='statistic_info',
+ field=jsonfield.fields.JSONField(default={}),
+ ),
+ ]
diff --git a/problem/migrations/0006_auto_20170823_0918.py b/problem/migrations/0006_auto_20170823_0918.py
new file mode 100644
index 0000000..933070f
--- /dev/null
+++ b/problem/migrations/0006_auto_20170823_0918.py
@@ -0,0 +1,35 @@
+# -*- coding: utf-8 -*-
+# Generated by Django 1.9.6 on 2017-08-23 09:18
+from __future__ import unicode_literals
+
+from django.db import migrations
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ('problem', '0005_auto_20170815_1258'),
+ ]
+
+ operations = [
+ migrations.RenameField(
+ model_name='contestproblem',
+ old_name='total_accepted_number',
+ new_name='accepted_number',
+ ),
+ migrations.RenameField(
+ model_name='contestproblem',
+ old_name='total_submit_number',
+ new_name='submission_number',
+ ),
+ migrations.RenameField(
+ model_name='problem',
+ old_name='total_accepted_number',
+ new_name='accepted_number',
+ ),
+ migrations.RenameField(
+ model_name='problem',
+ old_name='total_submit_number',
+ new_name='submission_number',
+ ),
+ ]
diff --git a/problem/migrations/0008_auto_20170923_1318.py b/problem/migrations/0008_auto_20170923_1318.py
new file mode 100644
index 0000000..4f5bb99
--- /dev/null
+++ b/problem/migrations/0008_auto_20170923_1318.py
@@ -0,0 +1,66 @@
+# -*- coding: utf-8 -*-
+# Generated by Django 1.11.4 on 2017-09-23 13:18
+from __future__ import unicode_literals
+
+from django.db import migrations, models
+import django.db.models.deletion
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ('contest', '0005_auto_20170823_0918'),
+ ('problem', '0006_auto_20170823_0918'),
+ ]
+
+ operations = [
+ migrations.AddField(
+ model_name='contestproblem',
+ name='total_score',
+ field=models.IntegerField(blank=True, default=0),
+ ),
+ migrations.AddField(
+ model_name='problem',
+ name='total_score',
+ field=models.IntegerField(blank=True, default=0),
+ ),
+ migrations.AlterUniqueTogether(
+ name='contestproblem',
+ unique_together=set([]),
+ ),
+ migrations.RemoveField(
+ model_name='contestproblem',
+ name='contest',
+ ),
+ migrations.RemoveField(
+ model_name='contestproblem',
+ name='created_by',
+ ),
+ migrations.RemoveField(
+ model_name='contestproblem',
+ name='tags',
+ ),
+ migrations.AddField(
+ model_name='problem',
+ name='contest',
+ field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, to='contest.Contest'),
+ preserve_default=False,
+ ),
+ migrations.AddField(
+ model_name='problem',
+ name='is_public',
+ field=models.BooleanField(default=False),
+ ),
+ migrations.AlterField(
+ model_name='problem',
+ name='_id',
+ field=models.CharField(db_index=True, max_length=24),
+ ),
+ migrations.AlterUniqueTogether(
+ name='problem',
+ unique_together=set([('_id', 'contest')]),
+ ),
+ migrations.DeleteModel(
+ name='ContestProblem',
+ ),
+ ]
diff --git a/problem/migrations/0009_auto_20171011_1214.py b/problem/migrations/0009_auto_20171011_1214.py
new file mode 100644
index 0000000..e219ff2
--- /dev/null
+++ b/problem/migrations/0009_auto_20171011_1214.py
@@ -0,0 +1,45 @@
+# -*- coding: utf-8 -*-
+# Generated by Django 1.11.4 on 2017-10-11 12:14
+from __future__ import unicode_literals
+
+import django.contrib.postgres.fields.jsonb
+from django.db import migrations
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ('problem', '0008_auto_20170923_1318'),
+ ]
+
+ operations = [
+ migrations.AlterField(
+ model_name='problem',
+ name='languages',
+ field=django.contrib.postgres.fields.jsonb.JSONField(),
+ ),
+ migrations.AlterField(
+ model_name='problem',
+ name='samples',
+ field=django.contrib.postgres.fields.jsonb.JSONField(),
+ ),
+ migrations.AlterField(
+ model_name='problem',
+ name='statistic_info',
+ field=django.contrib.postgres.fields.jsonb.JSONField(default=dict),
+ ),
+ migrations.AlterField(
+ model_name='problem',
+ name='template',
+ field=django.contrib.postgres.fields.jsonb.JSONField(),
+ ),
+ migrations.AlterField(
+ model_name='problem',
+ name='test_case_score',
+ field=django.contrib.postgres.fields.jsonb.JSONField(),
+ ),
+ migrations.AlterModelOptions(
+ name='problem',
+ options={'ordering': ('create_time',)},
+ ),
+ ]
diff --git a/problem/migrations/0010_problem_spj_compile_ok.py b/problem/migrations/0010_problem_spj_compile_ok.py
new file mode 100644
index 0000000..0df1b36
--- /dev/null
+++ b/problem/migrations/0010_problem_spj_compile_ok.py
@@ -0,0 +1,20 @@
+# -*- coding: utf-8 -*-
+# Generated by Django 1.11.4 on 2017-11-16 12:42
+from __future__ import unicode_literals
+
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ('problem', '0009_auto_20171011_1214'),
+ ]
+
+ operations = [
+ migrations.AddField(
+ model_name='problem',
+ name='spj_compile_ok',
+ field=models.BooleanField(default=False),
+ ),
+ ]
diff --git a/problem/models.py b/problem/models.py
index 17d4e5d..5116f4b 100644
--- a/problem/models.py
+++ b/problem/models.py
@@ -1,5 +1,5 @@
from django.db import models
-from jsonfield import JSONField
+from utils.models import JSONField
from account.models import User
from contest.models import Contest
@@ -18,7 +18,18 @@ class ProblemRuleType(object):
OI = "OI"
-class AbstractProblem(models.Model):
+class ProblemDifficulty(object):
+ High = "High"
+ Mid = "Mid"
+ Low = "Low"
+
+
+class Problem(models.Model):
+ # display ID
+ _id = models.CharField(max_length=24, db_index=True)
+ contest = models.ForeignKey(Contest, null=True, blank=True)
+ # for contest problem
+ is_public = models.BooleanField(default=False)
title = models.CharField(max_length=128)
# HTML
description = RichTextField()
@@ -27,6 +38,7 @@ class AbstractProblem(models.Model):
# [{input: "test", output: "123"}, {input: "test123", output: "456"}]
samples = JSONField()
test_case_id = models.CharField(max_length=32)
+ # [{"input_name": "1.in", "output_name": "1.out", "score": 0}]
test_case_score = JSONField()
hint = RichTextField(blank=True, null=True)
languages = JSONField()
@@ -44,37 +56,28 @@ class AbstractProblem(models.Model):
spj_language = models.CharField(max_length=32, blank=True, null=True)
spj_code = models.TextField(blank=True, null=True)
spj_version = models.CharField(max_length=32, blank=True, null=True)
+ spj_compile_ok = models.BooleanField(default=False)
rule_type = models.CharField(max_length=32)
visible = models.BooleanField(default=True)
difficulty = models.CharField(max_length=32)
tags = models.ManyToManyField(ProblemTag)
source = models.CharField(max_length=200, blank=True, null=True)
- total_submit_number = models.IntegerField(default=0)
- total_accepted_number = models.IntegerField(default=0)
+ # for OI mode
+ total_score = models.IntegerField(default=0, blank=True)
+ submission_number = models.BigIntegerField(default=0)
+ accepted_number = models.BigIntegerField(default=0)
+ # {JudgeStatus.ACCEPTED: 3, JudgeStaus.WRONG_ANSWER: 11}, the number means count
+ statistic_info = JSONField(default=dict)
class Meta:
db_table = "problem"
- abstract = True
+ unique_together = (("_id", "contest"),)
+ ordering = ("create_time",)
def add_submission_number(self):
- self.accepted_problem_number = models.F("total_submit_number") + 1
- self.save()
+ self.submission_number = models.F("submission_number") + 1
+ self.save(update_fields=["submission_number"])
def add_ac_number(self):
- self.accepted_problem_number = models.F("total_accepted_number") + 1
- self.save()
-
-
-class Problem(AbstractProblem):
- _id = models.CharField(max_length=24, unique=True, db_index=True)
-
-
-class ContestProblem(AbstractProblem):
- _id = models.CharField(max_length=24, db_index=True)
- contest = models.ForeignKey(Contest)
- # 是否已经公开了题目,防止重复公开
- is_public = models.BooleanField(default=False)
-
- class Meta:
- db_table = "contest_problem"
- unique_together = (("_id", "contest"), )
+ self.accepted_number = models.F("accepted_number") + 1
+ self.save(update_fields=["accepted_number"])
diff --git a/problem/serializers.py b/problem/serializers.py
index 0425f3d..46b2e22 100644
--- a/problem/serializers.py
+++ b/problem/serializers.py
@@ -4,6 +4,7 @@ from judge.languages import language_names, spj_language_names
from utils.api import DateTimeTZField, UsernameSerializer, serializers
from .models import Problem, ProblemRuleType, ProblemTag
+from .utils import parse_problem_template
class TestCaseUploadForm(forms.Form):
@@ -12,8 +13,8 @@ class TestCaseUploadForm(forms.Form):
class CreateSampleSerializer(serializers.Serializer):
- input = serializers.CharField()
- output = serializers.CharField()
+ input = serializers.CharField(trim_whitespace=False)
+ output = serializers.CharField(trim_whitespace=False)
class CreateTestCaseScoreSerializer(serializers.Serializer):
@@ -39,7 +40,7 @@ class CreateOrEditProblemSerializer(serializers.Serializer):
input_description = serializers.CharField()
output_description = serializers.CharField()
samples = serializers.ListField(child=CreateSampleSerializer(), allow_empty=False)
- test_case_id = serializers.CharField(min_length=32, max_length=32)
+ test_case_id = serializers.CharField(max_length=32)
test_case_score = serializers.ListField(child=CreateTestCaseScoreSerializer(), allow_empty=False)
time_limit = serializers.IntegerField(min_value=1, max_value=1000 * 60)
memory_limit = serializers.IntegerField(min_value=1, max_value=1024)
@@ -49,6 +50,7 @@ class CreateOrEditProblemSerializer(serializers.Serializer):
spj = serializers.BooleanField()
spj_language = serializers.ChoiceField(choices=spj_language_names, allow_blank=True, allow_null=True)
spj_code = serializers.CharField(allow_blank=True, allow_null=True)
+ spj_compile_ok = serializers.BooleanField(default=False)
visible = serializers.BooleanField()
difficulty = serializers.ChoiceField(choices=[Difficulty.LOW, Difficulty.MID, Difficulty.HIGH])
tags = serializers.ListField(child=serializers.CharField(max_length=32), allow_empty=False)
@@ -68,12 +70,23 @@ class CreateContestProblemSerializer(CreateOrEditProblemSerializer):
contest_id = serializers.IntegerField()
+class EditContestProblemSerializer(CreateOrEditProblemSerializer):
+ id = serializers.IntegerField()
+ contest_id = serializers.IntegerField()
+
+
class TagSerializer(serializers.ModelSerializer):
class Meta:
model = ProblemTag
+ fields = "__all__"
-class ProblemSerializer(serializers.ModelSerializer):
+class CompileSPJSerializer(serializers.Serializer):
+ spj_language = serializers.ChoiceField(choices=spj_language_names)
+ spj_code = serializers.CharField()
+
+
+class BaseProblemSerializer(serializers.ModelSerializer):
samples = serializers.JSONField()
test_case_score = serializers.JSONField()
languages = serializers.JSONField()
@@ -82,6 +95,100 @@ class ProblemSerializer(serializers.ModelSerializer):
create_time = DateTimeTZField()
last_update_time = DateTimeTZField()
created_by = UsernameSerializer()
+ statistic_info = serializers.JSONField()
+
+
+class ProblemAdminSerializer(BaseProblemSerializer):
+ class Meta:
+ model = Problem
+ fields = "__all__"
+
+
+class ContestProblemAdminSerializer(BaseProblemSerializer):
+ class Meta:
+ model = Problem
+ fields = "__all__"
+
+
+class ProblemSerializer(BaseProblemSerializer):
+ template = serializers.SerializerMethodField()
+
+ def get_template(self, obj):
+ ret = {}
+ for lang, code in obj.template.items():
+ ret[lang] = parse_problem_template(code)["template"]
+ return ret
class Meta:
model = Problem
+ exclude = ("contest", "test_case_score", "test_case_id", "visible", "is_public",
+ "template", "spj_code", "spj_version", "spj_compile_ok")
+
+
+class ContestProblemSerializer(BaseProblemSerializer):
+ class Meta:
+ model = Problem
+ exclude = ("test_case_score", "test_case_id", "visible", "is_public", "difficulty")
+
+
+class ContestProblemSafeSerializer(BaseProblemSerializer):
+ class Meta:
+ model = Problem
+ exclude = ("test_case_score", "test_case_id", "visible", "is_public", "difficulty",
+ "submission_number", "accepted_number", "statistic_info")
+
+
+class ContestProblemMakePublicSerializer(serializers.Serializer):
+ id = serializers.IntegerField()
+ display_id = serializers.CharField(max_length=32)
+
+
+class ExportProblemSerializer(serializers.ModelSerializer):
+ description = serializers.SerializerMethodField()
+ input_description = serializers.SerializerMethodField()
+ output_description = serializers.SerializerMethodField()
+ test_case_score = serializers.SerializerMethodField()
+ hint = serializers.SerializerMethodField()
+ time_limit = serializers.SerializerMethodField()
+ memory_limit = serializers.SerializerMethodField()
+ spj = serializers.SerializerMethodField()
+ template = serializers.SerializerMethodField()
+
+ def get_description(self, obj):
+ return {"format": "html", "value": obj.description}
+
+ def get_input_description(self, obj):
+ return {"format": "html", "value": obj.input_description}
+
+ def get_output_description(self, obj):
+ return {"format": "html", "value": obj.output_description}
+
+ def get_hint(self, obj):
+ return {"format": "html", "value": obj.hint}
+
+ def get_test_case_score(self, obj):
+ return obj.test_case_score if obj.rule_type == ProblemRuleType.OI else []
+
+ def get_time_limit(self, obj):
+ return {"unit": "ms", "value": obj.time_limit}
+
+ def get_memory_limit(self, obj):
+ return {"unit": "MB", "value": obj.memory_limit}
+
+ def get_spj(self, obj):
+ return {"enabled": obj.spj,
+ "code": obj.spj_code if obj.spj else None,
+ "language": obj.spj_language if obj.spj else None}
+
+ def get_template(self, obj):
+ ret = {}
+ for k, v in obj.template.items():
+ ret[k] = parse_problem_template(v)
+ return ret
+
+ class Meta:
+ model = Problem
+ fields = ("_id", "title", "description",
+ "input_description", "output_description",
+ "test_case_score", "hint", "time_limit", "memory_limit", "samples",
+ "template", "spj", "rule_type", "source", "template")
diff --git a/problem/tests.py b/problem/tests.py
index d80262a..93cd058 100644
--- a/problem/tests.py
+++ b/problem/tests.py
@@ -1,6 +1,8 @@
import copy
+import hashlib
import os
import shutil
+from datetime import timedelta
from zipfile import ZipFile
from django.conf import settings
@@ -8,7 +10,59 @@ from django.conf import settings
from utils.api.tests import APITestCase
from .models import ProblemTag
-from .views.admin import TestCaseUploadAPI
+from .models import Problem, ProblemRuleType
+from contest.models import Contest
+from contest.tests import DEFAULT_CONTEST_DATA
+
+from .views.admin import TestCaseAPI
+from .utils import parse_problem_template
+
+
+DEFAULT_PROBLEM_DATA = {"_id": "A-110", "title": "test", "description": "test
", "input_description": "test",
+ "output_description": "test", "time_limit": 1000, "memory_limit": 256, "difficulty": "Low",
+ "visible": True, "tags": ["test"], "languages": ["C", "C++", "Java", "Python2"], "template": {},
+ "samples": [{"input": "test", "output": "test"}], "spj": False, "spj_language": "C",
+ "spj_code": "", "spj_compile_ok": True, "test_case_id": "499b26290cc7994e0b497212e842ea85",
+ "test_case_score": [{"output_name": "1.out", "input_name": "1.in", "output_size": 0,
+ "stripped_output_md5": "d41d8cd98f00b204e9800998ecf8427e",
+ "input_size": 0, "score": 0}],
+ "rule_type": "ACM", "hint": "test
", "source": "test"}
+
+
+class ProblemCreateTestBase(APITestCase):
+ @staticmethod
+ def add_problem(problem_data, created_by):
+ data = copy.deepcopy(problem_data)
+ if data["spj"]:
+ if not data["spj_language"] or not data["spj_code"]:
+ raise ValueError("Invalid spj")
+ data["spj_version"] = hashlib.md5(
+ (data["spj_language"] + ":" + data["spj_code"]).encode("utf-8")).hexdigest()
+ else:
+ data["spj_language"] = None
+ data["spj_code"] = None
+ if data["rule_type"] == ProblemRuleType.OI:
+ total_score = 0
+ for item in data["test_case_score"]:
+ if item["score"] <= 0:
+ raise ValueError("invalid score")
+ else:
+ total_score += item["score"]
+ data["total_score"] = total_score
+ data["created_by"] = created_by
+ tags = data.pop("tags")
+
+ data["languages"] = list(data["languages"])
+
+ problem = Problem.objects.create(**data)
+
+ for item in tags:
+ try:
+ tag = ProblemTag.objects.get(name=item)
+ except ProblemTag.DoesNotExist:
+ tag = ProblemTag.objects.create(name=item)
+ problem.tags.add(tag)
+ return problem
class ProblemTagListAPITest(APITestCase):
@@ -17,17 +71,20 @@ class ProblemTagListAPITest(APITestCase):
ProblemTag.objects.create(name="name2")
resp = self.client.get(self.reverse("problem_tag_list_api"))
self.assertSuccess(resp)
- self.assertEqual(resp.data["data"], ["name1", "name2"])
+ resp_data = resp.data["data"]
+ self.assertEqual(resp_data[0]["name"], "name1")
+ self.assertEqual(resp_data[1]["name"], "name2")
class TestCaseUploadAPITest(APITestCase):
def setUp(self):
- self.api = TestCaseUploadAPI()
- self.url = self.reverse("test_case_upload_api")
+ self.api = TestCaseAPI()
+ self.url = self.reverse("test_case_api")
self.create_super_admin()
def test_filter_file_name(self):
- self.assertEqual(self.api.filter_name_list(["1.in", "1.out", "2.in", ".DS_Store"], spj=False), ["1.in", "1.out"])
+ self.assertEqual(self.api.filter_name_list(["1.in", "1.out", "2.in", ".DS_Store"], spj=False),
+ ["1.in", "1.out"])
self.assertEqual(self.api.filter_name_list(["2.in", "2.out"], spj=False), [])
self.assertEqual(self.api.filter_name_list(["1.in", "1.out", "2.in"], spj=True), ["1.in", "2.in"])
@@ -76,19 +133,11 @@ class TestCaseUploadAPITest(APITestCase):
self.assertEqual(f.read(), name + "\n" + name + "\n" + "end")
-class ProblemAPITest(APITestCase):
+class ProblemAdminAPITest(APITestCase):
def setUp(self):
- self.url = self.reverse("problem_api")
+ self.url = self.reverse("problem_admin_api")
self.create_super_admin()
- self.data = {"_id": "A-110", "title": "test", "description": "test
", "input_description": "test",
- "output_description": "test", "time_limit": 1000, "memory_limit": 256, "difficulty": "Low",
- "visible": True, "tags": ["test"], "languages": ["C", "C++", "Java", "Python2"], "template": {},
- "samples": [{"input": "test", "output": "test"}], "spj": False, "spj_language": "C",
- "spj_code": "", "test_case_id": "499b26290cc7994e0b497212e842ea85",
- "test_case_score": [{"output_name": "1.out", "input_name": "1.in", "output_size": 0,
- "stripped_output_md5": "d41d8cd98f00b204e9800998ecf8427e",
- "input_size": 0, "score": 0}],
- "rule_type": "ACM", "hint": "test
", "source": "test"}
+ self.data = copy.deepcopy(DEFAULT_PROBLEM_DATA)
def test_create_problem(self):
resp = self.client.post(self.url, data=self.data)
@@ -128,3 +177,127 @@ class ProblemAPITest(APITestCase):
data["id"] = problem_id
resp = self.client.put(self.url, data=data)
self.assertSuccess(resp)
+
+
+class ProblemAPITest(ProblemCreateTestBase):
+ def setUp(self):
+ self.url = self.reverse("problem_api")
+ admin = self.create_admin(login=False)
+ self.problem = self.add_problem(DEFAULT_PROBLEM_DATA, admin)
+ self.create_user("test", "test123")
+
+ def test_get_problem_list(self):
+ resp = self.client.get(f"{self.url}?limit=10")
+ self.assertSuccess(resp)
+
+ def get_one_problem(self):
+ resp = self.client.get(self.url + "?id=" + self.problem._id)
+ self.assertSuccess(resp)
+
+
+class ContestProblemAdminTest(APITestCase):
+ def setUp(self):
+ self.url = self.reverse("contest_problem_admin_api")
+ self.create_admin()
+ self.contest = self.client.post(self.reverse("contest_admin_api"), data=DEFAULT_CONTEST_DATA).data["data"]
+
+ def test_create_contest_problem(self):
+ data = copy.deepcopy(DEFAULT_PROBLEM_DATA)
+ data["contest_id"] = self.contest["id"]
+ resp = self.client.post(self.url, data=data)
+ self.assertSuccess(resp)
+ return resp.data["data"]
+
+ def test_get_contest_problem(self):
+ self.test_create_contest_problem()
+ contest_id = self.contest["id"]
+ resp = self.client.get(self.url + "?contest_id=" + str(contest_id))
+ self.assertSuccess(resp)
+ self.assertEqual(len(resp.data["data"]["results"]), 1)
+
+ def test_get_one_contest_problem(self):
+ contest_problem = self.test_create_contest_problem()
+ contest_id = self.contest["id"]
+ problem_id = contest_problem["id"]
+ resp = self.client.get(f"{self.url}?contest_id={contest_id}&id={problem_id}")
+ self.assertSuccess(resp)
+
+
+class ContestProblemTest(ProblemCreateTestBase):
+ def setUp(self):
+ admin = self.create_admin()
+ url = self.reverse("contest_admin_api")
+ contest_data = copy.deepcopy(DEFAULT_CONTEST_DATA)
+ contest_data["password"] = ""
+ contest_data["start_time"] = contest_data["start_time"] + timedelta(hours=1)
+ self.contest = self.client.post(url, data=contest_data).data["data"]
+ self.problem = self.add_problem(DEFAULT_PROBLEM_DATA, admin)
+ self.problem.contest_id = self.contest["id"]
+ self.problem.save()
+ self.url = self.reverse("contest_problem_api")
+
+ def test_admin_get_contest_problem_list(self):
+ contest_id = self.contest["id"]
+ resp = self.client.get(self.url + "?contest_id=" + str(contest_id))
+ self.assertSuccess(resp)
+ self.assertEqual(len(resp.data["data"]), 1)
+
+ def test_admin_get_one_contest_problem(self):
+ contest_id = self.contest["id"]
+ problem_id = self.problem._id
+ resp = self.client.get("{}?contest_id={}&problem_id={}".format(self.url, contest_id, problem_id))
+ self.assertSuccess(resp)
+
+ def test_regular_user_get_not_started_contest_problem(self):
+ self.create_user("test", "test123")
+ resp = self.client.get(self.url + "?contest_id=" + str(self.contest["id"]))
+ self.assertDictEqual(resp.data, {"error": "error", "data": "Contest has not started yet."})
+
+ def test_reguar_user_get_started_contest_problem(self):
+ self.create_user("test", "test123")
+ contest = Contest.objects.first()
+ contest.start_time = contest.start_time - timedelta(hours=1)
+ contest.save()
+ resp = self.client.get(self.url + "?contest_id=" + str(self.contest["id"]))
+ self.assertSuccess(resp)
+
+
+class ParseProblemTemplateTest(APITestCase):
+ def test_parse(self):
+ template_str = """
+//PREPEND BEGIN
+aaa
+//PREPEND END
+
+//TEMPLATE BEGIN
+bbb
+//TEMPLATE END
+
+//APPEND BEGIN
+ccc
+//APPEND END
+"""
+
+ ret = parse_problem_template(template_str)
+ self.assertEqual(ret["prepend"], "aaa\n")
+ self.assertEqual(ret["template"], "bbb\n")
+ self.assertEqual(ret["append"], "ccc\n")
+
+ def test_parse1(self):
+ template_str = """
+//PREPEND BEGIN
+aaa
+//PREPEND END
+
+//APPEND BEGIN
+ccc
+//APPEND END
+//APPEND BEGIN
+ddd
+//APPEND END
+"""
+
+ ret = parse_problem_template(template_str)
+ self.assertEqual(ret["prepend"], "aaa\n")
+ self.assertEqual(ret["template"], "")
+ self.assertEqual(ret["append"], "ccc\n")
diff --git a/problem/urls/admin.py b/problem/urls/admin.py
index b4813c5..d4f9974 100644
--- a/problem/urls/admin.py
+++ b/problem/urls/admin.py
@@ -1,9 +1,12 @@
from django.conf.urls import url
-from ..views.admin import ContestProblemAPI, ProblemAPI, TestCaseUploadAPI
+from ..views.admin import ContestProblemAPI, ProblemAPI, TestCaseAPI, MakeContestProblemPublicAPIView
+from ..views.admin import CompileSPJAPI
urlpatterns = [
- url(r"^test_case/upload/?$", TestCaseUploadAPI.as_view(), name="test_case_upload_api"),
- url(r"^problem/?$", ProblemAPI.as_view(), name="problem_api"),
- url(r"^contest/problem/?$", ContestProblemAPI.as_view(), name="contest_problem_api")
+ url(r"^test_case/?$", TestCaseAPI.as_view(), name="test_case_api"),
+ url(r"^compile_spj/?$", CompileSPJAPI.as_view(), name="compile_spj"),
+ url(r"^problem/?$", ProblemAPI.as_view(), name="problem_admin_api"),
+ url(r"^contest/problem/?$", ContestProblemAPI.as_view(), name="contest_problem_admin_api"),
+ url(r"^contest_problem/make_public/?$", MakeContestProblemPublicAPIView.as_view(), name="make_public_api"),
]
diff --git a/problem/urls/oj.py b/problem/urls/oj.py
index a7613f1..f7cd3ae 100644
--- a/problem/urls/oj.py
+++ b/problem/urls/oj.py
@@ -1,7 +1,10 @@
from django.conf.urls import url
-from ..views.oj import ProblemTagAPI
+from ..views.oj import ProblemTagAPI, ProblemAPI, ContestProblemAPI, PickOneAPI
urlpatterns = [
- url(r"^problem/tags/?$", ProblemTagAPI.as_view(), name="problem_tag_list_api")
+ url(r"^problem/tags/?$", ProblemTagAPI.as_view(), name="problem_tag_list_api"),
+ url(r"^problem/?$", ProblemAPI.as_view(), name="problem_api"),
+ url(r"^pickone/?$", PickOneAPI.as_view(), name="pick_one_api"),
+ url(r"^contest/problem/?$", ContestProblemAPI.as_view(), name="contest_problem_api"),
]
diff --git a/problem/utils.py b/problem/utils.py
new file mode 100644
index 0000000..f824309
--- /dev/null
+++ b/problem/utils.py
@@ -0,0 +1,10 @@
+import re
+
+
+def parse_problem_template(template_str):
+ prepend = re.findall("//PREPEND BEGIN\n([\s\S]+?)//PREPEND END", template_str)
+ template = re.findall("//TEMPLATE BEGIN\n([\s\S]+?)//TEMPLATE END", template_str)
+ append = re.findall("//APPEND BEGIN\n([\s\S]+?)//APPEND END", template_str)
+ return {"prepend": prepend[0] if prepend else "",
+ "template": template[0] if template else "",
+ "append": append[0] if append else ""}
diff --git a/problem/views/admin.py b/problem/views/admin.py
index ad818ad..36fe655 100644
--- a/problem/views/admin.py
+++ b/problem/views/admin.py
@@ -1,22 +1,27 @@
import hashlib
import json
import os
+import shutil
import zipfile
+from wsgiref.util import FileWrapper
from django.conf import settings
+from django.http import StreamingHttpResponse
from account.decorators import problem_permission_required
+from judge.dispatcher import SPJCompiler
from contest.models import Contest
+from submission.models import Submission
from utils.api import APIView, CSRFExemptAPIView, validate_serializer
-from utils.shortcuts import rand_str
+from utils.shortcuts import rand_str, natural_sort_key
-from ..models import ContestProblem, Problem, ProblemRuleType, ProblemTag
-from ..serializers import (CreateContestProblemSerializer,
- CreateProblemSerializer, EditProblemSerializer,
- ProblemSerializer, TestCaseUploadForm)
+from ..models import Problem, ProblemRuleType, ProblemTag
+from ..serializers import (CreateContestProblemSerializer, ContestProblemAdminSerializer, CompileSPJSerializer,
+ CreateProblemSerializer, EditProblemSerializer, EditContestProblemSerializer,
+ ProblemAdminSerializer, TestCaseUploadForm, ContestProblemMakePublicSerializer)
-class TestCaseUploadAPI(CSRFExemptAPIView):
+class TestCaseAPI(CSRFExemptAPIView):
request_parsers = ()
def filter_name_list(self, name_list, spj):
@@ -30,7 +35,7 @@ class TestCaseUploadAPI(CSRFExemptAPIView):
prefix += 1
continue
else:
- return sorted(ret)
+ return sorted(ret, key=natural_sort_key)
else:
while True:
in_name = str(prefix) + ".in"
@@ -41,7 +46,30 @@ class TestCaseUploadAPI(CSRFExemptAPIView):
prefix += 1
continue
else:
- return sorted(ret)
+ return sorted(ret, key=natural_sort_key)
+
+ @problem_permission_required
+ def get(self, request):
+ problem_id = request.GET.get("problem_id")
+ if not problem_id:
+ return self.error("Parameter error, problem_id is required")
+ try:
+ problem = Problem.objects.get(id=problem_id)
+ except Problem.DoesNotExist:
+ return self.error("Problem does not exists")
+
+ test_case_dir = os.path.join(settings.TEST_CASE_DIR, problem.test_case_id)
+ if not os.path.isdir(test_case_dir):
+ return self.error("Test case does not exists")
+ name_list = self.filter_name_list(os.listdir(test_case_dir), problem.spj)
+ name_list.append("info")
+ file_name = os.path.join(test_case_dir, problem.test_case_id + ".zip")
+ with zipfile.ZipFile(file_name, "w") as file:
+ for test_case in name_list:
+ file.write(f"{test_case_dir}/{test_case}", test_case)
+ response = StreamingHttpResponse(FileWrapper(open(file_name, "rb")), content_type="application/zip")
+ response["Content-Disposition"] = f"attachment; filename=problem_{problem.id}_test_cases.zip"
+ return response
@problem_permission_required
def post(self, request):
@@ -76,7 +104,7 @@ class TestCaseUploadAPI(CSRFExemptAPIView):
content = zip_file.read(item).replace(b"\r\n", b"\n")
size_cache[item] = len(content)
if item.endswith(".out"):
- md5_cache[item] = hashlib.md5(content).hexdigest()
+ md5_cache[item] = hashlib.md5(content.rstrip()).hexdigest()
f.write(content)
test_case_info = {"spj": spj, "test_cases": {}}
@@ -109,44 +137,80 @@ class TestCaseUploadAPI(CSRFExemptAPIView):
return self.success({"id": test_case_id, "info": ret, "hint": hint, "spj": spj})
-class ProblemAPI(APIView):
+class CompileSPJAPI(APIView):
+ @validate_serializer(CompileSPJSerializer)
+ @problem_permission_required
+ def post(self, request):
+ data = request.data
+ spj_version = rand_str(8)
+ error = SPJCompiler(data["spj_code"], spj_version, data["spj_language"]).compile_spj()
+ if error:
+ return self.error(error)
+ else:
+ return self.success()
+
+
+class ProblemBase(APIView):
+ def common_checks(self, request):
+ data = request.data
+ if data["spj"]:
+ if not data["spj_language"] or not data["spj_code"]:
+ return "Invalid spj"
+ if not data["spj_compile_ok"]:
+ return "SPJ code must be compiled successfully"
+ data["spj_version"] = hashlib.md5(
+ (data["spj_language"] + ":" + data["spj_code"]).encode("utf-8")).hexdigest()
+ else:
+ data["spj_language"] = None
+ data["spj_code"] = None
+ if data["rule_type"] == ProblemRuleType.OI:
+ total_score = 0
+ for item in data["test_case_score"]:
+ if item["score"] <= 0:
+ return "Invalid score"
+ else:
+ total_score += item["score"]
+ data["total_score"] = total_score
+ data["created_by"] = request.user
+ data["languages"] = list(data["languages"])
+
+ @problem_permission_required
+ def delete(self, request):
+ id = request.GET.get("id")
+ if not id:
+ return self.error("Invalid parameter, id is requred")
+ try:
+ problem = Problem.objects.get(id=id)
+ except Problem.DoesNotExist:
+ return self.error("Problem does not exists")
+ if Submission.objects.filter(problem=problem).exists():
+ return self.error("Can't delete the problem as it has submissions")
+ d = os.path.join(settings.TEST_CASE_DIR, problem.test_case_id)
+ if os.path.isdir(d):
+ shutil.rmtree(d, ignore_errors=True)
+ problem.delete()
+ return self.success()
+
+
+class ProblemAPI(ProblemBase):
@validate_serializer(CreateProblemSerializer)
@problem_permission_required
def post(self, request):
data = request.data
_id = data["_id"]
- if _id:
- try:
- Problem.objects.get(_id=_id)
- return self.error("Display ID already exists")
- except Problem.DoesNotExist:
- pass
- else:
- data["_id"] = rand_str(8)
-
- if data["spj"]:
- if not data["spj_language"] or not data["spj_code"]:
- return self.error("Invalid spj")
- data["spj_version"] = hashlib.md5((data["spj_language"] + ":" + data["spj_code"]).encode("utf-8")).hexdigest()
- else:
- data["spj_language"] = None
- data["spj_code"] = None
- if data["rule_type"] == ProblemRuleType.OI:
- for item in data["test_case_score"]:
- if item["score"] <= 0:
- return self.error("Invalid score")
- # todo check filename and score info
- data["created_by"] = request.user
- tags = data.pop("tags")
-
- data["languages"] = list(data["languages"])
-
- problem = Problem.objects.create(**data)
-
if not _id:
- problem._id = str(problem.id)
- problem.save()
+ return self.error("Display ID is required")
+ if Problem.objects.filter(_id=_id, contest_id__isnull=True).exists():
+ return self.error("Display ID already exists")
+
+ error_info = self.common_checks(request)
+ if error_info:
+ return self.error(error_info)
+
+ # todo check filename and score info
+ tags = data.pop("tags")
+ problem = Problem.objects.create(**data)
for item in tags:
try:
@@ -154,7 +218,7 @@ class ProblemAPI(APIView):
except ProblemTag.DoesNotExist:
tag = ProblemTag.objects.create(name=item)
problem.tags.add(tag)
- return self.success(ProblemSerializer(problem).data)
+ return self.success(ProblemAdminSerializer(problem).data)
@problem_permission_required
def get(self, request):
@@ -165,17 +229,17 @@ class ProblemAPI(APIView):
problem = Problem.objects.get(id=problem_id)
if not user.can_mgmt_all_problem() and problem.created_by != user:
return self.error("Problem does not exist")
- return self.success(ProblemSerializer(problem).data)
+ return self.success(ProblemAdminSerializer(problem).data)
except Problem.DoesNotExist:
return self.error("Problem does not exist")
- problems = Problem.objects.all().order_by("-create_time")
+ problems = Problem.objects.filter(contest_id__isnull=True).order_by("-create_time")
if not user.can_mgmt_all_problem():
problems = problems.filter(created_by=user)
keyword = request.GET.get("keyword")
if keyword:
problems = problems.filter(title__contains=keyword)
- return self.success(self.paginate_data(request, problems, ProblemSerializer))
+ return self.success(self.paginate_data(request, problems, ProblemAdminSerializer))
@validate_serializer(EditProblemSerializer)
@problem_permission_required
@@ -192,29 +256,17 @@ class ProblemAPI(APIView):
return self.error("Problem does not exist")
_id = data["_id"]
- if _id:
- try:
- Problem.objects.exclude(id=problem_id).get(_id=_id)
- return self.error("Display ID already exists")
- except Problem.DoesNotExist:
- pass
- else:
- data["_id"] = str(problem_id)
+ if not _id:
+ return self.error("Display ID is required")
+ if Problem.objects.exclude(id=problem_id).filter(_id=_id, contest_id__isnull=True).exists():
+ return self.error("Display ID already exists")
- if data["spj"]:
- if not data["spj_language"] or not data["spj_code"]:
- return self.error("Invalid spj")
- data["spj_version"] = hashlib.md5((data["spj_language"] + ":" + data["spj_code"]).encode("utf-8")).hexdigest()
- else:
- data["spj_language"] = None
- data["spj_code"] = None
-
- if data["rule_type"] == ProblemRuleType.OI:
- for item in data["test_case_score"]:
- if item["score"] <= 0:
- return self.error("Invalid score")
+ error_info = self.common_checks(request)
+ if error_info:
+ return self.error(error_info)
# todo check filename and score info
tags = data.pop("tags")
+ data["languages"] = list(data["languages"])
for k, v in data.items():
setattr(problem, k, v)
@@ -231,11 +283,11 @@ class ProblemAPI(APIView):
return self.success()
-class ContestProblemAPI(APIView):
+class ContestProblemAPI(ProblemBase):
@validate_serializer(CreateContestProblemSerializer)
+ @problem_permission_required
def post(self, request):
data = request.data
-
try:
contest = Contest.objects.get(id=data.pop("contest_id"))
if request.user.is_admin() and contest.created_by != request.user:
@@ -248,33 +300,19 @@ class ContestProblemAPI(APIView):
_id = data["_id"]
if not _id:
- return self.error("Display id is required for contest problem")
- try:
- ContestProblem.objects.get(_id=_id, contest=contest)
+ return self.error("Display ID is required")
+
+ if Problem.objects.filter(_id=_id, contest=contest).exists():
return self.error("Duplicate Display id")
- except ContestProblem.DoesNotExist:
- pass
- if data["spj"]:
- if not data["spj_language"] or not data["spj_code"]:
- return self.error("Invalid spj")
- data["spj_version"] = hashlib.md5((data["spj_language"] + ":" + data["spj_code"]).encode("utf-8")).hexdigest()
- else:
- data["spj_language"] = None
- data["spj_code"] = None
+ error_info = self.common_checks(request)
+ if error_info:
+ return self.error(error_info)
- if data["rule_type"] == ProblemRuleType.OI:
- for item in data["test_case_score"]:
- if item["score"] <= 0:
- return self.error("Invalid score")
# todo check filename and score info
-
- data["created_by"] = request.user
data["contest"] = contest
tags = data.pop("tags")
- data["languages"] = list(data["languages"])
-
- problem = ContestProblem.objects.create(**data)
+ problem = Problem.objects.create(**data)
for item in tags:
try:
@@ -282,28 +320,109 @@ class ContestProblemAPI(APIView):
except ProblemTag.DoesNotExist:
tag = ProblemTag.objects.create(name=item)
problem.tags.add(tag)
- return self.success(ProblemSerializer(problem).data)
+ return self.success(ContestProblemAdminSerializer(problem).data)
+ @problem_permission_required
def get(self, request):
problem_id = request.GET.get("id")
contest_id = request.GET.get("contest_id")
user = request.user
if problem_id:
try:
- problem = ContestProblem.objects.get(id=problem_id)
+ problem = Problem.objects.get(id=problem_id)
if user.is_admin() and problem.contest.created_by != user:
return self.error("Problem does not exist")
- except ContestProblem.DoesNotExist:
+ except Problem.DoesNotExist:
return self.error("Problem does not exist")
- return self.success(ProblemSerializer(problem).data)
+ return self.success(ProblemAdminSerializer(problem).data)
if not contest_id:
return self.error("Contest id is required")
- problems = ContestProblem.objects.filter(contest_id=contest_id).order_by("-create_time")
+ problems = Problem.objects.filter(contest_id=contest_id).order_by("-create_time")
if user.is_admin():
problems = problems.filter(contest__created_by=user)
keyword = request.GET.get("keyword")
if keyword:
problems = problems.filter(title__contains=keyword)
- return self.success(self.paginate_data(request, problems, ProblemSerializer))
+ return self.success(self.paginate_data(request, problems, ContestProblemAdminSerializer))
+
+ @validate_serializer(EditContestProblemSerializer)
+ @problem_permission_required
+ def put(self, request):
+ data = request.data
+ try:
+ contest = Contest.objects.get(id=data.pop("contest_id"))
+ if request.user.is_admin() and contest.created_by != request.user:
+ return self.error("Contest does not exist")
+ except Contest.DoesNotExist:
+ return self.error("Contest does not exist")
+
+ if data["rule_type"] != contest.rule_type:
+ return self.error("Invalid rule type")
+
+ problem_id = data.pop("id")
+ user = request.user
+
+ try:
+ problem = Problem.objects.get(id=problem_id)
+ if not user.can_mgmt_all_problem() and problem.created_by != user:
+ return self.error("Problem does not exist")
+ except Problem.DoesNotExist:
+ return self.error("Problem does not exist")
+
+ _id = data["_id"]
+ if not _id:
+ return self.error("Display ID is required")
+ if Problem.objects.exclude(id=problem_id).filter(_id=_id, contest=contest).exists():
+ return self.error("Display ID already exists")
+
+ error_info = self.common_checks(request)
+ if error_info:
+ return self.error(error_info)
+ # todo check filename and score info
+ tags = data.pop("tags")
+ data["languages"] = list(data["languages"])
+
+ for k, v in data.items():
+ setattr(problem, k, v)
+ problem.save()
+
+ problem.tags.remove(*problem.tags.all())
+ for tag in tags:
+ try:
+ tag = ProblemTag.objects.get(name=tag)
+ except ProblemTag.DoesNotExist:
+ tag = ProblemTag.objects.create(name=tag)
+ problem.tags.add(tag)
+ return self.success()
+
+
+class MakeContestProblemPublicAPIView(APIView):
+ @validate_serializer(ContestProblemMakePublicSerializer)
+ @problem_permission_required
+ def post(self, request):
+ data = request.data
+ display_id = data.get("display_id")
+ if Problem.objects.filter(_id=display_id, contest_id__isnull=True).exists():
+ return self.error("Duplicate display ID")
+
+ try:
+ problem = Problem.objects.get(id=data["id"])
+ except Problem.DoesNotExist:
+ return self.error("Problem does not exist")
+
+ if not problem.contest or problem.is_public:
+ return self.error("Alreay be a public problem")
+ problem.is_public = True
+ problem.save()
+ # https://docs.djangoproject.com/en/1.11/topics/db/queries/#copying-model-instances
+ tags = problem.tags.all()
+ problem.pk = None
+ problem.contest = None
+ problem._id = display_id
+ problem.submission_number = problem.accepted_number = 0
+ problem.statistic_info = {}
+ problem.save()
+ problem.tags.set(tags)
+ return self.success()
diff --git a/problem/views/oj.py b/problem/views/oj.py
index 94496ce..5d1efb4 100644
--- a/problem/views/oj.py
+++ b/problem/views/oj.py
@@ -1,8 +1,116 @@
+import random
+from django.db.models import Q
from utils.api import APIView
-
-from ..models import ProblemTag
+from account.decorators import check_contest_permission
+from ..models import ProblemTag, Problem, ProblemRuleType
+from ..serializers import ProblemSerializer, TagSerializer
+from ..serializers import ContestProblemSerializer, ContestProblemSafeSerializer
+from contest.models import ContestRuleType
class ProblemTagAPI(APIView):
def get(self, request):
- return self.success([item.name for item in ProblemTag.objects.all().order_by("id")])
+ return self.success(TagSerializer(ProblemTag.objects.all(), many=True).data)
+
+
+class PickOneAPI(APIView):
+ def get(self, request):
+ problems = Problem.objects.filter(contest_id__isnull=True, visible=True)
+ count = problems.count()
+ if count == 0:
+ return self.error("No problem to pick")
+ return self.success(problems[random.randint(0, count - 1)]._id)
+
+
+class ProblemAPI(APIView):
+ @staticmethod
+ def _add_problem_status(request, queryset_values):
+ if request.user.is_authenticated():
+ profile = request.user.userprofile
+ acm_problems_status = profile.acm_problems_status.get("problems", {})
+ oi_problems_status = profile.oi_problems_status.get("problems", {})
+ # paginate data
+ results = queryset_values.get("results")
+ if results is not None:
+ problems = results
+ else:
+ problems = [queryset_values, ]
+ for problem in problems:
+ if problem["rule_type"] == ProblemRuleType.ACM:
+ problem["my_status"] = acm_problems_status.get(str(problem["id"]), {}).get("status")
+ else:
+ problem["my_status"] = oi_problems_status.get(str(problem["id"]), {}).get("status")
+
+ def get(self, request):
+ # 问题详情页
+ problem_id = request.GET.get("problem_id")
+ if problem_id:
+ try:
+ problem = Problem.objects.select_related("created_by") \
+ .get(_id=problem_id, contest_id__isnull=True, visible=True)
+ problem_data = ProblemSerializer(problem).data
+ self._add_problem_status(request, problem_data)
+ return self.success(problem_data)
+ except Problem.DoesNotExist:
+ return self.error("Problem does not exist")
+
+ limit = request.GET.get("limit")
+ if not limit:
+ return self.error("Limit is needed")
+
+ problems = Problem.objects.select_related("created_by").filter(contest_id__isnull=True, visible=True)
+ # 按照标签筛选
+ tag_text = request.GET.get("tag")
+ if tag_text:
+ problems = problems.filter(tags__name=tag_text)
+
+ # 搜索的情况
+ keyword = request.GET.get("keyword", "").strip()
+ if keyword:
+ problems = problems.filter(Q(title__icontains=keyword) | Q(_id__icontains=keyword))
+
+ # 难度筛选
+ difficulty = request.GET.get("difficulty")
+ if difficulty:
+ problems = problems.filter(difficulty=difficulty)
+ # 根据profile 为做过的题目添加标记
+ data = self.paginate_data(request, problems, ProblemSerializer)
+ self._add_problem_status(request, data)
+ return self.success(data)
+
+
+class ContestProblemAPI(APIView):
+ def _add_problem_status(self, request, queryset_values):
+ if request.user.is_authenticated():
+ profile = request.user.userprofile
+ if self.contest.rule_type == ContestRuleType.ACM:
+ problems_status = profile.acm_problems_status.get("contest_problems", {})
+ else:
+ problems_status = profile.oi_problems_status.get("contest_problems", {})
+ for problem in queryset_values:
+ problem["my_status"] = problems_status.get(str(problem["id"]), {}).get("status")
+
+ @check_contest_permission(check_type="problems")
+ def get(self, request):
+ problem_id = request.GET.get("problem_id")
+ if problem_id:
+ try:
+ problem = Problem.objects.select_related("created_by").get(_id=problem_id,
+ contest=self.contest,
+ visible=True)
+ except Problem.DoesNotExist:
+ return self.error("Problem does not exist.")
+ if self.contest.problem_details_permission(request.user):
+ problem_data = ContestProblemSerializer(problem).data
+ self._add_problem_status(request, [problem_data, ])
+ else:
+ problem_data = ContestProblemSafeSerializer(problem).data
+ return self.success(problem_data)
+
+ contest_problems = Problem.objects.select_related("created_by").filter(contest=self.contest, visible=True)
+ if self.contest.problem_details_permission(request.user):
+ data = ContestProblemSerializer(contest_problems, many=True).data
+ self._add_problem_status(request, data)
+ else:
+ data = ContestProblemSafeSerializer(contest_problems, many=True).data
+ return self.success(data)
diff --git a/requirements.txt b/requirements.txt
deleted file mode 100644
index 9b0d805..0000000
--- a/requirements.txt
+++ /dev/null
@@ -1,10 +0,0 @@
-django==1.9.6
-djangorestframework==3.4.0
-otpauth
-pillow
-python-dateutil
-celery
-Envelopes
-pytz
-jsonfield
-qrcode
\ No newline at end of file
diff --git a/run_test.py b/run_test.py
index 8358f12..cb4c630 100644
--- a/run_test.py
+++ b/run_test.py
@@ -21,7 +21,7 @@ print("running flake8...")
if os.system("flake8 --statistics ."):
exit()
-ret = os.system("coverage run ./manage.py test {module} --settings={setting}".format(module=test_module, setting=setting))
+ret = os.system("coverage run --include=\"$PWD/*\" manage.py test {module} --settings={setting}".format(module=test_module, setting=setting))
if not ret and is_coverage:
os.system("coverage html && open htmlcov/index.html")
diff --git a/submission/__init__.py b/submission/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/submission/migrations/0001_initial.py b/submission/migrations/0001_initial.py
new file mode 100644
index 0000000..42a5352
--- /dev/null
+++ b/submission/migrations/0001_initial.py
@@ -0,0 +1,39 @@
+# -*- coding: utf-8 -*-
+# Generated by Django 1.9.6 on 2017-05-09 06:41
+from __future__ import unicode_literals
+
+from django.db import migrations, models
+import jsonfield.fields
+import utils.models
+import utils.shortcuts
+
+
+class Migration(migrations.Migration):
+
+ initial = True
+
+ dependencies = [
+ ]
+
+ operations = [
+ migrations.CreateModel(
+ name='Submission',
+ fields=[
+ ('id', models.CharField(db_index=True, default=utils.shortcuts.rand_str, max_length=32, primary_key=True, serialize=False)),
+ ('contest_id', models.IntegerField(db_index=True, null=True)),
+ ('problem_id', models.IntegerField(db_index=True)),
+ ('created_time', models.DateTimeField(auto_now_add=True)),
+ ('user_id', models.IntegerField(db_index=True)),
+ ('code', utils.models.RichTextField()),
+ ('result', models.IntegerField(default=6)),
+ ('info', jsonfield.fields.JSONField(default={})),
+ ('language', models.CharField(max_length=20)),
+ ('shared', models.BooleanField(default=False)),
+ ('accepted_time', models.IntegerField(blank=True, null=True)),
+ ('accepted_info', jsonfield.fields.JSONField(default={})),
+ ],
+ options={
+ 'db_table': 'submission',
+ },
+ ),
+ ]
diff --git a/submission/migrations/0002_auto_20170509_1203.py b/submission/migrations/0002_auto_20170509_1203.py
new file mode 100644
index 0000000..78dcbe9
--- /dev/null
+++ b/submission/migrations/0002_auto_20170509_1203.py
@@ -0,0 +1,38 @@
+# -*- coding: utf-8 -*-
+# Generated by Django 1.9.6 on 2017-05-09 12:03
+from __future__ import unicode_literals
+
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ('submission', '0001_initial'),
+ ]
+
+ operations = [
+ migrations.AlterField(
+ model_name='submission',
+ name='code',
+ field=models.TextField(),
+ ),
+ migrations.RenameField(
+ model_name='submission',
+ old_name='accepted_info',
+ new_name='statistic_info',
+ ),
+ migrations.RemoveField(
+ model_name='submission',
+ name='accepted_time',
+ ),
+ migrations.RenameField(
+ model_name='submission',
+ old_name='created_time',
+ new_name='create_time',
+ ),
+ migrations.AlterModelOptions(
+ name='submission',
+ options={'ordering': ('-create_time',)},
+ )
+ ]
diff --git a/submission/migrations/0005_submission_username.py b/submission/migrations/0005_submission_username.py
new file mode 100644
index 0000000..68a3243
--- /dev/null
+++ b/submission/migrations/0005_submission_username.py
@@ -0,0 +1,21 @@
+# -*- coding: utf-8 -*-
+# Generated by Django 1.11.4 on 2017-08-26 03:47
+from __future__ import unicode_literals
+
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ('submission', '0002_auto_20170509_1203'),
+ ]
+
+ operations = [
+ migrations.AddField(
+ model_name='submission',
+ name='username',
+ field=models.CharField(default="", max_length=30),
+ preserve_default=False,
+ ),
+ ]
diff --git a/submission/migrations/0006_auto_20170830_1154.py b/submission/migrations/0006_auto_20170830_1154.py
new file mode 100644
index 0000000..675cc86
--- /dev/null
+++ b/submission/migrations/0006_auto_20170830_1154.py
@@ -0,0 +1,20 @@
+# -*- coding: utf-8 -*-
+# Generated by Django 1.11.4 on 2017-08-30 11:54
+from __future__ import unicode_literals
+
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ('submission', '0005_submission_username'),
+ ]
+
+ operations = [
+ migrations.AlterField(
+ model_name='submission',
+ name='result',
+ field=models.IntegerField(db_index=True, default=6),
+ ),
+ ]
diff --git a/submission/migrations/0007_auto_20170923_1318.py b/submission/migrations/0007_auto_20170923_1318.py
new file mode 100644
index 0000000..6356680
--- /dev/null
+++ b/submission/migrations/0007_auto_20170923_1318.py
@@ -0,0 +1,47 @@
+# -*- coding: utf-8 -*-
+# Generated by Django 1.11.4 on 2017-09-23 13:18
+from __future__ import unicode_literals
+
+import django.contrib.postgres.fields.jsonb
+from django.db import migrations, models
+import django.db.models.deletion
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ('submission', '0006_auto_20170830_1154'),
+ ]
+
+ operations = [
+ migrations.AlterField(
+ model_name='submission',
+ name='contest_id',
+ field=models.ForeignKey(null=True, on_delete=django.db.models.deletion.CASCADE, to='contest.Contest'),
+ ),
+ migrations.AlterField(
+ model_name='submission',
+ name='problem_id',
+ field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='problem.Problem'),
+ ),
+ migrations.RenameField(
+ model_name='submission',
+ old_name='contest_id',
+ new_name='contest',
+ ),
+ migrations.RenameField(
+ model_name='submission',
+ old_name='problem_id',
+ new_name='problem',
+ ),
+ migrations.AlterField(
+ model_name='submission',
+ name='info',
+ field=django.contrib.postgres.fields.jsonb.JSONField(default=dict),
+ ),
+ migrations.AlterField(
+ model_name='submission',
+ name='statistic_info',
+ field=django.contrib.postgres.fields.jsonb.JSONField(default=dict),
+ ),
+ ]
diff --git a/submission/migrations/0008_submission_ip.py b/submission/migrations/0008_submission_ip.py
new file mode 100644
index 0000000..e60841b
--- /dev/null
+++ b/submission/migrations/0008_submission_ip.py
@@ -0,0 +1,20 @@
+# -*- coding: utf-8 -*-
+# Generated by Django 1.11.4 on 2017-11-10 06:57
+from __future__ import unicode_literals
+
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ('submission', '0007_auto_20170923_1318'),
+ ]
+
+ operations = [
+ migrations.AddField(
+ model_name='submission',
+ name='ip',
+ field=models.CharField(blank=True, max_length=32, null=True),
+ ),
+ ]
diff --git a/submission/migrations/__init__.py b/submission/migrations/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/submission/models.py b/submission/models.py
new file mode 100644
index 0000000..0c4aeda
--- /dev/null
+++ b/submission/models.py
@@ -0,0 +1,53 @@
+from django.db import models
+from utils.models import JSONField
+from problem.models import Problem
+from contest.models import Contest
+
+from utils.shortcuts import rand_str
+
+
+class JudgeStatus:
+ COMPILE_ERROR = -2
+ WRONG_ANSWER = -1
+ ACCEPTED = 0
+ CPU_TIME_LIMIT_EXCEEDED = 1
+ REAL_TIME_LIMIT_EXCEEDED = 2
+ MEMORY_LIMIT_EXCEEDED = 3
+ RUNTIME_ERROR = 4
+ SYSTEM_ERROR = 5
+ PENDING = 6
+ JUDGING = 7
+ PARTIALLY_ACCEPTED = 8
+
+
+class Submission(models.Model):
+ id = models.CharField(max_length=32, default=rand_str, primary_key=True, db_index=True)
+ contest = models.ForeignKey(Contest, null=True)
+ problem = models.ForeignKey(Problem)
+ create_time = models.DateTimeField(auto_now_add=True)
+ user_id = models.IntegerField(db_index=True)
+ username = models.CharField(max_length=30)
+ code = models.TextField()
+ result = models.IntegerField(db_index=True, default=JudgeStatus.PENDING)
+ # 从JudgeServer返回的判题详情
+ info = JSONField(default=dict)
+ language = models.CharField(max_length=20)
+ shared = models.BooleanField(default=False)
+ # 存储该提交所用时间和内存值,方便提交列表显示
+ # {time_cost: "", memory_cost: "", err_info: "", score: 0}
+ statistic_info = JSONField(default=dict)
+ ip = models.CharField(max_length=32, null=True, blank=True)
+
+ def check_user_permission(self, user, check_share=True):
+ return self.user_id == user.id or \
+ (check_share and self.shared is True) or \
+ user.is_super_admin() or \
+ user.can_mgmt_all_problem() or \
+ self.problem.created_by_id == user.id
+
+ class Meta:
+ db_table = "submission"
+ ordering = ("-create_time",)
+
+ def __str__(self):
+ return self.id
diff --git a/submission/serializers.py b/submission/serializers.py
new file mode 100644
index 0000000..2b67ab8
--- /dev/null
+++ b/submission/serializers.py
@@ -0,0 +1,54 @@
+from .models import Submission
+from utils.api import serializers
+from judge.languages import language_names
+
+
+class CreateSubmissionSerializer(serializers.Serializer):
+ problem_id = serializers.IntegerField()
+ language = serializers.ChoiceField(choices=language_names)
+ code = serializers.CharField(max_length=20000)
+ contest_id = serializers.IntegerField(required=False)
+ captcha = serializers.CharField(required=False)
+
+
+class ShareSubmissionSerializer(serializers.Serializer):
+ id = serializers.CharField()
+ shared = serializers.BooleanField()
+
+
+class SubmissionModelSerializer(serializers.ModelSerializer):
+ info = serializers.JSONField()
+ statistic_info = serializers.JSONField()
+
+ class Meta:
+ model = Submission
+
+
+# 不显示submission info的serializer, 用于ACM rule_type
+class SubmissionSafeModelSerializer(serializers.ModelSerializer):
+ problem = serializers.SlugRelatedField(read_only=True, slug_field="_id")
+ statistic_info = serializers.JSONField()
+
+ class Meta:
+ model = Submission
+ exclude = ("info", "contest", "ip")
+
+
+class SubmissionListSerializer(serializers.ModelSerializer):
+ problem = serializers.SlugRelatedField(read_only=True, slug_field="_id")
+ statistic_info = serializers.JSONField()
+ show_link = serializers.SerializerMethodField()
+
+ def __init__(self, *args, **kwargs):
+ self.user = kwargs.pop("user", None)
+ super().__init__(*args, **kwargs)
+
+ class Meta:
+ model = Submission
+ exclude = ("info", "contest", "code", "ip")
+
+ def get_show_link(self, obj):
+ # 没传user或为匿名user
+ if self.user is None or not self.user.is_authenticated():
+ return False
+ return obj.check_user_permission(self.user)
diff --git a/submission/tests.py b/submission/tests.py
new file mode 100644
index 0000000..fdcd12e
--- /dev/null
+++ b/submission/tests.py
@@ -0,0 +1,68 @@
+from unittest import mock
+from copy import deepcopy
+
+from .models import Submission
+from problem.models import Problem, ProblemTag
+from utils.api.tests import APITestCase
+
+DEFAULT_PROBLEM_DATA = {"_id": "A-110", "title": "test", "description": "test
", "input_description": "test",
+ "output_description": "test", "time_limit": 1000, "memory_limit": 256, "difficulty": "Low",
+ "visible": True, "tags": ["test"], "languages": ["C", "C++", "Java", "Python2"], "template": {},
+ "samples": [{"input": "test", "output": "test"}], "spj": False, "spj_language": "C",
+ "spj_code": "", "test_case_id": "499b26290cc7994e0b497212e842ea85",
+ "test_case_score": [{"output_name": "1.out", "input_name": "1.in", "output_size": 0,
+ "stripped_output_md5": "d41d8cd98f00b204e9800998ecf8427e",
+ "input_size": 0, "score": 0}],
+ "rule_type": "ACM", "hint": "test
", "source": "test"}
+
+DEFAULT_SUBMISSION_DATA = {
+ "problem_id": "1",
+ "user_id": 1,
+ "username": "test",
+ "code": "xxxxxxxxxxxxxx",
+ "result": -2,
+ "info": {},
+ "language": "C",
+ "statistic_info": {}
+}
+# todo contest submission
+
+
+class SubmissionPrepare(APITestCase):
+ def _create_problem_and_submission(self):
+ user = self.create_admin("test", "test123", login=False)
+ problem_data = deepcopy(DEFAULT_PROBLEM_DATA)
+ tags = problem_data.pop("tags")
+ problem_data["created_by"] = user
+ self.problem = Problem.objects.create(**problem_data)
+ for tag in tags:
+ tag = ProblemTag.objects.create(name=tag)
+ self.problem.tags.add(tag)
+ self.problem.save()
+ self.submission_data = deepcopy(DEFAULT_SUBMISSION_DATA)
+ self.submission_data["problem_id"] = self.problem.id
+ self.submission = Submission.objects.create(**self.submission_data)
+
+
+class SubmissionListTest(SubmissionPrepare):
+ def setUp(self):
+ self._create_problem_and_submission()
+ self.create_user("123", "345")
+ self.url = self.reverse("submission_list_api")
+
+ def test_get_submission_list(self):
+ resp = self.client.get(self.url, data={"limit": "10"})
+ self.assertSuccess(resp)
+
+
+@mock.patch("submission.views.oj.judge_task.delay")
+class SubmissionAPITest(SubmissionPrepare):
+ def setUp(self):
+ self._create_problem_and_submission()
+ self.user = self.create_user("123", "test123")
+ self.url = self.reverse("submission_api")
+
+ def test_create_submission(self, judge_task):
+ resp = self.client.post(self.url, self.submission_data)
+ self.assertSuccess(resp)
+ judge_task.assert_called()
diff --git a/submission/urls/__init__.py b/submission/urls/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/submission/urls/admin.py b/submission/urls/admin.py
new file mode 100644
index 0000000..bf86022
--- /dev/null
+++ b/submission/urls/admin.py
@@ -0,0 +1,7 @@
+from django.conf.urls import url
+
+from ..views.admin import SubmissionRejudgeAPI
+
+urlpatterns = [
+ url(r"^submission/rejudge?$", SubmissionRejudgeAPI.as_view(), name="submission_rejudge_api"),
+]
diff --git a/submission/urls/oj.py b/submission/urls/oj.py
new file mode 100644
index 0000000..49116b9
--- /dev/null
+++ b/submission/urls/oj.py
@@ -0,0 +1,10 @@
+from django.conf.urls import url
+
+from ..views.oj import SubmissionAPI, SubmissionListAPI, ContestSubmissionListAPI, SubmissionExistsAPI
+
+urlpatterns = [
+ url(r"^submission/?$", SubmissionAPI.as_view(), name="submission_api"),
+ url(r"^submissions/?$", SubmissionListAPI.as_view(), name="submission_list_api"),
+ url(r"^submission_exists/?$", SubmissionExistsAPI.as_view(), name="submission_exists"),
+ url(r"^contest_submissions/?$", ContestSubmissionListAPI.as_view(), name="contest_submission_list_api"),
+]
diff --git a/submission/views/__init__.py b/submission/views/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/submission/views/admin.py b/submission/views/admin.py
new file mode 100644
index 0000000..679360b
--- /dev/null
+++ b/submission/views/admin.py
@@ -0,0 +1,24 @@
+from account.decorators import super_admin_required
+from judge.tasks import judge_task
+# from judge.dispatcher import JudgeDispatcher
+from utils.api import APIView
+from ..models import Submission, JudgeStatus
+
+
+class SubmissionRejudgeAPI(APIView):
+ @super_admin_required
+ def get(self, request):
+ id = request.GET.get("id")
+ if not id:
+ return self.error("Paramater error, id is required")
+ try:
+ submission = Submission.objects.select_related("problem").get(id=id, contest_id__isnull=True)
+ except Submission.DoesNotExist:
+ return self.error("Submission does not exists")
+ submission.result = JudgeStatus.PENDING
+ submission.info = {}
+ submission.statistic_info = {}
+ submission.save()
+
+ judge_task.delay(submission.id, submission.problem.id)
+ return self.success()
diff --git a/submission/views/oj.py b/submission/views/oj.py
new file mode 100644
index 0000000..8514c13
--- /dev/null
+++ b/submission/views/oj.py
@@ -0,0 +1,211 @@
+import ipaddress
+
+from django.conf import settings
+from account.decorators import login_required, check_contest_permission
+from judge.tasks import judge_task
+# from judge.dispatcher import JudgeDispatcher
+from problem.models import Problem, ProblemRuleType
+from contest.models import Contest, ContestStatus, ContestRuleType
+from options.options import SysOptions
+from utils.api import APIView, validate_serializer
+from utils.throttling import TokenBucket, BucketController
+from utils.captcha import Captcha
+from utils.cache import cache
+from ..models import Submission
+from ..serializers import (CreateSubmissionSerializer, SubmissionModelSerializer,
+ ShareSubmissionSerializer)
+from ..serializers import SubmissionSafeModelSerializer, SubmissionListSerializer
+
+
+class SubmissionAPI(APIView):
+ def throttling(self, request):
+ user_controller = BucketController(factor=request.user.id,
+ redis_conn=cache,
+ default_capacity=settings.TOKEN_BUCKET_DEFAULT_CAPACITY)
+ user_bucket = TokenBucket(fill_rate=settings.TOKEN_BUCKET_FILL_RATE,
+ capacity=settings.TOKEN_BUCKET_DEFAULT_CAPACITY,
+ last_capacity=user_controller.last_capacity,
+ last_timestamp=user_controller.last_timestamp)
+ if user_bucket.consume():
+ user_controller.last_capacity -= 1
+ else:
+ return "Please wait %d seconds" % int(user_bucket.expected_time() + 1)
+
+ ip_controller = BucketController(factor=request.session["ip"],
+ redis_conn=cache,
+ default_capacity=settings.TOKEN_BUCKET_DEFAULT_CAPACITY * 3)
+
+ ip_bucket = TokenBucket(fill_rate=settings.TOKEN_BUCKET_FILL_RATE * 3,
+ capacity=settings.TOKEN_BUCKET_DEFAULT_CAPACITY * 3,
+ last_capacity=ip_controller.last_capacity,
+ last_timestamp=ip_controller.last_timestamp)
+ if ip_bucket.consume():
+ ip_controller.last_capacity -= 1
+ else:
+ return "Captcha is required"
+
+ @validate_serializer(CreateSubmissionSerializer)
+ @login_required
+ def post(self, request):
+ data = request.data
+ hide_id = False
+ if data.get("contest_id"):
+ try:
+ contest = Contest.objects.get(id=data["contest_id"])
+ except Contest.DoesNotExist:
+ return self.error("Contest doesn't exist.")
+ if contest.status == ContestStatus.CONTEST_ENDED:
+ return self.error("The contest have ended")
+ if not request.user.is_contest_admin(contest):
+ if contest.status == ContestStatus.CONTEST_NOT_START:
+ return self.error("Contest have not started")
+ user_ip = ipaddress.ip_address(request.session.get("ip"))
+ if contest.allowed_ip_ranges:
+ if not any(user_ip in ipaddress.ip_network(cidr) for cidr in contest.allowed_ip_ranges):
+ return self.error("Your IP is not allowed in this contest")
+
+ if not contest.problem_details_permission(request.user):
+ hide_id = True
+
+ if data.get("captcha"):
+ if not Captcha(request).check(data["captcha"]):
+ return self.error("Invalid captcha")
+ error = self.throttling(request)
+ if error:
+ return self.error(error)
+
+ try:
+ problem = Problem.objects.get(id=data["problem_id"], contest_id=data.get("contest_id"), visible=True)
+ except Problem.DoesNotExist:
+ return self.error("Problem not exist")
+
+ submission = Submission.objects.create(user_id=request.user.id,
+ username=request.user.username,
+ language=data["language"],
+ code=data["code"],
+ problem_id=problem.id,
+ ip=request.session["ip"],
+ contest_id=data.get("contest_id"))
+ # use this for debug
+ # JudgeDispatcher(submission.id, problem.id).judge()
+ judge_task.delay(submission.id, problem.id)
+ if hide_id:
+ return self.success()
+ else:
+ return self.success({"submission_id": submission.id})
+
+ @login_required
+ def get(self, request):
+ submission_id = request.GET.get("id")
+ if not submission_id:
+ return self.error("Parameter id doesn't exist")
+ try:
+ submission = Submission.objects.select_related("problem").get(id=submission_id)
+ except Submission.DoesNotExist:
+ return self.error("Submission doesn't exist")
+ if not submission.check_user_permission(request.user):
+ return self.error("No permission for this submission")
+
+ if submission.problem.rule_type == ProblemRuleType.OI or request.user.is_admin_role():
+ submission_data = SubmissionModelSerializer(submission).data
+ else:
+ submission_data = SubmissionSafeModelSerializer(submission).data
+ # 是否有权限取消共享
+ submission_data["can_unshare"] = submission.check_user_permission(request.user, check_share=False)
+ return self.success(submission_data)
+
+ @validate_serializer(ShareSubmissionSerializer)
+ @login_required
+ def put(self, request):
+ """
+ share submission
+ """
+ try:
+ submission = Submission.objects.select_related("problem").get(id=request.data["id"])
+ except Submission.DoesNotExist:
+ return self.error("Submission doesn't exist")
+ if not submission.check_user_permission(request.user, check_share=False):
+ return self.error("No permission to share the submission")
+ if submission.contest and submission.contest.status == ContestStatus.CONTEST_UNDERWAY:
+ return self.error("Can not share submission now")
+ submission.shared = request.data["shared"]
+ submission.save(update_fields=["shared"])
+ return self.success()
+
+
+class SubmissionListAPI(APIView):
+ def get(self, request):
+ if not request.GET.get("limit"):
+ return self.error("Limit is needed")
+ if request.GET.get("contest_id"):
+ return self.error("Parameter error")
+
+ submissions = Submission.objects.filter(contest_id__isnull=True).select_related("problem__created_by")
+ problem_id = request.GET.get("problem_id")
+ myself = request.GET.get("myself")
+ result = request.GET.get("result")
+ username = request.GET.get("username")
+ if problem_id:
+ try:
+ problem = Problem.objects.get(_id=problem_id, contest_id__isnull=True, visible=True)
+ except Problem.DoesNotExist:
+ return self.error("Problem doesn't exist")
+ submissions = submissions.filter(problem=problem)
+ if (myself and myself == "1") or not SysOptions.submission_list_show_all:
+ submissions = submissions.filter(user_id=request.user.id)
+ elif username:
+ submissions = submissions.filter(username__icontains=username)
+ if result:
+ submissions = submissions.filter(result=result)
+ data = self.paginate_data(request, submissions)
+ data["results"] = SubmissionListSerializer(data["results"], many=True, user=request.user).data
+ return self.success(data)
+
+
+class ContestSubmissionListAPI(APIView):
+ @check_contest_permission(check_type="submissions")
+ def get(self, request):
+ if not request.GET.get("limit"):
+ return self.error("Limit is needed")
+
+ contest = self.contest
+ submissions = Submission.objects.filter(contest_id=contest.id).select_related("problem__created_by")
+ problem_id = request.GET.get("problem_id")
+ myself = request.GET.get("myself")
+ result = request.GET.get("result")
+ username = request.GET.get("username")
+ if problem_id:
+ try:
+ problem = Problem.objects.get(_id=problem_id, contest_id=contest.id, visible=True)
+ except Problem.DoesNotExist:
+ return self.error("Problem doesn't exist")
+ submissions = submissions.filter(problem=problem)
+
+ if myself and myself == "1":
+ submissions = submissions.filter(user_id=request.user.id)
+ elif username:
+ submissions = submissions.filter(username__icontains=username)
+ if result:
+ submissions = submissions.filter(result=result)
+
+ # filter the test submissions submitted before contest start
+ if contest.status != ContestStatus.CONTEST_NOT_START:
+ submissions = submissions.filter(create_time__gte=contest.start_time)
+
+ # 封榜的时候只能看到自己的提交
+ if contest.rule_type == ContestRuleType.ACM:
+ if not contest.real_time_rank and not contest.is_contest_admin(request.user):
+ submissions = submissions.filter(user_id=request.user.id)
+
+ data = self.paginate_data(request, submissions)
+ data["results"] = SubmissionListSerializer(data["results"], many=True, user=request.user).data
+ return self.success(data)
+
+
+class SubmissionExistsAPI(APIView):
+ def get(self, request):
+ if not request.GET.get("problem_id"):
+ return self.error("Parameter error, problem_id is required")
+ return self.success(request.user.is_authenticated() and
+ Submission.objects.filter(problem_id=request.GET["problem_id"],
+ user_id=request.user.id).exists())
diff --git a/utils/api/__init__.py b/utils/api/__init__.py
index dedbe3a..9384481 100644
--- a/utils/api/__init__.py
+++ b/utils/api/__init__.py
@@ -1,2 +1,2 @@
-from .api import * # NOQA
from ._serializers import * # NOQA
+from .api import * # NOQA
diff --git a/utils/api/_serializers.py b/utils/api/_serializers.py
index 4cfe3f4..737a965 100644
--- a/utils/api/_serializers.py
+++ b/utils/api/_serializers.py
@@ -1,11 +1,9 @@
-from django.utils import timezone
from rest_framework import serializers
class DateTimeTZField(serializers.DateTimeField):
def to_representation(self, value):
- self.format = "%Y-%-m-%d %H:%M:%S %Z"
- value = timezone.localtime(value)
+ # value = timezone.localtime(value)
return super(DateTimeTZField, self).to_representation(value)
diff --git a/utils/api/api.py b/utils/api/api.py
index 78018cd..e33daaf 100644
--- a/utils/api/api.py
+++ b/utils/api/api.py
@@ -65,6 +65,7 @@ class APIView(View):
for parser in self.request_parsers:
if content_type.startswith(parser.content_type):
break
+ # else means the for loop is not interrupted by break
else:
raise ValueError("unknown content_type '%s'" % content_type)
if body:
@@ -78,7 +79,7 @@ class APIView(View):
def success(self, data=None):
return self.response({"error": None, "data": data})
- def error(self, msg, err="error"):
+ def error(self, msg="error", err="error"):
return self.response({"error": err, "data": msg})
def _serializer_error_to_str(self, errors):
@@ -106,18 +107,12 @@ class APIView(View):
:param object_serializer: 用来序列化query set, 如果为None, 则直接对query set切片
:return:
"""
- need_paginate = request.GET.get("limit", None)
- if need_paginate is None:
- if object_serializer:
- return object_serializer(query_set, many=True).data
- else:
- return query_set
try:
- limit = int(request.GET.get("limit", "100"))
+ limit = int(request.GET.get("limit", "10"))
except ValueError:
- limit = 100
- if limit < 0:
- limit = 100
+ limit = 10
+ if limit < 0 or limit > 250:
+ limit = 10
try:
offset = int(request.GET.get("offset", "0"))
except ValueError:
@@ -129,7 +124,7 @@ class APIView(View):
count = query_set.count()
results = object_serializer(results, many=True).data
else:
- count = len(query_set)
+ count = query_set.count()
data = {"results": results,
"total": count}
return data
diff --git a/utils/api/tests.py b/utils/api/tests.py
index ff72791..b47ceae 100644
--- a/utils/api/tests.py
+++ b/utils/api/tests.py
@@ -8,25 +8,27 @@ from account.models import AdminType, ProblemPermission, User, UserProfile
class APITestCase(TestCase):
client_class = APIClient
- def create_user(self, username, password, admin_type=AdminType.REGULAR_USER, login=True, problem_permission=ProblemPermission.NONE):
+ def create_user(self, username, password, admin_type=AdminType.REGULAR_USER, login=True,
+ problem_permission=ProblemPermission.NONE):
user = User.objects.create(username=username, admin_type=admin_type, problem_permission=problem_permission)
user.set_password(password)
- UserProfile.objects.create(user=user, time_zone="Asia/Shanghai")
+ UserProfile.objects.create(user=user)
user.save()
if login:
self.client.login(username=username, password=password)
return user
def create_admin(self, username="admin", password="admin", login=True):
- return self.create_user(username=username, password=password, admin_type=AdminType.ADMIN, problem_permission=ProblemPermission.OWN,
+ return self.create_user(username=username, password=password, admin_type=AdminType.ADMIN,
+ problem_permission=ProblemPermission.OWN,
login=login)
def create_super_admin(self, username="root", password="root", login=True):
return self.create_user(username=username, password=password, admin_type=AdminType.SUPER_ADMIN,
problem_permission=ProblemPermission.ALL, login=login)
- def reverse(self, url_name):
- return reverse(url_name)
+ def reverse(self, url_name, *args, **kwargs):
+ return reverse(url_name, *args, **kwargs)
def assertSuccess(self, response):
if not response.data["error"] is None:
diff --git a/utils/cache.py b/utils/cache.py
new file mode 100644
index 0000000..ed9059b
--- /dev/null
+++ b/utils/cache.py
@@ -0,0 +1,27 @@
+from django.core.cache import cache, caches # noqa
+from django.conf import settings # noqa
+
+from django_redis.cache import RedisCache
+from django_redis.client.default import DefaultClient
+
+
+class MyRedisClient(DefaultClient):
+ def __getattr__(self, item):
+ client = self.get_client(write=True)
+ return getattr(client, item)
+
+ def redis_incr(self, key, count=1):
+ """
+ django 默认的 incr 在 key 不存在时候会抛异常
+ """
+ client = self.get_client(write=True)
+ return client.incr(key, count)
+
+
+class MyRedisCache(RedisCache):
+ def __init__(self, server, params):
+ super().__init__(server, params)
+ self._client_cls = MyRedisClient
+
+ def __getattr__(self, item):
+ return getattr(self.client, item)
diff --git a/utils/captcha/__init__.py b/utils/captcha/__init__.py
index ee3fe25..8b8375c 100644
--- a/utils/captcha/__init__.py
+++ b/utils/captcha/__init__.py
@@ -1,12 +1,9 @@
"""
Copyright 2013 TY
-
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
-
http://www.apache.org/licenses/LICENSE-2.0
-
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -18,13 +15,10 @@ import os
import time
import random
-from io import BytesIO
-from django.http import HttpResponse
from PIL import Image, ImageDraw, ImageFont
class Captcha(object):
-
def __init__(self, request):
"""
初始化,设置各种属性
@@ -60,9 +54,9 @@ class Captcha(object):
self._set_answer("".join(string))
return string
- def display(self):
+ def get(self):
"""
- 生成验证码图片
+ 生成验证码图片,返回值为图片的bytes
"""
background = (random.randrange(200, 255), random.randrange(200, 255), random.randrange(200, 255))
code_color = (random.randrange(0, 50), random.randrange(0, 50), random.randrange(0, 50), 255)
@@ -86,11 +80,8 @@ class Captcha(object):
# 随机化字符之间的距离 字符粘连可以降低识别率
x += font_size * random.randrange(6, 8) / 10
- buf = BytesIO()
- image.save(buf, "gif")
-
self.django_request.session[self.session_key] = "".join(code)
- return HttpResponse(buf.getvalue(), "image/gif")
+ return image
def check(self, code):
"""
diff --git a/utils/captcha/views.py b/utils/captcha/views.py
index ba42059..4642e32 100644
--- a/utils/captcha/views.py
+++ b/utils/captcha/views.py
@@ -1,7 +1,8 @@
-from django.http import HttpResponse
-
-from utils.captcha import Captcha
+from . import Captcha
+from ..api import APIView
+from ..shortcuts import img2base64
-def show_captcha(request):
- return HttpResponse(Captcha(request).display(), content_type="image/gif")
+class CaptchaAPIView(APIView):
+ def get(self, request):
+ return self.success(img2base64(Captcha(request).get()))
diff --git a/utils/constants.py b/utils/constants.py
new file mode 100644
index 0000000..390d568
--- /dev/null
+++ b/utils/constants.py
@@ -0,0 +1,28 @@
+class Choices:
+ @classmethod
+ def choices(cls):
+ d = cls.__dict__
+ return [d[item] for item in d.keys() if not item.startswith("__")]
+
+
+class ContestType:
+ PUBLIC_CONTEST = "Public"
+ PASSWORD_PROTECTED_CONTEST = "Password Protected"
+
+
+class ContestStatus:
+ CONTEST_NOT_START = "1"
+ CONTEST_ENDED = "-1"
+ CONTEST_UNDERWAY = "0"
+
+
+class ContestRuleType(Choices):
+ ACM = "ACM"
+ OI = "OI"
+
+
+class CacheKey:
+ waiting_queue = "waiting_queue"
+ contest_rank_cache = "contest_rank_cache"
+ website_config = "website_config"
+ option = "option"
diff --git a/utils/management/commands/initadmin.py b/utils/management/commands/initadmin.py
deleted file mode 100644
index 5829178..0000000
--- a/utils/management/commands/initadmin.py
+++ /dev/null
@@ -1,40 +0,0 @@
-from django.core.management.base import BaseCommand
-
-from account.models import AdminType, ProblemPermission, User, UserProfile
-from utils.shortcuts import rand_str # NOQA
-
-
-class Command(BaseCommand):
- def handle(self, *args, **options):
- try:
- admin = User.objects.get(username="root")
- if admin.admin_type == AdminType.SUPER_ADMIN:
- self.stdout.write(self.style.WARNING("Super admin user 'root' already exists, "
- "would you like to reset it's password?\n"
- "Input yes to confirm: "))
- if input() == "yes":
- # for dev
- # rand_password = rand_str(length=6)
- rand_password = "rootroot"
- admin.save()
- self.stdout.write(self.style.SUCCESS("Successfully created super admin user password.\n"
- "Username: root\nPassword: %s\n"
- "Remember to change password and turn on two factors auth "
- "after installation." % rand_password))
- else:
- self.stdout.write(self.style.SUCCESS("Nothing happened"))
- else:
- self.stdout.write(self.style.ERROR("User 'root' is not super admin."))
- except User.DoesNotExist:
- user = User.objects.create(username="root", email="root@oj.com", admin_type=AdminType.SUPER_ADMIN,
- problem_permission=ProblemPermission.ALL)
- # for dev
- # rand_password = rand_str(length=6)
- rand_password = "rootroot"
- user.set_password(rand_password)
- user.save()
- UserProfile.objects.create(user=user, time_zone="Asia/Shanghai")
- self.stdout.write(self.style.SUCCESS("Successfully created super admin user.\n"
- "Username: root\nPassword: %s\n"
- "Remember to change password and turn on two factors auth "
- "after installation." % rand_password))
diff --git a/utils/management/commands/inituser.py b/utils/management/commands/inituser.py
new file mode 100644
index 0000000..c3f0827
--- /dev/null
+++ b/utils/management/commands/inituser.py
@@ -0,0 +1,44 @@
+from django.core.management.base import BaseCommand
+
+from account.models import AdminType, ProblemPermission, User, UserProfile
+from utils.shortcuts import rand_str # NOQA
+
+
+class Command(BaseCommand):
+ def add_arguments(self, parser):
+ parser.add_argument("--username", type=str)
+ parser.add_argument("--password", type=str)
+ parser.add_argument("--action", type=str)
+
+ def handle(self, *args, **options):
+ username = options["username"]
+ password = options["password"]
+ action = options["action"]
+
+ if not(username and password and action):
+ self.stdout.write(self.style.ERROR("Invalid args"))
+ exit(1)
+
+ if action == "create_super_admin":
+ if User.objects.filter(username=username).exists():
+ self.stdout.write(self.style.SUCCESS(f"User {username} exists, operation ignored"))
+ exit()
+
+ user = User.objects.create(username=username, admin_type=AdminType.SUPER_ADMIN,
+ problem_permission=ProblemPermission.ALL)
+ user.set_password(password)
+ user.save()
+ UserProfile.objects.create(user=user)
+
+ self.stdout.write(self.style.SUCCESS("User created"))
+ elif action == "reset":
+ try:
+ user = User.objects.get(username=username)
+ user.set_password(password)
+ user.save()
+ self.stdout.write(self.style.SUCCESS(f"Password is rested"))
+ except User.DoesNotExist:
+ self.stdout.write(self.style.ERROR(f"User {username} doesnot exist, operation ignored"))
+ exit(1)
+ else:
+ raise ValueError("Invalid action")
diff --git a/utils/models.py b/utils/models.py
index a3e15b0..3c11452 100644
--- a/utils/models.py
+++ b/utils/models.py
@@ -1,11 +1,10 @@
+from django.contrib.postgres.fields import JSONField # NOQA
from django.db import models
from utils.xss_filter import XssHtml
class RichTextField(models.TextField):
- __metaclass__ = models.SubfieldBase
-
def get_prep_value(self, value):
if not value:
value = ""
diff --git a/utils/shortcuts.py b/utils/shortcuts.py
index 9962ade..9340564 100644
--- a/utils/shortcuts.py
+++ b/utils/shortcuts.py
@@ -1,32 +1,10 @@
-import logging
+import re
+import datetime
import random
+from base64 import b64encode
+from io import BytesIO
from django.utils.crypto import get_random_string
-from envelopes import Envelope
-
-from conf.models import SMTPConfig
-
-logger = logging.getLogger(__name__)
-
-
-def send_email(from_name, to_email, to_name, subject, content):
- smtp = SMTPConfig.objects.first()
- if not smtp:
- return
- envlope = Envelope(from_addr=(smtp.email, from_name),
- to_addr=(to_email, to_name),
- subject=subject,
- html_body=content)
- try:
- envlope.send(smtp.server,
- login=smtp.email,
- password=smtp.password,
- port=smtp.port,
- tls=smtp.tls)
- return True
- except Exception as e:
- logger.exception(e)
- return False
def rand_str(length=32, type="lower_hex"):
@@ -44,3 +22,44 @@ def rand_str(length=32, type="lower_hex"):
return random.choice("123456789abcdef") + get_random_string(length - 1, allowed_chars="0123456789abcdef")
else:
return random.choice("123456789") + get_random_string(length - 1, allowed_chars="0123456789")
+
+
+def build_query_string(kv_data, ignore_none=True):
+ # {"a": 1, "b": "test"} -> "?a=1&b=test"
+ query_string = ""
+ for k, v in kv_data.items():
+ if ignore_none is True and kv_data[k] is None:
+ continue
+ if query_string != "":
+ query_string += "&"
+ else:
+ query_string = "?"
+ query_string += (k + "=" + str(v))
+ return query_string
+
+
+def img2base64(img):
+ with BytesIO() as buf:
+ img.save(buf, "gif")
+ buf_str = buf.getvalue()
+ img_prefix = "data:image/png;base64,"
+ b64_str = img_prefix + b64encode(buf_str).decode("utf-8")
+ return b64_str
+
+
+def datetime2str(value, format="iso-8601"):
+ if format.lower() == "iso-8601":
+ value = value.isoformat()
+ if value.endswith("+00:00"):
+ value = value[:-6] + "Z"
+ return value
+ return value.strftime(format)
+
+
+def timestamp2utcstr(value):
+ return datetime.datetime.utcfromtimestamp(value).isoformat()
+
+
+def natural_sort_key(s, _nsre=re.compile(r"(\d+)")):
+ return [int(text) if text.isdigit() else text.lower()
+ for text in re.split(_nsre, s)]
diff --git a/utils/throttling.py b/utils/throttling.py
new file mode 100644
index 0000000..7c5f54a
--- /dev/null
+++ b/utils/throttling.py
@@ -0,0 +1,90 @@
+from __future__ import print_function
+import time
+
+
+class TokenBucket:
+ def __init__(self, fill_rate, capacity, last_capacity, last_timestamp):
+ self.capacity = float(capacity)
+ self._left_tokens = last_capacity
+ self.fill_rate = float(fill_rate)
+ self.timestamp = last_timestamp
+
+ def consume(self, tokens=1):
+ if tokens <= self.tokens:
+ self._left_tokens -= tokens
+ return True
+ return False
+
+ def expected_time(self, tokens=1):
+ _tokens = self.tokens
+ tokens = max(tokens, _tokens)
+ return (tokens - _tokens) / self.fill_rate * 60
+
+ @property
+ def tokens(self):
+ if self._left_tokens < self.capacity:
+ now = time.time()
+ delta = self.fill_rate * ((now - self.timestamp) / 60)
+ self._left_tokens = min(self.capacity, self._left_tokens + delta)
+ self.timestamp = now
+ return self._left_tokens
+
+
+class BucketController:
+ def __init__(self, factor, redis_conn, default_capacity):
+ self.default_capacity = default_capacity
+ self.redis = redis_conn
+ self.key = "bucket_" + str(factor)
+
+ @property
+ def last_capacity(self):
+ value = self.redis.hget(self.key, "last_capacity")
+ if value is None:
+ self.last_capacity = self.default_capacity
+ return self.default_capacity
+ return int(value)
+
+ @last_capacity.setter
+ def last_capacity(self, value):
+ self.redis.hset(self.key, "last_capacity", value)
+
+ @property
+ def last_timestamp(self):
+ value = self.redis.hget(self.key, "last_timestamp")
+ if value is None:
+ timestamp = int(time.time())
+ self.last_timestamp = timestamp
+ return timestamp
+ return int(value)
+
+ @last_timestamp.setter
+ def last_timestamp(self, value):
+ self.redis.hset(self.key, "last_timestamp", value)
+
+
+"""
+# # Token bucket, to limit submission rate
+# # Demo
+
+success = failure = 0
+current_user_id = 1
+token_bucket_default_capacity = 50
+token_bucket_fill_rate = 10
+for i in range(5000):
+ controller = BucketController(user_id=current_user_id,
+ redis_conn=redis.Redis(),
+ default_capacity=token_bucket_default_capacity)
+ bucket = TokenBucket(fill_rate=token_bucket_fill_rate,
+ capacity=token_bucket_default_capacity,
+ last_capacity=controller.last_capacity,
+ last_timestamp=controller.last_timestamp)
+ time.sleep(0.05)
+ if bucket.consume():
+ success += 1
+ print(i, ": Accepted")
+ controller.last_capacity -= 1
+ else:
+ failure += 1
+ print(i, "Dropped, time left ", bucket.expected_time())
+print(success, failure)
+"""
diff --git a/utils/urls.py b/utils/urls.py
new file mode 100644
index 0000000..ca9fb0f
--- /dev/null
+++ b/utils/urls.py
@@ -0,0 +1,7 @@
+from django.conf.urls import url
+
+from .views import SimditorImageUploadAPIView
+
+urlpatterns = [
+ url(r"^upload_image/?$", SimditorImageUploadAPIView.as_view(), name="upload_image")
+]
diff --git a/utils/views.py b/utils/views.py
new file mode 100644
index 0000000..c3e3861
--- /dev/null
+++ b/utils/views.py
@@ -0,0 +1,44 @@
+import os
+from django.conf import settings
+from account.serializers import ImageUploadForm
+from utils.shortcuts import rand_str
+from utils.api import CSRFExemptAPIView
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class SimditorImageUploadAPIView(CSRFExemptAPIView):
+ request_parsers = ()
+
+ def post(self, request):
+ form = ImageUploadForm(request.POST, request.FILES)
+ if form.is_valid():
+ img = form.cleaned_data["image"]
+ else:
+ return self.response({
+ "success": False,
+ "msg": "Upload failed",
+ "file_path": ""})
+
+ suffix = os.path.splitext(img.name)[-1].lower()
+ if suffix not in [".gif", ".jpg", ".jpeg", ".bmp", ".png"]:
+ return self.response({
+ "success": False,
+ "msg": "Unsupported file format",
+ "file_path": ""})
+ img_name = rand_str(10) + suffix
+ try:
+ with open(os.path.join(settings.UPLOAD_DIR, img_name), "wb") as imgFile:
+ for chunk in img:
+ imgFile.write(chunk)
+ except IOError as e:
+ logger.error(e)
+ return self.response({
+ "success": True,
+ "msg": "Upload Error",
+ "file_path": f"{settings.UPLOAD_PREFIX}/{img_name}"})
+ return self.response({
+ "success": True,
+ "msg": "Success",
+ "file_path": f"{settings.UPLOAD_PREFIX}/{img_name}"})
diff --git a/utils/xss_filter.py b/utils/xss_filter.py
index d29495b..34d65a8 100644
--- a/utils/xss_filter.py
+++ b/utils/xss_filter.py
@@ -26,11 +26,8 @@ Cannot defense xss in browser which is belowed IE7
浏览器版本:IE7+ 或其他浏览器,无法防御IE6及以下版本浏览器中的XSS
"""
import re
-
-try:
- from html.parser import HTMLParser
-except:
- from HTMLParser import HTMLParser
+import copy
+from html.parser import HTMLParser
class XssHtml(HTMLParser):
@@ -163,7 +160,7 @@ class XssHtml(HTMLParser):
else:
other = []
if attrs:
- for (key, value) in attrs.items():
+ for key, value in copy.deepcopy(attrs).items():
if key not in self.common_attrs + other:
del attrs[key]
return attrs