Skip to content

Commit

Permalink
[compiler] modified top level of shapeReification to moduleOp
Browse files Browse the repository at this point in the history
  • Loading branch information
XG-zheng committed Apr 10, 2024
1 parent 7363a9c commit 80408e2
Show file tree
Hide file tree
Showing 10 changed files with 125 additions and 115 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,14 @@
#include <memory>

namespace mlir {
class ModuleOp;
// forward decl
namespace func {
class FuncOp;
} // namespace func

void populateHloToByreTensorPattern(
RewritePatternSet &patterns,
const llvm::StringMap<llvm::StringRef> &supportMap, bool appendArgTypes);

std::unique_ptr<OperationPass<func::FuncOp>>
std::unique_ptr<OperationPass<ModuleOp>>
createConvertHloToByreTensorPass(bool appendArgTypes = false);

} // namespace mlir
Expand Down
2 changes: 1 addition & 1 deletion compiler/include/byteir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def MhloToCat : Pass<"mhlo-to-cat", "func::FuncOp"> {
// HloToByreTensor
//===----------------------------------------------------------------------===//

def ConvertHloToByreTensor : Pass<"hlo-to-byre-tensor", "func::FuncOp"> {
def ConvertHloToByreTensor : Pass<"hlo-to-byre-tensor", "ModuleOp"> {
let summary = "Convert hlo op to byre tensor op.";
let constructor = "mlir::createConvertHloToByreTensorPass()";
let dependentDialects = [
Expand Down
7 changes: 6 additions & 1 deletion compiler/include/byteir/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ def SetOpSpace: Pass<"set-op-space", "func::FuncOp"> {
// ShapeReification
//===----------------------------------------------------------------------===//

def ShapeReification : Pass<"byteir-shape-reification", "func::FuncOp"> {
def ShapeReification : Pass<"byteir-shape-reification", "ModuleOp"> {
let summary = "Iteratively reify all shape computations.";
let description = [{
If an operation has a shape reification implementation, that is to say, we
Expand All @@ -443,6 +443,11 @@ def ShapeReification : Pass<"byteir-shape-reification", "func::FuncOp"> {
"mlir::tensor::TensorDialect",
"mlir::arith::ArithDialect",
];
let options = [
Option<"anchorFunc", "anchor-func", "std::string",
/*default=*/"",
"An optional funcName used to specify the target function.">,
];
}

#endif // BYTEIR_TRANSFORMS_PASSES
7 changes: 3 additions & 4 deletions compiler/include/byteir/Transforms/ShapeReification.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,10 @@
#include <memory>

namespace mlir {
namespace func {
class FuncOp;
} // namespace func
class ModuleOp;

std::unique_ptr<OperationPass<func::FuncOp>> createByteIRShapeReificationPass();
std::unique_ptr<OperationPass<ModuleOp>>
createByteIRShapeReificationPass(llvm::StringRef anchorFunc = "");

} // namespace mlir

Expand Down
18 changes: 13 additions & 5 deletions compiler/lib/Analysis/SymbolicShape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,20 @@ SymbolicShapeAnalysis::SymbolicShapeAnalysis(ModuleOp moduleOp)
}

// run shape reification pass on all the auxiliary functions
PassManager pm(moduleOp->getContext(), func::FuncOp::getOperationName());
pm.addPass(createByteIRShapeReificationPass());
pm.addPass(createCSEPass());
for (auto funcOp : shpFuncOps) {
if (mlir::failed(pm.run(funcOp))) {
llvm::errs() << "Pass pipeline inside symbolic shape analysis failed.";
{
PassManager pm(moduleOp->getContext(), moduleOp.getOperationName());
pm.addPass(createByteIRShapeReificationPass(funcOp.getName()));
if (mlir::failed(pm.run(moduleOp))) {
llvm::errs() << "Pass pipeline inside symbolic shape analysis failed.";
}
}
{
PassManager pm(moduleOp->getContext(), func::FuncOp::getOperationName());
pm.addPass(createCSEPass());
if (mlir::failed(pm.run(funcOp))) {
llvm::errs() << "Pass pipeline inside symbolic shape analysis failed.";
}
}
}
}
Expand Down
6 changes: 3 additions & 3 deletions compiler/lib/Conversion/HloToByreTensor/HloToByreTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -768,15 +768,15 @@ struct ConvertHloToByreTensorPass
MLIRContext &ctx = getContext();
RewritePatternSet patterns(&ctx);
ConversionTarget target(ctx);
auto funcOp = getOperation();

populateHloToByreTensorPattern(patterns, supportMap, appendArgTypes);
target.addIllegalDialect<mhlo::MhloDialect>();
target.addLegalDialect<tensor::TensorDialect, byre::ByreDialect,
shape::ShapeDialect, arith::ArithDialect>();

FrozenRewritePatternSet frozenPatterns(std::move(patterns));
if (failed(applyPartialConversion(funcOp, target, frozenPatterns))) {
if (failed(
applyPartialConversion(getOperation(), target, frozenPatterns))) {
signalPassFailure();
}
}
Expand Down Expand Up @@ -810,7 +810,7 @@ void mlir::populateHloToByreTensorPattern(
ConvertSliceOp, ConvertConcatenateOp>(patterns.getContext());
}

std::unique_ptr<OperationPass<func::FuncOp>>
std::unique_ptr<OperationPass<ModuleOp>>
mlir::createConvertHloToByreTensorPass(bool appendArgTypes) {
return std::make_unique<ConvertHloToByreTensorPass>(appendArgTypes);
}
145 changes: 71 additions & 74 deletions compiler/lib/Dialect/mhlo/Util/ShapeInferUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,18 @@
#include "byteir/Dialect/mhlo/DynamicShapeOpRegister/Register.h"
#include "byteir/Transforms/ShapeReification.h"
#include "mhlo/IR/hlo_ops.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/TopologicalSortUtils.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/Support/Debug.h"

#include <queue>
#include <string>

Expand Down Expand Up @@ -184,32 +187,6 @@ mlir::inferReturnTypeComponents(llvm::StringRef name) {
}

namespace {

SmallVector<Operation *> collectAllOpsForReturn(Operation *retOp) {
llvm::DenseSet<Operation *> visitedOp;
std::queue<Operation *> opQueue;

opQueue.push(retOp);
while (!opQueue.empty()) {
auto frontOp = opQueue.front();
opQueue.pop();
if (visitedOp.find(frontOp) != visitedOp.end()) {
continue;
}
visitedOp.insert(frontOp);
for (Value operand : frontOp->getOperands()) {
if (!operand.getDefiningOp()) {
continue;
}
if (Operation *defOp = operand.getDefiningOp()) {
opQueue.push(defOp);
}
}
}
visitedOp.erase(retOp);
return SmallVector<Operation *>(visitedOp.begin(), visitedOp.end());
}

bool deduceFromFuncArgShape(Value value) {
if (value.isa<BlockArgument>()) {
return false;
Expand Down Expand Up @@ -240,87 +217,107 @@ bool deduceFromFuncArgShape(Value value) {
return true;
}

LogicalResult reifyCallOp(OpBuilder &builder, Operation *op,
SmallVectorImpl<Value> &reifications) {
OpBuilder::InsertionGuard guard(builder);
auto callOp = dyn_cast<func::CallOp>(op);
if (!callOp) {
return failure();
}

ModuleOp moduleOp = op->getParentRegion()->getParentOfType<ModuleOp>();
// auxiliary builder used for create operations in shape func
// original builder maybe a rewriter, used for create operations in specific
// pattern.
OpBuilder auxiliaryBuilder(moduleOp);
StringRef funcName = callOp.getCallee();
auto funcOp = moduleOp.lookupSymbol<func::FuncOp>(funcName);
FailureOr<func::FuncOp> createCorrespondingShapeFunc(func::FuncOp funcOp) {
ModuleOp moduleOp = funcOp->getParentOfType<ModuleOp>();
// use auxiliary builder, create shape func in the start of moduleOp
OpBuilder builder = OpBuilder::atBlockBegin(moduleOp.getBody());

// clone funcOp, newFuncOp used for deduce function shape
std::string newFuncName = funcName.str() + "_Shape";
auxiliaryBuilder.setInsertionPointToStart(moduleOp.getBody());
auto newFuncOp = auxiliaryBuilder.create<func::FuncOp>(
funcOp->getLoc(), newFuncName, funcOp.getFunctionType());
newFuncOp.setPrivate();
Twine shapeFuncName = funcOp.getName() + "_Shape";
auto shapeFunc = builder.create<func::FuncOp>(
funcOp->getLoc(), shapeFuncName.str(), funcOp.getFunctionType());
shapeFunc.setPrivate();
IRMapping emptyBvm;
funcOp.cloneInto(newFuncOp, emptyBvm);
funcOp.cloneInto(shapeFunc, emptyBvm);

// replace the operands of returnOp with corresponding shape
func::ReturnOp retOp = *newFuncOp.getOps<func::ReturnOp>().begin();
func::ReturnOp retOp = *shapeFunc.getOps<func::ReturnOp>().begin();
if (!retOp) {
newFuncOp->erase();
shapeFunc->erase();
return failure();
}

for (Value &&retTensor : retOp.getOperands()) {
auto retTy = retTensor.getType();
if (!retTy.isa<RankedTensorType>()) {
newFuncOp->erase();
shapeFunc->erase();
return failure();
}
}

SmallVector<Type> allResultTypes;
SmallVector<Value> allResults;

auxiliaryBuilder.setInsertionPoint(retOp);
builder.setInsertionPoint(retOp);
for (Value &&retTensor : retOp.getOperands()) {
auto retShape =
auxiliaryBuilder.create<shape::ShapeOfOp>(retOp.getLoc(), retTensor);
auto retShape = builder.create<shape::ShapeOfOp>(retOp.getLoc(), retTensor);
allResultTypes.emplace_back(retShape.getType());
allResults.emplace_back(retShape);
}

// return the shape of original tensor returned by function
auto newRetOp =
auxiliaryBuilder.create<func::ReturnOp>(retOp.getLoc(), allResults);
auto newFuncType = auxiliaryBuilder.getFunctionType(
newFuncOp.getArgumentTypes(), allResultTypes);
newFuncOp.setFunctionType(newFuncType);
auto shapeFuncRetOp =
builder.create<func::ReturnOp>(retOp.getLoc(), allResults);
auto shapeFuncType =
builder.getFunctionType(shapeFunc.getArgumentTypes(), allResultTypes);
shapeFunc.setFunctionType(shapeFuncType);
retOp->erase();

// reify newFunc to get the shape computation for current callOp
// reify shapeFunc to get the shape computation.
{
PassManager pm(moduleOp->getContext(), func::FuncOp::getOperationName());
pm.addPass(createCanonicalizerPass());
pm.addPass(createCSEPass());
pm.addPass(createByteIRShapeReificationPass());
PassManager pm(moduleOp->getContext(), moduleOp.getOperationName());
// only run pass on shapeFunc
pm.addPass(createByteIRShapeReificationPass(shapeFunc.getName()));
if (mlir::failed(pm.run(moduleOp))) {
shapeFunc->erase();
return failure();
}
}

// canonicalize shapeFunc
{
PassManager pm(shapeFunc->getContext(), shapeFunc.getOperationName());
pm.addPass(createCanonicalizerPass());
pm.addPass(createCSEPass());

if (mlir::failed(pm.run(newFuncOp))) {
newFuncOp->erase();
// only run pass on shapeFunc, don't modify other ops.
if (mlir::failed(pm.run(shapeFunc))) {
shapeFunc->erase();
return failure();
}
}
return shapeFunc;
}

// collect all shape computation ops
SmallVector<Operation *> reificationOps = collectAllOpsForReturn(newRetOp);
LogicalResult reifyCallOp(OpBuilder &builder, Operation *op,
SmallVectorImpl<Value> &reifications) {
OpBuilder::InsertionGuard guard(builder);
auto callOp = dyn_cast<func::CallOp>(op);
if (!callOp) {
return failure();
}

ModuleOp moduleOp = op->getParentOfType<ModuleOp>();
StringRef funcName = callOp.getCallee();
auto funcOp = moduleOp.lookupSymbol<func::FuncOp>(funcName);

// create corresponding shape function
auto maybeShapeFunc = createCorrespondingShapeFunc(funcOp);
if (failed(maybeShapeFunc)) {
return failure();
}

func::FuncOp shapeFunc = *maybeShapeFunc;
func::ReturnOp retOp = *shapeFunc.getOps<func::ReturnOp>().begin();

// collect all shape computation ops
SetVector<Operation *> reificationOpSet;
getBackwardSlice(retOp.getOperation(), &reificationOpSet);
SmallVector<Operation *> reificationOps(reificationOpSet.begin(),
reificationOpSet.end());
// value only depends on the shape of FuncArgs.
for (Value &&ret : newRetOp.getOperands()) {
for (Value &&ret : retOp.getOperands()) {
if (!deduceFromFuncArgShape(ret)) {
newFuncOp->erase();
shapeFunc->erase();
return failure();
}
}
Expand All @@ -330,9 +327,9 @@ LogicalResult reifyCallOp(OpBuilder &builder, Operation *op,
mlir::computeTopologicalSorting(reificationOps);

IRMapping bvm;
size_t numArg = newFuncOp.getNumArguments();
size_t numArg = shapeFunc.getNumArguments();
for (size_t i = 0; i < numArg; ++i) {
bvm.map(newFuncOp.getArgument(i), callOp.getOperand(i));
bvm.map(shapeFunc.getArgument(i), callOp.getOperand(i));
}

builder.setInsertionPoint(callOp);
Expand All @@ -341,13 +338,13 @@ LogicalResult reifyCallOp(OpBuilder &builder, Operation *op,
auto newOp = builder.clone(*oldOp, bvm);
}

for (Value &&ret : newRetOp.getOperands()) {
for (Value &&ret : retOp.getOperands()) {
reifications.push_back(bvm.lookup(ret));
}
}

// remove newFuncOp
newFuncOp->erase();
shapeFunc->erase();
return success();
}

Expand Down
5 changes: 2 additions & 3 deletions compiler/lib/Pipelines/ByreTensorOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,8 @@ void createByreTensorOptPipelineImpl(OpPassManager &pm, std::string entryFunc,
pm.addPass(createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(
createConvertHloToByreCustomPass(getCudaByreCustomConfig()));
pm.addNestedPass<func::FuncOp>(
createConvertHloToByreTensorPass(appendArgTypes));
pm.addNestedPass<func::FuncOp>(createByteIRShapeReificationPass());
pm.addPass(createConvertHloToByreTensorPass(appendArgTypes));
pm.addPass(createByteIRShapeReificationPass());
pm.addPass(createCanonicalizerPass());
}
} // namespace
Expand Down
22 changes: 9 additions & 13 deletions compiler/lib/Pipelines/ShapeOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,13 @@
using namespace mlir;

void mlir::createShapeOptPipeline(OpPassManager &pm) {
invokeOpPassPipelineBuilder<func::FuncOp>(
[](OpPassManager &pm) {
pm.addPass(createSetAssumingAlwaysTruePass());
pm.addPass(createCanonicalizeExtPass());
pm.addPass(createInsertTieShapePass());
pm.addPass(createInsertShapeConstraintPass());
pm.addPass(createByteIRShapeReificationPass());
addCleanUpExtPassPipeline(pm, /*topHasSymTable*/ false);
pm.addPass(createResolveShapeConstraintPass());
pm.addPass(createBoundedShapeInferencePass());
pm.addPass(createCanonicalizeExtPass());
},
pm);
pm.addNestedPass<func::FuncOp>(createSetAssumingAlwaysTruePass());
pm.addNestedPass<func::FuncOp>(createCanonicalizeExtPass());
pm.addNestedPass<func::FuncOp>(createInsertTieShapePass());
pm.addNestedPass<func::FuncOp>(createInsertShapeConstraintPass());
pm.addPass(createByteIRShapeReificationPass());
addCleanUpExtPassPipeline(pm, /*topHasSymTable*/ false);
pm.addNestedPass<func::FuncOp>(createResolveShapeConstraintPass());
pm.addNestedPass<func::FuncOp>(createBoundedShapeInferencePass());
pm.addNestedPass<func::FuncOp>(createCanonicalizeExtPass());
}
Loading

0 comments on commit 80408e2

Please sign in to comment.