diff --git a/tests/test_client.py b/tests/test_client.py index 7c4d422..e6fa5ab 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -34,18 +34,29 @@ def test_can_be_constructed_on_https(): Client(api_url="https://example.com", api_key="foo") +# On Python 3.8/3.9, pytest.mark.asyncio doesn't work with mock.patch.dict, +# so we have to use the old-fashioned way of setting the environment variable +# and then cleaning it up manually. +# +# @mock.patch.dict(os.environ, {"DISPATCH_API_KEY": "0000000000000000"}) @pytest.mark.asyncio -@mock.patch.dict(os.environ, {"DISPATCH_API_KEY": "0000000000000000"}) async def test_api_key_from_env(): - async with server() as api: - client = Client(api_url=api.url) - - with pytest.raises( - PermissionError, - match=r"Dispatch received an invalid authentication token \(check DISPATCH_API_KEY is correct\)", - ) as mc: - await client.dispatch([Call(function="my-function", input=42)]) - + prev_api_key = os.environ.get("DISPATCH_API_KEY") + try: + os.environ["DISPATCH_API_KEY"] = "0000000000000000" + async with server() as api: + client = Client(api_url=api.url) + + with pytest.raises( + PermissionError, + match=r"Dispatch received an invalid authentication token \(check DISPATCH_API_KEY is correct\)", + ) as mc: + await client.dispatch([Call(function="my-function", input=42)]) + finally: + if prev_api_key is None: + del os.environ["DISPATCH_API_KEY"] + else: + os.environ["DISPATCH_API_KEY"] = prev_api_key @pytest.mark.asyncio async def test_api_key_from_arg(): diff --git a/tests/test_fastapi.py b/tests/test_fastapi.py index 9c067ef..a317a9b 100644 --- a/tests/test_fastapi.py +++ b/tests/test_fastapi.py @@ -57,7 +57,7 @@ def dispatch_test_init(self, reg: Registry) -> str: self.sockets = [sock] self.uvicorn = uvicorn.Server(config) self.runner = Runner() - if sys.version_info >= (3, 9): + if sys.version_info >= (3, 10): self.event = asyncio.Event() else: self.event = asyncio.Event(loop=self.runner.get_loop()) diff --git a/tests/test_http.py b/tests/test_http.py index 69ad365..cba4742 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -46,7 +46,7 @@ def dispatch_test_init(self, reg: Registry) -> str: self.aiohttp = Server(host, port, Dispatch(reg)) self.aioloop.run(self.aiohttp.start()) - if sys.version_info >= (3, 9): + if sys.version_info >= (3, 10): self.aiowait = asyncio.Event() else: self.aiowait = asyncio.Event(loop=self.aioloop.get_loop())