Skip to content

Commit

Permalink
[tuner]: use lowering config binding (#629)
Browse files Browse the repository at this point in the history
This PR is relevant to the task in
#453 : use IREE bindings for
compilation info (incl., lowering_config and translation_info).

Remove data class `ReorderWorkgroupsStrategy`, and use lowering_config
binding.

---------

Signed-off-by: Bangtian Liu <[email protected]>
  • Loading branch information
bangtianliu authored Dec 3, 2024
1 parent 5b75ea1 commit c0ca2e2
Show file tree
Hide file tree
Showing 8 changed files with 408 additions and 173 deletions.
133 changes: 85 additions & 48 deletions tuner/tuner/candidate_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,38 +38,48 @@

tune_logger = logging.getLogger("tune")


# TODO: remove the argument 'workgroup_sizes' and 'reduction_sizes'.
def apply_configuration(
template: list[str], configuration: Configuration, tile_sizes: list[int]
template: list[str],
configuration: Configuration,
workgroup_sizes: list[int],
reduction_sizes: list[int],
) -> str:
intrinsic = get_intrinsic(configuration)
subgroup_m_count = get_subgroup_m_count(configuration)
subgroup_n_count = get_subgroup_n_count(configuration)
tune_logger.info(f"Applying: {configuration}")
expr0 = re.compile(
r"<intrinsic = #iree_gpu\.mma_layout<(.+)>, subgroup_m_count = ([0-9]+), subgroup_n_count = ([0-9]+)>"
)
expr1 = re.compile(
r"LLVMGPUVectorDistribute workgroup_size = \[.+\] subgroup_size = ([0-9]+),"
)
expr2 = re.compile(r"tile_sizes = \[\[([0-9]+)(, ([0-9]+))+\]\]")
expr3 = re.compile(r"gpu_pipeline_options = #iree_gpu\.pipeline_options<([^>]*)>")
expr4 = re.compile(r"\"amdgpu-waves-per-eu\" = \"([0-9])\"")
repl0 = f"<intrinsic = {configuration.intrinsic}, subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}>"
expr2 = re.compile(r"workgroup = \[([0-9]+)(, ([0-9]+))+\]")
expr3 = re.compile(r"reduction = \[([0-9]+)(, ([0-9]+))+\]")
expr4 = re.compile(r"gpu_pipeline_options = #iree_gpu\.pipeline_options<([^>]*)>")
expr5 = re.compile(r"\"amdgpu-waves-per-eu\" = \"([0-9])\"")
repl0 = f"<intrinsic = {intrinsic}, subgroup_m_count = {subgroup_m_count}, subgroup_n_count = {subgroup_n_count}>"
repl1 = f'LLVMGPUVectorDistribute workgroup_size = [{", ".join(map(str, configuration.workgroup_size))}] subgroup_size = {configuration.subgroup_size},'
repl2 = f'tile_sizes = [[{", ".join(map(str, tile_sizes))}]]'
repl3 = f"gpu_pipeline_options = {configuration.gpu_pipeline_options}"
repl4 = f'"amdgpu-waves-per-eu" = "{configuration.waves_per_eu}"'
repl2 = f"workgroup = {workgroup_sizes}"
repl3 = f"reduction = {reduction_sizes}"
repl4 = f"gpu_pipeline_options = {configuration.gpu_pipeline_options}"
repl5 = f'"amdgpu-waves-per-eu" = "{configuration.waves_per_eu}"'

new_mlir = ""
for line in template:
if "intrinsic =" in line:
line = re.sub(expr0, repl0, line)
if "LLVMGPUVectorDistribute " in line:
line = re.sub(expr1, repl1, line)
if "tile_sizes" in line:
if "workgroup" in line:
line = re.sub(expr2, repl2, line)
if "gpu_pipeline_options =" in line:
if "reduction" in line:
line = re.sub(expr3, repl3, line)
if "amdgpu-waves-per-eu" in line:
if "gpu_pipeline_options =" in line:
line = re.sub(expr4, repl4, line)
if "amdgpu-waves-per-eu" in line:
line = re.sub(expr5, repl5, line)
new_mlir += line

return new_mlir
Expand Down Expand Up @@ -115,7 +125,9 @@ class MmtTuner(DispatchTuner, MmtParser):
def get_transform_function_mmt(
self, problem_size: ProblemSize, functionName: str, configuration: Configuration
) -> str:
tile_sizes = ", ".join(map(str, get_mmt_tile_sizes(configuration)))
intrinsic = get_intrinsic(configuration)
subgroup_m_count = get_subgroup_m_count(configuration)
subgroup_n_count = get_subgroup_n_count(configuration)

wg_x, wg_y, wg_z = configuration.workgroup_size
extra_config = get_pipeline_config(configuration)
Expand All @@ -127,12 +139,12 @@ def get_transform_function_mmt(
transform.iree.match.cast_compatible_type %lhs = tensor<{problem_size.lhs_type}> : !transform.any_value
transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value
%config = transform.param.constant #iree_codegen.compilation_info<
lowering_config = #iree_codegen.lowering_config<tile_sizes = [[{tile_sizes}]]>,
lowering_config = {configuration.lowering_config}>,
translation_info = #iree_codegen.translation_info<LLVMGPUVectorDistribute
workgroup_size = [{wg_x}, {wg_y}, {wg_z}] subgroup_size = {configuration.subgroup_size},
{{mma_schedule = #iree_gpu.mma_schedule<
intrinsic = {configuration.intrinsic},
subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}>
intrinsic = {intrinsic},
subgroup_m_count = {subgroup_m_count}, subgroup_n_count = {subgroup_n_count}>
{extra_config}}}>
> -> !transform.any_param
transform.yield %matmul, %config : !transform.any_op, !transform.any_param
Expand All @@ -153,7 +165,10 @@ def apply_params(
"// ",
)
modified += apply_configuration(
template, configuration, get_mmt_tile_sizes(configuration)
template,
configuration,
get_mmt_workgroup_sizes(configuration),
get_mmt_reduction_sizes(configuration),
)
embeddable = indent(
self.get_transform_function_mmt(problem_size, f"match_op", configuration),
Expand All @@ -163,13 +178,6 @@ def apply_params(


class ConvTuner(DispatchTuner, ConvParser):
# int64_t n = outputShape[0];
# int64_t oh = outputShape[1];
# int64_t ow = outputShape[2];
# int64_t oc = outputShape[3];
# int64_t fh = filterShape[0];
# int64_t fw = filterShape[1];
# int64_t ic = filterShape[2];
def get_transform_function_conv(
self, problem_size: ProblemSize, functionName: str, configuration: Configuration
) -> str:
Expand All @@ -185,7 +193,15 @@ def get_transform_function_conv(
filter = f"tensor<{problem_size.rhs_type}>"
output = f"tensor<{dynamic_batch_output_ty}>"

tile_sizes = ", ".join(map(str, self.get_conv_tile_sizes(configuration)))
workgroup_sizes = ", ".join(
map(str, self.get_conv_workgroup_sizes(configuration))
)
reduction_sizes = ", ".join(
map(str, self.get_conv_reduction_sizes(configuration))
)
intrinsic = get_intrinsic(configuration)
subgroup_m_count = get_subgroup_m_count(configuration)
subgroup_n_count = get_subgroup_n_count(configuration)

wg_x, wg_y, wg_z = configuration.workgroup_size
extra_config = get_pipeline_config(configuration)
Expand All @@ -200,12 +216,12 @@ def get_transform_function_conv(
outs(%out : {output}) -> {output}
}} : (!transform.any_op) -> (!transform.any_value, !transform.any_value)
%config = transform.param.constant #iree_codegen.compilation_info<
lowering_config = #iree_codegen.lowering_config<tile_sizes = [[{tile_sizes}]]>,
lowering_config = {configuration.lowering_config}>,
translation_info = #iree_codegen.translation_info<LLVMGPUVectorDistribute
workgroup_size = [{wg_x}, {wg_y}, {wg_z}] subgroup_size = {configuration.subgroup_size},
{{mma_schedule = #iree_gpu.mma_schedule<
intrinsic = {configuration.intrinsic},
subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}>
intrinsic = {intrinsic},
subgroup_m_count = {subgroup_m_count}, subgroup_n_count = {subgroup_n_count}>
{extra_config}}}>
> -> !transform.any_param
transform.yield %conv, %config : !transform.any_op, !transform.any_param
Expand All @@ -228,7 +244,10 @@ def apply_params(
"// ",
)
modified += apply_configuration(
template, configuration, self.get_conv_tile_sizes(configuration)
template,
configuration,
self.get_conv_workgroup_sizes(configuration),
self.get_conv_reduction_sizes(configuration),
)
embeddable = indent(
self.get_transform_function_conv(problem_size, f"match_op", configuration),
Expand All @@ -244,7 +263,15 @@ def get_transform_function_broadcast_rhs_mmt(
functionName: str,
configuration: Configuration,
) -> str:
tile_sizes = ", ".join(map(str, get_batch_mmt_tile_sizes(configuration)))
workgroup_sizes = ", ".join(
map(str, get_batch_mmt_workgroup_sizes(configuration))
)
reduction_sizes = ", ".join(
map(str, get_batch_mmt_reduction_sizes(configuration))
)
intrinsic = get_intrinsic(configuration)
subgroup_m_count = get_subgroup_m_count(configuration)
subgroup_n_count = get_subgroup_n_count(configuration)

wg_x, wg_y, wg_z = configuration.workgroup_size
extra_config = get_pipeline_config(configuration)
Expand All @@ -261,12 +288,12 @@ def get_transform_function_broadcast_rhs_mmt(
transform.iree.match.cast_compatible_type %lhs = tensor<{lhs_dynamic_batch}> : !transform.any_value
transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value
%config = transform.param.constant #iree_codegen.compilation_info<
lowering_config = #iree_codegen.lowering_config<tile_sizes = [[{tile_sizes}]]>,
lowering_config = {configuration.lowering_config}>,
translation_info = #iree_codegen.translation_info<LLVMGPUVectorDistribute
workgroup_size = [{wg_x}, {wg_y}, {wg_z}] subgroup_size = {configuration.subgroup_size},
{{mma_schedule = #iree_gpu.mma_schedule<
intrinsic = {configuration.intrinsic},
subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}>
intrinsic = {intrinsic},
subgroup_m_count = {subgroup_m_count}, subgroup_n_count = {subgroup_n_count}>
{extra_config}}}>
> -> !transform.any_param
transform.yield %generic, %config : !transform.any_op, !transform.any_param
Expand All @@ -287,7 +314,10 @@ def apply_params_broadcast_rhs_mmt(
"// ",
)
modified += apply_configuration(
template, configuration, get_batch_mmt_tile_sizes(configuration)
template,
configuration,
get_batch_mmt_workgroup_sizes(configuration),
get_batch_mmt_reduction_sizes(configuration),
)

embeddable = indent(
Expand Down Expand Up @@ -315,7 +345,8 @@ def apply_params(
apply_configuration(
template,
configuration,
get_contract_tile_sizes(configuration, self.tile_dims),
get_contract_workgroup_sizes(configuration, self.tile_dims),
get_contract_reduction_sizes(configuration, self.tile_dims),
),
"",
)
Expand All @@ -328,7 +359,9 @@ def get_transform_function_batch_mmt(
functionName: str,
configuration: Configuration,
) -> str:
tile_sizes = ", ".join(map(str, get_batch_mmt_tile_sizes(configuration)))
intrinsic = get_intrinsic(configuration)
subgroup_m_count = get_subgroup_m_count(configuration)
subgroup_n_count = get_subgroup_n_count(configuration)

wg_x, wg_y, wg_z = configuration.workgroup_size
extra_config = get_pipeline_config(configuration)
Expand All @@ -341,12 +374,12 @@ def get_transform_function_batch_mmt(
transform.iree.match.cast_compatible_type %lhs = tensor<{problem_size.lhs_type}> : !transform.any_value
transform.iree.match.cast_compatible_type %rhs = tensor<{problem_size.rhs_type}> : !transform.any_value
%config = transform.param.constant #iree_codegen.compilation_info<
lowering_config = #iree_codegen.lowering_config<tile_sizes = [[{tile_sizes}]]>,
lowering_config = {configuration.lowering_config}>,
translation_info = #iree_codegen.translation_info<LLVMGPUVectorDistribute
workgroup_size = [{wg_x}, {wg_y}, {wg_z}] subgroup_size = {configuration.subgroup_size},
{{mma_schedule = #iree_gpu.mma_schedule<
intrinsic = {configuration.intrinsic},
subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}>
intrinsic = {intrinsic},
subgroup_m_count = {subgroup_m_count}, subgroup_n_count = {subgroup_n_count}>
{extra_config}}}>
> -> !transform.any_param
transform.yield %generic, %config : !transform.any_op, !transform.any_param
Expand All @@ -368,7 +401,10 @@ def apply_params(
"// ",
)
modified += apply_configuration(
template, configuration, get_batch_mmt_tile_sizes(configuration)
template,
configuration,
get_batch_mmt_workgroup_sizes(configuration),
get_batch_mmt_reduction_sizes(configuration),
)

embeddable = indent(
Expand All @@ -392,9 +428,9 @@ def get_transform_function_batch_matmul(
input1 = f"tensor<{problem_size.rhs_type}>"
output = f"tensor<{problem_size.res_type}>"

tile_sizes = ", ".join(
map(str, get_contract_tile_sizes(configuration, tile_dims))
)
intrinsic = get_intrinsic(configuration)
subgroup_m_count = get_subgroup_m_count(configuration)
subgroup_n_count = get_subgroup_n_count(configuration)

wg_x, wg_y, wg_z = configuration.workgroup_size
extra_config = get_pipeline_config(configuration)
Expand All @@ -409,12 +445,12 @@ def get_transform_function_batch_matmul(
outs(%out : {output}) -> {output}
}} : (!transform.any_op) -> (!transform.any_value, !transform.any_value)
%config = transform.param.constant #iree_codegen.compilation_info<
lowering_config = #iree_codegen.lowering_config<tile_sizes = [[{tile_sizes}]]>,
lowering_config = {configuration.lowering_config}>,
translation_info = #iree_codegen.translation_info<LLVMGPUPadAndVectorDistribute
workgroup_size = [{wg_x}, {wg_y}, {wg_z}] subgroup_size = {configuration.subgroup_size},
{{mma_schedule = #iree_gpu.mma_schedule<
intrinsic = {configuration.intrinsic},
subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}>
intrinsic = {intrinsic},
subgroup_m_count = {subgroup_m_count}, subgroup_n_count = {subgroup_n_count}>
{extra_config}}}>
> -> !transform.any_param
transform.yield %batch_matmul, %config : !transform.any_op, !transform.any_param
Expand All @@ -440,7 +476,8 @@ def apply_params(
modified += apply_configuration(
template,
configuration,
get_contract_tile_sizes(configuration, self.tile_dims),
get_contract_workgroup_sizes(configuration, self.tile_dims),
get_contract_reduction_sizes(configuration, self.tile_dims),
)

embeddable = indent(
Expand Down Expand Up @@ -548,7 +585,7 @@ def tune(
tune_logger.debug(str(problem_size))
configs = []
for i, config in enumerate(
generate_solutions(tune_logger, problem_size, num_subgroups, mma_list)
generate_solutions(tuner_context, problem_size, num_subgroups, mma_list)
):
if i >= limit:
break
Expand Down
Loading

0 comments on commit c0ca2e2

Please sign in to comment.