Skip to content

Commit

Permalink
Improve code generality to cover for ops without air::AsyncOpInterface (
Browse files Browse the repository at this point in the history
Xilinx#453)

* Improve code generality to cover for ops without air::AsyncOpInterface

* Add unit test
  • Loading branch information
erwei-xilinx authored Feb 24, 2024
1 parent fc3ee49 commit fb243b2
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 16 deletions.
44 changes: 28 additions & 16 deletions mlir/lib/Util/Dependency.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ bool isAsyncOp(Operation *op) {
// for loop to a new for loop located at the same scope.
scf::ForOp hoistTargetOpsToNewSCFFor(OpBuilder builder, scf::ForOp for_op,
SmallVector<Operation *> target_ops) {

auto loc = for_op->getLoc();
// If target ops are already perfectly nested, then skip
auto hasNElements = [](Block *block, unsigned N) {
auto op_ptr = block->begin();
Expand Down Expand Up @@ -604,8 +604,8 @@ scf::ForOp hoistTargetOpsToNewSCFFor(OpBuilder builder, scf::ForOp for_op,
builder.setInsertionPoint(for_op);
IRMapping remap;
auto new_for_op = builder.create<scf::ForOp>(
for_op.getLoc(), for_op.getLowerBound(), for_op.getUpperBound(),
for_op.getStep(), for_op.getInitArgs());
loc, for_op.getLowerBound(), for_op.getUpperBound(), for_op.getStep(),
for_op.getInitArgs());
remap.map(for_op.getInductionVar(), new_for_op.getInductionVar());
remap.map(getLoopCarriedTokenFromScfOp(for_op, "argument"),
getLoopCarriedTokenFromScfOp(new_for_op, "argument"));
Expand All @@ -618,13 +618,12 @@ scf::ForOp hoistTargetOpsToNewSCFFor(OpBuilder builder, scf::ForOp for_op,
yield_operands.push_back(new_op->getResult(0));
}
builder.create<scf::YieldOp>(
new_for_op.getLoc(),
SmallVector<Value>{builder
.create<air::WaitAllOp>(
new_for_op.getLoc(),
air::AsyncTokenType::get(builder.getContext()),
yield_operands)
->getResult(0)});
loc, SmallVector<Value>{
builder
.create<air::WaitAllOp>(
loc, air::AsyncTokenType::get(builder.getContext()),
yield_operands)
->getResult(0)});

// Update dependency to hoisted ops
for (auto herd : new_for_op.getOps<air::HerdOp>()) {
Expand All @@ -634,13 +633,26 @@ scf::ForOp hoistTargetOpsToNewSCFFor(OpBuilder builder, scf::ForOp for_op,
}
for (auto erase_op : target_ops) {
// Reconnect returned tokens.
for (auto user : erase_op->getResult(0).getUsers()) {
if (auto async_user = dyn_cast<air::AsyncOpInterface>(user)) {
eraseAsyncDependencyFromAsyncOp(async_user, erase_op->getResult(0));
for (auto dep : getAsyncDependenciesFromOp(erase_op)) {
if (dep != getLoopCarriedTokenFromScfOp(for_op, "argument")) {
air::addAsyncDependencyIfNew(user, dep);
builder.setInsertionPoint(erase_op);
for (auto res : erase_op->getResults()) {
if (!isa<air::AsyncTokenType>(res.getType()))
continue;
for (auto &u : res.getUses()) {
if (auto async_user = dyn_cast<air::AsyncOpInterface>(u.getOwner())) {
eraseAsyncDependencyFromAsyncOp(async_user, res);
for (auto dep : getAsyncDependenciesFromOp(erase_op)) {
if (dep != getLoopCarriedTokenFromScfOp(for_op, "argument")) {
air::addAsyncDependencyIfNew(u.getOwner(), dep);
}
}
} else {
// User op doesn't have air::AsyncOpInterface. Replace uses with newly
// generated air.wait_all op.
u.assign(builder
.create<air::WaitAllOp>(
loc, air::AsyncTokenType::get(builder.getContext()),
getAsyncDependenciesFromOp(erase_op))
.getAsyncToken());
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,3 +259,92 @@ module {
return
}
}

// -----

// Loop nest.

// CHECK-LABEL: func2

// CHECK: air.launch
// CHECK: air.segment @segment_0
// CHECK: air.herd @herd_0

// CHECK: scf.for %{{.*}} = %c0 to %c2048 step %c256 iter_args(%{{.*}} = %{{.*}}) -> (!air.async.token) {
// CHECK: scf.for %{{.*}} = %c0 to %c256 step %c64 iter_args(%{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}) -> (!air.async.token, !air.async.token, !air.async.token, !air.async.token) {
// CHECK: scf.yield %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !air.async.token, !air.async.token, !air.async.token, !air.async.token
// CHECK: scf.yield %{{.*}} : !air.async.token

// CHECK: air.herd_terminator
// CHECK: air.segment_terminator
// CHECK: air.launch_terminator

module {
func.func @func2() {
%c32 = arith.constant 32 : index
%0 = air.launch async (%arg3, %arg4) in (%arg5=%c32, %arg6=%c32) attributes {id = 1 : i32} {
%1 = air.segment @segment_0 async attributes {id = 2 : i32} {
%c2 = arith.constant 2 : index
%2 = air.herd @herd_0 async tile (%arg7, %arg8) in (%arg9=%c2, %arg10=%c2) attributes {id = 3 : i32} {
%c64 = arith.constant 64 : index
%c0 = arith.constant 0 : index
%c256 = arith.constant 256 : index
%c2048 = arith.constant 2048 : index
%async_token_1, %results_2 = air.execute -> (memref<32x32xi32, 2>) {
%alloc = memref.alloc() : memref<32x32xi32, 2>
air.execute_terminator %alloc : memref<32x32xi32, 2>
}
%3 = scf.for %arg11 = %c0 to %c2048 step %c256 iter_args(%arg12 = %async_token_1) -> (!air.async.token) {
%async_token_4, %results_5 = air.execute [%arg12] -> (memref<32x32xi32, 2>) {
%alloc = memref.alloc() : memref<32x32xi32, 2>
air.execute_terminator %alloc : memref<32x32xi32, 2>
}
%async_token_6, %results_7 = air.execute [%async_token_4] -> (memref<32x32xi32, 2>) {
%alloc = memref.alloc() : memref<32x32xi32, 2>
air.execute_terminator %alloc : memref<32x32xi32, 2>
}
%async_token_8, %results_9 = air.execute [%async_token_6] -> (memref<32x32xi32, 2>) {
%alloc = memref.alloc() : memref<32x32xi32, 2>
air.execute_terminator %alloc : memref<32x32xi32, 2>
}
%async_token_10, %results_11 = air.execute [%async_token_6] -> (memref<32x32xi32, 2>) {
%alloc = memref.alloc() : memref<32x32xi32, 2>
air.execute_terminator %alloc : memref<32x32xi32, 2>
}
%5:4 = scf.for %arg13 = %c0 to %c256 step %c64 iter_args(%arg14 = %async_token_8, %arg15 = %async_token_10, %arg16 = %async_token_10, %arg17 = %async_token_10) -> (!air.async.token, !air.async.token, !air.async.token, !air.async.token) {
%6 = air.channel.get async [%arg17, %arg14, %async_token_8] @channel_2[%arg7, %arg8] (%results_9[] [] []) {id = 9 : i32} : (memref<32x32xi32, 2>)
%7 = air.channel.get async [%arg17, %arg14, %async_token_10] @channel_3[%arg7, %arg8] (%results_11[] [] []) {id = 10 : i32} : (memref<32x32xi32, 2>)
%async_token_12 = air.wait_all async [%arg16, %7, %6]
%async_token_13 = air.execute {
memref.dealloc %results_9 : memref<32x32xi32, 2>
}
%async_token_14 = air.execute {
memref.dealloc %results_11 : memref<32x32xi32, 2>
}
%8 = air.channel.get async [%7, %6, %arg15] @channel_2[%arg7, %arg8] (%results_7[] [] []) {id = 9 : i32} : (memref<32x32xi32, 2>)
%9 = air.channel.get async [%7, %6, %arg15] @channel_3[%arg7, %arg8] (%results_5[] [] []) {id = 10 : i32} : (memref<32x32xi32, 2>)
%async_token_15 = air.wait_all async [%async_token_12, %9, %8]
%async_token_16 = air.execute {
memref.dealloc %results_7 : memref<32x32xi32, 2>
}
%async_token_17 = air.execute {
memref.dealloc %results_5 : memref<32x32xi32, 2>
}
%10 = air.wait_all async [%8, %9]
scf.yield %async_token_12, %async_token_15, %async_token_15, %10 : !air.async.token, !air.async.token, !air.async.token, !air.async.token
}
scf.yield %5#1 : !air.async.token
}
%4 = air.channel.put async [%3] @channel_4[%arg7, %arg8] (%results_2[] [] []) {id = 11 : i32} : (memref<32x32xi32, 2>)
%async_token_3 = air.execute [%4] {
memref.dealloc %results_2 : memref<32x32xi32, 2>
}
air.herd_terminator
}
air.segment_terminator
}
air.launch_terminator
}
return
}
}

0 comments on commit fb243b2

Please sign in to comment.