new throttling

This commit is contained in:
virusdefender
2017-12-23 22:27:53 +08:00
parent 3c8e6cd9f3
commit 072364497c
3 changed files with 84 additions and 105 deletions

View File

@@ -21,6 +21,7 @@ class OptionKeys:
submission_list_show_all = "submission_list_show_all" submission_list_show_all = "submission_list_show_all"
smtp_config = "smtp_config" smtp_config = "smtp_config"
judge_server_token = "judge_server_token" judge_server_token = "judge_server_token"
throttling = "throttling"
class OptionDefaultValue: class OptionDefaultValue:
@@ -32,6 +33,8 @@ class OptionDefaultValue:
submission_list_show_all = True submission_list_show_all = True
smtp_config = {} smtp_config = {}
judge_server_token = default_token judge_server_token = default_token
throttling = {"ip": {"capacity": 100, "fill_rate": 0.1, "default_capacity": 50},
"user": {"capacity": 20, "fill_rate": 0.03, "default_capacity": 10}}
class _SysOptionsMeta(type): class _SysOptionsMeta(type):
@@ -180,6 +183,14 @@ class _SysOptionsMeta(type):
def judge_server_token(cls, value): def judge_server_token(cls, value):
cls._set_option(OptionKeys.judge_server_token, value) cls._set_option(OptionKeys.judge_server_token, value)
@property
def throttling(cls):
return cls._get_option(OptionKeys.throttling)
@throttling.setter
def throttling(cls, value):
cls._set_option(OptionKeys.throttling, value)
class SysOptions(metaclass=_SysOptionsMeta): class SysOptions(metaclass=_SysOptionsMeta):
pass pass

View File

@@ -1,6 +1,5 @@
import ipaddress import ipaddress
from django.conf import settings
from account.decorators import login_required, check_contest_permission from account.decorators import login_required, check_contest_permission
from judge.tasks import judge_task from judge.tasks import judge_task
# from judge.dispatcher import JudgeDispatcher # from judge.dispatcher import JudgeDispatcher
@@ -8,7 +7,7 @@ from problem.models import Problem, ProblemRuleType
from contest.models import Contest, ContestStatus, ContestRuleType from contest.models import Contest, ContestStatus, ContestRuleType
from options.options import SysOptions from options.options import SysOptions
from utils.api import APIView, validate_serializer from utils.api import APIView, validate_serializer
from utils.throttling import TokenBucket, BucketController from utils.throttling import TokenBucket
from utils.captcha import Captcha from utils.captcha import Captcha
from utils.cache import cache from utils.cache import cache
from ..models import Submission from ..models import Submission
@@ -19,29 +18,16 @@ from ..serializers import SubmissionSafeModelSerializer, SubmissionListSerialize
class SubmissionAPI(APIView): class SubmissionAPI(APIView):
def throttling(self, request): def throttling(self, request):
user_controller = BucketController(factor=request.user.id, user_bucket = TokenBucket(key=str(request.user.id),
redis_conn=cache, redis_conn=cache, **SysOptions.throttling["user"])
default_capacity=settings.TOKEN_BUCKET_DEFAULT_CAPACITY) can_consume, wait = user_bucket.consume()
user_bucket = TokenBucket(fill_rate=settings.TOKEN_BUCKET_FILL_RATE, if not can_consume:
capacity=settings.TOKEN_BUCKET_DEFAULT_CAPACITY, return "Please wait %d seconds" % (int(wait))
last_capacity=user_controller.last_capacity,
last_timestamp=user_controller.last_timestamp)
if user_bucket.consume():
user_controller.last_capacity -= 1
else:
return "Please wait %d seconds" % int(user_bucket.expected_time() + 1)
ip_controller = BucketController(factor=request.session["ip"], ip_bucket = TokenBucket(key=request.session["ip"],
redis_conn=cache, redis_conn=cache, **SysOptions.throttling["ip"])
default_capacity=settings.TOKEN_BUCKET_DEFAULT_CAPACITY * 3) can_consume, wait = ip_bucket.consume()
if not can_consume:
ip_bucket = TokenBucket(fill_rate=settings.TOKEN_BUCKET_FILL_RATE * 3,
capacity=settings.TOKEN_BUCKET_DEFAULT_CAPACITY * 3,
last_capacity=ip_controller.last_capacity,
last_timestamp=ip_controller.last_timestamp)
if ip_bucket.consume():
ip_controller.last_capacity -= 1
else:
return "Captcha is required" return "Captcha is required"
@validate_serializer(CreateSubmissionSerializer) @validate_serializer(CreateSubmissionSerializer)

View File

@@ -1,90 +1,72 @@
from __future__ import print_function
import time import time
class TokenBucket: class TokenBucket:
def __init__(self, fill_rate, capacity, last_capacity, last_timestamp): """
self.capacity = float(capacity) 注意对于单个key的操作不是线程安全的
self._left_tokens = last_capacity """
self.fill_rate = float(fill_rate) def __init__(self, key, capacity, fill_rate, default_capacity, redis_conn):
self.timestamp = last_timestamp """
:param capacity: 最大容量
:param fill_rate: 填充速度/每秒
:param default_capacity: 初始容量
:param redis_conn: redis connection
"""
self._key = key
self._capacity = capacity
self._fill_rate = fill_rate
self._default_capacity = default_capacity
self._redis_conn = redis_conn
def consume(self, tokens=1): self._last_capacity_key = "last_capacity"
if tokens <= self.tokens: self._last_timestamp_key = "last_timestamp"
self._left_tokens -= tokens
return True
return False
def expected_time(self, tokens=1): def _init_key(self):
_tokens = self.tokens self._last_capacity = self._default_capacity
tokens = max(tokens, _tokens) now = time.time()
return (tokens - _tokens) / self.fill_rate * 60 self._last_timestamp = now
return self._default_capacity, now
@property @property
def tokens(self): def _last_capacity(self):
if self._left_tokens < self.capacity: last_capacity = self._redis_conn.hget(self._key, self._last_capacity_key)
if last_capacity is None:
return self._init_key()[0]
else:
return float(last_capacity)
@_last_capacity.setter
def _last_capacity(self, value):
self._redis_conn.hset(self._key, self._last_capacity_key, value)
@property
def _last_timestamp(self):
return float(self._redis_conn.hget(self._key, self._last_timestamp_key))
@_last_timestamp.setter
def _last_timestamp(self, value):
self._redis_conn.hset(self._key, self._last_timestamp_key, value)
def _try_to_fill(self, now):
delta = self._fill_rate * (now - self._last_timestamp)
return min(self._last_capacity + delta, self._capacity)
def consume(self, num=1):
"""
消耗 num 个 token返回是否成功
:param num:
:return: result: bool, wait_time: float
"""
# print("capacity ", self.fill(time.time()))
if self._last_capacity >= num:
self._last_capacity -= num
return True, 0
else:
now = time.time() now = time.time()
delta = self.fill_rate * ((now - self.timestamp) / 60) cur_num = self._try_to_fill(now)
self._left_tokens = min(self.capacity, self._left_tokens + delta) if cur_num >= num:
self.timestamp = now self._last_capacity = cur_num - num
return self._left_tokens self._last_timestamp = now
return True, 0
else:
class BucketController: return False, (num - cur_num) / self._fill_rate
def __init__(self, factor, redis_conn, default_capacity):
self.default_capacity = default_capacity
self.redis = redis_conn
self.key = "bucket_" + str(factor)
@property
def last_capacity(self):
value = self.redis.hget(self.key, "last_capacity")
if value is None:
self.last_capacity = self.default_capacity
return self.default_capacity
return int(value)
@last_capacity.setter
def last_capacity(self, value):
self.redis.hset(self.key, "last_capacity", value)
@property
def last_timestamp(self):
value = self.redis.hget(self.key, "last_timestamp")
if value is None:
timestamp = int(time.time())
self.last_timestamp = timestamp
return timestamp
return int(value)
@last_timestamp.setter
def last_timestamp(self, value):
self.redis.hset(self.key, "last_timestamp", value)
"""
# # Token bucket, to limit submission rate
# # Demo
success = failure = 0
current_user_id = 1
token_bucket_default_capacity = 50
token_bucket_fill_rate = 10
for i in range(5000):
controller = BucketController(user_id=current_user_id,
redis_conn=redis.Redis(),
default_capacity=token_bucket_default_capacity)
bucket = TokenBucket(fill_rate=token_bucket_fill_rate,
capacity=token_bucket_default_capacity,
last_capacity=controller.last_capacity,
last_timestamp=controller.last_timestamp)
time.sleep(0.05)
if bucket.consume():
success += 1
print(i, ": Accepted")
controller.last_capacity -= 1
else:
failure += 1
print(i, "Dropped, time left ", bucket.expected_time())
print(success, failure)
"""