diff --git a/mlir/include/air/Util/Dependency.h b/mlir/include/air/Util/Dependency.h index 5d6387d99..307023496 100644 --- a/mlir/include/air/Util/Dependency.h +++ b/mlir/include/air/Util/Dependency.h @@ -70,7 +70,8 @@ void addAsyncDependencyIfNew(Operation *op, Value token); bool isAsyncOp(Operation *op); bool areAsyncDependent(Operation *a, Operation *b); bool isAsyncDependent(Operation *a, Operation *b); -scf::ForOp hoistTargetOpsToNewSCFFor(OpBuilder builder, scf::ForOp for_op, +scf::ForOp hoistTargetOpsToNewSCFFor(PatternRewriter &rewriter, + scf::ForOp for_op, SmallVector target_ops); LogicalResult unrollAIRChannelPutGetInScfParallel(OpBuilder builder, scf::ParallelOp par, diff --git a/mlir/lib/Util/Dependency.cpp b/mlir/lib/Util/Dependency.cpp index f884e9238..55a566e4a 100644 --- a/mlir/lib/Util/Dependency.cpp +++ b/mlir/lib/Util/Dependency.cpp @@ -660,14 +660,27 @@ bool isAsyncDependent(Operation *a, Operation *b) { // Splits an SCF for loop into two for loops, by hoisting target operations in // for loop to a new for loop located at the same scope. -scf::ForOp hoistTargetOpsToNewSCFFor(OpBuilder builder, scf::ForOp for_op, +scf::ForOp hoistTargetOpsToNewSCFFor(PatternRewriter &rewriter, + scf::ForOp for_op, SmallVector target_ops) { auto loc = for_op->getLoc(); // If target ops are already perfectly nested, then skip - auto hasNChannelOps = [](Block *block, unsigned N) { - SmallVector chanOps; - block->walk([&](air::ChannelInterface op) { chanOps.push_back(op); }); - return chanOps.size() == N; + auto hasNChannelOps = [target_ops](Block *block, unsigned N) { + unsigned counter = 0; + block->walk>( + [target_ops, &counter](Operation *op) { + if (op->hasTrait()) + return WalkResult::skip(); + if (llvm::is_contained(target_ops, op)) { + counter++; + return WalkResult::skip(); + } + if (isa(op)) + counter++; + counter++; + return WalkResult::advance(); + }); + return counter == N; }; if (hasNChannelOps(for_op.getBody(), 1)) return for_op; @@ -686,20 +699,20 @@ scf::ForOp hoistTargetOpsToNewSCFFor(OpBuilder builder, scf::ForOp for_op, } } - builder.setInsertionPoint(for_op); + rewriter.setInsertionPoint(for_op); IRMapping remap; - auto new_for_op = builder.create( + auto new_for_op = rewriter.create( loc, for_op.getLowerBound(), for_op.getUpperBound(), for_op.getStep(), - SmallVector{builder - .create( - loc, - air::AsyncTokenType::get(builder.getContext()), - SmallVector{}) - .getAsyncToken()}); + SmallVector{ + rewriter + .create( + loc, air::AsyncTokenType::get(rewriter.getContext()), + SmallVector{}) + .getAsyncToken()}); remap.map(for_op.getInductionVar(), new_for_op.getInductionVar()); remap.map(getLoopCarriedTokenFromScfOp(for_op, "argument"), getLoopCarriedTokenFromScfOp(new_for_op, "argument")); - builder.setInsertionPointToStart(new_for_op.getBody()); + rewriter.setInsertionPointToStart(new_for_op.getBody()); SmallVector yield_operands; // Build up a log of ops to be cloned; using SetVector to avoid repetition. llvm::SetVector ops_to_be_cloned; @@ -719,14 +732,14 @@ scf::ForOp hoistTargetOpsToNewSCFFor(OpBuilder builder, scf::ForOp for_op, } Operation *back_of_dep_chain; for (auto o : ops_to_be_cloned) - back_of_dep_chain = builder.clone(*o, remap); + back_of_dep_chain = rewriter.clone(*o, remap); yield_operands.push_back(getAsyncTokenFromOp(back_of_dep_chain)); - builder.create( + rewriter.create( loc, SmallVector{ - builder + rewriter .create( - loc, air::AsyncTokenType::get(builder.getContext()), + loc, air::AsyncTokenType::get(rewriter.getContext()), yield_operands) ->getResult(0)}); @@ -738,7 +751,7 @@ scf::ForOp hoistTargetOpsToNewSCFFor(OpBuilder builder, scf::ForOp for_op, } for (auto erase_op : target_ops) { // Reconnect returned tokens. - builder.setInsertionPoint(erase_op); + rewriter.setInsertionPoint(erase_op); for (auto res : erase_op->getResults()) { if (!isa(res.getType())) continue; @@ -752,9 +765,9 @@ scf::ForOp hoistTargetOpsToNewSCFFor(OpBuilder builder, scf::ForOp for_op, // User op doesn't have air::AsyncOpInterface. Replace uses with newly // generated air.wait_all op. u->replaceUsesOfWith( - res, builder + res, rewriter .create( - loc, air::AsyncTokenType::get(builder.getContext()), + loc, air::AsyncTokenType::get(rewriter.getContext()), getAsyncDependenciesFromOp(erase_op)) .getAsyncToken()); } @@ -762,7 +775,7 @@ scf::ForOp hoistTargetOpsToNewSCFFor(OpBuilder builder, scf::ForOp for_op, } } for (auto erase_op : target_ops) - erase_op->erase(); + rewriter.eraseOp(erase_op); for (auto user : for_op.getResults().front().getUsers()) { air::addAsyncDependencyIfNew(user, new_for_op.getResults().front()); }