Skip to content

Commit

Permalink
Added instr.sched options to tune_gemm.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ravil-mobile committed Oct 24, 2024
1 parent b36e072 commit 66eb96c
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 16 deletions.
26 changes: 15 additions & 11 deletions python/perf-kernels/tools/tune_gemm/tune_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def get_full_tuning_space():
waves_per_eu_range = [0]
matrix_instr_nonkdim_range = [16, 32]
kpack_range = [1, 2]
sched_variants = ["\"default\"", "\"iglp0\"", "\"ck_v3\""]

for block_m in block_mn_range:
for block_n in block_mn_range:
Expand All @@ -72,13 +73,15 @@ def get_full_tuning_space():
for num_stages in num_stage_range:
for waves_per_eu in waves_per_eu_range:
for matrix_instr_nonkdim in matrix_instr_nonkdim_range:
for kpack in kpack_range:
configs.append({
'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K':
block_k, 'GROUP_SIZE_M': group_m, 'SPLIT_K': split_k, 'num_warps':
num_warps, 'num_stages': num_stages, 'waves_per_eu': waves_per_eu,
'matrix_instr_nonkdim': matrix_instr_nonkdim, 'kpack': kpack
})
for sched_variant in sched_variants:
for kpack in kpack_range:
configs.append({
'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K':
block_k, 'GROUP_SIZE_M': group_m, 'SPLIT_K': split_k, 'num_warps':
num_warps, 'num_stages': num_stages, 'waves_per_eu': waves_per_eu,
'matrix_instr_nonkdim': matrix_instr_nonkdim, 'kpack': kpack,
'instruction_sched_variant': sched_variant
})

return configs

Expand Down Expand Up @@ -346,7 +349,7 @@ def gen_rotating_tensors(M, N, K, dtype_a, need_Trans_a, dtype_b, need_Trans_b,


def matmul(a, b, c, bias, block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu,
mfmaInstrSize, kpack, use_bias):
mfmaInstrSize, kpack, use_bias, sched_variant):
# Check constraints.
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
#assert a.is_contiguous(), "Matrix A must be contiguous"
Expand All @@ -363,12 +366,13 @@ def matmul(a, b, c, bias, block_m, block_n, block_k, group_m, split_k, num_warps
c.stride(1), stride_bias=stride_bias, BLOCK_SIZE_M=block_m, BLOCK_SIZE_N=block_n,
BLOCK_SIZE_K=block_k, GROUP_SIZE_M=group_m, SPLIT_K=split_k, num_warps=num_warps,
num_stages=num_stages, waves_per_eu=waves_per_eu, matrix_instr_nonkdim=mfmaInstrSize,
kpack=kpack, BIAS=use_bias, EVEN_K=EVEN_K, GRID_MN=grid[0], NUM_XCDS=num_xcds)
kpack=kpack, BIAS=use_bias, EVEN_K=EVEN_K, GRID_MN=grid[0], NUM_XCDS=num_xcds,
instruction_sched_variant=sched_variant)
return c


def test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, config, bias_vector, verbose):
block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfmaInstrSize, kpack = read_config(
block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfmaInstrSize, kpack, sched_variant = read_config(
config)
use_bias = bias_vector
torch.manual_seed(0)
Expand All @@ -384,7 +388,7 @@ def test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type
# Allocates output.
c = torch.zeros((M, N), device=a.device, dtype=tl_to_torch_types[name_to_tl_types[dtype_c]])
triton_output = matmul(a, b, c, bias, block_m, block_n, block_k, group_m, split_k, num_warps, num_stages,
waves_per_eu, mfmaInstrSize, kpack, use_bias)
waves_per_eu, mfmaInstrSize, kpack, use_bias, sched_variant)
torch_output = torch.matmul(a_fp16, b_fp16)
if use_bias:
torch_output += bias_fp16[:, None]
Expand Down
13 changes: 8 additions & 5 deletions python/perf-kernels/tools/tune_gemm/utils/file_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,16 @@ def read_config(config):
waves_per_eu = config.get('waves_per_eu')
mfma_instr_size = config.get('matrix_instr_nonkdim')
kpack = config.get('kpack')
return block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfma_instr_size, kpack
sched_variant = config.get('instruction_sched_variant')
return block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfma_instr_size, kpack, sched_variant


def gen_configStr(config):
block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfmaInstrSize, kpack = read_config(
block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfmaInstrSize, kpack, sched_variant = read_config(
config)

## {M}_{N}_{K} is removed since the same kernel can be used for differen gemm sizes
configStr = f"BM{block_m}_BN{block_n}_BK{block_k}_GM{group_m}_SK{split_k}_nW{num_warps}_nS{num_stages}_EU{waves_per_eu}_kP{kpack}_mfma{mfmaInstrSize}"
configStr = f"BM{block_m}_BN{block_n}_BK{block_k}_GM{group_m}_SK{split_k}_nW{num_warps}_nS{num_stages}_EU{waves_per_eu}_kP{kpack}_mfma{mfmaInstrSize}_sched{sched_variant[1:-1].upper()}"

return configStr

Expand Down Expand Up @@ -69,7 +70,7 @@ def generate_matmul_kernels(configs):
## construct the configStr and generate the wrapper function matmul_{configStr}()
## If `warmup` is set, the generated kernel will be **compiled**
def gen_kernel_and_configStr_from_config(config, EVEN_K, dtype_a, dtype_b, dtype_c, bias_size, warmup):
block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfmaInstrSize, kpack = read_config(
block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfmaInstrSize, kpack, sched_variant = read_config(
config)

configStr = gen_configStr(config)
Expand Down Expand Up @@ -112,6 +113,7 @@ def matmul_{configStr}(M, N, K, am, ak, bk, bn, cm, cn, biasn):
EVEN_K = {EVEN_K},
GRID_MN = grid_mn,
NUM_XCDS = {num_xcds},
instruction_sched_variant = {sched_variant},
grid=(1,),
)
return None
Expand Down Expand Up @@ -145,7 +147,8 @@ def matmul_{configStr}(a, b, c, bias, M, N, K, am, ak, bk, bn, cm, cn, biasn):
BIAS = {use_bias},
EVEN_K = {EVEN_K},
GRID_MN = grid[0],
NUM_XCDS = {num_xcds}
NUM_XCDS = {num_xcds},
instruction_sched_variant = {sched_variant},
)
return c
"""
Expand Down

0 comments on commit 66eb96c

Please sign in to comment.