Skip to content

Commit

Permalink
feat: add handler wrapper function (#30)
Browse files Browse the repository at this point in the history
  • Loading branch information
julianolf authored Nov 29, 2023
1 parent 6d955eb commit 4a5a378
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 20 deletions.
14 changes: 13 additions & 1 deletion src/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
import logging
from datetime import datetime, timezone
from http import HTTPStatus
from typing import Any, Callable, Dict
from typing import Any, Awaitable, Callable, Dict
from uuid import uuid4

from fastapi import Request, Response, routing

logger = logging.getLogger(__name__)

Handler = Callable[[Dict[str, Any], Any], Dict[str, Any]]
Endpoint = Callable[[Request], Awaitable[Response]]


class ApiGatewayResponse(Response):
Expand Down Expand Up @@ -51,6 +52,17 @@ async def event_builder(request: Request) -> Dict[str, Any]:
return event


def handler(func: Handler) -> Endpoint:
async def wrapper(request: Request) -> Response:
event = await event_builder(request)
result = func(event, None)
response = ApiGatewayResponse(result)

return response

return wrapper


def default_endpoint(request: Request) -> Response:
logger.error(f"Executing default endpoint: {request.scope}")

Expand Down
63 changes: 44 additions & 19 deletions tests/test_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,29 @@
import routing


def build_request():
async def receive():
return {"type": "http.request", "body": b'{"message": "test"}'}

scope = {
"type": "http",
"http_version": "1.1",
"root_path": "",
"path": "/test",
"method": "GET",
"query_string": [],
"path_params": {},
"client": ("127.0.0.1", 80),
"app": FastAPI(),
"headers": [
(b"content-type", b"application/json"),
(b"user-agent", b"python/unittest"),
],
}

return Request(scope, receive)


class TestAPIRoute(unittest.TestCase):
def test_default_endpoint(self):
with self.assertLogs(logger=routing.__name__, level="ERROR") as logs:
Expand Down Expand Up @@ -76,25 +99,7 @@ def test_response_creation_fails_when_missing_required_attribute(self):

class TestEventBuilder(unittest.IsolatedAsyncioTestCase):
async def test_event_builder(self):
async def receive():
return {"type": "http.request", "body": b'{"message": "test"}'}

scope = {
"type": "http",
"http_version": "1.1",
"root_path": "",
"path": "/test",
"method": "GET",
"query_string": [],
"path_params": {},
"client": ("127.0.0.1", 80),
"app": FastAPI(),
"headers": [
(b"content-type", b"application/json"),
(b"user-agent", b"python/unittest"),
],
}
request = Request(scope, receive)
request = build_request()
expected_keys = {
"body",
"path",
Expand All @@ -110,3 +115,23 @@ async def receive():

self.assertIsInstance(event, dict)
self.assertEqual(set(event.keys()), expected_keys)


class TestHandler(unittest.IsolatedAsyncioTestCase):
async def test_handler(self):
request = build_request()

def echo(event, _):
return {
"statusCode": HTTPStatus.OK.value,
"body": event["body"],
"headers": event["headers"],
}

endpoint = routing.handler(echo)
response = await endpoint(request)

self.assertIsInstance(response, Response)
self.assertEqual(response.status_code, HTTPStatus.OK.value)
self.assertEqual(response.body.decode(), '{"message": "test"}')
self.assertEqual(response.headers["content-type"], "application/json")

0 comments on commit 4a5a378

Please sign in to comment.