From 172fd4b1f4e1bbffad056dfe86a5228a10d51fa6 Mon Sep 17 00:00:00 2001 From: virusdefender Date: Sat, 19 Nov 2016 12:32:23 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BD=BF=E7=94=A8Python3=E5=92=8C=E6=9B=B4?= =?UTF-8?q?=E7=A7=91=E5=AD=A6=E7=9A=84API=E5=86=99=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- account/decorators.py | 7 +- account/middleware.py | 10 +- account/models.py | 2 - account/serializers.py | 3 +- account/tests.py | 18 ++-- account/urls/admin.py | 1 - account/urls/oj.py | 1 - account/views/admin.py | 100 ++++++++++---------- account/views/oj.py | 114 ++++++++++------------- announcement/models.py | 2 - announcement/serializers.py | 7 +- announcement/tests.py | 6 +- announcement/urls/admin.py | 1 - announcement/views.py | 8 +- conf/models.py | 4 +- utils/api/__init__.py | 2 + utils/api/_serializers.py | 17 ++++ utils/api/api.py | 179 ++++++++++++++++++++++++++++++++++++ utils/{ => api}/tests.py | 7 +- utils/serializers.py | 18 ---- utils/shortcuts.py | 98 ++------------------ 21 files changed, 335 insertions(+), 270 deletions(-) create mode 100644 utils/api/__init__.py create mode 100644 utils/api/_serializers.py create mode 100644 utils/api/api.py rename utils/{ => api}/tests.py (90%) diff --git a/account/decorators.py b/account/decorators.py index d68dda7..c5da921 100644 --- a/account/decorators.py +++ b/account/decorators.py @@ -1,13 +1,10 @@ -# coding=utf-8 from __future__ import unicode_literals -import urllib -import json import functools from django.http import HttpResponse from django.utils.translation import ugettext as _ -from utils.shortcuts import JSONResponse +from utils.api import JSONResponse from .models import AdminType @@ -19,7 +16,7 @@ class BasePermissionDecorator(object): return functools.partial(self.__call__, obj) def error(self, data): - return JSONResponse({"error": "permission-denied", "data": data}) + return JSONResponse.response({"error": "permission-denied", "data": data}) def __call__(self, *args, **kwargs): self.request = args[1] diff --git a/account/middleware.py b/account/middleware.py index bc1c089..d092721 100644 --- a/account/middleware.py +++ b/account/middleware.py @@ -1,12 +1,8 @@ -# coding=utf-8 import time -import json - -from django.http import HttpResponse from django.utils.translation import ugettext as _ from django.contrib import auth -from utils.shortcuts import JSONResponse +from utils.api import JSONResponse from .models import AdminType @@ -17,7 +13,7 @@ class SessionSecurityMiddleware(object): # 24 hours passed since last visit if time.time() - request.session["last_activity"] >= 24 * 60 * 60: auth.logout(request) - return JSONResponse({"error": "login-required", "data": _("Please login in first")}) + return JSONResponse.response({"error": "login-required", "data": _("Please login in first")}) # update last active time request.session["last_activity"] = time.time() @@ -27,4 +23,4 @@ class AdminRequiredMiddleware(object): path = request.path_info if path.startswith("/admin/") or path.startswith("/api/admin/"): if not(request.user.is_authenticated() and request.user.is_admin()): - return JSONResponse({"error": "login-required", "data": _("Please login in first")}) \ No newline at end of file + return JSONResponse.response({"error": "login-required", "data": _("Please login in first")}) \ No newline at end of file diff --git a/account/models.py b/account/models.py index 79ad34a..faeab5d 100644 --- a/account/models.py +++ b/account/models.py @@ -1,5 +1,3 @@ -# coding=utf-8 -from __future__ import unicode_literals from django.contrib.auth.models import AbstractBaseUser from django.db import models from jsonfield import JSONField diff --git a/account/serializers.py b/account/serializers.py index 2dbd34a..ad2197c 100644 --- a/account/serializers.py +++ b/account/serializers.py @@ -1,7 +1,6 @@ # coding=utf-8 -from rest_framework import serializers +from utils.api import serializers, DateTimeTZField -from utils.serializers import DateTimeTZField from .models import User, AdminType diff --git a/account/tests.py b/account/tests.py index a3d55ef..8e5b855 100644 --- a/account/tests.py +++ b/account/tests.py @@ -1,17 +1,13 @@ -# coding=utf-8 -from __future__ import unicode_literals - import time +from unittest import mock -import mock from django.contrib import auth -from django.core.urlresolvers import reverse from django.utils.translation import ugettext as _ -from rest_framework.test import APIClient from utils.otp_auth import OtpAuth from utils.shortcuts import rand_str -from utils.tests import APITestCase +from utils.api.tests import APITestCase, APIClient + from .models import User, AdminType @@ -37,7 +33,7 @@ class UserLoginAPITest(APITestCase): def setUp(self): self.username = self.password = "test" self.user = self.create_user(username=self.username, password=self.password) - self.login_url = reverse("user_login_api") + self.login_url = self.reverse("user_login_api") def _set_tfa(self): self.user.two_factor_auth = True @@ -110,7 +106,7 @@ class CaptchaTest(APITestCase): class UserRegisterAPITest(CaptchaTest): def setUp(self): self.client = APIClient() - self.register_url = reverse("user_register_api") + self.register_url = self.reverse("user_register_api") self.captcha = rand_str(4) self.data = {"username": "test_user", "password": "testuserpassword", @@ -150,7 +146,7 @@ class UserRegisterAPITest(CaptchaTest): class UserChangePasswordAPITest(CaptchaTest): def setUp(self): self.client = APIClient() - self.url = reverse("user_change_password_api") + self.url = self.reverse("user_change_password_api") # Create user at first self.username = "test_user" @@ -183,7 +179,7 @@ class AdminUserTest(APITestCase): self.user = self.create_super_admin(login=True) self.username = self.password = "test" self.regular_user = self.create_user(username=self.username, password=self.password) - self.url = reverse("user_admin_api") + self.url = self.reverse("user_admin_api") self.data = {"id": self.regular_user.id, "username": self.username, "real_name": "test_name", "email": "test@qq.com", "admin_type": AdminType.REGULAR_USER, "open_api": True, "two_factor_auth": False, "is_disabled": False} diff --git a/account/urls/admin.py b/account/urls/admin.py index 946bbc0..c6193b2 100644 --- a/account/urls/admin.py +++ b/account/urls/admin.py @@ -1,4 +1,3 @@ -# coding=utf-8 from django.conf.urls import url from ..views.admin import UserAdminAPIView diff --git a/account/urls/oj.py b/account/urls/oj.py index f2d5aed..4669cd4 100644 --- a/account/urls/oj.py +++ b/account/urls/oj.py @@ -1,4 +1,3 @@ -# coding=utf-8 from django.conf.urls import url from ..views.oj import UserLoginAPIView, UserRegisterAPIView, UserChangePasswordAPIView diff --git a/account/views/admin.py b/account/views/admin.py index 62be6b3..c6661ea 100644 --- a/account/views/admin.py +++ b/account/views/admin.py @@ -1,75 +1,73 @@ -# coding=utf-8 from __future__ import unicode_literals from django.core.exceptions import MultipleObjectsReturned from django.db.models import Q from django.utils.translation import ugettext as _ -from utils.shortcuts import (APIView, paginate_data, rand_str) +from utils.api import APIView, validate_serializer +from utils.shortcuts import rand_str + from ..decorators import super_admin_required -from ..models import User, AdminType +from ..models import User from ..serializers import (UserSerializer, EditUserSerializer) class UserAdminAPIView(APIView): + @validate_serializer(EditUserSerializer) @super_admin_required def put(self, request): """ Edit user api """ - serializer = EditUserSerializer(data=request.data) - if serializer.is_valid(): - data = serializer.data - try: - user = User.objects.get(id=data["id"]) - except User.DoesNotExist: - return self.error(_("User does not exist")) - try: - user = User.objects.get(username=data["username"]) - if user.id != data["id"]: - return self.error(_("Username already exists")) - except User.DoesNotExist: - pass + data = request.data + try: + user = User.objects.get(id=data["id"]) + except User.DoesNotExist: + return self.error(_("User does not exist")) + try: + user = User.objects.get(username=data["username"]) + if user.id != data["id"]: + return self.error(_("Username already exists")) + except User.DoesNotExist: + pass - try: - user = User.objects.get(email=data["email"]) - if user.id != data["id"]: - return self.error(_("Email already exists")) - # Some old data has duplicate email - except MultipleObjectsReturned: + try: + user = User.objects.get(email=data["email"]) + if user.id != data["id"]: return self.error(_("Email already exists")) - except User.DoesNotExist: - pass + # Some old data has duplicate email + except MultipleObjectsReturned: + return self.error(_("Email already exists")) + except User.DoesNotExist: + pass - user.username = data["username"] - user.real_name = data["real_name"] - user.email = data["email"] - user.admin_type = data["admin_type"] - user.is_disabled = data["is_disabled"] + user.username = data["username"] + user.real_name = data["real_name"] + user.email = data["email"] + user.admin_type = data["admin_type"] + user.is_disabled = data["is_disabled"] - if data["password"]: - user.set_password(data["password"]) + if data["password"]: + user.set_password(data["password"]) - if data["open_api"]: - # Avoid reset user appkey after saving changes - if not user.open_api: - user.open_api_appkey = rand_str() - else: - user.open_api_appkey = None - user.open_api = data["open_api"] - - if data["two_factor_auth"]: - # Avoid reset user tfa_token after saving changes - if not user.two_factor_auth: - user.tfa_token = rand_str() - else: - user.tfa_token = None - user.two_factor_auth = data["two_factor_auth"] - - user.save() - return self.success(UserSerializer(user).data) + if data["open_api"]: + # Avoid reset user appkey after saving changes + if not user.open_api: + user.open_api_appkey = rand_str() else: - return self.invalid_serializer(serializer) + user.open_api_appkey = None + user.open_api = data["open_api"] + + if data["two_factor_auth"]: + # Avoid reset user tfa_token after saving changes + if not user.two_factor_auth: + user.tfa_token = rand_str() + else: + user.tfa_token = None + user.two_factor_auth = data["two_factor_auth"] + + user.save() + return self.success(UserSerializer(user).data) @super_admin_required def get(self, request): @@ -97,4 +95,4 @@ class UserAdminAPIView(APIView): user = user.filter(Q(username__contains=keyword) | Q(real_name__contains=keyword) | Q(email__contains=keyword)) - return self.success(paginate_data(request, user, UserSerializer)) + return self.success(self.paginate_data(request, user, UserSerializer)) diff --git a/account/views/oj.py b/account/views/oj.py index 853c6ae..8a0079f 100644 --- a/account/views/oj.py +++ b/account/views/oj.py @@ -1,13 +1,10 @@ -# coding=utf-8 -from __future__ import unicode_literals - from django.contrib import auth from django.core.exceptions import MultipleObjectsReturned from django.utils.translation import ugettext as _ +from utils.api import APIView, validate_serializer from utils.captcha import Captcha from utils.otp_auth import OtpAuth -from utils.shortcuts import (APIView, ) from ..decorators import login_required from ..models import User, UserProfile from ..serializers import (UserLoginSerializer, UserRegisterSerializer, @@ -15,33 +12,30 @@ from ..serializers import (UserLoginSerializer, UserRegisterSerializer, class UserLoginAPIView(APIView): + @validate_serializer(UserLoginSerializer) def post(self, request): """ User login api """ - serializer = UserLoginSerializer(data=request.data) - if serializer.is_valid(): - data = serializer.data - user = auth.authenticate(username=data["username"], password=data["password"]) - # None is returned if username or password is wrong - if user: - if not user.two_factor_auth: - auth.login(request, user) - return self.success(_("Succeeded")) + data = request.data + user = auth.authenticate(username=data["username"], password=data["password"]) + # None is returned if username or password is wrong + if user: + if not user.two_factor_auth: + auth.login(request, user) + return self.success(_("Succeeded")) - # `tfa_code` not in post data - if user.two_factor_auth and "tfa_code" not in data: - return self.success("tfa_required") + # `tfa_code` not in post data + if user.two_factor_auth and "tfa_code" not in data: + return self.success("tfa_required") - if OtpAuth(user.tfa_token).valid_totp(data["tfa_code"]): - auth.login(request, user) - return self.success(_("Succeeded")) - else: - return self.error(_("Invalid two factor verification code")) + if OtpAuth(user.tfa_token).valid_totp(data["tfa_code"]): + auth.login(request, user) + return self.success(_("Succeeded")) else: - return self.error(_("Invalid username or password")) + return self.error(_("Invalid two factor verification code")) else: - return self.invalid_serializer(serializer) + return self.error(_("Invalid username or password")) # todo remove this, only for debug use def get(self, request): @@ -50,56 +44,50 @@ class UserLoginAPIView(APIView): class UserRegisterAPIView(APIView): + @validate_serializer(UserRegisterSerializer) def post(self, request): """ User register api """ - serializer = UserRegisterSerializer(data=request.data) - if serializer.is_valid(): - data = serializer.data - captcha = Captcha(request) - if not captcha.check(data["captcha"]): - return self.error(_("Invalid captcha")) - try: - User.objects.get(username=data["username"]) - return self.error(_("Username already exists")) - except User.DoesNotExist: - pass - try: - User.objects.get(email=data["email"]) - return self.error(_("Email already exists")) - # Some old data has duplicate email - except MultipleObjectsReturned: - return self.error(_("Email already exists")) - except User.DoesNotExist: - user = User.objects.create(username=data["username"], email=data["email"]) - user.set_password(data["password"]) - user.save() - UserProfile.objects.create(user=user) - return self.success(_("Succeeded")) - else: - return self.invalid_serializer(serializer) + data = request.data + captcha = Captcha(request) + if not captcha.check(data["captcha"]): + return self.error(_("Invalid captcha")) + try: + User.objects.get(username=data["username"]) + return self.error(_("Username already exists")) + except User.DoesNotExist: + pass + try: + User.objects.get(email=data["email"]) + return self.error(_("Email already exists")) + # Some old data has duplicate email + except MultipleObjectsReturned: + return self.error(_("Email already exists")) + except User.DoesNotExist: + user = User.objects.create(username=data["username"], email=data["email"]) + user.set_password(data["password"]) + user.save() + UserProfile.objects.create(user=user) + return self.success(_("Succeeded")) class UserChangePasswordAPIView(APIView): + @validate_serializer(UserChangePasswordSerializer) @login_required def post(self, request): """ User change password api """ - serializer = UserChangePasswordSerializer(data=request.data) - if serializer.is_valid(): - data = serializer.data - captcha = Captcha(request) - if not captcha.check(data["captcha"]): - return self.error(_("Invalid captcha")) - username = request.user.username - user = auth.authenticate(username=username, password=data["old_password"]) - if user: - user.set_password(data["new_password"]) - user.save() - return self.success(_("Succeeded")) - else: - return self.error(_("Invalid old password")) + data = request.data + captcha = Captcha(request) + if not captcha.check(data["captcha"]): + return self.error(_("Invalid captcha")) + username = request.user.username + user = auth.authenticate(username=username, password=data["old_password"]) + if user: + user.set_password(data["new_password"]) + user.save() + return self.success(_("Succeeded")) else: - return self.invalid_serializer(serializer) + return self.error(_("Invalid old password")) diff --git a/announcement/models.py b/announcement/models.py index 7d1e43b..186d4ea 100644 --- a/announcement/models.py +++ b/announcement/models.py @@ -1,5 +1,3 @@ -# coding=utf-8 -from __future__ import unicode_literals from django.db import models from account.models import User diff --git a/announcement/serializers.py b/announcement/serializers.py index d21ce90..72393c8 100644 --- a/announcement/serializers.py +++ b/announcement/serializers.py @@ -1,10 +1,7 @@ -# coding=utf-8 -from __future__ import unicode_literals - -from rest_framework import serializers +from utils.api import serializers from account.models import User -from utils.serializers import DateTimeTZField +from utils.api._serializers import DateTimeTZField from .models import Announcement diff --git a/announcement/tests.py b/announcement/tests.py index 86037d9..5de3179 100644 --- a/announcement/tests.py +++ b/announcement/tests.py @@ -1,12 +1,10 @@ -# coding=utf-8 -from django.core.urlresolvers import reverse -from utils.tests import APITestCase +from utils.api.tests import APITestCase, APIClient class AnnouncementAdminTest(APITestCase): def setUp(self): self.user = self.create_super_admin(login=True) - self.url = reverse("announcement_admin_api") + self.url = self.reverse("announcement_admin_api") def test_announcement_list(self): response = self.client.get(self.url) diff --git a/announcement/urls/admin.py b/announcement/urls/admin.py index c8b3a67..746ab3a 100644 --- a/announcement/urls/admin.py +++ b/announcement/urls/admin.py @@ -1,4 +1,3 @@ -# coding=utf-8 from django.conf.urls import url from ..views import AnnouncementAdminAPIView diff --git a/announcement/views.py b/announcement/views.py index c0d42c0..406bdcd 100644 --- a/announcement/views.py +++ b/announcement/views.py @@ -1,10 +1,8 @@ -# coding=utf-8 -from __future__ import unicode_literals - from django.utils.translation import ugettext as _ from account.decorators import super_admin_required -from utils.shortcuts import paginate_data, APIView +from utils.api import APIView + from .models import Announcement from .serializers import (CreateAnnouncementSerializer, AnnouncementSerializer, EditAnnouncementSerializer) @@ -63,4 +61,4 @@ class AnnouncementAdminAPIView(APIView): announcement = Announcement.objects.all().order_by("-create_time") if request.GET.get("visible") == "true": announcement = announcement.filter(visible=True) - return self.success(paginate_data(request, announcement, AnnouncementSerializer)) + return self.success(self.paginate_data(request, announcement, AnnouncementSerializer)) diff --git a/conf/models.py b/conf/models.py index 1ce796a..59f0dc6 100644 --- a/conf/models.py +++ b/conf/models.py @@ -19,9 +19,9 @@ class WebsiteConfig(models.Model): base_url = models.CharField(max_length=128, default=None) name = models.CharField(max_length=32, default="Online Judge") name_shortcut = models.CharField(max_length=32, default="oj") - website_footer = models.CharField(max_length=256, default="Online Judge") + website_footer = models.TextField(default="Online Judge") # allow register - register = models.BooleanField(default=True) + allow_register = models.BooleanField(default=True) # submission list show all user's submission submission_list_show_all = models.BooleanField(default=False) diff --git a/utils/api/__init__.py b/utils/api/__init__.py new file mode 100644 index 0000000..0cf3185 --- /dev/null +++ b/utils/api/__init__.py @@ -0,0 +1,2 @@ +from .api import * +from ._serializers import * \ No newline at end of file diff --git a/utils/api/_serializers.py b/utils/api/_serializers.py new file mode 100644 index 0000000..0e4884d --- /dev/null +++ b/utils/api/_serializers.py @@ -0,0 +1,17 @@ +import json + +from django.utils import timezone + +from rest_framework import serializers + + +class JSONField(serializers.Field): + def to_representation(self, value): + return json.loads(value) + + +class DateTimeTZField(serializers.DateTimeField): + def to_representation(self, value): + self.format = "%Y-%-m-%d %-H:%-M:%-S" + value = timezone.localtime(value) + return super(DateTimeTZField, self).to_representation(value) \ No newline at end of file diff --git a/utils/api/api.py b/utils/api/api.py new file mode 100644 index 0000000..a8fd055 --- /dev/null +++ b/utils/api/api.py @@ -0,0 +1,179 @@ +# coding=utf-8 +import json +import logging + +from django.http import HttpResponse, QueryDict +from django.utils.decorators import method_decorator +from django.views.decorators.csrf import csrf_exempt +from django.views.generic import View + + +logger = logging.getLogger(__name__) + + +class ContentType(object): + json_request = "application/json" + json_response = "application/json;charset=UTF-8" + url_encoded_request = "application/x-www-form-urlencoded" + binary_response = "application/octet-stream" + + +class JSONParser(object): + content_type = ContentType.json_request + + @staticmethod + def parse(body): + return json.loads(body.decode("utf-8")) + + +class URLEncodedParser(object): + content_type = ContentType.url_encoded_request + + @staticmethod + def parse(body): + return QueryDict(body).dict() + + +class JSONResponse(object): + content_type = ContentType.json_response + + @classmethod + def response(cls, data): + resp = HttpResponse(json.dumps(data, indent=4), content_type=cls.content_type) + resp.data = data + return resp + + +class APIView(View): + """ + Django view的父类, 和django-rest-framework的用法基本一致 + - request.data获取解析之后的json或者urlencoded数据, dict类型 + - self.success, self.error和self.invalid_serializer可以根据业需求修改, + 写到父类中是为了不同的人开发写法统一,不再使用自己的success/error格式 + - self.response 返回一个django HttpResponse, 具体在self.response_class中实现 + - parse请求的类需要定义在request_parser中, 目前只支持json和urlencoded的类型, 用来解析请求的数据 + """ + request_parsers = (JSONParser, URLEncodedParser) + response_class = JSONResponse + + def _get_request_data(self, request): + if request.method != "GET": + body = request.body + content_type = request.META.get("CONTENT_TYPE") + if not content_type: + raise ValueError("content_type is required") + for parser in self.request_parsers: + if content_type.startswith(parser.content_type): + break + else: + raise ValueError("unknown content_type '%s'" % content_type) + if body: + return parser.parse(body) + return {} + return request.GET + + def response(self, data): + return self.response_class.response(data) + + def success(self, data=None): + return self.response({"error": None, "data": data}) + + def error(self, msg, err="error"): + return self.response({"error": err, "data": msg}) + + def invalid_serializer(self, serializer): + for k, v in serializer.errors.items(): + if k != "non_field_errors": + return self.error(err="invalid-" + k, msg=k + ": " + v[0]) + else: + return self.error(err="invalid-field", msg=k[0]) + + def server_error(self): + return self.error(err="server-error", msg="server error") + + def paginate_data(self, request, query_set, object_serializer=None): + """ + :param request: django的request + :param query_set: django model的query set或者其他list like objects + :param object_serializer: 用来序列化query set, 如果为None, 则直接对query set切片 + :return: + """ + need_paginate = request.GET.get("limit", None) + if need_paginate is None: + if object_serializer: + return object_serializer(query_set, many=True).data + else: + return query_set + try: + limit = int(request.GET.get("limit", "100")) + except ValueError: + limit = 100 + if limit < 0: + limit = 100 + try: + offset = int(request.GET.get("offset", "0")) + except ValueError: + offset = 0 + if offset < 0: + offset = 0 + results = query_set[offset:offset + limit] + if object_serializer: + count = query_set.count() + results = object_serializer(results, many=True).data + else: + count = len(query_set) + data = {"results": results, + "total": count} + return data + + def dispatch(self, request, *args, **kwargs): + try: + request.data = self._get_request_data(self.request) + except ValueError as e: + return self.error(err="invalid-request", msg=str(e)) + try: + return super(APIView, self).dispatch(request, *args, **kwargs) + except Exception as e: + logger.exception(e) + return self.server_error() + + +class CSRFExemptAPIView(APIView): + @method_decorator(csrf_exempt) + def dispatch(self, request, *args, **kwargs): + return super(CSRFExemptAPIView, self).dispatch(request, *args, **kwargs) + + +class SNServerAPIView(CSRFExemptAPIView): + def empty_response(self): + resp = HttpResponse() + resp["Content-Length"] = 0 + return resp + + def response(self, data): + resp = super(SNServerAPIView, self).response(data) + resp["Content-Length"] = len(resp.content) + return resp + + +def validate_serializer(serializer): + """ + @validate_serializer(TestSerializer) + def post(self, request): + return self.success(request.data) + """ + def validate(view_method): + def handle(*args, **kwargs): + self = args[0] + request = args[1] + s = serializer(data=request.data) + if s.is_valid(): + request.data = s.data + request.serializer = s + return view_method(*args, **kwargs) + else: + return self.invalid_serializer(s) + + return handle + + return validate \ No newline at end of file diff --git a/utils/tests.py b/utils/api/tests.py similarity index 90% rename from utils/tests.py rename to utils/api/tests.py index 01d88d0..44a03b0 100644 --- a/utils/tests.py +++ b/utils/api/tests.py @@ -1,6 +1,6 @@ -# coding=utf-8 -from __future__ import unicode_literals from django.test.testcases import TestCase +from django.core.urlresolvers import reverse + from rest_framework.test import APIClient from account.models import User, AdminType @@ -23,6 +23,9 @@ class APITestCase(TestCase): def create_super_admin(self, username="root", password="root", login=False): return self.create_user(username=username, password=password, admin_type=AdminType.SUPER_ADMIN, login=login) + def reverse(self, url_name): + return reverse(url_name) + def assertSuccess(self, response): self.assertTrue(response.data["error"] is None) diff --git a/utils/serializers.py b/utils/serializers.py index 154ce81..e69de29 100644 --- a/utils/serializers.py +++ b/utils/serializers.py @@ -1,18 +0,0 @@ -# coding=utf-8 -import json - -from django.utils import timezone - -from rest_framework import serializers - - -class JSONField(serializers.Field): - def to_representation(self, value): - return json.loads(value) - - -class DateTimeTZField(serializers.DateTimeField): - def to_representation(self, value): - self.format = "%Y-%-m-%d %-H:%-M:%-S" - value = timezone.localtime(value) - return super(DateTimeTZField, self).to_representation(value) \ No newline at end of file diff --git a/utils/shortcuts.py b/utils/shortcuts.py index 68cfdef..0506ad9 100644 --- a/utils/shortcuts.py +++ b/utils/shortcuts.py @@ -1,103 +1,25 @@ # coding=utf-8 -import json import logging import random -from django.http import HttpResponse -from django.views.generic import View +from django.utils.crypto import get_random_string + logger = logging.getLogger(__name__) -def JSONResponse(data, content_type="application/json"): - resp = HttpResponse(json.dumps(data, indent=4), content_type=content_type) - resp.data = data - return resp - - -class APIView(View): - def _get_request_json(self, request): - if request.method != "GET": - body = request.body - if body: - return json.loads(body.decode("utf-8")) - return {} - return request.GET - - def success(self, data=None): - return JSONResponse({"error": None, "data": data}) - - def error(self, message, error="error"): - return JSONResponse({"error": error, "data": message}) - - def invalid_serializer(self, serializer): - for k, v in serializer.errors.items(): - return self.error(k + ": " + v[0], error="invalid-data-format") - - def server_error(self): - return self.error("Server Error") - - def dispatch(self, request, *args, **kwargs): - try: - request.data = self._get_request_json(self.request) - except ValueError: - return self.error("Invalid JSON") - try: - return super(APIView, self).dispatch(request, *args, **kwargs) - except Exception as e: - logging.exception(e) - return self.server_error() - - -def paginate_data(request, query_set, object_serializer): - """ - function used to paginate data - """ - need_paginate = request.GET.get("paging", None) - # if paging=true not in request.GET, then we return all data - if need_paginate != "true": - if object_serializer: - return object_serializer(query_set, many=True).data - else: - return query_set - - try: - limit = int(request.GET.get("limit", "100")) - except ValueError: - limit = 100 - if limit < 0: - limit = 100 - - try: - offset = int(request.GET.get("offset", "0")) - except ValueError: - offset = 0 - if offset < 0: - offset = 0 - - results = query_set[offset:offset + limit] - if object_serializer: - count = query_set.count() - results = object_serializer(results, many=True).data - else: - count = len(query_set) - - data = {"results": results, - "count": count} - - return data - - def rand_str(length=32, type="lower_hex"): """ - generate types of random string or number with specific length - DO NOT USE TO GENERATE SECRET KEY! + 生成指定长度的随机字符串或者数字, 可以用于密钥等安全场景 + :param length: 字符串或者数字的长度 + :param type: str 代表随机字符串,num 代表随机数字 + :return: 字符串 """ if type == "str": - return ''.join(random.choice("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789") for i in range(length)) + return get_random_string(length, allowed_chars="ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789") elif type == "lower_str": - return ''.join(random.choice("abcdefghijklmnopqrstuvwxyz0123456789") for i in range(length)) + return get_random_string(length, allowed_chars="abcdefghijklmnopqrstuvwxyz0123456789") elif type == "lower_hex": - return ''.join(random.choice("0123456789abcdef") for i in range(length)) + return random.choice("123456789abcdef") + get_random_string(length - 1, allowed_chars="0123456789abcdef") else: - return random.choice("123456789") + ''.join(random.choice("0123456789") for i in range(length - 1)) \ No newline at end of file + return random.choice("123456789") + get_random_string(length - 1, allowed_chars="0123456789") \ No newline at end of file