Skip to content

Commit

Permalink
feat: Allow @api_view wrapped functions as well.
Browse files Browse the repository at this point in the history
  • Loading branch information
bikemule authored and bikemule committed Dec 5, 2024
1 parent e059a4c commit dcbb34d
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 7 deletions.
23 changes: 17 additions & 6 deletions hybridrouter/hybridrouter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import OrderedDict
from typing import Optional, Type, Union, overload
from typing import Callable, Optional, Type, Union, overload

from django.urls import include, path, re_path
from django.urls.exceptions import NoReverseMatch
Expand Down Expand Up @@ -63,18 +63,25 @@ def register(
) -> None:
...

@overload
def register(
self, prefix: str, viewset: Type[Callable], basename: Optional[str] = None
) -> None:
...

def register(
self,
prefix: str,
viewset: Union[Type[APIView], Type[ViewSetMixin]],
viewset: Union[Type[APIView], Type[ViewSetMixin], Type[Callable]],
basename: Optional[str] = None,
) -> None:
"""
Registers an APIView or ViewSet with the specified prefix.
Registers an APIView, ViewSet, or @api_view-decorated function with the specified prefix.
Args:
prefix (str): URL prefix for the view or viewset.
viewset (Type[APIView] or Type[ViewSetMixin]): The APIView or ViewSet class.
viewset (Type[APIView] or Type[ViewSetMixin] or Type[Callable]):
A class (APIView or ViewSet) or function (@api_view-decorated function).
basename (str, optional): The base name for the view or viewset. Defaults to None.
"""
if basename is None:
Expand Down Expand Up @@ -148,9 +155,13 @@ def _build_urls(self, node, prefix, urls):
viewset_urls = self._get_viewset_urls(node.view, prefix, node.basename)
urls.extend(viewset_urls)
else:
# Add the basic view with a unique name
name = f"{node.basename}"
urls.append(path(f"{prefix}", node.view.as_view(), name=name))
# Only APIView has as_view, so try that first
try:
urls.append(path(f"{prefix}", node.view.as_view(), name=name))
except AttributeError:
# That didn't work, so it must be an @api_view-decorated function.
urls.append(path(f"{prefix}", node.view, name=name))
# If this node is a nested router, include it
elif node.is_nested_router:
urls.append(
Expand Down
29 changes: 28 additions & 1 deletion hybridrouter/tests/test_hybrid_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from .conftest import recevoir_test_url_resolver
from .models import Item
from .views import ItemView
from .views import ItemView, item_view
from .viewsets import ItemViewSet, SlugItemViewSet


Expand All @@ -26,6 +26,7 @@ def create_urlconf(router):
def test_register_views_and_viewsets(hybrid_router, db):
# Enregistrer des vues simples
hybrid_router.register("items-view", ItemView, basename="item-view")
hybrid_router.register("apiitems-view", item_view, basename="apiitem-view")

# Enregistrer des ViewSets
hybrid_router.register("items-set", ItemViewSet, basename="item-set")
Expand All @@ -40,10 +41,12 @@ def test_register_views_and_viewsets(hybrid_router, db):

# Vérifier que les URL sont correctement générées
view_url = reverse("item-view")
api_view_url = reverse("apiitem-view")
list_url = reverse("item-set-list")
detail_url = reverse("item-set-detail", kwargs={"pk": 1})

assert view_url == "/items-view/"
assert api_view_url == "/apiitems-view/"
assert list_url == "/items-set/"
assert detail_url == "/items-set/1/"

Expand All @@ -52,6 +55,9 @@ def test_register_views_and_viewsets(hybrid_router, db):
response = client.get(view_url)
assert response.status_code == status.HTTP_200_OK

response = client.get(api_view_url)
assert response.status_code == status.HTTP_200_OK

Item.objects.create(id=1, name="Test Item", description="Item for testing.")

response = client.get(list_url)
Expand All @@ -61,6 +67,27 @@ def test_register_views_and_viewsets(hybrid_router, db):
assert response.status_code == status.HTTP_200_OK


@override_settings()
def test_register_only_api_views(hybrid_router, db):
# Enregistrer uniquement des vues simples
hybrid_router.register("simple-view", item_view, basename="simple-view")

urlconf = create_urlconf(hybrid_router)

with override_settings(ROOT_URLCONF=urlconf):
resolver = get_resolver(urlconf)
recevoir_test_url_resolver(resolver.url_patterns)

# Vérifier que l'URL est correctement générée
view_url = reverse("simple-view")
assert view_url == "/simple-view/"

# Vérifier que la vue fonctionne correctement
client = APIClient()
response = client.get(view_url)
assert response.status_code == status.HTTP_200_OK


@override_settings()
def test_register_only_views(hybrid_router, db):
# Enregistrer uniquement des vues simples
Expand Down
8 changes: 8 additions & 0 deletions hybridrouter/tests/views.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from rest_framework.decorators import api_view
from rest_framework.views import APIView
from rest_framework.response import Response
from .models import Item
Expand All @@ -9,3 +10,10 @@ def get(self, request):
items = Item.objects.all()
serializer = ItemSerializer(items, many=True)
return Response(serializer.data)


@api_view(["GET"])
def item_view(request):
items = Item.objects.all()
serializer = ItemSerializer(items, many=True)
return Response(serializer.data)

0 comments on commit dcbb34d

Please sign in to comment.