diff --git a/rest_framework-stubs/decorators.pyi b/rest_framework-stubs/decorators.pyi index 30ef41896..cfc278ed1 100644 --- a/rest_framework-stubs/decorators.pyi +++ b/rest_framework-stubs/decorators.pyi @@ -10,7 +10,7 @@ from typing import ( from django.db.models import QuerySet from django.http.response import HttpResponseBase from rest_framework.authentication import BaseAuthentication -from rest_framework.filters import _FilterBackendProtocol +from rest_framework.filters import FilterBackendProtocol from rest_framework.parsers import BaseParser from rest_framework.permissions import _PermissionClass from rest_framework.renderers import BaseRenderer @@ -112,7 +112,7 @@ def action( authentication_classes: _AuthClassesParam = ..., renderer_classes: _RenderClassesParam = ..., parser_classes: _ParserClassesParam = ..., - filter_backends: Sequence[type[_FilterBackendProtocol]] = ..., + filter_backends: Sequence[type[FilterBackendProtocol]] = ..., lookup_field: str = ..., lookup_url_kwarg: str | None = ..., queryset: QuerySet[Any] = ..., diff --git a/rest_framework-stubs/filters.pyi b/rest_framework-stubs/filters.pyi index b9476d694..af7b256bf 100644 --- a/rest_framework-stubs/filters.pyi +++ b/rest_framework-stubs/filters.pyi @@ -10,13 +10,12 @@ from rest_framework.request import Request from rest_framework.views import APIView _MT = TypeVar("_MT", bound=Model) -_Q = TypeVar("_Q", bound=QuerySet[Any]) -class _FilterBackendProtocol(Protocol): - def filter_queryset(self, request: Any, queryset: _Q, view: APIView) -> _Q: ... +class FilterBackendProtocol(Protocol): + def filter_queryset(self, request: Request, queryset: QuerySet[_MT], view: APIView) -> QuerySet[_MT]: ... class BaseFilterBackend: - def filter_queryset(self, request: Any, queryset: _Q, view: APIView) -> _Q: ... + def filter_queryset(self, request: Request, queryset: QuerySet[_MT], view: APIView) -> QuerySet[_MT]: ... def get_schema_fields(self, view: APIView) -> list[Any]: ... def get_schema_operation_parameters(self, view: APIView) -> Any: ... diff --git a/rest_framework-stubs/generics.pyi b/rest_framework-stubs/generics.pyi index 91eceaafe..e1f36379d 100644 --- a/rest_framework-stubs/generics.pyi +++ b/rest_framework-stubs/generics.pyi @@ -5,7 +5,7 @@ from django.db.models import Manager, Model from django.db.models.query import QuerySet from django.http.response import HttpResponse from rest_framework import mixins, views -from rest_framework.filters import _FilterBackendProtocol +from rest_framework.filters import FilterBackendProtocol from rest_framework.pagination import BasePagination from rest_framework.request import Request from rest_framework.response import Response @@ -24,7 +24,7 @@ class GenericAPIView(views.APIView): serializer_class: type[BaseSerializer] | None = ... lookup_field: str = ... lookup_url_kwarg: str | None = ... - filter_backends: Sequence[type[_FilterBackendProtocol]] = ... + filter_backends: Sequence[type[FilterBackendProtocol]] = ... pagination_class: type[BasePagination] | None = ... def get_object(self) -> Any: ... def get_serializer(self, *args: Any, **kwargs: Any) -> BaseSerializer: ... diff --git a/tests/test_restframework.py b/tests/test_restframework.py index 0a9733be8..684cb45ea 100644 --- a/tests/test_restframework.py +++ b/tests/test_restframework.py @@ -1,4 +1,8 @@ +from typing import Sequence + from rest_framework.decorators import permission_classes +from rest_framework.filters import FilterBackendProtocol +from rest_framework.generics import GenericAPIView from rest_framework.permissions import IsAdminUser, IsAuthenticated from rest_framework.test import APIClient from typing_extensions import assert_type @@ -15,3 +19,26 @@ def test_test_client_types() -> None: def test_decorator_types() -> None: permission_classes([IsAuthenticated]) permission_classes([IsAuthenticated | IsAdminUser]) + + +def test_filter_backends_types() -> None: + """ + django-filter does not use typing, so we're testing if untyped code also works fine + with our types. + + https://github.com/carltongibson/django-filter/pull/1585 + """ + + class DjangoFilterBackend: + def filter_queryset(self, request, queryset, view): # type: ignore + filterset = self.get_filterset(request, queryset, view) # type: ignore + return filterset.qs + + class BookAPIView(GenericAPIView): + filter_backends = (DjangoFilterBackend,) + + # mypy and pyright works differently here: + # mypy: Expression is of type "tuple[type[DjangoFilterBackend]]", not "Sequence[type[FilterBackendProtocol]]" [assert-type] + # Disable this check for now in mypy. + # At least assigning to filter_backends works fine, that's our main concern here. + assert_type(BookAPIView.filter_backends, Sequence[type[FilterBackendProtocol]]) # type: ignore[assert-type]