From 85a99b3154f4c6559aa3657cdb72327ddf249c1f Mon Sep 17 00:00:00 2001 From: erwei-xilinx Date: Thu, 9 May 2024 15:16:43 -0700 Subject: [PATCH] Purge scf.parallel op containing only wait_all no-ops (#569) --- mlir/lib/Conversion/AIRRtToNpuPass.cpp | 33 +++++++++++++++++++ .../Conversion/AIRRtToNpu/airrt_to_npu.mlir | 33 +++++++++++++++++++ 2 files changed, 66 insertions(+) diff --git a/mlir/lib/Conversion/AIRRtToNpuPass.cpp b/mlir/lib/Conversion/AIRRtToNpuPass.cpp index 858e3ae0e..384a43dd9 100644 --- a/mlir/lib/Conversion/AIRRtToNpuPass.cpp +++ b/mlir/lib/Conversion/AIRRtToNpuPass.cpp @@ -871,6 +871,7 @@ struct AIRRtToNpuPass : public impl::AIRRtToNpuBase { moveFuncOpToEndOfDeviceOp(module); // Purge all wait all ops + purgeSCFParContainingOnlyWaitAllOps(module); purgeWaitAlls(module); // Purge airrt.dma x and y fields, as they are obsolete for AIE2. @@ -1131,6 +1132,38 @@ struct AIRRtToNpuPass : public impl::AIRRtToNpuBase { } } + void purgeSCFParContainingOnlyWaitAllOps(ModuleOp module) { + SmallVector scf_pars; + module.walk([&](mlir::func::FuncOp f) { + f.walk([&](scf::ParallelOp par_op) { scf_pars.push_back(par_op); }); + }); + OpBuilder builder(module); + for (auto par_op : scf_pars) { + bool containsOnlyWaitAll = true; + par_op.walk([&](Operation *o) { + if (isa(o)) + return; + else if (isa(o)) + return; + else if (o->mightHaveTrait()) + return; + else { + containsOnlyWaitAll = false; + return; + } + }); + if (!containsOnlyWaitAll) + assert(false && "found scf.parallel op at this IR, NYI"); + builder.setInsertionPoint(par_op); + auto newWaitAll = builder.create( + par_op->getLoc(), airrt::EventType::get(par_op->getContext()), + par_op.getInitVals()); + for (auto res : par_op->getResults()) + res.replaceAllUsesWith(newWaitAll->getResult(0)); + par_op->erase(); + } + } + std::optional getAllocOpForSymbol(SmallVector shimDmaAllocOps, StringRef sym_name) { diff --git a/mlir/test/Conversion/AIRRtToNpu/airrt_to_npu.mlir b/mlir/test/Conversion/AIRRtToNpu/airrt_to_npu.mlir index 37f7ba4d6..3b3fe9aeb 100644 --- a/mlir/test/Conversion/AIRRtToNpu/airrt_to_npu.mlir +++ b/mlir/test/Conversion/AIRRtToNpu/airrt_to_npu.mlir @@ -897,3 +897,36 @@ module { return } } + +// ----- + +// Purge scf.parallel op which contains only no-ops. + +// CHECK-LABEL: func20 +// CHECK: return +module { + func.func @func20() { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c152 = arith.constant 152 : index + %51 = airrt.wait_all : !airrt.event + %52 = scf.for %arg3 = %c0 to %c152 step %c1 iter_args(%arg4 = %51) -> (!airrt.event) { + %61 = airrt.wait_all : !airrt.event + %62 = airrt.wait_all %arg4, %61 : !airrt.event + %63 = airrt.wait_all : !airrt.event + %64 = airrt.wait_all %arg4, %63 : !airrt.event + %65 = scf.parallel (%arg5) = (%c0) to (%c2) step (%c1) init (%arg4) -> !airrt.event { + %66 = airrt.wait_all : !airrt.event + %67 = airrt.wait_all %arg4, %66 : !airrt.event + scf.reduce(%67 : !airrt.event) { + ^bb0(%arg6: !airrt.event, %arg7: !airrt.event): + %68 = airrt.wait_all %arg6, %arg7 : !airrt.event + scf.reduce.return %68 : !airrt.event + } + } + scf.yield %65 : !airrt.event + } + return + } +}