Skip to content

Commit

Permalink
[tuner] Tweak tile and fuse constraints and test to reduce run time (#…
Browse files Browse the repository at this point in the history
…779)

The `test_generate_tile_and_fuse_constraints*` tests have a long run
time due to the complex constraint problem they are solving. This PR
reduces the complexity of the problem to improve the run time of the
test.

Signed-off-by: Max Dawkins <[email protected]>
  • Loading branch information
Max191 authored Jan 8, 2025
1 parent 12446b3 commit 04b1819
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 52 deletions.
1 change: 1 addition & 0 deletions tuner/tuner/candidate_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,7 @@ def main():
tuner_ctx,
args.limit,
args.num_subgroups,
iree_codegen.DispatchLoweringPassPipeline.LLVMGPUTileAndFuse,
)
for candidate_num, spec in enumerate(specs):
spec_dir = Path(args.output)
Expand Down
37 changes: 17 additions & 20 deletions tuner/tuner/dispatch_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,15 +169,14 @@ def generate_tile_and_fuse_constraints(
m_tiles, n_tiles, k_tiles, subgroup_m_tiles, subgroup_n_tiles = tile_sizes
intrinsic_mn, intrinsic_k = intrinsic_size
wg_x, wg_y, wg_z = workgroup_size
wg_threads = z3.Int("wg_threads")
constraints = [wg_x == wg_threads, wg_y == 1, wg_z == 1]
wg_threads = wg_x
constraints = [wg_y == 1, wg_z == 1]
constraints += [subgroup_size == 64, wg_threads <= 1024]
constraints += [
get_mfma_intrinsic_constraints(
problem_size, intrinsic_mn, intrinsic_mn, intrinsic_k, mma_intrinsics
)
]
subgroup_k_count = 1

constraints += [
m_tiles[-1] >= intrinsic_mn,
Expand All @@ -192,9 +191,9 @@ def generate_tile_and_fuse_constraints(
constraints += [m_shape % m == 0 for m, m_shape in zip(m_tiles, M)]
constraints += [n_shape % n == 0 for n, n_shape in zip(n_tiles, N)]
constraints += [k_shape % k == 0 for k, k_shape in zip(k_tiles[:-1], K[:-1])]
constraints += [m >= 0 for m in m_tiles]
constraints += [n >= 0 for n in n_tiles]
constraints += [k >= 0 for k in k_tiles]
constraints += [m >= 1 for m in m_tiles]
constraints += [n >= 1 for n in n_tiles]
constraints += [k >= 1 for k in k_tiles]
constraints += [K[-1] % (k_tiles[-1] * intrinsic_k) == 0]
constraints += [m <= m_shape for m, m_shape in zip(m_tiles, M)]
constraints += [n <= n_shape for n, n_shape in zip(n_tiles, N)]
Expand All @@ -203,29 +202,27 @@ def generate_tile_and_fuse_constraints(
for x in (subgroup_m_count, subgroup_n_count):
constraints += [x >= 1, x <= 32]

subgroup_m_tile_count = z3.Int("sg_m_tcnt")
subgroup_n_tile_count = z3.Int("sg_n_tcnt")
subgroup_k_tile_count = z3.Int("sg_k_tcnt")
for x in (subgroup_m_tile_count, subgroup_n_tile_count, subgroup_k_tile_count):
constraints += [x >= 1, x <= 32]
constraints += [math.prod(subgroup_m_tiles) == subgroup_m_tile_count]
constraints += [math.prod(subgroup_n_tiles) == subgroup_n_tile_count]
constraints += [
m % m_subgroup == 0 for m, m_subgroup in zip(m_tiles, subgroup_m_tiles)
m % m_subgroup == 0
for m, m_subgroup in zip(m_tiles[:-1], subgroup_m_tiles[:-1])
]
constraints += [
n % n_subgroup == 0 for n, n_subgroup in zip(n_tiles, subgroup_n_tiles)
n % n_subgroup == 0
for n, n_subgroup in zip(n_tiles[:-1], subgroup_n_tiles[:-1])
]
constraints += [m_subgroup > 0 for m_subgroup in subgroup_m_tiles]
constraints += [n_subgroup > 0 for n_subgroup in subgroup_n_tiles]
constraints += [m_tiles[-1] % (subgroup_m_tiles[-1] * intrinsic_mn) == 0]
constraints += [n_tiles[-1] % (subgroup_n_tiles[-1] * intrinsic_mn) == 0]
constraints += [m_subgroup >= 1 for m_subgroup in subgroup_m_tiles]
constraints += [n_subgroup >= 1 for n_subgroup in subgroup_n_tiles]

constraints += [
math.prod(m_tiles) == subgroup_m_count * subgroup_m_tile_count * intrinsic_mn
math.prod(m_tiles)
== math.prod(subgroup_m_tiles) * subgroup_m_count * intrinsic_mn
]
constraints += [
math.prod(n_tiles) == subgroup_n_count * subgroup_n_tile_count * intrinsic_mn
math.prod(n_tiles)
== math.prod(subgroup_n_tiles) * subgroup_n_count * intrinsic_mn
]
constraints += [math.prod(k_tiles) == subgroup_k_count * subgroup_k_tile_count]
subgroups = subgroup_m_count * subgroup_n_count
if num_subgroups > 0:
constraints += [subgroups == num_subgroups]
Expand Down
64 changes: 32 additions & 32 deletions tuner/tuner/dispatch_constraints_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,20 +124,20 @@ def test_generate_tile_and_fuse_constraints_valid_input(
tuner_ctx: common.TunerContext,
) -> None:
matmul_size = common.ContractionSizes(
M=[4, 32],
N=[6, 64],
K=[8, 128],
B=[2, 16],
M=[32],
N=[64],
K=[128],
B=[2],
)
contraction_dims = common.ContractionDimensions(
m=[1, 5],
n=[2, 6],
k=[3, 7],
batch=[0, 4],
m=[1],
n=[2],
k=[3],
batch=[0],
)
lhs_type = common.ShapedType([2, 4, 8, 16, 32, 128], tuner_ctx.type.f16)
rhs_type = common.ShapedType([2, 6, 8, 16, 64, 128], tuner_ctx.type.f16)
res_type = common.ShapedType([2, 4, 6, 16, 32, 64], tuner_ctx.type.f32)
lhs_type = common.ShapedType([2, 32, 128], tuner_ctx.type.f16)
rhs_type = common.ShapedType([2, 64, 128], tuner_ctx.type.f16)
res_type = common.ShapedType([2, 32, 64], tuner_ctx.type.f32)
problem_size = common.ProblemSize(
matmul_size,
lhs_type,
Expand All @@ -148,13 +148,13 @@ def test_generate_tile_and_fuse_constraints_valid_input(
)
# Define input parameters as z3 Ints
m, n, k = (
[z3.Int("m0"), z3.Int("m1")],
[z3.Int("n0"), z3.Int("n1")],
[z3.Int("k0"), z3.Int("k1")],
[z3.Int("m0")],
[z3.Int("n0")],
[z3.Int("k0")],
)
subgroup_m, subgroup_n = (
[z3.Int("subgroup_m0"), z3.Int("subgroup_m1")],
[z3.Int("subgroup_n0"), z3.Int("subgroup_n1")],
[z3.Int("subgroup_m0")],
[z3.Int("subgroup_n0")],
)
subgroup_size = z3.Int("subgroup_size")
intrinsic_mn = z3.Int("intrinsic_mn")
Expand Down Expand Up @@ -198,20 +198,20 @@ def test_generate_tile_and_fuse_constraints_invalid_input(
) -> None:
# Define input parameters that should lead to unsatisfiable constraints
matmul_size = common.ContractionSizes(
M=[4, 32],
N=[6, 64],
K=[8, 128],
B=[2, 16],
M=[32],
N=[64],
K=[128],
B=[2],
)
contraction_dims = common.ContractionDimensions(
m=[1, 5],
n=[2, 6],
k=[3, 7],
batch=[0, 4],
m=[1],
n=[2],
k=[3],
batch=[0],
)
lhs_type = common.ShapedType([2, 4, 8, 16, 32, 128], tuner_ctx.type.f16)
rhs_type = common.ShapedType([2, 6, 8, 16, 64, 128], tuner_ctx.type.f16)
res_type = common.ShapedType([2, 4, 6, 16, 32, 64], tuner_ctx.type.f32)
lhs_type = common.ShapedType([2, 32, 128], tuner_ctx.type.f16)
rhs_type = common.ShapedType([2, 64, 128], tuner_ctx.type.f16)
res_type = common.ShapedType([2, 32, 64], tuner_ctx.type.f32)
problem_size = common.ProblemSize(
matmul_size,
lhs_type,
Expand All @@ -222,13 +222,13 @@ def test_generate_tile_and_fuse_constraints_invalid_input(
)
# Define input parameters as z3 Ints
m, n, k = (
[z3.Int("m0"), z3.Int("m1")],
[z3.Int("n0"), z3.Int("n1")],
[z3.Int("k0"), z3.Int("k1")],
[z3.Int("m0")],
[z3.Int("n0")],
[z3.Int("k0")],
)
subgroup_m, subgroup_n = (
[z3.Int("subgroup_m0"), z3.Int("subgroup_m1")],
[z3.Int("subgroup_n0"), z3.Int("subgroup_n1")],
[z3.Int("subgroup_m0")],
[z3.Int("subgroup_n0")],
)
subgroup_size = z3.Int("subgroup_size")
intrinsic_mn = z3.Int("intrinsic_mn")
Expand Down

0 comments on commit 04b1819

Please sign in to comment.