From feb700188fb3ded5a61334736e512cbded093491 Mon Sep 17 00:00:00 2001 From: Thomas Pelletier Date: Tue, 30 Jan 2024 17:40:57 -0500 Subject: [PATCH] Fix status serialization --- src/dispatch/coroutine.py | 4 +++- tests/test_fastapi.py | 12 ++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/src/dispatch/coroutine.py b/src/dispatch/coroutine.py index ac35308a..bb05bdc0 100644 --- a/src/dispatch/coroutine.py +++ b/src/dispatch/coroutine.py @@ -203,9 +203,10 @@ def error(cls, error: Error) -> Output: """Terminally exit the coroutine with the provided error.""" return Output( coroutine_pb2.ExecuteResponse( + status=error.status._proto, exit=coroutine_pb2.Exit( result=coroutine_pb2.Result(error=error._as_proto()) - ) + ), ) ) @@ -298,6 +299,7 @@ def __init__(self, status: Status, type: str | None, message: str | None): self.type = type self.message = message + self.status = status @classmethod def from_exception(cls, ex: Exception, status: Status | None = None) -> Error: diff --git a/tests/test_fastapi.py b/tests/test_fastapi.py index 8b4dc4c1..33f0c31b 100644 --- a/tests/test_fastapi.py +++ b/tests/test_fastapi.py @@ -269,6 +269,7 @@ def mycoro(input: Input) -> Output: resp = self.execute(mycoro) self.assertEqual("ZeroDivisionError", resp.exit.result.error.type) self.assertEqual("division by zero", resp.exit.result.error.message) + self.assertEqual(Status.TEMPORARY_ERROR, resp.status) def test_coroutine_unexpected_exception(self): @self.app.dispatch_coroutine() @@ -279,3 +280,14 @@ def mycoro(input: Input) -> Output: resp = self.execute(mycoro) self.assertEqual("ZeroDivisionError", resp.exit.result.error.type) self.assertEqual("division by zero", resp.exit.result.error.message) + self.assertEqual(Status.TEMPORARY_ERROR, resp.status) + + def test_specific_status(self): + @self.app.dispatch_coroutine() + def mycoro(input: Input) -> Output: + return Output.error(Error(Status.THROTTLED, "foo", "bar")) + + resp = self.execute(mycoro) + self.assertEqual("foo", resp.exit.result.error.type) + self.assertEqual("bar", resp.exit.result.error.message) + self.assertEqual(Status.THROTTLED, resp.status)