diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index 615f83b6b0..f722d0f481 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -416,16 +416,23 @@ async def test_closed_worker_during_transfer(c, s, a, b): config={"distributed.scheduler.allowed-failures": 0}, ) async def test_restarting_during_transfer_raises_killed_worker(c, s, a, b): + await c.register_plugin(BlockedShuffleReceiveShuffleWorkerPlugin(), name="shuffle") df = dask.datasets.timeseries( start="2000-01-01", - end="2000-03-01", + end="2000-02-01", dtypes={"x": float, "y": float}, freq="10 s", ) + shuffle_extA = a.plugins["shuffle"] + shuffle_extB = b.plugins["shuffle"] with dask.config.set({"dataframe.shuffle.method": "p2p"}): out = df.shuffle("x") out = c.compute(out.x.size) - await wait_for_tasks_in_state("shuffle-transfer", "memory", 1, b) + await asyncio.gather( + shuffle_extA.in_shuffle_receive.wait(), shuffle_extB.in_shuffle_receive.wait() + ) + shuffle_extA.block_shuffle_receive.set() + shuffle_extB.block_shuffle_receive.set() await assert_worker_cleanup(b, close=True) with pytest.raises(KilledWorker): diff --git a/distributed/tests/test_cancelled_state.py b/distributed/tests/test_cancelled_state.py index ddfd4f7738..ece5ef3fdc 100644 --- a/distributed/tests/test_cancelled_state.py +++ b/distributed/tests/test_cancelled_state.py @@ -39,6 +39,7 @@ SecedeEvent, TaskFinishedMsg, UpdateDataEvent, + WorkerState, ) @@ -825,8 +826,11 @@ async def release_all_futures(): @pytest.mark.parametrize("intermediate_state", ["resumed", "cancelled"]) -@pytest.mark.parametrize("close_worker", [False, True]) -@gen_cluster(client=True, config={"distributed.comm.timeouts.connect": "500ms"}) +@pytest.mark.parametrize("close_worker", [True]) +@gen_cluster( + client=True, + config={"distributed.comm.timeouts.connect": "500ms"}, +) async def test_deadlock_cancelled_after_inflight_before_gather_from_worker( c, s, a, x, intermediate_state, close_worker ): @@ -839,10 +843,34 @@ async def test_deadlock_cancelled_after_inflight_before_gather_from_worker( fut2 = c.submit(sum, [fut1, fut1B], workers=[x.address], key="f2") await fut2 - async with BlockedGatherDep(s.address, name="b") as b: + class InstrumentedWorkerState(WorkerState): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.fut2_in_flight = asyncio.Event() + self.fut2_in_intermediate = asyncio.Event() + + def _transition(self, ts, finish, *args, **kwargs): + def _verify_state(finish): + if ts.key == fut2.key: + if isinstance(finish, tuple) and finish[0] == "flight": + self.fut2_in_flight.set() + if self.fut2_in_flight.is_set() and finish == intermediate_state: + self.fut2_in_intermediate.set() + + # The expected state might be either the requested one or the + # actual, final state + _verify_state(finish) + try: + return super()._transition(ts, finish, *args, **kwargs) + finally: + _verify_state(ts.state) + + async with BlockedGatherDep( + s.address, name="b", WorkerStateClass=InstrumentedWorkerState + ) as b: fut3 = c.submit(inc, fut2, workers=[b.address], key="f3") - await wait_for_state(fut2.key, "flight", b) + await b.state.fut2_in_flight.wait() s.set_restrictions(worker={fut1B.key: a.address, fut2.key: b.address}) @@ -855,7 +883,7 @@ async def test_deadlock_cancelled_after_inflight_before_gather_from_worker( stimulus_id="remove-worker", ) - await wait_for_state(fut2.key, intermediate_state, b, interval=0) + await b.state.fut2_in_intermediate.wait() b.block_gather_dep.set() await fut3 diff --git a/distributed/worker.py b/distributed/worker.py index 7e3fecb9b2..0cb84dd252 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -524,6 +524,7 @@ def __init__( ################################### # Parameters to Server scheduler_sni: str | None = None, + WorkerStateClass: type = WorkerState, **kwargs, ): if reconnect is not None: @@ -770,7 +771,7 @@ def __init__( transfer_incoming_bytes_limit = int( self.memory_manager.memory_limit * transfer_incoming_bytes_fraction ) - state = WorkerState( + state = WorkerStateClass( nthreads=nthreads, data=self.memory_manager.data, threads=self.threads,