From a285c543b6534ce3d20f6a72111420248952109c Mon Sep 17 00:00:00 2001 From: Julian Oppermann Date: Thu, 29 Aug 2024 15:01:13 +0100 Subject: [PATCH 1/4] [MatchTargetSize] Extend for-loop canonicalization patterh Signed-off-by: Julian Oppermann --- test/TritonIntelGPU/match-target-size.mlir | 45 +++++++++++++++++++ .../MatchTargetSize.cpp | 33 ++++++++++++-- 2 files changed, 74 insertions(+), 4 deletions(-) diff --git a/test/TritonIntelGPU/match-target-size.mlir b/test/TritonIntelGPU/match-target-size.mlir index f5452d4a5e..2c26411a07 100644 --- a/test/TritonIntelGPU/match-target-size.mlir +++ b/test/TritonIntelGPU/match-target-size.mlir @@ -142,6 +142,51 @@ tt.func public @simplify_scf_for(%arg0: tensor<16x8xf16>, %arg1: tensor<16x8xf16 // ----- +// COM: Test SCF canonicalization: ensure loop canonicalization can be applied to dependendent loops +tt.func public @simplify_scf_for(%arg0: tensor<16x8xf16>, %arg1: tensor<16x8xf16>, %arg2: !tt.ptr, + %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i32, %arg7: i32) { + // CHECK-LABEL: @simplify_scf_for + // CHECK-NOT: triton_intel_gpu.glue + // CHECK: [[RES:%.*]]:2 = scf.for {{.*}} iter_args([[INIT1:%.*]] = %arg0, [[INIT2:%.*]] = %arg1) + // CHECK-SAME: -> (tensor<16x8xf16>, tensor<16x8xf16>) : i32 { + // CHECK-NEXT: scf.yield [[INIT2]], [[INIT1]] : tensor<16x8xf16>, tensor<16x8xf16> + // CHECK-NEXT: } + // CHECK: [[RES2:%.*]]:2 = scf.for {{.*}} iter_args([[INIT3:%.*]] = [[RES]]#0, [[INIT4:%.*]] = [[RES]]#1) + // CHECK-SAME: -> (tensor<16x8xf16>, tensor<16x8xf16>) : i32 { + // CHECK-NEXT: scf.yield [[INIT3]], [[INIT4]] : tensor<16x8xf16>, tensor<16x8xf16> + // CHECK-NEXT: } + // CHECK-NEXT: [[GLUE:%.*]] = triton_intel_gpu.glue [[RES2]]#1, [[RES2]]#0 + // CHECK-SAME: : (tensor<16x8xf16>, tensor<16x8xf16>) -> tensor<16x16xf16> + // CHECK-NEXT: [[PTR:%.*]] = tt.make_tensor_ptr %arg2 + // CHECK-NEXT: tt.store [[PTR]], [[GLUE]] + %lb = arith.constant 0 : i32 + %ub = arith.constant 32 : i32 + %st = arith.constant 1 : i32 + %c1_i64 = arith.constant 1 : i64 + %cst = arith.constant dense<42.0> : tensor<16x16xf16> + %glue = triton_intel_gpu.glue %arg0, %arg1 : (tensor<16x8xf16>, tensor<16x8xf16>) -> tensor<16x16xf16> + %res = scf.for %iv = %lb to %ub step %st iter_args(%arg = %glue) -> (tensor<16x16xf16>) : i32 { + %e1 = triton_intel_gpu.extract %arg[1] : tensor<16x16xf16> -> tensor<16x8xf16> + %e2 = triton_intel_gpu.extract %arg[0] : tensor<16x16xf16> -> tensor<16x8xf16> + %g1 = triton_intel_gpu.glue %e1, %e2 : (tensor<16x8xf16>, tensor<16x8xf16>) -> tensor<16x16xf16> + scf.yield %g1 : tensor<16x16xf16> + } + %res2 = scf.for %iv = %lb to %ub step %st iter_args(%arg = %res) -> (tensor<16x16xf16>) : i32 { + %e1 = triton_intel_gpu.extract %arg[0] : tensor<16x16xf16> -> tensor<16x8xf16> + %e2 = triton_intel_gpu.extract %arg[1] : tensor<16x16xf16> -> tensor<16x8xf16> + %g1 = triton_intel_gpu.glue %e1, %e2 : (tensor<16x8xf16>, tensor<16x8xf16>) -> tensor<16x16xf16> + scf.yield %g1 : tensor<16x16xf16> + } + %e3 = triton_intel_gpu.extract %res2[0] : tensor<16x16xf16> -> tensor<16x8xf16> + %e4 = triton_intel_gpu.extract %res2[1] : tensor<16x16xf16> -> tensor<16x8xf16> + %g2 = triton_intel_gpu.glue %e4, %e3 : (tensor<16x8xf16>, tensor<16x8xf16>) -> tensor<16x16xf16> + %ptr = tt.make_tensor_ptr %arg2, [%arg3, %arg4], [%arg5, %c1_i64], [%arg6, %arg7] {order = array} : > + tt.store %ptr, %g2 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32} : !tt.ptr> + tt.return +} + +// ----- + // COM: Test transformation for int8 datatype // CHECK-LABEL: @matmul_kernel_with_block_pointers_int8 diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/MatchTargetSize.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/MatchTargetSize.cpp index 089a498c37..db893b73eb 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/MatchTargetSize.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/MatchTargetSize.cpp @@ -514,6 +514,13 @@ class ScfPattern : public OpRewritePattern { // Replace uses of the original loop results with the new loop results. userIndexMap.clear(); + + // If a results is used by another scf.for loop, we re-glue the individual + // results together to allow canonicalization of the dependent loop, too. + llvm::SmallDenseMap reglueMap; + + rewriter.setInsertionPointAfter(newForOp); + idx = 0; for (auto [result, init] : llvm::zip(forOp.getResults(), forOp.getInits())) { @@ -524,11 +531,25 @@ class ScfPattern : public OpRewritePattern { } auto glue = cast(definingOp); - for (Operation *user : result.getUsers()) { - if (auto extract = dyn_cast(user)) { + if (llvm::all_of(result.getUsers(), + [](auto *user) { return isa(user); })) { + for (Operation *user : result.getUsers()) { + auto extract = cast(user); userIndexMap[extract] = idx + extract.getIndex(); deleteList.push_back(extract.getOperation()); } + } else if (llvm::all_of(result.getUsers(), [](auto *user) { + return isa(user); + })) { + // We have encountered a glued operand (and already split its uses + // within this loop), but the corresponding result's user(s) are + // dependent loops, not extracts. Hence, we have to re-glue the results. + auto reglue = rewriter.create( + forOp->getLoc(), result.getType(), + newForOp->getResults().slice(idx, glue.getOperands().size())); + reglueMap[result] = reglue; + } else { + llvm::report_fatal_error("Unexpected users of loop result"); } idx += glue->getOperands().size(); @@ -537,6 +558,9 @@ class ScfPattern : public OpRewritePattern { for (auto [user, idx] : userIndexMap) user.replaceAllUsesWith(newForOp.getResults()[idx]); + for (auto [res, glue] : reglueMap) + res.replaceAllUsesWith(glue); + for (Operation *deleteOp : deleteList) rewriter.eraseOp(deleteOp); @@ -567,10 +591,11 @@ class ScfPattern : public OpRewritePattern { return false; } - // Bail out if the loop result is not used by an 'extract' operation. + // Bail out if the loop result is not used by an 'extract' operation, or + // another loop. if (forOp->getNumResults() == 1 && llvm::any_of(forOp.getResult(0).getUsers(), [](Operation *user) { - return !isa(user); + return !isa(user); })) return false; From 173dd19fc564a0ccb7bd60519c1266be2e0cd5c4 Mon Sep 17 00:00:00 2001 From: Julian Oppermann Date: Thu, 29 Aug 2024 18:57:38 +0100 Subject: [PATCH 2/4] Nits. --- .../intel/lib/TritonIntelGPUTransforms/MatchTargetSize.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/MatchTargetSize.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/MatchTargetSize.cpp index db893b73eb..bd6446fcc7 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/MatchTargetSize.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/MatchTargetSize.cpp @@ -531,14 +531,15 @@ class ScfPattern : public OpRewritePattern { } auto glue = cast(definingOp); - if (llvm::all_of(result.getUsers(), - [](auto *user) { return isa(user); })) { + if (llvm::all_of(result.getUsers(), [](Operation *user) { + return isa(user); + })) { for (Operation *user : result.getUsers()) { auto extract = cast(user); userIndexMap[extract] = idx + extract.getIndex(); deleteList.push_back(extract.getOperation()); } - } else if (llvm::all_of(result.getUsers(), [](auto *user) { + } else if (llvm::all_of(result.getUsers(), [](Operation *user) { return isa(user); })) { // We have encountered a glued operand (and already split its uses From 57c160b862e16e573c1b96f3d6ab7b3144d95518 Mon Sep 17 00:00:00 2001 From: Julian Oppermann Date: Mon, 2 Sep 2024 15:25:59 +0100 Subject: [PATCH 3/4] Always insert glue ops after the loop. --- .../MatchTargetSize.cpp | 43 ++++--------------- 1 file changed, 9 insertions(+), 34 deletions(-) diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/MatchTargetSize.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/MatchTargetSize.cpp index bd6446fcc7..4c85df5cdb 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/MatchTargetSize.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/MatchTargetSize.cpp @@ -512,56 +512,31 @@ class ScfPattern : public OpRewritePattern { rewriter.create(yield.getLoc(), newValues); rewriter.eraseOp(yield); - // Replace uses of the original loop results with the new loop results. - userIndexMap.clear(); - - // If a results is used by another scf.for loop, we re-glue the individual - // results together to allow canonicalization of the dependent loop, too. - llvm::SmallDenseMap reglueMap; - rewriter.setInsertionPointAfter(newForOp); idx = 0; for (auto [result, init] : llvm::zip(forOp.getResults(), forOp.getInits())) { Operation *definingOp = init.getDefiningOp(); + + // Loop-carried value was not split by this pattern, just rewire all users + // to the new scf.for operation. if (!isa_and_nonnull(definingOp)) { - userIndexMap[result] = idx++; + result.replaceAllUsesWith(newForOp.getResults()[idx]); + ++idx; continue; } + // Re-glue individual results together _after_ the loop. This enables + // canonicalization of extract ops and dependent loops. auto glue = cast(definingOp); - if (llvm::all_of(result.getUsers(), [](Operation *user) { - return isa(user); - })) { - for (Operation *user : result.getUsers()) { - auto extract = cast(user); - userIndexMap[extract] = idx + extract.getIndex(); - deleteList.push_back(extract.getOperation()); - } - } else if (llvm::all_of(result.getUsers(), [](Operation *user) { - return isa(user); - })) { - // We have encountered a glued operand (and already split its uses - // within this loop), but the corresponding result's user(s) are - // dependent loops, not extracts. Hence, we have to re-glue the results. - auto reglue = rewriter.create( + auto reglue = rewriter.create( forOp->getLoc(), result.getType(), newForOp->getResults().slice(idx, glue.getOperands().size())); - reglueMap[result] = reglue; - } else { - llvm::report_fatal_error("Unexpected users of loop result"); - } - + result.replaceAllUsesWith(reglue); idx += glue->getOperands().size(); } - for (auto [user, idx] : userIndexMap) - user.replaceAllUsesWith(newForOp.getResults()[idx]); - - for (auto [res, glue] : reglueMap) - res.replaceAllUsesWith(glue); - for (Operation *deleteOp : deleteList) rewriter.eraseOp(deleteOp); From bf5a11bbb81f750e6e4cf8e0a45568bf7a7a32ec Mon Sep 17 00:00:00 2001 From: Julian Oppermann Date: Mon, 2 Sep 2024 15:27:51 +0100 Subject: [PATCH 4/4] Format. --- .../intel/lib/TritonIntelGPUTransforms/MatchTargetSize.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/MatchTargetSize.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/MatchTargetSize.cpp index 4c85df5cdb..bc411670d2 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/MatchTargetSize.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/MatchTargetSize.cpp @@ -531,8 +531,8 @@ class ScfPattern : public OpRewritePattern { // canonicalization of extract ops and dependent loops. auto glue = cast(definingOp); auto reglue = rewriter.create( - forOp->getLoc(), result.getType(), - newForOp->getResults().slice(idx, glue.getOperands().size())); + forOp->getLoc(), result.getType(), + newForOp->getResults().slice(idx, glue.getOperands().size())); result.replaceAllUsesWith(reglue); idx += glue->getOperands().size(); }