From d2d31466c2e308d56a7187f43bdc911d4beb7344 Mon Sep 17 00:00:00 2001 From: c4ffein Date: Fri, 9 Aug 2024 00:48:17 +0200 Subject: [PATCH] test-case-headers --- ninja/testing/client.py | 12 +++++++++++- tests/test_test_client.py | 23 +++++++++++++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/ninja/testing/client.py b/ninja/testing/client.py index 0be8ce0a..00e78d3f 100644 --- a/ninja/testing/client.py +++ b/ninja/testing/client.py @@ -26,7 +26,12 @@ def build_absolute_uri(location: Optional[str] = None) -> str: class NinjaClientBase: __test__ = False # <- skip pytest - def __init__(self, router_or_app: Union[NinjaAPI, Router]) -> None: + def __init__( + self, + router_or_app: Union[NinjaAPI, Router], + headers: Optional[Dict[str, str]] = None, + ) -> None: + self.headers = headers or {} self.router_or_app = router_or_app def get( @@ -82,6 +87,11 @@ def request( request_params["body"] = json_dumps(json, cls=NinjaJSONEncoder) if data is None: data = {} + if self.headers or request_params.get("headers"): + request_params["headers"] = { + **self.headers, + **request_params.get("headers", {}), + } func, request, kwargs = self._resolve(method, path, data, request_params) return self._call(func, request, kwargs) # type: ignore diff --git a/tests/test_test_client.py b/tests/test_test_client.py index 67f15e72..e182c20e 100644 --- a/tests/test_test_client.py +++ b/tests/test_test_client.py @@ -27,6 +27,11 @@ def simple_get(request): return "test" +@router.get("/test-headers") +def get_headers(request): + return dict(request.headers) + + client = TestClient(router) @@ -78,3 +83,21 @@ def test_json_as_body(): ClientTestSchema.model_validate_json(request.body).model_dump_json() == schema_instance.model_dump_json() ) + + +headered_client = TestClient(router, headers={"A": "a", "B": "b"}) + + +def test_client_request_only_header(): + r = client.get("/test-headers", headers={"A": "na"}) + assert r.json() == {"A": "na"} + + +def test_headered_client_request_with_default_headers(): + r = headered_client.get("/test-headers") + assert r.json() == {"A": "a", "B": "b"} + + +def test_headered_client_request_with_overwritten_and_additional_headers(): + r = headered_client.get("/test-headers", headers={"A": "na", "C": "nc"}) + assert r.json() == {"A": "na", "B": "b", "C": "nc"}