Skip to content

Commit

Permalink
Fix mfma cnt
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanglx13 committed Oct 4, 2024
1 parent 995144f commit d47adce
Showing 1 changed file with 22 additions and 5 deletions.
27 changes: 22 additions & 5 deletions python/perf-kernels/tools/tune_gemm/process_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ def gen_all_clk(code_fullname, trace_fullname):
code_list = code_data['code']

found_1st_barrier = False
mfma_cnt = 0
mfma_dsRead_cnt = 0
mfma_cnt_total = 0
should_cnt = False
## Find the s_barriers
for i in range(len(code_list)):
Expand All @@ -68,10 +69,14 @@ def gen_all_clk(code_fullname, trace_fullname):
## This is barrier2 or barrier3
should_cnt = False
if "mfma" in code_list[i][0] and should_cnt:
mfma_cnt += 1
mfma_dsRead_cnt += 1
if "mfma" in code_list[i][0]:
mfma_cnt_total += 1

mfma_dsRead_cnt = mfma_cnt
mfma_dsWrite_cnt = 128 - mfma_cnt
## /= 2 because the last iteration of local_load and tt.dot
## is peeled off by stream-pipeliner
mfma_cnt_total /= 2
mfma_dsWrite_cnt = mfma_cnt_total - mfma_dsRead_cnt

if len(marker_barrier) != 3:
print(f"Not 3 barriers?? Found {len(marker_barrier)}")
Expand Down Expand Up @@ -121,7 +126,16 @@ def gen_all_clk(code_fullname, trace_fullname):
if len1 == 0 or len2 == 0 or len3 == 0:
incomplete = True

return firstInstr_clk, instrAfterBarrier1_clk, instrAfterBarrier2_clk, instrAfterBarrier3_clk, lastInstr_clk, mfma_dsRead_cnt, mfma_dsWrite_cnt, incomplete
#print(f"{firstInstr_clk}")
#print(f"{instrAfterBarrier1_clk}")
#print(f"{instrAfterBarrier2_clk}")
#print(f"{instrAfterBarrier3_clk}")
#print(f"{lastInstr_clk}")
#print(f"{mfma_dsRead_cnt}")
#print(f"{mfma_dsWrite_cnt}")
#print(f"{incomplete}")

return firstInstr_clk, instrAfterBarrier1_clk, instrAfterBarrier2_clk, instrAfterBarrier3_clk, lastInstr_clk, mfma_dsRead_cnt, int(mfma_dsWrite_cnt), incomplete


def gen_coarse_clk(instr0_clk, bar1_clk, bar3_clk, instr9_clk):
Expand Down Expand Up @@ -225,6 +239,9 @@ def main():
continue
trace_filename = f"se{se}_sm{sm}_sl{sl}_wv{wid}.json"
trace_fullname = os.path.join(trace_dir, trace_filename)
if not os.path.isfile(trace_fullname):
#print(f"trace file not found {trace_fullname}")
return
pro, loop, epi, iter_clk, lat1, lat2, lat_sum, idle1, idle2, incomplete = parse_trace(code_fullname, trace_fullname)
if incomplete:
continue
Expand Down

0 comments on commit d47adce

Please sign in to comment.