From 0f0a42f6c87e348e9439145a02b8c1d348257fb2 Mon Sep 17 00:00:00 2001 From: zhengxuegui Date: Sun, 7 Apr 2024 10:57:44 +0800 Subject: [PATCH 1/5] [compiler] support shape reification for callOp --- compiler/include/byteir/Dialect/mhlo/Passes.h | 1 - .../include/byteir/Dialect/mhlo/Passes.td | 19 -- compiler/include/byteir/Transforms/Passes.h | 1 + compiler/include/byteir/Transforms/Passes.td | 19 ++ .../mhlo => }/Transforms/ShapeReification.h | 2 +- compiler/lib/Analysis/SymbolicShape.cpp | 2 +- compiler/lib/Dialect/mhlo/CMakeLists.txt | 1 - .../lib/Dialect/mhlo/Util/ShapeInferUtil.cpp | 178 ++++++++++++++++++ compiler/lib/Pipelines/ByreTensorOpt.cpp | 1 + compiler/lib/Transforms/CMakeLists.txt | 1 + compiler/lib/Transforms/PassDetail.h | 8 + .../mhlo => }/Transforms/ShapeReification.cpp | 4 +- .../FuncToByre/func_to_byre_tensor.mlir | 17 ++ .../test/Transforms/shapeReification.mlir | 57 +++++- 14 files changed, 285 insertions(+), 26 deletions(-) rename compiler/include/byteir/{Dialect/mhlo => }/Transforms/ShapeReification.h (94%) rename compiler/lib/{Dialect/mhlo => }/Transforms/ShapeReification.cpp (97%) diff --git a/compiler/include/byteir/Dialect/mhlo/Passes.h b/compiler/include/byteir/Dialect/mhlo/Passes.h index 351b071df..58ebeb6f7 100644 --- a/compiler/include/byteir/Dialect/mhlo/Passes.h +++ b/compiler/include/byteir/Dialect/mhlo/Passes.h @@ -36,7 +36,6 @@ #include "byteir/Dialect/mhlo/Transforms/LayoutTransformation.h" #include "byteir/Dialect/mhlo/Transforms/MatmulLayoutTransform.h" #include "byteir/Dialect/mhlo/Transforms/RewriteWithConstraint.h" -#include "byteir/Dialect/mhlo/Transforms/ShapeReification.h" #include "byteir/Dialect/mhlo/Transforms/StaticShapeInference.h" #include "byteir/Dialect/mhlo/Transforms/UnfuseBatchNorm.h" diff --git a/compiler/include/byteir/Dialect/mhlo/Passes.td b/compiler/include/byteir/Dialect/mhlo/Passes.td index 7fe03f8f8..58f35033d 100644 --- a/compiler/include/byteir/Dialect/mhlo/Passes.td +++ b/compiler/include/byteir/Dialect/mhlo/Passes.td @@ -305,25 +305,6 @@ def RewriteWithConstraint : Pass<"rewrite-with-constraint", "mlir::func::FuncOp let constructor = "mlir::createRewriteWithConstraintPass()"; } -//===----------------------------------------------------------------------===// -// ShapeReification -//===----------------------------------------------------------------------===// - -def ShapeReification : Pass<"byteir-shape-reification", "func::FuncOp"> { - let summary = "Iteratively reify all shape computations."; - let description = [{ - If an operation has a shape reification implementation, that is to say, we - know how to express the outputs' shape by it's inputs' shape symbolicly, - then a tensor.dim or shape.shape_of on this type of operation could be - reified. And shape reification procedure could be handled recursively. - }]; - let constructor = "mlir::createByteIRShapeReificationPass()"; - let dependentDialects = [ - "mlir::shape::ShapeDialect", - "mlir::tensor::TensorDialect" - ]; -} - //===----------------------------------------------------------------------===// // Static Shape Inference //===----------------------------------------------------------------------===// diff --git a/compiler/include/byteir/Transforms/Passes.h b/compiler/include/byteir/Transforms/Passes.h index 7a179a34e..0f4beb1ac 100644 --- a/compiler/include/byteir/Transforms/Passes.h +++ b/compiler/include/byteir/Transforms/Passes.h @@ -35,6 +35,7 @@ #include "byteir/Transforms/RewriteOpToStdCall.h" #include "byteir/Transforms/SetArgShape.h" #include "byteir/Transforms/SetSpace.h" +#include "byteir/Transforms/ShapeReification.h" #include "byteir/Transforms/TryCatchModulePipeline.h" namespace mlir { diff --git a/compiler/include/byteir/Transforms/Passes.td b/compiler/include/byteir/Transforms/Passes.td index 97d69c022..17e88c1e6 100644 --- a/compiler/include/byteir/Transforms/Passes.td +++ b/compiler/include/byteir/Transforms/Passes.td @@ -425,4 +425,23 @@ def SetOpSpace: Pass<"set-op-space", "func::FuncOp"> { ]; } +//===----------------------------------------------------------------------===// +// ShapeReification +//===----------------------------------------------------------------------===// + +def ShapeReification : Pass<"byteir-shape-reification", "func::FuncOp"> { + let summary = "Iteratively reify all shape computations."; + let description = [{ + If an operation has a shape reification implementation, that is to say, we + know how to express the outputs' shape by it's inputs' shape symbolicly, + then a tensor.dim or shape.shape_of on this type of operation could be + reified. And shape reification procedure could be handled recursively. + }]; + let constructor = "mlir::createByteIRShapeReificationPass()"; + let dependentDialects = [ + "mlir::shape::ShapeDialect", + "mlir::tensor::TensorDialect" + ]; +} + #endif // BYTEIR_TRANSFORMS_PASSES diff --git a/compiler/include/byteir/Dialect/mhlo/Transforms/ShapeReification.h b/compiler/include/byteir/Transforms/ShapeReification.h similarity index 94% rename from compiler/include/byteir/Dialect/mhlo/Transforms/ShapeReification.h rename to compiler/include/byteir/Transforms/ShapeReification.h index 19f338f22..7c4cb5043 100644 --- a/compiler/include/byteir/Dialect/mhlo/Transforms/ShapeReification.h +++ b/compiler/include/byteir/Transforms/ShapeReification.h @@ -1,6 +1,6 @@ //===- ShapeReification.h -------------------------------------*--- C++ -*-===// // -// Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved. // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at diff --git a/compiler/lib/Analysis/SymbolicShape.cpp b/compiler/lib/Analysis/SymbolicShape.cpp index 1f5d9f499..703dec1c4 100644 --- a/compiler/lib/Analysis/SymbolicShape.cpp +++ b/compiler/lib/Analysis/SymbolicShape.cpp @@ -16,7 +16,7 @@ //===----------------------------------------------------------------------===// #include "byteir/Analysis/SymbolicShape.h" -#include "byteir/Dialect/mhlo/Transforms/ShapeReification.h" +#include "byteir/Transforms/ShapeReification.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/IRMapping.h" diff --git a/compiler/lib/Dialect/mhlo/CMakeLists.txt b/compiler/lib/Dialect/mhlo/CMakeLists.txt index 81667fb71..a6501cf0f 100644 --- a/compiler/lib/Dialect/mhlo/CMakeLists.txt +++ b/compiler/lib/Dialect/mhlo/CMakeLists.txt @@ -105,7 +105,6 @@ add_mlir_dialect_library(ByteIRMhloPasses Transforms/ReduceFusion.cpp Transforms/ReshapeGather.cpp Transforms/RewriteWithConstraint.cpp - Transforms/ShapeReification.cpp Transforms/StaticShapeInference.cpp Transforms/TrivialFusion.cpp Transforms/UnfuseBatchNorm.cpp diff --git a/compiler/lib/Dialect/mhlo/Util/ShapeInferUtil.cpp b/compiler/lib/Dialect/mhlo/Util/ShapeInferUtil.cpp index 0bf8250b5..965adc765 100644 --- a/compiler/lib/Dialect/mhlo/Util/ShapeInferUtil.cpp +++ b/compiler/lib/Dialect/mhlo/Util/ShapeInferUtil.cpp @@ -17,13 +17,19 @@ #include "byteir/Dialect/mhlo/Util/ShapeInferUtil.h" #include "byteir/Dialect/mhlo/DynamicShapeOpRegister/Register.h" +#include "byteir/Transforms/ShapeReification.h" #include "mhlo/IR/hlo_ops.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/TopologicalSortUtils.h" #include "llvm/ADT/StringMap.h" #include "llvm/Support/Debug.h" +#include +#include using namespace mlir; @@ -177,6 +183,168 @@ mlir::inferReturnTypeComponents(llvm::StringRef name) { return nullptr; } +namespace { + +SmallVector collectAllOpsForReturn(Operation *retOp) { + llvm::DenseSet visitedOp; + std::queue opQueue; + + opQueue.push(retOp); + while (!opQueue.empty()) { + auto frontOp = opQueue.front(); + opQueue.pop(); + if (visitedOp.find(frontOp) != visitedOp.end()) { + continue; + } + visitedOp.insert(frontOp); + for (Value operand : frontOp->getOperands()) { + if (!operand.getDefiningOp()) { + continue; + } + if (Operation *defOp = operand.getDefiningOp()) { + opQueue.push(defOp); + } + } + } + visitedOp.erase(retOp); + return SmallVector(visitedOp.begin(), visitedOp.end()); +} + +bool deduceFromFuncArgShape(Value value) { + if (value.isa()) { + return false; + } + + auto defOp = value.getDefiningOp(); + if (!defOp) { + return false; + } + + if (isa(defOp)) { + return true; + } + + if (isa(defOp)) { + auto operand = defOp->getOperand(0); + if (operand.isa()) { + return true; + } + return false; + } + + for (Value &&operand : defOp->getOperands()) { + if (!deduceFromFuncArgShape(operand)) { + return false; + } + } + return true; +} + +LogicalResult reifyCallOp(OpBuilder &builder, Operation *op, + SmallVectorImpl &reifications) { + OpBuilder::InsertionGuard guard(builder); + auto callOp = dyn_cast(op); + if (!callOp) { + return failure(); + } + + ModuleOp moduleOp = op->getParentRegion()->getParentOfType(); + // auxiliary builder used for create operations in shape func + // original builder maybe a rewriter, used for create operations in specific + // pattern. + OpBuilder auxiliaryBuilder(moduleOp); + StringRef funcName = callOp.getCallee(); + auto funcOp = moduleOp.lookupSymbol(funcName); + + // clone funcOp, newFuncOp used for deduce function shape + std::string newFuncName = funcName.str() + "_Shape"; + auxiliaryBuilder.setInsertionPointToStart(moduleOp.getBody()); + auto newFuncOp = auxiliaryBuilder.create( + funcOp->getLoc(), newFuncName, funcOp.getFunctionType()); + newFuncOp.setPrivate(); + IRMapping emptyBvm; + funcOp.cloneInto(newFuncOp, emptyBvm); + + // replace the operands of returnOp with corresponding shape + func::ReturnOp retOp = *newFuncOp.getOps().begin(); + if (!retOp) { + newFuncOp->erase(); + return failure(); + } + + SmallVector allResultTypes; + SmallVector allResults; + + auxiliaryBuilder.setInsertionPoint(retOp); + for (Value &&retTensor : retOp.getOperands()) { + auto retShape = + auxiliaryBuilder.create(retOp.getLoc(), retTensor); + allResultTypes.emplace_back(retShape.getType()); + allResults.emplace_back(retShape); + } + + // return the shape of original tensor returned by function + auto newRetOp = + auxiliaryBuilder.create(retOp.getLoc(), allResults); + auto newFuncType = auxiliaryBuilder.getFunctionType( + newFuncOp.getArgumentTypes(), allResultTypes); + newFuncOp.setFunctionType(newFuncType); + retOp->erase(); + + // reify newFunc to get the shape computation for current callOp + { + PassManager pm(moduleOp->getContext(), func::FuncOp::getOperationName()); + pm.addPass(createCanonicalizerPass()); + pm.addPass(createCSEPass()); + pm.addPass(createByteIRShapeReificationPass()); + pm.addPass(createCanonicalizerPass()); + pm.addPass(createCSEPass()); + + if (mlir::failed(pm.run(newFuncOp))) { + newFuncOp->erase(); + return failure(); + } + } + + // collect all shape computation ops + SmallVector reificationOps = collectAllOpsForReturn(newRetOp); + + // value only depends on the shape of FuncArgs. + for (Value &&ret : newRetOp.getOperands()) { + if (!deduceFromFuncArgShape(ret)) { + newFuncOp->erase(); + return failure(); + } + } + + // mapping the shape computation ops and collect reifications + { + mlir::computeTopologicalSorting(reificationOps); + + IRMapping bvm; + size_t numArg = newFuncOp.getNumArguments(); + for (size_t i = 0; i < numArg; ++i) { + bvm.map(newFuncOp.getArgument(i), callOp.getOperand(i)); + } + + builder.setInsertionPoint(callOp); + + for (Operation *oldOp : reificationOps) { + auto newOp = builder.clone(*oldOp, bvm); + } + + for (Value &&ret : newRetOp.getOperands()) { + reifications.push_back(bvm.lookup(ret)); + } + } + + // remove newFuncOp + newFuncOp->erase(); + return success(); +} + +} // namespace + LogicalResult mlir::reifyShapes(OpBuilder &builder, Operation *op, SmallVectorImpl &reifications) { if (!op) @@ -207,6 +375,16 @@ LogicalResult mlir::reifyShapes(OpBuilder &builder, Operation *op, } if (failed(inferFunc(op, builder, op->getOperands(), reifications))) return failure(); + } else if (auto callOp = dyn_cast(op)) { + if (failed(reifyCallOp(builder, op, reifications))) { + return failure(); + } + } else if (auto dpsOp = dyn_cast(op)) { + for (OpResult &&result : op->getOpResults()) { + auto tiedOperand = dpsOp.getTiedOpOperand(result); + reifications.push_back( + builder.create(op->getLoc(), tiedOperand->get())); + } } else { // Return failure if op doesn't have InferShapedTypeOpInterface and not // registered. diff --git a/compiler/lib/Pipelines/ByreTensorOpt.cpp b/compiler/lib/Pipelines/ByreTensorOpt.cpp index 5b1f710ad..4d5c2b5c6 100644 --- a/compiler/lib/Pipelines/ByreTensorOpt.cpp +++ b/compiler/lib/Pipelines/ByreTensorOpt.cpp @@ -47,6 +47,7 @@ void createByreTensorOptPipelineImpl(OpPassManager &pm, std::string entryFunc, createConvertHloToByreCustomPass(getCudaByreCustomConfig())); pm.addNestedPass( createConvertHloToByreTensorPass(appendArgTypes)); + pm.addNestedPass(createByteIRShapeReificationPass()); pm.addPass(createCanonicalizerPass()); } } // namespace diff --git a/compiler/lib/Transforms/CMakeLists.txt b/compiler/lib/Transforms/CMakeLists.txt index 9ac510696..3ab4a25ab 100644 --- a/compiler/lib/Transforms/CMakeLists.txt +++ b/compiler/lib/Transforms/CMakeLists.txt @@ -17,6 +17,7 @@ add_mlir_library(ByteIRTransforms RewriteOpToStdCall.cpp SetArgShape.cpp SetSpace.cpp + ShapeReification.cpp Utils.cpp ADDITIONAL_HEADER_DIRS diff --git a/compiler/lib/Transforms/PassDetail.h b/compiler/lib/Transforms/PassDetail.h index f0cf6f3fa..05ade22f2 100644 --- a/compiler/lib/Transforms/PassDetail.h +++ b/compiler/lib/Transforms/PassDetail.h @@ -51,6 +51,14 @@ namespace scf { class SCFDialect; } // namespace scf +namespace shape { +class ShapeDialect; +} // namespace shape + +namespace tensor { +class TensorDialect; +} // namespace tensor + #define GEN_PASS_CLASSES #include "byteir/Transforms/Passes.h.inc" diff --git a/compiler/lib/Dialect/mhlo/Transforms/ShapeReification.cpp b/compiler/lib/Transforms/ShapeReification.cpp similarity index 97% rename from compiler/lib/Dialect/mhlo/Transforms/ShapeReification.cpp rename to compiler/lib/Transforms/ShapeReification.cpp index 7b6c1b548..382e2ed10 100644 --- a/compiler/lib/Dialect/mhlo/Transforms/ShapeReification.cpp +++ b/compiler/lib/Transforms/ShapeReification.cpp @@ -1,6 +1,6 @@ //===- ShapeReification.cpp -----------------------------------*--- C++ -*-===// // -// Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. +// Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved. // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -15,7 +15,7 @@ // //===----------------------------------------------------------------------===// -#include "byteir/Dialect/mhlo/Transforms/ShapeReification.h" +#include "byteir/Transforms/ShapeReification.h" #include "byteir/Dialect/mhlo/DynamicShapeOpRegister/Register.h" #include "byteir/Dialect/mhlo/Util/ShapeInferUtil.h" diff --git a/compiler/test/Conversion/FuncToByre/func_to_byre_tensor.mlir b/compiler/test/Conversion/FuncToByre/func_to_byre_tensor.mlir index 31a468e7d..4821a5915 100644 --- a/compiler/test/Conversion/FuncToByre/func_to_byre_tensor.mlir +++ b/compiler/test/Conversion/FuncToByre/func_to_byre_tensor.mlir @@ -20,3 +20,20 @@ func.func @test_normal_function_call(%arg0 : tensor<4xf32>) -> tensor<4xf32> att } // CHECK-LABEL: test_normal_function_call // CHECK: call @some_func + + +// ----- + +func.func private @Unknown0(%arg0: tensor, %arg1: tensor) -> tensor attributes {__byteir_elementwise_fusion__, byre_compute_name = "Unknown0"} { + %0 = mhlo.add %arg0, %arg1 : tensor + return %0 : tensor +} + +func.func @forward(%arg0: tensor, %arg1: tensor) -> tensor attributes {__placeholder__byre.entry_point} { + %1 = call @Unknown0(%arg1, %arg0) : (tensor, tensor) -> tensor + return %1 : tensor +} + +// CHECK-LABEL: func.func @forward +// CHECK: tensor.empty +// CHECK-NEXT: byre.compute_on_tensor @Unknown0 diff --git a/compiler/test/Transforms/shapeReification.mlir b/compiler/test/Transforms/shapeReification.mlir index d1b3cd530..157d5e176 100644 --- a/compiler/test/Transforms/shapeReification.mlir +++ b/compiler/test/Transforms/shapeReification.mlir @@ -1,4 +1,4 @@ -// RUN: byteir-opt %s -byteir-shape-reification -canonicalize -cse | FileCheck %s +// RUN: byteir-opt %s --split-input-file -byteir-shape-reification -canonicalize -cse | FileCheck %s func.func @several_ops(%arg0: tensor, %arg1: tensor<2x4xf32>, %arg2: tensor<4xf32>) -> (!shape.shape, !shape.shape, !shape.shape, !shape.shape) { %0 = "mhlo.dot"(%arg0, %arg1) : (tensor, tensor<2x4xf32>) -> tensor @@ -26,6 +26,8 @@ func.func @several_ops(%arg0: tensor, %arg1: tensor<2x4xf32>, %arg2: te // CHECK-DAG: %[[V3:.+]] = shape.value_as_shape %[[C2]] : tensor<1xindex> -> !shape.shape // CHECK-DAG: return %[[V2]], %[[V3]], %[[V2]], %[[V2]] : !shape.shape, !shape.shape, !shape.shape, !shape.shape +// ----- + // CHECK-LABEL: @infer_shape_using_dim_op func.func @infer_shape_using_dim_op(%arg0: tensor, %arg1: tensor, %arg2: tensor<4x4xf32>) -> !shape.shape { %0 = mhlo.add %arg0, %arg1 : tensor @@ -40,6 +42,8 @@ func.func @infer_shape_using_dim_op(%arg0: tensor, %arg1: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> tensor { %0 = "mhlo.custom_call"(%arg0, %arg1, %arg2, %arg3) {call_target_name = "tf.DynamicStitch", has_side_effect = false} : (tensor, tensor, tensor, tensor) -> tensor %c0 = arith.constant 0 : index @@ -52,6 +56,8 @@ func.func @dynamic_stitch(%arg0: tensor, %arg1: tensor, %arg2: ten return %0 : tensor } +// ----- + func.func @gelu(%arg0: tensor) -> tensor { %0 = mhlo.custom_call @byteir.gelu(%arg0) {backend_config = "", byteir_attrs = {approximate = "erf"}} : (tensor) -> tensor %c0 = arith.constant 0 : index @@ -62,6 +68,8 @@ func.func @gelu(%arg0: tensor) -> tensor { return %0 : tensor } +// ----- + // CHECK-LABEL: func.func @dot_general func.func @dot_general(%arg0: tensor, %arg1: tensor) -> tensor<3xindex> { %c1 = arith.constant 1 : index @@ -80,11 +88,14 @@ func.func @dot_general(%arg0: tensor, %arg1: tensor) -> return %3 : tensor<3xindex> } +// ----- + // TODO: Check this after nested function call is supported func.func private @inner_func(%arg0 : tensor, %arg1 : tensor) -> tensor { %0 = mhlo.add %arg0, %arg1 : tensor return %0 : tensor } + func.func @outer_func(%arg0: tensor, %arg1: tensor) -> (!shape.shape, !shape.shape) { %0 = mhlo.add %arg0, %arg1 : tensor %1 = shape.shape_of %0 : tensor -> tensor<2xindex> @@ -94,3 +105,47 @@ func.func @outer_func(%arg0: tensor, %arg1: tensor) -> (!shape %5 = shape.value_as_shape %4 : tensor<2xindex> -> !shape.shape return %2, %5 : !shape.shape, !shape.shape } +// CHECK-LABEL: func.func @outer_func +// CHECK: %[[V0:.*]] = shape.shape_of %arg0 : tensor -> tensor<2xindex> +// CHECK: %[[V1:.*]] = shape.value_as_shape %1 : tensor<2xindex> -> !shape.shape +// CHECK: return %[[V1]], %[[V1]] : !shape.shape, !shape.shape + +// ----- + +func.func private @Unknown1(%arg0: tensor, %arg1: tensor) -> tensor attributes {__byteir_matmul_epilogue_fusion__} { + %0 = mhlo.constant dense_resource<__elided__> : tensor<10x20xf32> + %1 = "mhlo.dot"(%arg0, %0) : (tensor, tensor<10x20xf32>) -> tensor + %2 = mhlo.add %1, %arg1 : tensor + return %2 : tensor +} + +func.func private @Unknown0(%arg0: tensor, %arg1: tensor<20xf32>, %arg2: tensor, %arg3: tensor) -> (tensor, tensor) { + %0 = mhlo.constant dense<0.000000e+00> : tensor + %c20 = arith.constant 20 : index + %c0 = arith.constant 0 : index + %dim = tensor.dim %arg0, %c0 : tensor + %from_elements = tensor.from_elements %dim, %c20 : tensor<2xindex> + %1 = "mhlo.dynamic_broadcast_in_dim"(%arg1, %from_elements) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<20xf32>, tensor<2xindex>) -> tensor + %2 = mhlo.add %arg2, %1 : tensor + %3 = call @Unknown1(%arg0, %2) : (tensor, tensor) -> tensor + %4 = mhlo.maximum %2, %3 : tensor + return %4, %3 : tensor, tensor +} + +func.func @forward(%arg0: tensor, %arg1: tensor, %arg2: tensor<20x?xf32>) -> tensor<2xindex> attributes {__placeholder__byre.entry_point} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = mhlo.constant dense_resource<__elided__> : tensor<10x20xf32> + %1 = mhlo.constant dense_resource<__elided__> : tensor<20xf32> + %2 = "mhlo.dot"(%arg0, %0) : (tensor, tensor<10x20xf32>) -> tensor + %3:2 = call @Unknown0(%arg0, %1, %2, %arg1) : (tensor, tensor<20xf32>, tensor, tensor) -> (tensor, tensor) + %4 = "mhlo.dot"(%3#0, %arg2) : (tensor, tensor<20x?xf32>) -> tensor + %5 = shape.shape_of %4 : tensor -> tensor<2xindex> + return %5 : tensor<2xindex> +} + +// CHECK-LABEL: func.func @forward +// CHECK: %[[DIM:.*]] = tensor.dim %arg0, %c0 : tensor +// CHECK: %[[DIM0:.*]] = tensor.dim %arg2, %c1 : tensor<20x?xf32> +// CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[DIM:.*]], %[[DIM0:.*]] : tensor<2xindex> +// CHECK: return %[[SHAPE:.*]] : tensor<2xindex> \ No newline at end of file From da7551a6bdbbeba5f38a3039bf2204643f393fc4 Mon Sep 17 00:00:00 2001 From: zhengxuegui Date: Sun, 7 Apr 2024 11:16:54 +0800 Subject: [PATCH 2/5] [compiler] check return type of func --- compiler/lib/Dialect/mhlo/Util/ShapeInferUtil.cpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/compiler/lib/Dialect/mhlo/Util/ShapeInferUtil.cpp b/compiler/lib/Dialect/mhlo/Util/ShapeInferUtil.cpp index 965adc765..5f2abec54 100644 --- a/compiler/lib/Dialect/mhlo/Util/ShapeInferUtil.cpp +++ b/compiler/lib/Dialect/mhlo/Util/ShapeInferUtil.cpp @@ -272,6 +272,14 @@ LogicalResult reifyCallOp(OpBuilder &builder, Operation *op, return failure(); } + for (Value &&retTensor : retOp.getOperands()) { + auto retTy = retTensor.getType(); + if (!retTy.isa()) { + newFuncOp->erase(); + return failure(); + } + } + SmallVector allResultTypes; SmallVector allResults; From 7363a9c90f0f1e11ed828c6e6f4bf5a6c05b4354 Mon Sep 17 00:00:00 2001 From: zhengxuegui Date: Mon, 8 Apr 2024 17:48:14 +0800 Subject: [PATCH 3/5] [compiler] fix error cast in shape reification pass --- compiler/include/byteir/Transforms/Passes.td | 3 ++- compiler/lib/Transforms/PassDetail.h | 4 ++++ compiler/lib/Transforms/ShapeReification.cpp | 5 +++-- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/compiler/include/byteir/Transforms/Passes.td b/compiler/include/byteir/Transforms/Passes.td index 17e88c1e6..b3d3e9853 100644 --- a/compiler/include/byteir/Transforms/Passes.td +++ b/compiler/include/byteir/Transforms/Passes.td @@ -440,7 +440,8 @@ def ShapeReification : Pass<"byteir-shape-reification", "func::FuncOp"> { let constructor = "mlir::createByteIRShapeReificationPass()"; let dependentDialects = [ "mlir::shape::ShapeDialect", - "mlir::tensor::TensorDialect" + "mlir::tensor::TensorDialect", + "mlir::arith::ArithDialect", ]; } diff --git a/compiler/lib/Transforms/PassDetail.h b/compiler/lib/Transforms/PassDetail.h index 05ade22f2..0a63eda78 100644 --- a/compiler/lib/Transforms/PassDetail.h +++ b/compiler/lib/Transforms/PassDetail.h @@ -43,6 +43,10 @@ namespace memref { class MemRefDialect; } // namespace memref +namespace arith { +class ArithDialect; +} // namespace arith + namespace mhlo { class MhloDialect; } // namespace mhlo diff --git a/compiler/lib/Transforms/ShapeReification.cpp b/compiler/lib/Transforms/ShapeReification.cpp index 382e2ed10..f196d93a6 100644 --- a/compiler/lib/Transforms/ShapeReification.cpp +++ b/compiler/lib/Transforms/ShapeReification.cpp @@ -20,6 +20,7 @@ #include "byteir/Dialect/mhlo/DynamicShapeOpRegister/Register.h" #include "byteir/Dialect/mhlo/Util/ShapeInferUtil.h" #include "mhlo/IR/hlo_ops.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -59,8 +60,8 @@ struct ShapeReificationOnTensorDimPattern // Insert cast, if needed. if (dimOfShape.getType() != op.getType()) { - dimOfShape = rewriter.create(op.getLoc(), op.getType(), - dimOfShape); + dimOfShape = rewriter.create( + op.getLoc(), op.getType(), dimOfShape); } rewriter.replaceOp(op, dimOfShape); From 80408e2a37afaa1a7828f996fb524b593fb22fe8 Mon Sep 17 00:00:00 2001 From: zhengxuegui Date: Wed, 10 Apr 2024 19:02:34 +0800 Subject: [PATCH 4/5] [compiler] modified top level of shapeReification to moduleOp --- .../HloToByreTensor/HloToByreTensor.h | 6 +- compiler/include/byteir/Conversion/Passes.td | 2 +- compiler/include/byteir/Transforms/Passes.td | 7 +- .../byteir/Transforms/ShapeReification.h | 7 +- compiler/lib/Analysis/SymbolicShape.cpp | 18 ++- .../HloToByreTensor/HloToByreTensor.cpp | 6 +- .../lib/Dialect/mhlo/Util/ShapeInferUtil.cpp | 145 +++++++++--------- compiler/lib/Pipelines/ByreTensorOpt.cpp | 5 +- compiler/lib/Pipelines/ShapeOpt.cpp | 22 ++- compiler/lib/Transforms/ShapeReification.cpp | 22 ++- 10 files changed, 125 insertions(+), 115 deletions(-) diff --git a/compiler/include/byteir/Conversion/HloToByreTensor/HloToByreTensor.h b/compiler/include/byteir/Conversion/HloToByreTensor/HloToByreTensor.h index f045933ba..8b9e36c79 100644 --- a/compiler/include/byteir/Conversion/HloToByreTensor/HloToByreTensor.h +++ b/compiler/include/byteir/Conversion/HloToByreTensor/HloToByreTensor.h @@ -24,16 +24,14 @@ #include namespace mlir { +class ModuleOp; // forward decl -namespace func { -class FuncOp; -} // namespace func void populateHloToByreTensorPattern( RewritePatternSet &patterns, const llvm::StringMap &supportMap, bool appendArgTypes); -std::unique_ptr> +std::unique_ptr> createConvertHloToByreTensorPass(bool appendArgTypes = false); } // namespace mlir diff --git a/compiler/include/byteir/Conversion/Passes.td b/compiler/include/byteir/Conversion/Passes.td index ead522d1c..f98447280 100644 --- a/compiler/include/byteir/Conversion/Passes.td +++ b/compiler/include/byteir/Conversion/Passes.td @@ -267,7 +267,7 @@ def MhloToCat : Pass<"mhlo-to-cat", "func::FuncOp"> { // HloToByreTensor //===----------------------------------------------------------------------===// -def ConvertHloToByreTensor : Pass<"hlo-to-byre-tensor", "func::FuncOp"> { +def ConvertHloToByreTensor : Pass<"hlo-to-byre-tensor", "ModuleOp"> { let summary = "Convert hlo op to byre tensor op."; let constructor = "mlir::createConvertHloToByreTensorPass()"; let dependentDialects = [ diff --git a/compiler/include/byteir/Transforms/Passes.td b/compiler/include/byteir/Transforms/Passes.td index b3d3e9853..4bc351390 100644 --- a/compiler/include/byteir/Transforms/Passes.td +++ b/compiler/include/byteir/Transforms/Passes.td @@ -429,7 +429,7 @@ def SetOpSpace: Pass<"set-op-space", "func::FuncOp"> { // ShapeReification //===----------------------------------------------------------------------===// -def ShapeReification : Pass<"byteir-shape-reification", "func::FuncOp"> { +def ShapeReification : Pass<"byteir-shape-reification", "ModuleOp"> { let summary = "Iteratively reify all shape computations."; let description = [{ If an operation has a shape reification implementation, that is to say, we @@ -443,6 +443,11 @@ def ShapeReification : Pass<"byteir-shape-reification", "func::FuncOp"> { "mlir::tensor::TensorDialect", "mlir::arith::ArithDialect", ]; + let options = [ + Option<"anchorFunc", "anchor-func", "std::string", + /*default=*/"", + "An optional funcName used to specify the target function.">, + ]; } #endif // BYTEIR_TRANSFORMS_PASSES diff --git a/compiler/include/byteir/Transforms/ShapeReification.h b/compiler/include/byteir/Transforms/ShapeReification.h index 7c4cb5043..4cc2365e8 100644 --- a/compiler/include/byteir/Transforms/ShapeReification.h +++ b/compiler/include/byteir/Transforms/ShapeReification.h @@ -22,11 +22,10 @@ #include namespace mlir { -namespace func { -class FuncOp; -} // namespace func +class ModuleOp; -std::unique_ptr> createByteIRShapeReificationPass(); +std::unique_ptr> +createByteIRShapeReificationPass(llvm::StringRef anchorFunc = ""); } // namespace mlir diff --git a/compiler/lib/Analysis/SymbolicShape.cpp b/compiler/lib/Analysis/SymbolicShape.cpp index 703dec1c4..e8800ddfc 100644 --- a/compiler/lib/Analysis/SymbolicShape.cpp +++ b/compiler/lib/Analysis/SymbolicShape.cpp @@ -118,12 +118,20 @@ SymbolicShapeAnalysis::SymbolicShapeAnalysis(ModuleOp moduleOp) } // run shape reification pass on all the auxiliary functions - PassManager pm(moduleOp->getContext(), func::FuncOp::getOperationName()); - pm.addPass(createByteIRShapeReificationPass()); - pm.addPass(createCSEPass()); for (auto funcOp : shpFuncOps) { - if (mlir::failed(pm.run(funcOp))) { - llvm::errs() << "Pass pipeline inside symbolic shape analysis failed."; + { + PassManager pm(moduleOp->getContext(), moduleOp.getOperationName()); + pm.addPass(createByteIRShapeReificationPass(funcOp.getName())); + if (mlir::failed(pm.run(moduleOp))) { + llvm::errs() << "Pass pipeline inside symbolic shape analysis failed."; + } + } + { + PassManager pm(moduleOp->getContext(), func::FuncOp::getOperationName()); + pm.addPass(createCSEPass()); + if (mlir::failed(pm.run(funcOp))) { + llvm::errs() << "Pass pipeline inside symbolic shape analysis failed."; + } } } } diff --git a/compiler/lib/Conversion/HloToByreTensor/HloToByreTensor.cpp b/compiler/lib/Conversion/HloToByreTensor/HloToByreTensor.cpp index 91104ec58..dae0bbdc1 100644 --- a/compiler/lib/Conversion/HloToByreTensor/HloToByreTensor.cpp +++ b/compiler/lib/Conversion/HloToByreTensor/HloToByreTensor.cpp @@ -768,7 +768,6 @@ struct ConvertHloToByreTensorPass MLIRContext &ctx = getContext(); RewritePatternSet patterns(&ctx); ConversionTarget target(ctx); - auto funcOp = getOperation(); populateHloToByreTensorPattern(patterns, supportMap, appendArgTypes); target.addIllegalDialect(); @@ -776,7 +775,8 @@ struct ConvertHloToByreTensorPass shape::ShapeDialect, arith::ArithDialect>(); FrozenRewritePatternSet frozenPatterns(std::move(patterns)); - if (failed(applyPartialConversion(funcOp, target, frozenPatterns))) { + if (failed( + applyPartialConversion(getOperation(), target, frozenPatterns))) { signalPassFailure(); } } @@ -810,7 +810,7 @@ void mlir::populateHloToByreTensorPattern( ConvertSliceOp, ConvertConcatenateOp>(patterns.getContext()); } -std::unique_ptr> +std::unique_ptr> mlir::createConvertHloToByreTensorPass(bool appendArgTypes) { return std::make_unique(appendArgTypes); } diff --git a/compiler/lib/Dialect/mhlo/Util/ShapeInferUtil.cpp b/compiler/lib/Dialect/mhlo/Util/ShapeInferUtil.cpp index 5f2abec54..03e48010b 100644 --- a/compiler/lib/Dialect/mhlo/Util/ShapeInferUtil.cpp +++ b/compiler/lib/Dialect/mhlo/Util/ShapeInferUtil.cpp @@ -19,15 +19,18 @@ #include "byteir/Dialect/mhlo/DynamicShapeOpRegister/Register.h" #include "byteir/Transforms/ShapeReification.h" #include "mhlo/IR/hlo_ops.h" +#include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" #include "mlir/Transforms/TopologicalSortUtils.h" #include "llvm/ADT/StringMap.h" #include "llvm/Support/Debug.h" + #include #include @@ -184,32 +187,6 @@ mlir::inferReturnTypeComponents(llvm::StringRef name) { } namespace { - -SmallVector collectAllOpsForReturn(Operation *retOp) { - llvm::DenseSet visitedOp; - std::queue opQueue; - - opQueue.push(retOp); - while (!opQueue.empty()) { - auto frontOp = opQueue.front(); - opQueue.pop(); - if (visitedOp.find(frontOp) != visitedOp.end()) { - continue; - } - visitedOp.insert(frontOp); - for (Value operand : frontOp->getOperands()) { - if (!operand.getDefiningOp()) { - continue; - } - if (Operation *defOp = operand.getDefiningOp()) { - opQueue.push(defOp); - } - } - } - visitedOp.erase(retOp); - return SmallVector(visitedOp.begin(), visitedOp.end()); -} - bool deduceFromFuncArgShape(Value value) { if (value.isa()) { return false; @@ -240,42 +217,30 @@ bool deduceFromFuncArgShape(Value value) { return true; } -LogicalResult reifyCallOp(OpBuilder &builder, Operation *op, - SmallVectorImpl &reifications) { - OpBuilder::InsertionGuard guard(builder); - auto callOp = dyn_cast(op); - if (!callOp) { - return failure(); - } - - ModuleOp moduleOp = op->getParentRegion()->getParentOfType(); - // auxiliary builder used for create operations in shape func - // original builder maybe a rewriter, used for create operations in specific - // pattern. - OpBuilder auxiliaryBuilder(moduleOp); - StringRef funcName = callOp.getCallee(); - auto funcOp = moduleOp.lookupSymbol(funcName); +FailureOr createCorrespondingShapeFunc(func::FuncOp funcOp) { + ModuleOp moduleOp = funcOp->getParentOfType(); + // use auxiliary builder, create shape func in the start of moduleOp + OpBuilder builder = OpBuilder::atBlockBegin(moduleOp.getBody()); // clone funcOp, newFuncOp used for deduce function shape - std::string newFuncName = funcName.str() + "_Shape"; - auxiliaryBuilder.setInsertionPointToStart(moduleOp.getBody()); - auto newFuncOp = auxiliaryBuilder.create( - funcOp->getLoc(), newFuncName, funcOp.getFunctionType()); - newFuncOp.setPrivate(); + Twine shapeFuncName = funcOp.getName() + "_Shape"; + auto shapeFunc = builder.create( + funcOp->getLoc(), shapeFuncName.str(), funcOp.getFunctionType()); + shapeFunc.setPrivate(); IRMapping emptyBvm; - funcOp.cloneInto(newFuncOp, emptyBvm); + funcOp.cloneInto(shapeFunc, emptyBvm); // replace the operands of returnOp with corresponding shape - func::ReturnOp retOp = *newFuncOp.getOps().begin(); + func::ReturnOp retOp = *shapeFunc.getOps().begin(); if (!retOp) { - newFuncOp->erase(); + shapeFunc->erase(); return failure(); } for (Value &&retTensor : retOp.getOperands()) { auto retTy = retTensor.getType(); if (!retTy.isa()) { - newFuncOp->erase(); + shapeFunc->erase(); return failure(); } } @@ -283,44 +248,76 @@ LogicalResult reifyCallOp(OpBuilder &builder, Operation *op, SmallVector allResultTypes; SmallVector allResults; - auxiliaryBuilder.setInsertionPoint(retOp); + builder.setInsertionPoint(retOp); for (Value &&retTensor : retOp.getOperands()) { - auto retShape = - auxiliaryBuilder.create(retOp.getLoc(), retTensor); + auto retShape = builder.create(retOp.getLoc(), retTensor); allResultTypes.emplace_back(retShape.getType()); allResults.emplace_back(retShape); } // return the shape of original tensor returned by function - auto newRetOp = - auxiliaryBuilder.create(retOp.getLoc(), allResults); - auto newFuncType = auxiliaryBuilder.getFunctionType( - newFuncOp.getArgumentTypes(), allResultTypes); - newFuncOp.setFunctionType(newFuncType); + auto shapeFuncRetOp = + builder.create(retOp.getLoc(), allResults); + auto shapeFuncType = + builder.getFunctionType(shapeFunc.getArgumentTypes(), allResultTypes); + shapeFunc.setFunctionType(shapeFuncType); retOp->erase(); - // reify newFunc to get the shape computation for current callOp + // reify shapeFunc to get the shape computation. { - PassManager pm(moduleOp->getContext(), func::FuncOp::getOperationName()); - pm.addPass(createCanonicalizerPass()); - pm.addPass(createCSEPass()); - pm.addPass(createByteIRShapeReificationPass()); + PassManager pm(moduleOp->getContext(), moduleOp.getOperationName()); + // only run pass on shapeFunc + pm.addPass(createByteIRShapeReificationPass(shapeFunc.getName())); + if (mlir::failed(pm.run(moduleOp))) { + shapeFunc->erase(); + return failure(); + } + } + + // canonicalize shapeFunc + { + PassManager pm(shapeFunc->getContext(), shapeFunc.getOperationName()); pm.addPass(createCanonicalizerPass()); pm.addPass(createCSEPass()); - - if (mlir::failed(pm.run(newFuncOp))) { - newFuncOp->erase(); + // only run pass on shapeFunc, don't modify other ops. + if (mlir::failed(pm.run(shapeFunc))) { + shapeFunc->erase(); return failure(); } } + return shapeFunc; +} - // collect all shape computation ops - SmallVector reificationOps = collectAllOpsForReturn(newRetOp); +LogicalResult reifyCallOp(OpBuilder &builder, Operation *op, + SmallVectorImpl &reifications) { + OpBuilder::InsertionGuard guard(builder); + auto callOp = dyn_cast(op); + if (!callOp) { + return failure(); + } + + ModuleOp moduleOp = op->getParentOfType(); + StringRef funcName = callOp.getCallee(); + auto funcOp = moduleOp.lookupSymbol(funcName); + // create corresponding shape function + auto maybeShapeFunc = createCorrespondingShapeFunc(funcOp); + if (failed(maybeShapeFunc)) { + return failure(); + } + + func::FuncOp shapeFunc = *maybeShapeFunc; + func::ReturnOp retOp = *shapeFunc.getOps().begin(); + + // collect all shape computation ops + SetVector reificationOpSet; + getBackwardSlice(retOp.getOperation(), &reificationOpSet); + SmallVector reificationOps(reificationOpSet.begin(), + reificationOpSet.end()); // value only depends on the shape of FuncArgs. - for (Value &&ret : newRetOp.getOperands()) { + for (Value &&ret : retOp.getOperands()) { if (!deduceFromFuncArgShape(ret)) { - newFuncOp->erase(); + shapeFunc->erase(); return failure(); } } @@ -330,9 +327,9 @@ LogicalResult reifyCallOp(OpBuilder &builder, Operation *op, mlir::computeTopologicalSorting(reificationOps); IRMapping bvm; - size_t numArg = newFuncOp.getNumArguments(); + size_t numArg = shapeFunc.getNumArguments(); for (size_t i = 0; i < numArg; ++i) { - bvm.map(newFuncOp.getArgument(i), callOp.getOperand(i)); + bvm.map(shapeFunc.getArgument(i), callOp.getOperand(i)); } builder.setInsertionPoint(callOp); @@ -341,13 +338,13 @@ LogicalResult reifyCallOp(OpBuilder &builder, Operation *op, auto newOp = builder.clone(*oldOp, bvm); } - for (Value &&ret : newRetOp.getOperands()) { + for (Value &&ret : retOp.getOperands()) { reifications.push_back(bvm.lookup(ret)); } } // remove newFuncOp - newFuncOp->erase(); + shapeFunc->erase(); return success(); } diff --git a/compiler/lib/Pipelines/ByreTensorOpt.cpp b/compiler/lib/Pipelines/ByreTensorOpt.cpp index 4d5c2b5c6..3b87432a1 100644 --- a/compiler/lib/Pipelines/ByreTensorOpt.cpp +++ b/compiler/lib/Pipelines/ByreTensorOpt.cpp @@ -45,9 +45,8 @@ void createByreTensorOptPipelineImpl(OpPassManager &pm, std::string entryFunc, pm.addPass(createCanonicalizerPass()); pm.addNestedPass( createConvertHloToByreCustomPass(getCudaByreCustomConfig())); - pm.addNestedPass( - createConvertHloToByreTensorPass(appendArgTypes)); - pm.addNestedPass(createByteIRShapeReificationPass()); + pm.addPass(createConvertHloToByreTensorPass(appendArgTypes)); + pm.addPass(createByteIRShapeReificationPass()); pm.addPass(createCanonicalizerPass()); } } // namespace diff --git a/compiler/lib/Pipelines/ShapeOpt.cpp b/compiler/lib/Pipelines/ShapeOpt.cpp index 71870b343..6fd7204f5 100644 --- a/compiler/lib/Pipelines/ShapeOpt.cpp +++ b/compiler/lib/Pipelines/ShapeOpt.cpp @@ -27,17 +27,13 @@ using namespace mlir; void mlir::createShapeOptPipeline(OpPassManager &pm) { - invokeOpPassPipelineBuilder( - [](OpPassManager &pm) { - pm.addPass(createSetAssumingAlwaysTruePass()); - pm.addPass(createCanonicalizeExtPass()); - pm.addPass(createInsertTieShapePass()); - pm.addPass(createInsertShapeConstraintPass()); - pm.addPass(createByteIRShapeReificationPass()); - addCleanUpExtPassPipeline(pm, /*topHasSymTable*/ false); - pm.addPass(createResolveShapeConstraintPass()); - pm.addPass(createBoundedShapeInferencePass()); - pm.addPass(createCanonicalizeExtPass()); - }, - pm); + pm.addNestedPass(createSetAssumingAlwaysTruePass()); + pm.addNestedPass(createCanonicalizeExtPass()); + pm.addNestedPass(createInsertTieShapePass()); + pm.addNestedPass(createInsertShapeConstraintPass()); + pm.addPass(createByteIRShapeReificationPass()); + addCleanUpExtPassPipeline(pm, /*topHasSymTable*/ false); + pm.addNestedPass(createResolveShapeConstraintPass()); + pm.addNestedPass(createBoundedShapeInferencePass()); + pm.addNestedPass(createCanonicalizeExtPass()); } diff --git a/compiler/lib/Transforms/ShapeReification.cpp b/compiler/lib/Transforms/ShapeReification.cpp index f196d93a6..e6559f06c 100644 --- a/compiler/lib/Transforms/ShapeReification.cpp +++ b/compiler/lib/Transforms/ShapeReification.cpp @@ -65,6 +65,7 @@ struct ShapeReificationOnTensorDimPattern } rewriter.replaceOp(op, dimOfShape); + return success(); } }; @@ -106,14 +107,16 @@ void PopulateShapeReificationPatterns(MLIRContext *ctx, struct ShapeReificationPass : public ShapeReificationBase { - ShapeReificationPass() + ShapeReificationPass(const std::string &anchorFunc) : ShapeReificationBase::ShapeReificationBase() { // ReifyReturnType implementation could also be registered outside // ShapeReificationPass registerAllMhloReifyReturnTypeShapes(); + this->anchorFunc = anchorFunc; } void runOnOperation() override { + ModuleOp moduleOp = getOperation(); // Collect patterns. MLIRContext *ctx = &getContext(); RewritePatternSet patterns(ctx); @@ -125,17 +128,22 @@ struct ShapeReificationPass // iteration. GreedyRewriteConfig cfg; cfg.useTopDownTraversal = false; - func::FuncOp f = getOperation(); FrozenRewritePatternSet frozenPatterns(std::move(patterns)); - if (failed(applyPatternsAndFoldGreedily(f, frozenPatterns, cfg))) { - return signalPassFailure(); + + // apply Patterns on target funcOp. + for (auto funcOp : moduleOp.getOps()) { + if (this->anchorFunc == "" || funcOp.getName() == this->anchorFunc) { + if (failed(applyPatternsAndFoldGreedily(funcOp, frozenPatterns, cfg))) { + return signalPassFailure(); + } + } } } }; } // namespace -std::unique_ptr> -mlir::createByteIRShapeReificationPass() { - return std::make_unique(); +std::unique_ptr> +mlir::createByteIRShapeReificationPass(llvm::StringRef anchorFunc /*=""*/) { + return std::make_unique(anchorFunc.str()); } From 0ad9b82237b9acb05659b86281b8b800d5f5a81f Mon Sep 17 00:00:00 2001 From: zhengxuegui Date: Sat, 13 Apr 2024 22:59:02 +0800 Subject: [PATCH 5/5] [compiler] save shapeFunc in a temp module to avoid modifying origin module --- compiler/include/byteir/Transforms/Passes.td | 7 +- .../byteir/Transforms/ShapeReification.h | 7 +- compiler/lib/Analysis/SymbolicShape.cpp | 18 ++--- .../lib/Dialect/mhlo/Util/ShapeInferUtil.cpp | 70 ++++++++++++++----- compiler/lib/Pipelines/ByreTensorOpt.cpp | 2 +- compiler/lib/Pipelines/ShapeOpt.cpp | 22 +++--- compiler/lib/Transforms/ShapeReification.cpp | 21 ++---- .../test/Transforms/shapeReification.mlir | 20 ++++-- 8 files changed, 97 insertions(+), 70 deletions(-) diff --git a/compiler/include/byteir/Transforms/Passes.td b/compiler/include/byteir/Transforms/Passes.td index 4bc351390..b3d3e9853 100644 --- a/compiler/include/byteir/Transforms/Passes.td +++ b/compiler/include/byteir/Transforms/Passes.td @@ -429,7 +429,7 @@ def SetOpSpace: Pass<"set-op-space", "func::FuncOp"> { // ShapeReification //===----------------------------------------------------------------------===// -def ShapeReification : Pass<"byteir-shape-reification", "ModuleOp"> { +def ShapeReification : Pass<"byteir-shape-reification", "func::FuncOp"> { let summary = "Iteratively reify all shape computations."; let description = [{ If an operation has a shape reification implementation, that is to say, we @@ -443,11 +443,6 @@ def ShapeReification : Pass<"byteir-shape-reification", "ModuleOp"> { "mlir::tensor::TensorDialect", "mlir::arith::ArithDialect", ]; - let options = [ - Option<"anchorFunc", "anchor-func", "std::string", - /*default=*/"", - "An optional funcName used to specify the target function.">, - ]; } #endif // BYTEIR_TRANSFORMS_PASSES diff --git a/compiler/include/byteir/Transforms/ShapeReification.h b/compiler/include/byteir/Transforms/ShapeReification.h index 4cc2365e8..7c4cb5043 100644 --- a/compiler/include/byteir/Transforms/ShapeReification.h +++ b/compiler/include/byteir/Transforms/ShapeReification.h @@ -22,10 +22,11 @@ #include namespace mlir { -class ModuleOp; +namespace func { +class FuncOp; +} // namespace func -std::unique_ptr> -createByteIRShapeReificationPass(llvm::StringRef anchorFunc = ""); +std::unique_ptr> createByteIRShapeReificationPass(); } // namespace mlir diff --git a/compiler/lib/Analysis/SymbolicShape.cpp b/compiler/lib/Analysis/SymbolicShape.cpp index e8800ddfc..703dec1c4 100644 --- a/compiler/lib/Analysis/SymbolicShape.cpp +++ b/compiler/lib/Analysis/SymbolicShape.cpp @@ -118,20 +118,12 @@ SymbolicShapeAnalysis::SymbolicShapeAnalysis(ModuleOp moduleOp) } // run shape reification pass on all the auxiliary functions + PassManager pm(moduleOp->getContext(), func::FuncOp::getOperationName()); + pm.addPass(createByteIRShapeReificationPass()); + pm.addPass(createCSEPass()); for (auto funcOp : shpFuncOps) { - { - PassManager pm(moduleOp->getContext(), moduleOp.getOperationName()); - pm.addPass(createByteIRShapeReificationPass(funcOp.getName())); - if (mlir::failed(pm.run(moduleOp))) { - llvm::errs() << "Pass pipeline inside symbolic shape analysis failed."; - } - } - { - PassManager pm(moduleOp->getContext(), func::FuncOp::getOperationName()); - pm.addPass(createCSEPass()); - if (mlir::failed(pm.run(funcOp))) { - llvm::errs() << "Pass pipeline inside symbolic shape analysis failed."; - } + if (mlir::failed(pm.run(funcOp))) { + llvm::errs() << "Pass pipeline inside symbolic shape analysis failed."; } } } diff --git a/compiler/lib/Dialect/mhlo/Util/ShapeInferUtil.cpp b/compiler/lib/Dialect/mhlo/Util/ShapeInferUtil.cpp index 03e48010b..313025a1d 100644 --- a/compiler/lib/Dialect/mhlo/Util/ShapeInferUtil.cpp +++ b/compiler/lib/Dialect/mhlo/Util/ShapeInferUtil.cpp @@ -24,6 +24,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OwningOpRef.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" @@ -217,10 +218,12 @@ bool deduceFromFuncArgShape(Value value) { return true; } -FailureOr createCorrespondingShapeFunc(func::FuncOp funcOp) { - ModuleOp moduleOp = funcOp->getParentOfType(); - // use auxiliary builder, create shape func in the start of moduleOp - OpBuilder builder = OpBuilder::atBlockBegin(moduleOp.getBody()); +// the auxiliaryModuleOp must be a empty module, only used for save shapeFunc +FailureOr +createCorrespondingShapeFunc(func::FuncOp funcOp, ModuleOp auxiliaryModuleOp) { + // use auxiliary builder, create shape func in the start of auxiliaryModuleOp + ModuleOp oriModuleOp = funcOp->getParentOfType(); + OpBuilder builder = OpBuilder::atBlockBegin(auxiliaryModuleOp.getBody()); // clone funcOp, newFuncOp used for deduce function shape Twine shapeFuncName = funcOp.getName() + "_Shape"; @@ -229,6 +232,41 @@ FailureOr createCorrespondingShapeFunc(func::FuncOp funcOp) { shapeFunc.setPrivate(); IRMapping emptyBvm; funcOp.cloneInto(shapeFunc, emptyBvm); + llvm::DenseSet callOpSet; + shapeFunc.walk([&](func::CallOp callOp) { callOpSet.insert(callOp); }); + + while (!callOpSet.empty()) { + auto callOp = *callOpSet.begin(); + callOpSet.erase(callOpSet.begin()); + auto callFunc = oriModuleOp.lookupSymbol(callOp.getCallee()); + // inline this func. + builder.setInsertionPoint(callOp); + IRMapping bvm; + for (auto inputAndArg : + llvm::zip(callFunc.getArguments(), callOp.getOperands())) { + bvm.map(std::get<0>(inputAndArg), std::get<1>(inputAndArg)); + } + Block &entryBlock = callFunc.getBlocks().front(); + ValueRange funcOuts; + for (Operation &op : entryBlock) { + auto retOp = mlir::dyn_cast(op); + if (!retOp) { + auto newOp = builder.clone(op, bvm); + if (auto nestedCall = dyn_cast(newOp)) { + callOpSet.insert(nestedCall); + } + } else { + funcOuts = retOp.getOperands(); + } + } + + for (auto callResultAndFuncOuts : + llvm::zip(callOp.getResults(), funcOuts)) { + auto mappedOut = bvm.lookup(std::get<1>(callResultAndFuncOuts)); + std::get<0>(callResultAndFuncOuts).replaceAllUsesWith(mappedOut); + } + callOp->erase(); + } // replace the operands of returnOp with corresponding shape func::ReturnOp retOp = *shapeFunc.getOps().begin(); @@ -265,21 +303,13 @@ FailureOr createCorrespondingShapeFunc(func::FuncOp funcOp) { // reify shapeFunc to get the shape computation. { - PassManager pm(moduleOp->getContext(), moduleOp.getOperationName()); + PassManager pm(oriModuleOp->getContext(), func::FuncOp::getOperationName()); // only run pass on shapeFunc - pm.addPass(createByteIRShapeReificationPass(shapeFunc.getName())); - if (mlir::failed(pm.run(moduleOp))) { - shapeFunc->erase(); - return failure(); - } - } - - // canonicalize shapeFunc - { - PassManager pm(shapeFunc->getContext(), shapeFunc.getOperationName()); pm.addPass(createCanonicalizerPass()); pm.addPass(createCSEPass()); - // only run pass on shapeFunc, don't modify other ops. + pm.addPass(createByteIRShapeReificationPass()); + pm.addPass(createCanonicalizerPass()); + pm.addPass(createCSEPass()); if (mlir::failed(pm.run(shapeFunc))) { shapeFunc->erase(); return failure(); @@ -300,8 +330,12 @@ LogicalResult reifyCallOp(OpBuilder &builder, Operation *op, StringRef funcName = callOp.getCallee(); auto funcOp = moduleOp.lookupSymbol(funcName); - // create corresponding shape function - auto maybeShapeFunc = createCorrespondingShapeFunc(funcOp); + // create a temp module, then insert corresponding shape function to this + // module + OwningOpRef auxiliaryModuleOp( + ModuleOp::create(UnknownLoc::get(moduleOp->getContext()))); + auto maybeShapeFunc = + createCorrespondingShapeFunc(funcOp, auxiliaryModuleOp.get()); if (failed(maybeShapeFunc)) { return failure(); } diff --git a/compiler/lib/Pipelines/ByreTensorOpt.cpp b/compiler/lib/Pipelines/ByreTensorOpt.cpp index 3b87432a1..91940b0d4 100644 --- a/compiler/lib/Pipelines/ByreTensorOpt.cpp +++ b/compiler/lib/Pipelines/ByreTensorOpt.cpp @@ -46,7 +46,7 @@ void createByreTensorOptPipelineImpl(OpPassManager &pm, std::string entryFunc, pm.addNestedPass( createConvertHloToByreCustomPass(getCudaByreCustomConfig())); pm.addPass(createConvertHloToByreTensorPass(appendArgTypes)); - pm.addPass(createByteIRShapeReificationPass()); + pm.addNestedPass(createByteIRShapeReificationPass()); pm.addPass(createCanonicalizerPass()); } } // namespace diff --git a/compiler/lib/Pipelines/ShapeOpt.cpp b/compiler/lib/Pipelines/ShapeOpt.cpp index 6fd7204f5..71870b343 100644 --- a/compiler/lib/Pipelines/ShapeOpt.cpp +++ b/compiler/lib/Pipelines/ShapeOpt.cpp @@ -27,13 +27,17 @@ using namespace mlir; void mlir::createShapeOptPipeline(OpPassManager &pm) { - pm.addNestedPass(createSetAssumingAlwaysTruePass()); - pm.addNestedPass(createCanonicalizeExtPass()); - pm.addNestedPass(createInsertTieShapePass()); - pm.addNestedPass(createInsertShapeConstraintPass()); - pm.addPass(createByteIRShapeReificationPass()); - addCleanUpExtPassPipeline(pm, /*topHasSymTable*/ false); - pm.addNestedPass(createResolveShapeConstraintPass()); - pm.addNestedPass(createBoundedShapeInferencePass()); - pm.addNestedPass(createCanonicalizeExtPass()); + invokeOpPassPipelineBuilder( + [](OpPassManager &pm) { + pm.addPass(createSetAssumingAlwaysTruePass()); + pm.addPass(createCanonicalizeExtPass()); + pm.addPass(createInsertTieShapePass()); + pm.addPass(createInsertShapeConstraintPass()); + pm.addPass(createByteIRShapeReificationPass()); + addCleanUpExtPassPipeline(pm, /*topHasSymTable*/ false); + pm.addPass(createResolveShapeConstraintPass()); + pm.addPass(createBoundedShapeInferencePass()); + pm.addPass(createCanonicalizeExtPass()); + }, + pm); } diff --git a/compiler/lib/Transforms/ShapeReification.cpp b/compiler/lib/Transforms/ShapeReification.cpp index e6559f06c..8d09215a8 100644 --- a/compiler/lib/Transforms/ShapeReification.cpp +++ b/compiler/lib/Transforms/ShapeReification.cpp @@ -107,16 +107,14 @@ void PopulateShapeReificationPatterns(MLIRContext *ctx, struct ShapeReificationPass : public ShapeReificationBase { - ShapeReificationPass(const std::string &anchorFunc) + ShapeReificationPass() : ShapeReificationBase::ShapeReificationBase() { // ReifyReturnType implementation could also be registered outside // ShapeReificationPass registerAllMhloReifyReturnTypeShapes(); - this->anchorFunc = anchorFunc; } void runOnOperation() override { - ModuleOp moduleOp = getOperation(); // Collect patterns. MLIRContext *ctx = &getContext(); RewritePatternSet patterns(ctx); @@ -128,22 +126,17 @@ struct ShapeReificationPass // iteration. GreedyRewriteConfig cfg; cfg.useTopDownTraversal = false; + func::FuncOp f = getOperation(); FrozenRewritePatternSet frozenPatterns(std::move(patterns)); - - // apply Patterns on target funcOp. - for (auto funcOp : moduleOp.getOps()) { - if (this->anchorFunc == "" || funcOp.getName() == this->anchorFunc) { - if (failed(applyPatternsAndFoldGreedily(funcOp, frozenPatterns, cfg))) { - return signalPassFailure(); - } - } + if (failed(applyPatternsAndFoldGreedily(f, frozenPatterns, cfg))) { + return signalPassFailure(); } } }; } // namespace -std::unique_ptr> -mlir::createByteIRShapeReificationPass(llvm::StringRef anchorFunc /*=""*/) { - return std::make_unique(anchorFunc.str()); +std::unique_ptr> +mlir::createByteIRShapeReificationPass() { + return std::make_unique(); } diff --git a/compiler/test/Transforms/shapeReification.mlir b/compiler/test/Transforms/shapeReification.mlir index 157d5e176..cfba2985c 100644 --- a/compiler/test/Transforms/shapeReification.mlir +++ b/compiler/test/Transforms/shapeReification.mlir @@ -112,11 +112,19 @@ func.func @outer_func(%arg0: tensor, %arg1: tensor) -> (!shape // ----- +func.func private @Unknown2(%arg0: tensor, %arg1: tensor) -> tensor attributes {__byteir_matmul_epilogue_fusion__} { + %0 = mhlo.constant dense_resource<__elided__> : tensor<10x20xf32> + %1 = "mhlo.dot"(%arg0, %0) : (tensor, tensor<10x20xf32>) -> tensor + %2 = mhlo.add %1, %arg1 : tensor + return %2 : tensor +} + func.func private @Unknown1(%arg0: tensor, %arg1: tensor) -> tensor attributes {__byteir_matmul_epilogue_fusion__} { %0 = mhlo.constant dense_resource<__elided__> : tensor<10x20xf32> - %1 = "mhlo.dot"(%arg0, %0) : (tensor, tensor<10x20xf32>) -> tensor - %2 = mhlo.add %1, %arg1 : tensor - return %2 : tensor + %1 = call @Unknown2(%arg0, %arg1) : (tensor, tensor) -> tensor + %2 = "mhlo.dot"(%arg0, %0) : (tensor, tensor<10x20xf32>) -> tensor + %3 = mhlo.add %2, %1 : tensor + return %3 : tensor } func.func private @Unknown0(%arg0: tensor, %arg1: tensor<20xf32>, %arg2: tensor, %arg3: tensor) -> (tensor, tensor) { @@ -146,6 +154,6 @@ func.func @forward(%arg0: tensor, %arg1: tensor, %arg2: tens // CHECK-LABEL: func.func @forward // CHECK: %[[DIM:.*]] = tensor.dim %arg0, %c0 : tensor -// CHECK: %[[DIM0:.*]] = tensor.dim %arg2, %c1 : tensor<20x?xf32> -// CHECK: %[[SHAPE:.*]] = tensor.from_elements %[[DIM:.*]], %[[DIM0:.*]] : tensor<2xindex> -// CHECK: return %[[SHAPE:.*]] : tensor<2xindex> \ No newline at end of file +// CHECK-NEXT: %[[DIM0:.*]] = tensor.dim %arg2, %c1 : tensor<20x?xf32> +// CHECK-NEXT: %[[SHAPE:.*]] = tensor.from_elements %[[DIM:.*]], %[[DIM0:.*]] : tensor<2xindex> +// CHECK-NEXT: return %[[SHAPE:.*]] : tensor<2xindex> \ No newline at end of file