diff --git a/byte_micro_perf/backends/utils.py b/byte_micro_perf/backends/utils.py index 7ec9684d..36ee634e 100644 --- a/byte_micro_perf/backends/utils.py +++ b/byte_micro_perf/backends/utils.py @@ -36,6 +36,9 @@ def dump_communication_ops_report( mb = dtype_size * size / 1024 / 1024 algo_bw = dtype_size * size / latency / 1e3 bus_bw = algo_bw * (group_size - 1) / group_size + + if op_name == "broadcast": + bus_bw = algo_bw if op_name == "allreduce": bus_bw *= 2