diff --git a/lib/Transforms/ForwardInsertToExtract/BUILD b/lib/Transforms/ForwardInsertToExtract/BUILD new file mode 100644 index 000000000..e91d2969f --- /dev/null +++ b/lib/Transforms/ForwardInsertToExtract/BUILD @@ -0,0 +1,51 @@ +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") + +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "ForwardInsertToExtract", + srcs = ["ForwardInsertToExtract.cpp"], + hdrs = [ + "ForwardInsertToExtract.h", + ], + deps = [ + ":pass_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AffineDialect", + "@llvm-project//mlir:AffineUtils", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + ], +) +# ForwardInsertToExtract tablegen and headers. + +gentbl_cc_library( + name = "pass_inc_gen", + tbl_outs = [ + ( + [ + "-gen-pass-decls", + "-name=ForwardInsertToExtract", + ], + "ForwardInsertToExtract.h.inc", + ), + ( + ["-gen-pass-doc"], + "ForwardInsertToExtractPasses.md", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "ForwardInsertToExtract.td", + deps = [ + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:PassBaseTdFiles", + ], +) diff --git a/lib/Transforms/ForwardInsertToExtract/CMakeLists.txt b/lib/Transforms/ForwardInsertToExtract/CMakeLists.txt new file mode 100644 index 000000000..7deb87f86 --- /dev/null +++ b/lib/Transforms/ForwardInsertToExtract/CMakeLists.txt @@ -0,0 +1,19 @@ +set(LLVM_TARGET_DEFINITIONS ForwardInsertToExtract.td) +mlir_tablegen(ForwardInsertToExtract.h.inc -gen-pass-decls -name ForwardInsertToExtract) +add_public_tablegen_target(MLIRHeirForwardInsertToExtractIncGen) + +add_mlir_dialect_library(MLIRHeirForwardInsertToExtract + ForwardInsertToExtract.cpp + + DEPENDS + MLIRHeirForwardInsertToExtractIncGen + + LINK_LIBS PUBLIC + MLIRModArithDialect + MLIRIR + MLIRInferTypeOpInterface + MLIRArithDialect + MLIRSupport + MLIRDialect + MLIRIR +) diff --git a/lib/Transforms/ForwardInsertToExtract/ForwardInsertToExtract.cpp b/lib/Transforms/ForwardInsertToExtract/ForwardInsertToExtract.cpp new file mode 100644 index 000000000..4cad90e89 --- /dev/null +++ b/lib/Transforms/ForwardInsertToExtract/ForwardInsertToExtract.cpp @@ -0,0 +1,132 @@ +#include "lib/Transforms/ForwardInsertToExtract/ForwardInsertToExtract.h" + +#include + +#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project +#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Affine/Utils.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/include/mlir/IR/Value.h" // from @llvm-project +#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project + +#define DEBUG_TYPE "forward-insert-to-extract" + +namespace mlir { +namespace heir { + +#define GEN_PASS_DEF_FORWARDINSERTTOEXTRACT +#include "lib/Transforms/ForwardInsertToExtract/ForwardInsertToExtract.h.inc" + +bool ForwardSingleInsertToExtract::isForwardableOp( + Operation *potentialInsert, tensor::ExtractOp &extractOp) const { + if (!dominanceInfo.properlyDominates(potentialInsert, + extractOp.getOperation())) { + LLVM_DEBUG(llvm::dbgs() << "insert op does not dominate extract op\n"); + return false; + } + + if (extractOp->getBlock() != potentialInsert->getBlock()) { + LLVM_DEBUG(llvm::dbgs() + << "insert and extract op are not in the same block\n"); + return false; + } + + return llvm::TypeSwitch(*potentialInsert) + .Case([&](auto insertOp) { + ValueRange insertIndices = insertOp.getIndices(); + ValueRange extractIndices = extractOp.getIndices(); + if (insertIndices != extractIndices) { + LLVM_DEBUG(llvm::dbgs() + << "insert and extract op do not have matching indices\n"); + return false; + } + + // Naively scan through the operations between the two ops and check if + // anything prevents forwarding. + for (auto currentNode = insertOp->getNextNode(); + currentNode != extractOp.getOperation(); + currentNode = currentNode->getNextNode()) { + if (currentNode->getNumRegions() > 0) { + LLVM_DEBUG(llvm::dbgs() << "an op with control flow is between the " + "insert and extract op\n"); + return false; + } + + if (auto op = dyn_cast(currentNode)) { + if (op.getDest() == insertOp.getDest() && + op.getIndices() == insertIndices) { + LLVM_DEBUG(llvm::dbgs() + << "an intermediate op inserts to the same index\n"); + return false; + } + } + } + return true; + }) + .Default([&](Operation &) { + LLVM_DEBUG(llvm::dbgs() + << "Unsupported op type, cannot check for forwardability\n"); + return false; + }); +} + +FailureOr getInsertedValue(Operation *insertOp) { + return llvm::TypeSwitch>(*insertOp) + .Case( + [&](auto insertOp) { return insertOp.getScalar(); }) + .Default([&](Operation &) { return failure(); }); +} + +LogicalResult ForwardSingleInsertToExtract::matchAndRewrite( + tensor::ExtractOp extractOp, PatternRewriter &rewriter) const { + LLVM_DEBUG(llvm::dbgs() << "Considering extractOp for replacement: " + << extractOp << "\n"); + + auto *def = extractOp.getTensor().getDefiningOp(); + if (def != nullptr) { + LLVM_DEBUG(llvm::dbgs() + << "DefiningOp of the one considered: " + << *extractOp.getTensor().getDefiningOp() + << "\n"); + + LLVM_DEBUG(llvm::dbgs() + << "Considering def for forwarding: " << *def << "\n"); + if (isForwardableOp(def, extractOp)) { + auto result = getInsertedValue(def); + LLVM_DEBUG(llvm::dbgs() << "def is forwardable: " << *def << "\n"); + if (failed(result)) { + return failure(); + } + auto value = result.value(); + rewriter.replaceAllUsesWith(extractOp, value); + return success(); + } + LLVM_DEBUG(llvm::dbgs() << "def is not forwardable: " << *def << "\n"); + } else { + LLVM_DEBUG(llvm::dbgs() << "def is nullptr " << "\n"); + } + return failure(); +} + +struct ForwardInsertToExtract + : impl::ForwardInsertToExtractBase { + using ForwardInsertToExtractBase::ForwardInsertToExtractBase; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + DominanceInfo dom(getOperation()); + patterns.add(context, dom); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; + +} // namespace heir +} // namespace mlir diff --git a/lib/Transforms/ForwardInsertToExtract/ForwardInsertToExtract.h b/lib/Transforms/ForwardInsertToExtract/ForwardInsertToExtract.h new file mode 100644 index 000000000..34d4adb67 --- /dev/null +++ b/lib/Transforms/ForwardInsertToExtract/ForwardInsertToExtract.h @@ -0,0 +1,43 @@ +#ifndef LIB_TRANSFORMS_FORWARDINSERTTOEXTRACT_FORWARDINSERTTOEXTRACT_H_ +#define LIB_TRANSFORMS_FORWARDINSERTTOEXTRACT_FORWARDINSERTTOEXTRACT_H_ + +#include "mlir/include/mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Affine/Utils.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/include/mlir/IR/Dominance.h" // from @llvm-project +#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/include/mlir/Transforms/Passes.h" // from @llvm-project + +namespace mlir { +namespace heir { + +#define GEN_PASS_DECL +#include "lib/Transforms/ForwardInsertToExtract/ForwardInsertToExtract.h.inc" + +#define GEN_PASS_REGISTRATION +#include "lib/Transforms/ForwardInsertToExtract/ForwardInsertToExtract.h.inc" + +struct ForwardSingleInsertToExtract + : public OpRewritePattern { + ForwardSingleInsertToExtract(mlir::MLIRContext *context, DominanceInfo &dom) + : OpRewritePattern(context, 3), dominanceInfo(dom) {} + + public: + LogicalResult matchAndRewrite(tensor::ExtractOp op, + PatternRewriter &rewriter) const override; + + private: + bool isForwardableOp(Operation *potentialInsert, + tensor::ExtractOp &extractOp) const; + + DominanceInfo &dominanceInfo; +}; + +} // namespace heir +} // namespace mlir + +#endif // LIB_TRANSFORMS_FORWARDINSERTTOEXTRACT_FORWARDINSERTTOEXTRACT_H_ diff --git a/lib/Transforms/ForwardInsertToExtract/ForwardInsertToExtract.td b/lib/Transforms/ForwardInsertToExtract/ForwardInsertToExtract.td new file mode 100644 index 000000000..e5a3f6886 --- /dev/null +++ b/lib/Transforms/ForwardInsertToExtract/ForwardInsertToExtract.td @@ -0,0 +1,23 @@ +#ifndef LIB_TRANSFORMS_FORWARDINSERTTOEXTRACT_FORWARDINSERTTOEXTRACT_TD_ +#define LIB_TRANSFORMS_FORWARDINSERTTOEXTRACT_FORWARDINSERTTOEXTRACT_TD_ + +include "mlir/Pass/PassBase.td" + +def ForwardInsertToExtract : Pass<"forward-insert-to-extract"> { + let summary = "Forward inserts to extracts within a single block"; + let description = [{ + This pass is similar to forward-store-to-load pass where store ops + are forwarded load ops; here instead tensor.insert ops are forwarded + to tensor.extract ops. + + Does not support complex control flow within a block, nor ops with + arbitrary subregions. + }]; + let dependentDialects = [ + "mlir::affine::AffineDialect", + "mlir::memref::MemRefDialect", + "mlir::tensor::TensorDialect" + ]; +} + +#endif // LIB_TRANSFORMS_FORWARDINSERTTOEXTRACT_FORWARDINSERTTOEXTRACT_TD_ diff --git a/tests/forward_insert_to_extract/BUILD b/tests/forward_insert_to_extract/BUILD new file mode 100644 index 000000000..c571e6fc6 --- /dev/null +++ b/tests/forward_insert_to_extract/BUILD @@ -0,0 +1,10 @@ +load("//bazel:lit.bzl", "glob_lit_tests") + +package(default_applicable_licenses = ["@heir//:license"]) + +glob_lit_tests( + name = "all_tests", + data = ["@heir//tests:test_utilities"], + driver = "@heir//tests:run_lit.sh", + test_file_exts = ["mlir"], +) diff --git a/tests/forward_insert_to_extract/forward_insert_to_extract.mlir b/tests/forward_insert_to_extract/forward_insert_to_extract.mlir new file mode 100644 index 000000000..f2ecbff1e --- /dev/null +++ b/tests/forward_insert_to_extract/forward_insert_to_extract.mlir @@ -0,0 +1,164 @@ +// RUN: heir-opt -forward-insert-to-extract %s | FileCheck %s + + +#encoding = #lwe.polynomial_evaluation_encoding +#my_poly = #polynomial.int_polynomial<1 + x**16> +#ring= #polynomial.ring +#rlwe_params = #lwe.rlwe_params + + +!cc = !openfhe.crypto_context +!pt = !lwe.rlwe_plaintext +!ptf16 = !lwe.rlwe_plaintext +!ct = !lwe.rlwe_ciphertext + + +// CHECK-LABEL: @successful_forwarding +// CHECK-SAME: (%[[ARG0:.*]]: !openfhe.crypto_context, + + +func.func @successful_forwarding(%arg0: !cc, %arg1: tensor<1x16x!ct>, %arg2: tensor<1x16x!ct>, %arg3: tensor<16xf64>, %arg4: tensor<16xf64>) -> tensor<1x16x!ct> { + + // CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : index + %c0 = arith.constant 0 : index + // CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index + %c1 = arith.constant 1 : index + + // CHECK-NEXT: %[[EXTRACTED:.*]] = tensor.extract + %extracted = tensor.extract %arg1[%c0, %c0] : tensor<1x16x!ct> + // CHECK-NEXT: %[[EXTRACTED0:.*]] = tensor.extract + %extracted_0 = tensor.extract %arg2[%c0, %c0] : tensor<1x16x!ct> + // CHECK-NEXT: %[[VAL0:.*]] = openfhe.make_ckks_packed_plaintext %[[ARG0]] + %0 = openfhe.make_ckks_packed_plaintext %arg0, %arg3 : (!cc, tensor<16xf64>) -> !pt + // CHECK-NEXT: %[[VAL1:.*]] = openfhe.mul_plain %[[ARG0]], %[[EXTRACTED]], %[[VAL0]] + %1 = openfhe.mul_plain %arg0, %extracted, %0 : (!cc, !ct, !pt) -> !ct + // CHECK-NEXT: %[[VAL2:.*]] = openfhe.add %[[ARG0]], %[[EXTRACTED0]], %[[VAL1]] + %2 = openfhe.add %arg0, %extracted_0, %1 : (!cc, !ct, !ct) -> !ct + + // CHECK-NEXT: %[[INSERTED0:.*]] = tensor.insert %[[VAL2]] + %inserted = tensor.insert %2 into %arg2[%c0, %c0] : tensor<1x16x!ct> + + // CHECK-NEXT: %[[EXTRACTED1:.*]] = tensor.extract + %extracted_1 = tensor.extract %arg1[%c0, %c1] : tensor<1x16x!ct> + // CHECK-NOT: tensor.extract %[[INSERTED0]] + %extracted_2 = tensor.extract %inserted[%c0, %c0] : tensor<1x16x!ct> + // CHECK-NEXT: %[[VAL3:.*]] = openfhe.make_ckks_packed_plaintext + %3 = openfhe.make_ckks_packed_plaintext %arg0, %arg4 : (!cc, tensor<16xf64>) -> !lwe.rlwe_plaintext>, underlying_type = f32> + // CHECK-NEXT: %[[VAL4:.*]] = openfhe.mul_plain + %4 = openfhe.mul_plain %arg0, %extracted_1, %3 : (!cc, !ct, !lwe.rlwe_plaintext>, underlying_type = f32>) -> !ct + // CHECK-NEXT: %[[VAL5:.*]] = openfhe.add + %5 = openfhe.add %arg0, %extracted_2, %4 : (!cc, !ct, !ct) -> !ct + // CHECK-NEXT: %[[INSERTED1:.*]] = tensor.insert + %inserted_3 = tensor.insert %5 into %inserted[%c0, %c0] : tensor<1x16x!ct> + // CHECK-NEXT: return %[[INSERTED1]] + return %inserted_3 : tensor<1x16x!ct> +} + + +//hits def == nullptr +// CHECK-LABEL: @forward_from_func_arg +// CHECK-SAME: (%[[ARG0:.*]]: !openfhe.crypto_context, + +func.func @forward_from_func_arg(%arg0: !cc, %arg1: tensor<1x16x!ct>, %arg2: tensor<1x16x!ct>)-> !ct { + // CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : index + %c0 = arith.constant 0 : index + // CHECK-NEXT: %[[EXTRACTED:.*]] = tensor.extract + %extracted = tensor.extract %arg1[%c0, %c0] : tensor<1x16x!ct> + + return %extracted : !ct +} + +// CHECK-LABEL: @forwarding_with_an_insert_in_between +// CHECK-SAME: (%[[ARG0:.*]]: !openfhe.crypto_context, + +func.func @forwarding_with_an_insert_in_between(%arg0: !cc, %arg1: tensor<1x16x!ct>, %arg2: tensor<1x16x!ct>, %arg3: tensor<16xf64> )-> !ct { + + // CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : index + %c0 = arith.constant 0 : index + + // CHECK-NEXT: %[[EXTRACTED:.*]] = tensor.extract + %extracted = tensor.extract %arg1[%c0, %c0] : tensor<1x16x!ct> + // CHECK-NEXT: %[[EXTRACTED0:.*]] = tensor.extract + %extracted_0 = tensor.extract %arg2[%c0, %c0] : tensor<1x16x!ct> + // CHECK-NEXT: %[[VAL0:.*]] = openfhe.make_ckks_packed_plaintext %[[ARG0]] + %0 = openfhe.make_ckks_packed_plaintext %arg0, %arg3 : (!cc, tensor<16xf64>) -> !pt + // CHECK-NEXT: %[[VAL1:.*]] = openfhe.mul_plain %[[ARG0]], %[[EXTRACTED]], %[[VAL0]] + %1 = openfhe.mul_plain %arg0, %extracted, %0 : (!cc, !ct, !pt) -> !ct + // CHECK-NEXT: %[[VAL2:.*]] = openfhe.add %[[ARG0]], %[[EXTRACTED0]], %[[VAL1]] + %2 = openfhe.add %arg0, %extracted_0, %1 : (!cc, !ct, !ct) -> !ct + // CHECK-NEXT: %[[VALA2:.*]] = openfhe.add %[[ARG0]], %[[EXTRACTED0]], %[[VAL2]] + %a2 = openfhe.add %arg0, %extracted_0, %2 : (!cc, !ct, !ct) -> !ct + // CHECK-NOT: %[[INSERTED0:.*]] = tensor.insert %[[VAL2]] + %inserted = tensor.insert %2 into %arg2[%c0, %c0] : tensor<1x16x!ct> + // CHECK-NOT: %[[INSERTED1:.*]] = tensor.insert %[[VALA2]] + %inserted_1 = tensor.insert %a2 into %arg1[%c0, %c0] : tensor<1x16x!ct> + + // CHECK-NOT: tensor.extract %[[INSERTED1]] + %extracted_2 = tensor.extract %inserted_1[%c0, %c0] : tensor<1x16x!ct> + // CHECK: return %[[VALA2]] + return %extracted_2 : !ct +} + +// CHECK-LABEL: @forwarding_with_an_operation_in_between +// CHECK-SAME: (%[[ARG0:.*]]: !openfhe.crypto_context, + +func.func @forwarding_with_an_operation_in_between(%arg0: !cc, %arg1: tensor<1x16x!ct>, %arg2: tensor<1x16x!ct>, %arg3: tensor<16xf64>, %arg4: i1 )-> !ct { + + // CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : index + %c0 = arith.constant 0 : index + + // CHECK-NEXT: %[[EXTRACTED:.*]] = tensor.extract + %extracted = tensor.extract %arg1[%c0, %c0] : tensor<1x16x!ct> + // CHECK-NEXT: %[[EXTRACTED0:.*]] = tensor.extract + %extracted_0 = tensor.extract %arg2[%c0, %c0] : tensor<1x16x!ct> + // CHECK-NEXT: %[[VAL0:.*]] = openfhe.make_ckks_packed_plaintext %[[ARG0]] + %0 = openfhe.make_ckks_packed_plaintext %arg0, %arg3 : (!cc, tensor<16xf64>) -> !pt + // CHECK-NEXT: %[[VAL1:.*]] = openfhe.mul_plain %[[ARG0]], %[[EXTRACTED]], %[[VAL0]] + %1 = openfhe.mul_plain %arg0, %extracted, %0 : (!cc, !ct, !pt) -> !ct + // CHECK-NEXT: %[[VAL2:.*]] = openfhe.add %[[ARG0]], %[[EXTRACTED0]], %[[VAL1]] + %2 = openfhe.add %arg0, %extracted_0, %1 : (!cc, !ct, !ct) -> !ct + + // CHECK-NOT: %[[INSERTED0:.*]] = tensor.insert %[[VAL2]] + %inserted = tensor.insert %2 into %arg2[%c0, %c0] : tensor<1x16x!ct> + + scf.if %arg4 { + // CHECK-NOT: %[[VALa2:.*]] = openfhe.add %[[ARG0]], %[[EXTRACTED0]], %[[VAL2]] + %a2 = openfhe.add %arg0, %extracted_0, %2 : (!cc, !ct, !ct) -> !ct + // CHECK-NOT: %[[INSERTED1:.*]] = tensor.insert %[[VAL1]] + %inserted_1 = tensor.insert %a2 into %arg2[%c0, %c0] : tensor<1x16x!ct> + } + // CHECK-NOT: tensor.extract %[[INSERTED0]] + %extracted_2 = tensor.extract %inserted[%c0, %c0] : tensor<1x16x!ct> + return %extracted_2 : !ct +} + + +// CHECK-LABEL: @two_extracts_both_forwarded +// CHECK-SAME: (%[[ARG0:.*]]: !openfhe.crypto_context, + +func.func @two_extracts_both_forwarded(%arg0: !cc, %arg1: tensor<1x16x!ct>, %arg2: tensor<1x16x!ct>, %arg3: tensor<16xf64>) -> !ct { + + // CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : index + %c0 = arith.constant 0 : index + + // CHECK-NEXT: %[[EXTRACTED:.*]] = tensor.extract + %extracted = tensor.extract %arg1[%c0, %c0] : tensor<1x16x!ct> + // CHECK-NEXT: %[[EXTRACTED0:.*]] = tensor.extract + %extracted_0 = tensor.extract %arg2[%c0, %c0] : tensor<1x16x!ct> + // CHECK-NEXT: %[[VAL0:.*]] = openfhe.make_ckks_packed_plaintext %[[ARG0]] + %0 = openfhe.make_ckks_packed_plaintext %arg0, %arg3 : (!cc, tensor<16xf64>) -> !pt + // CHECK-NEXT: %[[VAL1:.*]] = openfhe.mul_plain %[[ARG0]], %[[EXTRACTED]], %[[VAL0]] + %1 = openfhe.mul_plain %arg0, %extracted, %0 : (!cc, !ct, !pt) -> !ct + // CHECK-NEXT: %[[VAL2:.*]] = openfhe.add %[[ARG0]], %[[EXTRACTED0]], %[[VAL1]] + %2 = openfhe.add %arg0, %extracted_0, %1 : (!cc, !ct, !ct) -> !ct + + %inserted = tensor.insert %2 into %arg2[%c0, %c0] : tensor<1x16x!ct> + + // CHECK-NOT: %[[EXTRACTED1:.*]] = tensor.extract %[[INSERTED0]] + %extracted_1 = tensor.extract %inserted[%c0, %c0] : tensor<1x16x!ct> + // CHECK-NOT: %[[EXTRACTED2:.*]] = tensor.extract %[[INSERTED0]] + %extracted_2 = tensor.extract %inserted[%c0, %c0] : tensor<1x16x!ct> + // CHECK: openfhe.add %[[ARG0]], %[[VAL2]], %[[VAL2]] + %3 = openfhe.add %arg0, %extracted_1, %extracted_2 : (!cc, !ct, !ct) -> !ct + return %3: !ct +} diff --git a/tools/BUILD b/tools/BUILD index f7694422d..03b9c6e51 100644 --- a/tools/BUILD +++ b/tools/BUILD @@ -83,6 +83,7 @@ cc_binary( "@heir//lib/Transforms/ConvertSecretInsertToStaticInsert", "@heir//lib/Transforms/ConvertSecretWhileToStaticFor", "@heir//lib/Transforms/ElementwiseToAffine", + "@heir//lib/Transforms/ForwardInsertToExtract", "@heir//lib/Transforms/ForwardStoreToLoad", "@heir//lib/Transforms/FullLoopUnroll", "@heir//lib/Transforms/LinalgCanonicalizations", diff --git a/tools/heir-opt.cpp b/tools/heir-opt.cpp index 8cb7bc4de..23fabe173 100644 --- a/tools/heir-opt.cpp +++ b/tools/heir-opt.cpp @@ -54,6 +54,7 @@ #include "lib/Transforms/ConvertSecretInsertToStaticInsert/ConvertSecretInsertToStaticInsert.h" #include "lib/Transforms/ConvertSecretWhileToStaticFor/ConvertSecretWhileToStaticFor.h" #include "lib/Transforms/ElementwiseToAffine/ElementwiseToAffine.h" +#include "lib/Transforms/ForwardInsertToExtract/ForwardInsertToExtract.h" #include "lib/Transforms/ForwardStoreToLoad/ForwardStoreToLoad.h" #include "lib/Transforms/FullLoopUnroll/FullLoopUnroll.h" #include "lib/Transforms/LinalgCanonicalizations/LinalgCanonicalizations.h" @@ -701,6 +702,7 @@ int main(int argc, char **argv) { registerConvertSecretExtractToStaticExtractPasses(); registerConvertSecretInsertToStaticInsertPasses(); registerApplyFoldersPasses(); + registerForwardInsertToExtractPasses(); registerForwardStoreToLoadPasses(); registerOperationBalancerPasses(); registerStraightLineVectorizerPasses();