Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[compiler] support shape reification for callOp #182

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,14 @@
#include <memory>

namespace mlir {
class ModuleOp;
// forward decl
namespace func {
class FuncOp;
} // namespace func

void populateHloToByreTensorPattern(
RewritePatternSet &patterns,
const llvm::StringMap<llvm::StringRef> &supportMap, bool appendArgTypes);

std::unique_ptr<OperationPass<func::FuncOp>>
std::unique_ptr<OperationPass<ModuleOp>>
createConvertHloToByreTensorPass(bool appendArgTypes = false);

} // namespace mlir
Expand Down
2 changes: 1 addition & 1 deletion compiler/include/byteir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def MhloToCat : Pass<"mhlo-to-cat", "func::FuncOp"> {
// HloToByreTensor
//===----------------------------------------------------------------------===//

def ConvertHloToByreTensor : Pass<"hlo-to-byre-tensor", "func::FuncOp"> {
def ConvertHloToByreTensor : Pass<"hlo-to-byre-tensor", "ModuleOp"> {
let summary = "Convert hlo op to byre tensor op.";
let constructor = "mlir::createConvertHloToByreTensorPass()";
let dependentDialects = [
Expand Down
1 change: 0 additions & 1 deletion compiler/include/byteir/Dialect/mhlo/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
#include "byteir/Dialect/mhlo/Transforms/LayoutTransformation.h"
#include "byteir/Dialect/mhlo/Transforms/MatmulLayoutTransform.h"
#include "byteir/Dialect/mhlo/Transforms/RewriteWithConstraint.h"
#include "byteir/Dialect/mhlo/Transforms/ShapeReification.h"
#include "byteir/Dialect/mhlo/Transforms/StaticShapeInference.h"
#include "byteir/Dialect/mhlo/Transforms/UnfuseBatchNorm.h"

Expand Down
19 changes: 0 additions & 19 deletions compiler/include/byteir/Dialect/mhlo/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -305,25 +305,6 @@ def RewriteWithConstraint : Pass<"rewrite-with-constraint", "mlir::func::FuncOp
let constructor = "mlir::createRewriteWithConstraintPass()";
}

//===----------------------------------------------------------------------===//
// ShapeReification
//===----------------------------------------------------------------------===//

def ShapeReification : Pass<"byteir-shape-reification", "func::FuncOp"> {
let summary = "Iteratively reify all shape computations.";
let description = [{
If an operation has a shape reification implementation, that is to say, we
know how to express the outputs' shape by it's inputs' shape symbolicly,
then a tensor.dim or shape.shape_of on this type of operation could be
reified. And shape reification procedure could be handled recursively.
}];
let constructor = "mlir::createByteIRShapeReificationPass()";
let dependentDialects = [
"mlir::shape::ShapeDialect",
"mlir::tensor::TensorDialect"
];
}

//===----------------------------------------------------------------------===//
// Static Shape Inference
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions compiler/include/byteir/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include "byteir/Transforms/RewriteOpToStdCall.h"
#include "byteir/Transforms/SetArgShape.h"
#include "byteir/Transforms/SetSpace.h"
#include "byteir/Transforms/ShapeReification.h"
#include "byteir/Transforms/TryCatchModulePipeline.h"

namespace mlir {
Expand Down
20 changes: 20 additions & 0 deletions compiler/include/byteir/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -425,4 +425,24 @@ def SetOpSpace: Pass<"set-op-space", "func::FuncOp"> {
];
}

//===----------------------------------------------------------------------===//
// ShapeReification
//===----------------------------------------------------------------------===//

def ShapeReification : Pass<"byteir-shape-reification", "func::FuncOp"> {
let summary = "Iteratively reify all shape computations.";
let description = [{
If an operation has a shape reification implementation, that is to say, we
know how to express the outputs' shape by it's inputs' shape symbolicly,
then a tensor.dim or shape.shape_of on this type of operation could be
reified. And shape reification procedure could be handled recursively.
}];
let constructor = "mlir::createByteIRShapeReificationPass()";
let dependentDialects = [
"mlir::shape::ShapeDialect",
"mlir::tensor::TensorDialect",
"mlir::arith::ArithDialect",
];
}

#endif // BYTEIR_TRANSFORMS_PASSES
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//===- ShapeReification.h -------------------------------------*--- C++ -*-===//
//
// Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved.
// Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
Expand Down
2 changes: 1 addition & 1 deletion compiler/lib/Analysis/SymbolicShape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
//===----------------------------------------------------------------------===//

#include "byteir/Analysis/SymbolicShape.h"
#include "byteir/Dialect/mhlo/Transforms/ShapeReification.h"
#include "byteir/Transforms/ShapeReification.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h"
Expand Down
6 changes: 3 additions & 3 deletions compiler/lib/Conversion/HloToByreTensor/HloToByreTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -768,15 +768,15 @@ struct ConvertHloToByreTensorPass
MLIRContext &ctx = getContext();
RewritePatternSet patterns(&ctx);
ConversionTarget target(ctx);
auto funcOp = getOperation();

populateHloToByreTensorPattern(patterns, supportMap, appendArgTypes);
target.addIllegalDialect<mhlo::MhloDialect>();
target.addLegalDialect<tensor::TensorDialect, byre::ByreDialect,
shape::ShapeDialect, arith::ArithDialect>();

FrozenRewritePatternSet frozenPatterns(std::move(patterns));
if (failed(applyPartialConversion(funcOp, target, frozenPatterns))) {
if (failed(
applyPartialConversion(getOperation(), target, frozenPatterns))) {
signalPassFailure();
}
}
Expand Down Expand Up @@ -810,7 +810,7 @@ void mlir::populateHloToByreTensorPattern(
ConvertSliceOp, ConvertConcatenateOp>(patterns.getContext());
}

std::unique_ptr<OperationPass<func::FuncOp>>
std::unique_ptr<OperationPass<ModuleOp>>
mlir::createConvertHloToByreTensorPass(bool appendArgTypes) {
return std::make_unique<ConvertHloToByreTensorPass>(appendArgTypes);
}
1 change: 0 additions & 1 deletion compiler/lib/Dialect/mhlo/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ add_mlir_dialect_library(ByteIRMhloPasses
Transforms/ReduceFusion.cpp
Transforms/ReshapeGather.cpp
Transforms/RewriteWithConstraint.cpp
Transforms/ShapeReification.cpp
Transforms/StaticShapeInference.cpp
Transforms/TrivialFusion.cpp
Transforms/UnfuseBatchNorm.cpp
Expand Down
217 changes: 217 additions & 0 deletions compiler/lib/Dialect/mhlo/Util/ShapeInferUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,24 @@

#include "byteir/Dialect/mhlo/Util/ShapeInferUtil.h"
#include "byteir/Dialect/mhlo/DynamicShapeOpRegister/Register.h"
#include "byteir/Transforms/ShapeReification.h"
#include "mhlo/IR/hlo_ops.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OwningOpRef.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/TopologicalSortUtils.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/Support/Debug.h"

#include <queue>
#include <string>

using namespace mlir;

#define DEBUG_TYPE "shape-infer-util"
Expand Down Expand Up @@ -177,6 +187,203 @@ mlir::inferReturnTypeComponents(llvm::StringRef name) {
return nullptr;
}

namespace {
bool deduceFromFuncArgShape(Value value) {
if (value.isa<BlockArgument>()) {
return false;
}

auto defOp = value.getDefiningOp();
if (!defOp) {
return false;
}

if (isa<arith::ConstantIndexOp, arith::ConstantOp>(defOp)) {
return true;
}

if (isa<tensor::DimOp, shape::ShapeOfOp>(defOp)) {
auto operand = defOp->getOperand(0);
if (operand.isa<BlockArgument>()) {
return true;
}
return false;
}

for (Value &&operand : defOp->getOperands()) {
if (!deduceFromFuncArgShape(operand)) {
return false;
}
}
return true;
}

// the auxiliaryModuleOp must be a empty module, only used for save shapeFunc
FailureOr<func::FuncOp>
createCorrespondingShapeFunc(func::FuncOp funcOp, ModuleOp auxiliaryModuleOp) {
// use auxiliary builder, create shape func in the start of auxiliaryModuleOp
ModuleOp oriModuleOp = funcOp->getParentOfType<ModuleOp>();
OpBuilder builder = OpBuilder::atBlockBegin(auxiliaryModuleOp.getBody());

// clone funcOp, newFuncOp used for deduce function shape
Twine shapeFuncName = funcOp.getName() + "_Shape";
auto shapeFunc = builder.create<func::FuncOp>(
funcOp->getLoc(), shapeFuncName.str(), funcOp.getFunctionType());
shapeFunc.setPrivate();
IRMapping emptyBvm;
funcOp.cloneInto(shapeFunc, emptyBvm);
llvm::DenseSet<func::CallOp> callOpSet;
shapeFunc.walk([&](func::CallOp callOp) { callOpSet.insert(callOp); });

while (!callOpSet.empty()) {
auto callOp = *callOpSet.begin();
callOpSet.erase(callOpSet.begin());
auto callFunc = oriModuleOp.lookupSymbol<func::FuncOp>(callOp.getCallee());
// inline this func.
builder.setInsertionPoint(callOp);
IRMapping bvm;
for (auto inputAndArg :
llvm::zip(callFunc.getArguments(), callOp.getOperands())) {
bvm.map(std::get<0>(inputAndArg), std::get<1>(inputAndArg));
}
Block &entryBlock = callFunc.getBlocks().front();
ValueRange funcOuts;
for (Operation &op : entryBlock) {
auto retOp = mlir::dyn_cast<func::ReturnOp>(op);
if (!retOp) {
auto newOp = builder.clone(op, bvm);
if (auto nestedCall = dyn_cast<func::CallOp>(newOp)) {
callOpSet.insert(nestedCall);
}
} else {
funcOuts = retOp.getOperands();
}
}

for (auto callResultAndFuncOuts :
llvm::zip(callOp.getResults(), funcOuts)) {
auto mappedOut = bvm.lookup(std::get<1>(callResultAndFuncOuts));
std::get<0>(callResultAndFuncOuts).replaceAllUsesWith(mappedOut);
}
callOp->erase();
}

// replace the operands of returnOp with corresponding shape
func::ReturnOp retOp = *shapeFunc.getOps<func::ReturnOp>().begin();
if (!retOp) {
shapeFunc->erase();
return failure();
}

for (Value &&retTensor : retOp.getOperands()) {
auto retTy = retTensor.getType();
if (!retTy.isa<RankedTensorType>()) {
shapeFunc->erase();
return failure();
}
}

SmallVector<Type> allResultTypes;
SmallVector<Value> allResults;

builder.setInsertionPoint(retOp);
for (Value &&retTensor : retOp.getOperands()) {
auto retShape = builder.create<shape::ShapeOfOp>(retOp.getLoc(), retTensor);
allResultTypes.emplace_back(retShape.getType());
allResults.emplace_back(retShape);
}

// return the shape of original tensor returned by function
auto shapeFuncRetOp =
builder.create<func::ReturnOp>(retOp.getLoc(), allResults);
auto shapeFuncType =
builder.getFunctionType(shapeFunc.getArgumentTypes(), allResultTypes);
shapeFunc.setFunctionType(shapeFuncType);
retOp->erase();

// reify shapeFunc to get the shape computation.
{
PassManager pm(oriModuleOp->getContext(), func::FuncOp::getOperationName());
// only run pass on shapeFunc
pm.addPass(createCanonicalizerPass());
pm.addPass(createCSEPass());
pm.addPass(createByteIRShapeReificationPass());
pm.addPass(createCanonicalizerPass());
pm.addPass(createCSEPass());
if (mlir::failed(pm.run(shapeFunc))) {
shapeFunc->erase();
return failure();
}
}
return shapeFunc;
}

LogicalResult reifyCallOp(OpBuilder &builder, Operation *op,
SmallVectorImpl<Value> &reifications) {
OpBuilder::InsertionGuard guard(builder);
auto callOp = dyn_cast<func::CallOp>(op);
if (!callOp) {
return failure();
}

ModuleOp moduleOp = op->getParentOfType<ModuleOp>();
StringRef funcName = callOp.getCallee();
auto funcOp = moduleOp.lookupSymbol<func::FuncOp>(funcName);

// create a temp module, then insert corresponding shape function to this
// module
OwningOpRef<ModuleOp> auxiliaryModuleOp(
ModuleOp::create(UnknownLoc::get(moduleOp->getContext())));
auto maybeShapeFunc =
createCorrespondingShapeFunc(funcOp, auxiliaryModuleOp.get());
if (failed(maybeShapeFunc)) {
return failure();
}

func::FuncOp shapeFunc = *maybeShapeFunc;
func::ReturnOp retOp = *shapeFunc.getOps<func::ReturnOp>().begin();

// collect all shape computation ops
SetVector<Operation *> reificationOpSet;
getBackwardSlice(retOp.getOperation(), &reificationOpSet);
SmallVector<Operation *> reificationOps(reificationOpSet.begin(),
reificationOpSet.end());
// value only depends on the shape of FuncArgs.
for (Value &&ret : retOp.getOperands()) {
if (!deduceFromFuncArgShape(ret)) {
shapeFunc->erase();
return failure();
}
}

// mapping the shape computation ops and collect reifications
{
mlir::computeTopologicalSorting(reificationOps);

IRMapping bvm;
size_t numArg = shapeFunc.getNumArguments();
for (size_t i = 0; i < numArg; ++i) {
bvm.map(shapeFunc.getArgument(i), callOp.getOperand(i));
}

builder.setInsertionPoint(callOp);

for (Operation *oldOp : reificationOps) {
auto newOp = builder.clone(*oldOp, bvm);
}

for (Value &&ret : retOp.getOperands()) {
reifications.push_back(bvm.lookup(ret));
}
}

// remove newFuncOp
shapeFunc->erase();
return success();
}

} // namespace

LogicalResult mlir::reifyShapes(OpBuilder &builder, Operation *op,
SmallVectorImpl<Value> &reifications) {
if (!op)
Expand Down Expand Up @@ -207,6 +414,16 @@ LogicalResult mlir::reifyShapes(OpBuilder &builder, Operation *op,
}
if (failed(inferFunc(op, builder, op->getOperands(), reifications)))
return failure();
} else if (auto callOp = dyn_cast<func::CallOp>(op)) {
if (failed(reifyCallOp(builder, op, reifications))) {
return failure();
}
} else if (auto dpsOp = dyn_cast<DestinationStyleOpInterface>(op)) {
for (OpResult &&result : op->getOpResults()) {
auto tiedOperand = dpsOp.getTiedOpOperand(result);
reifications.push_back(
builder.create<shape::ShapeOfOp>(op->getLoc(), tiedOperand->get()));
}
} else {
// Return failure if op doesn't have InferShapedTypeOpInterface and not
// registered.
Expand Down
Loading
Loading