From 24bcf215f2fb251215ae8016cf4f68ae266ede25 Mon Sep 17 00:00:00 2001 From: Rich Wareham Date: Fri, 12 Oct 2018 12:23:11 +0100 Subject: [PATCH 1/4] api: add custom paginator for counting resources Add a custom DRF paginator which is based on CursorPaginator byt can optionally add a count of the number of results. This is disabled by default because doing a full count can be inefficient and isn't always needed. Add appropriate magic to make sure that the Swagger API documentation reflects the new change. --- api/pagination.py | 95 +++++++++++++++++++++++++++++++++++++++++++++++ doc/api.rst | 7 ++++ 2 files changed, 102 insertions(+) create mode 100644 api/pagination.py diff --git a/api/pagination.py b/api/pagination.py new file mode 100644 index 00000000..64c08026 --- /dev/null +++ b/api/pagination.py @@ -0,0 +1,95 @@ +""" +Custom paginators + +""" +from drf_yasg import openapi +from drf_yasg.inspectors import DjangoRestResponsePagination +from rest_framework.compat import coreapi, coreschema +from rest_framework import pagination + + +class ExtendedCursorPagination(pagination.CursorPagination): + """ + Custom DRF paginator based on the standard CursorPagination with the extra wrinkle of allowing + a total object count to be added to the result. By default this parameter is called + "include_count" and as long as it is set to "true", a count will be returned. + + """ + include_count_query_param = 'include_count' + + def paginate_queryset(self, queryset, request, view=None): + if self.get_should_include_count(request): + self.queryset_count = queryset.count() + else: + self.queryset_count = None + return super().paginate_queryset(queryset, request, view) + + def get_paginated_response(self, data): + response = super().get_paginated_response(data) + + if self.queryset_count is not None: + response.data['count'] = self.queryset_count + + return response + + def get_should_include_count(self, request): + if self.include_count_query_param: + # This rather odd error handling logic is copied from DRF's implementation of similar + # methods. + try: + value = request.query_params[self.include_count_query_param] + return str(value).lower() == 'true' + except (KeyError, ValueError): + pass + return False + + def get_schema_fields(self, view): + assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' + assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`' + + # Explicitly get super class' fields as a list since the API just mandates that they be an + # iterable. + fields = list(super().get_schema_fields(view)) + + if self.include_count_query_param: + fields.append(coreapi.Field( + name=self.include_count_query_param, + required=False, + location='query', + schema=coreschema.Boolean( + title='Include count of resources', + description=( + 'Include total resource count in response. ' + 'By default the count is not included for performance reasons.' + ), + ), + )) + + return fields + + +class ExtendedCursorPaginationInspector(DjangoRestResponsePagination): + """ + Inspector for DRF YASG which understands :py:class:`~.ExtendedCursorPagination`. Either add + this to the `DEFAULT_PAGINATOR_INSPECTORS` setting for drf-yasg or decorate your view: + + .. code:: + + from django.utils.decorators import method_decorator + from drf_yasg import utils as yasg_utils + + # ... + + @method_decorator(name='get', decorator=yasg_utils.swagger_auto_schema( + paginator_inspectors=[ExtendedCursorPaginationInspector] + )) + class MyCountView(ListAPIView): + pagination_class = ExtendedCursorPagination + + """ + def get_paginated_response(self, paginator, response_schema): + schema = None + if isinstance(paginator, ExtendedCursorPagination): + schema = super().get_paginated_response(paginator, response_schema) + schema['properties']['count'] = openapi.Schema(type=openapi.TYPE_INTEGER) + return schema diff --git a/doc/api.rst b/doc/api.rst index 6d2f8875..b629b498 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -25,6 +25,13 @@ Views :members: :member-order: bysource +DRF Extensions +-------------- + +.. automodule:: api.pagination + :members: + :member-order: bysource + Serializers ----------- From e15b17a5ac028b37528e7497c2a3c73d7585da97 Mon Sep 17 00:00:00 2001 From: Rich Wareham Date: Fri, 12 Oct 2018 14:31:33 +0100 Subject: [PATCH 2/4] api: make use of new pagination class Make use of the new pagination class allowing all resource list views to optionally include a total resource count. --- api/tests/test_views.py | 11 +++++++++++ api/views.py | 11 ++++++++--- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/api/tests/test_views.py b/api/tests/test_views.py index 40c206ba..135c821e 100644 --- a/api/tests/test_views.py +++ b/api/tests/test_views.py @@ -267,6 +267,17 @@ def test_token_auth_list(self): for item in response_data['results']: self.assertIn(item['id'], expected_ids) + def test_include_count(self): + """Asking to include a count should return a count of resources.""" + response_data = self.view(self.factory.get('/?include_count=true')).data + self.assertIn('count', response_data) + self.assertGreater(response_data['count'], 0) + + def test_count_not_include_by_default(self): + """Not asking for a coount should not include one.""" + response_data = self.view(self.get_request).data + self.assertNotIn('count', response_data) + class MediaItemViewTestCase(ViewTestCase): def setUp(self): diff --git a/api/views.py b/api/views.py index e13a7ecc..d7d212cc 100644 --- a/api/views.py +++ b/api/views.py @@ -9,15 +9,17 @@ from django.db import models from django.http import Http404 from django.shortcuts import redirect, render +from django.utils.decorators import method_decorator from django_filters import rest_framework as df_filters -from drf_yasg import inspectors, openapi -from rest_framework import generics, pagination, filters, views +from drf_yasg import inspectors, openapi, utils as yasg_utils +from rest_framework import generics, filters, views from rest_framework.exceptions import ParseError import requests import mediaplatform.models as mpmodels from mediaplatform_jwp.api import delivery +from . import pagination as api_pagination from . import permissions from . import serializers @@ -33,7 +35,7 @@ POSTER_IMAGE_VALID_EXTENSIONS = ['jpg'] -class ListPagination(pagination.CursorPagination): +class ListPagination(api_pagination.ExtendedCursorPagination): page_size = 50 @@ -249,6 +251,9 @@ def filter_playlist(self, queryset, name, value): return queryset.filter(id__in=value.media_items) +@method_decorator(name='get', decorator=yasg_utils.swagger_auto_schema( + paginator_inspectors=[api_pagination.ExtendedCursorPaginationInspector] +)) class MediaItemListView(MediaItemListMixin, generics.ListCreateAPIView): """ Endpoint to retrieve a list of media. From 8cbda3a1c25749d317fc81868bf97f248a167c0d Mon Sep 17 00:00:00 2001 From: Rich Wareham Date: Fri, 12 Oct 2018 14:32:05 +0100 Subject: [PATCH 3/4] api: allow page size of lists to be configured Allow the page size of list responses to be specified. This is to allow the case where one doesn't really care about the resources but one wants the count. In this case, one can just set page_size to 1. DRF diasallows setting page_size to 0. --- api/views.py | 1 + 1 file changed, 1 insertion(+) diff --git a/api/views.py b/api/views.py index d7d212cc..ab26f5d3 100644 --- a/api/views.py +++ b/api/views.py @@ -37,6 +37,7 @@ class ListPagination(api_pagination.ExtendedCursorPagination): page_size = 50 + page_size_query_param = 'page_size' class ViewMixinBase: From e7db1f2885c71639363a8b8526faa36b4199fe15 Mon Sep 17 00:00:00 2001 From: Rich Wareham Date: Mon, 15 Oct 2018 13:38:09 +0100 Subject: [PATCH 4/4] api: add tests for include_count when authorised Add tests which exercise include_count when the user is logged in. --- api/tests/test_views.py | 46 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/api/tests/test_views.py b/api/tests/test_views.py index fc1d670d..0a779fcd 100644 --- a/api/tests/test_views.py +++ b/api/tests/test_views.py @@ -273,6 +273,14 @@ def test_include_count(self): self.assertIn('count', response_data) self.assertGreater(response_data['count'], 0) + def test_auth_include_count(self): + """Asking to include a count should return a count of resources when logged in.""" + request = self.factory.get('/?include_count=true') + force_authenticate(request, user=self.user) + response_data = self.view(request).data + self.assertIn('count', response_data) + self.assertGreater(response_data['count'], 0) + def test_count_not_include_by_default(self): """Not asking for a coount should not include one.""" response_data = self.view(self.get_request).data @@ -918,6 +926,25 @@ def test_basic_list(self): for item in response_data['results']: self.assertIn(item['id'], expected_ids) + def test_include_count(self): + """Asking to include a count should return a count of resources.""" + response_data = self.view(self.factory.get('/?include_count=true')).data + self.assertIn('count', response_data) + self.assertGreater(response_data['count'], 0) + + def test_auth_include_count(self): + """Asking to include a count should return a count of resources when logged in.""" + request = self.factory.get('/?include_count=true') + force_authenticate(request, user=self.user) + response_data = self.view(request).data + self.assertIn('count', response_data) + self.assertGreater(response_data['count'], 0) + + def test_count_not_include_by_default(self): + """Not asking for a coount should not include one.""" + response_data = self.view(self.get_request).data + self.assertNotIn('count', response_data) + class ChannelViewTestCase(ViewTestCase): def setUp(self): @@ -1092,6 +1119,25 @@ def test_create_requires_channel_user_can_edit(self): response = self.view(request) self.assertEqual(response.status_code, 400) + def test_include_count(self): + """Asking to include a count should return a count of resources.""" + response_data = self.view(self.factory.get('/?include_count=true')).data + self.assertIn('count', response_data) + self.assertGreater(response_data['count'], 0) + + def test_auth_include_count(self): + """Asking to include a count should return a count of resources when logged in.""" + request = self.factory.get('/?include_count=true') + force_authenticate(request, user=self.user) + response_data = self.view(request).data + self.assertIn('count', response_data) + self.assertGreater(response_data['count'], 0) + + def test_count_not_include_by_default(self): + """Not asking for a coount should not include one.""" + response_data = self.view(self.get_request).data + self.assertNotIn('count', response_data) + class PlaylistViewTestCase(ViewTestCase): def setUp(self):