Skip to content

Commit

Permalink
Replaced nested for-loops in tune_gemm.py with itertools.product
Browse files Browse the repository at this point in the history
  • Loading branch information
ravil-mobile committed Oct 30, 2024
1 parent 5230674 commit cdeffe9
Showing 1 changed file with 13 additions and 19 deletions.
32 changes: 13 additions & 19 deletions python/perf-kernels/tools/tune_gemm/tune_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from datetime import datetime
import multiprocessing
import pandas as pd
import itertools

from utils.file_generator import (
gen_configStr,
Expand Down Expand Up @@ -55,7 +56,7 @@ def get_full_tuning_space():
split_k_range = [1, 2, 4, 5, 6, 8, 10, 12, 16, 18, 24]
num_warps_range = [1, 2, 4, 8]
group_m_range = [1, 2, 4, 8, 16, 32]
# For now we see better perf with num_stages=0 for all gemm configs we care
# For now we see better perf with num_stages=2 for all gemm configs we care
# But keep this explicit so that we do not forget we may need to set it to
# other values in the future
num_stage_range = [2]
Expand All @@ -64,24 +65,17 @@ def get_full_tuning_space():
kpack_range = [1, 2]
sched_variants = ["\"default\""]

for block_m in block_mn_range:
for block_n in block_mn_range:
for block_k in block_k_range:
for num_warps in num_warps_range:
for group_m in group_m_range:
for split_k in split_k_range:
for num_stages in num_stage_range:
for waves_per_eu in waves_per_eu_range:
for matrix_instr_nonkdim in matrix_instr_nonkdim_range:
for sched_variant in sched_variants:
for kpack in kpack_range:
configs.append({
'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K':
block_k, 'GROUP_SIZE_M': group_m, 'SPLIT_K': split_k, 'num_warps':
num_warps, 'num_stages': num_stages, 'waves_per_eu': waves_per_eu,
'matrix_instr_nonkdim': matrix_instr_nonkdim, 'kpack': kpack,
'instruction_sched_variant': sched_variant
})
space = itertools.product(block_mn_range, block_mn_range, block_k_range, num_warps_range, group_m_range,
split_k_range, num_stage_range, waves_per_eu_range, matrix_instr_nonkdim_range,
sched_variants, kpack_range)

for instance in space:
block_m, block_n, block_k, num_warps, group_m, split_k, num_stages, waves_per_eu, matrix_instr_nonkdim, sched_variant, kpack = instance
configs.append({
'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': block_k, 'GROUP_SIZE_M': group_m,
'SPLIT_K': split_k, 'num_warps': num_warps, 'num_stages': num_stages, 'waves_per_eu': waves_per_eu,
'matrix_instr_nonkdim': matrix_instr_nonkdim, 'kpack': kpack, 'instruction_sched_variant': sched_variant
})

return configs

Expand Down

0 comments on commit cdeffe9

Please sign in to comment.