Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added instr.sched options to tune_gemm.py #649

Open
wants to merge 6 commits into
base: main_perf
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/perf-kernels/tools/tune_gemm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,7 @@ will be added later.
- Switched back to rocprofv1. Check [ticket#228](https://github.com/ROCm/triton-internal/issues/228) for more details.
- Improved the post-procesing logic to filter out the "spikes" in the profiling results.
- Reduced the number of iterations in both tuning and benchmark mode (120 and 200).
- Appended the parameters tuning space with instruction scheduling variants for the main gemm-loop (k-loop).


# One config running script
Expand Down
12 changes: 12 additions & 0 deletions python/perf-kernels/tools/tune_gemm/matmul_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,15 @@ def matmul_kernel(a_ptr, b_ptr, c_ptr, bias_ptr, M, N, K, stride_am, stride_ak,
stride_cn, stride_bias, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr, SPLIT_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, BIAS: tl.constexpr,
EVEN_K: tl.constexpr, GRID_MN: tl.constexpr, NUM_XCDS: tl.constexpr):

tl.assume(stride_am > 0)
tl.assume(stride_ak > 0)
tl.assume(stride_bk > 0)
tl.assume(stride_bn > 0)
tl.assume(stride_cm > 0)
tl.assume(stride_cn > 0)
tl.assume(stride_bias > 0)

pid = tl.program_id(axis=0)
pid_z = tl.program_id(1)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
Expand All @@ -33,6 +42,9 @@ def matmul_kernel(a_ptr, b_ptr, c_ptr, bias_ptr, M, N, K, stride_am, stride_ak,
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m

tl.assume(pid_m > 0)
tl.assume(pid_n > 0)

if SPLIT_K == 1:
offs_k = tl.arange(0, BLOCK_SIZE_K)
else:
Expand Down
1 change: 1 addition & 0 deletions python/perf-kernels/tools/tune_gemm/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def teardown_class(self):
},
], ids=lambda val: f"Config: {val}")
def test_matmul_performance_regression(self, config, record_property):
config.setdefault('instruction_sched_variant', 'default')

M, N, K, col_a, col_b, runConfig = tune_gemm.process_item(deepcopy(config))

Expand Down
48 changes: 23 additions & 25 deletions python/perf-kernels/tools/tune_gemm/tune_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from datetime import datetime
import multiprocessing
import pandas as pd
import itertools

from utils.file_generator import (
gen_configStr,
Expand Down Expand Up @@ -55,30 +56,26 @@ def get_full_tuning_space():
split_k_range = [1, 2, 4, 5, 6, 8, 10, 12, 16, 18, 24]
num_warps_range = [1, 2, 4, 8]
group_m_range = [1, 2, 4, 8, 16, 32]
# For now we see better perf with num_stages=0 for all gemm configs we care
# For now we see better perf with num_stages=2 for all gemm configs we care
# But keep this explicit so that we do not forget we may need to set it to
# other values in the future
num_stage_range = [0]
num_stage_range = [2]
waves_per_eu_range = [0]
matrix_instr_nonkdim_range = [16, 32]
kpack_range = [1, 2]
sched_variants = ["\"default\""]

for block_m in block_mn_range:
for block_n in block_mn_range:
for block_k in block_k_range:
for num_warps in num_warps_range:
for group_m in group_m_range:
for split_k in split_k_range:
for num_stages in num_stage_range:
for waves_per_eu in waves_per_eu_range:
for matrix_instr_nonkdim in matrix_instr_nonkdim_range:
for kpack in kpack_range:
configs.append({
'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K':
block_k, 'GROUP_SIZE_M': group_m, 'SPLIT_K': split_k, 'num_warps':
num_warps, 'num_stages': num_stages, 'waves_per_eu': waves_per_eu,
'matrix_instr_nonkdim': matrix_instr_nonkdim, 'kpack': kpack
})
space = itertools.product(block_mn_range, block_mn_range, block_k_range, num_warps_range, group_m_range,
split_k_range, num_stage_range, waves_per_eu_range, matrix_instr_nonkdim_range,
sched_variants, kpack_range)

for instance in space:
block_m, block_n, block_k, num_warps, group_m, split_k, num_stages, waves_per_eu, matrix_instr_nonkdim, sched_variant, kpack = instance
configs.append({
'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': block_k, 'GROUP_SIZE_M': group_m,
'SPLIT_K': split_k, 'num_warps': num_warps, 'num_stages': num_stages, 'waves_per_eu': waves_per_eu,
'matrix_instr_nonkdim': matrix_instr_nonkdim, 'kpack': kpack, 'instruction_sched_variant': sched_variant
})

return configs

Expand Down Expand Up @@ -144,7 +141,7 @@ def prune_configs(M, N, K, configs, elemBytes_a, elemBytes_b):
# out of shared memory resource
# TODO (zhanglx): This does not consider the LDS usage in the epilogue
LDS = BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a + BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b
LDS = LDS if not num_stages else LDS * num_stages
LDS = LDS if not num_stages else LDS * (num_stages - 1)
if LDS > 65536:
continue
# Skip small block sizes and num_warps for large gemm
Expand Down Expand Up @@ -346,7 +343,7 @@ def gen_rotating_tensors(M, N, K, dtype_a, need_Trans_a, dtype_b, need_Trans_b,


def matmul(a, b, c, bias, block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu,
mfmaInstrSize, kpack, use_bias):
mfmaInstrSize, kpack, use_bias, sched_variant):
# Check constraints.
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
#assert a.is_contiguous(), "Matrix A must be contiguous"
Expand All @@ -363,12 +360,13 @@ def matmul(a, b, c, bias, block_m, block_n, block_k, group_m, split_k, num_warps
c.stride(1), stride_bias=stride_bias, BLOCK_SIZE_M=block_m, BLOCK_SIZE_N=block_n,
BLOCK_SIZE_K=block_k, GROUP_SIZE_M=group_m, SPLIT_K=split_k, num_warps=num_warps,
num_stages=num_stages, waves_per_eu=waves_per_eu, matrix_instr_nonkdim=mfmaInstrSize,
kpack=kpack, BIAS=use_bias, EVEN_K=EVEN_K, GRID_MN=grid[0], NUM_XCDS=num_xcds)
kpack=kpack, BIAS=use_bias, EVEN_K=EVEN_K, GRID_MN=grid[0], NUM_XCDS=num_xcds,
instruction_sched_variant=sched_variant)
return c


def test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type, config, bias_vector, verbose):
block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfmaInstrSize, kpack = read_config(
block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfmaInstrSize, kpack, sched_variant = read_config(
config)
use_bias = bias_vector
torch.manual_seed(0)
Expand All @@ -384,7 +382,7 @@ def test_correctness(M, N, K, col_a, col_b, dtype_a, dtype_b, dtype_c, init_type
# Allocates output.
c = torch.zeros((M, N), device=a.device, dtype=tl_to_torch_types[name_to_tl_types[dtype_c]])
triton_output = matmul(a, b, c, bias, block_m, block_n, block_k, group_m, split_k, num_warps, num_stages,
waves_per_eu, mfmaInstrSize, kpack, use_bias)
waves_per_eu, mfmaInstrSize, kpack, use_bias, sched_variant)
torch_output = torch.matmul(a_fp16, b_fp16)
if use_bias:
torch_output += bias_fp16[:, None]
Expand Down Expand Up @@ -649,11 +647,11 @@ def main():
formatted_tflops = format_output(tri_tflops)
minTime = format_output(minTime)
if not run_bench:
print(f'TFLOPS: {formatted_tflops} time(us): {minTime}', end=" ", flush=True)
print(f'\nTFLOPS: {formatted_tflops}; time(us): {minTime}', end=" ", flush=True)

bestConfig_compact_str = gen_configStr(bestConfig)
if not run_bench:
print(f'best_config: {bestConfig_compact_str}', end=" ", flush=True)
print(f'\nbest_config: {bestConfig_compact_str}', end=" ", flush=True)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have an example output after adding '\n'?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, sure.

> ./tune_gemm.py --gemm_size_file ~/tuning/input.yaml --gpu_ids 3,4,5 --jobs 32 --o ~/tuning/output.yaml
Tuning 1 gemm sizes starts at: 2024-10-29 14:26:32.618604
SIZE: 4864 8192 4160 TN nConfigs: 720 
TFLOPS: 516.47; time(us): 641.89 
best_config: BM128_BN128_BK64_GM8_SK1_nW4_nS2_EU0_kP2_mfma16_schedDEFAULT 
>>> Elapsed time: 0:04:11.238153 = 0:00:20.441773 (compile) + 0:03:49.947198 (profile) + 0:00:00.681876 (post processing)
Tuning ends at: 2024-10-29 14:30:44.031012
Total tuning time (h:m:s): 0:04:11.412408


# write best config to tuning_results.yaml
if run_bench:
Expand Down
13 changes: 8 additions & 5 deletions python/perf-kernels/tools/tune_gemm/utils/file_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,16 @@ def read_config(config):
waves_per_eu = config.get('waves_per_eu')
mfma_instr_size = config.get('matrix_instr_nonkdim')
kpack = config.get('kpack')
return block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfma_instr_size, kpack
sched_variant = config.get('instruction_sched_variant')
return block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfma_instr_size, kpack, sched_variant


def gen_configStr(config):
block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfmaInstrSize, kpack = read_config(
block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfmaInstrSize, kpack, sched_variant = read_config(
config)

## {M}_{N}_{K} is removed since the same kernel can be used for differen gemm sizes
configStr = f"BM{block_m}_BN{block_n}_BK{block_k}_GM{group_m}_SK{split_k}_nW{num_warps}_nS{num_stages}_EU{waves_per_eu}_kP{kpack}_mfma{mfmaInstrSize}"
configStr = f"BM{block_m}_BN{block_n}_BK{block_k}_GM{group_m}_SK{split_k}_nW{num_warps}_nS{num_stages}_EU{waves_per_eu}_kP{kpack}_mfma{mfmaInstrSize}_sched{sched_variant[1:-1].upper()}"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now we are using local-prefetch, but we cannot have - in kernel names. Can you also convert - into _?


return configStr

Expand Down Expand Up @@ -69,7 +70,7 @@ def generate_matmul_kernels(configs):
## construct the configStr and generate the wrapper function matmul_{configStr}()
## If `warmup` is set, the generated kernel will be **compiled**
def gen_kernel_and_configStr_from_config(config, EVEN_K, dtype_a, dtype_b, dtype_c, bias_size, warmup):
block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfmaInstrSize, kpack = read_config(
block_m, block_n, block_k, group_m, split_k, num_warps, num_stages, waves_per_eu, mfmaInstrSize, kpack, sched_variant = read_config(
config)

configStr = gen_configStr(config)
Expand Down Expand Up @@ -112,6 +113,7 @@ def matmul_{configStr}(M, N, K, am, ak, bk, bn, cm, cn, biasn):
EVEN_K = {EVEN_K},
GRID_MN = grid_mn,
NUM_XCDS = {num_xcds},
instruction_sched_variant = {sched_variant},

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing quotes ' '

grid=(1,),
)
return None
Expand Down Expand Up @@ -145,7 +147,8 @@ def matmul_{configStr}(a, b, c, bias, M, N, K, am, ak, bk, bn, cm, cn, biasn):
BIAS = {use_bias},
EVEN_K = {EVEN_K},
GRID_MN = grid[0],
NUM_XCDS = {num_xcds}
NUM_XCDS = {num_xcds},
instruction_sched_variant = {sched_variant},

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing quote ' '

)
return c
"""
Expand Down