Skip to content

Commit

Permalink
Send/recv host and device frames in a message each
Browse files Browse the repository at this point in the history
  • Loading branch information
jakirkham committed Apr 21, 2020
1 parent 5bbd53e commit 1b540a3
Showing 1 changed file with 68 additions and 16 deletions.
84 changes: 68 additions & 16 deletions distributed/comm/ucx.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,34 @@ async def write(
hasattr(f, "__cuda_array_interface__") for f in frames
)
sizes = tuple(nbytes(f) for f in frames)
send_frames = [
each_frame
for each_frame, each_size in zip(frames, sizes)
if each_size
]
host_frames = host_array(
sum(
each_size
for is_cuda, each_size in zip(cuda_frames, sizes)
if not is_cuda
)
)
device_frames = device_array(
sum(
each_size
for is_cuda, each_size in zip(cuda_frames, sizes)
if is_cuda
)
)

# Pack frames
host_frames_view = memoryview(host_frames)
device_frames_view = as_numba_device_array(device_frames)
for each_frame, is_cuda, each_size in zip(frames, cuda_frames, sizes):
if each_size:
if is_cuda:
each_frame_view = as_numba_device_array(each_frame)
device_frames_view[:each_size] = each_frame_view[:]
device_frames_view = device_frames_view[each_size:]
else:
each_frame_view = memoryview(each_frame).cast("B")
host_frames_view[:each_size] = each_frame_view[:]
host_frames_view = host_frames_view[each_size:]

# Send meta data
await self.ep.send(struct.pack("Q", nframes))
Expand All @@ -216,8 +239,10 @@ async def write(
if any(cuda_frames):
synchronize_stream(0)

for each_frame in send_frames:
await self.ep.send(each_frame)
if nbytes(host_frames):
await self.ep.send(host_frames)
if nbytes(device_frames):
await self.ep.send(device_frames)
return sum(sizes)
except (ucp.exceptions.UCXBaseException):
self.abort()
Expand Down Expand Up @@ -248,21 +273,48 @@ async def read(self, deserializers=("cuda", "dask", "pickle", "error")):
raise CommClosedError("While reading, the connection was closed")
else:
# Recv frames
frames = [
device_array(each_size) if is_cuda else host_array(each_size)
for is_cuda, each_size in zip(cuda_frames, sizes)
]
recv_frames = [
each_frame for each_frame in frames if len(each_frame) > 0
]
host_frames = host_array(
sum(
each_size
for is_cuda, each_size in zip(cuda_frames, sizes)
if not is_cuda
)
)
device_frames = device_array(
sum(
each_size
for is_cuda, each_size in zip(cuda_frames, sizes)
if is_cuda
)
)

# It is necessary to first populate `frames` with CUDA arrays and synchronize
# the default stream before starting receiving to ensure buffers have been allocated
if any(cuda_frames):
synchronize_stream(0)

for each_frame in recv_frames:
await self.ep.recv(each_frame)
if nbytes(host_frames):
await self.ep.recv(host_frames)
if nbytes(device_frames):
await self.ep.recv(device_frames)

frames = [
device_array(each_size) if is_cuda else host_array(each_size)
for is_cuda, each_size in zip(cuda_frames, sizes)
]
host_frames_view = memoryview(host_frames)
device_frames_view = as_numba_device_array(device_frames)
for each_frame, is_cuda, each_size in zip(frames, cuda_frames, sizes):
if each_size:
if is_cuda:
each_frame_view = as_numba_device_array(each_frame)
each_frame_view[:] = device_frames_view[:each_size]
device_frames_view = device_frames_view[each_size:]
else:
each_frame_view = memoryview(each_frame)
each_frame_view[:] = host_frames_view[:each_size]
host_frames_view = host_frames_view[each_size:]

msg = await from_frames(
frames, deserialize=self.deserialize, deserializers=deserializers
)
Expand Down

0 comments on commit 1b540a3

Please sign in to comment.