From 69efe7985d248ff2967edff0f814ff8e78f7a0b9 Mon Sep 17 00:00:00 2001 From: kevinsouthByteDance Date: Mon, 19 Aug 2024 11:03:56 +0000 Subject: [PATCH] [micro_perf] fix parsing logic of group gemm --- byte_micro_perf/core/perf_engine.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/byte_micro_perf/core/perf_engine.py b/byte_micro_perf/core/perf_engine.py index db59bb1c..43cc01d6 100644 --- a/byte_micro_perf/core/perf_engine.py +++ b/byte_micro_perf/core/perf_engine.py @@ -168,10 +168,12 @@ def parse_workload(workload): kn = input_shape_group.get("KN", []) if k and n: kn.append([list(shape) for shape in itertools.product(k, n)]) - for group in groups: - for batch in batches: - for _kn in kn: - shape_list.append([[[group * batch, _kn[0]], [_kn[0], _kn[1]]]]) + for batch in batches: + for _kn in kn: + group_input_shape_list = [] + for group in groups: + group_input_shape_list.append([[group * batch, _kn[0]], [_kn[0], _kn[1]]]) + shape_list.append(group_input_shape_list) # gemm else: if m and n and k: