From 8a6093d6457c1e5ddbeacf87575a95fa4ce1d222 Mon Sep 17 00:00:00 2001 From: virusdefender <1670873886@qq.com> Date: Wed, 5 Aug 2015 08:44:28 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E9=80=9A=E7=94=A8=E5=88=86?= =?UTF-8?q?=E9=A1=B5=E5=87=BD=E6=95=B0=E5=92=8C=E5=AF=B9=E5=BA=94=E7=9A=84?= =?UTF-8?q?=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- oj/settings.py | 1 + utils/shortcuts.py | 49 ++++++++++++++++++++++++++++++++++- utils/test_urls.py | 9 +++++++ utils/tests.py | 64 ++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 122 insertions(+), 1 deletion(-) create mode 100644 utils/test_urls.py create mode 100644 utils/tests.py diff --git a/oj/settings.py b/oj/settings.py index db9ea8f..793f605 100644 --- a/oj/settings.py +++ b/oj/settings.py @@ -47,6 +47,7 @@ INSTALLED_APPS = ( 'django.contrib.staticfiles', 'account', + 'utils', 'rest_framework', 'rest_framework_swagger', diff --git a/utils/shortcuts.py b/utils/shortcuts.py index bd5f286..3d37063 100644 --- a/utils/shortcuts.py +++ b/utils/shortcuts.py @@ -1,4 +1,7 @@ # coding=utf-8 +from django.core.paginator import Paginator + +from rest_framework import pagination from rest_framework.response import Response @@ -11,4 +14,48 @@ def serializer_invalid_response(serializer): def success_response(data): - return Response(data={"code": 0, "data": data}) \ No newline at end of file + return Response(data={"code": 0, "data": data}) + + +def paginate(request, query_set, object_serializer): + """ + 用于分页的函数 + :param query_set 数据库查询结果 + :param object_serializer: 序列化单个object的serializer + :return response + """ + need_paginate = request.GET.get("paging", None) + # 如果请求的参数里面没有paging=true的话 就返回全部数据 + if need_paginate != "true": + return success_response(data=object_serializer(query_set, many=True).data) + + page_size = request.GET.get("page_size", None) + if not page_size: + return error_response(u"参数错误") + + try: + page_size = int(page_size) + except Exception: + return error_response(u"参数错误") + + paginator = Paginator(query_set, page_size) + page = request.GET.get("page", None) + + try: + current_page = paginator.page(page) + except Exception: + return error_response(u"参数错误") + + data = {"results": object_serializer(current_page, many=True).data, "previous_page": None, "next_page": None} + + try: + data["previous_page"] = current_page.previous_page_number() + except Exception: + pass + + try: + data["next_page"] = current_page.next_page_number() + except Exception: + pass + + return success_response(data) \ No newline at end of file diff --git a/utils/test_urls.py b/utils/test_urls.py new file mode 100644 index 0000000..0500d16 --- /dev/null +++ b/utils/test_urls.py @@ -0,0 +1,9 @@ +# coding=utf-8 +from django.conf.urls import include, url + + + +urlpatterns = [ + url(r'^paginate_test/$', "utils.tests.pagination_test_func"), +] + diff --git a/utils/tests.py b/utils/tests.py new file mode 100644 index 0000000..9ab31f9 --- /dev/null +++ b/utils/tests.py @@ -0,0 +1,64 @@ +# coding=utf-8 +from rest_framework.test import APIClient, APITestCase +from rest_framework import serializers +from rest_framework.decorators import api_view + +from account.models import User +from .shortcuts import paginate + + +class PaginationTestSerialiser(serializers.Serializer): + username = serializers.CharField(max_length=100) + + +@api_view(["GET"]) +def pagination_test_func(request): + return paginate(request, User.objects.all(), PaginationTestSerialiser) + + +class PaginatorTest(APITestCase): + urls = "utils.test_urls" + + def setUp(self): + self.client = APIClient() + self.url = "/paginate_test/" + User.objects.create(username="test1") + User.objects.create(username="test2") + + def test_no_paginate(self): + response = self.client.get(self.url) + self.assertEqual(response.data["code"], 0) + self.assertNotIn("next_page", response.data["data"]) + self.assertNotIn("previous_page", response.data["data"]) + + def test_error_parameter(self): + response = self.client.get(self.url + "?paging=true") + self.assertEqual(response.data, {"code": 1, "data": u"参数错误"}) + + response = self.client.get(self.url + "?paging=true&limit=-1") + self.assertEqual(response.data, {"code": 1, "data": u"参数错误"}) + + response = self.client.get(self.url + "?paging=true&limit=aa") + self.assertEqual(response.data, {"code": 1, "data": u"参数错误"}) + + response = self.client.get(self.url + "?paging=true&limit=1&page_size=1&page=-1") + self.assertEqual(response.data, {"code": 1, "data": u"参数错误"}) + + response = self.client.get(self.url + "?paging=true&limit=1&page_size=aaa&page=1") + self.assertEqual(response.data, {"code": 1, "data": u"参数错误"}) + + response = self.client.get(self.url + "?paging=true&limit=1&page_size=1&page=aaa") + self.assertEqual(response.data, {"code": 1, "data": u"参数错误"}) + + def test_correct_paginate(self): + response = self.client.get(self.url + "?paging=true&limit=1&page_size=1&page=1") + self.assertEqual(response.data["code"], 0) + self.assertEqual(response.data["data"]["previous_page"], None) + self.assertEqual(response.data["data"]["next_page"], 2) + self.assertEqual(len(response.data["data"]["results"]), 1) + + response = self.client.get(self.url + "?paging=true&limit=1&page_size=2&page=1") + self.assertEqual(response.data["code"], 0) + self.assertEqual(response.data["data"]["previous_page"], None) + self.assertEqual(response.data["data"]["next_page"], None) + self.assertEqual(len(response.data["data"]["results"]), 2)