diff --git a/test/TritonIntelGPU/match-target-size.mlir b/test/TritonIntelGPU/match-target-size.mlir index f5452d4a5e..2c26411a07 100644 --- a/test/TritonIntelGPU/match-target-size.mlir +++ b/test/TritonIntelGPU/match-target-size.mlir @@ -142,6 +142,51 @@ tt.func public @simplify_scf_for(%arg0: tensor<16x8xf16>, %arg1: tensor<16x8xf16 // ----- +// COM: Test SCF canonicalization: ensure loop canonicalization can be applied to dependendent loops +tt.func public @simplify_scf_for(%arg0: tensor<16x8xf16>, %arg1: tensor<16x8xf16>, %arg2: !tt.ptr, + %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i32, %arg7: i32) { + // CHECK-LABEL: @simplify_scf_for + // CHECK-NOT: triton_intel_gpu.glue + // CHECK: [[RES:%.*]]:2 = scf.for {{.*}} iter_args([[INIT1:%.*]] = %arg0, [[INIT2:%.*]] = %arg1) + // CHECK-SAME: -> (tensor<16x8xf16>, tensor<16x8xf16>) : i32 { + // CHECK-NEXT: scf.yield [[INIT2]], [[INIT1]] : tensor<16x8xf16>, tensor<16x8xf16> + // CHECK-NEXT: } + // CHECK: [[RES2:%.*]]:2 = scf.for {{.*}} iter_args([[INIT3:%.*]] = [[RES]]#0, [[INIT4:%.*]] = [[RES]]#1) + // CHECK-SAME: -> (tensor<16x8xf16>, tensor<16x8xf16>) : i32 { + // CHECK-NEXT: scf.yield [[INIT3]], [[INIT4]] : tensor<16x8xf16>, tensor<16x8xf16> + // CHECK-NEXT: } + // CHECK-NEXT: [[GLUE:%.*]] = triton_intel_gpu.glue [[RES2]]#1, [[RES2]]#0 + // CHECK-SAME: : (tensor<16x8xf16>, tensor<16x8xf16>) -> tensor<16x16xf16> + // CHECK-NEXT: [[PTR:%.*]] = tt.make_tensor_ptr %arg2 + // CHECK-NEXT: tt.store [[PTR]], [[GLUE]] + %lb = arith.constant 0 : i32 + %ub = arith.constant 32 : i32 + %st = arith.constant 1 : i32 + %c1_i64 = arith.constant 1 : i64 + %cst = arith.constant dense<42.0> : tensor<16x16xf16> + %glue = triton_intel_gpu.glue %arg0, %arg1 : (tensor<16x8xf16>, tensor<16x8xf16>) -> tensor<16x16xf16> + %res = scf.for %iv = %lb to %ub step %st iter_args(%arg = %glue) -> (tensor<16x16xf16>) : i32 { + %e1 = triton_intel_gpu.extract %arg[1] : tensor<16x16xf16> -> tensor<16x8xf16> + %e2 = triton_intel_gpu.extract %arg[0] : tensor<16x16xf16> -> tensor<16x8xf16> + %g1 = triton_intel_gpu.glue %e1, %e2 : (tensor<16x8xf16>, tensor<16x8xf16>) -> tensor<16x16xf16> + scf.yield %g1 : tensor<16x16xf16> + } + %res2 = scf.for %iv = %lb to %ub step %st iter_args(%arg = %res) -> (tensor<16x16xf16>) : i32 { + %e1 = triton_intel_gpu.extract %arg[0] : tensor<16x16xf16> -> tensor<16x8xf16> + %e2 = triton_intel_gpu.extract %arg[1] : tensor<16x16xf16> -> tensor<16x8xf16> + %g1 = triton_intel_gpu.glue %e1, %e2 : (tensor<16x8xf16>, tensor<16x8xf16>) -> tensor<16x16xf16> + scf.yield %g1 : tensor<16x16xf16> + } + %e3 = triton_intel_gpu.extract %res2[0] : tensor<16x16xf16> -> tensor<16x8xf16> + %e4 = triton_intel_gpu.extract %res2[1] : tensor<16x16xf16> -> tensor<16x8xf16> + %g2 = triton_intel_gpu.glue %e4, %e3 : (tensor<16x8xf16>, tensor<16x8xf16>) -> tensor<16x16xf16> + %ptr = tt.make_tensor_ptr %arg2, [%arg3, %arg4], [%arg5, %c1_i64], [%arg6, %arg7] {order = array} : > + tt.store %ptr, %g2 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32} : !tt.ptr> + tt.return +} + +// ----- + // COM: Test transformation for int8 datatype // CHECK-LABEL: @matmul_kernel_with_block_pointers_int8 diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/MatchTargetSize.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/MatchTargetSize.cpp index 089a498c37..bc411670d2 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/MatchTargetSize.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/MatchTargetSize.cpp @@ -512,31 +512,31 @@ class ScfPattern : public OpRewritePattern { rewriter.create(yield.getLoc(), newValues); rewriter.eraseOp(yield); - // Replace uses of the original loop results with the new loop results. - userIndexMap.clear(); + rewriter.setInsertionPointAfter(newForOp); + idx = 0; for (auto [result, init] : llvm::zip(forOp.getResults(), forOp.getInits())) { Operation *definingOp = init.getDefiningOp(); + + // Loop-carried value was not split by this pattern, just rewire all users + // to the new scf.for operation. if (!isa_and_nonnull(definingOp)) { - userIndexMap[result] = idx++; + result.replaceAllUsesWith(newForOp.getResults()[idx]); + ++idx; continue; } + // Re-glue individual results together _after_ the loop. This enables + // canonicalization of extract ops and dependent loops. auto glue = cast(definingOp); - for (Operation *user : result.getUsers()) { - if (auto extract = dyn_cast(user)) { - userIndexMap[extract] = idx + extract.getIndex(); - deleteList.push_back(extract.getOperation()); - } - } - + auto reglue = rewriter.create( + forOp->getLoc(), result.getType(), + newForOp->getResults().slice(idx, glue.getOperands().size())); + result.replaceAllUsesWith(reglue); idx += glue->getOperands().size(); } - for (auto [user, idx] : userIndexMap) - user.replaceAllUsesWith(newForOp.getResults()[idx]); - for (Operation *deleteOp : deleteList) rewriter.eraseOp(deleteOp); @@ -567,10 +567,11 @@ class ScfPattern : public OpRewritePattern { return false; } - // Bail out if the loop result is not used by an 'extract' operation. + // Bail out if the loop result is not used by an 'extract' operation, or + // another loop. if (forOp->getNumResults() == 1 && llvm::any_of(forOp.getResult(0).getUsers(), [](Operation *user) { - return !isa(user); + return !isa(user); })) return false;