增加通用分页函数和对应的测试

This commit is contained in:
virusdefender
2015-08-05 08:44:28 +08:00
parent 44be61dab6
commit 8a6093d645
4 changed files with 122 additions and 1 deletions

View File

@@ -47,6 +47,7 @@ INSTALLED_APPS = (
'django.contrib.staticfiles', 'django.contrib.staticfiles',
'account', 'account',
'utils',
'rest_framework', 'rest_framework',
'rest_framework_swagger', 'rest_framework_swagger',

View File

@@ -1,4 +1,7 @@
# coding=utf-8 # coding=utf-8
from django.core.paginator import Paginator
from rest_framework import pagination
from rest_framework.response import Response from rest_framework.response import Response
@@ -11,4 +14,48 @@ def serializer_invalid_response(serializer):
def success_response(data): def success_response(data):
return Response(data={"code": 0, "data": data}) return Response(data={"code": 0, "data": data})
def paginate(request, query_set, object_serializer):
"""
用于分页的函数
:param query_set 数据库查询结果
:param object_serializer: 序列化单个object的serializer
:return response
"""
need_paginate = request.GET.get("paging", None)
# 如果请求的参数里面没有paging=true的话 就返回全部数据
if need_paginate != "true":
return success_response(data=object_serializer(query_set, many=True).data)
page_size = request.GET.get("page_size", None)
if not page_size:
return error_response(u"参数错误")
try:
page_size = int(page_size)
except Exception:
return error_response(u"参数错误")
paginator = Paginator(query_set, page_size)
page = request.GET.get("page", None)
try:
current_page = paginator.page(page)
except Exception:
return error_response(u"参数错误")
data = {"results": object_serializer(current_page, many=True).data, "previous_page": None, "next_page": None}
try:
data["previous_page"] = current_page.previous_page_number()
except Exception:
pass
try:
data["next_page"] = current_page.next_page_number()
except Exception:
pass
return success_response(data)

9
utils/test_urls.py Normal file
View File

@@ -0,0 +1,9 @@
# coding=utf-8
from django.conf.urls import include, url
urlpatterns = [
url(r'^paginate_test/$', "utils.tests.pagination_test_func"),
]

64
utils/tests.py Normal file
View File

@@ -0,0 +1,64 @@
# coding=utf-8
from rest_framework.test import APIClient, APITestCase
from rest_framework import serializers
from rest_framework.decorators import api_view
from account.models import User
from .shortcuts import paginate
class PaginationTestSerialiser(serializers.Serializer):
username = serializers.CharField(max_length=100)
@api_view(["GET"])
def pagination_test_func(request):
return paginate(request, User.objects.all(), PaginationTestSerialiser)
class PaginatorTest(APITestCase):
urls = "utils.test_urls"
def setUp(self):
self.client = APIClient()
self.url = "/paginate_test/"
User.objects.create(username="test1")
User.objects.create(username="test2")
def test_no_paginate(self):
response = self.client.get(self.url)
self.assertEqual(response.data["code"], 0)
self.assertNotIn("next_page", response.data["data"])
self.assertNotIn("previous_page", response.data["data"])
def test_error_parameter(self):
response = self.client.get(self.url + "?paging=true")
self.assertEqual(response.data, {"code": 1, "data": u"参数错误"})
response = self.client.get(self.url + "?paging=true&limit=-1")
self.assertEqual(response.data, {"code": 1, "data": u"参数错误"})
response = self.client.get(self.url + "?paging=true&limit=aa")
self.assertEqual(response.data, {"code": 1, "data": u"参数错误"})
response = self.client.get(self.url + "?paging=true&limit=1&page_size=1&page=-1")
self.assertEqual(response.data, {"code": 1, "data": u"参数错误"})
response = self.client.get(self.url + "?paging=true&limit=1&page_size=aaa&page=1")
self.assertEqual(response.data, {"code": 1, "data": u"参数错误"})
response = self.client.get(self.url + "?paging=true&limit=1&page_size=1&page=aaa")
self.assertEqual(response.data, {"code": 1, "data": u"参数错误"})
def test_correct_paginate(self):
response = self.client.get(self.url + "?paging=true&limit=1&page_size=1&page=1")
self.assertEqual(response.data["code"], 0)
self.assertEqual(response.data["data"]["previous_page"], None)
self.assertEqual(response.data["data"]["next_page"], 2)
self.assertEqual(len(response.data["data"]["results"]), 1)
response = self.client.get(self.url + "?paging=true&limit=1&page_size=2&page=1")
self.assertEqual(response.data["code"], 0)
self.assertEqual(response.data["data"]["previous_page"], None)
self.assertEqual(response.data["data"]["next_page"], None)
self.assertEqual(len(response.data["data"]["results"]), 2)