diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py index 6979cdd934..7761afef7a 100644 --- a/distributed/comm/ucx.py +++ b/distributed/comm/ucx.py @@ -190,7 +190,11 @@ async def write( ] # Send meta data + + # Send # of frames (uint64) await self.ep.send(struct.pack("Q", nframes)) + # Send which frames are CUDA (bool) and + # how large each frame is (uint64) await self.ep.send( struct.pack(nframes * "?" + nframes * "Q", *cuda_frames, *sizes) ) @@ -222,11 +226,15 @@ async def read(self, deserializers=("cuda", "dask", "pickle", "error")): try: # Recv meta data + + # Recv # of frames (uint64) nframes_fmt = "Q" nframes = host_array(struct.calcsize(nframes_fmt)) await self.ep.recv(nframes) (nframes,) = struct.unpack(nframes_fmt, nframes) + # Recv which frames are CUDA (bool) and + # how large each frame is (uint64) header_fmt = nframes * "?" + nframes * "Q" header = host_array(struct.calcsize(header_fmt)) await self.ep.recv(header)