diff --git a/byte_micro_perf/core/perf_engine.py b/byte_micro_perf/core/perf_engine.py index 93968d44..aceba919 100644 --- a/byte_micro_perf/core/perf_engine.py +++ b/byte_micro_perf/core/perf_engine.py @@ -182,7 +182,7 @@ def parse_workload(workload): -ConfigInstance = namedtuple("ConfigInstance", ["dtype", "tensor_shapes", "index"]) +ConfigInstance = namedtuple("ConfigInstance", ["dtype", "tensor_shapes", "index", "total"]) ResultItem = namedtuple("ResultItem", ["config", "report"]) @@ -261,7 +261,7 @@ def start_engine(self) -> None: case_index = 0 for dtype in dtype_list: for shape in shape_list: - test_list.append(ConfigInstance(dtype, shape, case_index)) + test_list.append(ConfigInstance(dtype, shape, case_index + 1, len(dtype_list) * len(shape_list))) case_index = case_index + 1 try: @@ -379,7 +379,6 @@ def perf_func(self, rank: int, *args): test_dtype = test_instance.dtype test_shape = test_instance.tensor_shapes - print(f"rank {rank}, {test_instance}") """ input_shape could be: @@ -399,6 +398,15 @@ def perf_func(self, rank: int, *args): if reports and "Error" not in reports: result_list.append(ResultItem(test_instance, reports)) + latency = reports.get("Avg latency(us)", 0) + kernel_bw = reports.get("Kernel bandwidth(GB/s)", 0) + bus_bw = reports.get("Bus bandwidth(GB/s)", 0) + + print(f"rank {rank}, {test_instance}, latency: {latency}\nkernel_bw: {kernel_bw}, bus_bw: {bus_bw}") + else: + print(f"rank {rank}, {test_instance}, error") + + output_result_list = [] if world_size > 1: all_result_list = backend_instance.all_gather_object(result_list) @@ -411,8 +419,6 @@ def perf_func(self, rank: int, *args): for test_instance in test_list: test_dtype = test_instance.dtype test_shape = test_instance.tensor_shapes - if rank == 0: - print(f"rank {rank}, {test_instance}") """ input_shape could be: @@ -432,6 +438,15 @@ def perf_func(self, rank: int, *args): if reports and "Error" not in reports: result_list.append(ResultItem(test_instance, reports)) + latency = reports.get("Avg latency(us)", 0) + kernel_bw = reports.get("Kernel bandwidth(GB/s)", 0) + bus_bw = reports.get("Bus bandwidth(GB/s)", 0) + if rank == 0: + print(f"rank {rank}, {test_instance}, latency: {latency}\nkernel_bw: {kernel_bw}, bus_bw: {bus_bw}") + else: + if rank == 0: + print(f"rank {rank}, {test_instance}, error") + # destroy dist if world_size > 1: