bug fixes

This commit is contained in:
virusdefender
2017-10-06 17:46:14 +08:00
committed by zema1
parent a324d55364
commit 93bd77d8d8
16 changed files with 91 additions and 94 deletions

View File

@@ -1,10 +1,5 @@
import time
import pytz
from django.contrib import auth
from django.utils import timezone
from django.utils.translation import ugettext as _
from django.db import connection from django.db import connection
from django.utils.timezone import now
from django.utils.deprecation import MiddlewareMixin from django.utils.deprecation import MiddlewareMixin
from utils.api import JSONResponse from utils.api import JSONResponse
@@ -14,14 +9,11 @@ class SessionRecordMiddleware(MiddlewareMixin):
def process_request(self, request): def process_request(self, request):
if request.user.is_authenticated(): if request.user.is_authenticated():
session = request.session session = request.session
ip = request.META.get("HTTP_X_REAL_IP", "UNKNOWN IP") session["user_agent"] = request.META.get("HTTP_USER_AGENT", "")
user_agent = request.META.get("HTTP_USER_AGENT", "") session["ip"] = request.META.get("HTTP_X_REAL_IP", "UNKNOWN IP")
_ip = session.setdefault("ip", ip) session["last_activity"] = now()
_user_agent = session.setdefault("user_agent", user_agent)
if ip != _ip or user_agent != _user_agent:
session.modified = True
user_sessions = request.user.session_keys user_sessions = request.user.session_keys
if request.session.session_key not in user_sessions: if session.session_key not in user_sessions:
user_sessions.append(session.session_key) user_sessions.append(session.session_key)
request.user.save() request.user.save()

View File

@@ -50,7 +50,7 @@ class Migration(migrations.Migration):
fields=[ fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('problems_status', jsonfield.fields.JSONField(default={})), ('problems_status', jsonfield.fields.JSONField(default={})),
('avatar', models.CharField(default=account.models._default_avatar, max_length=50)), ('avatar', models.CharField(default="default.png", max_length=50)),
('blog', models.URLField(blank=True, null=True)), ('blog', models.URLField(blank=True, null=True)),
('mood', models.CharField(blank=True, max_length=200, null=True)), ('mood', models.CharField(blank=True, max_length=200, null=True)),
('accepted_problem_number', models.IntegerField(default=0)), ('accepted_problem_number', models.IntegerField(default=0)),

View File

@@ -72,7 +72,7 @@ class UserProfile(models.Model):
oi_problems_status = JSONField(default={}) oi_problems_status = JSONField(default={})
real_name = models.CharField(max_length=32, blank=True, null=True) real_name = models.CharField(max_length=32, blank=True, null=True)
avatar = models.CharField(max_length=256, default=f"{settings.IMAGE_UPLOAD_DIR}/default.png") avatar = models.CharField(max_length=256, default=f"/{settings.IMAGE_UPLOAD_DIR}/default.png")
blog = models.URLField(blank=True, null=True) blog = models.URLField(blank=True, null=True)
mood = models.CharField(max_length=256, blank=True, null=True) mood = models.CharField(max_length=256, blank=True, null=True)
github = models.CharField(max_length=64, blank=True, null=True) github = models.CharField(max_length=64, blank=True, null=True)

View File

@@ -26,7 +26,6 @@ class UserRegisterSerializer(serializers.Serializer):
class UserChangePasswordSerializer(serializers.Serializer): class UserChangePasswordSerializer(serializers.Serializer):
old_password = serializers.CharField() old_password = serializers.CharField()
new_password = serializers.CharField(min_length=6) new_password = serializers.CharField(min_length=6)
captcha = serializers.CharField()
class UserSerializer(serializers.ModelSerializer): class UserSerializer(serializers.ModelSerializer):
@@ -46,6 +45,7 @@ class UserProfileSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = UserProfile model = UserProfile
fields = "__all__"
class UserInfoSerializer(serializers.ModelSerializer): class UserInfoSerializer(serializers.ModelSerializer):

View File

@@ -8,11 +8,9 @@ from otpauth import OtpAuth
from utils.api.tests import APIClient, APITestCase from utils.api.tests import APIClient, APITestCase
from utils.shortcuts import rand_str from utils.shortcuts import rand_str
from utils.cache import default_cache from options.options import SysOptions
from utils.constants import CacheKey
from .models import AdminType, ProblemPermission, User from .models import AdminType, ProblemPermission, User
from conf.models import WebsiteConfig
class PermissionDecoratorTest(APITestCase): class PermissionDecoratorTest(APITestCase):
@@ -157,13 +155,9 @@ class UserRegisterAPITest(CaptchaTest):
self.data = {"username": "test_user", "password": "testuserpassword", self.data = {"username": "test_user", "password": "testuserpassword",
"real_name": "real_name", "email": "test@qduoj.com", "real_name": "real_name", "email": "test@qduoj.com",
"captcha": self._set_captcha(self.client.session)} "captcha": self._set_captcha(self.client.session)}
# clea cache in redis
default_cache.delete(CacheKey.website_config)
def test_website_config_limit(self): def test_website_config_limit(self):
website = WebsiteConfig.objects.create() SysOptions.allow_register = False
website.allow_register = False
website.save()
resp = self.client.post(self.register_url, data=self.data) resp = self.client.post(self.register_url, data=self.data)
self.assertDictEqual(resp.data, {"error": "error", "data": "Register have been disabled by admin"}) self.assertDictEqual(resp.data, {"error": "error", "data": "Register have been disabled by admin"})
@@ -247,7 +241,6 @@ class TwoFactorAuthAPITest(APITestCase):
def setUp(self): def setUp(self):
self.url = self.reverse("two_factor_auth_api") self.url = self.reverse("two_factor_auth_api")
self.create_user("test", "test123") self.create_user("test", "test123")
self.create_website_config()
def _get_tfa_code(self): def _get_tfa_code(self):
user = User.objects.first() user = User.objects.first()
@@ -295,7 +288,6 @@ class ApplyResetPasswordAPITest(CaptchaTest):
user.email = "test@oj.com" user.email = "test@oj.com"
user.save() user.save()
self.url = self.reverse("apply_reset_password_api") self.url = self.reverse("apply_reset_password_api")
self.create_website_config()
self.data = {"email": "test@oj.com", "captcha": self._set_captcha(self.client.session)} self.data = {"email": "test@oj.com", "captcha": self._set_captcha(self.client.session)}
def _refresh_captcha(self): def _refresh_captcha(self):

View File

@@ -3,7 +3,7 @@ from django.conf.urls import url
from ..views.oj import (ApplyResetPasswordAPI, ResetPasswordAPI, from ..views.oj import (ApplyResetPasswordAPI, ResetPasswordAPI,
UserChangePasswordAPI, UserRegisterAPI, UserChangePasswordAPI, UserRegisterAPI,
UserLoginAPI, UserLogoutAPI, UsernameOrEmailCheck, UserLoginAPI, UserLogoutAPI, UsernameOrEmailCheck,
SSOAPI, AvatarUploadAPI, TwoFactorAuthAPI, UserProfileAPI, AvatarUploadAPI, TwoFactorAuthAPI, UserProfileAPI,
UserRankAPI, CheckTFARequiredAPI, SessionManagementAPI) UserRankAPI, CheckTFARequiredAPI, SessionManagementAPI)
from utils.captcha.views import CaptchaAPIView from utils.captcha.views import CaptchaAPIView

View File

@@ -1,5 +1,4 @@
import os import os
import pickle
from datetime import timedelta from datetime import timedelta
from importlib import import_module from importlib import import_module
@@ -16,15 +15,14 @@ from utils.constants import ContestRuleType
from options.options import SysOptions from options.options import SysOptions
from utils.api import APIView, validate_serializer from utils.api import APIView, validate_serializer
from utils.captcha import Captcha from utils.captcha import Captcha
from utils.shortcuts import rand_str, img2base64, timestamp2utcstr from utils.shortcuts import rand_str, img2base64, datetime2str
from ..decorators import login_required from ..decorators import login_required
from ..models import User, UserProfile from ..models import User, UserProfile
from ..serializers import (ApplyResetPasswordSerializer, ResetPasswordSerializer, from ..serializers import (ApplyResetPasswordSerializer, ResetPasswordSerializer,
UserChangePasswordSerializer, UserLoginSerializer, UserChangePasswordSerializer, UserLoginSerializer,
UserRegisterSerializer, UsernameOrEmailCheckSerializer, UserRegisterSerializer, UsernameOrEmailCheckSerializer,
RankInfoSerializer) RankInfoSerializer)
from ..serializers import (SSOSerializer, TwoFactorAuthCodeSerializer, from ..serializers import (TwoFactorAuthCodeSerializer, UserProfileSerializer,
UserProfileSerializer,
EditUserProfileSerializer, AvatarUploadForm) EditUserProfileSerializer, AvatarUploadForm)
from ..tasks import send_email_async from ..tasks import send_email_async
@@ -81,7 +79,7 @@ class AvatarUploadAPI(APIView):
img.write(chunk) img.write(chunk)
user_profile = request.user.userprofile user_profile = request.user.userprofile
user_profile.avatar = f"{settings.IMAGE_UPLOAD_DIR}/{name}" user_profile.avatar = f"/{settings.IMAGE_UPLOAD_DIR}/{name}"
user_profile.save() user_profile.save()
return self.success("Succeeded") return self.success("Succeeded")
@@ -327,7 +325,7 @@ class SessionManagementAPI(APIView):
s["current_session"] = True s["current_session"] = True
s["ip"] = session["ip"] s["ip"] = session["ip"]
s["user_agent"] = session["user_agent"] s["user_agent"] = session["user_agent"]
s["last_activity"] = timestamp2utcstr(session["last_activity"]) s["last_activity"] = datetime2str(session["last_activity"])
s["session_key"] = key s["session_key"] = key
result.append(s) result.append(s)
if modified: if modified:

View File

@@ -1,8 +1,9 @@
from utils.constants import ContestRuleType # noqa
from django.db import models from django.db import models
from django.utils.timezone import now from django.utils.timezone import now
from jsonfield import JSONField from jsonfield import JSONField
from utils.constants import ContestStatus, ContestRuleType, ContestType from utils.constants import ContestStatus, ContestType
from account.models import User, AdminType from account.models import User, AdminType
from utils.models import RichTextField from utils.models import RichTextField

View File

@@ -4,7 +4,7 @@ from utils.api import APIView, validate_serializer
from utils.constants import CacheKey from utils.constants import CacheKey
from account.decorators import login_required, check_contest_permission from account.decorators import login_required, check_contest_permission
from utils.constants import ContestRuleType, ContestType, ContestStatus from utils.constants import ContestRuleType, ContestStatus
from ..models import ContestAnnouncement, Contest, OIContestRank, ACMContestRank from ..models import ContestAnnouncement, Contest, OIContestRank, ACMContestRank
from ..serializers import ContestAnnouncementSerializer from ..serializers import ContestAnnouncementSerializer
from ..serializers import ContestSerializer, ContestPasswordVerifySerializer from ..serializers import ContestSerializer, ContestPasswordVerifySerializer

View File

@@ -14,7 +14,7 @@ from judge.languages import languages
from options.options import SysOptions from options.options import SysOptions
from problem.models import Problem, ProblemRuleType from problem.models import Problem, ProblemRuleType
from submission.models import JudgeStatus, Submission from submission.models import JudgeStatus, Submission
from utils.cache import judge_cache, default_cache from utils.cache import cache
from utils.constants import CacheKey from utils.constants import CacheKey
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -22,31 +22,28 @@ logger = logging.getLogger(__name__)
# 继续处理在队列中的问题 # 继续处理在队列中的问题
def process_pending_task(): def process_pending_task():
if judge_cache.llen(CacheKey.waiting_queue): if cache.llen(CacheKey.waiting_queue):
# 防止循环引入 # 防止循环引入
from judge.tasks import judge_task from judge.tasks import judge_task
data = json.loads(judge_cache.rpop(CacheKey.waiting_queue).decode("utf-8")) data = json.loads(cache.rpop(CacheKey.waiting_queue).decode("utf-8"))
judge_task.delay(**data) judge_task.delay(**data)
class JudgeDispatcher(object): class JudgeDispatcher(object):
def __init__(self, submission_id, problem_id): def __init__(self, submission_id, problem_id):
self.token = hashlib.sha256(SysOptions.judge_server_token.encode("utf-8")).hexdigest() self.token = hashlib.sha256(SysOptions.judge_server_token.encode("utf-8")).hexdigest()
self.redis_conn = judge_cache self.submission = Submission.objects.get(id=submission_id)
self.submission = Submission.objects.get(pk=submission_id)
self.contest_id = self.submission.contest_id self.contest_id = self.submission.contest_id
if self.contest_id: if self.contest_id:
self.problem = Problem.objects.select_related("contest") \ self.problem = Problem.objects.select_related("contest").get(id=problem_id, contest_id=self.contest_id)
.get(id=problem_id, contest_id=self.contest_id)
self.contest = self.problem.contest self.contest = self.problem.contest
else: else:
self.problem = Problem.objects.get(id=problem_id) self.problem = Problem.objects.get(id=problem_id)
def _request(self, url, data=None): def _request(self, url, data=None):
kwargs = {"headers": {"X-Judge-Server-Token": self.token, kwargs = {"headers": {"X-Judge-Server-Token": self.token}}
"Content-Type": "application/json"}}
if data: if data:
kwargs["data"] = json.dumps(data) kwargs["json"] = data
try: try:
return requests.post(url, **kwargs).json() return requests.post(url, **kwargs).json()
except Exception as e: except Exception as e:
@@ -55,7 +52,6 @@ class JudgeDispatcher(object):
@staticmethod @staticmethod
def choose_judge_server(): def choose_judge_server():
with transaction.atomic(): with transaction.atomic():
# TODO: use more reasonable way
servers = JudgeServer.objects.select_for_update().all().order_by("task_number") servers = JudgeServer.objects.select_for_update().all().order_by("task_number")
servers = [s for s in servers if s.status == "normal"] servers = [s for s in servers if s.status == "normal"]
if servers: if servers:
@@ -65,10 +61,10 @@ class JudgeDispatcher(object):
return server return server
@staticmethod @staticmethod
def release_judge_res(judge_server_id): def release_judge_server(judge_server_id):
with transaction.atomic(): with transaction.atomic():
# 使用原子操作, 同时因为use和release中间间隔了判题过程,需要重新查询一下 # 使用原子操作, 同时因为use和release中间间隔了判题过程,需要重新查询一下
server = JudgeServer.objects.select_for_update().get(id=judge_server_id) server = JudgeServer.objects.get(id=judge_server_id)
server.used_instance_number = F("task_number") - 1 server.used_instance_number = F("task_number") - 1
server.save() server.save()
@@ -94,7 +90,7 @@ class JudgeDispatcher(object):
server = self.choose_judge_server() server = self.choose_judge_server()
if not server: if not server:
data = {"submission_id": self.submission.id, "problem_id": self.problem.id} data = {"submission_id": self.submission.id, "problem_id": self.problem.id}
self.redis_conn.lpush(CacheKey.waiting_queue, json.dumps(data)) cache.lpush(CacheKey.waiting_queue, json.dumps(data))
return return
sub_config = list(filter(lambda item: self.submission.language == item["name"], languages))[0] sub_config = list(filter(lambda item: self.submission.language == item["name"], languages))[0]
@@ -138,7 +134,7 @@ class JudgeDispatcher(object):
else: else:
self.submission.result = JudgeStatus.PARTIALLY_ACCEPTED self.submission.result = JudgeStatus.PARTIALLY_ACCEPTED
self.submission.save() self.submission.save()
self.release_judge_res(server.id) self.release_judge_server(server.id)
self.update_problem_status() self.update_problem_status()
if self.contest_id: if self.contest_id:
@@ -223,7 +219,7 @@ class JudgeDispatcher(object):
if self.contest_id and self.contest.status != ContestStatus.CONTEST_UNDERWAY: if self.contest_id and self.contest.status != ContestStatus.CONTEST_UNDERWAY:
return return
if self.contest.real_time_rank: if self.contest.real_time_rank:
default_cache.delete(CacheKey.contest_rank_cache + str(self.contest_id)) cache.delete(CacheKey.contest_rank_cache + str(self.contest_id))
with transaction.atomic(): with transaction.atomic():
if self.contest.rule_type == ContestRuleType.ACM: if self.contest.rule_type == ContestRuleType.ACM:
acm_rank, _ = ACMContestRank.objects.select_for_update(). \ acm_rank, _ = ACMContestRank.objects.select_for_update(). \

View File

@@ -5,8 +5,12 @@ BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
DATABASES = { DATABASES = {
'default': { 'default': {
'ENGINE': 'django.db.backends.sqlite3', 'ENGINE': 'django.db.backends.postgresql_psycopg2',
'NAME': os.path.join(BASE_DIR, 'db.sqlite3'), 'HOST': '127.0.0.1',
'PORT': 5433,
'NAME': "onlinejudge",
'USER': "onlinejudge",
'PASSWORD': 'onlinejudge'
} }
} }

View File

@@ -61,7 +61,6 @@ MIDDLEWARE_CLASSES = (
'account.middleware.SessionRecordMiddleware', 'account.middleware.SessionRecordMiddleware',
# 'account.middleware.LogSqlMiddleware', # 'account.middleware.LogSqlMiddleware',
) )
SESSION_ENGINE = 'django.contrib.sessions.backends.cache'
ROOT_URLCONF = 'oj.urls' ROOT_URLCONF = 'oj.urls'
TEMPLATES = [ TEMPLATES = [
@@ -166,41 +165,33 @@ LOGGING = {
} }
REST_FRAMEWORK = { REDIS_URL = "redis://127.0.0.1:6379"
'TEST_REQUEST_DEFAULT_FORMAT': 'json',
'DEFAULT_RENDERER_CLASSES': (
'rest_framework.renderers.JSONRenderer',
)
}
CACHE_JUDGE_QUEUE = "judge_queue"
CACHE_THROTTLING = "throttling"
def redis_config(db):
def make_key(key, key_prefix, version):
return key
return {
"BACKEND": "utils.cache.MyRedisCache",
"LOCATION": f"{REDIS_URL}/{db}",
"TIMEOUT": None,
"KEY_PREFIX": "",
"KEY_FUNCTION": make_key
}
CACHES = { CACHES = {
"default": { "default": redis_config(db=1)
"BACKEND": "django_redis.cache.RedisCache",
"LOCATION": "redis://127.0.0.1:6379/1",
"OPTIONS": {
"CLIENT_CLASS": "django_redis.client.DefaultClient",
}
},
CACHE_JUDGE_QUEUE: {
"BACKEND": "django_redis.cache.RedisCache",
"LOCATION": "redis://127.0.0.1:6379/2",
"OPTIONS": {
"CLIENT_CLASS": "django_redis.client.DefaultClient",
}
},
CACHE_THROTTLING: {
"BACKEND": "django_redis.cache.RedisCache",
"LOCATION": "redis://127.0.0.1:6379/3",
"OPTIONS": {
"CLIENT_CLASS": "django_redis.client.DefaultClient",
}
}
} }
CELERY_RESULT_BACKEND = CELERY_BROKER_URL = f"{REDIS_URL}/2"
CELERY_TASK_SOFT_TIME_LIMIT = CELERY_TASK_TIME_LIMIT = 180
SESSION_ENGINE = "django.contrib.sessions.backends.cache"
SESSION_CACHE_ALIAS = "default"
# For celery # For celery
REDIS_QUEUE = { REDIS_QUEUE = {
"host": "127.0.0.1", "host": "127.0.0.1",

View File

@@ -113,15 +113,15 @@ class _SysOptionsMeta(type):
@property @property
def website_base_url(cls): def website_base_url(cls):
return cls._get_option(OptionKeys.website_base_url) return cls._get_option(OptionKeys.website_base_url)
@website_base_url.setter @website_base_url.setter
def website_base_url(cls, value): def website_base_url(cls, value):
cls._set_option(OptionKeys.website_base_url, value) cls._set_option(OptionKeys.website_base_url, value)
@property @property
def website_name(cls): def website_name(cls):
return cls._get_option(OptionKeys.website_name) return cls._get_option(OptionKeys.website_name)
@website_name.setter @website_name.setter
def website_name(cls, value): def website_name(cls, value):
cls._set_option(OptionKeys.website_name, value) cls._set_option(OptionKeys.website_name, value)
@@ -173,7 +173,7 @@ class _SysOptionsMeta(type):
@judge_server_token.setter @judge_server_token.setter
def judge_server_token(cls, value): def judge_server_token(cls, value):
cls._set_option(OptionKeys.judge_server_token, value) cls._set_option(OptionKeys.judge_server_token, value)
class SysOptions(metaclass=_SysOptionsMeta): class SysOptions(metaclass=_SysOptionsMeta):
pass pass

View File

@@ -71,6 +71,7 @@ class CreateContestProblemSerializer(CreateOrEditProblemSerializer):
class TagSerializer(serializers.ModelSerializer): class TagSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = ProblemTag model = ProblemTag
fields = "__all__"
class BaseProblemSerializer(serializers.ModelSerializer): class BaseProblemSerializer(serializers.ModelSerializer):
@@ -88,6 +89,7 @@ class BaseProblemSerializer(serializers.ModelSerializer):
class ProblemAdminSerializer(BaseProblemSerializer): class ProblemAdminSerializer(BaseProblemSerializer):
class Meta: class Meta:
model = Problem model = Problem
fields = "__all__"
class ContestProblemAdminSerializer(BaseProblemSerializer): class ContestProblemAdminSerializer(BaseProblemSerializer):

View File

@@ -5,16 +5,16 @@ from problem.models import Problem, ProblemRuleType
from contest.models import Contest, ContestStatus, ContestRuleType from contest.models import Contest, ContestStatus, ContestRuleType
from utils.api import APIView, validate_serializer from utils.api import APIView, validate_serializer
from utils.throttling import TokenBucket, BucketController from utils.throttling import TokenBucket, BucketController
from utils.cache import cache
from ..models import Submission from ..models import Submission
from ..serializers import CreateSubmissionSerializer, SubmissionModelSerializer from ..serializers import CreateSubmissionSerializer, SubmissionModelSerializer
from ..serializers import SubmissionSafeSerializer, SubmissionListSerializer from ..serializers import SubmissionSafeSerializer, SubmissionListSerializer
from utils.cache import throttling_cache
def _submit(response, user, problem_id, language, code, contest_id): def _submit(response, user, problem_id, language, code, contest_id):
# TODO: 预设默认值,需修改 # TODO: 预设默认值,需修改
controller = BucketController(user_id=user.id, controller = BucketController(user_id=user.id,
redis_conn=throttling_cache, redis_conn=cache,
default_capacity=30) default_capacity=30)
bucket = TokenBucket(fill_rate=10, capacity=20, bucket = TokenBucket(fill_rate=10, capacity=20,
last_capacity=controller.last_capacity, last_capacity=controller.last_capacity,

View File

@@ -1,6 +1,27 @@
from django.conf import settings from django.core.cache import cache, caches # noqa
from django_redis import get_redis_connection from django.conf import settings # noqa
judge_cache = get_redis_connection(settings.CACHE_JUDGE_QUEUE) from django_redis.cache import RedisCache
throttling_cache = get_redis_connection(settings.CACHE_THROTTLING) from django_redis.client.default import DefaultClient
default_cache = get_redis_connection("default")
class MyRedisClient(DefaultClient):
def __getattr__(self, item):
client = self.get_client(write=True)
return getattr(client, item)
def redis_incr(self, key, count=1):
"""
django 默认的 incr 在 key 不存在时候会抛异常
"""
client = self.get_client(write=True)
return client.incr(key, count)
class MyRedisCache(RedisCache):
def __init__(self, server, params):
super().__init__(server, params)
self._client_cls = MyRedisClient
def __getattr__(self, item):
return getattr(self.client, item)