diff --git a/msccl/language/mscclpp/ir.py b/msccl/language/mscclpp/ir.py index 2981db3..d5dbae8 100644 --- a/msccl/language/mscclpp/ir.py +++ b/msccl/language/mscclpp/ir.py @@ -174,6 +174,10 @@ def get_channel_ids(chunk_list, tb_channel_dict, src_buffer, dst_buffer, chan_ty def remove_empty_fields(d): return {k: v for k, v in d.items() if v not in [None, "", [], {}]} + max_scratch = max(gpu.scratch_chunks for gpu in program.gpus) + max_input = max(gpu.input_chunks for gpu in program.gpus) + max_output = max(gpu.output_chunks for gpu in program.gpus) + for id, gpu in enumerate(program.gpus): gpu_instance = { "id": id, @@ -196,6 +200,22 @@ def remove_empty_fields(d): gpu_instance["channels"].append(obj) gpu_instance["channels"] = list(filter(lambda x: x["type"] != "none", gpu_instance["channels"])) gpu_instance["channels"] = sorted(gpu_instance["channels"], key=lambda x: (x["srcbuff"], x["dstbuff"])) + + # render GPU NVLS channels + for i, chan in enumerate(gpu_instance["channels"]): + if chan["type"] == "nvls": + buff = chan["srcbuff"] + buffer_size = ( + max_input + if buff == Buffer.input.value + else max_output if buff == Buffer.output.value else max_scratch + ) + gpu_instance["channels"][i] = { + "buff": chan["srcbuff"], + "type": chan["type"], + "rankGroups": [{"size": buffer_size, "ranks": ranks} for ranks in chan["connectedTo"]], + } + for tb in gpu.threadblocks: if tb.id < 0: continue