Skip to content

Commit

Permalink
Add API key auth for DRF
Browse files Browse the repository at this point in the history
  • Loading branch information
itssimon committed Aug 21, 2023
1 parent e4cad36 commit 16f0323
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 22 deletions.
8 changes: 6 additions & 2 deletions apitally/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from datetime import datetime, timedelta
from hashlib import scrypt
from math import floor
from typing import Any, Dict, List, Optional, Set, Type, TypeVar, cast
from typing import Any, Dict, List, Optional, Set, Type, TypeVar, Union, cast
from uuid import UUID, uuid4


Expand Down Expand Up @@ -138,7 +138,11 @@ class KeyInfo:
def is_expired(self) -> bool:
return self.expires_at is not None and self.expires_at < datetime.now()

def check_scopes(self, scopes: List[str]) -> bool:
def check_scopes(self, scopes: Union[List[str], str]) -> bool:
if isinstance(scopes, str):
scopes = [scopes]

Check warning on line 143 in apitally/client/base.py

View check run for this annotation

Codecov / codecov/patch

apitally/client/base.py#L143

Added line #L143 was not covered by tests
if not isinstance(scopes, list):
raise ValueError("scopes must be a string or a list of strings")

Check warning on line 145 in apitally/client/base.py

View check run for this annotation

Codecov / codecov/patch

apitally/client/base.py#L145

Added line #L145 was not covered by tests
return all(scope in self.scopes for scope in scopes)

@classmethod
Expand Down
8 changes: 1 addition & 7 deletions apitally/django_ninja.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,7 @@
from ninja import NinjaAPI


__all__ = [
"ApitallyMiddleware",
"AuthorizationAPIKeyHeader",
"KeyInfo",
]
__all__ = ["ApitallyMiddleware", "AuthorizationAPIKeyHeader", "KeyInfo"]


class ApitallyMiddleware(_ApitallyMiddleware):
Expand Down Expand Up @@ -60,8 +56,6 @@ def authenticate(self, request: HttpRequest, key: Optional[str]) -> Optional[Key
raise InvalidAPIKey()
if not key_info.check_scopes(self.scopes):
raise PermissionDenied()
if not hasattr(request, "key_info"):
setattr(request, "key_info", key_info)
return key_info


Expand Down
35 changes: 35 additions & 0 deletions apitally/django_rest_framework.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from rest_framework.permissions import BasePermission

from apitally.client.base import KeyInfo
from apitally.client.threading import ApitallyClient
from apitally.django import ApitallyMiddleware


if TYPE_CHECKING:
from django.http import HttpRequest
from rest_framework.views import APIView


__all__ = ["ApitallyMiddleware", "HasAPIKey", "KeyInfo"]


class HasAPIKey(BasePermission): # type: ignore[misc]
def has_permission(self, request: HttpRequest, view: APIView) -> bool:
authorization = request.headers.get("Authorization")
if not authorization:
return False
scheme, _, param = authorization.partition(" ")
if scheme.lower() != "apikey":
return False
key_info = ApitallyClient.get_instance().key_registry.get(param)
if key_info is None:
return False
if hasattr(view, "required_scopes") and not key_info.check_scopes(view.required_scopes):
return False
if not hasattr(request, "key_info"):
setattr(request, "key_info", key_info)
return True
18 changes: 18 additions & 0 deletions tests/django_urls.py → tests/django_rest_framework_urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,41 @@
from django.urls import path
from rest_framework.views import APIView

from apitally.django_rest_framework import HasAPIKey


class FooView(APIView):
permission_classes = [HasAPIKey]

def get(self, request: HttpRequest) -> HttpResponse:
return HttpResponse("foo")


class FooBarView(APIView):
permission_classes = [HasAPIKey]
required_scopes = ["foo"]

def get(self, request: HttpRequest, bar: int) -> HttpResponse:
return HttpResponse(f"foo: {bar}")


class BarView(APIView):
permission_classes = [HasAPIKey]
required_scopes = ["bar"]

def post(self, request: HttpRequest) -> HttpResponse:
return HttpResponse("bar")


class BazView(APIView):
permission_classes = [HasAPIKey]

def put(self, request: HttpRequest) -> HttpResponse:
raise ValueError("baz")


urlpatterns = [
path("foo/", FooView.as_view()),
path("foo/<int:bar>/", FooBarView.as_view()),
path("bar/", BarView.as_view()),
path("baz/", BazView.as_view()),
Expand Down
13 changes: 7 additions & 6 deletions tests/test_django_ninja.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import json
from importlib.util import find_spec
from typing import TYPE_CHECKING, Iterator
from typing import TYPE_CHECKING

import pytest
from pytest_mock import MockerFixture
Expand All @@ -18,9 +18,9 @@


@pytest.fixture(scope="module", autouse=True)
def setup(module_mocker: MockerFixture) -> Iterator[None]:
def setup(module_mocker: MockerFixture) -> None:
import django
from django.apps.registry import Apps
from django.apps.registry import apps
from django.conf import settings
from django.utils.functional import empty

Expand All @@ -29,7 +29,10 @@ def setup(module_mocker: MockerFixture) -> Iterator[None]:
module_mocker.patch("apitally.client.threading.ApitallyClient.send_app_info")
module_mocker.patch("apitally.django.ApitallyMiddleware.config", None)

module_mocker.patch("django.apps.registry.apps", Apps())
settings._wrapped = empty
apps.app_configs.clear()
apps.loading = False
apps.ready = False

settings.configure(
ROOT_URLCONF="tests.django_ninja_urls",
Expand All @@ -44,8 +47,6 @@ def setup(module_mocker: MockerFixture) -> Iterator[None]:
},
)
django.setup()
yield
settings._wrapped = empty


@pytest.fixture(scope="module")
Expand Down
54 changes: 47 additions & 7 deletions tests/test_django.py → tests/test_django_rest_framework.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from importlib.util import find_spec
from typing import TYPE_CHECKING, Iterator
from typing import TYPE_CHECKING

import pytest
from pytest_mock import MockerFixture
Expand All @@ -13,10 +13,13 @@
if TYPE_CHECKING:
from rest_framework.test import APIClient

from apitally.client.base import KeyRegistry


@pytest.fixture(scope="module", autouse=True)
def setup(module_mocker: MockerFixture) -> Iterator[None]:
def setup(module_mocker: MockerFixture) -> None:
import django
from django.apps.registry import apps
from django.conf import settings
from django.utils.functional import empty

Expand All @@ -25,12 +28,17 @@ def setup(module_mocker: MockerFixture) -> Iterator[None]:
module_mocker.patch("apitally.client.threading.ApitallyClient.send_app_info")
module_mocker.patch("apitally.django.ApitallyMiddleware.config", None)

settings._wrapped = empty
apps.app_configs.clear()
apps.loading = False
apps.ready = False

settings.configure(
ROOT_URLCONF="tests.django_urls",
ROOT_URLCONF="tests.django_rest_framework_urls",
ALLOWED_HOSTS=["testserver"],
SECRET_KEY="secret",
MIDDLEWARE=[
"apitally.django.ApitallyMiddleware",
"apitally.django_rest_framework.ApitallyMiddleware",
],
INSTALLED_APPS=[
"django.contrib.auth",
Expand All @@ -43,8 +51,6 @@ def setup(module_mocker: MockerFixture) -> Iterator[None]:
},
)
django.setup()
yield
settings._wrapped = empty


@pytest.fixture(scope="module")
Expand All @@ -56,6 +62,7 @@ def client() -> APIClient:

def test_middleware_requests_ok(client: APIClient, mocker: MockerFixture):
mock = mocker.patch("apitally.client.base.RequestLogger.log_request")
mocker.patch("apitally.django_rest_framework.HasAPIKey.has_permission", return_value=True)

response = client.get("/foo/123/")
assert response.status_code == 200
Expand All @@ -75,6 +82,7 @@ def test_middleware_requests_ok(client: APIClient, mocker: MockerFixture):

def test_middleware_requests_error(client: APIClient, mocker: MockerFixture):
mock = mocker.patch("apitally.client.base.RequestLogger.log_request")
mocker.patch("apitally.django_rest_framework.HasAPIKey.has_permission", return_value=True)

response = client.put("/baz/")
assert response.status_code == 500
Expand All @@ -86,14 +94,46 @@ def test_middleware_requests_error(client: APIClient, mocker: MockerFixture):
assert mock.call_args.kwargs["response_time"] > 0


def test_api_key_auth(client: APIClient, key_registry: KeyRegistry, mocker: MockerFixture):
mock = mocker.patch("apitally.django_rest_framework.ApitallyClient.get_instance")
mock.return_value.key_registry = key_registry

# Unauthenticated
response = client.get("/foo/123/")
assert response.status_code == 403

# Invalid auth scheme
headers = {"Authorization": "Bearer invalid"}
response = client.get("/foo/123/", headers=headers) # type: ignore[arg-type]
assert response.status_code == 403

# Invalid API key
headers = {"Authorization": "ApiKey invalid"}
response = client.get("/foo/123/", headers=headers) # type: ignore[arg-type]
assert response.status_code == 403

# Valid API key, no scope required
headers = {"Authorization": "ApiKey 7ll40FB.DuHxzQQuGQU4xgvYvTpmnii7K365j9VI"}
response = client.get("/foo/", headers=headers) # type: ignore[arg-type]
assert response.status_code == 200

# Valid API key with required scope
response = client.get("/foo/123/", headers=headers) # type: ignore[arg-type]
assert response.status_code == 200

# Valid API key without required scope
response = client.post("/bar/", headers=headers) # type: ignore[arg-type]
assert response.status_code == 403


def test_get_app_info():
from django.urls import get_resolver

from apitally.django import _extract_views_from_url_patterns, _get_app_info

views = _extract_views_from_url_patterns(get_resolver().url_patterns)
app_info = _get_app_info(views=views, app_version="1.2.3")
assert len(app_info["paths"]) == 3
assert len(app_info["paths"]) == 4
assert app_info["versions"]["django"]
assert app_info["versions"]["app"] == "1.2.3"
assert app_info["framework"] == "django"

0 comments on commit 16f0323

Please sign in to comment.