Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
Binyang2014 committed Nov 5, 2024
1 parent dea16a6 commit 6ac918b
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions msccl/language/mscclpp/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,10 @@ def get_channel_ids(chunk_list, tb_channel_dict, src_buffer, dst_buffer, chan_ty
def remove_empty_fields(d):
return {k: v for k, v in d.items() if v not in [None, "", [], {}]}

max_scratch = max(gpu.scratch_chunks for gpu in program.gpus)
max_input = max(gpu.input_chunks for gpu in program.gpus)
max_output = max(gpu.output_chunks for gpu in program.gpus)

for id, gpu in enumerate(program.gpus):
gpu_instance = {
"id": id,
Expand All @@ -196,6 +200,22 @@ def remove_empty_fields(d):
gpu_instance["channels"].append(obj)
gpu_instance["channels"] = list(filter(lambda x: x["type"] != "none", gpu_instance["channels"]))
gpu_instance["channels"] = sorted(gpu_instance["channels"], key=lambda x: (x["srcbuff"], x["dstbuff"]))

# render GPU NVLS channels
for i, chan in enumerate(gpu_instance["channels"]):
if chan["type"] == "nvls":
buff = chan["srcbuff"]
buffer_size = (
max_input
if buff == Buffer.input.value
else max_output if buff == Buffer.output.value else max_scratch
)
gpu_instance["channels"][i] = {
"buff": chan["srcbuff"],
"type": chan["type"],
"rankGroups": [{"size": buffer_size, "ranks": ranks} for ranks in chan["connectedTo"]],
}

for tb in gpu.threadblocks:
if tb.id < 0:
continue
Expand Down

0 comments on commit 6ac918b

Please sign in to comment.