diff --git a/include/ttmlir/Conversion/Passes.h b/include/ttmlir/Conversion/Passes.h index 4c7464d88..cb629925a 100644 --- a/include/ttmlir/Conversion/Passes.h +++ b/include/ttmlir/Conversion/Passes.h @@ -7,6 +7,7 @@ #ifdef TTMLIR_ENABLE_STABLEHLO #include "ttmlir/Conversion/ArithToStableHLO/ArithToStableHLO.h" +#include "ttmlir/Conversion/RedundantBroadcastElimination/RedundantBroadcastElimination.h" #include "ttmlir/Conversion/StableHLOToTTIR/StableHLOToTTIR.h" #endif #include "ttmlir/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.h" diff --git a/include/ttmlir/Conversion/Passes.td b/include/ttmlir/Conversion/Passes.td index d85396459..3c6ee372f 100644 --- a/include/ttmlir/Conversion/Passes.td +++ b/include/ttmlir/Conversion/Passes.td @@ -18,6 +18,11 @@ let summary = "Convert Arith Dialect to StableHLO dialect."; let constructor = "createConvertArithToStableHLOPass()"; let dependentDialects = ["mlir::stablehlo::StablehloDialect", "mlir::arith::ArithDialect"]; } +def RedundantBroadcastElimination : Pass<"redundant-broadcast-elimination", "::mlir::ModuleOp"> { +let summary = "Eliminate any redundant broadcast ops by folding them."; + let constructor = "createRedundantBroadcastEliminationPass()"; + let dependentDialects = ["mlir::stablehlo::StablehloDialect"]; +} #endif def ConvertTosaToTTIR : Pass<"convert-tosa-to-ttir", "::mlir::ModuleOp"> { diff --git a/include/ttmlir/Conversion/RedundantBroadcastElimination/RedundantBroadcastElimination.h b/include/ttmlir/Conversion/RedundantBroadcastElimination/RedundantBroadcastElimination.h new file mode 100644 index 000000000..0e1dac4de --- /dev/null +++ b/include/ttmlir/Conversion/RedundantBroadcastElimination/RedundantBroadcastElimination.h @@ -0,0 +1,20 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef TTMLIR_CONVERSION_REDUNDANT_BROADCAST_ELIMINATION_H +#define TTMLIR_CONVERSION_REDUNDANT_BROADCAST_ELIMINATION_H + +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir::tt { + +#ifdef TTMLIR_ENABLE_STABLEHLO +std::unique_ptr> +createRedundantBroadcastEliminationPass(); +#endif + +} // namespace mlir::tt + +#endif // TTMLIR_CONVERSION_REDUNDANT_BROADCAST_ELIMINATION_H diff --git a/include/ttmlir/Dialect/TTIR/Pipelines/TTIRPipelines.h b/include/ttmlir/Dialect/TTIR/Pipelines/TTIRPipelines.h index f922c2501..250ef96d8 100644 --- a/include/ttmlir/Dialect/TTIR/Pipelines/TTIRPipelines.h +++ b/include/ttmlir/Dialect/TTIR/Pipelines/TTIRPipelines.h @@ -30,6 +30,12 @@ struct StableHLOToTTIRPipelineOptions // This pass will convert stablehlo.composite ops into func.call ops so // that the TTIR inliner pass may inline the ops. llvm::cl::init(true)}; + Option eliminateRedundantBroadcast{ + *this, "eliminate-redundant-broadcast", + llvm::cl::desc("Eliminate redundant broadcast ops."), + // Stablehlo can generate redundant broadcast ops where the input and + // output shapes are same. This pass folds those broadcasts. + llvm::cl::init(true)}; }; void createStableHLOToTTIRPipeline( diff --git a/lib/Conversion/StableHLOToTTIR/ArithToStableHLOPass.cpp b/lib/Conversion/StableHLOToTTIR/ArithToStableHLOPass.cpp index e759170e8..1f8f7b88b 100644 --- a/lib/Conversion/StableHLOToTTIR/ArithToStableHLOPass.cpp +++ b/lib/Conversion/StableHLOToTTIR/ArithToStableHLOPass.cpp @@ -3,6 +3,7 @@ // SPDX-License-Identifier: Apache-2.0 #include "ttmlir/Conversion/ArithToStableHLO/ArithToStableHLO.h" +#include "ttmlir/Conversion/RedundantBroadcastElimination/RedundantBroadcastElimination.h" #include #include diff --git a/lib/Conversion/StableHLOToTTIR/CMakeLists.txt b/lib/Conversion/StableHLOToTTIR/CMakeLists.txt index c09068750..cc7a6e008 100644 --- a/lib/Conversion/StableHLOToTTIR/CMakeLists.txt +++ b/lib/Conversion/StableHLOToTTIR/CMakeLists.txt @@ -7,6 +7,7 @@ add_mlir_library(TTMLIRStableHLOToTTIR StableHLOToTTIRPatterns.cpp StableHLOToTTIRPass.cpp ArithToStableHLOPass.cpp + RedundantBroadcastElimination.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/ttmlir/Conversion/StableHLOToTTIR diff --git a/lib/Conversion/StableHLOToTTIR/RedundantBroadcastElimination.cpp b/lib/Conversion/StableHLOToTTIR/RedundantBroadcastElimination.cpp new file mode 100644 index 000000000..254b44804 --- /dev/null +++ b/lib/Conversion/StableHLOToTTIR/RedundantBroadcastElimination.cpp @@ -0,0 +1,70 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "ttmlir/Conversion/RedundantBroadcastElimination/RedundantBroadcastElimination.h" +#include "ttmlir/Dialect/TT/IR/TT.h" +#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" +#include "ttmlir/Dialect/TTIR/IR/TTIR.h" +#include "ttmlir/Dialect/TTIR/IR/TTIROps.h" + +using namespace mlir; +using namespace mlir::tt; + +namespace mlir::tt::ttir { + +#define GEN_PASS_DEF_REDUNDANTBROADCASTELIMINATION +#include "ttmlir/Conversion/Passes.h.inc" + +} // namespace mlir::tt::ttir + +namespace { + +class RedundantBroadcastEliminationPass + : public ttir::impl::RedundantBroadcastEliminationBase< + RedundantBroadcastEliminationPass> { +public: + using ttir::impl::RedundantBroadcastEliminationBase< + RedundantBroadcastEliminationPass>::RedundantBroadcastEliminationBase; + + void runOnOperation() final { + ModuleOp module = getOperation(); + IRRewriter rewriter(&getContext()); + + module->walk([&](Operation *op) { + if (mlir::isa(op)) { + if (op->use_empty()) { + return; + } + + if (op->getResult(0).getType() == op->getOperand(0).getType()) { + // This broadcast is redundant + rewriter.replaceAllUsesWith((Value)op->getResult(0), + (Value)op->getOperand(0)); + rewriter.eraseOp(op); + } + } + }); + } +}; + +} // namespace + +namespace mlir::tt { + +std::unique_ptr> +createRedundantBroadcastEliminationPass() { + return std::make_unique(); +} + +} // namespace mlir::tt diff --git a/lib/Dialect/TTIR/Pipelines/TTIRPipelines.cpp b/lib/Dialect/TTIR/Pipelines/TTIRPipelines.cpp index a092a36e1..5aab0b0e1 100644 --- a/lib/Dialect/TTIR/Pipelines/TTIRPipelines.cpp +++ b/lib/Dialect/TTIR/Pipelines/TTIRPipelines.cpp @@ -28,6 +28,9 @@ void createStableHLOToTTIRPipeline( pm.addPass(stablehlo::createStablehloLegalizeCompositeToCallPass()); } pm.addPass(createConvertStableHLOToTTIRPass()); + if (options.eliminateRedundantBroadcast) { + pm.addPass(createRedundantBroadcastEliminationPass()); + } if (options.removeDeadValuesEnabled) { pm.addPass(mlir::createRemoveDeadValuesPass()); }