Skip to content

Commit

Permalink
Passing request to (#512)
Browse files Browse the repository at this point in the history
  • Loading branch information
vitalik committed Jul 22, 2023
1 parent 0adfb9e commit a718d73
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 7 deletions.
7 changes: 7 additions & 0 deletions docs/docs/guides/response/pagination.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,13 @@ def list_users(request):
return User.objects.all()
```

Tip: You can access request object from params:

```Python
def paginate_queryset(self, queryset, pagination: Input, **params):
request = params["request"]
```

### Output attribute

By defult page items are placed to `'items'` attribute. To override this behaviour use `items_attribute`:
Expand Down
7 changes: 4 additions & 3 deletions ninja/pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Callable, List, Optional, Tuple, Type

from django.db.models import QuerySet
from django.http import HttpRequest
from django.utils.module_loading import import_string

from ninja import Field, Query, Router, Schema
Expand Down Expand Up @@ -133,15 +134,15 @@ def _inject_pagination(
paginator: PaginationBase = paginator_class(**paginator_params)

@wraps(func)
def view_with_pagination(*args: Tuple[Any], **kwargs: Any) -> Any:
def view_with_pagination(request: HttpRequest, **kwargs: Any) -> Any:
pagination_params = kwargs.pop("ninja_pagination")
if paginator.pass_parameter:
kwargs[paginator.pass_parameter] = pagination_params

items = func(*args, **kwargs)
items = func(request, **kwargs)

result = paginator.paginate_queryset(
items, pagination=pagination_params, **kwargs
items, pagination=pagination_params, request=request, **kwargs
)
if paginator.Output:
result[paginator.items_attribute] = list(result[paginator.items_attribute])
Expand Down
45 changes: 41 additions & 4 deletions tests/test_pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def paginate_queryset(self, items, pagination: Input, **params):


class NoOutputPagination(PaginationBase):
# only offset param, defaults to 5 per page but without Output schema
# Outputs items without count attribute
class Input(Schema):
skip: int

Expand Down Expand Up @@ -66,6 +66,28 @@ def paginate_queryset(self, items, pagination: Input, **params):
}


class NextPrevPagination(PaginationBase):
# only offset param, defaults to 5 per page
class Input(Schema):
skip: int

class Output(Schema):
items: List[Any]
next: str = None
prev: str = None

def paginate_queryset(self, items, pagination: Input, request, **params):
skip = pagination.skip
prev_skip = skip - 5
if prev_skip < 0:
prev_skip = 0
return {
"items": items[skip : skip + 5],
"next": request.build_absolute_uri(f"?skip={skip+5}"),
"prev": request.build_absolute_uri(f"?skip={prev_skip}"),
}


@api.get("/items_1", response=List[int])
@paginate # WITHOUT brackets (should use default pagination)
def items_1(request, **kwargs):
Expand Down Expand Up @@ -106,7 +128,7 @@ def items_6(request, **kwargs):
@api.get("/items_7", response=List[int])
@paginate(NoOutputPagination)
def items_7(request):
return [7] * 7
return list(range(15))


@api.get("/items_8", response=List[int])
Expand All @@ -115,6 +137,12 @@ def items_8(request):
return list(range(1000))


@api.get("/items_9", response=List[int])
@paginate(NextPrevPagination)
def items_9(request):
return list(range(100))


client = TestClient(api)


Expand Down Expand Up @@ -269,8 +297,8 @@ def test_case6_pass_param_kwargs():


def test_case7():
response = client.get("/items_7?skip=5").json()
assert response == [7, 7]
response = client.get("/items_7?skip=10").json()
assert response == [10, 11, 12, 13, 14]

schema = api.get_openapi_schema()["paths"]["/api/items_7"]["get"]
response = schema["responses"][200]["content"]["application/json"]["schema"]
Expand All @@ -287,6 +315,15 @@ def test_case8():
assert response == {"results": [5, 6, 7, 8, 9], "count": 1000, "skip": 5}


def test_case9():
response = client.get("/items_9?skip=5").json()
assert response == {
"items": [5, 6, 7, 8, 9],
"next": "http://testlocation/?skip=10",
"prev": "http://testlocation/?skip=0",
}


def test_config_error_None():
with pytest.raises(ConfigError):

Expand Down

0 comments on commit a718d73

Please sign in to comment.