Skip to content

Commit

Permalink
Purge scf.parallel op containing only wait_all no-ops (Xilinx#569)
Browse files Browse the repository at this point in the history
  • Loading branch information
erwei-xilinx authored May 9, 2024
1 parent ccca6b5 commit 85a99b3
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 0 deletions.
33 changes: 33 additions & 0 deletions mlir/lib/Conversion/AIRRtToNpuPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -871,6 +871,7 @@ struct AIRRtToNpuPass : public impl::AIRRtToNpuBase<AIRRtToNpuPass> {
moveFuncOpToEndOfDeviceOp(module);

// Purge all wait all ops
purgeSCFParContainingOnlyWaitAllOps(module);
purgeWaitAlls(module);

// Purge airrt.dma x and y fields, as they are obsolete for AIE2.
Expand Down Expand Up @@ -1131,6 +1132,38 @@ struct AIRRtToNpuPass : public impl::AIRRtToNpuBase<AIRRtToNpuPass> {
}
}

void purgeSCFParContainingOnlyWaitAllOps(ModuleOp module) {
SmallVector<scf::ParallelOp> scf_pars;
module.walk([&](mlir::func::FuncOp f) {
f.walk([&](scf::ParallelOp par_op) { scf_pars.push_back(par_op); });
});
OpBuilder builder(module);
for (auto par_op : scf_pars) {
bool containsOnlyWaitAll = true;
par_op.walk([&](Operation *o) {
if (isa<airrt::WaitAllOp>(o))
return;
else if (isa<scf::ParallelOp>(o))
return;
else if (o->mightHaveTrait<OpTrait::IsTerminator>())
return;
else {
containsOnlyWaitAll = false;
return;
}
});
if (!containsOnlyWaitAll)
assert(false && "found scf.parallel op at this IR, NYI");
builder.setInsertionPoint(par_op);
auto newWaitAll = builder.create<airrt::WaitAllOp>(
par_op->getLoc(), airrt::EventType::get(par_op->getContext()),
par_op.getInitVals());
for (auto res : par_op->getResults())
res.replaceAllUsesWith(newWaitAll->getResult(0));
par_op->erase();
}
}

std::optional<AIE::ShimDMAAllocationOp>
getAllocOpForSymbol(SmallVector<AIE::ShimDMAAllocationOp> shimDmaAllocOps,
StringRef sym_name) {
Expand Down
33 changes: 33 additions & 0 deletions mlir/test/Conversion/AIRRtToNpu/airrt_to_npu.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -897,3 +897,36 @@ module {
return
}
}

// -----

// Purge scf.parallel op which contains only no-ops.

// CHECK-LABEL: func20
// CHECK: return
module {
func.func @func20() {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c152 = arith.constant 152 : index
%51 = airrt.wait_all : !airrt.event
%52 = scf.for %arg3 = %c0 to %c152 step %c1 iter_args(%arg4 = %51) -> (!airrt.event) {
%61 = airrt.wait_all : !airrt.event
%62 = airrt.wait_all %arg4, %61 : !airrt.event
%63 = airrt.wait_all : !airrt.event
%64 = airrt.wait_all %arg4, %63 : !airrt.event
%65 = scf.parallel (%arg5) = (%c0) to (%c2) step (%c1) init (%arg4) -> !airrt.event {
%66 = airrt.wait_all : !airrt.event
%67 = airrt.wait_all %arg4, %66 : !airrt.event
scf.reduce(%67 : !airrt.event) {
^bb0(%arg6: !airrt.event, %arg7: !airrt.event):
%68 = airrt.wait_all %arg6, %arg7 : !airrt.event
scf.reduce.return %68 : !airrt.event
}
}
scf.yield %65 : !airrt.event
}
return
}
}

0 comments on commit 85a99b3

Please sign in to comment.