Skip to content

Commit

Permalink
Increase test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
itssimon committed Sep 6, 2023
1 parent d4487c9 commit 1e27f03
Show file tree
Hide file tree
Showing 10 changed files with 108 additions and 59 deletions.
4 changes: 2 additions & 2 deletions apitally/django.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,12 @@ def get_view(self, request: HttpRequest) -> Optional[DjangoViewInfo]:
return next((view for view in self.views if view.pattern == resolver_match.route), None)

def get_consumer(self, request: HttpRequest) -> Optional[str]:
if hasattr(request, "consumer_identifier"):
return str(request.consumer_identifier)
if self.config is not None and self.config.identify_consumer_func is not None:
consumer_identifier = self.config.identify_consumer_func(request)
if consumer_identifier is not None:
return str(consumer_identifier)
if hasattr(request, "consumer_identifier"):
return str(request.consumer_identifier)
if hasattr(request, "auth") and isinstance(request.auth, KeyInfo):
return f"key:{request.auth.key_id}"
return None
Expand Down
2 changes: 1 addition & 1 deletion apitally/django_ninja.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def _get_api(views: List[DjangoViewInfo]) -> NinjaAPI:
return next(
(view.func.__self__.api for view in views if view.is_ninja_path_view and hasattr(view.func, "__self__"))
)
except StopIteration:
except StopIteration: # pragma: no cover
raise RuntimeError("Could not find NinjaAPI instance")


Expand Down
2 changes: 1 addition & 1 deletion apitally/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ async def __call__(self, request: Request, security_scopes: SecurityScopes) -> O
detail="Permission denied",
)
if key_info is not None:
request.state.consumer_identifier = f"key:{key_info.key_id}"
request.state.key_info = key_info
return key_info


Expand Down
8 changes: 1 addition & 7 deletions apitally/flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import time
from functools import wraps
from threading import Timer
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple

import flask
from flask import Flask, g, make_response, request
Expand Down Expand Up @@ -35,12 +35,10 @@ def __init__(
sync_interval: float = 60,
openapi_url: Optional[str] = None,
filter_unhandled_paths: bool = True,
identify_consumer_func: Optional[Callable[[], Optional[str]]] = None,
) -> None:
self.app = app
self.wsgi_app = app.wsgi_app
self.filter_unhandled_paths = filter_unhandled_paths
self.identify_consumer_func = identify_consumer_func
self.client = ApitallyClient(
client_id=client_id, env=env, sync_api_keys=sync_api_keys, sync_interval=sync_interval
)
Expand Down Expand Up @@ -93,10 +91,6 @@ def get_rule(self, environ: WSGIEnvironment) -> Tuple[str, bool]:
return environ["PATH_INFO"], False

def get_consumer(self) -> Optional[str]:
if self.identify_consumer_func is not None:
consumer_identifier = self.identify_consumer_func()
if consumer_identifier is not None:
return str(consumer_identifier)
if "consumer_identifier" in g:
return str(g.consumer_identifier)
if "key_info" in g and isinstance(g.key_info, KeyInfo):
Expand Down
33 changes: 16 additions & 17 deletions apitally/starlette.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,32 +97,29 @@ async def log_request(
and response is not None
and response.headers.get("Content-Type") == "application/json"
):
try:
body = await self.get_response_json(response)
if isinstance(body, dict) and "detail" in body and isinstance(body["detail"], list):
# Log FastAPI / Pydantic validation errors
self.client.validation_error_logger.log_validation_errors(
consumer=consumer,
method=request.method,
path=path_template,
detail=body["detail"],
)
except json.JSONDecodeError:
pass
body = await self.get_response_json(response)
if isinstance(body, dict) and "detail" in body and isinstance(body["detail"], list):
# Log FastAPI / Pydantic validation errors
self.client.validation_error_logger.log_validation_errors(
consumer=consumer,
method=request.method,
path=path_template,
detail=body["detail"],
)

@staticmethod
async def get_response_json(response: Response) -> Any:
if hasattr(response, "body"):
try:
return json.loads(response.body)
except json.JSONDecodeError:
except json.JSONDecodeError: # pragma: no cover
return None
elif hasattr(response, "body_iterator"):
try:
response_body = [section async for section in response.body_iterator]
response.body_iterator = iterate_in_threadpool(iter(response_body))
return json.loads(b"".join(response_body))
except json.JSONDecodeError:
except json.JSONDecodeError: # pragma: no cover
return None

@staticmethod
Expand All @@ -134,12 +131,14 @@ def get_path_template(request: Request) -> Tuple[str, bool]:
return request.url.path, False

def get_consumer(self, request: Request) -> Optional[str]:
if hasattr(request.state, "consumer_identifier"):
return str(request.state.consumer_identifier)
if self.identify_consumer_func is not None:
consumer_identifier = self.identify_consumer_func(request)
if consumer_identifier is not None:
return str(consumer_identifier)
if hasattr(request.state, "consumer_identifier"):
return str(request.state.consumer_identifier)
if hasattr(request.state, "key_info") and isinstance(key_info := request.state.key_info, KeyInfo):
return f"key:{key_info.key_id}"
if "user" in request.scope and isinstance(user := request.scope["user"], APIKeyUser):
return f"key:{user.key_info.key_id}"
return None
Expand Down Expand Up @@ -217,7 +216,7 @@ def _get_routes(app: ASGIApp) -> List[BaseRoute]:
return app.routes
elif hasattr(app, "app"):
return _get_routes(app.app)
return []
return [] # pragma: no cover


def _get_versions(app_version: Optional[str]) -> Dict[str, str]:
Expand Down
1 change: 1 addition & 0 deletions tests/django_ninja_urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def bar(request: HttpRequest) -> str:

@api.put("/baz", auth=APIKeyAuth())
def baz(request: HttpRequest) -> str:
request.consumer_identifier = "baz" # type: ignore[attr-defined]
raise ValueError("baz")


Expand Down
22 changes: 20 additions & 2 deletions tests/test_django_ninja.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import json
from importlib.util import find_spec
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional

import pytest
from pytest_mock import MockerFixture
Expand All @@ -12,11 +12,18 @@
pytest.skip("django-ninja is not available", allow_module_level=True)

if TYPE_CHECKING:
from django.http import HttpRequest
from django.test import Client

from apitally.client.base import KeyRegistry


def identify_consumer(request: HttpRequest) -> Optional[str]:
if consumer := request.GET.get("consumer"):
return consumer
return None


@pytest.fixture(scope="module", autouse=True)
def setup(module_mocker: MockerFixture) -> None:
import django
Expand Down Expand Up @@ -44,6 +51,7 @@ def setup(module_mocker: MockerFixture) -> None:
"client_id": "76b5cb91-a0a4-4ea0-a894-57d2b9fcb2c9",
"env": "default",
"sync_api_keys": True,
"identify_consumer_func": "tests.test_django_ninja.identify_consumer",
},
)
django.setup()
Expand Down Expand Up @@ -128,7 +136,7 @@ def test_api_key_auth(client: Client, key_registry: KeyRegistry, mocker: MockerF
response = client.get("/api/foo", **headers) # type: ignore[arg-type]
assert response.status_code == 403

# Valid API key, no scope required, custom header
# Valid API key, no scope required, custom header, consumer identified by API key
headers = {"HTTP_APIKEY": "7ll40FB.DuHxzQQuGQU4xgvYvTpmnii7K365j9VI"}
response = client.get("/api/foo", **headers) # type: ignore[arg-type]
assert response.status_code == 200
Expand All @@ -139,10 +147,20 @@ def test_api_key_auth(client: Client, key_registry: KeyRegistry, mocker: MockerF
response = client.get("/api/foo/123", **headers) # type: ignore[arg-type]
assert response.status_code == 200

# Valid API key with required scope, consumer identified by custom function
response = client.get("/api/foo/123?consumer=foo", **headers) # type: ignore[arg-type]
assert response.status_code == 200
assert log_request_mock.call_args.kwargs["consumer"] == "foo"

# Valid API key without required scope
response = client.post("/api/bar", **headers) # type: ignore[arg-type]
assert response.status_code == 403

# Valid API key, consumer identifier from request object
response = client.put("/api/baz", **headers) # type: ignore[arg-type]
assert response.status_code == 500
assert log_request_mock.call_args.kwargs["consumer"] == "baz"


def test_get_app_info(mocker: MockerFixture):
from django.urls import get_resolver
Expand Down
30 changes: 24 additions & 6 deletions tests/test_fastapi.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from importlib.util import find_spec
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional

import pytest
from pytest_mock import MockerFixture
Expand All @@ -15,7 +15,13 @@

from apitally.client.base import KeyRegistry

from apitally.client.base import KeyInfo # import here to avoid pydantic error
# Global imports to avoid NameErrors during FastAPI dependency injection
try:
from fastapi import Request

from apitally.client.base import KeyInfo
except ImportError:
pass


CLIENT_ID = "76b5cb91-a0a4-4ea0-a894-57d2b9fcb2c9"
Expand All @@ -30,8 +36,13 @@ def app(mocker: MockerFixture) -> FastAPI:

mocker.patch("apitally.client.asyncio.ApitallyClient._instance", None)

def identify_consumer(request: Request) -> Optional[str]:
if consumer := request.query_params.get("consumer"):
return consumer
return None

app = FastAPI()
app.add_middleware(ApitallyMiddleware, client_id=CLIENT_ID, env=ENV)
app.add_middleware(ApitallyMiddleware, client_id=CLIENT_ID, env=ENV, identify_consumer_func=identify_consumer)
api_key_auth_custom = APIKeyAuth(custom_header="ApiKey")

@app.get("/foo/")
Expand All @@ -43,7 +54,8 @@ def bar(key: KeyInfo = Security(api_key_auth, scopes=["bar"])):
return "bar"

@app.get("/baz/", dependencies=[Depends(api_key_auth_custom)])
def baz():
def baz(request: Request):
request.state.consumer_identifier = "baz"
return "baz"

return app
Expand Down Expand Up @@ -79,14 +91,20 @@ def test_api_key_auth(app: FastAPI, key_registry: KeyRegistry, mocker: MockerFix
response = client.get("/baz", headers={"ApiKey": "invalid"})
assert response.status_code == 403

# Valid API key with required scope
# Valid API key with required scope, consumer identified by API key
response = client.get("/foo", headers=headers)
assert response.status_code == 200
assert log_request_mock.call_args.kwargs["consumer"] == "key:1"

# Valid API key, no scope required, custom header
# Valid API key with required scope, identify consumer with custom function
response = client.get("/foo?consumer=foo", headers=headers)
assert response.status_code == 200
assert log_request_mock.call_args.kwargs["consumer"] == "foo"

# Valid API key, no scope required, custom header, consumer identifier from request.state object
response = client.get("/baz", headers=headers_custom)
assert response.status_code == 200
assert log_request_mock.call_args.kwargs["consumer"] == "baz"

# Valid API key without required scope
response = client.get("/bar", headers=headers)
Expand Down
Loading

0 comments on commit 1e27f03

Please sign in to comment.