Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
fsx950223 committed Oct 22, 2024
1 parent f8d19e7 commit ecb691c
Showing 1 changed file with 29 additions and 48 deletions.
77 changes: 29 additions & 48 deletions tensorflow/tools/hipblaslt/tensile_config_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,30 +168,31 @@ def extract_dtype(match):
unique_gemms_subgroups[i%args.gpus] = [(k, v)]

def find_matmul_instruction(mfma_instruction, size, CU):
for m_tiles in reversed(range(1, CU+1)):
if size[0] // m_tiles > 256:
continue
wave_tile_m = math.ceil(size[0] // m_tiles / mfma_instruction[0])
if wave_tile_m <= 0:
continue
for n_tiles in reversed(range(1, CU+1)):
if size[1] // n_tiles > 256:
for bm in range(int(math.log(mfma_instruction[3],2))+1):
for m_tiles in reversed(range(1, CU+1)):
if size[0] // m_tiles > 256:
continue
wave_tile_n = math.ceil(size[1] // n_tiles / mfma_instruction[1])
if wave_tile_n <= 0:
wave_tile_m = math.ceil(size[0] // m_tiles / mfma_instruction[0])
if wave_tile_m <= 0:
continue
matmul_instruction = mfma_instruction + [1, 1, 1, 1, 1]
for k in reversed(range(3)):
if wave_tile_m // (2**k) > 0:
matmul_instruction[-4] = wave_tile_m // (2**k)
matmul_instruction[-2] = 2**k

for l in reversed(range(3)):
if wave_tile_n // (2**l) > 0:
matmul_instruction[-3] = wave_tile_n // (2**l)
matmul_instruction[-1] = 2**l
for n_tiles in reversed(range(1, CU+1)):
if size[1] // n_tiles > 256:
continue
wave_tile_n = math.ceil(size[1] // n_tiles / mfma_instruction[1])
if wave_tile_n <= 0:
continue
matmul_instruction = mfma_instruction + [2**bm, 1, 1, 1, 1]
for k in reversed(range(3)):
if wave_tile_m // (2**k) >= 1 and wave_tile_m // (2**k) <= 32:
matmul_instruction[-4] = wave_tile_m // (2**k)
matmul_instruction[-2] = 2**k

for l in reversed(range(3)):
if wave_tile_n // (2**l) >= 1 and wave_tile_n // (2**l) <= 32:
matmul_instruction[-3] = wave_tile_n // (2**l)
matmul_instruction[-1] = 2**l

return matmul_instruction
yield matmul_instruction


for gpu_idx, unique_gemms_subgroup in enumerate(unique_gemms_subgroups):
Expand All @@ -216,38 +217,18 @@ def find_matmul_instruction(mfma_instruction, size, CU):
if mfma_instruction is None:
continue
if args.fast:
matmul_instruction = find_matmul_instruction(mfma_instruction, size, CU)
matmul_instruction = next(find_matmul_instruction(mfma_instruction, size, CU))
if matmul_instruction is not None:
if dtype_str not in matmul_instructions:
matmul_instructions[dtype_str] = dict()
matmul_instructions[dtype_str][str(matmul_instruction)] = matmul_instruction
else:
for m_tiles in reversed(range(1, CU+1)):
if size[0] // m_tiles > 256:
continue
wave_tile_m = math.ceil(size[0] // m_tiles / mfma_instruction[0])
if wave_tile_m <= 0:
continue
for n_tiles in reversed(range(1, CU+1)):
if size[1] // n_tiles > 256:
continue
wave_tile_n = math.ceil(size[1] // n_tiles / mfma_instruction[1])
if wave_tile_n <= 0:
continue
matmul_instruction = mfma_instruction+[1, 1, 1, 1, 1]
for k in reversed(range(3)):
if wave_tile_m // (2**k) > 0:
matmul_instruction[-4] = wave_tile_m//(2**k)
matmul_instruction[-2] = 2**k

for l in reversed(range(3)):
if wave_tile_n // (2**l) > 0:
matmul_instruction[-3] = wave_tile_n//(2**l)
matmul_instruction[-1] = 2**l

if dtype_str not in matmul_instructions:
matmul_instructions[dtype_str] = dict()
matmul_instructions[dtype_str][str(matmul_instruction)] = matmul_instruction
matmul_instruction_gen = find_matmul_instruction(mfma_instruction, size, CU)
for matmul_instruction in matmul_instruction_gen:
if matmul_instruction is not None:
if dtype_str not in matmul_instructions:
matmul_instructions[dtype_str] = dict()
matmul_instructions[dtype_str][str(matmul_instruction)] = matmul_instruction


if dtype_str in gemm_group:
Expand Down

0 comments on commit ecb691c

Please sign in to comment.