diff --git a/api/pagination.py b/api/pagination.py new file mode 100644 index 00000000..64c08026 --- /dev/null +++ b/api/pagination.py @@ -0,0 +1,95 @@ +""" +Custom paginators + +""" +from drf_yasg import openapi +from drf_yasg.inspectors import DjangoRestResponsePagination +from rest_framework.compat import coreapi, coreschema +from rest_framework import pagination + + +class ExtendedCursorPagination(pagination.CursorPagination): + """ + Custom DRF paginator based on the standard CursorPagination with the extra wrinkle of allowing + a total object count to be added to the result. By default this parameter is called + "include_count" and as long as it is set to "true", a count will be returned. + + """ + include_count_query_param = 'include_count' + + def paginate_queryset(self, queryset, request, view=None): + if self.get_should_include_count(request): + self.queryset_count = queryset.count() + else: + self.queryset_count = None + return super().paginate_queryset(queryset, request, view) + + def get_paginated_response(self, data): + response = super().get_paginated_response(data) + + if self.queryset_count is not None: + response.data['count'] = self.queryset_count + + return response + + def get_should_include_count(self, request): + if self.include_count_query_param: + # This rather odd error handling logic is copied from DRF's implementation of similar + # methods. + try: + value = request.query_params[self.include_count_query_param] + return str(value).lower() == 'true' + except (KeyError, ValueError): + pass + return False + + def get_schema_fields(self, view): + assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' + assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`' + + # Explicitly get super class' fields as a list since the API just mandates that they be an + # iterable. + fields = list(super().get_schema_fields(view)) + + if self.include_count_query_param: + fields.append(coreapi.Field( + name=self.include_count_query_param, + required=False, + location='query', + schema=coreschema.Boolean( + title='Include count of resources', + description=( + 'Include total resource count in response. ' + 'By default the count is not included for performance reasons.' + ), + ), + )) + + return fields + + +class ExtendedCursorPaginationInspector(DjangoRestResponsePagination): + """ + Inspector for DRF YASG which understands :py:class:`~.ExtendedCursorPagination`. Either add + this to the `DEFAULT_PAGINATOR_INSPECTORS` setting for drf-yasg or decorate your view: + + .. code:: + + from django.utils.decorators import method_decorator + from drf_yasg import utils as yasg_utils + + # ... + + @method_decorator(name='get', decorator=yasg_utils.swagger_auto_schema( + paginator_inspectors=[ExtendedCursorPaginationInspector] + )) + class MyCountView(ListAPIView): + pagination_class = ExtendedCursorPagination + + """ + def get_paginated_response(self, paginator, response_schema): + schema = None + if isinstance(paginator, ExtendedCursorPagination): + schema = super().get_paginated_response(paginator, response_schema) + schema['properties']['count'] = openapi.Schema(type=openapi.TYPE_INTEGER) + return schema diff --git a/api/tests/test_views.py b/api/tests/test_views.py index e1bcabbb..0a779fcd 100644 --- a/api/tests/test_views.py +++ b/api/tests/test_views.py @@ -267,6 +267,25 @@ def test_token_auth_list(self): for item in response_data['results']: self.assertIn(item['id'], expected_ids) + def test_include_count(self): + """Asking to include a count should return a count of resources.""" + response_data = self.view(self.factory.get('/?include_count=true')).data + self.assertIn('count', response_data) + self.assertGreater(response_data['count'], 0) + + def test_auth_include_count(self): + """Asking to include a count should return a count of resources when logged in.""" + request = self.factory.get('/?include_count=true') + force_authenticate(request, user=self.user) + response_data = self.view(request).data + self.assertIn('count', response_data) + self.assertGreater(response_data['count'], 0) + + def test_count_not_include_by_default(self): + """Not asking for a coount should not include one.""" + response_data = self.view(self.get_request).data + self.assertNotIn('count', response_data) + def test_search_by_title(self): """Items can be searched by title.""" item = mpmodels.MediaItem.objects.first() @@ -907,6 +926,25 @@ def test_basic_list(self): for item in response_data['results']: self.assertIn(item['id'], expected_ids) + def test_include_count(self): + """Asking to include a count should return a count of resources.""" + response_data = self.view(self.factory.get('/?include_count=true')).data + self.assertIn('count', response_data) + self.assertGreater(response_data['count'], 0) + + def test_auth_include_count(self): + """Asking to include a count should return a count of resources when logged in.""" + request = self.factory.get('/?include_count=true') + force_authenticate(request, user=self.user) + response_data = self.view(request).data + self.assertIn('count', response_data) + self.assertGreater(response_data['count'], 0) + + def test_count_not_include_by_default(self): + """Not asking for a coount should not include one.""" + response_data = self.view(self.get_request).data + self.assertNotIn('count', response_data) + class ChannelViewTestCase(ViewTestCase): def setUp(self): @@ -1081,6 +1119,25 @@ def test_create_requires_channel_user_can_edit(self): response = self.view(request) self.assertEqual(response.status_code, 400) + def test_include_count(self): + """Asking to include a count should return a count of resources.""" + response_data = self.view(self.factory.get('/?include_count=true')).data + self.assertIn('count', response_data) + self.assertGreater(response_data['count'], 0) + + def test_auth_include_count(self): + """Asking to include a count should return a count of resources when logged in.""" + request = self.factory.get('/?include_count=true') + force_authenticate(request, user=self.user) + response_data = self.view(request).data + self.assertIn('count', response_data) + self.assertGreater(response_data['count'], 0) + + def test_count_not_include_by_default(self): + """Not asking for a coount should not include one.""" + response_data = self.view(self.get_request).data + self.assertNotIn('count', response_data) + class PlaylistViewTestCase(ViewTestCase): def setUp(self): diff --git a/api/views.py b/api/views.py index 72285af6..cf359877 100644 --- a/api/views.py +++ b/api/views.py @@ -10,15 +10,17 @@ from django.db import models from django.http import Http404 from django.shortcuts import redirect, render +from django.utils.decorators import method_decorator from django_filters import rest_framework as df_filters -from drf_yasg import inspectors, openapi -from rest_framework import generics, pagination, filters, views +from drf_yasg import inspectors, openapi, utils as yasg_utils +from rest_framework import generics, filters, views from rest_framework.exceptions import ParseError import requests import mediaplatform.models as mpmodels from mediaplatform_jwp.api import delivery +from . import pagination as api_pagination from . import permissions from . import serializers @@ -34,8 +36,9 @@ POSTER_IMAGE_VALID_EXTENSIONS = ['jpg'] -class ListPagination(pagination.CursorPagination): +class ListPagination(api_pagination.ExtendedCursorPagination): page_size = 50 + page_size_query_param = 'page_size' class FullTextSearchFilter(filters.SearchFilter): @@ -261,6 +264,9 @@ def filter_playlist(self, queryset, name, value): return queryset.filter(id__in=value.media_items) +@method_decorator(name='get', decorator=yasg_utils.swagger_auto_schema( + paginator_inspectors=[api_pagination.ExtendedCursorPaginationInspector] +)) class MediaItemListView(MediaItemListMixin, generics.ListCreateAPIView): """ List and search Media items. If no other ordering is specified, results are returned in order diff --git a/doc/api.rst b/doc/api.rst index 6d2f8875..b629b498 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -25,6 +25,13 @@ Views :members: :member-order: bysource +DRF Extensions +-------------- + +.. automodule:: api.pagination + :members: + :member-order: bysource + Serializers -----------