From afbabdff70618c76d9f02a59d5d147b47c3fc051 Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Tue, 11 Jun 2024 17:44:14 -0700 Subject: [PATCH] fix slow FastAPI tests Signed-off-by: Achille Roussel --- src/dispatch/test/__init__.py | 21 ++++++++++++++------- tests/test_fastapi.py | 11 +++++++++-- tests/test_http.py | 16 ++++++++-------- 3 files changed, 31 insertions(+), 17 deletions(-) diff --git a/src/dispatch/test/__init__.py b/src/dispatch/test/__init__.py index 48a2f38..1b8c14a 100644 --- a/src/dispatch/test/__init__.py +++ b/src/dispatch/test/__init__.py @@ -13,7 +13,7 @@ import dispatch.experimental.durable.registry from dispatch.function import Client as BaseClient -from dispatch.function import ClientError +from dispatch.function import ClientError, Input, Output from dispatch.function import Registry as BaseRegistry from dispatch.http import Dispatch from dispatch.http import Server as BaseServer @@ -372,16 +372,23 @@ async def test_content_length_too_large(self): "message": "content length is too large", } - def test_simple_end_to_end(self): + @aiotest + async def test_call_function_no_input(self): + @self.dispatch.function + def my_function() -> str: + return "Hello World!" + + ret = await my_function() + assert ret == "Hello World!" + + @aiotest + async def test_call_function_with_input(self): @self.dispatch.function def my_function(name: str) -> str: return f"Hello world: {name}" - async def test(): - msg = await my_function("52") - assert msg == "Hello world: 52" - - self.loop.run_until_complete(test()) + ret = await my_function("52") + assert ret == "Hello world: 52" class ClientRequestContentLengthMissing(aiohttp.ClientRequest): diff --git a/tests/test_fastapi.py b/tests/test_fastapi.py index 63d2d01..119f927 100644 --- a/tests/test_fastapi.py +++ b/tests/test_fastapi.py @@ -21,6 +21,7 @@ from fastapi.testclient import TestClient import dispatch +from dispatch.asyncio import Runner from dispatch.experimental.durable.registry import clear_functions from dispatch.fastapi import Dispatch from dispatch.function import Arguments, Error, Function, Input, Output @@ -54,13 +55,19 @@ def dispatch_test_init(self, api_key: str, api_url: str) -> Dispatch: config = uvicorn.Config(app, host=host, port=port) self.sockets = [sock] self.uvicorn = uvicorn.Server(config) + self.runner = Runner() + self.event = asyncio.Event() return reg def dispatch_test_run(self): - self.uvicorn.run(self.sockets) + loop = self.runner.get_loop() + loop.create_task(self.uvicorn.serve(self.sockets)) + self.runner.run(self.event.wait()) + self.runner.close() def dispatch_test_stop(self): - self.uvicorn.should_exit = True + loop = self.runner.get_loop() + loop.call_soon_threadsafe(self.event.set) def create_dispatch_instance(app: fastapi.FastAPI, endpoint: str): diff --git a/tests/test_http.py b/tests/test_http.py index a81a4a0..fb07ec8 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -4,6 +4,7 @@ from http.server import HTTPServer import dispatch.test +from dispatch.asyncio import Runner from dispatch.function import Registry from dispatch.http import Dispatch, Server @@ -52,19 +53,18 @@ def dispatch_test_init(self, api_key: str, api_url: str) -> Registry: api_url=api_url, ) - self.aioloop = asyncio.new_event_loop() + self.aiowait = asyncio.Event() + self.aioloop = Runner() self.aiohttp = Server(host, port, Dispatch(reg)) - self.aioloop.run_until_complete(self.aiohttp.start()) + self.aioloop.run(self.aiohttp.start()) reg.endpoint = f"http://{self.aiohttp.host}:{self.aiohttp.port}" return reg def dispatch_test_run(self): - self.aioloop.run_forever() + self.aioloop.run(self.aiowait.wait()) + self.aioloop.run(self.aiohttp.stop()) + self.aioloop.close() def dispatch_test_stop(self): - def stop(): - self.aiohttp.stop() - self.aioloop.stop() - - self.aioloop.call_soon_threadsafe(stop) + self.aioloop.get_loop().call_soon_threadsafe(self.aiowait.set)