Skip to content

Commit

Permalink
Fix merge issue
Browse files Browse the repository at this point in the history
  • Loading branch information
Binyang2014 committed Oct 28, 2024
1 parent 3c94ad5 commit db92968
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 0 deletions.
10 changes: 10 additions & 0 deletions msccl/language/instruction_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,16 @@ def circular_dep_after_merge(op: Op, other_op: Op):
frontier.append(n)
frontier = frontier[1:]

"""
For case: op2.prev = [op1, op3]. op1.next = [op2]. op3.next = [op2]. And op1 and op2 are satisfied to merge.
We only apply the merge if all previous ops of op2 are visited after the merge.
"""
def all_prevs_visited_after_merge(op: Op, other_op: Op):
step = op.step
for prev in other_op.prev:
if prev.step > step:
return False
return True

def same_tb(op1: Op, op2: Op):
return op1.tb == op2.tb and op1.channel == op2.channel
Expand Down
5 changes: 5 additions & 0 deletions msccl/language/mscclpp/instruction_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
same_count,
same_buf_dst,
same_buf_src,
all_prevs_visited_after_merge,
)
from msccl.language.types import ChunkRef, ChannelType, MscclppInstruction as Instruction, Op, Threadblock

Expand Down Expand Up @@ -41,6 +42,7 @@ def try_merge_same_instructions(
and same_count(op, next_op)
and same_chan_type(op, next_op)
and not circular_dep_after_merge(op, next_op)
and all_prevs_visited_after_merge(op, next_op)
):
# Append the source chunks from next_op
op.srcs.append(
Expand Down Expand Up @@ -85,6 +87,7 @@ def try_compact_instructions(
and same_chan_type(op, seq_op)
and same_count(op, seq_op)
and not circular_dep_after_merge(op, seq_op)
and all_prevs_visited_after_merge(op, seq_op)
):
# Append the source and destination chunks from seq_op
op.dsts.append(
Expand Down Expand Up @@ -124,6 +127,7 @@ def try_fuse_with_put(self, op: Op, next_op: Op, tb: Threadblock, queue: list) -
and next_op.channel_type == ChannelType.sm
and (op.channel_type == ChannelType.none or op.channel_type == ChannelType.sm)
and not circular_dep_after_merge(op, next_op)
and all_prevs_visited_after_merge(op, next_op)
):
if len(op.dsts) > 0 and op.dsts[0][0].buffer != next_op.dst.buffer:
return False
Expand Down Expand Up @@ -170,6 +174,7 @@ def try_fuse_instructions_using_proxy_channel(
and same_chan_type(op, next_op)
and op.channel_type == ChannelType.proxy
and not circular_dep_after_merge(op, next_op)
and all_prevs_visited_after_merge(op, next_op)
):
if op.inst == Instruction.put and next_op.inst == Instruction.signal:
op.inst = Instruction.put_with_signal
Expand Down

0 comments on commit db92968

Please sign in to comment.