使用Python3和更科学的API写法
This commit is contained in:
2
utils/api/__init__.py
Normal file
2
utils/api/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .api import *
|
||||
from ._serializers import *
|
||||
17
utils/api/_serializers.py
Normal file
17
utils/api/_serializers.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import json
|
||||
|
||||
from django.utils import timezone
|
||||
|
||||
from rest_framework import serializers
|
||||
|
||||
|
||||
class JSONField(serializers.Field):
|
||||
def to_representation(self, value):
|
||||
return json.loads(value)
|
||||
|
||||
|
||||
class DateTimeTZField(serializers.DateTimeField):
|
||||
def to_representation(self, value):
|
||||
self.format = "%Y-%-m-%d %-H:%-M:%-S"
|
||||
value = timezone.localtime(value)
|
||||
return super(DateTimeTZField, self).to_representation(value)
|
||||
179
utils/api/api.py
Normal file
179
utils/api/api.py
Normal file
@@ -0,0 +1,179 @@
|
||||
# coding=utf-8
|
||||
import json
|
||||
import logging
|
||||
|
||||
from django.http import HttpResponse, QueryDict
|
||||
from django.utils.decorators import method_decorator
|
||||
from django.views.decorators.csrf import csrf_exempt
|
||||
from django.views.generic import View
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ContentType(object):
|
||||
json_request = "application/json"
|
||||
json_response = "application/json;charset=UTF-8"
|
||||
url_encoded_request = "application/x-www-form-urlencoded"
|
||||
binary_response = "application/octet-stream"
|
||||
|
||||
|
||||
class JSONParser(object):
|
||||
content_type = ContentType.json_request
|
||||
|
||||
@staticmethod
|
||||
def parse(body):
|
||||
return json.loads(body.decode("utf-8"))
|
||||
|
||||
|
||||
class URLEncodedParser(object):
|
||||
content_type = ContentType.url_encoded_request
|
||||
|
||||
@staticmethod
|
||||
def parse(body):
|
||||
return QueryDict(body).dict()
|
||||
|
||||
|
||||
class JSONResponse(object):
|
||||
content_type = ContentType.json_response
|
||||
|
||||
@classmethod
|
||||
def response(cls, data):
|
||||
resp = HttpResponse(json.dumps(data, indent=4), content_type=cls.content_type)
|
||||
resp.data = data
|
||||
return resp
|
||||
|
||||
|
||||
class APIView(View):
|
||||
"""
|
||||
Django view的父类, 和django-rest-framework的用法基本一致
|
||||
- request.data获取解析之后的json或者urlencoded数据, dict类型
|
||||
- self.success, self.error和self.invalid_serializer可以根据业需求修改,
|
||||
写到父类中是为了不同的人开发写法统一,不再使用自己的success/error格式
|
||||
- self.response 返回一个django HttpResponse, 具体在self.response_class中实现
|
||||
- parse请求的类需要定义在request_parser中, 目前只支持json和urlencoded的类型, 用来解析请求的数据
|
||||
"""
|
||||
request_parsers = (JSONParser, URLEncodedParser)
|
||||
response_class = JSONResponse
|
||||
|
||||
def _get_request_data(self, request):
|
||||
if request.method != "GET":
|
||||
body = request.body
|
||||
content_type = request.META.get("CONTENT_TYPE")
|
||||
if not content_type:
|
||||
raise ValueError("content_type is required")
|
||||
for parser in self.request_parsers:
|
||||
if content_type.startswith(parser.content_type):
|
||||
break
|
||||
else:
|
||||
raise ValueError("unknown content_type '%s'" % content_type)
|
||||
if body:
|
||||
return parser.parse(body)
|
||||
return {}
|
||||
return request.GET
|
||||
|
||||
def response(self, data):
|
||||
return self.response_class.response(data)
|
||||
|
||||
def success(self, data=None):
|
||||
return self.response({"error": None, "data": data})
|
||||
|
||||
def error(self, msg, err="error"):
|
||||
return self.response({"error": err, "data": msg})
|
||||
|
||||
def invalid_serializer(self, serializer):
|
||||
for k, v in serializer.errors.items():
|
||||
if k != "non_field_errors":
|
||||
return self.error(err="invalid-" + k, msg=k + ": " + v[0])
|
||||
else:
|
||||
return self.error(err="invalid-field", msg=k[0])
|
||||
|
||||
def server_error(self):
|
||||
return self.error(err="server-error", msg="server error")
|
||||
|
||||
def paginate_data(self, request, query_set, object_serializer=None):
|
||||
"""
|
||||
:param request: django的request
|
||||
:param query_set: django model的query set或者其他list like objects
|
||||
:param object_serializer: 用来序列化query set, 如果为None, 则直接对query set切片
|
||||
:return:
|
||||
"""
|
||||
need_paginate = request.GET.get("limit", None)
|
||||
if need_paginate is None:
|
||||
if object_serializer:
|
||||
return object_serializer(query_set, many=True).data
|
||||
else:
|
||||
return query_set
|
||||
try:
|
||||
limit = int(request.GET.get("limit", "100"))
|
||||
except ValueError:
|
||||
limit = 100
|
||||
if limit < 0:
|
||||
limit = 100
|
||||
try:
|
||||
offset = int(request.GET.get("offset", "0"))
|
||||
except ValueError:
|
||||
offset = 0
|
||||
if offset < 0:
|
||||
offset = 0
|
||||
results = query_set[offset:offset + limit]
|
||||
if object_serializer:
|
||||
count = query_set.count()
|
||||
results = object_serializer(results, many=True).data
|
||||
else:
|
||||
count = len(query_set)
|
||||
data = {"results": results,
|
||||
"total": count}
|
||||
return data
|
||||
|
||||
def dispatch(self, request, *args, **kwargs):
|
||||
try:
|
||||
request.data = self._get_request_data(self.request)
|
||||
except ValueError as e:
|
||||
return self.error(err="invalid-request", msg=str(e))
|
||||
try:
|
||||
return super(APIView, self).dispatch(request, *args, **kwargs)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
return self.server_error()
|
||||
|
||||
|
||||
class CSRFExemptAPIView(APIView):
|
||||
@method_decorator(csrf_exempt)
|
||||
def dispatch(self, request, *args, **kwargs):
|
||||
return super(CSRFExemptAPIView, self).dispatch(request, *args, **kwargs)
|
||||
|
||||
|
||||
class SNServerAPIView(CSRFExemptAPIView):
|
||||
def empty_response(self):
|
||||
resp = HttpResponse()
|
||||
resp["Content-Length"] = 0
|
||||
return resp
|
||||
|
||||
def response(self, data):
|
||||
resp = super(SNServerAPIView, self).response(data)
|
||||
resp["Content-Length"] = len(resp.content)
|
||||
return resp
|
||||
|
||||
|
||||
def validate_serializer(serializer):
|
||||
"""
|
||||
@validate_serializer(TestSerializer)
|
||||
def post(self, request):
|
||||
return self.success(request.data)
|
||||
"""
|
||||
def validate(view_method):
|
||||
def handle(*args, **kwargs):
|
||||
self = args[0]
|
||||
request = args[1]
|
||||
s = serializer(data=request.data)
|
||||
if s.is_valid():
|
||||
request.data = s.data
|
||||
request.serializer = s
|
||||
return view_method(*args, **kwargs)
|
||||
else:
|
||||
return self.invalid_serializer(s)
|
||||
|
||||
return handle
|
||||
|
||||
return validate
|
||||
@@ -1,6 +1,6 @@
|
||||
# coding=utf-8
|
||||
from __future__ import unicode_literals
|
||||
from django.test.testcases import TestCase
|
||||
from django.core.urlresolvers import reverse
|
||||
|
||||
from rest_framework.test import APIClient
|
||||
|
||||
from account.models import User, AdminType
|
||||
@@ -23,6 +23,9 @@ class APITestCase(TestCase):
|
||||
def create_super_admin(self, username="root", password="root", login=False):
|
||||
return self.create_user(username=username, password=password, admin_type=AdminType.SUPER_ADMIN, login=login)
|
||||
|
||||
def reverse(self, url_name):
|
||||
return reverse(url_name)
|
||||
|
||||
def assertSuccess(self, response):
|
||||
self.assertTrue(response.data["error"] is None)
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
# coding=utf-8
|
||||
import json
|
||||
|
||||
from django.utils import timezone
|
||||
|
||||
from rest_framework import serializers
|
||||
|
||||
|
||||
class JSONField(serializers.Field):
|
||||
def to_representation(self, value):
|
||||
return json.loads(value)
|
||||
|
||||
|
||||
class DateTimeTZField(serializers.DateTimeField):
|
||||
def to_representation(self, value):
|
||||
self.format = "%Y-%-m-%d %-H:%-M:%-S"
|
||||
value = timezone.localtime(value)
|
||||
return super(DateTimeTZField, self).to_representation(value)
|
||||
@@ -1,103 +1,25 @@
|
||||
# coding=utf-8
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
|
||||
from django.http import HttpResponse
|
||||
from django.views.generic import View
|
||||
from django.utils.crypto import get_random_string
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def JSONResponse(data, content_type="application/json"):
|
||||
resp = HttpResponse(json.dumps(data, indent=4), content_type=content_type)
|
||||
resp.data = data
|
||||
return resp
|
||||
|
||||
|
||||
class APIView(View):
|
||||
def _get_request_json(self, request):
|
||||
if request.method != "GET":
|
||||
body = request.body
|
||||
if body:
|
||||
return json.loads(body.decode("utf-8"))
|
||||
return {}
|
||||
return request.GET
|
||||
|
||||
def success(self, data=None):
|
||||
return JSONResponse({"error": None, "data": data})
|
||||
|
||||
def error(self, message, error="error"):
|
||||
return JSONResponse({"error": error, "data": message})
|
||||
|
||||
def invalid_serializer(self, serializer):
|
||||
for k, v in serializer.errors.items():
|
||||
return self.error(k + ": " + v[0], error="invalid-data-format")
|
||||
|
||||
def server_error(self):
|
||||
return self.error("Server Error")
|
||||
|
||||
def dispatch(self, request, *args, **kwargs):
|
||||
try:
|
||||
request.data = self._get_request_json(self.request)
|
||||
except ValueError:
|
||||
return self.error("Invalid JSON")
|
||||
try:
|
||||
return super(APIView, self).dispatch(request, *args, **kwargs)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
return self.server_error()
|
||||
|
||||
|
||||
def paginate_data(request, query_set, object_serializer):
|
||||
"""
|
||||
function used to paginate data
|
||||
"""
|
||||
need_paginate = request.GET.get("paging", None)
|
||||
# if paging=true not in request.GET, then we return all data
|
||||
if need_paginate != "true":
|
||||
if object_serializer:
|
||||
return object_serializer(query_set, many=True).data
|
||||
else:
|
||||
return query_set
|
||||
|
||||
try:
|
||||
limit = int(request.GET.get("limit", "100"))
|
||||
except ValueError:
|
||||
limit = 100
|
||||
if limit < 0:
|
||||
limit = 100
|
||||
|
||||
try:
|
||||
offset = int(request.GET.get("offset", "0"))
|
||||
except ValueError:
|
||||
offset = 0
|
||||
if offset < 0:
|
||||
offset = 0
|
||||
|
||||
results = query_set[offset:offset + limit]
|
||||
if object_serializer:
|
||||
count = query_set.count()
|
||||
results = object_serializer(results, many=True).data
|
||||
else:
|
||||
count = len(query_set)
|
||||
|
||||
data = {"results": results,
|
||||
"count": count}
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def rand_str(length=32, type="lower_hex"):
|
||||
"""
|
||||
generate types of random string or number with specific length
|
||||
DO NOT USE TO GENERATE SECRET KEY!
|
||||
生成指定长度的随机字符串或者数字, 可以用于密钥等安全场景
|
||||
:param length: 字符串或者数字的长度
|
||||
:param type: str 代表随机字符串,num 代表随机数字
|
||||
:return: 字符串
|
||||
"""
|
||||
if type == "str":
|
||||
return ''.join(random.choice("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789") for i in range(length))
|
||||
return get_random_string(length, allowed_chars="ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789")
|
||||
elif type == "lower_str":
|
||||
return ''.join(random.choice("abcdefghijklmnopqrstuvwxyz0123456789") for i in range(length))
|
||||
return get_random_string(length, allowed_chars="abcdefghijklmnopqrstuvwxyz0123456789")
|
||||
elif type == "lower_hex":
|
||||
return ''.join(random.choice("0123456789abcdef") for i in range(length))
|
||||
return random.choice("123456789abcdef") + get_random_string(length - 1, allowed_chars="0123456789abcdef")
|
||||
else:
|
||||
return random.choice("123456789") + ''.join(random.choice("0123456789") for i in range(length - 1))
|
||||
return random.choice("123456789") + get_random_string(length - 1, allowed_chars="0123456789")
|
||||
Reference in New Issue
Block a user