import asyncio import functools import inspect import json import logging from asgiref.sync import sync_to_async from django.http import HttpResponse, QueryDict from django.utils.decorators import method_decorator from django.views.decorators.csrf import csrf_exempt from django.views.generic import View logger = logging.getLogger("") class APIError(Exception): def __init__(self, msg, err=None): self.err = err self.msg = msg super().__init__(err, msg) class ContentType(object): json_request = "application/json" json_response = "application/json;charset=UTF-8" url_encoded_request = "application/x-www-form-urlencoded" binary_response = "application/octet-stream" class JSONParser(object): content_type = ContentType.json_request @staticmethod def parse(body): return json.loads(body.decode("utf-8")) class URLEncodedParser(object): content_type = ContentType.url_encoded_request @staticmethod def parse(body): return QueryDict(body) class JSONResponse(object): content_type = ContentType.json_response @classmethod def response(cls, data): resp = HttpResponse(json.dumps(data, indent=4), content_type=cls.content_type) resp.data = data return resp class APIView(View): """ Django view的父类, 和django-rest-framework的用法基本一致 - request.data获取解析之后的json或者urlencoded数据, dict类型 - self.success, self.error和self.invalid_serializer可以根据业需求修改, 写到父类中是为了不同的人开发写法统一,不再使用自己的success/error格式 - self.response 返回一个django HttpResponse, 具体在self.response_class中实现 - parse请求的类需要定义在request_parser中, 目前只支持json和urlencoded的类型, 用来解析请求的数据 """ request_parsers = (JSONParser, URLEncodedParser) response_class = JSONResponse def _get_request_data(self, request): if request.method not in ["GET", "DELETE"]: body = request.body content_type = request.META.get("CONTENT_TYPE") if not content_type: raise ValueError("content_type is required") for parser in self.request_parsers: if content_type.startswith(parser.content_type): break # else means the for loop is not interrupted by break else: raise ValueError("unknown content_type '%s'" % content_type) if body: return parser.parse(body) return {} return request.GET def response(self, data): return self.response_class.response(data) def success(self, data=None): return self.response({"error": None, "data": data}) def error(self, msg="error", err="error"): return self.response({"error": err, "data": msg}) def extract_errors(self, errors, key="field"): if isinstance(errors, dict): if not errors: return key, "Invalid field" key = list(errors.keys())[0] return self.extract_errors(errors.pop(key), key) elif isinstance(errors, list): return self.extract_errors(errors[0], key) return key, errors def invalid_serializer(self, serializer): key, error = self.extract_errors(serializer.errors) if key == "non_field_errors": msg = error else: msg = f"{key}: {error}" return self.error(err=f"invalid-{key}", msg=msg) def server_error(self): return self.error(err="server-error", msg="server error") def paginate_data(self, request, query_set, object_serializer=None): """ :param request: django的request :param query_set: django model的query set或者其他list like objects :param object_serializer: 用来序列化query set, 如果为None, 则直接对query set切片 :return: """ try: limit = int(request.GET.get("limit", "10")) except ValueError: limit = 10 if limit < 0 or limit > 250: limit = 10 try: offset = int(request.GET.get("offset", "0")) except ValueError: offset = 0 if offset < 0: offset = 0 # 只调用一次 count(),避免重复查询 count = query_set.count() results = query_set[offset:offset + limit] if object_serializer: results = object_serializer(results, many=True, context={"request": request}).data data = {"results": results, "total": count} return data def dispatch(self, request, *args, **kwargs): if self.request_parsers: try: request.data = self._get_request_data(self.request) except ValueError as e: return self.error(err="invalid-request", msg=str(e)) try: return super(APIView, self).dispatch(request, *args, **kwargs) except APIError as e: ret = {"msg": e.msg} if e.err: ret["err"] = e.err return self.error(**ret) except Exception as e: logger.exception(e) return self.server_error() class CSRFExemptAPIView(APIView): @method_decorator(csrf_exempt) def dispatch(self, request, *args, **kwargs): return super(CSRFExemptAPIView, self).dispatch(request, *args, **kwargs) class AsyncAPIView(APIView): view_is_async = True def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) cls.view_is_async = True async def dispatch(self, request, *args, **kwargs): if self.request_parsers: try: request.data = self._get_request_data(self.request) except ValueError as e: return self.error(err="invalid-request", msg=str(e)) try: handler = getattr(self, request.method.lower(), self.http_method_not_allowed) response = handler(request, *args, **kwargs) if asyncio.iscoroutine(response): response = await response return response except APIError as e: ret = {"msg": e.msg} if e.err: ret["err"] = e.err return self.error(**ret) except Exception as e: logger.exception(e) return self.server_error() def serialize_data(self, object_serializer, data, **kwargs): return object_serializer(data, **kwargs).data async def async_serialize_data(self, object_serializer, data, **kwargs): return await sync_to_async( self.serialize_data, thread_sensitive=True, )(object_serializer, data, **kwargs) async def async_paginate_data(self, request, query_set, object_serializer=None): try: limit = int(request.GET.get("limit", "10")) except ValueError: limit = 10 if limit < 0 or limit > 250: limit = 10 try: offset = int(request.GET.get("offset", "0")) except ValueError: offset = 0 if offset < 0: offset = 0 async def _slice(): return [item async for item in query_set[offset:offset + limit]] count, results = await asyncio.gather( query_set.acount(), _slice(), ) if object_serializer: results = await self.async_serialize_data( object_serializer, results, many=True, context={"request": request}, ) data = {"results": results, "total": count} return data class CSRFExemptAsyncAPIView(AsyncAPIView): @method_decorator(csrf_exempt) async def dispatch(self, request, *args, **kwargs): return await super().dispatch(request, *args, **kwargs) def validate_serializer(serializer): """ @validate_serializer(TestSerializer) def post(self, request): return self.success(request.data) """ def validate(view_method): if inspect.iscoroutinefunction(view_method): @functools.wraps(view_method) async def async_handle(*args, **kwargs): self = args[0] request = args[1] s = serializer(data=request.data) if await sync_to_async(s.is_valid)(): request.data = s.data request.serializer = s response = view_method(*args, **kwargs) if asyncio.iscoroutine(response): return await response return response else: return self.invalid_serializer(s) return async_handle @functools.wraps(view_method) def handle(*args, **kwargs): self = args[0] request = args[1] s = serializer(data=request.data) if s.is_valid(): request.data = s.data request.serializer = s return view_method(*args, **kwargs) else: return self.invalid_serializer(s) return handle return validate