Skip to content

Commit

Permalink
[MatchTargetSize] Extend for-loop canonicalization pattern (#2045)
Browse files Browse the repository at this point in the history
Support canonicalization of dependent `scf.for` loops by re-gluing
individual results after the loop.

See #1947 for
more context / the complete PoC.

---------

Signed-off-by: Julian Oppermann <[email protected]>
  • Loading branch information
jopperm authored Sep 9, 2024
1 parent 254449c commit 66bf5d8
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 15 deletions.
45 changes: 45 additions & 0 deletions test/TritonIntelGPU/match-target-size.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,51 @@ tt.func public @simplify_scf_for(%arg0: tensor<16x8xf16>, %arg1: tensor<16x8xf16

// -----

// COM: Test SCF canonicalization: ensure loop canonicalization can be applied to dependendent loops
tt.func public @simplify_scf_for(%arg0: tensor<16x8xf16>, %arg1: tensor<16x8xf16>, %arg2: !tt.ptr<f16, 1>,
%arg3: i64, %arg4: i64, %arg5: i64, %arg6: i32, %arg7: i32) {
// CHECK-LABEL: @simplify_scf_for
// CHECK-NOT: triton_intel_gpu.glue
// CHECK: [[RES:%.*]]:2 = scf.for {{.*}} iter_args([[INIT1:%.*]] = %arg0, [[INIT2:%.*]] = %arg1)
// CHECK-SAME: -> (tensor<16x8xf16>, tensor<16x8xf16>) : i32 {
// CHECK-NEXT: scf.yield [[INIT2]], [[INIT1]] : tensor<16x8xf16>, tensor<16x8xf16>
// CHECK-NEXT: }
// CHECK: [[RES2:%.*]]:2 = scf.for {{.*}} iter_args([[INIT3:%.*]] = [[RES]]#0, [[INIT4:%.*]] = [[RES]]#1)
// CHECK-SAME: -> (tensor<16x8xf16>, tensor<16x8xf16>) : i32 {
// CHECK-NEXT: scf.yield [[INIT3]], [[INIT4]] : tensor<16x8xf16>, tensor<16x8xf16>
// CHECK-NEXT: }
// CHECK-NEXT: [[GLUE:%.*]] = triton_intel_gpu.glue [[RES2]]#1, [[RES2]]#0
// CHECK-SAME: : (tensor<16x8xf16>, tensor<16x8xf16>) -> tensor<16x16xf16>
// CHECK-NEXT: [[PTR:%.*]] = tt.make_tensor_ptr %arg2
// CHECK-NEXT: tt.store [[PTR]], [[GLUE]]
%lb = arith.constant 0 : i32
%ub = arith.constant 32 : i32
%st = arith.constant 1 : i32
%c1_i64 = arith.constant 1 : i64
%cst = arith.constant dense<42.0> : tensor<16x16xf16>
%glue = triton_intel_gpu.glue %arg0, %arg1 : (tensor<16x8xf16>, tensor<16x8xf16>) -> tensor<16x16xf16>
%res = scf.for %iv = %lb to %ub step %st iter_args(%arg = %glue) -> (tensor<16x16xf16>) : i32 {
%e1 = triton_intel_gpu.extract %arg[1] : tensor<16x16xf16> -> tensor<16x8xf16>
%e2 = triton_intel_gpu.extract %arg[0] : tensor<16x16xf16> -> tensor<16x8xf16>
%g1 = triton_intel_gpu.glue %e1, %e2 : (tensor<16x8xf16>, tensor<16x8xf16>) -> tensor<16x16xf16>
scf.yield %g1 : tensor<16x16xf16>
}
%res2 = scf.for %iv = %lb to %ub step %st iter_args(%arg = %res) -> (tensor<16x16xf16>) : i32 {
%e1 = triton_intel_gpu.extract %arg[0] : tensor<16x16xf16> -> tensor<16x8xf16>
%e2 = triton_intel_gpu.extract %arg[1] : tensor<16x16xf16> -> tensor<16x8xf16>
%g1 = triton_intel_gpu.glue %e1, %e2 : (tensor<16x8xf16>, tensor<16x8xf16>) -> tensor<16x16xf16>
scf.yield %g1 : tensor<16x16xf16>
}
%e3 = triton_intel_gpu.extract %res2[0] : tensor<16x16xf16> -> tensor<16x8xf16>
%e4 = triton_intel_gpu.extract %res2[1] : tensor<16x16xf16> -> tensor<16x8xf16>
%g2 = triton_intel_gpu.glue %e4, %e3 : (tensor<16x8xf16>, tensor<16x8xf16>) -> tensor<16x16xf16>
%ptr = tt.make_tensor_ptr %arg2, [%arg3, %arg4], [%arg5, %c1_i64], [%arg6, %arg7] {order = array<i32: 1, 0>} : <tensor<16x16xf16>>
tt.store %ptr, %g2 {boundaryCheck = array<i32: 0, 1>, cache = 1 : i32, evict = 1 : i32} : !tt.ptr<tensor<16x16xf16>>
tt.return
}

// -----

// COM: Test transformation for int8 datatype

// CHECK-LABEL: @matmul_kernel_with_block_pointers_int8
Expand Down
31 changes: 16 additions & 15 deletions third_party/intel/lib/TritonIntelGPUTransforms/MatchTargetSize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -528,31 +528,31 @@ class ScfPattern : public OpRewritePattern<scf::ForOp> {
rewriter.create<scf::YieldOp>(yield.getLoc(), newValues);
rewriter.eraseOp(yield);

// Replace uses of the original loop results with the new loop results.
userIndexMap.clear();
rewriter.setInsertionPointAfter(newForOp);

idx = 0;
for (auto [result, init] :
llvm::zip(forOp.getResults(), forOp.getInits())) {
Operation *definingOp = init.getDefiningOp();

// Loop-carried value was not split by this pattern, just rewire all users
// to the new scf.for operation.
if (!isa_and_nonnull<ttgi::GlueOp>(definingOp)) {
userIndexMap[result] = idx++;
result.replaceAllUsesWith(newForOp.getResults()[idx]);
++idx;
continue;
}

// Re-glue individual results together _after_ the loop. This enables
// canonicalization of extract ops and dependent loops.
auto glue = cast<ttgi::GlueOp>(definingOp);
for (Operation *user : result.getUsers()) {
if (auto extract = dyn_cast<ttgi::ExtractOp>(user)) {
userIndexMap[extract] = idx + extract.getIndex();
deleteList.push_back(extract.getOperation());
}
}

auto reglue = rewriter.create<ttgi::GlueOp>(
forOp->getLoc(), result.getType(),
newForOp->getResults().slice(idx, glue.getOperands().size()));
result.replaceAllUsesWith(reglue);
idx += glue->getOperands().size();
}

for (auto [user, idx] : userIndexMap)
user.replaceAllUsesWith(newForOp.getResults()[idx]);

for (Operation *deleteOp : deleteList)
rewriter.eraseOp(deleteOp);

Expand Down Expand Up @@ -583,10 +583,11 @@ class ScfPattern : public OpRewritePattern<scf::ForOp> {
return false;
}

// Bail out if the loop result is not used by an 'extract' operation.
// Bail out if the loop result is not used by an 'extract' operation, or
// another loop.
if (forOp->getNumResults() == 1 &&
llvm::any_of(forOp.getResult(0).getUsers(), [](Operation *user) {
return !isa<ttgi::ExtractOp>(user);
return !isa<ttgi::ExtractOp, scf::ForOp>(user);
}))
return false;

Expand Down

0 comments on commit 66bf5d8

Please sign in to comment.