diff --git a/lib/Dialect/BGV/CMakeLists.txt b/lib/Dialect/BGV/CMakeLists.txt index d9b30c6ed..9cd151c5a 100644 --- a/lib/Dialect/BGV/CMakeLists.txt +++ b/lib/Dialect/BGV/CMakeLists.txt @@ -1,8 +1,6 @@ add_subdirectory(IR) -add_subdirectory(Transforms) add_mlir_dialect_library(MLIRBGV IR/BGVDialect.cpp - Transforms/AddClientInterface.cpp ADDITIONAL_HEADER_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/IR @@ -10,7 +8,6 @@ add_mlir_dialect_library(MLIRBGV DEPENDS MLIRBGVIncGen MLIRBGVOpsIncGen - MLIRBGVPassesIncGen LINK_LIBS PUBLIC MLIRIR diff --git a/lib/Dialect/BGV/Transforms/AddClientInterface.h b/lib/Dialect/BGV/Transforms/AddClientInterface.h deleted file mode 100644 index c3f86e2f5..000000000 --- a/lib/Dialect/BGV/Transforms/AddClientInterface.h +++ /dev/null @@ -1,17 +0,0 @@ -#ifndef LIB_DIALECT_BGV_TRANSFORMS_ADDCLIENTINTERFACE_H_ -#define LIB_DIALECT_BGV_TRANSFORMS_ADDCLIENTINTERFACE_H_ - -#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project - -namespace mlir { -namespace heir { -namespace bgv { - -#define GEN_PASS_DECL_ADDCLIENTINTERFACE -#include "lib/Dialect/BGV/Transforms/Passes.h.inc" - -} // namespace bgv -} // namespace heir -} // namespace mlir - -#endif // LIB_DIALECT_BGV_TRANSFORMS_ADDCLIENTINTERFACE_H_ diff --git a/lib/Dialect/BGV/Transforms/BUILD b/lib/Dialect/BGV/Transforms/BUILD deleted file mode 100644 index 6a2713a19..000000000 --- a/lib/Dialect/BGV/Transforms/BUILD +++ /dev/null @@ -1,59 +0,0 @@ -load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") - -package( - default_applicable_licenses = ["@heir//:license"], - default_visibility = ["//visibility:public"], -) - -cc_library( - name = "Transforms", - hdrs = [ - "Passes.h", - ], - deps = [ - ":AddClientInterface", - ":pass_inc_gen", - "@heir//lib/Dialect/BGV/IR:Dialect", - ], -) - -cc_library( - name = "AddClientInterface", - srcs = ["AddClientInterface.cpp"], - hdrs = [ - "AddClientInterface.h", - ], - deps = [ - ":pass_inc_gen", - "@heir//lib/Dialect/BGV/IR:Dialect", - "@heir//lib/Dialect/LWE/IR:Dialect", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:Support", - ], -) - -gentbl_cc_library( - name = "pass_inc_gen", - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=BGV", - ], - "Passes.h.inc", - ), - ( - ["-gen-pass-doc"], - "BGVPasses.md", - ), - ], - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "Passes.td", - deps = [ - "@llvm-project//mlir:OpBaseTdFiles", - "@llvm-project//mlir:PassBaseTdFiles", - ], -) diff --git a/lib/Dialect/BGV/Transforms/CMakeLists.txt b/lib/Dialect/BGV/Transforms/CMakeLists.txt deleted file mode 100644 index 745dc7a7d..000000000 --- a/lib/Dialect/BGV/Transforms/CMakeLists.txt +++ /dev/null @@ -1,3 +0,0 @@ -set(LLVM_TARGET_DEFINITIONS Passes.td) -mlir_tablegen(Passes.h.inc -gen-pass-decls -name BGV) -add_public_tablegen_target(MLIRBGVPassesIncGen) diff --git a/lib/Dialect/BGV/Transforms/Passes.h b/lib/Dialect/BGV/Transforms/Passes.h deleted file mode 100644 index b08177883..000000000 --- a/lib/Dialect/BGV/Transforms/Passes.h +++ /dev/null @@ -1,18 +0,0 @@ -#ifndef LIB_DIALECT_BGV_TRANSFORMS_PASSES_H_ -#define LIB_DIALECT_BGV_TRANSFORMS_PASSES_H_ - -#include "lib/Dialect/BGV/IR/BGVDialect.h" -#include "lib/Dialect/BGV/Transforms/AddClientInterface.h" - -namespace mlir { -namespace heir { -namespace bgv { - -#define GEN_PASS_REGISTRATION -#include "lib/Dialect/BGV/Transforms/Passes.h.inc" - -} // namespace bgv -} // namespace heir -} // namespace mlir - -#endif // LIB_DIALECT_BGV_TRANSFORMS_PASSES_H_ diff --git a/lib/Dialect/BGV/Transforms/Passes.td b/lib/Dialect/BGV/Transforms/Passes.td deleted file mode 100644 index bcd5cb67e..000000000 --- a/lib/Dialect/BGV/Transforms/Passes.td +++ /dev/null @@ -1,82 +0,0 @@ -#ifndef LIB_DIALECT_BGV_TRANSFORMS_PASSES_TD_ -#define LIB_DIALECT_BGV_TRANSFORMS_PASSES_TD_ - -include "mlir/Pass/PassBase.td" - -def AddClientInterface : Pass<"bgv-add-client-interface"> { - let summary = "Add client interfaces to BGV encrypted functions"; - let description = [{ - This pass adds encrypt and decrypt functions for each compiled function in the - IR. These functions maintain the same interface as the original function, - while the compiled function may lose some of this information by the lowerings - to ciphertext types (e.g., a scalar ciphertext, when lowered through BGV, must - be encoded as a tensor). - - Example: - - For an input function with signature - - ```mlir - #encoding = ... - #params = ... - !in_ty = !lwe.rlwe_ciphertext> - !out_ty = !lwe.rlwe_ciphertext - func.func @my_func(%arg0: !in_ty) -> !out_ty { - ... - } - ``` - - The pass will generate two new functions with signatures - - ```mlir - func.func @my_func__encrypt( - %arg0: tensor<32xi16>, - %sk: !lwe.rlwe_secret_key<...> - ) -> !in_ty - - func.func @my_func__decrypt( - %arg0: !out_ty, - %sk: !lwe.rlwe_secret_key<...> - ) -> i16 - ``` - - The `my_func__encrypt` function has the same order of operands as `my_func`, - and uses their `underylying_type` as the corresponding input type. - The last operand is the encryption key. - The same holds for `my_func__decrypt`, but the inputs are the return types - of `my_func` and the results are the underlying types of the return types of `my_func`. - - If `use-public-key` is set to true, the encrypt function uses - `lwe.rlwe_public_key` for encryption. - - If `one-value-per-helper-fn` is set to true, the encryption helpers are split - into separate functions, one for each SSA value being converted. For example, - using the same `!in_ty` and `!out_ty` as above, this function signature - - ```mlir - func.func @my_func(%arg0: !in_ty, %arg1: !in_ty) -> (!out_ty, !out_ty) - ``` - - generates the following four helpers. - - ```mlir - func.func @my_func__encrypt__arg0(%arg0: tensor<32xi16>, %sk: !lwe.rlwe_secret_key<...>) -> !in_ty - func.func @my_func__encrypt__arg1(%arg1: tensor<32xi16>, %sk: !lwe.rlwe_secret_key<...>) -> !in_ty - func.func @my_func__decrypt__result0(%arg0: !out_ty, %sk: !lwe.rlwe_secret_key<...>) -> i16 - func.func @my_func__decrypt__result1(%arg1: !out_ty, %sk: !lwe.rlwe_secret_key<...>) -> i16 - } - ``` - - The suffix `__argN` indicates the SSA value being encrypted is the N-th argument of `my_func`, - and similarly for `__resultN`. - }]; - let dependentDialects = ["mlir::heir::bgv::BGVDialect"]; - let options = [ - Option<"usePublicKey", "use-public-key", "bool", /*default=*/"false", - "If true, generate a client interface that uses a public key for encryption.">, - Option<"oneValuePerHelperFn", "one-value-per-helper-fn", "bool", /*default=*/"false", - "If true, split encryption helpers into separate functions for each SSA value."> - ]; -} - -#endif // LIB_DIALECT_BGV_TRANSFORMS_PASSES_TD_ diff --git a/lib/Dialect/LWE/CMakeLists.txt b/lib/Dialect/LWE/CMakeLists.txt index df1443de6..9a6d3ada9 100644 --- a/lib/Dialect/LWE/CMakeLists.txt +++ b/lib/Dialect/LWE/CMakeLists.txt @@ -3,6 +3,7 @@ add_subdirectory(Transforms) add_mlir_dialect_library(MLIRLWE IR/LWEDialect.cpp Transforms/SetDefaultParameters.cpp + Transforms/AddClientInterface.cpp ADDITIONAL_HEADER_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/IR diff --git a/lib/Dialect/BGV/Transforms/AddClientInterface.cpp b/lib/Dialect/LWE/Transforms/AddClientInterface.cpp similarity index 98% rename from lib/Dialect/BGV/Transforms/AddClientInterface.cpp rename to lib/Dialect/LWE/Transforms/AddClientInterface.cpp index 1327c846f..beca4bcb5 100644 --- a/lib/Dialect/BGV/Transforms/AddClientInterface.cpp +++ b/lib/Dialect/LWE/Transforms/AddClientInterface.cpp @@ -1,9 +1,8 @@ -#include "lib/Dialect/BGV/Transforms/AddClientInterface.h" +#include "lib/Dialect/LWE/Transforms/AddClientInterface.h" #include #include -#include "lib/Dialect/BGV/IR/BGVOps.h" #include "lib/Dialect/LWE/IR/LWEAttributes.h" #include "lib/Dialect/LWE/IR/LWEOps.h" #include "lib/Dialect/LWE/IR/LWETypes.h" @@ -21,10 +20,10 @@ namespace mlir { namespace heir { -namespace bgv { +namespace lwe { #define GEN_PASS_DEF_ADDCLIENTINTERFACE -#include "lib/Dialect/BGV/Transforms/Passes.h.inc" +#include "lib/Dialect/LWE/Transforms/Passes.h.inc" FailureOr getRlweParmsFromFuncOp(func::FuncOp op) { lwe::RLWEParamsAttr rlweParams = nullptr; @@ -154,6 +153,7 @@ LogicalResult convertFunc(func::FuncOp op, bool usePublicKey, auto module = op->getParentOfType(); auto rlweParamsResult = getRlweParmsFromFuncOp(op); if (failed(rlweParamsResult)) { + // TODO (#891): Add support for schemes other than BGV return failure(); } lwe::RLWEParamsAttr rlweParams = rlweParamsResult.value(); @@ -271,6 +271,6 @@ struct AddClientInterface : impl::AddClientInterfaceBase { if (result.wasInterrupted()) signalPassFailure(); } }; -} // namespace bgv +} // namespace lwe } // namespace heir } // namespace mlir diff --git a/lib/Dialect/LWE/Transforms/AddClientInterface.h b/lib/Dialect/LWE/Transforms/AddClientInterface.h new file mode 100644 index 000000000..185ee934f --- /dev/null +++ b/lib/Dialect/LWE/Transforms/AddClientInterface.h @@ -0,0 +1,17 @@ +#ifndef LIB_DIALECT_LWE_TRANSFORMS_ADDCLIENTINTERFACE_H_ +#define LIB_DIALECT_LWE_TRANSFORMS_ADDCLIENTINTERFACE_H_ + +#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace heir { +namespace lwe { + +#define GEN_PASS_DECL_ADDCLIENTINTERFACE +#include "lib/Dialect/LWE/Transforms/Passes.h.inc" + +} // namespace lwe +} // namespace heir +} // namespace mlir + +#endif // LIB_DIALECT_LWE_TRANSFORMS_ADDCLIENTINTERFACE_H_ diff --git a/lib/Dialect/LWE/Transforms/BUILD b/lib/Dialect/LWE/Transforms/BUILD index 656fa0641..ecdb0442b 100644 --- a/lib/Dialect/LWE/Transforms/BUILD +++ b/lib/Dialect/LWE/Transforms/BUILD @@ -11,12 +11,31 @@ cc_library( "Passes.h", ], deps = [ + ":AddClientInterface", ":SetDefaultParameters", ":pass_inc_gen", "@heir//lib/Dialect/LWE/IR:Dialect", ], ) +cc_library( + name = "AddClientInterface", + srcs = ["AddClientInterface.cpp"], + hdrs = [ + "AddClientInterface.h", + ], + deps = [ + ":pass_inc_gen", + "@heir//lib/Dialect/BGV/IR:Dialect", + "@heir//lib/Dialect/LWE/IR:Dialect", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + ], +) + cc_library( name = "SetDefaultParameters", srcs = ["SetDefaultParameters.cpp"], diff --git a/lib/Dialect/LWE/Transforms/Passes.h b/lib/Dialect/LWE/Transforms/Passes.h index ca02de674..72527511d 100644 --- a/lib/Dialect/LWE/Transforms/Passes.h +++ b/lib/Dialect/LWE/Transforms/Passes.h @@ -2,6 +2,7 @@ #define LIB_DIALECT_LWE_TRANSFORMS_PASSES_H_ #include "lib/Dialect/LWE/IR/LWEDialect.h" +#include "lib/Dialect/LWE/Transforms/AddClientInterface.h" #include "lib/Dialect/LWE/Transforms/SetDefaultParameters.h" namespace mlir { diff --git a/lib/Dialect/LWE/Transforms/Passes.td b/lib/Dialect/LWE/Transforms/Passes.td index 1c1a700a9..51b989033 100644 --- a/lib/Dialect/LWE/Transforms/Passes.td +++ b/lib/Dialect/LWE/Transforms/Passes.td @@ -3,6 +3,25 @@ include "mlir/Pass/PassBase.td" + +def AddClientInterface : Pass<"lwe-add-client-interface"> { + let summary = "Add client interfaces to (R)LWE encrypted functions"; + let description = [{ + This pass adds encrypt and decrypt functions for each compiled function in the + IR. These functions maintain the same interface as the original function, + while the compiled function may lose some of this information by the lowerings + to ciphertext types (e.g., a scalar ciphertext, when lowered through RLWE schemes, + must be encoded as a tensor). + }]; + let dependentDialects = ["mlir::heir::lwe::LWEDialect"]; + let options = [ + Option<"usePublicKey", "use-public-key", "bool", /*default=*/"false", + "If true, generate a client interface that uses a public key for encryption.">, + Option<"oneValuePerHelperFn", "one-value-per-helper-fn", "bool", /*default=*/"false", + "If true, split encryption helpers into separate functions for each SSA value."> + ]; +} + def SetDefaultParameters : Pass<"lwe-set-default-parameters"> { let summary = "Set default parameters for LWE ops"; let description = [{ diff --git a/tests/bgv/add_client_interface.mlir b/tests/lwe/add_client_interface.mlir similarity index 97% rename from tests/bgv/add_client_interface.mlir rename to tests/lwe/add_client_interface.mlir index df277a567..91e63808b 100644 --- a/tests/bgv/add_client_interface.mlir +++ b/tests/lwe/add_client_interface.mlir @@ -1,4 +1,4 @@ -// RUN: heir-opt --bgv-add-client-interface %s | FileCheck %s +// RUN: heir-opt --lwe-add-client-interface %s | FileCheck %s // These two types differ only on their underlying_type. The IR stays as the !in_ty // for the entire computation until the final extract op. diff --git a/tests/bgv/add_client_interface_public_key.mlir b/tests/lwe/add_client_interface_public_key.mlir similarity index 96% rename from tests/bgv/add_client_interface_public_key.mlir rename to tests/lwe/add_client_interface_public_key.mlir index a4f61ff31..f6a19750b 100644 --- a/tests/bgv/add_client_interface_public_key.mlir +++ b/tests/lwe/add_client_interface_public_key.mlir @@ -1,4 +1,4 @@ -// RUN: heir-opt --bgv-add-client-interface=use-public-key=true %s | FileCheck %s +// RUN: heir-opt --lwe-add-client-interface=use-public-key=true %s | FileCheck %s // These two types differ only on their underlying_type. The IR stays as the !in_ty // for the entire computation until the final extract op. diff --git a/tests/bgv/add_client_interface_split.mlir b/tests/lwe/add_client_interface_split.mlir similarity index 96% rename from tests/bgv/add_client_interface_split.mlir rename to tests/lwe/add_client_interface_split.mlir index bf6f55807..3e092d493 100644 --- a/tests/bgv/add_client_interface_split.mlir +++ b/tests/lwe/add_client_interface_split.mlir @@ -1,4 +1,4 @@ -// RUN: heir-opt '--bgv-add-client-interface=use-public-key=true one-value-per-helper-fn=true' %s | FileCheck %s +// RUN: heir-opt '--lwe-add-client-interface=use-public-key=true one-value-per-helper-fn=true' %s | FileCheck %s #encoding = #lwe.polynomial_evaluation_encoding #params = #lwe.rlwe_params>> diff --git a/tools/BUILD b/tools/BUILD index 189d3e445..a607c50f2 100644 --- a/tools/BUILD +++ b/tools/BUILD @@ -46,8 +46,6 @@ cc_binary( "@heir//lib/Conversion/SecretToBGV", "@heir//lib/Dialect/ArithExt/IR:Dialect", "@heir//lib/Dialect/BGV/IR:Dialect", - "@heir//lib/Dialect/BGV/Transforms", - "@heir//lib/Dialect/BGV/Transforms:AddClientInterface", "@heir//lib/Dialect/CGGI/IR:Dialect", "@heir//lib/Dialect/CGGI/Transforms", "@heir//lib/Dialect/CKKS/IR:Dialect", @@ -55,6 +53,7 @@ cc_binary( "@heir//lib/Dialect/Jaxite/IR:Dialect", "@heir//lib/Dialect/LWE/IR:Dialect", "@heir//lib/Dialect/LWE/Transforms", + "@heir//lib/Dialect/LWE/Transforms:AddClientInterface", "@heir//lib/Dialect/Openfhe/IR:Dialect", "@heir//lib/Dialect/Openfhe/Transforms", "@heir//lib/Dialect/Openfhe/Transforms:ConfigureCryptoContext", diff --git a/tools/heir-opt.cpp b/tools/heir-opt.cpp index 3e1a1261a..d1027de1c 100644 --- a/tools/heir-opt.cpp +++ b/tools/heir-opt.cpp @@ -15,14 +15,13 @@ #include "lib/Conversion/SecretToBGV/SecretToBGV.h" #include "lib/Dialect/ArithExt/IR/ArithExtDialect.h" #include "lib/Dialect/BGV/IR/BGVDialect.h" -#include "lib/Dialect/BGV/Transforms/AddClientInterface.h" -#include "lib/Dialect/BGV/Transforms/Passes.h" #include "lib/Dialect/CGGI/IR/CGGIDialect.h" #include "lib/Dialect/CGGI/Transforms/Passes.h" #include "lib/Dialect/CKKS/IR/CKKSDialect.h" #include "lib/Dialect/Comb/IR/CombDialect.h" #include "lib/Dialect/Jaxite/IR/JaxiteDialect.h" #include "lib/Dialect/LWE/IR/LWEDialect.h" +#include "lib/Dialect/LWE/Transforms/AddClientInterface.h" #include "lib/Dialect/LWE/Transforms/Passes.h" #include "lib/Dialect/Openfhe/IR/OpenfheDialect.h" #include "lib/Dialect/Openfhe/Transforms/ConfigureCryptoContext.h" @@ -466,11 +465,11 @@ void mlirToOpenFheBgvPipelineBuilder(OpPassManager &pm, mlirToBgvPipelineBuilder(pm, options); // Add client interface - auto addClientInterfaceOptions = bgv::AddClientInterfaceOptions{}; + auto addClientInterfaceOptions = lwe::AddClientInterfaceOptions{}; // OpenFHE's pke API, which this pipeline generates, is always public-key addClientInterfaceOptions.usePublicKey = true; addClientInterfaceOptions.oneValuePerHelperFn = true; - pm.addPass(bgv::createAddClientInterface(addClientInterfaceOptions)); + pm.addPass(lwe::createAddClientInterface(addClientInterfaceOptions)); // Lower to openfhe pm.addPass(bgv::createBGVToOpenfhe()); @@ -515,7 +514,6 @@ int main(int argc, char **argv) { registerAllPasses(); // Custom passes in HEIR - bgv::registerBGVPasses(); cggi::registerCGGIPasses(); lwe::registerLWEPasses(); ::mlir::heir::polynomial::registerPolynomialPasses();