Skip to content

Commit

Permalink
Merge pull request google#991 from ahmedshakill:forward_insert_to_ext…
Browse files Browse the repository at this point in the history
…ract

PiperOrigin-RevId: 684141392
  • Loading branch information
copybara-github committed Oct 9, 2024
2 parents 2fb6474 + a1bd760 commit 2391792
Show file tree
Hide file tree
Showing 9 changed files with 436 additions and 0 deletions.
51 changes: 51 additions & 0 deletions lib/Transforms/ForwardInsertToExtract/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library")

package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

cc_library(
name = "ForwardInsertToExtract",
srcs = ["ForwardInsertToExtract.cpp"],
hdrs = [
"ForwardInsertToExtract.h",
],
deps = [
":pass_inc_gen",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AffineDialect",
"@llvm-project//mlir:AffineUtils",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:TransformUtils",
"@llvm-project//mlir:Transforms",
],
)
# ForwardInsertToExtract tablegen and headers.

gentbl_cc_library(
name = "pass_inc_gen",
tbl_outs = [
(
[
"-gen-pass-decls",
"-name=ForwardInsertToExtract",
],
"ForwardInsertToExtract.h.inc",
),
(
["-gen-pass-doc"],
"ForwardInsertToExtractPasses.md",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "ForwardInsertToExtract.td",
deps = [
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:PassBaseTdFiles",
],
)
19 changes: 19 additions & 0 deletions lib/Transforms/ForwardInsertToExtract/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
set(LLVM_TARGET_DEFINITIONS ForwardInsertToExtract.td)
mlir_tablegen(ForwardInsertToExtract.h.inc -gen-pass-decls -name ForwardInsertToExtract)
add_public_tablegen_target(MLIRHeirForwardInsertToExtractIncGen)

add_mlir_dialect_library(MLIRHeirForwardInsertToExtract
ForwardInsertToExtract.cpp

DEPENDS
MLIRHeirForwardInsertToExtractIncGen

LINK_LIBS PUBLIC
MLIRModArithDialect
MLIRIR
MLIRInferTypeOpInterface
MLIRArithDialect
MLIRSupport
MLIRDialect
MLIRIR
)
130 changes: 130 additions & 0 deletions lib/Transforms/ForwardInsertToExtract/ForwardInsertToExtract.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
#include "lib/Transforms/ForwardInsertToExtract/ForwardInsertToExtract.h"

#include <utility>

#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Affine/Utils.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project

#define DEBUG_TYPE "forward-insert-to-extract"

namespace mlir {
namespace heir {

#define GEN_PASS_DEF_FORWARDINSERTTOEXTRACT
#include "lib/Transforms/ForwardInsertToExtract/ForwardInsertToExtract.h.inc"

bool ForwardSingleInsertToExtract::isForwardableOp(
Operation *potentialInsert, tensor::ExtractOp &extractOp) const {
if (!dominanceInfo.properlyDominates(potentialInsert,
extractOp.getOperation())) {
LLVM_DEBUG(llvm::dbgs() << "insert op does not dominate extract op\n");
return false;
}

if (extractOp->getBlock() != potentialInsert->getBlock()) {
LLVM_DEBUG(llvm::dbgs()
<< "insert and extract op are not in the same block\n");
return false;
}

return llvm::TypeSwitch<Operation &, bool>(*potentialInsert)
.Case<tensor::InsertOp>([&](auto insertOp) {
ValueRange insertIndices = insertOp.getIndices();
ValueRange extractIndices = extractOp.getIndices();
if (insertIndices != extractIndices) {
LLVM_DEBUG(llvm::dbgs()
<< "insert and extract op do not have matching indices\n");
return false;
}

// Naively scan through the operations between the two ops and check if
// anything prevents forwarding.
for (auto currentNode = insertOp->getNextNode();
currentNode != extractOp.getOperation();
currentNode = currentNode->getNextNode()) {
if (currentNode->getNumRegions() > 0) {
LLVM_DEBUG(llvm::dbgs() << "an op with control flow is between the "
"insert and extract op\n");
return false;
}

if (auto op = dyn_cast<tensor::InsertOp>(currentNode)) {
if (op.getDest() == insertOp.getDest() &&
op.getIndices() == insertIndices) {
LLVM_DEBUG(llvm::dbgs()
<< "an intermediate op inserts to the same index\n");
return false;
}
}
}
return true;
})
.Default([&](Operation &) {
LLVM_DEBUG(llvm::dbgs()
<< "Unsupported op type, cannot check for forwardability\n");
return false;
});
}

FailureOr<Value> getInsertedValue(Operation *insertOp) {
return llvm::TypeSwitch<Operation &, FailureOr<Value>>(*insertOp)
.Case<tensor::InsertOp>(
[&](auto insertOp) { return insertOp.getScalar(); })
.Default([&](Operation &) { return failure(); });
}

LogicalResult ForwardSingleInsertToExtract::matchAndRewrite(
tensor::ExtractOp extractOp, PatternRewriter &rewriter) const {
LLVM_DEBUG(llvm::dbgs() << "Considering extractOp for replacement: "
<< extractOp << "\n");

auto *def = extractOp.getTensor().getDefiningOp();
if (def != nullptr) {
LLVM_DEBUG(llvm::dbgs()
<< "DefiningOp of the one considered: "
<< *extractOp.getTensor().getDefiningOp<tensor::InsertOp>()
<< "\n");

LLVM_DEBUG(llvm::dbgs()
<< "Considering def for forwarding: " << *def << "\n");
if (isForwardableOp(def, extractOp)) {
auto result = getInsertedValue(def);
LLVM_DEBUG(llvm::dbgs() << "def is forwardable: " << *def << "\n");
if (failed(result)) {
return failure();
}
auto value = result.value();
rewriter.replaceAllUsesWith(extractOp, value);
return success();
}
LLVM_DEBUG(llvm::dbgs() << "def is not forwardable: " << *def << "\n");
} else {
LLVM_DEBUG(llvm::dbgs() << "def is nullptr " << "\n");
}
return failure();
}

struct ForwardInsertToExtract
: impl::ForwardInsertToExtractBase<ForwardInsertToExtract> {
using ForwardInsertToExtractBase::ForwardInsertToExtractBase;

void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
DominanceInfo dom(getOperation());
patterns.add<ForwardSingleInsertToExtract>(context, dom);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};

} // namespace heir
} // namespace mlir
41 changes: 41 additions & 0 deletions lib/Transforms/ForwardInsertToExtract/ForwardInsertToExtract.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#ifndef LIB_TRANSFORMS_FORWARDINSERTTOEXTRACT_FORWARDINSERTTOEXTRACT_H_
#define LIB_TRANSFORMS_FORWARDINSERTTOEXTRACT_FORWARDINSERTTOEXTRACT_H_

#include "mlir/include/mlir/Dialect/Affine/Utils.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/include/mlir/IR/Dominance.h" // from @llvm-project
#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/Passes.h" // from @llvm-project

namespace mlir {
namespace heir {

#define GEN_PASS_DECL
#include "lib/Transforms/ForwardInsertToExtract/ForwardInsertToExtract.h.inc"

#define GEN_PASS_REGISTRATION
#include "lib/Transforms/ForwardInsertToExtract/ForwardInsertToExtract.h.inc"

struct ForwardSingleInsertToExtract
: public OpRewritePattern<tensor::ExtractOp> {
ForwardSingleInsertToExtract(mlir::MLIRContext *context, DominanceInfo &dom)
: OpRewritePattern<tensor::ExtractOp>(context, 3), dominanceInfo(dom) {}

public:
LogicalResult matchAndRewrite(tensor::ExtractOp op,
PatternRewriter &rewriter) const override;

private:
bool isForwardableOp(Operation *potentialInsert,
tensor::ExtractOp &extractOp) const;

DominanceInfo &dominanceInfo;
};

} // namespace heir
} // namespace mlir

#endif // LIB_TRANSFORMS_FORWARDINSERTTOEXTRACT_FORWARDINSERTTOEXTRACT_H_
18 changes: 18 additions & 0 deletions lib/Transforms/ForwardInsertToExtract/ForwardInsertToExtract.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#ifndef LIB_TRANSFORMS_FORWARDINSERTTOEXTRACT_FORWARDINSERTTOEXTRACT_TD_
#define LIB_TRANSFORMS_FORWARDINSERTTOEXTRACT_FORWARDINSERTTOEXTRACT_TD_

include "mlir/Pass/PassBase.td"

def ForwardInsertToExtract : Pass<"forward-insert-to-extract"> {
let summary = "Forward inserts to extracts within a single block";
let description = [{
This pass is similar to forward-store-to-load pass where store ops
are forwarded load ops; here instead tensor.insert ops are forwarded
to tensor.extract ops.

Does not support complex control flow within a block, nor ops with
arbitrary subregions.
}];
}

#endif // LIB_TRANSFORMS_FORWARDINSERTTOEXTRACT_FORWARDINSERTTOEXTRACT_TD_
10 changes: 10 additions & 0 deletions tests/forward_insert_to_extract/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
load("//bazel:lit.bzl", "glob_lit_tests")

package(default_applicable_licenses = ["@heir//:license"])

glob_lit_tests(
name = "all_tests",
data = ["@heir//tests:test_utilities"],
driver = "@heir//tests:run_lit.sh",
test_file_exts = ["mlir"],
)
Loading

0 comments on commit 2391792

Please sign in to comment.