Skip to content

Commit

Permalink
Merge pull request bytedance#92 from bytedance/xuzhenglin/fix_gemm
Browse files Browse the repository at this point in the history
[micro_perf] fix parsing logic of group gemm
  • Loading branch information
YJessicaGao authored Aug 19, 2024
2 parents fe25ade + 69efe79 commit 9e4c3c7
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions byte_micro_perf/core/perf_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,12 @@ def parse_workload(workload):
kn = input_shape_group.get("KN", [])
if k and n:
kn.append([list(shape) for shape in itertools.product(k, n)])
for group in groups:
for batch in batches:
for _kn in kn:
shape_list.append([[[group * batch, _kn[0]], [_kn[0], _kn[1]]]])
for batch in batches:
for _kn in kn:
group_input_shape_list = []
for group in groups:
group_input_shape_list.append([[group * batch, _kn[0]], [_kn[0], _kn[1]]])
shape_list.append(group_input_shape_list)
# gemm
else:
if m and n and k:
Expand Down

0 comments on commit 9e4c3c7

Please sign in to comment.