Skip to content

Commit

Permalink
[tuner] Add support for TileAndFuse and multi-dim contractions (#771)
Browse files Browse the repository at this point in the history
This PR adds support for tuning contractions with multiple M, N, K, and
Batch dimensions, and adds support for tuning with the TileAndFuse
pipeline. A new flag is added called `--codegen-pipeline` that specifies
which codegen pipeline to target (`llvmgpu_vector_distribute` or
`llvmgpu_tile_and_fuse`).

---------

Signed-off-by: Max Dawkins <[email protected]>
Signed-off-by: Max Dawkins <[email protected]>
Co-authored-by: Max Dawkins <[email protected]>
  • Loading branch information
Max191 and Max Dawkins authored Jan 7, 2025
1 parent 3bf4faf commit 4f6f2b3
Show file tree
Hide file tree
Showing 10 changed files with 600 additions and 124 deletions.
3 changes: 2 additions & 1 deletion tuner/examples/test/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,6 @@ python -m examples.test <model_file_path> <benchmark_file_path> \
--test_num_dispatch_candidates=<num_dispatch_candidates> \
--test_num_model_candidates=<num_model_candidates> \
--test_hip_target=<hip_target> \
--num-candidates=<num_generated_candidates>
--num-candidates=<num_generated_candidates> \
--codegen-pipeline=<codegen_pipeline>
```
5 changes: 4 additions & 1 deletion tuner/tuner/candidate_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def generate_configs_and_td_specs(
tuner_context: TunerContext,
limit: int = 4096, # Max candidates to be generated
num_subgroups: int = 4, # GPU spec, used to determine candidate generation constraints
codegen_pipeline: iree_codegen.DispatchLoweringPassPipeline = iree_codegen.DispatchLoweringPassPipeline.LLVMGPUVectorDistribute,
) -> list[ir.Module]:
dispatch_tuner_registry = DispatchTunerRegistry(check_translation_info=False)
dispatch_tuner_registry.register(
Expand Down Expand Up @@ -221,7 +222,9 @@ def generate_configs_and_td_specs(
variant_op = variant_op_list[0]
mma_list = iree_codegen.query_mma_intrinsics(variant_op)
for i, config in enumerate(
generate_solutions(tuner_context, problem_size, num_subgroups, mma_list)
generate_solutions(
tuner_context, problem_size, num_subgroups, mma_list, codegen_pipeline
)
):
if i >= limit:
break
Expand Down
53 changes: 43 additions & 10 deletions tuner/tuner/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import re
import logging
from dataclasses import astuple, dataclass
from dataclasses import astuple, dataclass, field
from enum import Enum
from typing import Optional
from typing import Any
Expand Down Expand Up @@ -67,31 +67,64 @@ def __str__(self) -> str:


@dataclass
class MatmulSize:
M: int
N: int
K: int
B: int = 1
class ContractionSizes:
"""
Represents the size of the iteration space along each contraction dimension.
For example, the following is a simple batch mmt:
linalg.generic ... indexing_maps = [
affine_map<(b, m, n, k) -> (b, m, k)>,
affine_map<(b, m, n, k) -> (b, n, k)>,
affine_map<(b, m, n, k) -> (b, m, n)>,
] ...
ins(%lhs: tensor<4x8x32xf16>, %rhs: tensor<4x16x32xf16>)
outs(%acc: tensor<4x8x16xf16>)
The ContractionSizes would be:
M = [8]
N = [16]
K = [32]
B = [4]
"""

M: list[int]
N: list[int]
K: list[int]
B: list[int] = field(default_factory=list)


@dataclass
class ContractionDimensions:
batch: list[int]
"""
Stores which dimensions of the iteration space belong to M, N, K, or Batch.
For example, the following is a simple batch mmt:
linalg.generic ... indexing_maps = [
affine_map<(b, m, n, k) -> (b, m, k)>,
affine_map<(b, m, n, k) -> (b, n, k)>,
affine_map<(b, m, n, k) -> (b, m, n)>,
]
The ContractionDimensions would be:
M = [1]
N = [2]
K = [3]
B = [0]
"""

m: list[int]
n: list[int]
k: list[int]
batch: list[int] = field(default_factory=list)


@dataclass
class ProblemSize:
matmul_size: MatmulSize
matmul_size: ContractionSizes
lhs_type: ShapedType
rhs_type: ShapedType
res_type: ShapedType
dispatch_kind: DispatchKind
contraction_dims: ContractionDimensions

@property
def MNK(self) -> tuple[int, int, int]:
def MNK(self) -> tuple[list[int], list[int], list[int]]:
return (self.matmul_size.M, self.matmul_size.N, self.matmul_size.K)


Expand Down Expand Up @@ -130,7 +163,7 @@ def get_lowering_config(
# A local variable to hold the transformed value.
promoted_value = value
match key:
case "workgroup" | "reduction":
case "workgroup" | "reduction" | "subgroup":
if isinstance(value, list):
promoted_value = ir.ArrayAttr.get(
[tuner_ctx.type.getI64(x) for x in value]
Expand Down
9 changes: 6 additions & 3 deletions tuner/tuner/common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,12 @@ def test_get_pipeline_config(tuner_ctx: common.TunerContext) -> None:
def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None:
assert common.get_compatible_mfma_intrinsics(
common.ProblemSize(
common.MatmulSize(2048, 1280, 1280),
common.ContractionSizes([2048], [1280], [1280]),
common.ShapedType([2048, 1280], tuner_ctx.type.f16),
common.ShapedType([1280, 1280], tuner_ctx.type.f16),
common.ShapedType([2048, 1280], tuner_ctx.type.f32),
common.DispatchKind.contraction,
common.ContractionDimensions([0], [1], [2]),
),
[
iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16,
Expand All @@ -138,11 +139,12 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None:

assert common.get_compatible_mfma_intrinsics(
common.ProblemSize(
common.MatmulSize(2048, 1280, 1280),
common.ContractionSizes([2048], [1280], [1280]),
common.ShapedType([2048, 1280], tuner_ctx.type.i8),
common.ShapedType([1280, 1280], tuner_ctx.type.i8),
common.ShapedType([2048, 1280], tuner_ctx.type.i32),
common.DispatchKind.contraction,
common.ContractionDimensions([0], [1], [2]),
),
[
iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16,
Expand All @@ -158,11 +160,12 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None:
assert (
common.get_compatible_mfma_intrinsics(
common.ProblemSize(
common.MatmulSize(968, 320, 640, 64),
common.ContractionSizes([968], [320], [640], [64]),
common.ShapedType([64, 968, 640], tuner_ctx.type.f32),
common.ShapedType([64, 640, 320], tuner_ctx.type.f32),
common.ShapedType([64, 968, 320], tuner_ctx.type.f32),
common.DispatchKind.contraction,
common.ContractionDimensions([1], [2], [3], [0]),
),
[
iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16,
Expand Down
Loading

0 comments on commit 4f6f2b3

Please sign in to comment.