Skip to content

Commit

Permalink
Fix a few tests
Browse files Browse the repository at this point in the history
  • Loading branch information
chriso committed Jun 13, 2024
1 parent 84c5a34 commit b7acacf
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 32 deletions.
5 changes: 4 additions & 1 deletion src/dispatch/proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,10 @@ def _any_unpickle(any: google.protobuf.any_pb2.Any) -> Any:
any.Unpack(b)
return pickle.loads(b.value)

raise InvalidArgumentError("unsupported pickled value container")
elif not any.type_url and not any.value:
return None

raise InvalidArgumentError(f"unsupported pickled value container: {any.type_url}")


def _pb_any_pickle(value: Any) -> google.protobuf.any_pb2.Any:
Expand Down
17 changes: 5 additions & 12 deletions tests/test_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from dispatch.fastapi import Dispatch
from dispatch.function import Arguments, Error, Function, Input, Output
from dispatch.proto import _any_unpickle as any_unpickle
from dispatch.proto import _pb_any_pickle as any_pickle
from dispatch.sdk.v1 import call_pb2 as call_pb
from dispatch.sdk.v1 import function_pb2 as function_pb
from dispatch.signature import (
Expand Down Expand Up @@ -91,23 +92,17 @@ async def my_function(input: Input) -> Output:
)

client = create_endpoint_client(app)
pickled = pickle.dumps("Hello World!")
input_any = google.protobuf.any_pb2.Any()
input_any.Pack(google.protobuf.wrappers_pb2.BytesValue(value=pickled))

req = function_pb.RunRequest(
function=my_function.name,
input=input_any,
input=any_pickle("Hello World!"),
)

resp = client.run(req)

self.assertIsInstance(resp, function_pb.RunResponse)

resp.exit.result.output.Unpack(
output_bytes := google.protobuf.wrappers_pb2.BytesValue()
)
output = pickle.loads(output_bytes.value)
output = any_unpickle(resp.exit.result.output)

self.assertEqual(output, "You told me: 'Hello World!' (12 characters)")

Expand Down Expand Up @@ -229,10 +224,8 @@ def execute(
req = function_pb.RunRequest(function=func.name)

if input is not None:
input_bytes = pickle.dumps(input)
input_any = google.protobuf.any_pb2.Any()
input_any.Pack(google.protobuf.wrappers_pb2.BytesValue(value=input_bytes))
req.input.CopyFrom(input_any)
any = any_pickle(input)
req.input.CopyFrom(any)
if state is not None:
req.poll_result.coroutine_state = state
if calls is not None:
Expand Down
12 changes: 3 additions & 9 deletions tests/test_flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from dispatch.flask import Dispatch
from dispatch.function import Arguments, Error, Function, Input, Output
from dispatch.proto import _any_unpickle as any_unpickle
from dispatch.proto import _pb_any_pickle as any_pickle
from dispatch.sdk.v1 import call_pb2 as call_pb
from dispatch.sdk.v1 import function_pb2 as function_pb
from dispatch.signature import (
Expand Down Expand Up @@ -56,23 +57,16 @@ async def my_function(input: Input) -> Output:
)

client = create_endpoint_client(app)
pickled = pickle.dumps("Hello World!")
input_any = google.protobuf.any_pb2.Any()
input_any.Pack(google.protobuf.wrappers_pb2.BytesValue(value=pickled))

req = function_pb.RunRequest(
function=my_function.name,
input=input_any,
function=my_function.name, input=any_pickle("Hello World!")
)

resp = client.run(req)

self.assertIsInstance(resp, function_pb.RunResponse)

resp.exit.result.output.Unpack(
output_bytes := google.protobuf.wrappers_pb2.BytesValue()
)
output = pickle.loads(output_bytes.value)
output = any_unpickle(resp.exit.result.output)

self.assertEqual(output, "You told me: 'Hello World!' (12 characters)")

Expand Down
13 changes: 3 additions & 10 deletions tests/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from dispatch.function import Arguments, Error, Function, Input, Output, Registry
from dispatch.http import Dispatch
from dispatch.proto import _any_unpickle as any_unpickle
from dispatch.proto import _pb_any_pickle as any_pickle
from dispatch.sdk.v1 import call_pb2 as call_pb
from dispatch.sdk.v1 import function_pb2 as function_pb
from dispatch.signature import parse_verification_key, public_key_from_pem
Expand Down Expand Up @@ -91,22 +92,14 @@ async def my_function(input: Input) -> Output:
http_client = dispatch.test.httpx.Client(httpx.Client(base_url=self.endpoint))
client = EndpointClient(http_client)

pickled = pickle.dumps("Hello World!")
input_any = google.protobuf.any_pb2.Any()
input_any.Pack(google.protobuf.wrappers_pb2.BytesValue(value=pickled))

req = function_pb.RunRequest(
function=my_function.name,
input=input_any,
function=my_function.name, input=any_pickle("Hello World!")
)

resp = client.run(req)

self.assertIsInstance(resp, function_pb.RunResponse)

resp.exit.result.output.Unpack(
output_bytes := google.protobuf.wrappers_pb2.BytesValue()
)
output = pickle.loads(output_bytes.value)
output = any_unpickle(resp.exit.result.output)

self.assertEqual(output, "You told me: 'Hello World!' (12 characters)")

0 comments on commit b7acacf

Please sign in to comment.