add some tests
This commit is contained in:
@@ -1,9 +0,0 @@
|
||||
# coding=utf-8
|
||||
import redis
|
||||
from django.conf import settings
|
||||
|
||||
|
||||
def get_cache_redis():
|
||||
return redis.Redis(host=settings.REDIS_CACHE["host"],
|
||||
port=settings.REDIS_CACHE["port"],
|
||||
db=settings.REDIS_CACHE["db"])
|
||||
@@ -1,15 +0,0 @@
|
||||
# coding=utf-8
|
||||
from envelopes import Envelope
|
||||
|
||||
from django.conf import settings
|
||||
|
||||
|
||||
def send_email(from_name, to_email, to_name, subject, content):
|
||||
envelope = Envelope(from_addr=(settings.SMTP_CONFIG["email"], from_name),
|
||||
to_addr=(to_email, to_name),
|
||||
subject=subject,
|
||||
html_body=content)
|
||||
envelope.send(settings.SMTP_CONFIG["smtp_server"],
|
||||
login=settings.SMTP_CONFIG["email"],
|
||||
password=settings.SMTP_CONFIG["password"],
|
||||
tls=settings.SMTP_CONFIG["tls"])
|
||||
@@ -1,36 +0,0 @@
|
||||
# coding=utf-8
|
||||
import shutil
|
||||
import os
|
||||
from django.conf import settings
|
||||
from django.core.management.base import BaseCommand
|
||||
from problem.models import Problem
|
||||
from contest.models import ContestProblem
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
"""
|
||||
清除测试用例文件夹中无用的测试用例
|
||||
"""
|
||||
def handle(self, *args, **options):
|
||||
self.stdout.write(self.style.WARNING("Please backup your test case dir firstly!"))
|
||||
problem_test_cases = [item.test_case_id for item in Problem.objects.all()]
|
||||
contest_problem_test_cases = [item.test_case_id for item in ContestProblem.objects.all()]
|
||||
|
||||
test_cases = list(set(problem_test_cases + contest_problem_test_cases))
|
||||
test_cases_dir = os.listdir(settings.TEST_CASE_DIR)
|
||||
|
||||
# 在 test_cases_dir 而不在 test_cases 中的
|
||||
dir_to_be_removed = list(set(test_cases_dir).difference(set(test_cases)))
|
||||
if dir_to_be_removed:
|
||||
self.stdout.write(self.style.ERROR("Following dirs will be removed: "))
|
||||
for item in dir_to_be_removed:
|
||||
self.stdout.write(self.style.WARNING(os.path.join(settings.TEST_CASE_DIR, item)))
|
||||
self.stdout.write(self.style.ERROR("Input yes to confirm: "))
|
||||
if raw_input() == "yes":
|
||||
for item in dir_to_be_removed:
|
||||
shutil.rmtree(os.path.join(settings.TEST_CASE_DIR, item), ignore_errors=True)
|
||||
self.stdout.write(self.style.SUCCESS("Done"))
|
||||
else:
|
||||
self.stdout.write(self.style.SUCCESS("Nothing happened"))
|
||||
else:
|
||||
self.stdout.write(self.style.SUCCESS("Test case dir is clean, nothing to do"))
|
||||
@@ -1,21 +0,0 @@
|
||||
# coding=utf-8
|
||||
|
||||
from django.core.management.base import BaseCommand
|
||||
import os
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
def handle(self, *args, **options):
|
||||
try:
|
||||
if os.system("python manage.py migrate") != 0:
|
||||
self.stdout.write(self.style.ERROR("Failed to execute command 'migrate'"))
|
||||
exit(1)
|
||||
if os.system("python manage.py migrate --database=submission") != 0:
|
||||
self.stdout.write(self.style.ERROR("Failed to execute command 'migrate --database=submission'"))
|
||||
exit(1)
|
||||
if os.system("python manage.py initadmin") != 0:
|
||||
self.stdout.write(self.style.ERROR("Failed to execute command 'initadmin'"))
|
||||
exit(1)
|
||||
self.stdout.write(self.style.SUCCESS("Done"))
|
||||
except Exception as e:
|
||||
self.stdout.write(self.style.ERROR("Failed to initialize, error: " + str(e)))
|
||||
@@ -1,19 +0,0 @@
|
||||
# coding=utf-8
|
||||
from django.core.management.base import BaseCommand
|
||||
from account.models import UserProfile
|
||||
from submission.models import Submission
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
def handle(self, *args, **options):
|
||||
self.stdout.write(self.style.SUCCESS("Please wait a minute"))
|
||||
for profile in UserProfile.objects.all():
|
||||
submissions = Submission.objects.filter(user_id=profile.user.id)
|
||||
profile.submission_number = submissions.count()
|
||||
accepted_problem_number = len(set(Submission.objects.filter(user_id=profile.user.id, contest_id__isnull=True)\
|
||||
.values_list("problem_id", flat=True)))
|
||||
accepted_contest_problem_number = len(set(Submission.objects.filter(user_id=profile.user.id, contest_id__isnull=False)\
|
||||
.values_list("problem_id", flat=True)))
|
||||
profile.accepted_problem_number = accepted_problem_number + accepted_contest_problem_number
|
||||
profile.save()
|
||||
self.stdout.write(self.style.SUCCESS("Succeeded"))
|
||||
@@ -1,33 +1,54 @@
|
||||
# coding=utf-8
|
||||
import os
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
|
||||
from django.shortcuts import render
|
||||
from django.http import HttpResponse
|
||||
from django.core.paginator import Paginator
|
||||
from django.http import HttpResponseRedirect
|
||||
|
||||
from rest_framework.response import Response
|
||||
from django.views.generic import View
|
||||
|
||||
|
||||
logger = logging.getLogger("app_info")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def error_page(request, error_reason):
|
||||
return render(request, "utils/error.html", {"error": error_reason})
|
||||
def JSONResponse(data, content_type="application/json"):
|
||||
resp = HttpResponse(json.dumps(data, indent=4), content_type=content_type)
|
||||
resp.data = data
|
||||
return resp
|
||||
|
||||
|
||||
def error_response(error_reason):
|
||||
return Response(data={"code": 1, "data": error_reason})
|
||||
class APIView(View):
|
||||
def _get_request_json(self, request):
|
||||
if request.method != "GET":
|
||||
body = request.body
|
||||
if body:
|
||||
return json.loads(body.decode("utf-8"))
|
||||
return {}
|
||||
return request.GET
|
||||
|
||||
def success(self, data=None):
|
||||
return JSONResponse({"error": None, "data": data})
|
||||
|
||||
def serializer_invalid_response(serializer):
|
||||
for k, v in serializer.errors.iteritems():
|
||||
return error_response(k + " : " + v[0])
|
||||
def error(self, message):
|
||||
return JSONResponse({"error": "error", "data": message})
|
||||
|
||||
def invalid_serializer(self, serializer):
|
||||
for k, v in serializer.errors.items():
|
||||
return self.error(k + ": " + v[0])
|
||||
|
||||
def success_response(data):
|
||||
return Response(data={"code": 0, "data": data})
|
||||
def server_error(self):
|
||||
return self.error("Server Error")
|
||||
|
||||
def dispatch(self, request, *args, **kwargs):
|
||||
try:
|
||||
request.data = self._get_request_json(self.request)
|
||||
except ValueError:
|
||||
return self.error("Invalid JSON")
|
||||
try:
|
||||
return super(APIView, self).dispatch(request, *args, **kwargs)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
return self.server_error()
|
||||
|
||||
|
||||
def paginate_data(request, query_set, object_serializer):
|
||||
@@ -41,9 +62,9 @@ def paginate_data(request, query_set, object_serializer):
|
||||
}
|
||||
]
|
||||
如果 url 中有 paging=true 的参数,
|
||||
然后还需要读取其余的两个参数,page=[int],需要的页码,p
|
||||
age_size=[int],一页的数据条数
|
||||
|
||||
然后还需要读取其余的两个参数,page=[int],需要的页码
|
||||
page_size=[int],一页的数据条数
|
||||
:param request
|
||||
:param query_set 数据库查询结果
|
||||
:param object_serializer: 序列化单个object的serializer
|
||||
"""
|
||||
@@ -97,34 +118,18 @@ def paginate_data(request, query_set, object_serializer):
|
||||
return data
|
||||
|
||||
|
||||
def paginate(request, query_set, object_serializer=None):
|
||||
try:
|
||||
data= paginate_data(request, query_set, object_serializer)
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
return error_response(u"参数错误")
|
||||
return success_response(data)
|
||||
|
||||
|
||||
def rand_str(length=32):
|
||||
if length > 128:
|
||||
raise ValueError("length must <= 128")
|
||||
return hashlib.sha512(os.urandom(128)).hexdigest()[0:length]
|
||||
|
||||
|
||||
def build_query_string(kv_data, ignore_none=True):
|
||||
# {"a": 1, "b": "test"} -> "?a=1&b=test"
|
||||
query_string = ""
|
||||
for k, v in kv_data.iteritems():
|
||||
if ignore_none is True and kv_data[k] is None:
|
||||
continue
|
||||
if query_string != "":
|
||||
query_string += "&"
|
||||
else:
|
||||
query_string = "?"
|
||||
query_string += (k + "=" + str(v))
|
||||
return query_string
|
||||
|
||||
|
||||
def redirect_to_login(request):
|
||||
return HttpResponseRedirect("/login/?__from=" + urllib.quote(request.path))
|
||||
def rand_str(length=32, type="lower_hex"):
|
||||
"""
|
||||
生成指定长度的随机字符串或者数字, 只用于随机编号等, 不要用于密钥等场景
|
||||
:param length: 字符串或者数字的长度
|
||||
:param type: str 代表随机字符串,num 代表随机数字
|
||||
:return: 字符串
|
||||
"""
|
||||
if type == "str":
|
||||
return ''.join(random.choice("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789") for i in range(length))
|
||||
elif type == "lower_str":
|
||||
return ''.join(random.choice("abcdefghijklmnopqrstuvwxyz0123456789") for i in range(length))
|
||||
elif type == "lower_hex":
|
||||
return ''.join(random.choice("0123456789abcdef") for i in range(length))
|
||||
else:
|
||||
return random.choice("123456789") + ''.join(random.choice("0123456789") for i in range(length - 1))
|
||||
@@ -1,26 +0,0 @@
|
||||
# copy from https://github.com/DMOJ/judge/issues/162
|
||||
try:
|
||||
# in large part from http://code.activestate.com/recipes/578899-strsignal/
|
||||
import signal
|
||||
import ctypes
|
||||
import ctypes.util
|
||||
|
||||
libc = ctypes.CDLL(ctypes.util.find_library("c"))
|
||||
strsignal_c = ctypes.CFUNCTYPE(ctypes.c_char_p, ctypes.c_int)(("strsignal", libc), ((1,),))
|
||||
NSIG = signal.NSIG
|
||||
|
||||
|
||||
def strsignal_ctypes_wrapper(signo):
|
||||
# The behavior of the C library strsignal() is unspecified if
|
||||
# called with an out-of-range argument. Range-check on entry
|
||||
# _and_ NULL-check on exit.
|
||||
if 0 <= signo < NSIG:
|
||||
s = strsignal_c(signo)
|
||||
if s:
|
||||
return s.decode("utf-8")
|
||||
return "Unknown signal %d" % signo
|
||||
|
||||
|
||||
strsignal = strsignal_ctypes_wrapper
|
||||
except:
|
||||
strsignal = lambda x: 'signal %d' % x
|
||||
@@ -1 +0,0 @@
|
||||
# coding=utf-8
|
||||
@@ -1,11 +0,0 @@
|
||||
# coding=utf-8
|
||||
from django import template
|
||||
|
||||
from announcement.models import Announcement
|
||||
|
||||
|
||||
def public_announcement_list():
|
||||
return Announcement.objects.filter(visible=True).order_by("-create_time")
|
||||
|
||||
register = template.Library()
|
||||
register.assignment_tag(public_announcement_list, name="public_announcement_list")
|
||||
@@ -1,73 +0,0 @@
|
||||
# coding=utf-8
|
||||
import json
|
||||
|
||||
|
||||
def get_contest_status(contest):
|
||||
status = contest.status
|
||||
if status == 1:
|
||||
return "没有开始"
|
||||
elif status == -1:
|
||||
return "已经结束"
|
||||
else:
|
||||
return "正在进行"
|
||||
|
||||
|
||||
def get_contest_status_color(contest):
|
||||
status = contest.status
|
||||
if status == 1:
|
||||
return "info"
|
||||
elif status == -1:
|
||||
return "warning"
|
||||
else:
|
||||
return "success"
|
||||
|
||||
|
||||
def get_the_formatted_time(seconds):
|
||||
if not seconds:
|
||||
return ""
|
||||
seconds = int(seconds)
|
||||
hour = seconds / (60 * 60)
|
||||
minute = (seconds - hour * 60 * 60) / 60
|
||||
second = seconds - hour * 60 * 60 - minute * 60
|
||||
return str(hour) + ":" + str(minute) + ":" + str(second)
|
||||
|
||||
|
||||
def get_submission_class(rank, problem):
|
||||
submission_info = json.loads(rank["submission_info"])
|
||||
if str(problem.id) not in submission_info:
|
||||
return ""
|
||||
else:
|
||||
submission = submission_info[str(problem.id)]
|
||||
if submission["is_ac"]:
|
||||
_class = "alert-success"
|
||||
if submission["is_first_ac"]:
|
||||
_class += " first-achieved"
|
||||
else:
|
||||
_class = "alert-danger"
|
||||
return _class
|
||||
|
||||
|
||||
def get_submission_content(rank, problem):
|
||||
submission_info = json.loads(rank["submission_info"])
|
||||
if str(problem.id) not in submission_info:
|
||||
return ""
|
||||
else:
|
||||
submission = submission_info[str(problem.id)]
|
||||
if submission["is_ac"]:
|
||||
r = get_the_formatted_time(submission["ac_time"])
|
||||
if submission["error_number"]:
|
||||
r += "<br>(-" + str(submission["error_number"]) + ")"
|
||||
return r
|
||||
else:
|
||||
return "(-" + str(submission["error_number"]) + ")"
|
||||
|
||||
|
||||
from django import template
|
||||
|
||||
register = template.Library()
|
||||
register.filter("contest_status", get_contest_status)
|
||||
register.filter("contest_status_color", get_contest_status_color)
|
||||
register.filter("format_seconds", get_the_formatted_time)
|
||||
register.simple_tag(get_submission_class, name="get_submission_class")
|
||||
register.simple_tag(get_submission_content, name="get_submission_content")
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
# coding=utf-8
|
||||
|
||||
|
||||
def get_problem_accepted_radio(problem):
|
||||
if problem.total_submit_number:
|
||||
return str(int((problem.total_accepted_number * 100) / problem.total_submit_number)) \
|
||||
+ "% (" + str(problem.total_accepted_number) + " / " + str(problem.total_submit_number) + ")"
|
||||
return "0%"
|
||||
|
||||
|
||||
def get_problem_status(problems_status, problem_id):
|
||||
# 用户没登陆 或者 user.problem_status 中没有这个字段都会到导致这里的problem_status 为 ""
|
||||
if not problems_status:
|
||||
return ""
|
||||
|
||||
if str(problem_id) in problems_status:
|
||||
if problems_status[str(problem_id)] == 1:
|
||||
return "glyphicon glyphicon-ok ac-flag"
|
||||
return "glyphicon glyphicon-minus dealing-flag"
|
||||
return ""
|
||||
|
||||
|
||||
from django import template
|
||||
|
||||
register = template.Library()
|
||||
register.filter("accepted_radio", get_problem_accepted_radio)
|
||||
register.simple_tag(get_problem_status, name="get_problem_status")
|
||||
@@ -1,44 +0,0 @@
|
||||
# coding=utf-8
|
||||
from django import template
|
||||
from utils.signal2str import strsignal
|
||||
|
||||
|
||||
def translate_result(value):
|
||||
results = {
|
||||
0: "Accepted",
|
||||
1: "Runtime Error",
|
||||
2: "Time Limit Exceeded",
|
||||
3: "Memory Limit Exceeded",
|
||||
4: "Compile Error",
|
||||
5: "Format Error",
|
||||
6: "Wrong Answer",
|
||||
7: "System Error",
|
||||
8: "Waiting"
|
||||
}
|
||||
return results[value]
|
||||
|
||||
|
||||
def translate_signal(value):
|
||||
if not value:
|
||||
return ""
|
||||
else:
|
||||
return strsignal(value)
|
||||
|
||||
|
||||
def translate_language(value):
|
||||
return {1: "C", 2: "C++", 3: "Java"}[value]
|
||||
|
||||
|
||||
def translate_result_class(value):
|
||||
if value == 0:
|
||||
return "success"
|
||||
elif value == 8:
|
||||
return "info"
|
||||
return "danger"
|
||||
|
||||
|
||||
register = template.Library()
|
||||
register.filter("translate_result", translate_result)
|
||||
register.filter("translate_language", translate_language)
|
||||
register.filter("translate_result_class", translate_result_class)
|
||||
register.filter("translate_signal", translate_signal)
|
||||
@@ -1,17 +0,0 @@
|
||||
# coding=utf-8
|
||||
import datetime
|
||||
from account.models import User
|
||||
|
||||
|
||||
def get_username(user_id):
|
||||
try:
|
||||
return User.objects.get(id=user_id).username
|
||||
except User.DoesNotExist:
|
||||
return ""
|
||||
|
||||
|
||||
from django import template
|
||||
|
||||
register = template.Library()
|
||||
register.filter("get_username", get_username)
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
# coding=utf-8
|
||||
from django import template
|
||||
from django.conf import settings
|
||||
register = template.Library()
|
||||
|
||||
|
||||
@register.simple_tag
|
||||
def show_website_info(name):
|
||||
return settings.WEBSITE_INFO[name]
|
||||
@@ -1,9 +0,0 @@
|
||||
# coding=utf-8
|
||||
from django.conf.urls import include, url
|
||||
|
||||
|
||||
|
||||
urlpatterns = [
|
||||
url(r'^paginate_test/$', "utils.tests.pagination_test_func"),
|
||||
]
|
||||
|
||||
@@ -1,68 +1,30 @@
|
||||
# coding=utf-8
|
||||
from rest_framework.test import APIClient, APITestCase
|
||||
from rest_framework import serializers
|
||||
from rest_framework.decorators import api_view
|
||||
from __future__ import unicode_literals
|
||||
from django.test.testcases import TestCase
|
||||
from rest_framework.test import APIClient
|
||||
|
||||
from account.models import User
|
||||
from .shortcuts import paginate
|
||||
from account.models import User, AdminType
|
||||
|
||||
|
||||
class PaginationTestSerialiser(serializers.Serializer):
|
||||
username = serializers.CharField(max_length=100)
|
||||
class APITestCase(TestCase):
|
||||
client_class = APIClient
|
||||
|
||||
def create_user(self, username, password, admin_type=AdminType.REGULAR_USER, login=False):
|
||||
user = User.objects.create(username=username, admin_type=admin_type)
|
||||
user.set_password(password)
|
||||
user.save()
|
||||
if login:
|
||||
self.client.login(username=username, password=password)
|
||||
return user
|
||||
|
||||
@api_view(["GET"])
|
||||
def pagination_test_func(request):
|
||||
return paginate(request, User.objects.all(), PaginationTestSerialiser)
|
||||
def create_admin(self, username="admin", password="admin", login=False):
|
||||
return self.create_user(username=username, password=password, admin_type=AdminType.ADMIN, login=login)
|
||||
|
||||
def create_super_admin(self, username="root", password="root", login=False):
|
||||
return self.create_user(username=username, password=password, admin_type=AdminType.SUPER_ADMIN, login=login)
|
||||
|
||||
class PaginatorTest(APITestCase):
|
||||
urls = "utils.test_urls"
|
||||
def assertSuccess(self, response):
|
||||
self.assertTrue(response.data["error"] is None)
|
||||
|
||||
def setUp(self):
|
||||
self.client = APIClient()
|
||||
self.url = "/paginate_test/"
|
||||
User.objects.create(username="test1")
|
||||
User.objects.create(username="test2")
|
||||
|
||||
def test_no_paginate(self):
|
||||
response = self.client.get(self.url)
|
||||
self.assertEqual(response.data["code"], 0)
|
||||
self.assertNotIn("next_page", response.data["data"])
|
||||
self.assertNotIn("previous_page", response.data["data"])
|
||||
|
||||
def test_error_parameter(self):
|
||||
response = self.client.get(self.url + "?paging=true")
|
||||
self.assertEqual(response.data, {"code": 1, "data": u"参数错误"})
|
||||
|
||||
response = self.client.get(self.url + "?paging=true&page_size=-1")
|
||||
self.assertEqual(response.data, {"code": 1, "data": u"参数错误"})
|
||||
|
||||
response = self.client.get(self.url + "?paging=true&page_size=aa")
|
||||
self.assertEqual(response.data, {"code": 1, "data": u"参数错误"})
|
||||
|
||||
response = self.client.get(self.url + "?paging=true&page_size=1&page=-1")
|
||||
self.assertEqual(response.data, {"code": 1, "data": u"参数错误"})
|
||||
|
||||
response = self.client.get(self.url + "?paging=true&page_size=aaa&page=1")
|
||||
self.assertEqual(response.data, {"code": 1, "data": u"参数错误"})
|
||||
|
||||
response = self.client.get(self.url + "?paging=true&page_size=1&page=aaa")
|
||||
self.assertEqual(response.data, {"code": 1, "data": u"参数错误"})
|
||||
|
||||
def test_correct_paginate(self):
|
||||
response = self.client.get(self.url + "?paging=true&page_size=1&page=1")
|
||||
self.assertEqual(response.data["code"], 0)
|
||||
self.assertEqual(response.data["data"]["previous_page"], None)
|
||||
self.assertEqual(response.data["data"]["next_page"], 2)
|
||||
self.assertEqual(len(response.data["data"]["results"]), 1)
|
||||
self.assertEqual(response.data["data"]["count"], 2)
|
||||
self.assertEqual(response.data["data"]["total_page"], 2)
|
||||
|
||||
response = self.client.get(self.url + "?paging=true&page_size=2&page=1")
|
||||
self.assertEqual(response.data["code"], 0)
|
||||
self.assertEqual(response.data["data"]["previous_page"], None)
|
||||
self.assertEqual(response.data["data"]["next_page"], None)
|
||||
self.assertEqual(len(response.data["data"]["results"]), 2)
|
||||
self.assertEqual(response.data["data"]["count"], 2)
|
||||
self.assertEqual(response.data["data"]["total_page"], 1)
|
||||
def assertFailed(self, response):
|
||||
self.assertTrue(response.data["error"] is not None)
|
||||
@@ -1,94 +0,0 @@
|
||||
# coding=utf-8
|
||||
import time
|
||||
import redis
|
||||
|
||||
|
||||
class TokenBucket(object):
|
||||
def __init__(self, fill_rate, capacity, last_capacity, last_timestamp):
|
||||
self.capacity = float(capacity)
|
||||
self._left_tokens = last_capacity
|
||||
self.fill_rate = float(fill_rate)
|
||||
self.timestamp = last_timestamp
|
||||
|
||||
def consume(self, tokens=1):
|
||||
if tokens <= self.tokens:
|
||||
self._left_tokens -= tokens
|
||||
return True
|
||||
return False
|
||||
|
||||
def expected_time(self, tokens=1):
|
||||
_tokens = self.tokens
|
||||
tokens = max(tokens, _tokens)
|
||||
return (tokens - _tokens) / self.fill_rate * 60
|
||||
|
||||
@property
|
||||
def tokens(self):
|
||||
if self._left_tokens < self.capacity:
|
||||
now = time.time()
|
||||
delta = self.fill_rate * ((now - self.timestamp) / 60)
|
||||
self._left_tokens = min(self.capacity, self._left_tokens + delta)
|
||||
self.timestamp = now
|
||||
return self._left_tokens
|
||||
|
||||
|
||||
class BucketController(object):
|
||||
def __init__(self, user_id, redis_conn, default_capacity):
|
||||
self.user_id = user_id
|
||||
self.default_capacity = default_capacity
|
||||
self.redis = redis_conn
|
||||
self.key = "bucket_" + str(self.user_id)
|
||||
|
||||
@property
|
||||
def last_capacity(self):
|
||||
value = self.redis.hget(self.key, "last_capacity")
|
||||
if value is None:
|
||||
self.last_capacity = self.default_capacity
|
||||
return self.default_capacity
|
||||
return int(value)
|
||||
|
||||
@last_capacity.setter
|
||||
def last_capacity(self, value):
|
||||
self.redis.hset(self.key, "last_capacity", value)
|
||||
|
||||
@property
|
||||
def last_timestamp(self):
|
||||
value = self.redis.hget(self.key, "last_timestamp")
|
||||
if value is None:
|
||||
timestamp = int(time.time())
|
||||
self.last_timestamp = timestamp
|
||||
return timestamp
|
||||
return int(value)
|
||||
|
||||
@last_timestamp.setter
|
||||
def last_timestamp(self, value):
|
||||
self.redis.hset(self.key, "last_timestamp", value)
|
||||
|
||||
|
||||
"""
|
||||
# token bucket 机制限制用户提交大量代码
|
||||
# demo
|
||||
success = failure = 0
|
||||
current_user_id = 1
|
||||
token_bucket_default_capacity = 50
|
||||
token_bucket_fill_rate = 10
|
||||
|
||||
|
||||
for i in range(5000):
|
||||
controller = BucketController(user_id=current_user_id,
|
||||
redis_conn=redis.Redis(),
|
||||
default_capacity=token_bucket_default_capacity)
|
||||
bucket = TokenBucket(fill_rate=token_bucket_fill_rate,
|
||||
capacity=token_bucket_default_capacity,
|
||||
last_capacity=controller.last_capacity,
|
||||
last_timestamp=controller.last_timestamp)
|
||||
|
||||
time.sleep(0.05)
|
||||
if bucket.consume():
|
||||
success += 1
|
||||
print i, ": Accepted"
|
||||
controller.last_capacity -= 1
|
||||
else:
|
||||
failure += 1
|
||||
print i, "Dropped, time left ", bucket.expected_time()
|
||||
print success, failure
|
||||
"""
|
||||
@@ -1,37 +0,0 @@
|
||||
# coding=utf-8
|
||||
|
||||
from rest_framework.views import APIView
|
||||
from rest_framework.response import Response
|
||||
from django.conf import settings
|
||||
|
||||
from utils.shortcuts import rand_str
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("app_info")
|
||||
|
||||
|
||||
class SimditorImageUploadAPIView(APIView):
|
||||
def post(self, request):
|
||||
if "image" not in request.FILES:
|
||||
return Response(data={
|
||||
"success": False,
|
||||
"msg": "上传失败",
|
||||
"file_path": "/"})
|
||||
img = request.FILES["image"]
|
||||
|
||||
image_name = rand_str() + '.' + str(request.FILES["image"].name.split('.')[-1])
|
||||
image_dir = settings.IMAGE_UPLOAD_DIR + image_name
|
||||
try:
|
||||
with open(image_dir, "wb") as imageFile:
|
||||
for chunk in img:
|
||||
imageFile.write(chunk)
|
||||
except IOError as e:
|
||||
logger.error(e)
|
||||
return Response(data={
|
||||
"success": True,
|
||||
"msg": "上传错误",
|
||||
"file_path": "/static/upload/" + image_name})
|
||||
return Response(data={
|
||||
"success": True,
|
||||
"msg": "",
|
||||
"file_path": "/static/upload/" + image_name})
|
||||
Reference in New Issue
Block a user