diff --git a/compiler/include/byteir/Conversion/HloToByreTensor/HloToByreTensor.h b/compiler/include/byteir/Conversion/HloToByreTensor/HloToByreTensor.h index f045933ba..8b9e36c79 100644 --- a/compiler/include/byteir/Conversion/HloToByreTensor/HloToByreTensor.h +++ b/compiler/include/byteir/Conversion/HloToByreTensor/HloToByreTensor.h @@ -24,16 +24,14 @@ #include namespace mlir { +class ModuleOp; // forward decl -namespace func { -class FuncOp; -} // namespace func void populateHloToByreTensorPattern( RewritePatternSet &patterns, const llvm::StringMap &supportMap, bool appendArgTypes); -std::unique_ptr> +std::unique_ptr> createConvertHloToByreTensorPass(bool appendArgTypes = false); } // namespace mlir diff --git a/compiler/include/byteir/Conversion/Passes.td b/compiler/include/byteir/Conversion/Passes.td index ead522d1c..f98447280 100644 --- a/compiler/include/byteir/Conversion/Passes.td +++ b/compiler/include/byteir/Conversion/Passes.td @@ -267,7 +267,7 @@ def MhloToCat : Pass<"mhlo-to-cat", "func::FuncOp"> { // HloToByreTensor //===----------------------------------------------------------------------===// -def ConvertHloToByreTensor : Pass<"hlo-to-byre-tensor", "func::FuncOp"> { +def ConvertHloToByreTensor : Pass<"hlo-to-byre-tensor", "ModuleOp"> { let summary = "Convert hlo op to byre tensor op."; let constructor = "mlir::createConvertHloToByreTensorPass()"; let dependentDialects = [ diff --git a/compiler/include/byteir/Transforms/Passes.td b/compiler/include/byteir/Transforms/Passes.td index b3d3e9853..4bc351390 100644 --- a/compiler/include/byteir/Transforms/Passes.td +++ b/compiler/include/byteir/Transforms/Passes.td @@ -429,7 +429,7 @@ def SetOpSpace: Pass<"set-op-space", "func::FuncOp"> { // ShapeReification //===----------------------------------------------------------------------===// -def ShapeReification : Pass<"byteir-shape-reification", "func::FuncOp"> { +def ShapeReification : Pass<"byteir-shape-reification", "ModuleOp"> { let summary = "Iteratively reify all shape computations."; let description = [{ If an operation has a shape reification implementation, that is to say, we @@ -443,6 +443,11 @@ def ShapeReification : Pass<"byteir-shape-reification", "func::FuncOp"> { "mlir::tensor::TensorDialect", "mlir::arith::ArithDialect", ]; + let options = [ + Option<"anchorFunc", "anchor-func", "std::string", + /*default=*/"", + "An optional funcName used to specify the target function.">, + ]; } #endif // BYTEIR_TRANSFORMS_PASSES diff --git a/compiler/include/byteir/Transforms/ShapeReification.h b/compiler/include/byteir/Transforms/ShapeReification.h index 7c4cb5043..4cc2365e8 100644 --- a/compiler/include/byteir/Transforms/ShapeReification.h +++ b/compiler/include/byteir/Transforms/ShapeReification.h @@ -22,11 +22,10 @@ #include namespace mlir { -namespace func { -class FuncOp; -} // namespace func +class ModuleOp; -std::unique_ptr> createByteIRShapeReificationPass(); +std::unique_ptr> +createByteIRShapeReificationPass(llvm::StringRef anchorFunc = ""); } // namespace mlir diff --git a/compiler/lib/Analysis/SymbolicShape.cpp b/compiler/lib/Analysis/SymbolicShape.cpp index 703dec1c4..e8800ddfc 100644 --- a/compiler/lib/Analysis/SymbolicShape.cpp +++ b/compiler/lib/Analysis/SymbolicShape.cpp @@ -118,12 +118,20 @@ SymbolicShapeAnalysis::SymbolicShapeAnalysis(ModuleOp moduleOp) } // run shape reification pass on all the auxiliary functions - PassManager pm(moduleOp->getContext(), func::FuncOp::getOperationName()); - pm.addPass(createByteIRShapeReificationPass()); - pm.addPass(createCSEPass()); for (auto funcOp : shpFuncOps) { - if (mlir::failed(pm.run(funcOp))) { - llvm::errs() << "Pass pipeline inside symbolic shape analysis failed."; + { + PassManager pm(moduleOp->getContext(), moduleOp.getOperationName()); + pm.addPass(createByteIRShapeReificationPass(funcOp.getName())); + if (mlir::failed(pm.run(moduleOp))) { + llvm::errs() << "Pass pipeline inside symbolic shape analysis failed."; + } + } + { + PassManager pm(moduleOp->getContext(), func::FuncOp::getOperationName()); + pm.addPass(createCSEPass()); + if (mlir::failed(pm.run(funcOp))) { + llvm::errs() << "Pass pipeline inside symbolic shape analysis failed."; + } } } } diff --git a/compiler/lib/Conversion/HloToByreTensor/HloToByreTensor.cpp b/compiler/lib/Conversion/HloToByreTensor/HloToByreTensor.cpp index 91104ec58..dae0bbdc1 100644 --- a/compiler/lib/Conversion/HloToByreTensor/HloToByreTensor.cpp +++ b/compiler/lib/Conversion/HloToByreTensor/HloToByreTensor.cpp @@ -768,7 +768,6 @@ struct ConvertHloToByreTensorPass MLIRContext &ctx = getContext(); RewritePatternSet patterns(&ctx); ConversionTarget target(ctx); - auto funcOp = getOperation(); populateHloToByreTensorPattern(patterns, supportMap, appendArgTypes); target.addIllegalDialect(); @@ -776,7 +775,8 @@ struct ConvertHloToByreTensorPass shape::ShapeDialect, arith::ArithDialect>(); FrozenRewritePatternSet frozenPatterns(std::move(patterns)); - if (failed(applyPartialConversion(funcOp, target, frozenPatterns))) { + if (failed( + applyPartialConversion(getOperation(), target, frozenPatterns))) { signalPassFailure(); } } @@ -810,7 +810,7 @@ void mlir::populateHloToByreTensorPattern( ConvertSliceOp, ConvertConcatenateOp>(patterns.getContext()); } -std::unique_ptr> +std::unique_ptr> mlir::createConvertHloToByreTensorPass(bool appendArgTypes) { return std::make_unique(appendArgTypes); } diff --git a/compiler/lib/Dialect/mhlo/Util/ShapeInferUtil.cpp b/compiler/lib/Dialect/mhlo/Util/ShapeInferUtil.cpp index 5f2abec54..03e48010b 100644 --- a/compiler/lib/Dialect/mhlo/Util/ShapeInferUtil.cpp +++ b/compiler/lib/Dialect/mhlo/Util/ShapeInferUtil.cpp @@ -19,15 +19,18 @@ #include "byteir/Dialect/mhlo/DynamicShapeOpRegister/Register.h" #include "byteir/Transforms/ShapeReification.h" #include "mhlo/IR/hlo_ops.h" +#include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" #include "mlir/Transforms/TopologicalSortUtils.h" #include "llvm/ADT/StringMap.h" #include "llvm/Support/Debug.h" + #include #include @@ -184,32 +187,6 @@ mlir::inferReturnTypeComponents(llvm::StringRef name) { } namespace { - -SmallVector collectAllOpsForReturn(Operation *retOp) { - llvm::DenseSet visitedOp; - std::queue opQueue; - - opQueue.push(retOp); - while (!opQueue.empty()) { - auto frontOp = opQueue.front(); - opQueue.pop(); - if (visitedOp.find(frontOp) != visitedOp.end()) { - continue; - } - visitedOp.insert(frontOp); - for (Value operand : frontOp->getOperands()) { - if (!operand.getDefiningOp()) { - continue; - } - if (Operation *defOp = operand.getDefiningOp()) { - opQueue.push(defOp); - } - } - } - visitedOp.erase(retOp); - return SmallVector(visitedOp.begin(), visitedOp.end()); -} - bool deduceFromFuncArgShape(Value value) { if (value.isa()) { return false; @@ -240,42 +217,30 @@ bool deduceFromFuncArgShape(Value value) { return true; } -LogicalResult reifyCallOp(OpBuilder &builder, Operation *op, - SmallVectorImpl &reifications) { - OpBuilder::InsertionGuard guard(builder); - auto callOp = dyn_cast(op); - if (!callOp) { - return failure(); - } - - ModuleOp moduleOp = op->getParentRegion()->getParentOfType(); - // auxiliary builder used for create operations in shape func - // original builder maybe a rewriter, used for create operations in specific - // pattern. - OpBuilder auxiliaryBuilder(moduleOp); - StringRef funcName = callOp.getCallee(); - auto funcOp = moduleOp.lookupSymbol(funcName); +FailureOr createCorrespondingShapeFunc(func::FuncOp funcOp) { + ModuleOp moduleOp = funcOp->getParentOfType(); + // use auxiliary builder, create shape func in the start of moduleOp + OpBuilder builder = OpBuilder::atBlockBegin(moduleOp.getBody()); // clone funcOp, newFuncOp used for deduce function shape - std::string newFuncName = funcName.str() + "_Shape"; - auxiliaryBuilder.setInsertionPointToStart(moduleOp.getBody()); - auto newFuncOp = auxiliaryBuilder.create( - funcOp->getLoc(), newFuncName, funcOp.getFunctionType()); - newFuncOp.setPrivate(); + Twine shapeFuncName = funcOp.getName() + "_Shape"; + auto shapeFunc = builder.create( + funcOp->getLoc(), shapeFuncName.str(), funcOp.getFunctionType()); + shapeFunc.setPrivate(); IRMapping emptyBvm; - funcOp.cloneInto(newFuncOp, emptyBvm); + funcOp.cloneInto(shapeFunc, emptyBvm); // replace the operands of returnOp with corresponding shape - func::ReturnOp retOp = *newFuncOp.getOps().begin(); + func::ReturnOp retOp = *shapeFunc.getOps().begin(); if (!retOp) { - newFuncOp->erase(); + shapeFunc->erase(); return failure(); } for (Value &&retTensor : retOp.getOperands()) { auto retTy = retTensor.getType(); if (!retTy.isa()) { - newFuncOp->erase(); + shapeFunc->erase(); return failure(); } } @@ -283,44 +248,76 @@ LogicalResult reifyCallOp(OpBuilder &builder, Operation *op, SmallVector allResultTypes; SmallVector allResults; - auxiliaryBuilder.setInsertionPoint(retOp); + builder.setInsertionPoint(retOp); for (Value &&retTensor : retOp.getOperands()) { - auto retShape = - auxiliaryBuilder.create(retOp.getLoc(), retTensor); + auto retShape = builder.create(retOp.getLoc(), retTensor); allResultTypes.emplace_back(retShape.getType()); allResults.emplace_back(retShape); } // return the shape of original tensor returned by function - auto newRetOp = - auxiliaryBuilder.create(retOp.getLoc(), allResults); - auto newFuncType = auxiliaryBuilder.getFunctionType( - newFuncOp.getArgumentTypes(), allResultTypes); - newFuncOp.setFunctionType(newFuncType); + auto shapeFuncRetOp = + builder.create(retOp.getLoc(), allResults); + auto shapeFuncType = + builder.getFunctionType(shapeFunc.getArgumentTypes(), allResultTypes); + shapeFunc.setFunctionType(shapeFuncType); retOp->erase(); - // reify newFunc to get the shape computation for current callOp + // reify shapeFunc to get the shape computation. { - PassManager pm(moduleOp->getContext(), func::FuncOp::getOperationName()); - pm.addPass(createCanonicalizerPass()); - pm.addPass(createCSEPass()); - pm.addPass(createByteIRShapeReificationPass()); + PassManager pm(moduleOp->getContext(), moduleOp.getOperationName()); + // only run pass on shapeFunc + pm.addPass(createByteIRShapeReificationPass(shapeFunc.getName())); + if (mlir::failed(pm.run(moduleOp))) { + shapeFunc->erase(); + return failure(); + } + } + + // canonicalize shapeFunc + { + PassManager pm(shapeFunc->getContext(), shapeFunc.getOperationName()); pm.addPass(createCanonicalizerPass()); pm.addPass(createCSEPass()); - - if (mlir::failed(pm.run(newFuncOp))) { - newFuncOp->erase(); + // only run pass on shapeFunc, don't modify other ops. + if (mlir::failed(pm.run(shapeFunc))) { + shapeFunc->erase(); return failure(); } } + return shapeFunc; +} - // collect all shape computation ops - SmallVector reificationOps = collectAllOpsForReturn(newRetOp); +LogicalResult reifyCallOp(OpBuilder &builder, Operation *op, + SmallVectorImpl &reifications) { + OpBuilder::InsertionGuard guard(builder); + auto callOp = dyn_cast(op); + if (!callOp) { + return failure(); + } + + ModuleOp moduleOp = op->getParentOfType(); + StringRef funcName = callOp.getCallee(); + auto funcOp = moduleOp.lookupSymbol(funcName); + // create corresponding shape function + auto maybeShapeFunc = createCorrespondingShapeFunc(funcOp); + if (failed(maybeShapeFunc)) { + return failure(); + } + + func::FuncOp shapeFunc = *maybeShapeFunc; + func::ReturnOp retOp = *shapeFunc.getOps().begin(); + + // collect all shape computation ops + SetVector reificationOpSet; + getBackwardSlice(retOp.getOperation(), &reificationOpSet); + SmallVector reificationOps(reificationOpSet.begin(), + reificationOpSet.end()); // value only depends on the shape of FuncArgs. - for (Value &&ret : newRetOp.getOperands()) { + for (Value &&ret : retOp.getOperands()) { if (!deduceFromFuncArgShape(ret)) { - newFuncOp->erase(); + shapeFunc->erase(); return failure(); } } @@ -330,9 +327,9 @@ LogicalResult reifyCallOp(OpBuilder &builder, Operation *op, mlir::computeTopologicalSorting(reificationOps); IRMapping bvm; - size_t numArg = newFuncOp.getNumArguments(); + size_t numArg = shapeFunc.getNumArguments(); for (size_t i = 0; i < numArg; ++i) { - bvm.map(newFuncOp.getArgument(i), callOp.getOperand(i)); + bvm.map(shapeFunc.getArgument(i), callOp.getOperand(i)); } builder.setInsertionPoint(callOp); @@ -341,13 +338,13 @@ LogicalResult reifyCallOp(OpBuilder &builder, Operation *op, auto newOp = builder.clone(*oldOp, bvm); } - for (Value &&ret : newRetOp.getOperands()) { + for (Value &&ret : retOp.getOperands()) { reifications.push_back(bvm.lookup(ret)); } } // remove newFuncOp - newFuncOp->erase(); + shapeFunc->erase(); return success(); } diff --git a/compiler/lib/Pipelines/ByreTensorOpt.cpp b/compiler/lib/Pipelines/ByreTensorOpt.cpp index 4d5c2b5c6..3b87432a1 100644 --- a/compiler/lib/Pipelines/ByreTensorOpt.cpp +++ b/compiler/lib/Pipelines/ByreTensorOpt.cpp @@ -45,9 +45,8 @@ void createByreTensorOptPipelineImpl(OpPassManager &pm, std::string entryFunc, pm.addPass(createCanonicalizerPass()); pm.addNestedPass( createConvertHloToByreCustomPass(getCudaByreCustomConfig())); - pm.addNestedPass( - createConvertHloToByreTensorPass(appendArgTypes)); - pm.addNestedPass(createByteIRShapeReificationPass()); + pm.addPass(createConvertHloToByreTensorPass(appendArgTypes)); + pm.addPass(createByteIRShapeReificationPass()); pm.addPass(createCanonicalizerPass()); } } // namespace diff --git a/compiler/lib/Pipelines/ShapeOpt.cpp b/compiler/lib/Pipelines/ShapeOpt.cpp index 71870b343..6fd7204f5 100644 --- a/compiler/lib/Pipelines/ShapeOpt.cpp +++ b/compiler/lib/Pipelines/ShapeOpt.cpp @@ -27,17 +27,13 @@ using namespace mlir; void mlir::createShapeOptPipeline(OpPassManager &pm) { - invokeOpPassPipelineBuilder( - [](OpPassManager &pm) { - pm.addPass(createSetAssumingAlwaysTruePass()); - pm.addPass(createCanonicalizeExtPass()); - pm.addPass(createInsertTieShapePass()); - pm.addPass(createInsertShapeConstraintPass()); - pm.addPass(createByteIRShapeReificationPass()); - addCleanUpExtPassPipeline(pm, /*topHasSymTable*/ false); - pm.addPass(createResolveShapeConstraintPass()); - pm.addPass(createBoundedShapeInferencePass()); - pm.addPass(createCanonicalizeExtPass()); - }, - pm); + pm.addNestedPass(createSetAssumingAlwaysTruePass()); + pm.addNestedPass(createCanonicalizeExtPass()); + pm.addNestedPass(createInsertTieShapePass()); + pm.addNestedPass(createInsertShapeConstraintPass()); + pm.addPass(createByteIRShapeReificationPass()); + addCleanUpExtPassPipeline(pm, /*topHasSymTable*/ false); + pm.addNestedPass(createResolveShapeConstraintPass()); + pm.addNestedPass(createBoundedShapeInferencePass()); + pm.addNestedPass(createCanonicalizeExtPass()); } diff --git a/compiler/lib/Transforms/ShapeReification.cpp b/compiler/lib/Transforms/ShapeReification.cpp index f196d93a6..e6559f06c 100644 --- a/compiler/lib/Transforms/ShapeReification.cpp +++ b/compiler/lib/Transforms/ShapeReification.cpp @@ -65,6 +65,7 @@ struct ShapeReificationOnTensorDimPattern } rewriter.replaceOp(op, dimOfShape); + return success(); } }; @@ -106,14 +107,16 @@ void PopulateShapeReificationPatterns(MLIRContext *ctx, struct ShapeReificationPass : public ShapeReificationBase { - ShapeReificationPass() + ShapeReificationPass(const std::string &anchorFunc) : ShapeReificationBase::ShapeReificationBase() { // ReifyReturnType implementation could also be registered outside // ShapeReificationPass registerAllMhloReifyReturnTypeShapes(); + this->anchorFunc = anchorFunc; } void runOnOperation() override { + ModuleOp moduleOp = getOperation(); // Collect patterns. MLIRContext *ctx = &getContext(); RewritePatternSet patterns(ctx); @@ -125,17 +128,22 @@ struct ShapeReificationPass // iteration. GreedyRewriteConfig cfg; cfg.useTopDownTraversal = false; - func::FuncOp f = getOperation(); FrozenRewritePatternSet frozenPatterns(std::move(patterns)); - if (failed(applyPatternsAndFoldGreedily(f, frozenPatterns, cfg))) { - return signalPassFailure(); + + // apply Patterns on target funcOp. + for (auto funcOp : moduleOp.getOps()) { + if (this->anchorFunc == "" || funcOp.getName() == this->anchorFunc) { + if (failed(applyPatternsAndFoldGreedily(funcOp, frozenPatterns, cfg))) { + return signalPassFailure(); + } + } } } }; } // namespace -std::unique_ptr> -mlir::createByteIRShapeReificationPass() { - return std::make_unique(); +std::unique_ptr> +mlir::createByteIRShapeReificationPass(llvm::StringRef anchorFunc /*=""*/) { + return std::make_unique(anchorFunc.str()); }