diff --git a/distributed/client.py b/distributed/client.py index f01eb31190..290feabd01 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -1502,7 +1502,17 @@ async def __aenter__(self): async def __aexit__(self, exc_type, exc_value, traceback): if self._previous_as_current: - _current_client.reset(self._previous_as_current) + try: + _current_client.reset(self._previous_as_current) + except ValueError as e: + if not e.args[0].endswith(" was created in a different Context"): + raise # pragma: nocover + warnings.warn( + "It is deprecated to enter and exit the Client context " + "manager from different tasks", + DeprecationWarning, + stacklevel=2, + ) await self._close( # if we're handling an exception, we assume that it's more # important to deliver that exception than shutdown gracefully. @@ -1512,7 +1522,17 @@ async def __aexit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type, exc_value, traceback): if self._previous_as_current: - _current_client.reset(self._previous_as_current) + try: + _current_client.reset(self._previous_as_current) + except ValueError as e: + if not e.args[0].endswith(" was created in a different Context"): + raise # pragma: nocover + warnings.warn( + "It is deprecated to enter and exit the Client context " + "manager from different threads", + DeprecationWarning, + stacklevel=2, + ) self.close() def __del__(self): diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 2ed731a160..fa60de1d88 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -1274,6 +1274,33 @@ async def client_2(): await asyncio.gather(client_1(), client_2()) +@gen_cluster(client=False, nthreads=[]) +async def test_context_manager_used_from_different_tasks(s): + c = Client(s.address, asynchronous=True) + await asyncio.create_task(c.__aenter__()) + with pytest.warns( + DeprecationWarning, + match=r"It is deprecated to enter and exit the Client context manager " + "from different tasks", + ): + await asyncio.create_task(c.__aexit__(None, None, None)) + + +def test_context_manager_used_from_different_threads(s, loop): + c = Client(s["address"]) + with ( + concurrent.futures.ThreadPoolExecutor(1) as tp1, + concurrent.futures.ThreadPoolExecutor(1) as tp2, + ): + tp1.submit(c.__enter__).result() + with pytest.warns( + DeprecationWarning, + match=r"It is deprecated to enter and exit the Client context manager " + "from different threads", + ): + tp2.submit(c.__exit__, None, None, None).result() + + def test_global_clients(loop): assert _get_global_client() is None with pytest.raises(