From 4f6f2b3e177980334af013cb11ef4f36a3f9d342 Mon Sep 17 00:00:00 2001 From: Max191 <44243577+Max191@users.noreply.github.com> Date: Tue, 7 Jan 2025 15:44:08 -0500 Subject: [PATCH] [tuner] Add support for TileAndFuse and multi-dim contractions (#771) This PR adds support for tuning contractions with multiple M, N, K, and Batch dimensions, and adds support for tuning with the TileAndFuse pipeline. A new flag is added called `--codegen-pipeline` that specifies which codegen pipeline to target (`llvmgpu_vector_distribute` or `llvmgpu_tile_and_fuse`). --------- Signed-off-by: Max Dawkins Signed-off-by: Max Dawkins Co-authored-by: Max Dawkins --- tuner/examples/test/README.md | 3 +- tuner/tuner/candidate_gen.py | 5 +- tuner/tuner/common.py | 53 +++- tuner/tuner/common_test.py | 9 +- tuner/tuner/dispatch_constraints.py | 303 ++++++++++++++++++----- tuner/tuner/dispatch_constraints_test.py | 241 ++++++++++++++++-- tuner/tuner/dispatch_parser.py | 46 ++-- tuner/tuner/dispatch_parser_test.py | 38 ++- tuner/tuner/libtuner.py | 24 ++ tuner/tuner/op_matchers.py | 2 +- 10 files changed, 600 insertions(+), 124 deletions(-) diff --git a/tuner/examples/test/README.md b/tuner/examples/test/README.md index 47ae7a8fe..850a161da 100644 --- a/tuner/examples/test/README.md +++ b/tuner/examples/test/README.md @@ -36,5 +36,6 @@ python -m examples.test \ --test_num_dispatch_candidates= \ --test_num_model_candidates= \ --test_hip_target= \ - --num-candidates= + --num-candidates= \ + --codegen-pipeline= ``` diff --git a/tuner/tuner/candidate_gen.py b/tuner/tuner/candidate_gen.py index b6264792e..ff7019ee0 100644 --- a/tuner/tuner/candidate_gen.py +++ b/tuner/tuner/candidate_gen.py @@ -194,6 +194,7 @@ def generate_configs_and_td_specs( tuner_context: TunerContext, limit: int = 4096, # Max candidates to be generated num_subgroups: int = 4, # GPU spec, used to determine candidate generation constraints + codegen_pipeline: iree_codegen.DispatchLoweringPassPipeline = iree_codegen.DispatchLoweringPassPipeline.LLVMGPUVectorDistribute, ) -> list[ir.Module]: dispatch_tuner_registry = DispatchTunerRegistry(check_translation_info=False) dispatch_tuner_registry.register( @@ -221,7 +222,9 @@ def generate_configs_and_td_specs( variant_op = variant_op_list[0] mma_list = iree_codegen.query_mma_intrinsics(variant_op) for i, config in enumerate( - generate_solutions(tuner_context, problem_size, num_subgroups, mma_list) + generate_solutions( + tuner_context, problem_size, num_subgroups, mma_list, codegen_pipeline + ) ): if i >= limit: break diff --git a/tuner/tuner/common.py b/tuner/tuner/common.py index 54051df47..45bcb0d75 100644 --- a/tuner/tuner/common.py +++ b/tuner/tuner/common.py @@ -6,7 +6,7 @@ import re import logging -from dataclasses import astuple, dataclass +from dataclasses import astuple, dataclass, field from enum import Enum from typing import Optional from typing import Any @@ -67,31 +67,64 @@ def __str__(self) -> str: @dataclass -class MatmulSize: - M: int - N: int - K: int - B: int = 1 +class ContractionSizes: + """ + Represents the size of the iteration space along each contraction dimension. + For example, the following is a simple batch mmt: + linalg.generic ... indexing_maps = [ + affine_map<(b, m, n, k) -> (b, m, k)>, + affine_map<(b, m, n, k) -> (b, n, k)>, + affine_map<(b, m, n, k) -> (b, m, n)>, + ] ... + ins(%lhs: tensor<4x8x32xf16>, %rhs: tensor<4x16x32xf16>) + outs(%acc: tensor<4x8x16xf16>) + The ContractionSizes would be: + M = [8] + N = [16] + K = [32] + B = [4] + """ + + M: list[int] + N: list[int] + K: list[int] + B: list[int] = field(default_factory=list) @dataclass class ContractionDimensions: - batch: list[int] + """ + Stores which dimensions of the iteration space belong to M, N, K, or Batch. + For example, the following is a simple batch mmt: + linalg.generic ... indexing_maps = [ + affine_map<(b, m, n, k) -> (b, m, k)>, + affine_map<(b, m, n, k) -> (b, n, k)>, + affine_map<(b, m, n, k) -> (b, m, n)>, + ] + The ContractionDimensions would be: + M = [1] + N = [2] + K = [3] + B = [0] + """ + m: list[int] n: list[int] k: list[int] + batch: list[int] = field(default_factory=list) @dataclass class ProblemSize: - matmul_size: MatmulSize + matmul_size: ContractionSizes lhs_type: ShapedType rhs_type: ShapedType res_type: ShapedType dispatch_kind: DispatchKind + contraction_dims: ContractionDimensions @property - def MNK(self) -> tuple[int, int, int]: + def MNK(self) -> tuple[list[int], list[int], list[int]]: return (self.matmul_size.M, self.matmul_size.N, self.matmul_size.K) @@ -130,7 +163,7 @@ def get_lowering_config( # A local variable to hold the transformed value. promoted_value = value match key: - case "workgroup" | "reduction": + case "workgroup" | "reduction" | "subgroup": if isinstance(value, list): promoted_value = ir.ArrayAttr.get( [tuner_ctx.type.getI64(x) for x in value] diff --git a/tuner/tuner/common_test.py b/tuner/tuner/common_test.py index b23360ccc..eba5b35e1 100644 --- a/tuner/tuner/common_test.py +++ b/tuner/tuner/common_test.py @@ -119,11 +119,12 @@ def test_get_pipeline_config(tuner_ctx: common.TunerContext) -> None: def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None: assert common.get_compatible_mfma_intrinsics( common.ProblemSize( - common.MatmulSize(2048, 1280, 1280), + common.ContractionSizes([2048], [1280], [1280]), common.ShapedType([2048, 1280], tuner_ctx.type.f16), common.ShapedType([1280, 1280], tuner_ctx.type.f16), common.ShapedType([2048, 1280], tuner_ctx.type.f32), common.DispatchKind.contraction, + common.ContractionDimensions([0], [1], [2]), ), [ iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16, @@ -138,11 +139,12 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None: assert common.get_compatible_mfma_intrinsics( common.ProblemSize( - common.MatmulSize(2048, 1280, 1280), + common.ContractionSizes([2048], [1280], [1280]), common.ShapedType([2048, 1280], tuner_ctx.type.i8), common.ShapedType([1280, 1280], tuner_ctx.type.i8), common.ShapedType([2048, 1280], tuner_ctx.type.i32), common.DispatchKind.contraction, + common.ContractionDimensions([0], [1], [2]), ), [ iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16, @@ -158,11 +160,12 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None: assert ( common.get_compatible_mfma_intrinsics( common.ProblemSize( - common.MatmulSize(968, 320, 640, 64), + common.ContractionSizes([968], [320], [640], [64]), common.ShapedType([64, 968, 640], tuner_ctx.type.f32), common.ShapedType([64, 640, 320], tuner_ctx.type.f32), common.ShapedType([64, 968, 320], tuner_ctx.type.f32), common.DispatchKind.contraction, + common.ContractionDimensions([1], [2], [3], [0]), ), [ iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16, diff --git a/tuner/tuner/dispatch_constraints.py b/tuner/tuner/dispatch_constraints.py index f6de5179d..50a36d02f 100644 --- a/tuner/tuner/dispatch_constraints.py +++ b/tuner/tuner/dispatch_constraints.py @@ -7,6 +7,7 @@ # Given an input dispatch, this code modifies the hyperparameters # in the code and runs it. +import math import z3 # type: ignore from typing import Iterator @@ -63,33 +64,37 @@ def get_dispatch_constraints( def calculate_shared_memory_usage_in_bytes( problem_size: ProblemSize, - m: int | z3.ArithRef, - n: int | z3.ArithRef, - k: int | z3.ArithRef, + m: list[int] | list[z3.ArithRef], + n: list[int] | list[z3.ArithRef], + k: list[int] | list[z3.ArithRef], ) -> int | z3.ArithRef: - lhs_memory = m * k * (problem_size.lhs_type.bitwidth // 8) - rhs_memory = k * n * (problem_size.rhs_type.bitwidth // 8) + lhs_memory = problem_size.lhs_type.bitwidth // 8 + for size in m + k: + lhs_memory *= size + rhs_memory = problem_size.rhs_type.bitwidth // 8 + for size in n + k: + rhs_memory *= size return lhs_memory + rhs_memory -def generate_constraints( +def generate_vector_distribute_constraints( problem_size: ProblemSize, - tile_sizes, - num_subgroups, - subgroup_size, - intrinsic_size, - workgroup_size, - subgroup_m_count, - subgroup_n_count, - waves_per_eu, + tile_sizes: list[list[z3.ArithRef]], + num_subgroups: int, + subgroup_size: z3.ArithRef, + intrinsic_size: list[z3.ArithRef], + workgroup_size: list[z3.ArithRef], + subgroup_m_count: z3.ArithRef, + subgroup_n_count: z3.ArithRef, + waves_per_eu: z3.ArithRef, mma_intrinsics: list[iree_gpu.MMAIntrinsic], ): M, N, K = ( - problem_size.matmul_size.M, - problem_size.matmul_size.N, - problem_size.matmul_size.K, + problem_size.matmul_size.M[-1], + problem_size.matmul_size.N[-1], + problem_size.matmul_size.K[-1], ) - m, n, k = tile_sizes + m_vars, n_vars, k_vars = tile_sizes intrinsic_mn, intrinsic_k = intrinsic_size wg_x, wg_y, wg_z = workgroup_size wg_threads = z3.Int("wg_threads") @@ -101,6 +106,10 @@ def generate_constraints( ) ] subgroup_k_count = 1 + m = m_vars[-1] + n = n_vars[-1] + k = k_vars[-1] + constraints += [v == 1 for v in m_vars[:-1] + n_vars[:-1] + k_vars[:-1]] constraints += [ m >= intrinsic_mn, m <= 512, @@ -136,7 +145,7 @@ def generate_constraints( constraints += [waves_per_eu == 2] # constraints += [z3.Or(waves_per_eu == 2, waves_per_eu == 3, waves_per_eu == 4)] - shared_memory = calculate_shared_memory_usage_in_bytes(problem_size, m, n, k) + shared_memory = calculate_shared_memory_usage_in_bytes(problem_size, [m], [n], [k]) constraints += [shared_memory <= 65536] constraints += get_dispatch_constraints(problem_size, m, n, k) @@ -144,6 +153,96 @@ def generate_constraints( return constraints +def generate_tile_and_fuse_constraints( + problem_size: ProblemSize, + tile_sizes: list[list[z3.ArithRef]], + num_subgroups: int, + subgroup_size: z3.ArithRef, + intrinsic_size: list[z3.ArithRef], + workgroup_size: list[z3.ArithRef], + subgroup_m_count: z3.ArithRef, + subgroup_n_count: z3.ArithRef, + waves_per_eu: z3.ArithRef, + mma_intrinsics: list[iree_gpu.MMAIntrinsic], +): + M, N, K = problem_size.MNK + m_tiles, n_tiles, k_tiles, subgroup_m_tiles, subgroup_n_tiles = tile_sizes + intrinsic_mn, intrinsic_k = intrinsic_size + wg_x, wg_y, wg_z = workgroup_size + wg_threads = z3.Int("wg_threads") + constraints = [wg_x == wg_threads, wg_y == 1, wg_z == 1] + constraints += [subgroup_size == 64, wg_threads <= 1024] + constraints += [ + get_mfma_intrinsic_constraints( + problem_size, intrinsic_mn, intrinsic_mn, intrinsic_k, mma_intrinsics + ) + ] + subgroup_k_count = 1 + + constraints += [ + m_tiles[-1] >= intrinsic_mn, + m_tiles[-1] % intrinsic_mn == 0, + n_tiles[-1] >= intrinsic_mn, + n_tiles[-1] % intrinsic_mn == 0, + k_tiles[-1] * intrinsic_k <= K[-1], + math.prod(m_tiles) <= 512, + math.prod(n_tiles) <= 512, + math.prod(k_tiles) <= 512 / intrinsic_k, + ] + constraints += [m_shape % m == 0 for m, m_shape in zip(m_tiles, M)] + constraints += [n_shape % n == 0 for n, n_shape in zip(n_tiles, N)] + constraints += [k_shape % k == 0 for k, k_shape in zip(k_tiles[:-1], K[:-1])] + constraints += [m >= 0 for m in m_tiles] + constraints += [n >= 0 for n in n_tiles] + constraints += [k >= 0 for k in k_tiles] + constraints += [K[-1] % (k_tiles[-1] * intrinsic_k) == 0] + constraints += [m <= m_shape for m, m_shape in zip(m_tiles, M)] + constraints += [n <= n_shape for n, n_shape in zip(n_tiles, N)] + constraints += [k <= k_shape for k, k_shape in zip(k_tiles[:-1], K[:-1])] + constraints += [(k_tiles[-1] * intrinsic_k) <= K[-1]] + for x in (subgroup_m_count, subgroup_n_count): + constraints += [x >= 1, x <= 32] + + subgroup_m_tile_count = z3.Int("sg_m_tcnt") + subgroup_n_tile_count = z3.Int("sg_n_tcnt") + subgroup_k_tile_count = z3.Int("sg_k_tcnt") + for x in (subgroup_m_tile_count, subgroup_n_tile_count, subgroup_k_tile_count): + constraints += [x >= 1, x <= 32] + constraints += [math.prod(subgroup_m_tiles) == subgroup_m_tile_count] + constraints += [math.prod(subgroup_n_tiles) == subgroup_n_tile_count] + constraints += [ + m % m_subgroup == 0 for m, m_subgroup in zip(m_tiles, subgroup_m_tiles) + ] + constraints += [ + n % n_subgroup == 0 for n, n_subgroup in zip(n_tiles, subgroup_n_tiles) + ] + constraints += [m_subgroup > 0 for m_subgroup in subgroup_m_tiles] + constraints += [n_subgroup > 0 for n_subgroup in subgroup_n_tiles] + + constraints += [ + math.prod(m_tiles) == subgroup_m_count * subgroup_m_tile_count * intrinsic_mn + ] + constraints += [ + math.prod(n_tiles) == subgroup_n_count * subgroup_n_tile_count * intrinsic_mn + ] + constraints += [math.prod(k_tiles) == subgroup_k_count * subgroup_k_tile_count] + subgroups = subgroup_m_count * subgroup_n_count + if num_subgroups > 0: + constraints += [subgroups == num_subgroups] + else: + constraints += [subgroups >= 1, subgroups <= 10] + constraints += [wg_threads == subgroups * subgroup_size] + + constraints += [waves_per_eu == 2] + + shared_memory = calculate_shared_memory_usage_in_bytes( + problem_size, m_tiles, n_tiles, k_tiles + ) + constraints += [shared_memory * intrinsic_k <= 65536] + + return constraints + + def getMMAAttr( output_type: ir.IntegerType | ir.FloatType, m: int, @@ -178,10 +277,16 @@ def generate_solutions( problem_size: ProblemSize, num_subgrups: int, mma_intrinsics: list[iree_gpu.MMAIntrinsic], + codegen_pipeline: iree_codegen.DispatchLoweringPassPipeline = iree_codegen.DispatchLoweringPassPipeline.LLVMGPUVectorDistribute, ) -> Iterator[iree_codegen.CompilationInfoAttr]: M, N, K = problem_size.MNK tuner_ctx.logger.info(f"{M},{N},{K}") - m, n, k = z3.Int("m"), z3.Int("n"), z3.Int("k") + m_vars = [z3.Int(f"m{i}") for i in range(len(M))] + n_vars = [z3.Int(f"n{i}") for i in range(len(N))] + k_vars = [z3.Int(f"k{i}") for i in range(len(K))] + subgroup_m_vars = [z3.Int(f"subgroup_m{i}") for i in range(len(M))] + subgroup_n_vars = [z3.Int(f"subgroup_n{i}") for i in range(len(N))] + # m, n, k = z3.Int("m"), z3.Int("n"), z3.Int("k") subgroup_size = z3.Int("subgroup_size") intrinsic_mn = z3.Int("intrinsic_mn") intrinsic_k = z3.Int("intrinsic_k") @@ -189,34 +294,52 @@ def generate_solutions( sg_m_cnt = z3.Int("sg_m_cnt") sg_n_cnt = z3.Int("sg_n_cnt") waves_per_eu = z3.Int("waves_per_eu") - all_vars = [ - m, - n, - k, - subgroup_size, - intrinsic_mn, - intrinsic_k, - wg_x, - wg_y, - wg_z, - sg_m_cnt, - sg_n_cnt, - waves_per_eu, - ] + all_vars = ( + m_vars + + n_vars + + k_vars + + [ + subgroup_size, + intrinsic_mn, + intrinsic_k, + wg_x, + wg_y, + wg_z, + sg_m_cnt, + sg_n_cnt, + waves_per_eu, + ] + ) solver = z3.Solver() - constraints = generate_constraints( - problem_size, - [m, n, k], - num_subgrups, - subgroup_size, - [intrinsic_mn, intrinsic_k], - [wg_x, wg_y, wg_z], - sg_m_cnt, - sg_n_cnt, - waves_per_eu, - mma_intrinsics, - ) + match codegen_pipeline: + case iree_codegen.DispatchLoweringPassPipeline.LLVMGPUVectorDistribute: + constraints = generate_vector_distribute_constraints( + problem_size, + [m_vars, n_vars, k_vars], + num_subgrups, + subgroup_size, + [intrinsic_mn, intrinsic_k], + [wg_x, wg_y, wg_z], + sg_m_cnt, + sg_n_cnt, + waves_per_eu, + mma_intrinsics, + ) + constraints += [v == 0 for v in subgroup_m_vars + subgroup_n_vars] + case iree_codegen.DispatchLoweringPassPipeline.LLVMGPUTileAndFuse: + constraints = generate_tile_and_fuse_constraints( + problem_size, + [m_vars, n_vars, k_vars, subgroup_m_vars, subgroup_n_vars], + num_subgrups, + subgroup_size, + [intrinsic_mn, intrinsic_k], + [wg_x, wg_y, wg_z], + sg_m_cnt, + sg_n_cnt, + waves_per_eu, + mma_intrinsics, + ) solver.add(z3.simplify(z3.And(constraints))) tuner_ctx.logger.debug(f"Initial constraints: {solver}") @@ -232,21 +355,80 @@ def generate_solutions( problem_size.lhs_type.element_type, problem_size.rhs_type.element_type, ) - workgroup_tiles = [lookup(m), lookup(n), 0] - reduction_tiles = [0, 0, lookup(k)] - if problem_size.dispatch_kind == DispatchKind.conv: - workgroup_tiles = [1, 1, lookup(m), lookup(n), 0, 0, 0] - reduction_tiles = [0, 0, 0, 0, 1, 1, lookup(k)] - lowering_config = get_lowering_config( - tuner_ctx=tuner_ctx, - mma_kind=mma_attr, - workgroup=workgroup_tiles, - reduction=reduction_tiles, - subgroup_m_count=lookup(sg_m_cnt), - subgroup_n_count=lookup(sg_n_cnt), + + def set_cdim_tile_sizes(tile_sizes, contraction_dims, csizes): + for dim, size in zip(contraction_dims, csizes): + tile_sizes[dim] = size + + # Get workgroup tile sizes. + workgroup_tile_sizes = [0] * ( + len(M) + len(N) + len(K) + len(problem_size.contraction_dims.batch) ) + set_cdim_tile_sizes( + workgroup_tile_sizes, + problem_size.contraction_dims.m, + [lookup(v) for v in m_vars], + ) + set_cdim_tile_sizes( + workgroup_tile_sizes, + problem_size.contraction_dims.n, + [lookup(v) for v in n_vars], + ) + set_cdim_tile_sizes( + workgroup_tile_sizes, + problem_size.contraction_dims.batch, + [1] * len(problem_size.contraction_dims.batch), + ) + + # Get subgroup tile sizes. + subgroup_tile_sizes = [0] * ( + len(M) + len(N) + len(K) + len(problem_size.contraction_dims.batch) + ) + set_cdim_tile_sizes( + subgroup_tile_sizes, + problem_size.contraction_dims.m, + [lookup(v) for v in subgroup_m_vars], + ) + set_cdim_tile_sizes( + subgroup_tile_sizes, + problem_size.contraction_dims.n, + [lookup(v) for v in subgroup_n_vars], + ) + set_cdim_tile_sizes( + subgroup_tile_sizes, + problem_size.contraction_dims.batch, + [1] * len(problem_size.contraction_dims.batch), + ) + + # Get reduction tile sizes. + reduction_tile_sizes = [0] * ( + len(M) + len(N) + len(K) + len(problem_size.contraction_dims.batch) + ) + set_cdim_tile_sizes( + reduction_tile_sizes, + problem_size.contraction_dims.k, + [lookup(v) for v in k_vars], + ) + + # Create the LoweringConfigAttr. + lowering_config_args = { + "tuner_ctx": tuner_ctx, + "mma_kind": mma_attr, + "workgroup": workgroup_tile_sizes, + "reduction": reduction_tile_sizes, + "subgroup_m_count": lookup(sg_m_cnt), + "subgroup_n_count": lookup(sg_n_cnt), + } + if ( + codegen_pipeline + == iree_codegen.DispatchLoweringPassPipeline.LLVMGPUTileAndFuse + ): + lowering_config_args["subgroup"] = subgroup_tile_sizes + lowering_config = get_lowering_config(**lowering_config_args) + + # Create the TranslationInfoAttr pipeline_attr = iree_codegen.DispatchLoweringPassPipelineAttr.get( - iree_codegen.DispatchLoweringPassPipeline.LLVMGPUVectorDistribute + codegen_pipeline ) pipeline_options = iree_gpu.PipelineOptionsAttr.get() config_dict = get_translation_info_config( @@ -259,9 +441,12 @@ def generate_solutions( lookup(subgroup_size), config_dict, ) + + # Create the CompilationInfoAttr. compilation_info = iree_codegen.CompilationInfoAttr.get( lowering_config, translation_info ) + solver.add(z3.simplify(z3.Not(z3.And(list(x == model[x] for x in all_vars))))) i += 1 yield compilation_info diff --git a/tuner/tuner/dispatch_constraints_test.py b/tuner/tuner/dispatch_constraints_test.py index 5c82f555f..d31a76e90 100644 --- a/tuner/tuner/dispatch_constraints_test.py +++ b/tuner/tuner/dispatch_constraints_test.py @@ -31,12 +31,18 @@ def tuner_ctx() -> Generator[common.TunerContext, None, None]: def test_generate_solutions(tuner_ctx: common.TunerContext) -> None: - matmul_size = common.MatmulSize(2048, 3840, 1280) + matmul_size = common.ContractionSizes([2048], [3840], [1280]) + contraction_dims = common.ContractionDimensions([0], [1], [2]) lhs_type = common.ShapedType([2048, 1280], tuner_ctx.type.f16) rhs_type = common.ShapedType([3840, 1280], tuner_ctx.type.f16) res_type = common.ShapedType([2048, 3840], tuner_ctx.type.f32) problem_size = common.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.contraction + matmul_size, + lhs_type, + rhs_type, + res_type, + common.DispatchKind.contraction, + contraction_dims, ) configs = dispatch_constraints.generate_solutions( tuner_ctx, @@ -54,56 +60,235 @@ def test_generate_solutions(tuner_ctx: common.TunerContext) -> None: def test_calculate_shared_memory_usage_in_bytes(tuner_ctx: common.TunerContext) -> None: - matmul_size = common.MatmulSize(1024, 1024, 1024) + matmul_size = common.ContractionSizes([1024], [1024], [1024]) + contraction_dims = common.ContractionDimensions([0], [1], [2]) lhs_type = common.ShapedType([1024, 1024], tuner_ctx.type.f16) rhs_type = common.ShapedType([1024, 1024], tuner_ctx.type.f16) res_type = common.ShapedType([1024, 1024], tuner_ctx.type.f32) problem_size = common.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.contraction + matmul_size, + lhs_type, + rhs_type, + res_type, + common.DispatchKind.contraction, + contraction_dims, ) assert ( dispatch_constraints.calculate_shared_memory_usage_in_bytes( - problem_size, 512, 64, 128 + problem_size, [512], [64], [128] ) == 147456 ) lhs_type = common.ShapedType([1024, 1024], tuner_ctx.type.i8) problem_size = common.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.contraction + matmul_size, + lhs_type, + rhs_type, + res_type, + common.DispatchKind.contraction, + contraction_dims, ) assert ( dispatch_constraints.calculate_shared_memory_usage_in_bytes( - problem_size, 512, 64, 128 + problem_size, [512], [64], [128] ) == 81920 ) rhs_type = common.ShapedType([1024, 1024], tuner_ctx.type.i32) problem_size = common.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.contraction + matmul_size, + lhs_type, + rhs_type, + res_type, + common.DispatchKind.contraction, + contraction_dims, ) assert ( dispatch_constraints.calculate_shared_memory_usage_in_bytes( - problem_size, 128, 64, 32 + problem_size, [128], [64], [32] ) == 12288 ) + assert ( + dispatch_constraints.calculate_shared_memory_usage_in_bytes( + problem_size, [2, 64], [4, 16], [8, 4] + ) + == 12288 + ) + + +def test_generate_tile_and_fuse_constraints_valid_input( + tuner_ctx: common.TunerContext, +) -> None: + matmul_size = common.ContractionSizes( + M=[4, 32], + N=[6, 64], + K=[8, 128], + B=[2, 16], + ) + contraction_dims = common.ContractionDimensions( + m=[1, 5], + n=[2, 6], + k=[3, 7], + batch=[0, 4], + ) + lhs_type = common.ShapedType([2, 4, 8, 16, 32, 128], tuner_ctx.type.f16) + rhs_type = common.ShapedType([2, 6, 8, 16, 64, 128], tuner_ctx.type.f16) + res_type = common.ShapedType([2, 4, 6, 16, 32, 64], tuner_ctx.type.f32) + problem_size = common.ProblemSize( + matmul_size, + lhs_type, + rhs_type, + res_type, + common.DispatchKind.contraction, + contraction_dims, + ) + # Define input parameters as z3 Ints + m, n, k = ( + [z3.Int("m0"), z3.Int("m1")], + [z3.Int("n0"), z3.Int("n1")], + [z3.Int("k0"), z3.Int("k1")], + ) + subgroup_m, subgroup_n = ( + [z3.Int("subgroup_m0"), z3.Int("subgroup_m1")], + [z3.Int("subgroup_n0"), z3.Int("subgroup_n1")], + ) + subgroup_size = z3.Int("subgroup_size") + intrinsic_mn = z3.Int("intrinsic_mn") + intrinsic_k = z3.Int("intrinsic_k") + wg_x, wg_y, wg_z = ( + z3.Int("wg_x"), + z3.Int("wg_y"), + z3.Int("wg_z"), + ) + sg_m_cnt = z3.Int("sg_m_cnt") + sg_n_cnt = z3.Int("sg_n_cnt") + waves_per_eu = z3.Int("waves_per_eu") + + constraints = dispatch_constraints.generate_tile_and_fuse_constraints( + problem_size, + [m, n, k, subgroup_m, subgroup_n], + 4, + subgroup_size, + [intrinsic_mn, intrinsic_k], + [wg_x, wg_y, wg_z], + sg_m_cnt, + sg_n_cnt, + waves_per_eu, + [ + iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16, + iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16, + iree_gpu.MMAIntrinsic.MFMA_I32_16x16x32_I8, + iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8, + ], + ) + + solver = z3.Solver() + solver.add(constraints) + + # Check if the constraints are satisfiable + assert solver.check() == z3.sat + + +def test_generate_tile_and_fuse_constraints_invalid_input( + tuner_ctx: common.TunerContext, +) -> None: + # Define input parameters that should lead to unsatisfiable constraints + matmul_size = common.ContractionSizes( + M=[4, 32], + N=[6, 64], + K=[8, 128], + B=[2, 16], + ) + contraction_dims = common.ContractionDimensions( + m=[1, 5], + n=[2, 6], + k=[3, 7], + batch=[0, 4], + ) + lhs_type = common.ShapedType([2, 4, 8, 16, 32, 128], tuner_ctx.type.f16) + rhs_type = common.ShapedType([2, 6, 8, 16, 64, 128], tuner_ctx.type.f16) + res_type = common.ShapedType([2, 4, 6, 16, 32, 64], tuner_ctx.type.f32) + problem_size = common.ProblemSize( + matmul_size, + lhs_type, + rhs_type, + res_type, + common.DispatchKind.contraction, + contraction_dims, + ) + # Define input parameters as z3 Ints + m, n, k = ( + [z3.Int("m0"), z3.Int("m1")], + [z3.Int("n0"), z3.Int("n1")], + [z3.Int("k0"), z3.Int("k1")], + ) + subgroup_m, subgroup_n = ( + [z3.Int("subgroup_m0"), z3.Int("subgroup_m1")], + [z3.Int("subgroup_n0"), z3.Int("subgroup_n1")], + ) + subgroup_size = z3.Int("subgroup_size") + intrinsic_mn = z3.Int("intrinsic_mn") + intrinsic_k = z3.Int("intrinsic_k") + wg_x, wg_y, wg_z = ( + z3.Int("wg_x"), + z3.Int("wg_y"), + z3.Int("wg_z"), + ) + sg_m_cnt = z3.Int("sg_m_cnt") + sg_n_cnt = z3.Int("sg_n_cnt") + waves_per_eu = z3.Int("waves_per_eu") + + constraints = dispatch_constraints.generate_tile_and_fuse_constraints( + problem_size, + [m, n, k, subgroup_m, subgroup_n], + 4, + subgroup_size, + [intrinsic_mn, intrinsic_k], + [wg_x, wg_y, wg_z], + sg_m_cnt, + sg_n_cnt, + waves_per_eu, + [ + iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16, + iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16, + iree_gpu.MMAIntrinsic.MFMA_I32_16x16x32_I8, + iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8, + ], + ) + constraints.append(m[0] > 1000) # Adding an additional unsatisfiable constraint + + solver = z3.Solver() + solver.add(constraints) + + # Check if the constraints are unsatisfiable + assert solver.check() == z3.unsat + -def test_generate_constraints_valid_input(tuner_ctx: common.TunerContext) -> None: - matmul_size = common.MatmulSize(1024, 1024, 1024) +def test_generate_vector_distribute_constraints_valid_input( + tuner_ctx: common.TunerContext, +) -> None: + matmul_size = common.ContractionSizes([1024], [1024], [1024]) + contraction_dims = common.ContractionDimensions([0], [1], [2]) lhs_type = common.ShapedType([1024, 1024], tuner_ctx.type.f16) rhs_type = common.ShapedType([1024, 1024], tuner_ctx.type.f16) res_type = common.ShapedType([1024, 1024], tuner_ctx.type.f32) problem_size = common.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.contraction + matmul_size, + lhs_type, + rhs_type, + res_type, + common.DispatchKind.contraction, + contraction_dims, ) # Define input parameters as z3 Ints m, n, k = ( - dispatch_constraints.z3.Int("m"), - z3.Int("n"), - z3.Int("k"), + [z3.Int("m")], + [z3.Int("n")], + [z3.Int("k")], ) subgroup_size = z3.Int("subgroup_size") intrinsic_mn = z3.Int("intrinsic_mn") @@ -117,7 +302,7 @@ def test_generate_constraints_valid_input(tuner_ctx: common.TunerContext) -> Non sg_n_cnt = z3.Int("sg_n_cnt") waves_per_eu = z3.Int("waves_per_eu") - constraints = dispatch_constraints.generate_constraints( + constraints = dispatch_constraints.generate_vector_distribute_constraints( problem_size, [m, n, k], 4, @@ -142,19 +327,27 @@ def test_generate_constraints_valid_input(tuner_ctx: common.TunerContext) -> Non assert solver.check() == z3.sat -def test_generate_constraints_invalid_input(tuner_ctx: common.TunerContext) -> None: +def test_generate_vector_distribute_constraints_invalid_input( + tuner_ctx: common.TunerContext, +) -> None: # Define input parameters that should lead to unsatisfiable constraints - matmul_size = common.MatmulSize(1024, 1024, 1024) + matmul_size = common.ContractionSizes([1024], [1024], [1024]) + contraction_dims = common.ContractionDimensions([0], [1], [2]) lhs_type = common.ShapedType([1024, 1024], tuner_ctx.type.f16) rhs_type = common.ShapedType([1024, 1024], tuner_ctx.type.f16) res_type = common.ShapedType([1024, 1024], tuner_ctx.type.f32) problem_size = common.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.contraction + matmul_size, + lhs_type, + rhs_type, + res_type, + common.DispatchKind.contraction, + contraction_dims, ) m, n, k = ( - z3.Int("m"), - z3.Int("n"), - z3.Int("k"), + [z3.Int("m")], + [z3.Int("n")], + [z3.Int("k")], ) subgroup_size = z3.Int("subgroup_size") intrinsic_mn = z3.Int("intrinsic_mn") @@ -168,7 +361,7 @@ def test_generate_constraints_invalid_input(tuner_ctx: common.TunerContext) -> N sg_n_cnt = z3.Int("sg_n_cnt") waves_per_eu = z3.Int("waves_per_eu") - constraints = dispatch_constraints.generate_constraints( + constraints = dispatch_constraints.generate_vector_distribute_constraints( problem_size, [m, n, k], 4, @@ -185,7 +378,7 @@ def test_generate_constraints_invalid_input(tuner_ctx: common.TunerContext) -> N iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8, ], ) - constraints.append(m > 1000) # Adding an additional unsatisfiable constraint + constraints.append(m[0] > 1000) # Adding an additional unsatisfiable constraint solver = z3.Solver() solver.add(constraints) diff --git a/tuner/tuner/dispatch_parser.py b/tuner/tuner/dispatch_parser.py index 502968ea8..ad7ba3a79 100644 --- a/tuner/tuner/dispatch_parser.py +++ b/tuner/tuner/dispatch_parser.py @@ -60,31 +60,39 @@ def get_shapes(self, template: list[str]) -> ProblemSize: ir_module = ir.Module.parse("\n".join(template)) contraction_op = match_root_op(ir_module, matcher) assert contraction_op is not None, f"contraction op not found" - cdims = matcher.contraction_dimensions - assert cdims, "no contraction dimensions" + contraction_dims = matcher.contraction_dimensions + assert contraction_dims, "no contraction dimensions" assert matcher.lhs_dims, "no lhs dimensions" assert matcher.rhs_dims, "no rhs dimensions" assert matcher.res_dims, "no result dimensions" - assert len(cdims.batch) <= 1, f"must have at most 1 batch dimension" - assert len(cdims.m) == 1, f"must have a single m dimension" - assert len(cdims.n) == 1, f"must have a single n dimension" - assert len(cdims.k) == 1, f"must have a single k dimension" lhs_type = ir.RankedTensorType(contraction_op.operands[0].type) rhs_type = ir.RankedTensorType(contraction_op.operands[1].type) res_type = ir.RankedTensorType(contraction_op.operands[2].type) - matmul_size = MatmulSize( - lhs_type.shape[matcher.lhs_dims.index(cdims.m[0])], - rhs_type.shape[matcher.rhs_dims.index(cdims.n[0])], - lhs_type.shape[matcher.lhs_dims.index(cdims.k[0])], + matmul_size = ContractionSizes( + M=[ + lhs_type.shape[matcher.lhs_dims.index(dim)] + for dim in contraction_dims.m + ], + N=[ + rhs_type.shape[matcher.rhs_dims.index(dim)] + for dim in contraction_dims.n + ], + K=[ + lhs_type.shape[matcher.lhs_dims.index(dim)] + for dim in contraction_dims.k + ], + B=[ + lhs_type.shape[matcher.lhs_dims.index(dim)] + for dim in contraction_dims.batch + ], ) - if len(cdims.batch) == 1: - matmul_size.B = lhs_type.shape[matcher.lhs_dims.index(cdims.batch[0])] return ProblemSize( matmul_size, lhs_type=ShapedType(lhs_type.shape, lhs_type.element_type), rhs_type=ShapedType(rhs_type.shape, rhs_type.element_type), res_type=ShapedType(res_type.shape, res_type.element_type), dispatch_kind=DispatchKind.contraction, + contraction_dims=contraction_dims, ) @@ -115,14 +123,18 @@ def get_shapes(self, template: list[str]) -> ProblemSize: res_type = ir.RankedTensorType(conv_op.operands[2].type) dim_info = ConvDimInfo.from_rhs_res(rhs_type, res_type) return ProblemSize( - MatmulSize( - M=dim_info.oh * dim_info.ow, - N=dim_info.oc, - K=dim_info.fh * dim_info.fw * dim_info.ic, - B=dim_info.n, + matmul_size=ContractionSizes( + M=[dim_info.n, dim_info.oh, dim_info.ow], + N=[dim_info.oc], + K=[dim_info.fh, dim_info.fw, dim_info.ic], ), lhs_type=ShapedType(lhs_type.shape, lhs_type.element_type), rhs_type=ShapedType(rhs_type.shape, rhs_type.element_type), res_type=ShapedType(res_type.shape, res_type.element_type), dispatch_kind=DispatchKind.conv, + contraction_dims=ContractionDimensions( + m=[0, 1, 2], + n=[3], + k=[4, 5, 6], + ), ) diff --git a/tuner/tuner/dispatch_parser_test.py b/tuner/tuner/dispatch_parser_test.py index c35b17bed..7ddb0bb84 100644 --- a/tuner/tuner/dispatch_parser_test.py +++ b/tuner/tuner/dispatch_parser_test.py @@ -78,10 +78,10 @@ def test_get_contraction_operation(tuner_ctx: common.TunerContext) -> None: assert mmt_op is not None assert isinstance(mmt_op.opview, linalg.GenericOp) shapes: common.ProblemSize = parser.get_shapes(transpose_b_str.splitlines()) - assert shapes.matmul_size.B == 1 - assert shapes.matmul_size.M == 16 - assert shapes.matmul_size.N == 32 - assert shapes.matmul_size.K == 64 + assert shapes.matmul_size.B == [] + assert shapes.matmul_size.M == [16] + assert shapes.matmul_size.N == [32] + assert shapes.matmul_size.K == [64] assert shapes.lhs_type.shape == [16, 64] assert isinstance(shapes.lhs_type.element_type, ir.F16Type) assert shapes.rhs_type.shape == [32, 64] @@ -102,10 +102,32 @@ def test_get_contraction_operation(tuner_ctx: common.TunerContext) -> None: module = ir.Module.parse(bmm_transposed_inputs_str, context) mmt_op = parser.get_contraction_operation(module) shapes = parser.get_shapes(bmm_transposed_inputs_str.splitlines()) - assert shapes.matmul_size.B == 5 - assert shapes.matmul_size.M == 8 - assert shapes.matmul_size.N == 40 - assert shapes.matmul_size.K == 128 + assert shapes.matmul_size.B == [5] + assert shapes.matmul_size.M == [8] + assert shapes.matmul_size.N == [40] + assert shapes.matmul_size.K == [128] + + with ir.Location.unknown(): + bmm_transposed_inputs_str = CONTRACTION_TEMPLATE.format( + lhs_type=ir.RankedTensorType.get( + [16, 8, 15, 16, 64, 256], ir.F16Type.get() + ), + rhs_type=ir.RankedTensorType.get( + [16, 9, 15, 16, 128, 256], ir.F16Type.get() + ), + res_type=ir.RankedTensorType.get([16, 8, 9, 16, 64, 128], ir.F32Type.get()), + lhs_map="affine_map<(b0, m0, n0, k0, b1, m1, n1, k1) -> (b0, m0, k0, b1, m1, k1)>", + rhs_map="affine_map<(b0, m0, n0, k0, b1, m1, n1, k1) -> (b0, n0, k0, b1, n1, k1)>", + res_map="affine_map<(b0, m0, n0, k0, b1, m1, n1, k1) -> (b0, m0, n0, b1, m1, n1)>", + iterator_types='["parallel", "parallel", "parallel", "reduction", "parallel", "parallel", "parallel", "reduction"]', + ) + module = ir.Module.parse(bmm_transposed_inputs_str, context) + mmt_op = parser.get_contraction_operation(module) + shapes = parser.get_shapes(bmm_transposed_inputs_str.splitlines()) + assert shapes.matmul_size.B == [16, 16] + assert shapes.matmul_size.M == [8, 64] + assert shapes.matmul_size.N == [9, 128] + assert shapes.matmul_size.K == [15, 256] def test_get_conv_operation(tuner_ctx: common.TunerContext) -> None: diff --git a/tuner/tuner/libtuner.py b/tuner/tuner/libtuner.py index 4e2a97ec8..fab86c369 100644 --- a/tuner/tuner/libtuner.py +++ b/tuner/tuner/libtuner.py @@ -227,6 +227,11 @@ class ExecutionPhases(str, Enum): benchmark_models = "benchmark-models" +class CodegenPipelines(str, Enum): + llvmgpu_vector_distribute = "llvmgpu_vector_distribute" + llvmgpu_tile_and_fuse = "llvmgpu_tile_and_fuse" + + def parse_arguments( initial_parser: Optional[argparse.ArgumentParser] = None, ) -> argparse.Namespace: @@ -298,6 +303,12 @@ def parse_arguments( candidate_gen_args.add_argument( "--tile-dims", help="Map of tile size matmul dims", type=str, default="mnk" ) + general_args.add_argument( + "--codegen-pipeline", + choices=[x.value for x in CodegenPipelines], + default=CodegenPipelines.llvmgpu_vector_distribute, + help="Codegen pipeline to tune for", + ) return parser.parse_args() @@ -499,7 +510,9 @@ def run_iree_benchmark_module_command(benchmark_pack: BenchmarkPack): ) times = [] + logging.debug(f"candidate {candidate_id} benchmark_results: {benchmark_results}") for benchmark_result in benchmark_results: + logging.debug(f"candidate {candidate_id} benchmark_result: {benchmark_result}") benchmark_name = benchmark_result.benchmark_name # With multiple benchmark results, there will be `real_time_mean`, but # not with single iteration benchmark results, so ignore the mean time @@ -601,6 +614,16 @@ def find_collisions( return collisions_exist, hash_values +def get_iree_codegen_pipeline(pipeline: CodegenPipelines): + match pipeline: + case CodegenPipelines.llvmgpu_vector_distribute: + return iree_codegen.DispatchLoweringPassPipeline.LLVMGPUVectorDistribute + case CodegenPipelines.llvmgpu_tile_and_fuse: + return iree_codegen.DispatchLoweringPassPipeline.LLVMGPUTileAndFuse + case _: + assert False, "unexpected codegen pipeline" + + def generate_candidate_specs( args: argparse.Namespace, path_config: PathConfig, @@ -628,6 +651,7 @@ def generate_candidate_specs( tuner_context=tuning_client.tuner_context, limit=args.num_candidates, num_subgroups=args.num_subgroups, + codegen_pipeline=get_iree_codegen_pipeline(args.codegen_pipeline), ) logging.debug("candidate_gen.py ends") handle_error( diff --git a/tuner/tuner/op_matchers.py b/tuner/tuner/op_matchers.py index db953fbb3..f3966b97d 100644 --- a/tuner/tuner/op_matchers.py +++ b/tuner/tuner/op_matchers.py @@ -170,10 +170,10 @@ def match_indexing_maps(self, maps: list[ir.AffineMap]) -> bool: return False self.contraction_dimensions = ContractionDimensions( - batch=batch_dims, m=m_dims, n=n_dims, k=k_dims, + batch=batch_dims, ) self.lhs_dims = lhs_dims self.rhs_dims = rhs_dims