Skip to content

Commit

Permalink
polish: Add comments
Browse files Browse the repository at this point in the history
Signed-off-by: Yuhan Ruan <[email protected]>
  • Loading branch information
AndyUB committed Nov 28, 2024
1 parent 9ae354e commit f60555a
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 28 deletions.
49 changes: 35 additions & 14 deletions python/ray/dag/compiled_dag_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,10 @@ def __init__(
self.output_type_hint: ChannelOutputType = task.dag_node.type_hint

# [TODO:andyub] One requires_nccl instead of three.
# We need a flag for NCCL read because currently we only support
# overlapping NCCL read with computation. The other two flags are kept
# for symmetry. We may be able to merge them into one flag after
# supporting overlapping NCCL collective with computation.
self.requires_nccl_read = task.dag_node.requires_nccl_read
self.requires_nccl_write = task.dag_node.requires_nccl_write
self.requires_nccl_collective = task.dag_node.requires_nccl_collective
Expand Down Expand Up @@ -710,7 +714,7 @@ def exec_operation(
# Channel closed. Exit the loop.
return True

# To overlap GPU communication for NCCL read, launch the NCCL recv operation,
# To overlap GPU communication for NCCL recv, launch the NCCL recv operation,
# skip the normal compute operation, and return the future without waiting.
if not self.requires_nccl_read or not overlap_gpu_communication:
input_data = self.reset_and_wait_intermediate_future()
Expand All @@ -721,7 +725,7 @@ def exec_operation(
val = input_data[i]
if isinstance(val, DAGOperationFuture):
resolved_future = val.wait()
# The only source of future is NCCL read.
# The only source of future is NCCL recv.
# The future wraps around a one-element list.
assert isinstance(resolved_future, list)
assert len(resolved_future) == 1
Expand Down Expand Up @@ -762,7 +766,7 @@ def exec_operation(
)

with self._send_stream:
# To overlap GPU communication for NCCL read, write the future as output to
# To overlap GPU communication for NCCL recv, write the future as output to
# the downstream task, which waits on the future in its compute operation.
if self.requires_nccl_read and overlap_gpu_communication:
output_val = self._intermediate_future
Expand Down Expand Up @@ -1027,9 +1031,25 @@ def _add_node(self, node: "ray.dag.DAGNode") -> None:
self.dag_node_to_idx[node] = idx
self.counter += 1

def _add_nccl_p2p_nodes(self) -> None:
def _update_nccl_p2p_nodes(self) -> None:
"""
Add NCCL P2P send/recv nodes to the DAG.
Find DAG nodes that involve in NCCL send/recv operations. Create nodes
to represent these operations and add them to the DAG.
Check for errors as well:
1. The driver cannot participate in NCCL send/recv operations.
2. An actor must be present for a NCCL send/recv operation.
3. NcclSendNode and NcclRecvNode should not be directly added to the DAG.
Example:
a.foo -(NCCL)-> b.bar
is transformed to:
a.foo -(IPC)-> a.nccl_send -(NCCL)-> b.nccl_recv -(IPC)-> b.bar
where IPC is IntraProcessChannel.
"""
from ray.dag import (
DAGNode,
Expand All @@ -1055,12 +1075,13 @@ def get_class_method_node_bind_index(node: ClassMethodNode) -> int:
nccl_send_nodes: Dict[DAGNode, _NcclSendNode] = dict()
nccl_recv_nodes: Dict[DAGNode, Dict[int, _NcclRecvNode]] = defaultdict(dict)

# Gather NCCL P2P send nodes.
# Find all DAG nodes that are NCCL P2P senders. Create and cache a
# NcclSendNode for each of them.
for task in self.idx_to_task.values():
if isinstance(task.dag_node, _NcclP2PNode):
raise ValueError(
"Please use type hints to specify NCCL transport instead of "
"adding NCCLSendNode or NCCLRecvNode to the DAG"
"adding NcclSendNode or NcclRecvNode to the DAG"
)

if not task.dag_node.type_hint.requires_nccl():
Expand Down Expand Up @@ -1096,7 +1117,8 @@ def get_class_method_node_bind_index(node: ClassMethodNode) -> int:
},
)

# Gather NCCL P2P recv nodes.
# Find all DAG nodes that are NCCL P2P receivers. Create and cache a
# NcclRecvNode for each of them.
for task in self.idx_to_task.values():
for arg_idx, arg in enumerate(task.args):
if not isinstance(arg, DAGNode) or not arg.type_hint.requires_nccl():
Expand Down Expand Up @@ -1129,17 +1151,14 @@ def get_class_method_node_bind_index(node: ClassMethodNode) -> int:
)
nccl_recv_nodes[task.dag_node][arg_idx] = recv_node

# Add NCCL P2P send nodes to the DAG.
# Add the newly created NcclSendNodes to the DAG.
for dag_node, send_node in nccl_send_nodes.items():
type_hint = dag_node.type_hint
dag_node.with_type_hint(ChannelOutputType())
send_node.with_type_hint(type_hint)
if dag_node.is_adag_output_node:
dag_node.is_adag_output_node = False
send_node.is_adag_output_node = True
self._add_node(send_node)

# Add NCCL P2P recv nodes to the DAG.
# Add the newly created NcclRecvNodes to the DAG.
for dag_node in nccl_recv_nodes:
new_args: List[Any] = list(dag_node._bound_args)
for arg_idx, recv_node in nccl_recv_nodes[dag_node].items():
Expand All @@ -1165,7 +1184,9 @@ def _preprocess(self) -> None:
)
from ray.dag.collective_node import _CollectiveGroup

self._add_nccl_p2p_nodes()
# Because type hints can be added or removed, we need to update
# the nodes that involve in NCCL P2P operations at compile time.
self._update_nccl_p2p_nodes()

self.input_task_idx, self.output_task_idx = None, None
self.actor_task_count.clear()
Expand Down
3 changes: 2 additions & 1 deletion python/ray/dag/dag_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ def __init__(
# Whether this node calls `experimental_compile`.
self.is_adag_output_node = False

# [CL]
# Whether this node requires NCCL read/write/collective operations.
# [TODO:andyub] Merge these into a single requires_nccl flag.
self._requires_nccl_read = False
self._requires_nccl_write = False
self._requires_nccl_collective = False
Expand Down
23 changes: 10 additions & 13 deletions python/ray/dag/dag_node_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,17 +267,19 @@ def _select_next_nodes(

if top_priority_node is None:
return None
next_nodes: Set[_DAGOperationGraphNode] = {
heapq.heappop(actor_to_candidates[top_priority_node.actor_handle._actor_id])
}

heapq.heappop(actor_to_candidates[top_priority_node.actor_handle._actor_id])

if top_priority_node.sync_group is not None:
next_nodes = []
for idx in top_priority_node.sync_group.task_idxs:
node = graph[idx]
assert node.is_ready
next_nodes.add(node)
next_nodes.append(node)
else:
next_nodes = [top_priority_node]

return list(next_nodes)
return next_nodes


def _build_dag_node_operation_graph(
Expand Down Expand Up @@ -342,8 +344,6 @@ def _build_dag_node_operation_graph(
isinstance(task.dag_node, ClassMethodNode)
and task.dag_node.is_class_method_output
):
# [CL]
# TODO(wxdeng): Handle the case where the task is a class method output.
continue
for downstream_task_idx in task.downstream_task_idxs:
downstream_dag_node = idx_to_task[downstream_task_idx].dag_node
Expand All @@ -353,9 +353,6 @@ def _build_dag_node_operation_graph(
isinstance(downstream_dag_node, ClassMethodNode)
and downstream_dag_node.is_class_method_output
):
# [CL]
# TODO(wxdeng): Handle the case where the downstream task is
# a class method output.
continue
if graph[task_idx].requires_nccl_write:
assert graph[downstream_task_idx].requires_nccl_read
Expand Down Expand Up @@ -614,14 +611,14 @@ def _generate_overlapped_execution_schedule(
"ray.actor.ActorHandle", List[_DAGOperationGraphNode]
] = copy.deepcopy(actor_to_execution_schedule)
for overlapped_schedule in actor_to_overlapped_schedule.values():
# Swap each NCCL read operation with its previous compute node to overlap
# the NCCL read operation with computation. The index starts at 1 because
# the first node has no previous node.
for i in range(1, len(overlapped_schedule)):
if (
overlapped_schedule[i].requires_nccl_read
and not overlapped_schedule[i - 1].requires_nccl_op
):
# For each NCCL read operation (i.e., recv), find the nearest
# compute node to swap with so that the NCCL read operation
# can be overlapped with computation.
overlapped_schedule[i], overlapped_schedule[i - 1] = (
overlapped_schedule[i - 1],
overlapped_schedule[i],
Expand Down

0 comments on commit f60555a

Please sign in to comment.