Skip to content

Commit

Permalink
[pass-status] tensor.extract ops are stil there
Browse files Browse the repository at this point in the history
replaces extract with forwarded insert value

moved rlwe params out

distincts forwarding cases

checks return values and adds cmakelists.txt
  • Loading branch information
ahmedshakill committed Oct 9, 2024
1 parent b79bccf commit a1bd760
Show file tree
Hide file tree
Showing 9 changed files with 445 additions and 0 deletions.
51 changes: 51 additions & 0 deletions lib/Transforms/ForwardInsertToExtract/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library")

package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

cc_library(
name = "ForwardInsertToExtract",
srcs = ["ForwardInsertToExtract.cpp"],
hdrs = [
"ForwardInsertToExtract.h",
],
deps = [
":pass_inc_gen",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AffineDialect",
"@llvm-project//mlir:AffineUtils",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:TransformUtils",
"@llvm-project//mlir:Transforms",
],
)
# ForwardInsertToExtract tablegen and headers.

gentbl_cc_library(
name = "pass_inc_gen",
tbl_outs = [
(
[
"-gen-pass-decls",
"-name=ForwardInsertToExtract",
],
"ForwardInsertToExtract.h.inc",
),
(
["-gen-pass-doc"],
"ForwardInsertToExtractPasses.md",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "ForwardInsertToExtract.td",
deps = [
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:PassBaseTdFiles",
],
)
19 changes: 19 additions & 0 deletions lib/Transforms/ForwardInsertToExtract/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
set(LLVM_TARGET_DEFINITIONS ForwardInsertToExtract.td)
mlir_tablegen(ForwardInsertToExtract.h.inc -gen-pass-decls -name ForwardInsertToExtract)
add_public_tablegen_target(MLIRHeirForwardInsertToExtractIncGen)

add_mlir_dialect_library(MLIRHeirForwardInsertToExtract
ForwardInsertToExtract.cpp

DEPENDS
MLIRHeirForwardInsertToExtractIncGen

LINK_LIBS PUBLIC
MLIRModArithDialect
MLIRIR
MLIRInferTypeOpInterface
MLIRArithDialect
MLIRSupport
MLIRDialect
MLIRIR
)
132 changes: 132 additions & 0 deletions lib/Transforms/ForwardInsertToExtract/ForwardInsertToExtract.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
#include "lib/Transforms/ForwardInsertToExtract/ForwardInsertToExtract.h"

#include <utility>

#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Affine/Utils.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project

#define DEBUG_TYPE "forward-insert-to-extract"

namespace mlir {
namespace heir {

#define GEN_PASS_DEF_FORWARDINSERTTOEXTRACT
#include "lib/Transforms/ForwardInsertToExtract/ForwardInsertToExtract.h.inc"

bool ForwardSingleInsertToExtract::isForwardableOp(
Operation *potentialInsert, tensor::ExtractOp &extractOp) const {
if (!dominanceInfo.properlyDominates(potentialInsert,
extractOp.getOperation())) {
LLVM_DEBUG(llvm::dbgs() << "insert op does not dominate extract op\n");
return false;
}

if (extractOp->getBlock() != potentialInsert->getBlock()) {
LLVM_DEBUG(llvm::dbgs()
<< "insert and extract op are not in the same block\n");
return false;
}

return llvm::TypeSwitch<Operation &, bool>(*potentialInsert)
.Case<tensor::InsertOp>([&](auto insertOp) {
ValueRange insertIndices = insertOp.getIndices();
ValueRange extractIndices = extractOp.getIndices();
if (insertIndices != extractIndices) {
LLVM_DEBUG(llvm::dbgs()
<< "insert and extract op do not have matching indices\n");
return false;
}

// Naively scan through the operations between the two ops and check if
// anything prevents forwarding.
for (auto currentNode = insertOp->getNextNode();
currentNode != extractOp.getOperation();
currentNode = currentNode->getNextNode()) {
if (currentNode->getNumRegions() > 0) {
LLVM_DEBUG(llvm::dbgs() << "an op with control flow is between the "
"insert and extract op\n");
return false;
}

if (auto op = dyn_cast<tensor::InsertOp>(currentNode)) {
if (op.getDest() == insertOp.getDest() &&
op.getIndices() == insertIndices) {
LLVM_DEBUG(llvm::dbgs()
<< "an intermediate op inserts to the same index\n");
return false;
}
}
}
return true;
})
.Default([&](Operation &) {
LLVM_DEBUG(llvm::dbgs()
<< "Unsupported op type, cannot check for forwardability\n");
return false;
});
}

FailureOr<Value> getInsertedValue(Operation *insertOp) {
return llvm::TypeSwitch<Operation &, FailureOr<Value>>(*insertOp)
.Case<tensor::InsertOp>(
[&](auto insertOp) { return insertOp.getScalar(); })
.Default([&](Operation &) { return failure(); });
}

LogicalResult ForwardSingleInsertToExtract::matchAndRewrite(
tensor::ExtractOp extractOp, PatternRewriter &rewriter) const {
LLVM_DEBUG(llvm::dbgs() << "Considering extractOp for replacement: "
<< extractOp << "\n");

auto *def = extractOp.getTensor().getDefiningOp();
if (def != nullptr) {
LLVM_DEBUG(llvm::dbgs()
<< "DefiningOp of the one considered: "
<< *extractOp.getTensor().getDefiningOp<tensor::InsertOp>()
<< "\n");

LLVM_DEBUG(llvm::dbgs()
<< "Considering def for forwarding: " << *def << "\n");
if (isForwardableOp(def, extractOp)) {
auto result = getInsertedValue(def);
LLVM_DEBUG(llvm::dbgs() << "def is forwardable: " << *def << "\n");
if (failed(result)) {
return failure();
}
auto value = result.value();
rewriter.replaceAllUsesWith(extractOp, value);
return success();
}
LLVM_DEBUG(llvm::dbgs() << "def is not forwardable: " << *def << "\n");
} else {
LLVM_DEBUG(llvm::dbgs() << "def is nullptr " << "\n");
}
return failure();
}

struct ForwardInsertToExtract
: impl::ForwardInsertToExtractBase<ForwardInsertToExtract> {
using ForwardInsertToExtractBase::ForwardInsertToExtractBase;

void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
DominanceInfo dom(getOperation());
patterns.add<ForwardSingleInsertToExtract>(context, dom);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};

} // namespace heir
} // namespace mlir
43 changes: 43 additions & 0 deletions lib/Transforms/ForwardInsertToExtract/ForwardInsertToExtract.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#ifndef LIB_TRANSFORMS_FORWARDINSERTTOEXTRACT_FORWARDINSERTTOEXTRACT_H_
#define LIB_TRANSFORMS_FORWARDINSERTTOEXTRACT_FORWARDINSERTTOEXTRACT_H_

#include "mlir/include/mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Affine/Utils.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/include/mlir/IR/Dominance.h" // from @llvm-project
#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/Passes.h" // from @llvm-project

namespace mlir {
namespace heir {

#define GEN_PASS_DECL
#include "lib/Transforms/ForwardInsertToExtract/ForwardInsertToExtract.h.inc"

#define GEN_PASS_REGISTRATION
#include "lib/Transforms/ForwardInsertToExtract/ForwardInsertToExtract.h.inc"

struct ForwardSingleInsertToExtract
: public OpRewritePattern<tensor::ExtractOp> {
ForwardSingleInsertToExtract(mlir::MLIRContext *context, DominanceInfo &dom)
: OpRewritePattern<tensor::ExtractOp>(context, 3), dominanceInfo(dom) {}

public:
LogicalResult matchAndRewrite(tensor::ExtractOp op,
PatternRewriter &rewriter) const override;

private:
bool isForwardableOp(Operation *potentialInsert,
tensor::ExtractOp &extractOp) const;

DominanceInfo &dominanceInfo;
};

} // namespace heir
} // namespace mlir

#endif // LIB_TRANSFORMS_FORWARDINSERTTOEXTRACT_FORWARDINSERTTOEXTRACT_H_
23 changes: 23 additions & 0 deletions lib/Transforms/ForwardInsertToExtract/ForwardInsertToExtract.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#ifndef LIB_TRANSFORMS_FORWARDINSERTTOEXTRACT_FORWARDINSERTTOEXTRACT_TD_
#define LIB_TRANSFORMS_FORWARDINSERTTOEXTRACT_FORWARDINSERTTOEXTRACT_TD_

include "mlir/Pass/PassBase.td"

def ForwardInsertToExtract : Pass<"forward-insert-to-extract"> {
let summary = "Forward inserts to extracts within a single block";
let description = [{
This pass is similar to forward-store-to-load pass where store ops
are forwarded load ops; here instead tensor.insert ops are forwarded
to tensor.extract ops.

Does not support complex control flow within a block, nor ops with
arbitrary subregions.
}];
let dependentDialects = [
"mlir::affine::AffineDialect",
"mlir::memref::MemRefDialect",
"mlir::tensor::TensorDialect"
];
}

#endif // LIB_TRANSFORMS_FORWARDINSERTTOEXTRACT_FORWARDINSERTTOEXTRACT_TD_
10 changes: 10 additions & 0 deletions tests/forward_insert_to_extract/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
load("//bazel:lit.bzl", "glob_lit_tests")

package(default_applicable_licenses = ["@heir//:license"])

glob_lit_tests(
name = "all_tests",
data = ["@heir//tests:test_utilities"],
driver = "@heir//tests:run_lit.sh",
test_file_exts = ["mlir"],
)
Loading

0 comments on commit a1bd760

Please sign in to comment.