Skip to content

Commit

Permalink
AIRIsolateAsyncDmaLoopNests: Fixup incomplete backward slice when spl…
Browse files Browse the repository at this point in the history
…itting the loop body (Xilinx#869)

* A number of minor fixups around async loop splitting

* target_ops_sets.empty() is subset of target_ops_sets.size() < 2

* Replace cloneDefiningOpsInRegion with getBackwardSliceInRegion, as the former could lead to repeated cloning of backward slice

* Add mlir ir test for deep backward slice tracing
  • Loading branch information
erwei-xilinx authored Jan 21, 2025
1 parent d226668 commit 0061fad
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 51 deletions.
9 changes: 4 additions & 5 deletions mlir/include/air/Util/Util.h
Original file line number Diff line number Diff line change
Expand Up @@ -246,11 +246,10 @@ bool hasNImpureOps(Block *block, unsigned N);
// terminator.
bool hasNElements(Block *block, unsigned N);

// Clone backward slices of a list of values.
SmallVector<Operation *> cloneDefiningOpsInRegion(OpBuilder builder,
Region *region,
SmallVectorImpl<Value> &opers,
IRMapping &remap);
// Get backward slice to a vector of values, within a specified region.
void getBackwardSliceInRegion(OpBuilder builder, Region *region,
SmallVectorImpl<Value> &vals,
SetVector<Operation *> &backwardSlices);

// Buffer all allocations of memref directly within the func op's body into the
// func op's arguments.
Expand Down
7 changes: 5 additions & 2 deletions mlir/lib/Conversion/AIRRtToNpuPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -905,8 +905,11 @@ struct AIRSpecializeAIRRtDmaWrapAndStrideInAffineFor
}

// Hoist any pure ops that the new channel op depends on.
(void)air::cloneDefiningOpsInRegion(rewriter, &for_op.getRegion(), opers,
remap);
llvm::SetVector<Operation *> backwardSlices;
air::getBackwardSliceInRegion(rewriter, &for_op.getRegion(), opers,
backwardSlices);
for (auto o : backwardSlices)
rewriter.clone(*o, remap);

auto new_dma = rewriter.create<airrt::DmaMemcpyNdOp>(
loc, tys, air::lookupOrDefaultRange(opers, remap));
Expand Down
18 changes: 10 additions & 8 deletions mlir/lib/Transform/AIRDependencyScheduleOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1953,9 +1953,11 @@ struct AIRSpecializeChannelWrapAndStrideInScfFor
llvm::concat<Value>(SmallVector<Value>{channel_op.getMemref()},
channel_op.getIndices(), offsets, wraps, strides));
IRMapping remap;
auto clonedOps = cloneDefiningOpsInRegion(rewriter, &for_op.getRegion(),
new_opers, remap);
for (auto cloned : clonedOps) {
llvm::SetVector<Operation *> backwardSlices;
air::getBackwardSliceInRegion(rewriter, &for_op.getRegion(), new_opers,
backwardSlices);
for (auto o : backwardSlices) {
auto cloned = rewriter.clone(*o, remap);
clearAsyncDependenciesOfAsyncOp(cloned);
for (auto token : deps)
addAsyncDependencyIfNew(cloned, token);
Expand Down Expand Up @@ -2138,9 +2140,11 @@ struct AIRSpecializeChannelWrapAndStrideInAffineFor
llvm::concat<Value>(SmallVector<Value>{channel_op.getMemref()},
channel_op.getIndices(), offsets, wraps, strides));
IRMapping remap;
auto clonedOps = cloneDefiningOpsInRegion(rewriter, &for_op.getRegion(),
new_opers, remap);
for (auto cloned : clonedOps) {
llvm::SetVector<Operation *> backwardSlices;
air::getBackwardSliceInRegion(rewriter, &for_op.getRegion(), new_opers,
backwardSlices);
for (auto o : backwardSlices) {
auto cloned = rewriter.clone(*o, remap);
clearAsyncDependenciesOfAsyncOp(cloned);
for (auto token : deps)
addAsyncDependencyIfNew(cloned, token);
Expand Down Expand Up @@ -4296,8 +4300,6 @@ struct IsolateAsyncDmaLoopNestInSCFForPattern
SmallVector<llvm::SetVector<Operation *>> target_ops_sets;

identifyTargetOpsInSCFFor(f, for_op, target_ops_sets);
if (target_ops_sets.empty())
return failure();
if (target_ops_sets.size() < 2)
return failure();

Expand Down
21 changes: 10 additions & 11 deletions mlir/lib/Util/Dependency.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -732,14 +732,11 @@ scf::ForOp hoistTargetOpsToNewSCFFor(PatternRewriter &rewriter,
for (auto &region : op->getRegions())
getUsedValuesDefinedAbove(region, region_opers);
region_opers.insert(op->getOperands().begin(), op->getOperands().end());
for (auto operand : region_opers) {
auto operandDepOp = operand.getDefiningOp();
if (!operandDepOp)
continue;
if (operandDepOp->getBlock() != for_op.getBody())
continue;
ops_to_be_cloned.insert(operandDepOp);
}
SmallVector<Value> region_opers_vec = region_opers.takeVector();
llvm::SetVector<Operation *> backwardSlices;
air::getBackwardSliceInRegion(rewriter, &for_op.getRegion(),
region_opers_vec, backwardSlices);
ops_to_be_cloned.insert(backwardSlices.begin(), backwardSlices.end());
ops_to_be_cloned.insert(op);
}
Operation *back_of_dep_chain;
Expand All @@ -755,8 +752,8 @@ scf::ForOp hoistTargetOpsToNewSCFFor(PatternRewriter &rewriter,
yield_operands)
->getResult(0)});

IRMapping waitAllRemap;
for (auto erase_op : target_ops) {
IRMapping waitAllRemap;
if (air::isAsyncOp(erase_op)) {
// Reconnect returned tokens.
rewriter.setInsertionPoint(erase_op);
Expand All @@ -766,10 +763,12 @@ scf::ForOp hoistTargetOpsToNewSCFFor(PatternRewriter &rewriter,
newWaitAll.getAsyncToken());
}
}
for (auto erase_op : target_ops)
// Erasing the original ops backwards, to avoid erasing op that still has
// valid uses.
for (auto erase_op : llvm::reverse(target_ops))
rewriter.eraseOp(erase_op);
for (auto user : for_op.getResults().front().getUsers()) {
air::addAsyncDependencyIfNew(user, new_for_op.getResults().front());
air::addAsyncDependencyIfNew(user, air::getAsyncTokenFromOp(new_for_op));
}

return new_for_op;
Expand Down
38 changes: 13 additions & 25 deletions mlir/lib/Util/Util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1655,43 +1655,31 @@ bool air::hasNImpureOps(Block *block, unsigned N) {
// Return if the given block contains N ops or not, not counting the block's
// terminator.
bool air::hasNElements(Block *block, unsigned N) {
// unsigned counter = 0;
// for (auto &o : block->without_terminator())
// counter++;
return llvm::range_size(block->without_terminator()) == N;
}

// Clone backward slices of a list of values.
SmallVector<Operation *>
air::cloneDefiningOpsInRegion(OpBuilder builder, Region *region,
SmallVectorImpl<Value> &opers, IRMapping &remap) {
SmallVector<Operation *> clonedOps;
SetVector<Operation *> backwardSlices;
// Get backward slice to a vector of values, within a specified region.
void air::getBackwardSliceInRegion(OpBuilder builder, Region *region,
SmallVectorImpl<Value> &vals,
SetVector<Operation *> &backwardSlices) {
BackwardSliceOptions bsOptions{
[&](Operation *o) { return region->isAncestor(o->getParentRegion()); }};
if (!region)
return clonedOps;
for (auto operand : opers) {
auto operandDefOp = operand.getDefiningOp();
if (!operandDefOp)
return;
for (auto val : vals) {
auto valDefOp = val.getDefiningOp();
if (!valDefOp)
continue;
if (!region->isAncestor(operandDefOp->getParentRegion()))
if (!region->isAncestor(valDefOp->getParentRegion()))
continue;
assert(air::isPure(operandDefOp) ||
isa<air::WaitAllOp>(operandDefOp)); // Pure ops and wait ops are
// safe to hoist out of loops.
// Get backward slices
SetVector<Operation *> operandBS;
getBackwardSlice(operandDefOp, &operandBS, bsOptions);
for (auto b : operandBS) {
assert(air::isPure(b) || isa<air::WaitAllOp>(b));
SetVector<Operation *> valBS;
getBackwardSlice(valDefOp, &valBS, bsOptions);
for (auto b : valBS) {
backwardSlices.insert(b);
}
backwardSlices.insert(operandDefOp);
backwardSlices.insert(valDefOp);
}
for (auto op : backwardSlices)
clonedOps.push_back(builder.clone(*op, remap));
return clonedOps;
}

// Buffer all allocations of L3 memref directly within the func op's body into
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -969,3 +969,79 @@ module {
return
}
}

// -----

// Loop splitting with deep backward slice.

// CHECK-LABEL: @func12
// CHECK: air.segment
// CHECK: air.channel.get {{.*}} @channel_24

// CHECK: scf.for %[[RES0:.*]] = %c0{{.*}} to %c8{{.*}} step %c4{{.*}} iter_args(%[[TOK0:.*]] = %{{.*}})
// CHECK: scf.for %[[RES1:.*]] = %c0{{.*}} to %c8{{.*}} step %c4{{.*}} iter_args(%[[TOK1:.*]] = %{{.*}})
// CHECK: %[[TOK2:.*]], %[[RES2:.*]] = air.execute [%[[TOK1]]] -> (index) {
// CHECK-NEXT: air.execute_terminator %[[RES0]] : index
// CHECK-NEXT: }
// CHECK: air.channel.put async [%[[TOK2]]] {{.*}}@channel_16[] {{.*}}[%[[RES2]]
// CHECK: scf.yield
// CHECK-NEXT: }
// CHECK: scf.yield
// CHECK-NEXT: }

// CHECK: scf.for %[[RES0:.*]] = %c0{{.*}} to %c8{{.*}} step %c4{{.*}} iter_args(%[[TOK0:.*]] = %{{.*}})
// CHECK: scf.for %[[RES1:.*]] = %c0{{.*}} to %c8{{.*}} step %c4{{.*}} iter_args(%[[TOK1:.*]] = %{{.*}})
// CHECK: %[[TOK2:.*]], %[[RES2:.*]] = air.execute [%[[TOK1]]] -> (index) {
// CHECK-NEXT: air.execute_terminator %[[RES0]] : index
// CHECK-NEXT: }
// CHECK: air.channel.put async [%[[TOK2]]] {{.*}}@channel_17[] {{.*}}[%[[RES2]]
// CHECK: scf.yield
// CHECK-NEXT: }
// CHECK: scf.yield
// CHECK-NEXT: }

module {
air.channel @channel_16 [1, 1]
air.channel @channel_17 [1, 1]
air.channel @channel_24 [1, 1]
func.func @func12() {
%4 = air.segment @segment_0 async attributes {id = 2 : i32} {
%c128 = arith.constant 128 : index
%c1024 = arith.constant 1024 : index
%c15 = arith.constant 15 : index
%c4 = arith.constant 4 : index
%c1_11 = arith.constant 1 : index
%c16384_12 = arith.constant 16384 : index
%c32_13 = arith.constant 32 : index
%c8_14 = arith.constant 8 : index
%c0_15 = arith.constant 0 : index
%async_token_26, %results_27 = air.execute -> (memref<8x16x32x32xbf16, 1 : i32>) {
%alloc = memref.alloc() : memref<8x16x32x32xbf16, 1 : i32>
air.execute_terminator %alloc : memref<8x16x32x32xbf16, 1 : i32>
}
%5 = air.channel.get async [%async_token_26] @channel_24[] (%results_27[] [] []) {id = 4 : i32} : (memref<8x16x32x32xbf16, 1 : i32>)
%7 = air.wait_all async [%5]
%8 = scf.for %arg10 = %c0_15 to %c8_14 step %c4 iter_args(%arg11 = %7) -> (!air.async.token) {
%10 = scf.for %arg12 = %c0_15 to %c8_14 step %c4 iter_args(%arg13 = %arg11) -> (!air.async.token) {
%19 = air.wait_all async [%arg13]
%20 = air.wait_all async [%19]
%async_token_50, %results_51 = air.execute [%20] -> (index) {
air.execute_terminator %arg10 : index
}
%21 = air.channel.put async [%async_token_50] @channel_16[] (%results_27[%results_51, %c15, %c0_15, %c0_15, %c0_15, %c0_15] [%c1_11, %c1_11, %c4, %c8_14, %c4, %c8_14] [%c16384_12, %c1024, %c8_14, %c128, %c32_13, %c1_11]) {id = 38 : i32} : (memref<8x16x32x32xbf16, 1 : i32>)
%async_token_52, %results_53 = air.execute [%20] -> (index) {
air.execute_terminator %arg10 : index
}
%22 = air.channel.put async [%async_token_52] @channel_17[] (%results_27[%results_53, %c15, %c0_15, %c0_15, %c0_15, %c0_15] [%c1_11, %c1_11, %c4, %c8_14, %c4, %c8_14] [%c16384_12, %c1024, %c8_14, %c128, %c32_13, %c1_11]) {id = 39 : i32} : (memref<8x16x32x32xbf16, 1 : i32>)
%29 = air.wait_all async [%arg13]
scf.yield %29 : !air.async.token
}
scf.yield %10 : !air.async.token
}
%async_token_28 = air.execute [%8] {
memref.dealloc %results_27 : memref<8x16x32x32xbf16, 1 : i32>
}
}
return
}
}

0 comments on commit 0061fad

Please sign in to comment.