forked from google/heir
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request google#991 from ahmedshakill:forward_insert_to_ext…
…ract PiperOrigin-RevId: 684141392
- Loading branch information
Showing
9 changed files
with
436 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") | ||
|
||
package( | ||
default_applicable_licenses = ["@heir//:license"], | ||
default_visibility = ["//visibility:public"], | ||
) | ||
|
||
cc_library( | ||
name = "ForwardInsertToExtract", | ||
srcs = ["ForwardInsertToExtract.cpp"], | ||
hdrs = [ | ||
"ForwardInsertToExtract.h", | ||
], | ||
deps = [ | ||
":pass_inc_gen", | ||
"@llvm-project//llvm:Support", | ||
"@llvm-project//mlir:AffineDialect", | ||
"@llvm-project//mlir:AffineUtils", | ||
"@llvm-project//mlir:IR", | ||
"@llvm-project//mlir:MemRefDialect", | ||
"@llvm-project//mlir:Pass", | ||
"@llvm-project//mlir:Support", | ||
"@llvm-project//mlir:TensorDialect", | ||
"@llvm-project//mlir:TransformUtils", | ||
"@llvm-project//mlir:Transforms", | ||
], | ||
) | ||
# ForwardInsertToExtract tablegen and headers. | ||
|
||
gentbl_cc_library( | ||
name = "pass_inc_gen", | ||
tbl_outs = [ | ||
( | ||
[ | ||
"-gen-pass-decls", | ||
"-name=ForwardInsertToExtract", | ||
], | ||
"ForwardInsertToExtract.h.inc", | ||
), | ||
( | ||
["-gen-pass-doc"], | ||
"ForwardInsertToExtractPasses.md", | ||
), | ||
], | ||
tblgen = "@llvm-project//mlir:mlir-tblgen", | ||
td_file = "ForwardInsertToExtract.td", | ||
deps = [ | ||
"@llvm-project//mlir:OpBaseTdFiles", | ||
"@llvm-project//mlir:PassBaseTdFiles", | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
set(LLVM_TARGET_DEFINITIONS ForwardInsertToExtract.td) | ||
mlir_tablegen(ForwardInsertToExtract.h.inc -gen-pass-decls -name ForwardInsertToExtract) | ||
add_public_tablegen_target(MLIRHeirForwardInsertToExtractIncGen) | ||
|
||
add_mlir_dialect_library(MLIRHeirForwardInsertToExtract | ||
ForwardInsertToExtract.cpp | ||
|
||
DEPENDS | ||
MLIRHeirForwardInsertToExtractIncGen | ||
|
||
LINK_LIBS PUBLIC | ||
MLIRModArithDialect | ||
MLIRIR | ||
MLIRInferTypeOpInterface | ||
MLIRArithDialect | ||
MLIRSupport | ||
MLIRDialect | ||
MLIRIR | ||
) |
130 changes: 130 additions & 0 deletions
130
lib/Transforms/ForwardInsertToExtract/ForwardInsertToExtract.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
#include "lib/Transforms/ForwardInsertToExtract/ForwardInsertToExtract.h" | ||
|
||
#include <utility> | ||
|
||
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project | ||
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project | ||
#include "mlir/include/mlir/Dialect/Affine/Utils.h" // from @llvm-project | ||
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project | ||
#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project | ||
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project | ||
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project | ||
#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project | ||
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project | ||
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project | ||
#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project | ||
|
||
#define DEBUG_TYPE "forward-insert-to-extract" | ||
|
||
namespace mlir { | ||
namespace heir { | ||
|
||
#define GEN_PASS_DEF_FORWARDINSERTTOEXTRACT | ||
#include "lib/Transforms/ForwardInsertToExtract/ForwardInsertToExtract.h.inc" | ||
|
||
bool ForwardSingleInsertToExtract::isForwardableOp( | ||
Operation *potentialInsert, tensor::ExtractOp &extractOp) const { | ||
if (!dominanceInfo.properlyDominates(potentialInsert, | ||
extractOp.getOperation())) { | ||
LLVM_DEBUG(llvm::dbgs() << "insert op does not dominate extract op\n"); | ||
return false; | ||
} | ||
|
||
if (extractOp->getBlock() != potentialInsert->getBlock()) { | ||
LLVM_DEBUG(llvm::dbgs() | ||
<< "insert and extract op are not in the same block\n"); | ||
return false; | ||
} | ||
|
||
return llvm::TypeSwitch<Operation &, bool>(*potentialInsert) | ||
.Case<tensor::InsertOp>([&](auto insertOp) { | ||
ValueRange insertIndices = insertOp.getIndices(); | ||
ValueRange extractIndices = extractOp.getIndices(); | ||
if (insertIndices != extractIndices) { | ||
LLVM_DEBUG(llvm::dbgs() | ||
<< "insert and extract op do not have matching indices\n"); | ||
return false; | ||
} | ||
|
||
// Naively scan through the operations between the two ops and check if | ||
// anything prevents forwarding. | ||
for (auto currentNode = insertOp->getNextNode(); | ||
currentNode != extractOp.getOperation(); | ||
currentNode = currentNode->getNextNode()) { | ||
if (currentNode->getNumRegions() > 0) { | ||
LLVM_DEBUG(llvm::dbgs() << "an op with control flow is between the " | ||
"insert and extract op\n"); | ||
return false; | ||
} | ||
|
||
if (auto op = dyn_cast<tensor::InsertOp>(currentNode)) { | ||
if (op.getDest() == insertOp.getDest() && | ||
op.getIndices() == insertIndices) { | ||
LLVM_DEBUG(llvm::dbgs() | ||
<< "an intermediate op inserts to the same index\n"); | ||
return false; | ||
} | ||
} | ||
} | ||
return true; | ||
}) | ||
.Default([&](Operation &) { | ||
LLVM_DEBUG(llvm::dbgs() | ||
<< "Unsupported op type, cannot check for forwardability\n"); | ||
return false; | ||
}); | ||
} | ||
|
||
FailureOr<Value> getInsertedValue(Operation *insertOp) { | ||
return llvm::TypeSwitch<Operation &, FailureOr<Value>>(*insertOp) | ||
.Case<tensor::InsertOp>( | ||
[&](auto insertOp) { return insertOp.getScalar(); }) | ||
.Default([&](Operation &) { return failure(); }); | ||
} | ||
|
||
LogicalResult ForwardSingleInsertToExtract::matchAndRewrite( | ||
tensor::ExtractOp extractOp, PatternRewriter &rewriter) const { | ||
LLVM_DEBUG(llvm::dbgs() << "Considering extractOp for replacement: " | ||
<< extractOp << "\n"); | ||
|
||
auto *def = extractOp.getTensor().getDefiningOp(); | ||
if (def != nullptr) { | ||
LLVM_DEBUG(llvm::dbgs() | ||
<< "DefiningOp of the one considered: " | ||
<< *extractOp.getTensor().getDefiningOp<tensor::InsertOp>() | ||
<< "\n"); | ||
|
||
LLVM_DEBUG(llvm::dbgs() | ||
<< "Considering def for forwarding: " << *def << "\n"); | ||
if (isForwardableOp(def, extractOp)) { | ||
auto result = getInsertedValue(def); | ||
LLVM_DEBUG(llvm::dbgs() << "def is forwardable: " << *def << "\n"); | ||
if (failed(result)) { | ||
return failure(); | ||
} | ||
auto value = result.value(); | ||
rewriter.replaceAllUsesWith(extractOp, value); | ||
return success(); | ||
} | ||
LLVM_DEBUG(llvm::dbgs() << "def is not forwardable: " << *def << "\n"); | ||
} else { | ||
LLVM_DEBUG(llvm::dbgs() << "def is nullptr " << "\n"); | ||
} | ||
return failure(); | ||
} | ||
|
||
struct ForwardInsertToExtract | ||
: impl::ForwardInsertToExtractBase<ForwardInsertToExtract> { | ||
using ForwardInsertToExtractBase::ForwardInsertToExtractBase; | ||
|
||
void runOnOperation() override { | ||
MLIRContext *context = &getContext(); | ||
RewritePatternSet patterns(context); | ||
DominanceInfo dom(getOperation()); | ||
patterns.add<ForwardSingleInsertToExtract>(context, dom); | ||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); | ||
} | ||
}; | ||
|
||
} // namespace heir | ||
} // namespace mlir |
41 changes: 41 additions & 0 deletions
41
lib/Transforms/ForwardInsertToExtract/ForwardInsertToExtract.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
#ifndef LIB_TRANSFORMS_FORWARDINSERTTOEXTRACT_FORWARDINSERTTOEXTRACT_H_ | ||
#define LIB_TRANSFORMS_FORWARDINSERTTOEXTRACT_FORWARDINSERTTOEXTRACT_H_ | ||
|
||
#include "mlir/include/mlir/Dialect/Affine/Utils.h" // from @llvm-project | ||
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project | ||
#include "mlir/include/mlir/IR/Dominance.h" // from @llvm-project | ||
#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project | ||
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project | ||
#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project | ||
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project | ||
#include "mlir/include/mlir/Transforms/Passes.h" // from @llvm-project | ||
|
||
namespace mlir { | ||
namespace heir { | ||
|
||
#define GEN_PASS_DECL | ||
#include "lib/Transforms/ForwardInsertToExtract/ForwardInsertToExtract.h.inc" | ||
|
||
#define GEN_PASS_REGISTRATION | ||
#include "lib/Transforms/ForwardInsertToExtract/ForwardInsertToExtract.h.inc" | ||
|
||
struct ForwardSingleInsertToExtract | ||
: public OpRewritePattern<tensor::ExtractOp> { | ||
ForwardSingleInsertToExtract(mlir::MLIRContext *context, DominanceInfo &dom) | ||
: OpRewritePattern<tensor::ExtractOp>(context, 3), dominanceInfo(dom) {} | ||
|
||
public: | ||
LogicalResult matchAndRewrite(tensor::ExtractOp op, | ||
PatternRewriter &rewriter) const override; | ||
|
||
private: | ||
bool isForwardableOp(Operation *potentialInsert, | ||
tensor::ExtractOp &extractOp) const; | ||
|
||
DominanceInfo &dominanceInfo; | ||
}; | ||
|
||
} // namespace heir | ||
} // namespace mlir | ||
|
||
#endif // LIB_TRANSFORMS_FORWARDINSERTTOEXTRACT_FORWARDINSERTTOEXTRACT_H_ |
18 changes: 18 additions & 0 deletions
18
lib/Transforms/ForwardInsertToExtract/ForwardInsertToExtract.td
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
#ifndef LIB_TRANSFORMS_FORWARDINSERTTOEXTRACT_FORWARDINSERTTOEXTRACT_TD_ | ||
#define LIB_TRANSFORMS_FORWARDINSERTTOEXTRACT_FORWARDINSERTTOEXTRACT_TD_ | ||
|
||
include "mlir/Pass/PassBase.td" | ||
|
||
def ForwardInsertToExtract : Pass<"forward-insert-to-extract"> { | ||
let summary = "Forward inserts to extracts within a single block"; | ||
let description = [{ | ||
This pass is similar to forward-store-to-load pass where store ops | ||
are forwarded load ops; here instead tensor.insert ops are forwarded | ||
to tensor.extract ops. | ||
|
||
Does not support complex control flow within a block, nor ops with | ||
arbitrary subregions. | ||
}]; | ||
} | ||
|
||
#endif // LIB_TRANSFORMS_FORWARDINSERTTOEXTRACT_FORWARDINSERTTOEXTRACT_TD_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
load("//bazel:lit.bzl", "glob_lit_tests") | ||
|
||
package(default_applicable_licenses = ["@heir//:license"]) | ||
|
||
glob_lit_tests( | ||
name = "all_tests", | ||
data = ["@heir//tests:test_utilities"], | ||
driver = "@heir//tests:run_lit.sh", | ||
test_file_exts = ["mlir"], | ||
) |
Oops, something went wrong.