Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix paginate Decorator Typing #1127

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 82 additions & 23 deletions ninja/pagination.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,41 @@
import inspect
from abc import ABC, abstractmethod
from functools import partial, wraps
from typing import Any, Callable, List, Optional, Tuple, Type

from django.db.models import QuerySet
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
overload,
)

from django.db import models
from django.http import HttpRequest
from django.utils.module_loading import import_string
from typing_extensions import Concatenate, ParamSpec, TypeAlias, TypeGuard
from typing_extensions import get_args as get_collection_args

from ninja import Field, Query, Router, Schema
from ninja.conf import settings
from ninja.constants import NOT_SET
from ninja.constants import NOT_SET, NOT_SET_TYPE
from ninja.errors import ConfigError
from ninja.operation import Operation
from ninja.signature.details import is_collection_type
from ninja.utils import contribute_operation_args, contribute_operation_callback

Req = TypeVar("Req", bound=HttpRequest)
M = TypeVar("M", bound=models.Model)
P = ParamSpec("P")

ViewFn: TypeAlias = Callable[Concatenate[Req, P], Sequence[Any]]
PaginatedViewFn: TypeAlias = Callable[Concatenate[Req, P], Dict[str, Any]]


class PaginationBase(ABC):
class Input(Schema):
Expand All @@ -35,20 +55,21 @@ def __init__(self, *, pass_parameter: Optional[str] = None, **kwargs: Any) -> No
@abstractmethod
def paginate_queryset(
self,
queryset: QuerySet,
queryset: Sequence[M],
pagination: Any,
**params: Any,
) -> Any:
) -> Dict[str, Any]:
pass # pragma: no cover

def _items_count(self, queryset: QuerySet) -> int:
def _items_count(self, queryset: Sequence[M]) -> int:
"""
Since lists are mainly compatible with QuerySets and can be passed to paginator.
We will first to try to use .count - and if not there will use a len
"""
try:
# forcing to find queryset.count instead of list.count:
return queryset.all().count()
# Avoid checking the type with `isinstance` because this might not work with
# monkey-patched QuerySets.
return queryset.all().count() # type: ignore
except AttributeError:
return len(queryset)

Expand All @@ -60,10 +81,10 @@ class Input(Schema):

def paginate_queryset(
self,
queryset: QuerySet,
queryset: Sequence[M],
pagination: Input,
**params: Any,
) -> Any:
) -> Dict[str, Any]:
offset = pagination.offset
limit: int = min(pagination.limit, settings.PAGINATION_MAX_LIMIT)
return {
Expand All @@ -84,18 +105,40 @@ def __init__(

def paginate_queryset(
self,
queryset: QuerySet,
queryset: Sequence[M],
pagination: Input,
**params: Any,
) -> Any:
) -> Dict[str, Any]:
offset = (pagination.page - 1) * self.page_size
return {
"items": queryset[offset : offset + self.page_size],
"count": self._items_count(queryset),
} # noqa: E203


def paginate(func_or_pgn_class: Any = NOT_SET, **paginator_params: Any) -> Callable:
@overload
def paginate(
func_or_pgn_class: ViewFn[Req, P], **paginator_params: Any
) -> PaginatedViewFn[Req, P]:
...


@overload
def paginate(
func_or_pgn_class: Union[Type[PaginationBase], NOT_SET_TYPE] = NOT_SET,
**paginator_params: Any,
) -> Callable[[ViewFn[Req, P]], PaginatedViewFn[Req, P]]:
...


def paginate(
func_or_pgn_class: Union[
ViewFn[Req, P], Type[PaginationBase], NOT_SET_TYPE
] = NOT_SET,
**paginator_params: Any,
) -> Union[
PaginatedViewFn[Req, P], Callable[[ViewFn[Req, P]], PaginatedViewFn[Req, P]]
]:
"""
@api.get(...
@paginate
Expand All @@ -109,37 +152,53 @@ def my_view(request):

"""

isfunction = inspect.isfunction(func_or_pgn_class)
def _is_view_func(func: Any) -> TypeGuard[ViewFn[Req, P]]:
return inspect.isfunction(func_or_pgn_class)

isnotset = func_or_pgn_class == NOT_SET

pagination_class: Type[PaginationBase] = import_string(settings.PAGINATION_CLASS)

if isfunction:
if _is_view_func(func_or_pgn_class):
return _inject_pagination(func_or_pgn_class, pagination_class)

if not isnotset:
# Second check is redundant, but `TypeGuard` doesn't narrow the negative case.
# `TypeIs` should resolve this: https://peps.python.org/pep-0742/
if not isnotset and isinstance(func_or_pgn_class, type):
pagination_class = func_or_pgn_class

def wrapper(func: Callable) -> Any:
def wrapper(func: ViewFn[Req, P]) -> PaginatedViewFn[Req, P]:
return _inject_pagination(func, pagination_class, **paginator_params)

return wrapper


def _inject_pagination(
func: Callable,
func: ViewFn[Req, P],
paginator_class: Type[PaginationBase],
**paginator_params: Any,
) -> Callable:
paginator: PaginationBase = paginator_class(**paginator_params)
) -> PaginatedViewFn[Req, P]:
"""Inject pagination into the view function.

Args:
func: The view function.
paginator_class: The paginator class.
**paginator_params: Parameters for the paginator class.

Returns:
The view function with pagination injected into the response.
"""
paginator = paginator_class(**paginator_params)

@wraps(func)
def view_with_pagination(request: HttpRequest, **kwargs: Any) -> Any:
def view_with_pagination(
request: Req, *args: P.args, **kwargs: P.kwargs
) -> Dict[str, Any]:
pagination_params = kwargs.pop("ninja_pagination")
if paginator.pass_parameter:
kwargs[paginator.pass_parameter] = pagination_params

items = func(request, **kwargs)
items = func(request, *args, **kwargs)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To properly utilize ParamSpec (which lets us preserve the call signature), we need to pass in *args. I'm guessing this wouldn't be problematic, but if will be, I can rethink the approach here.


result = paginator.paginate_queryset(
items, pagination=pagination_params, request=request, **kwargs
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,6 @@ branch = true
fail_under = 100
skip_covered = true
show_missing = true
exclude_also = [
"@(typing\\.)?overload",
]
Loading