async
This commit is contained in:
@@ -1,7 +1,10 @@
|
||||
import asyncio
|
||||
import functools
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
|
||||
from asgiref.sync import sync_to_async
|
||||
from django.http import HttpResponse, QueryDict
|
||||
from django.utils.decorators import method_decorator
|
||||
from django.views.decorators.csrf import csrf_exempt
|
||||
@@ -162,6 +165,77 @@ class CSRFExemptAPIView(APIView):
|
||||
return super(CSRFExemptAPIView, self).dispatch(request, *args, **kwargs)
|
||||
|
||||
|
||||
class AsyncAPIView(APIView):
|
||||
view_is_async = True
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
cls.view_is_async = True
|
||||
|
||||
async def dispatch(self, request, *args, **kwargs):
|
||||
if self.request_parsers:
|
||||
try:
|
||||
request.data = self._get_request_data(self.request)
|
||||
except ValueError as e:
|
||||
return self.error(err="invalid-request", msg=str(e))
|
||||
try:
|
||||
handler = getattr(self, request.method.lower(), self.http_method_not_allowed)
|
||||
response = handler(request, *args, **kwargs)
|
||||
if asyncio.iscoroutine(response):
|
||||
response = await response
|
||||
return response
|
||||
except APIError as e:
|
||||
ret = {"msg": e.msg}
|
||||
if e.err:
|
||||
ret["err"] = e.err
|
||||
return self.error(**ret)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
return self.server_error()
|
||||
|
||||
def serialize_data(self, object_serializer, data, **kwargs):
|
||||
return object_serializer(data, **kwargs).data
|
||||
|
||||
async def async_serialize_data(self, object_serializer, data, **kwargs):
|
||||
return await sync_to_async(
|
||||
self.serialize_data,
|
||||
thread_sensitive=True,
|
||||
)(object_serializer, data, **kwargs)
|
||||
|
||||
async def async_paginate_data(self, request, query_set, object_serializer=None):
|
||||
try:
|
||||
limit = int(request.GET.get("limit", "10"))
|
||||
except ValueError:
|
||||
limit = 10
|
||||
if limit < 0 or limit > 250:
|
||||
limit = 10
|
||||
try:
|
||||
offset = int(request.GET.get("offset", "0"))
|
||||
except ValueError:
|
||||
offset = 0
|
||||
if offset < 0:
|
||||
offset = 0
|
||||
count, results = await asyncio.gather(
|
||||
query_set.acount(),
|
||||
sync_to_async(lambda: list(query_set[offset:offset + limit]), thread_sensitive=True)(),
|
||||
)
|
||||
if object_serializer:
|
||||
results = await self.async_serialize_data(
|
||||
object_serializer,
|
||||
results,
|
||||
many=True,
|
||||
context={"request": request},
|
||||
)
|
||||
data = {"results": results, "total": count}
|
||||
return data
|
||||
|
||||
|
||||
class CSRFExemptAsyncAPIView(AsyncAPIView):
|
||||
@method_decorator(csrf_exempt)
|
||||
async def dispatch(self, request, *args, **kwargs):
|
||||
return await super().dispatch(request, *args, **kwargs)
|
||||
|
||||
|
||||
def validate_serializer(serializer):
|
||||
"""
|
||||
@validate_serializer(TestSerializer)
|
||||
@@ -169,6 +243,20 @@ def validate_serializer(serializer):
|
||||
return self.success(request.data)
|
||||
"""
|
||||
def validate(view_method):
|
||||
if inspect.iscoroutinefunction(view_method):
|
||||
@functools.wraps(view_method)
|
||||
async def async_handle(*args, **kwargs):
|
||||
self = args[0]
|
||||
request = args[1]
|
||||
s = serializer(data=request.data)
|
||||
if s.is_valid():
|
||||
request.data = s.data
|
||||
request.serializer = s
|
||||
return await view_method(*args, **kwargs)
|
||||
else:
|
||||
return self.invalid_serializer(s)
|
||||
return async_handle
|
||||
|
||||
@functools.wraps(view_method)
|
||||
def handle(*args, **kwargs):
|
||||
self = args[0]
|
||||
@@ -180,7 +268,6 @@ def validate_serializer(serializer):
|
||||
return view_method(*args, **kwargs)
|
||||
else:
|
||||
return self.invalid_serializer(s)
|
||||
|
||||
return handle
|
||||
|
||||
return validate
|
||||
|
||||
Reference in New Issue
Block a user