Skip to content

Commit

Permalink
Merge pull request #77 from stealthrocket/fix-fastapi-error-responses
Browse files Browse the repository at this point in the history
FastAPI: return error responses formatted for the connectrpc protocol
  • Loading branch information
achille-roussel authored Feb 20, 2024
2 parents 129293e + f9a8300 commit 4624d5f
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 24 deletions.
54 changes: 35 additions & 19 deletions src/dispatch/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,29 +137,47 @@ def __init__(
app.mount("/dispatch.sdk.v1.FunctionService", function_service)


class _GRPCResponse(fastapi.Response):
class _ConnectResponse(fastapi.Response):
media_type = "application/grpc+proto"


class _ConnectError(fastapi.HTTPException):
__slots__ = ("status", "code", "message")

def __init__(self, status, code, message):
super().__init__(status)
self.status = status
self.code = code
self.message = message


def _new_app(function_registry: Dispatch, verification_key: Ed25519PublicKey | None):
app = fastapi.FastAPI()

@app.exception_handler(_ConnectError)
async def on_error(request: fastapi.Request, exc: _ConnectError):
# https://connectrpc.com/docs/protocol/#error-end-stream
return fastapi.responses.JSONResponse(
status_code=exc.status, content={"code": exc.code, "message": exc.message}
)

@app.post(
# The endpoint for execution is hardcoded at the moment. If the service
# gains more endpoints, this should be turned into a dynamic dispatch
# like the official gRPC server does.
"/Run",
response_class=_GRPCResponse,
response_class=_ConnectResponse,
)
async def execute(request: fastapi.Request):
# Raw request body bytes are only available through the underlying
# starlette Request object's body method, which returns an awaitable,
# forcing execute() to be async.
data: bytes = await request.body()

logger.debug("handling run request with %d byte body", len(data))

if verification_key is not None:
if verification_key is None:
logger.debug("skipping request signature verification")
else:
signed_request = Request(
method=request.method,
url=str(request.url),
Expand All @@ -169,29 +187,28 @@ async def execute(request: fastapi.Request):
max_age = timedelta(minutes=5)
try:
verify_request(signed_request, verification_key, max_age)
except (InvalidSignature, ValueError):
logger.error("failed to verify request signature", exc_info=True)
raise fastapi.HTTPException(
status_code=403, detail="request signature is invalid"
)
else:
logger.debug("skipping request signature verification")
except ValueError as e:
raise _ConnectError(401, "unauthenticated", str(e))
except InvalidSignature as e:
# The http_message_signatures package sometimes wraps does not
# attach a message to the exception, so we set a default to
# have some context about the reason for the error.
message = str(e) or "invalid signature"
raise _ConnectError(403, "permission_denied", message)

req = function_pb.RunRequest.FromString(data)

if not req.function:
raise fastapi.HTTPException(status_code=400, detail="function is required")
raise _ConnectError(400, "invalid_argument", "function is required")

try:
func = function_registry._functions[req.function]
except KeyError:
logger.debug("function '%s' not found", req.function)
raise fastapi.HTTPException(
status_code=404, detail=f"Function '{req.function}' does not exist"
raise _ConnectError(
404, "not_found", f"function '{req.function}' does not exist"
)

input = Input(req)

logger.info("running function '%s'", req.function)
try:
output = func._primitive_call(input)
Expand All @@ -203,8 +220,8 @@ async def execute(request: fastapi.Request):
# so indicates a problem, and we return a 500 rather than attempt
# to catch and categorize the error here.
logger.error("function '%s' fatal error", req.function, exc_info=True)
raise fastapi.HTTPException(
status_code=500, detail=f"function '{req.function}' fatal error"
raise _ConnectError(
500, "internal", f"function '{req.function}' fatal error"
)
else:
response = output._message
Expand Down Expand Up @@ -241,7 +258,6 @@ async def execute(request: fastapi.Request):
)

logger.debug("finished handling run request with status %s", status.name)

return fastapi.Response(content=response.SerializeToString())

return app
7 changes: 5 additions & 2 deletions src/dispatch/signature/digest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import hmac

import http_sfv
from http_message_signatures import InvalidSignature


def generate_content_digest(body: str | bytes) -> str:
Expand Down Expand Up @@ -34,7 +35,9 @@ def verify_content_digest(digest_header: str | bytes, body: str | bytes):
digest = parsed_header["sha-256"].value
expect_digest = hashlib.sha256(body).digest()
else:
raise ValueError("missing content digest")
raise ValueError("missing content digest in http request header")

if not hmac.compare_digest(digest, expect_digest):
raise ValueError("unexpected content digest")
raise InvalidSignature(
"digest of the request body does not match the Content-Digest header"
)
5 changes: 4 additions & 1 deletion tests/dispatch/signature/test_signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,10 @@ def test_signature_too_old(self):
def test_content_digest_invalid(self):
sign_request(self.request, private_key, datetime.now())
self.request.body = "foo"
with self.assertRaisesRegex(ValueError, "unexpected content digest"):
with self.assertRaisesRegex(
InvalidSignature,
"digest of the request body does not match the Content-Digest header",
):
verify_request(self.request, public_key, max_age=timedelta(minutes=1))

def test_signature_coverage(self):
Expand Down
26 changes: 24 additions & 2 deletions tests/test_full.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import unittest

import fastapi
import httpx
from fastapi.testclient import TestClient

from dispatch import Call, Input, Output
Expand Down Expand Up @@ -39,8 +40,10 @@ def setUp(self):
api_url="http://127.0.0.1:10000",
)

http_client = TestClient(self.app, base_url="http://dispatch-service")
self.app_client = function_service.client(http_client, signing_key=private_key)
self.http_client = TestClient(self.app, base_url="http://dispatch-service")
self.app_client = function_service.client(
self.http_client, signing_key=private_key
)

self.server = ServerTest()
# shortcuts
Expand Down Expand Up @@ -68,3 +71,22 @@ def my_function(name: str) -> str:
# Validate results.
resp = self.servicer.response_for(dispatch_id)
self.assertEqual(any_unpickle(resp.exit.result.output), "Hello world: 52")

def test_simple_missing_signature(self):
@self.dispatch.function()
def my_function(name: str) -> str:
return f"Hello world: {name}"

[dispatch_id] = self.client.dispatch([my_function.build_call(52)])

self.app_client = function_service.client(self.http_client) # no signing key
try:
self.execute()
except httpx.HTTPStatusError as e:
assert e.response.status_code == 403
assert e.response.json() == {
"code": "permission_denied",
"message": 'Expected "Signature-Input" header field to be present',
}
else:
assert False, "Expected HTTPStatusError"

0 comments on commit 4624d5f

Please sign in to comment.