add some tests

This commit is contained in:
virusdefender
2016-10-30 02:17:35 +08:00
parent 078de956e5
commit 39857d1b56
118 changed files with 326 additions and 19131 deletions

View File

@@ -1,9 +0,0 @@
# coding=utf-8
import redis
from django.conf import settings
def get_cache_redis():
return redis.Redis(host=settings.REDIS_CACHE["host"],
port=settings.REDIS_CACHE["port"],
db=settings.REDIS_CACHE["db"])

View File

@@ -1,15 +0,0 @@
# coding=utf-8
from envelopes import Envelope
from django.conf import settings
def send_email(from_name, to_email, to_name, subject, content):
envelope = Envelope(from_addr=(settings.SMTP_CONFIG["email"], from_name),
to_addr=(to_email, to_name),
subject=subject,
html_body=content)
envelope.send(settings.SMTP_CONFIG["smtp_server"],
login=settings.SMTP_CONFIG["email"],
password=settings.SMTP_CONFIG["password"],
tls=settings.SMTP_CONFIG["tls"])

View File

@@ -1,36 +0,0 @@
# coding=utf-8
import shutil
import os
from django.conf import settings
from django.core.management.base import BaseCommand
from problem.models import Problem
from contest.models import ContestProblem
class Command(BaseCommand):
"""
清除测试用例文件夹中无用的测试用例
"""
def handle(self, *args, **options):
self.stdout.write(self.style.WARNING("Please backup your test case dir firstly!"))
problem_test_cases = [item.test_case_id for item in Problem.objects.all()]
contest_problem_test_cases = [item.test_case_id for item in ContestProblem.objects.all()]
test_cases = list(set(problem_test_cases + contest_problem_test_cases))
test_cases_dir = os.listdir(settings.TEST_CASE_DIR)
# 在 test_cases_dir 而不在 test_cases 中的
dir_to_be_removed = list(set(test_cases_dir).difference(set(test_cases)))
if dir_to_be_removed:
self.stdout.write(self.style.ERROR("Following dirs will be removed: "))
for item in dir_to_be_removed:
self.stdout.write(self.style.WARNING(os.path.join(settings.TEST_CASE_DIR, item)))
self.stdout.write(self.style.ERROR("Input yes to confirm: "))
if raw_input() == "yes":
for item in dir_to_be_removed:
shutil.rmtree(os.path.join(settings.TEST_CASE_DIR, item), ignore_errors=True)
self.stdout.write(self.style.SUCCESS("Done"))
else:
self.stdout.write(self.style.SUCCESS("Nothing happened"))
else:
self.stdout.write(self.style.SUCCESS("Test case dir is clean, nothing to do"))

View File

@@ -1,21 +0,0 @@
# coding=utf-8
from django.core.management.base import BaseCommand
import os
class Command(BaseCommand):
def handle(self, *args, **options):
try:
if os.system("python manage.py migrate") != 0:
self.stdout.write(self.style.ERROR("Failed to execute command 'migrate'"))
exit(1)
if os.system("python manage.py migrate --database=submission") != 0:
self.stdout.write(self.style.ERROR("Failed to execute command 'migrate --database=submission'"))
exit(1)
if os.system("python manage.py initadmin") != 0:
self.stdout.write(self.style.ERROR("Failed to execute command 'initadmin'"))
exit(1)
self.stdout.write(self.style.SUCCESS("Done"))
except Exception as e:
self.stdout.write(self.style.ERROR("Failed to initialize, error: " + str(e)))

View File

@@ -1,19 +0,0 @@
# coding=utf-8
from django.core.management.base import BaseCommand
from account.models import UserProfile
from submission.models import Submission
class Command(BaseCommand):
def handle(self, *args, **options):
self.stdout.write(self.style.SUCCESS("Please wait a minute"))
for profile in UserProfile.objects.all():
submissions = Submission.objects.filter(user_id=profile.user.id)
profile.submission_number = submissions.count()
accepted_problem_number = len(set(Submission.objects.filter(user_id=profile.user.id, contest_id__isnull=True)\
.values_list("problem_id", flat=True)))
accepted_contest_problem_number = len(set(Submission.objects.filter(user_id=profile.user.id, contest_id__isnull=False)\
.values_list("problem_id", flat=True)))
profile.accepted_problem_number = accepted_problem_number + accepted_contest_problem_number
profile.save()
self.stdout.write(self.style.SUCCESS("Succeeded"))

View File

@@ -1,33 +1,54 @@
# coding=utf-8
import os
import hashlib
import json
import logging
import random
from django.shortcuts import render
from django.http import HttpResponse
from django.core.paginator import Paginator
from django.http import HttpResponseRedirect
from rest_framework.response import Response
from django.views.generic import View
logger = logging.getLogger("app_info")
logger = logging.getLogger(__name__)
def error_page(request, error_reason):
return render(request, "utils/error.html", {"error": error_reason})
def JSONResponse(data, content_type="application/json"):
resp = HttpResponse(json.dumps(data, indent=4), content_type=content_type)
resp.data = data
return resp
def error_response(error_reason):
return Response(data={"code": 1, "data": error_reason})
class APIView(View):
def _get_request_json(self, request):
if request.method != "GET":
body = request.body
if body:
return json.loads(body.decode("utf-8"))
return {}
return request.GET
def success(self, data=None):
return JSONResponse({"error": None, "data": data})
def serializer_invalid_response(serializer):
for k, v in serializer.errors.iteritems():
return error_response(k + " : " + v[0])
def error(self, message):
return JSONResponse({"error": "error", "data": message})
def invalid_serializer(self, serializer):
for k, v in serializer.errors.items():
return self.error(k + ": " + v[0])
def success_response(data):
return Response(data={"code": 0, "data": data})
def server_error(self):
return self.error("Server Error")
def dispatch(self, request, *args, **kwargs):
try:
request.data = self._get_request_json(self.request)
except ValueError:
return self.error("Invalid JSON")
try:
return super(APIView, self).dispatch(request, *args, **kwargs)
except Exception as e:
logging.exception(e)
return self.server_error()
def paginate_data(request, query_set, object_serializer):
@@ -41,9 +62,9 @@ def paginate_data(request, query_set, object_serializer):
}
]
如果 url 中有 paging=true 的参数,
然后还需要读取其余的两个参数page=[int],需要的页码p
age_size=[int],一页的数据条数
然后还需要读取其余的两个参数page=[int],需要的页码
page_size=[int],一页的数据条数
:param request
:param query_set 数据库查询结果
:param object_serializer: 序列化单个object的serializer
"""
@@ -97,34 +118,18 @@ def paginate_data(request, query_set, object_serializer):
return data
def paginate(request, query_set, object_serializer=None):
try:
data= paginate_data(request, query_set, object_serializer)
except Exception as e:
logger.error(str(e))
return error_response(u"参数错误")
return success_response(data)
def rand_str(length=32):
if length > 128:
raise ValueError("length must <= 128")
return hashlib.sha512(os.urandom(128)).hexdigest()[0:length]
def build_query_string(kv_data, ignore_none=True):
# {"a": 1, "b": "test"} -> "?a=1&b=test"
query_string = ""
for k, v in kv_data.iteritems():
if ignore_none is True and kv_data[k] is None:
continue
if query_string != "":
query_string += "&"
else:
query_string = "?"
query_string += (k + "=" + str(v))
return query_string
def redirect_to_login(request):
return HttpResponseRedirect("/login/?__from=" + urllib.quote(request.path))
def rand_str(length=32, type="lower_hex"):
"""
生成指定长度的随机字符串或者数字, 只用于随机编号等, 不要用于密钥等场景
:param length: 字符串或者数字的长度
:param type: str 代表随机字符串num 代表随机数字
:return: 字符串
"""
if type == "str":
return ''.join(random.choice("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789") for i in range(length))
elif type == "lower_str":
return ''.join(random.choice("abcdefghijklmnopqrstuvwxyz0123456789") for i in range(length))
elif type == "lower_hex":
return ''.join(random.choice("0123456789abcdef") for i in range(length))
else:
return random.choice("123456789") + ''.join(random.choice("0123456789") for i in range(length - 1))

View File

@@ -1,26 +0,0 @@
# copy from https://github.com/DMOJ/judge/issues/162
try:
# in large part from http://code.activestate.com/recipes/578899-strsignal/
import signal
import ctypes
import ctypes.util
libc = ctypes.CDLL(ctypes.util.find_library("c"))
strsignal_c = ctypes.CFUNCTYPE(ctypes.c_char_p, ctypes.c_int)(("strsignal", libc), ((1,),))
NSIG = signal.NSIG
def strsignal_ctypes_wrapper(signo):
# The behavior of the C library strsignal() is unspecified if
# called with an out-of-range argument. Range-check on entry
# _and_ NULL-check on exit.
if 0 <= signo < NSIG:
s = strsignal_c(signo)
if s:
return s.decode("utf-8")
return "Unknown signal %d" % signo
strsignal = strsignal_ctypes_wrapper
except:
strsignal = lambda x: 'signal %d' % x

View File

@@ -1 +0,0 @@
# coding=utf-8

View File

@@ -1,11 +0,0 @@
# coding=utf-8
from django import template
from announcement.models import Announcement
def public_announcement_list():
return Announcement.objects.filter(visible=True).order_by("-create_time")
register = template.Library()
register.assignment_tag(public_announcement_list, name="public_announcement_list")

View File

@@ -1,73 +0,0 @@
# coding=utf-8
import json
def get_contest_status(contest):
status = contest.status
if status == 1:
return "没有开始"
elif status == -1:
return "已经结束"
else:
return "正在进行"
def get_contest_status_color(contest):
status = contest.status
if status == 1:
return "info"
elif status == -1:
return "warning"
else:
return "success"
def get_the_formatted_time(seconds):
if not seconds:
return ""
seconds = int(seconds)
hour = seconds / (60 * 60)
minute = (seconds - hour * 60 * 60) / 60
second = seconds - hour * 60 * 60 - minute * 60
return str(hour) + ":" + str(minute) + ":" + str(second)
def get_submission_class(rank, problem):
submission_info = json.loads(rank["submission_info"])
if str(problem.id) not in submission_info:
return ""
else:
submission = submission_info[str(problem.id)]
if submission["is_ac"]:
_class = "alert-success"
if submission["is_first_ac"]:
_class += " first-achieved"
else:
_class = "alert-danger"
return _class
def get_submission_content(rank, problem):
submission_info = json.loads(rank["submission_info"])
if str(problem.id) not in submission_info:
return ""
else:
submission = submission_info[str(problem.id)]
if submission["is_ac"]:
r = get_the_formatted_time(submission["ac_time"])
if submission["error_number"]:
r += "<br>-" + str(submission["error_number"]) + ""
return r
else:
return "-" + str(submission["error_number"]) + ""
from django import template
register = template.Library()
register.filter("contest_status", get_contest_status)
register.filter("contest_status_color", get_contest_status_color)
register.filter("format_seconds", get_the_formatted_time)
register.simple_tag(get_submission_class, name="get_submission_class")
register.simple_tag(get_submission_content, name="get_submission_content")

View File

@@ -1,27 +0,0 @@
# coding=utf-8
def get_problem_accepted_radio(problem):
if problem.total_submit_number:
return str(int((problem.total_accepted_number * 100) / problem.total_submit_number)) \
+ "% (" + str(problem.total_accepted_number) + " / " + str(problem.total_submit_number) + ")"
return "0%"
def get_problem_status(problems_status, problem_id):
# 用户没登陆 或者 user.problem_status 中没有这个字段都会到导致这里的problem_status 为 ""
if not problems_status:
return ""
if str(problem_id) in problems_status:
if problems_status[str(problem_id)] == 1:
return "glyphicon glyphicon-ok ac-flag"
return "glyphicon glyphicon-minus dealing-flag"
return ""
from django import template
register = template.Library()
register.filter("accepted_radio", get_problem_accepted_radio)
register.simple_tag(get_problem_status, name="get_problem_status")

View File

@@ -1,44 +0,0 @@
# coding=utf-8
from django import template
from utils.signal2str import strsignal
def translate_result(value):
results = {
0: "Accepted",
1: "Runtime Error",
2: "Time Limit Exceeded",
3: "Memory Limit Exceeded",
4: "Compile Error",
5: "Format Error",
6: "Wrong Answer",
7: "System Error",
8: "Waiting"
}
return results[value]
def translate_signal(value):
if not value:
return ""
else:
return strsignal(value)
def translate_language(value):
return {1: "C", 2: "C++", 3: "Java"}[value]
def translate_result_class(value):
if value == 0:
return "success"
elif value == 8:
return "info"
return "danger"
register = template.Library()
register.filter("translate_result", translate_result)
register.filter("translate_language", translate_language)
register.filter("translate_result_class", translate_result_class)
register.filter("translate_signal", translate_signal)

View File

@@ -1,17 +0,0 @@
# coding=utf-8
import datetime
from account.models import User
def get_username(user_id):
try:
return User.objects.get(id=user_id).username
except User.DoesNotExist:
return ""
from django import template
register = template.Library()
register.filter("get_username", get_username)

View File

@@ -1,9 +0,0 @@
# coding=utf-8
from django import template
from django.conf import settings
register = template.Library()
@register.simple_tag
def show_website_info(name):
return settings.WEBSITE_INFO[name]

View File

@@ -1,9 +0,0 @@
# coding=utf-8
from django.conf.urls import include, url
urlpatterns = [
url(r'^paginate_test/$', "utils.tests.pagination_test_func"),
]

View File

@@ -1,68 +1,30 @@
# coding=utf-8
from rest_framework.test import APIClient, APITestCase
from rest_framework import serializers
from rest_framework.decorators import api_view
from __future__ import unicode_literals
from django.test.testcases import TestCase
from rest_framework.test import APIClient
from account.models import User
from .shortcuts import paginate
from account.models import User, AdminType
class PaginationTestSerialiser(serializers.Serializer):
username = serializers.CharField(max_length=100)
class APITestCase(TestCase):
client_class = APIClient
def create_user(self, username, password, admin_type=AdminType.REGULAR_USER, login=False):
user = User.objects.create(username=username, admin_type=admin_type)
user.set_password(password)
user.save()
if login:
self.client.login(username=username, password=password)
return user
@api_view(["GET"])
def pagination_test_func(request):
return paginate(request, User.objects.all(), PaginationTestSerialiser)
def create_admin(self, username="admin", password="admin", login=False):
return self.create_user(username=username, password=password, admin_type=AdminType.ADMIN, login=login)
def create_super_admin(self, username="root", password="root", login=False):
return self.create_user(username=username, password=password, admin_type=AdminType.SUPER_ADMIN, login=login)
class PaginatorTest(APITestCase):
urls = "utils.test_urls"
def assertSuccess(self, response):
self.assertTrue(response.data["error"] is None)
def setUp(self):
self.client = APIClient()
self.url = "/paginate_test/"
User.objects.create(username="test1")
User.objects.create(username="test2")
def test_no_paginate(self):
response = self.client.get(self.url)
self.assertEqual(response.data["code"], 0)
self.assertNotIn("next_page", response.data["data"])
self.assertNotIn("previous_page", response.data["data"])
def test_error_parameter(self):
response = self.client.get(self.url + "?paging=true")
self.assertEqual(response.data, {"code": 1, "data": u"参数错误"})
response = self.client.get(self.url + "?paging=true&page_size=-1")
self.assertEqual(response.data, {"code": 1, "data": u"参数错误"})
response = self.client.get(self.url + "?paging=true&page_size=aa")
self.assertEqual(response.data, {"code": 1, "data": u"参数错误"})
response = self.client.get(self.url + "?paging=true&page_size=1&page=-1")
self.assertEqual(response.data, {"code": 1, "data": u"参数错误"})
response = self.client.get(self.url + "?paging=true&page_size=aaa&page=1")
self.assertEqual(response.data, {"code": 1, "data": u"参数错误"})
response = self.client.get(self.url + "?paging=true&page_size=1&page=aaa")
self.assertEqual(response.data, {"code": 1, "data": u"参数错误"})
def test_correct_paginate(self):
response = self.client.get(self.url + "?paging=true&page_size=1&page=1")
self.assertEqual(response.data["code"], 0)
self.assertEqual(response.data["data"]["previous_page"], None)
self.assertEqual(response.data["data"]["next_page"], 2)
self.assertEqual(len(response.data["data"]["results"]), 1)
self.assertEqual(response.data["data"]["count"], 2)
self.assertEqual(response.data["data"]["total_page"], 2)
response = self.client.get(self.url + "?paging=true&page_size=2&page=1")
self.assertEqual(response.data["code"], 0)
self.assertEqual(response.data["data"]["previous_page"], None)
self.assertEqual(response.data["data"]["next_page"], None)
self.assertEqual(len(response.data["data"]["results"]), 2)
self.assertEqual(response.data["data"]["count"], 2)
self.assertEqual(response.data["data"]["total_page"], 1)
def assertFailed(self, response):
self.assertTrue(response.data["error"] is not None)

View File

@@ -1,94 +0,0 @@
# coding=utf-8
import time
import redis
class TokenBucket(object):
def __init__(self, fill_rate, capacity, last_capacity, last_timestamp):
self.capacity = float(capacity)
self._left_tokens = last_capacity
self.fill_rate = float(fill_rate)
self.timestamp = last_timestamp
def consume(self, tokens=1):
if tokens <= self.tokens:
self._left_tokens -= tokens
return True
return False
def expected_time(self, tokens=1):
_tokens = self.tokens
tokens = max(tokens, _tokens)
return (tokens - _tokens) / self.fill_rate * 60
@property
def tokens(self):
if self._left_tokens < self.capacity:
now = time.time()
delta = self.fill_rate * ((now - self.timestamp) / 60)
self._left_tokens = min(self.capacity, self._left_tokens + delta)
self.timestamp = now
return self._left_tokens
class BucketController(object):
def __init__(self, user_id, redis_conn, default_capacity):
self.user_id = user_id
self.default_capacity = default_capacity
self.redis = redis_conn
self.key = "bucket_" + str(self.user_id)
@property
def last_capacity(self):
value = self.redis.hget(self.key, "last_capacity")
if value is None:
self.last_capacity = self.default_capacity
return self.default_capacity
return int(value)
@last_capacity.setter
def last_capacity(self, value):
self.redis.hset(self.key, "last_capacity", value)
@property
def last_timestamp(self):
value = self.redis.hget(self.key, "last_timestamp")
if value is None:
timestamp = int(time.time())
self.last_timestamp = timestamp
return timestamp
return int(value)
@last_timestamp.setter
def last_timestamp(self, value):
self.redis.hset(self.key, "last_timestamp", value)
"""
# token bucket 机制限制用户提交大量代码
# demo
success = failure = 0
current_user_id = 1
token_bucket_default_capacity = 50
token_bucket_fill_rate = 10
for i in range(5000):
controller = BucketController(user_id=current_user_id,
redis_conn=redis.Redis(),
default_capacity=token_bucket_default_capacity)
bucket = TokenBucket(fill_rate=token_bucket_fill_rate,
capacity=token_bucket_default_capacity,
last_capacity=controller.last_capacity,
last_timestamp=controller.last_timestamp)
time.sleep(0.05)
if bucket.consume():
success += 1
print i, ": Accepted"
controller.last_capacity -= 1
else:
failure += 1
print i, "Dropped, time left ", bucket.expected_time()
print success, failure
"""

View File

@@ -1,37 +0,0 @@
# coding=utf-8
from rest_framework.views import APIView
from rest_framework.response import Response
from django.conf import settings
from utils.shortcuts import rand_str
import logging
logger = logging.getLogger("app_info")
class SimditorImageUploadAPIView(APIView):
def post(self, request):
if "image" not in request.FILES:
return Response(data={
"success": False,
"msg": "上传失败",
"file_path": "/"})
img = request.FILES["image"]
image_name = rand_str() + '.' + str(request.FILES["image"].name.split('.')[-1])
image_dir = settings.IMAGE_UPLOAD_DIR + image_name
try:
with open(image_dir, "wb") as imageFile:
for chunk in img:
imageFile.write(chunk)
except IOError as e:
logger.error(e)
return Response(data={
"success": True,
"msg": "上传错误",
"file_path": "/static/upload/" + image_name})
return Response(data={
"success": True,
"msg": "",
"file_path": "/static/upload/" + image_name})