From b4e2db4a5788514c1d1408489533ffdd2c1f7ab1 Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Fri, 14 Jun 2024 16:41:10 -0700 Subject: [PATCH] more fixes after rebase Signed-off-by: Achille Roussel --- src/dispatch/proto.py | 12 +-- src/dispatch/test/__init__.py | 1 + tests/dispatch/test_scheduler.py | 20 +---- tests/test_aiohttp.py | 124 ------------------------------- 4 files changed, 3 insertions(+), 154 deletions(-) delete mode 100644 tests/test_aiohttp.py diff --git a/src/dispatch/proto.py b/src/dispatch/proto.py index 9576892..9af3631 100644 --- a/src/dispatch/proto.py +++ b/src/dispatch/proto.py @@ -78,17 +78,7 @@ def __init__(self, req: function_pb.RunRequest): self._has_input = req.HasField("input") if self._has_input: - if req.input.Is(google.protobuf.wrappers_pb2.BytesValue.DESCRIPTOR): - input_pb = google.protobuf.wrappers_pb2.BytesValue() - req.input.Unpack(input_pb) - input_bytes = input_pb.value - try: - self._input = pickle.loads(input_bytes) - except Exception: - self._input = input_bytes - else: - self._input = _pb_any_unpack(req.input) - + self._input = _pb_any_unpack(req.input) else: if req.poll_result.coroutine_state: raise IncompatibleStateError # coroutine_state is deprecated diff --git a/src/dispatch/test/__init__.py b/src/dispatch/test/__init__.py index 27c0165..e6de18f 100644 --- a/src/dispatch/test/__init__.py +++ b/src/dispatch/test/__init__.py @@ -231,6 +231,7 @@ def make_request(call: Call) -> RunRequest: root_dispatch_id=root_dispatch_id, poll_result=PollResult( coroutine_state=res.poll.coroutine_state, + typed_coroutine_state=res.poll.typed_coroutine_state, results=results, ), ) diff --git a/tests/dispatch/test_scheduler.py b/tests/dispatch/test_scheduler.py index e82d0db..c15ed84 100644 --- a/tests/dispatch/test_scheduler.py +++ b/tests/dispatch/test_scheduler.py @@ -464,7 +464,7 @@ async def resume( poll = assert_poll(prev_output) input = Input.from_poll_results( main.__qualname__, - poll.coroutine_state, + any_unpickle(poll.typed_coroutine_state), call_results, Error.from_exception(poll_error) if poll_error else None, ) @@ -492,30 +492,12 @@ def assert_exit_result_value(output: Output, expect: Any): assert expect == any_unpickle(result.output) -<<<<<<< HEAD - def resume( - self, - main: Callable, - prev_output: Output, - call_results: List[CallResult], - poll_error: Optional[Exception] = None, - ): - poll = self.assert_poll(prev_output) - input = Input.from_poll_results( - main.__qualname__, - any_unpickle(poll.typed_coroutine_state), - call_results, - Error.from_exception(poll_error) if poll_error else None, - ) - return self.runner.run(OneShotScheduler(main).run(input)) -======= def assert_exit_result_error( output: Output, expect: Type[Exception], message: Optional[str] = None ): result = assert_exit_result(output) assert not result.HasField("output") assert result.HasField("error") ->>>>>>> 626d02d (aiohttp: refactor internals to use asyncio throughout the SDK) error = Error._from_proto(result.error).to_exception() assert error.__class__ == expect diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py deleted file mode 100644 index 1f53543..0000000 --- a/tests/test_aiohttp.py +++ /dev/null @@ -1,124 +0,0 @@ -import asyncio -import base64 -import os -import pickle -import struct -import threading -import unittest -from typing import Any, Tuple -from unittest import mock - -import fastapi -import google.protobuf.any_pb2 -import google.protobuf.wrappers_pb2 -import httpx -from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PublicKey - -import dispatch.test.httpx -from dispatch.aiohttp import Dispatch, Server -from dispatch.asyncio import Runner -from dispatch.experimental.durable.registry import clear_functions -from dispatch.function import Arguments, Error, Function, Input, Output, Registry -from dispatch.proto import _any_unpickle as any_unpickle -from dispatch.proto import _pb_any_pickle as any_pickle -from dispatch.sdk.v1 import call_pb2 as call_pb -from dispatch.sdk.v1 import function_pb2 as function_pb -from dispatch.signature import parse_verification_key, public_key_from_pem -from dispatch.status import Status -from dispatch.test import EndpointClient - -public_key_pem = "-----BEGIN PUBLIC KEY-----\nMCowBQYDK2VwAyEAJrQLj5P/89iXES9+vFgrIy29clF9CC/oPPsw3c5D0bs=\n-----END PUBLIC KEY-----" -public_key_pem2 = "-----BEGIN PUBLIC KEY-----\\nMCowBQYDK2VwAyEAJrQLj5P/89iXES9+vFgrIy29clF9CC/oPPsw3c5D0bs=\\n-----END PUBLIC KEY-----" -public_key = public_key_from_pem(public_key_pem) -public_key_bytes = public_key.public_bytes_raw() -public_key_b64 = base64.b64encode(public_key_bytes) - -from datetime import datetime - - -def run(runner: Runner, server: Server, ready: threading.Event): - try: - with runner: - runner.run(serve(server, ready)) - except RuntimeError as e: - pass # silence errors triggered by stopping the loop after tests are done - - -async def serve(server: Server, ready: threading.Event): - async with server: - ready.set() # allow the test to continue after the server started - await asyncio.Event().wait() - - -class TestAIOHTTP(unittest.TestCase): - def setUp(self): - ready = threading.Event() - self.runner = Runner() - - host = "127.0.0.1" - port = 9997 - - self.endpoint = f"http://{host}:{port}" - self.dispatch = Dispatch( - Registry( - endpoint=self.endpoint, - api_key="0000000000000000", - api_url="http://127.0.0.1:10000", - ), - ) - - self.client = httpx.Client(timeout=1.0) - self.server = Server(host, port, self.dispatch) - self.thread = threading.Thread( - target=lambda: run(self.runner, self.server, ready) - ) - self.thread.start() - ready.wait() - - def tearDown(self): - loop = self.runner.get_loop() - loop.call_soon_threadsafe(loop.stop) - self.thread.join(timeout=1.0) - self.client.close() - - def test_content_length_missing(self): - resp = self.client.post(f"{self.endpoint}/dispatch.sdk.v1.FunctionService/Run") - body = resp.read() - self.assertEqual(resp.status_code, 400) - self.assertEqual( - body, b'{"code":"invalid_argument","message":"content length is required"}' - ) - - def test_content_length_too_large(self): - resp = self.client.post( - f"{self.endpoint}/dispatch.sdk.v1.FunctionService/Run", - data={"msg": "a" * 16_000_001}, - ) - body = resp.read() - self.assertEqual(resp.status_code, 400) - self.assertEqual( - body, b'{"code":"invalid_argument","message":"content length is too large"}' - ) - - def test_simple_request(self): - @self.dispatch.registry.primitive_function - async def my_function(input: Input) -> Output: - return Output.value( - f"You told me: '{input.input}' ({len(input.input)} characters)" - ) - - http_client = dispatch.test.httpx.Client(httpx.Client(base_url=self.endpoint)) - client = EndpointClient(http_client) - - req = function_pb.RunRequest( - function=my_function.name, - input=any_pickle("Hello World!"), - ) - - resp = client.run(req) - - self.assertIsInstance(resp, function_pb.RunResponse) - - output = any_unpickle(resp.exit.result.output) - - self.assertEqual(output, "You told me: 'Hello World!' (12 characters)")