diff --git a/announcement/tests.py b/announcement/tests.py index c226597..0354306 100644 --- a/announcement/tests.py +++ b/announcement/tests.py @@ -1,4 +1,6 @@ -from utils.api.tests import APITestCase, APIClient +from utils.api.tests import APITestCase + +from .models import Announcement class AnnouncementAdminTest(APITestCase): @@ -16,6 +18,7 @@ class AnnouncementAdminTest(APITestCase): def test_create_announcement(self): resp = self.create_announcement() self.assertSuccess(resp) + return resp def test_edit_announcement(self): data = {"id": self.create_announcement().data["data"]["id"], "title": "ahaha", "content": "test content", @@ -26,3 +29,9 @@ class AnnouncementAdminTest(APITestCase): self.assertEqual(resp_data["title"], "ahaha") self.assertEqual(resp_data["content"], "test content") self.assertEqual(resp_data["visible"], False) + + def test_delete_announcemen(self): + id = self.test_create_announcement().data["data"]["id"] + resp = self.client.delete(self.url, data={'id': id}) + self.assertSuccess(resp) + self.assertFalse(Announcement.objects.filter(id=id).exists()) diff --git a/announcement/views.py b/announcement/views.py index 2cd1d0a..dbb3c6b 100644 --- a/announcement/views.py +++ b/announcement/views.py @@ -1,7 +1,7 @@ from django.utils.translation import ugettext as _ from account.decorators import super_admin_required -from utils.api import APIView +from utils.api import APIView, validate_serializer, IDOnlySerializer from .models import Announcement from .serializers import (CreateAnnouncementSerializer, AnnouncementSerializer, @@ -9,42 +9,36 @@ from .serializers import (CreateAnnouncementSerializer, AnnouncementSerializer, class AnnouncementAdminAPI(APIView): + @validate_serializer(CreateAnnouncementSerializer) @super_admin_required def post(self, request): """ publish announcement """ - serializer = CreateAnnouncementSerializer(data=request.data) - if serializer.is_valid(): - data = serializer.data - announcement = Announcement.objects.create(title=data["title"], - content=data["content"], - created_by=request.user) - return self.success(AnnouncementSerializer(announcement).data) - else: - return self.invalid_serializer(serializer) + data = request.data + announcement = Announcement.objects.create(title=data["title"], + content=data["content"], + created_by=request.user) + return self.success(AnnouncementSerializer(announcement).data) + @validate_serializer(EditAnnouncementSerializer) @super_admin_required def put(self, request): """ edit announcement """ - serializer = EditAnnouncementSerializer(data=request.data) - if serializer.is_valid(): - data = serializer.data - try: - announcement = Announcement.objects.get(id=data["id"]) - except Announcement.DoesNotExist: - return self.error(_("Announcement does not exist")) + data = request.data + try: + announcement = Announcement.objects.get(id=data["id"]) + except Announcement.DoesNotExist: + return self.error(_("Announcement does not exist")) - announcement.title = data["title"] - announcement.content = data["content"] - announcement.visible = data["visible"] - announcement.save() + announcement.title = data["title"] + announcement.content = data["content"] + announcement.visible = data["visible"] + announcement.save() - return self.success(AnnouncementSerializer(announcement).data) - else: - return self.invalid_serializer(serializer) + return self.success(AnnouncementSerializer(announcement).data) @super_admin_required def get(self, request): @@ -62,3 +56,9 @@ class AnnouncementAdminAPI(APIView): if request.GET.get("visible") == "true": announcement = announcement.filter(visible=True) return self.success(self.paginate_data(request, announcement, AnnouncementSerializer)) + + @validate_serializer(IDOnlySerializer) + @super_admin_required + def delete(self, request): + Announcement.objects.filter(id=request.data["id"]).delete() + return self.success() diff --git a/utils/api/_serializers.py b/utils/api/_serializers.py index 0e4884d..d048890 100644 --- a/utils/api/_serializers.py +++ b/utils/api/_serializers.py @@ -14,4 +14,8 @@ class DateTimeTZField(serializers.DateTimeField): def to_representation(self, value): self.format = "%Y-%-m-%d %-H:%-M:%-S" value = timezone.localtime(value) - return super(DateTimeTZField, self).to_representation(value) \ No newline at end of file + return super(DateTimeTZField, self).to_representation(value) + + +class IDOnlySerializer(serializers.Serializer): + id = serializers.IntegerField() \ No newline at end of file diff --git a/utils/api/api.py b/utils/api/api.py index a8fd055..5141a87 100644 --- a/utils/api/api.py +++ b/utils/api/api.py @@ -1,6 +1,6 @@ -# coding=utf-8 import json import logging +import functools from django.http import HttpResponse, QueryDict from django.utils.decorators import method_decorator @@ -31,7 +31,7 @@ class URLEncodedParser(object): @staticmethod def parse(body): - return QueryDict(body).dict() + return QueryDict(body) class JSONResponse(object): @@ -127,10 +127,11 @@ class APIView(View): return data def dispatch(self, request, *args, **kwargs): - try: - request.data = self._get_request_data(self.request) - except ValueError as e: - return self.error(err="invalid-request", msg=str(e)) + if self.request_parsers: + try: + request.data = self._get_request_data(self.request) + except ValueError as e: + return self.error(err="invalid-request", msg=str(e)) try: return super(APIView, self).dispatch(request, *args, **kwargs) except Exception as e: @@ -144,18 +145,6 @@ class CSRFExemptAPIView(APIView): return super(CSRFExemptAPIView, self).dispatch(request, *args, **kwargs) -class SNServerAPIView(CSRFExemptAPIView): - def empty_response(self): - resp = HttpResponse() - resp["Content-Length"] = 0 - return resp - - def response(self, data): - resp = super(SNServerAPIView, self).response(data) - resp["Content-Length"] = len(resp.content) - return resp - - def validate_serializer(serializer): """ @validate_serializer(TestSerializer) @@ -163,6 +152,7 @@ def validate_serializer(serializer): return self.success(request.data) """ def validate(view_method): + @functools.wraps(view_method) def handle(*args, **kwargs): self = args[0] request = args[1] @@ -176,4 +166,4 @@ def validate_serializer(serializer): return handle - return validate \ No newline at end of file + return validate