From a80bc9d2a7537037cf9003c441740bbcb1c044f6 Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Thu, 26 Oct 2023 15:46:25 -0700 Subject: [PATCH 1/6] fix doc filename in BGV dialect --- include/Dialect/BGV/IR/BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/Dialect/BGV/IR/BUILD b/include/Dialect/BGV/IR/BUILD index 02824a330..e2b899efc 100644 --- a/include/Dialect/BGV/IR/BUILD +++ b/include/Dialect/BGV/IR/BUILD @@ -102,7 +102,7 @@ gentbl_cc_library( ), ( ["-gen-op-doc"], - "SecretOps.md", + "BGVOps.md", ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", From a5395cbfe0498f386d3cdf543542d200b6a0c836 Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Thu, 26 Oct 2023 15:52:20 -0700 Subject: [PATCH 2/6] Add empty TfheRust dialect shell --- include/Dialect/TfheRust/IR/BUILD | 111 ++++++++++++++++++ include/Dialect/TfheRust/IR/TfheRustDialect.h | 12 ++ .../Dialect/TfheRust/IR/TfheRustDialect.td | 22 ++++ include/Dialect/TfheRust/IR/TfheRustOps.h | 15 +++ include/Dialect/TfheRust/IR/TfheRustOps.td | 20 ++++ include/Dialect/TfheRust/IR/TfheRustTypes.h | 9 ++ include/Dialect/TfheRust/IR/TfheRustTypes.td | 17 +++ 7 files changed, 206 insertions(+) create mode 100644 include/Dialect/TfheRust/IR/BUILD create mode 100644 include/Dialect/TfheRust/IR/TfheRustDialect.h create mode 100644 include/Dialect/TfheRust/IR/TfheRustDialect.td create mode 100644 include/Dialect/TfheRust/IR/TfheRustOps.h create mode 100644 include/Dialect/TfheRust/IR/TfheRustOps.td create mode 100644 include/Dialect/TfheRust/IR/TfheRustTypes.h create mode 100644 include/Dialect/TfheRust/IR/TfheRustTypes.td diff --git a/include/Dialect/TfheRust/IR/BUILD b/include/Dialect/TfheRust/IR/BUILD new file mode 100644 index 000000000..f7180a1ff --- /dev/null +++ b/include/Dialect/TfheRust/IR/BUILD @@ -0,0 +1,111 @@ +# TfheRust, an exit dialect to the tfhe-rs API + +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") + +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +exports_files( + [ + "TfheRustDialect.h", + "TfheRustOps.h", + "TfheRustTypes.h", + ], +) + +td_library( + name = "td_files", + srcs = [ + "TfheRustDialect.td", + "TfheRustOps.td", + "TfheRustTypes.td", + ], + deps = [ + "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", + "@llvm-project//mlir:OpBaseTdFiles", + ], +) + +gentbl_cc_library( + name = "dialect_inc_gen", + tbl_outs = [ + ( + [ + "-gen-dialect-decls", + ], + "TfheRustDialect.h.inc", + ), + ( + [ + "-gen-dialect-defs", + ], + "TfheRustDialect.cpp.inc", + ), + ( + [ + "-gen-dialect-doc", + ], + "TfheRustDialect.md", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "TfheRustDialect.td", + deps = [ + ":td_files", + ], +) + +gentbl_cc_library( + name = "types_inc_gen", + tbl_outs = [ + ( + [ + "-gen-typedef-decls", + ], + "TfheRustTypes.h.inc", + ), + ( + [ + "-gen-typedef-defs", + ], + "TfheRustTypes.cpp.inc", + ), + ( + ["-gen-typedef-doc"], + "TfheRustTypes.md", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "TfheRustTypes.td", + deps = [ + ":dialect_inc_gen", + ":td_files", + ], +) + +gentbl_cc_library( + name = "ops_inc_gen", + tbl_outs = [ + ( + ["-gen-op-decls"], + "TfheRustOps.h.inc", + ), + ( + ["-gen-op-defs"], + "TfheRustOps.cpp.inc", + ), + ( + ["-gen-op-doc"], + "TfheRustOps.md", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "TfheRustOps.td", + deps = [ + ":dialect_inc_gen", + ":td_files", + ":types_inc_gen", + ], +) diff --git a/include/Dialect/TfheRust/IR/TfheRustDialect.h b/include/Dialect/TfheRust/IR/TfheRustDialect.h new file mode 100644 index 000000000..e246f5e60 --- /dev/null +++ b/include/Dialect/TfheRust/IR/TfheRustDialect.h @@ -0,0 +1,12 @@ +#ifndef INCLUDE_DIALECT_TFHERUST_IR_TFHERUSTDIALECT_H_ +#define INCLUDE_DIALECT_TFHERUST_IR_TFHERUSTDIALECT_H_ + +#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/include/mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/include/mlir/IR/DialectImplementation.h" // from @llvm-project + +// Generated headers (block clang-format from messing up order) +#include "include/Dialect/TfheRust/IR/TfheRustDialect.h.inc" + +#endif // INCLUDE_DIALECT_TFHERUST_IR_TFHERUSTDIALECT_H_ diff --git a/include/Dialect/TfheRust/IR/TfheRustDialect.td b/include/Dialect/TfheRust/IR/TfheRustDialect.td new file mode 100644 index 000000000..151802e69 --- /dev/null +++ b/include/Dialect/TfheRust/IR/TfheRustDialect.td @@ -0,0 +1,22 @@ +#ifndef INCLUDE_DIALECT_TFHERUST_IR_TFHERUSTDIALECT_TD_ +#define INCLUDE_DIALECT_TFHERUST_IR_TFHERUSTDIALECT_TD_ + +include "mlir/IR/DialectBase.td" +include "mlir/IR/OpBase.td" + +def TfheRust_Dialect : Dialect { + let name = "tfhe_rs"; + + let description = [{ + The `thfe_rs` dialect is an exit dialect for generating rust code against the tfhe-rs library API. + + See https://github.com/zama-ai/tfhe-rs + }]; + + let cppNamespace = "::mlir::heir::tfhe_rs"; + + let useDefaultTypePrinterParser = 1; + let useDefaultAttributePrinterParser = 1; +} + +#endif // INCLUDE_DIALECT_TFHERUST_IR_TFHERUSTDIALECT_TD_ diff --git a/include/Dialect/TfheRust/IR/TfheRustOps.h b/include/Dialect/TfheRust/IR/TfheRustOps.h new file mode 100644 index 000000000..17faff51f --- /dev/null +++ b/include/Dialect/TfheRust/IR/TfheRustOps.h @@ -0,0 +1,15 @@ +#ifndef INCLUDE_DIALECT_TFHERUST_IR_TFHERUSTOPS_H_ +#define INCLUDE_DIALECT_TFHERUST_IR_TFHERUSTOPS_H_ + +#include "include/Dialect/TfheRust/IR/TfheRustDialect.h" +#include "include/Dialect/TfheRust/IR/TfheRustTraits.h" +#include "include/Dialect/TfheRust/IR/TfheRustTypes.h" +#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/include/mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/include/mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project + +#define GET_OP_CLASSES +#include "include/Dialect/TfheRust/IR/TfheRustOps.h.inc" + +#endif // INCLUDE_DIALECT_TFHERUST_IR_TFHERUSTOPS_H_ diff --git a/include/Dialect/TfheRust/IR/TfheRustOps.td b/include/Dialect/TfheRust/IR/TfheRustOps.td new file mode 100644 index 000000000..2e27b62ce --- /dev/null +++ b/include/Dialect/TfheRust/IR/TfheRustOps.td @@ -0,0 +1,20 @@ +#ifndef INCLUDE_DIALECT_TFHERUST_IR_TFHERUSTOPS_TD_ +#define INCLUDE_DIALECT_TFHERUST_IR_TFHERUSTOPS_TD_ + +include "TfheRustDialect.td" +include "TfheRustTypes.td" + +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/InferTypeOpInterface.td" + +class TfheRust_Op traits = []> : + Op { + + let assemblyFormat = [{ + operands attr-dict `:` `(` type(operands) `)` `->` type(results) + }]; + let cppNamespace = "::mlir::heir::thfe_rs"; +} + + +#endif // INCLUDE_DIALECT_TFHERUST_IR_TFHERUSTOPS_TD_ diff --git a/include/Dialect/TfheRust/IR/TfheRustTypes.h b/include/Dialect/TfheRust/IR/TfheRustTypes.h new file mode 100644 index 000000000..867e54373 --- /dev/null +++ b/include/Dialect/TfheRust/IR/TfheRustTypes.h @@ -0,0 +1,9 @@ +#ifndef INCLUDE_DIALECT_TFHERUST_IR_TFHERUSTTYPES_H_ +#define INCLUDE_DIALECT_TFHERUST_IR_TFHERUSTTYPES_H_ + +#include "include/Dialect/TfheRust/IR/TfheRustDialect.h" + +#define GET_TYPEDEF_CLASSES +#include "include/Dialect/TfheRust/IR/TfheRustTypes.h.inc" + +#endif // INCLUDE_DIALECT_TFHERUST_IR_TFHERUSTTYPES_H_ diff --git a/include/Dialect/TfheRust/IR/TfheRustTypes.td b/include/Dialect/TfheRust/IR/TfheRustTypes.td new file mode 100644 index 000000000..2d90e26c9 --- /dev/null +++ b/include/Dialect/TfheRust/IR/TfheRustTypes.td @@ -0,0 +1,17 @@ +#ifndef INCLUDE_DIALECT_TFHERUST_IR_TFHERUSTTYPES_TD_ +#define INCLUDE_DIALECT_TFHERUST_IR_TFHERUSTTYPES_TD_ + +include "TfheRustDialect.td" + +include "mlir/IR/AttrTypeBase.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/IR/DialectBase.td" +include "mlir/IR/OpBase.td" + +// A base class for all types in this dialect +class TfheRust_Type + : TypeDef { + let mnemonic = typeMnemonic; +} + +#endif // INCLUDE_DIALECT_TFHERUST_IR_TFHERUSTTYPES_TD_ From 3ff8a72615c8095c83f5103facf803275967792b Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Thu, 2 Nov 2023 08:07:19 -0700 Subject: [PATCH 3/6] Add encrypted int types --- .../Dialect/TfheRust/IR/TfheRustDialect.td | 1 - include/Dialect/TfheRust/IR/TfheRustOps.h | 2 - include/Dialect/TfheRust/IR/TfheRustTypes.td | 20 ++++++++++ lib/Dialect/TfheRust/IR/BUILD | 23 ++++++++++++ lib/Dialect/TfheRust/IR/TfheRustDialect.cpp | 37 +++++++++++++++++++ tests/tfhe_rs/BUILD | 13 +++++++ tests/tfhe_rs/types.mlir | 27 ++++++++++++++ tools/BUILD | 1 + tools/heir-opt.cpp | 2 + 9 files changed, 123 insertions(+), 3 deletions(-) create mode 100644 lib/Dialect/TfheRust/IR/BUILD create mode 100644 lib/Dialect/TfheRust/IR/TfheRustDialect.cpp create mode 100644 tests/tfhe_rs/BUILD create mode 100644 tests/tfhe_rs/types.mlir diff --git a/include/Dialect/TfheRust/IR/TfheRustDialect.td b/include/Dialect/TfheRust/IR/TfheRustDialect.td index 151802e69..aa97110f5 100644 --- a/include/Dialect/TfheRust/IR/TfheRustDialect.td +++ b/include/Dialect/TfheRust/IR/TfheRustDialect.td @@ -16,7 +16,6 @@ def TfheRust_Dialect : Dialect { let cppNamespace = "::mlir::heir::tfhe_rs"; let useDefaultTypePrinterParser = 1; - let useDefaultAttributePrinterParser = 1; } #endif // INCLUDE_DIALECT_TFHERUST_IR_TFHERUSTDIALECT_TD_ diff --git a/include/Dialect/TfheRust/IR/TfheRustOps.h b/include/Dialect/TfheRust/IR/TfheRustOps.h index 17faff51f..74ecfc4bd 100644 --- a/include/Dialect/TfheRust/IR/TfheRustOps.h +++ b/include/Dialect/TfheRust/IR/TfheRustOps.h @@ -2,12 +2,10 @@ #define INCLUDE_DIALECT_TFHERUST_IR_TFHERUSTOPS_H_ #include "include/Dialect/TfheRust/IR/TfheRustDialect.h" -#include "include/Dialect/TfheRust/IR/TfheRustTraits.h" #include "include/Dialect/TfheRust/IR/TfheRustTypes.h" #include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/include/mlir/IR/Dialect.h" // from @llvm-project -#include "mlir/include/mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project #define GET_OP_CLASSES #include "include/Dialect/TfheRust/IR/TfheRustOps.h.inc" diff --git a/include/Dialect/TfheRust/IR/TfheRustTypes.td b/include/Dialect/TfheRust/IR/TfheRustTypes.td index 2d90e26c9..ce96c1154 100644 --- a/include/Dialect/TfheRust/IR/TfheRustTypes.td +++ b/include/Dialect/TfheRust/IR/TfheRustTypes.td @@ -14,4 +14,24 @@ class TfheRust_Type let mnemonic = typeMnemonic; } + +class TfheRust_EncryptedUInt : TfheRust_Type<"TfheRust_EncryptedUInt" # width, "eui" # width> { + let summary = "An encrypted unsigned integer corresponding to tfhe-rs's FHEUint" # width # " type."; +} + +// Available options are https://docs.rs/tfhe/latest/tfhe/index.html#types +foreach i = [2, 3, 4, 8, 10, 12, 14, 16, 32, 64, 128, 256] in { + def TfheRust_EncryptedUInt # i : TfheRust_EncryptedUInt; +} + +class TfheRust_EncryptedInt : TfheRust_Type<"TfheRust_EncryptedInt" # width, "ei" # width> { + let summary = "An encrypted signed integer corresponding to tfhe-rs's FHEInt" # width # " type."; +} + +// Available options are https://docs.rs/tfhe/latest/tfhe/index.html#types +foreach i = [8, 16, 32, 64, 128, 256] in { + def TfheRust_EncryptedInt # i : TfheRust_EncryptedInt; +} + + #endif // INCLUDE_DIALECT_TFHERUST_IR_TFHERUSTTYPES_TD_ diff --git a/lib/Dialect/TfheRust/IR/BUILD b/lib/Dialect/TfheRust/IR/BUILD new file mode 100644 index 000000000..a237a57b1 --- /dev/null +++ b/lib/Dialect/TfheRust/IR/BUILD @@ -0,0 +1,23 @@ +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "Dialect", + srcs = [ + "TfheRustDialect.cpp", + ], + hdrs = [ + "@heir//include/Dialect/TfheRust/IR:TfheRustDialect.h", + "@heir//include/Dialect/TfheRust/IR:TfheRustOps.h", + "@heir//include/Dialect/TfheRust/IR:TfheRustTypes.h", + ], + deps = [ + "@heir//include/Dialect/TfheRust/IR:dialect_inc_gen", + "@heir//include/Dialect/TfheRust/IR:ops_inc_gen", + "@heir//include/Dialect/TfheRust/IR:types_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + ], +) diff --git a/lib/Dialect/TfheRust/IR/TfheRustDialect.cpp b/lib/Dialect/TfheRust/IR/TfheRustDialect.cpp new file mode 100644 index 000000000..2adac1cb2 --- /dev/null +++ b/lib/Dialect/TfheRust/IR/TfheRustDialect.cpp @@ -0,0 +1,37 @@ +#include "include/Dialect/TfheRust/IR/TfheRustDialect.h" + +#include "include/Dialect/TfheRust/IR/TfheRustDialect.cpp.inc" +#include "include/Dialect/TfheRust/IR/TfheRustOps.h" +#include "include/Dialect/TfheRust/IR/TfheRustTypes.h" +#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project +#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project +#include "mlir/include/mlir/IR/DialectImplementation.h" // from @llvm-project +#define GET_TYPEDEF_CLASSES +#include "include/Dialect/TfheRust/IR/TfheRustTypes.cpp.inc" +#define GET_OP_CLASSES +#include "include/Dialect/TfheRust/IR/TfheRustOps.cpp.inc" + +namespace mlir { +namespace heir { +namespace tfhe_rs { + +//===----------------------------------------------------------------------===// +// TfheRust dialect. +//===----------------------------------------------------------------------===// + +// Dialect construction: there is one instance per context and it registers its +// operations, types, and interfaces here. +void TfheRustDialect::initialize() { + addTypes< +#define GET_TYPEDEF_LIST +#include "include/Dialect/TfheRust/IR/TfheRustTypes.cpp.inc" + >(); + addOperations< +#define GET_OP_LIST +#include "include/Dialect/TfheRust/IR/TfheRustOps.cpp.inc" + >(); +} + +} // namespace tfhe_rs +} // namespace heir +} // namespace mlir diff --git a/tests/tfhe_rs/BUILD b/tests/tfhe_rs/BUILD new file mode 100644 index 000000000..6c9032391 --- /dev/null +++ b/tests/tfhe_rs/BUILD @@ -0,0 +1,13 @@ +load("//bazel:lit.bzl", "glob_lit_tests") + +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +glob_lit_tests( + name = "all_tests", + data = ["@heir//tests:test_utilities"], + driver = "@heir//tests:run_lit.sh", + test_file_exts = ["mlir"], +) diff --git a/tests/tfhe_rs/types.mlir b/tests/tfhe_rs/types.mlir new file mode 100644 index 000000000..61ac98170 --- /dev/null +++ b/tests/tfhe_rs/types.mlir @@ -0,0 +1,27 @@ +// RUN: heir-opt %s | FileCheck %s + +// This simply tests for syntax. +module { + // CHECK-LABEL: func @test + func.func @test( + %arg_eui2: !tfhe_rs.eui2, + %arg_eui3: !tfhe_rs.eui3, + %arg_eui4: !tfhe_rs.eui4, + %arg_eui8: !tfhe_rs.eui8, + %arg_eui10: !tfhe_rs.eui10, + %arg_eui12: !tfhe_rs.eui12, + %arg_eui14: !tfhe_rs.eui14, + %arg_eui16: !tfhe_rs.eui16, + %arg_eui32: !tfhe_rs.eui32, + %arg_eui64: !tfhe_rs.eui64, + %arg_eui128: !tfhe_rs.eui128, + %arg_eui256: !tfhe_rs.eui256, + %arg_ei8: !tfhe_rs.ei8, + %arg_ei16: !tfhe_rs.ei16, + %arg_ei32: !tfhe_rs.ei32, + %arg_ei64: !tfhe_rs.ei64, + %arg_ei128: !tfhe_rs.ei128, + %arg_ei256: !tfhe_rs.ei256) { + return + } +} diff --git a/tools/BUILD b/tools/BUILD index bb3e29084..ab808300c 100644 --- a/tools/BUILD +++ b/tools/BUILD @@ -44,6 +44,7 @@ cc_binary( "@heir//lib/Dialect/Polynomial/IR:Dialect", "@heir//lib/Dialect/Secret/IR:Dialect", "@heir//lib/Dialect/Secret/Transforms", + "@heir//lib/Dialect/TfheRust/IR:Dialect", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:AffineToStandard", diff --git a/tools/heir-opt.cpp b/tools/heir-opt.cpp index bab302bd3..f353569eb 100644 --- a/tools/heir-opt.cpp +++ b/tools/heir-opt.cpp @@ -10,6 +10,7 @@ #include "include/Dialect/Polynomial/IR/PolynomialDialect.h" #include "include/Dialect/Secret/IR/SecretDialect.h" #include "include/Dialect/Secret/Transforms/Passes.h" +#include "include/Dialect/TfheRust/IR/TfheRustDialect.h" #include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project #include "mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h" // from @llvm-project #include "mlir/include/mlir/Conversion/ArithToLLVM/ArithToLLVM.h" // from @llvm-project @@ -155,6 +156,7 @@ int main(int argc, char **argv) { registry.insert(); registry.insert(); registry.insert(); + registry.insert(); // Add expected MLIR dialects to the registry. registry.insert(); From 9fc4e39d88ceeb121f3f8e9c599bacfea0f8be89 Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Fri, 27 Oct 2023 17:16:35 -0700 Subject: [PATCH 4/6] add create_trivial op --- include/Dialect/TfheRust/IR/TfheRustOps.td | 7 ++++- include/Dialect/TfheRust/IR/TfheRustTypes.td | 28 +++++++++++++++++++- tests/tfhe_rs/ops.mlir | 17 ++++++++++++ 3 files changed, 50 insertions(+), 2 deletions(-) create mode 100644 tests/tfhe_rs/ops.mlir diff --git a/include/Dialect/TfheRust/IR/TfheRustOps.td b/include/Dialect/TfheRust/IR/TfheRustOps.td index 2e27b62ce..c2ebfce3a 100644 --- a/include/Dialect/TfheRust/IR/TfheRustOps.td +++ b/include/Dialect/TfheRust/IR/TfheRustOps.td @@ -4,8 +4,8 @@ include "TfheRustDialect.td" include "TfheRustTypes.td" +include "mlir/IR/CommonTypeConstraints.td" include "mlir/IR/OpBase.td" -include "mlir/Interfaces/InferTypeOpInterface.td" class TfheRust_Op traits = []> : Op { @@ -16,5 +16,10 @@ class TfheRust_Op traits = []> : let cppNamespace = "::mlir::heir::thfe_rs"; } +def CreateTrivial : TfheRust_Op<"create_trivial"> { + let arguments = (ins ServerKey:$serverKey, AnyInteger:$value); + let results = (outs TfheRust_CiphertextType:$output); +} + #endif // INCLUDE_DIALECT_TFHERUST_IR_TFHERUSTOPS_TD_ diff --git a/include/Dialect/TfheRust/IR/TfheRustTypes.td b/include/Dialect/TfheRust/IR/TfheRustTypes.td index ce96c1154..381bc51b2 100644 --- a/include/Dialect/TfheRust/IR/TfheRustTypes.td +++ b/include/Dialect/TfheRust/IR/TfheRustTypes.td @@ -4,9 +4,10 @@ include "TfheRustDialect.td" include "mlir/IR/AttrTypeBase.td" -include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/IR/CommonTypeConstraints.td" include "mlir/IR/DialectBase.td" include "mlir/IR/OpBase.td" +include "mlir/Interfaces/InferTypeOpInterface.td" // A base class for all types in this dialect class TfheRust_Type @@ -33,5 +34,30 @@ foreach i = [8, 16, 32, 64, 128, 256] in { def TfheRust_EncryptedInt # i : TfheRust_EncryptedInt; } +def ServerKey : TfheRust_Type<"ServerKey", "server_key"> { + let summary = "The server key required to perform homomorphic operations."; +} + +def TfheRust_CiphertextType : + AnyTypeOf<[ + EncryptedUInt2, + EncryptedUInt3, + EncryptedUInt4, + EncryptedUInt8, + EncryptedUInt10, + EncryptedUInt12, + EncryptedUInt14, + EncryptedUInt16, + EncryptedUInt32, + EncryptedUInt64, + EncryptedUInt128, + EncryptedUInt256, + EncryptedInt8, + EncryptedInt16, + EncryptedInt32, + EncryptedInt64, + EncryptedInt128, + EncryptedInt256, + ]>; #endif // INCLUDE_DIALECT_TFHERUST_IR_TFHERUSTTYPES_TD_ diff --git a/tests/tfhe_rs/ops.mlir b/tests/tfhe_rs/ops.mlir new file mode 100644 index 000000000..77afc7fc1 --- /dev/null +++ b/tests/tfhe_rs/ops.mlir @@ -0,0 +1,17 @@ +// RUN: heir-opt %s | FileCheck %s + +// This simply tests for syntax. + +!sks = !tfhe_rs.server_key +module { + // CHECK-LABEL: func @test_create_trivial + func.func @test_create_trivial(%sks : !sks) { + %0 = arith.constant 1 : i8 + %1 = arith.constant 1 : i3 + %2 = arith.constant 1 : i128 + %e1 = tfhe_rs.create_trivial %sks, %0 : (!sks, i8) -> !tfhe_rs.ei8 + %eu1 = tfhe_rs.create_trivial %sks, %1 : (!sks, i3) -> !tfhe_rs.eui8 + %e2 = tfhe_rs.create_trivial %sks, %2 : (!sks, i128) -> !tfhe_rs.ei128 + return + } +} From 18a3c6d80c087098ca21f2ae12ad139cb30f4ce8 Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Wed, 1 Nov 2023 16:52:02 -0700 Subject: [PATCH 5/6] Add minimal ops to do PBS --- .../Dialect/TfheRust/IR/TfheRustDialect.td | 6 +-- include/Dialect/TfheRust/IR/TfheRustOps.td | 27 +++++++++- include/Dialect/TfheRust/IR/TfheRustTypes.td | 49 ++++++++++--------- lib/Dialect/TfheRust/IR/TfheRustDialect.cpp | 4 +- tests/tfhe_rs/ops.mlir | 24 +++++++-- tests/tfhe_rs/types.mlir | 36 +++++++------- tools/heir-opt.cpp | 2 +- 7 files changed, 96 insertions(+), 52 deletions(-) diff --git a/include/Dialect/TfheRust/IR/TfheRustDialect.td b/include/Dialect/TfheRust/IR/TfheRustDialect.td index aa97110f5..cc5f25ef7 100644 --- a/include/Dialect/TfheRust/IR/TfheRustDialect.td +++ b/include/Dialect/TfheRust/IR/TfheRustDialect.td @@ -5,15 +5,15 @@ include "mlir/IR/DialectBase.td" include "mlir/IR/OpBase.td" def TfheRust_Dialect : Dialect { - let name = "tfhe_rs"; + let name = "tfhe_rust"; let description = [{ - The `thfe_rs` dialect is an exit dialect for generating rust code against the tfhe-rs library API. + The `thfe_rust` dialect is an exit dialect for generating rust code against the tfhe-rs library API. See https://github.com/zama-ai/tfhe-rs }]; - let cppNamespace = "::mlir::heir::tfhe_rs"; + let cppNamespace = "::mlir::heir::tfhe_rust"; let useDefaultTypePrinterParser = 1; } diff --git a/include/Dialect/TfheRust/IR/TfheRustOps.td b/include/Dialect/TfheRust/IR/TfheRustOps.td index c2ebfce3a..cde72140c 100644 --- a/include/Dialect/TfheRust/IR/TfheRustOps.td +++ b/include/Dialect/TfheRust/IR/TfheRustOps.td @@ -13,11 +13,34 @@ class TfheRust_Op traits = []> : let assemblyFormat = [{ operands attr-dict `:` `(` type(operands) `)` `->` type(results) }]; - let cppNamespace = "::mlir::heir::thfe_rs"; + let cppNamespace = "::mlir::heir::thfe_rust"; } def CreateTrivial : TfheRust_Op<"create_trivial"> { - let arguments = (ins ServerKey:$serverKey, AnyInteger:$value); + let arguments = (ins TfheRust_ServerKey:$serverKey, AnyInteger:$value); + let results = (outs TfheRust_CiphertextType:$output); +} + +def ScalarLeftShift : TfheRust_Op<"scalar_left_shift"> { + let arguments = (ins TfheRust_ServerKey:$serverKey, TfheRust_CiphertextType:$ciphertext, AnyI8:$shiftAmount); + let results = (outs TfheRust_CiphertextType:$output); +} + +def Add : TfheRust_Op<"add"> { + let arguments = (ins + TfheRust_ServerKey:$serverKey, + TfheRust_CiphertextType:$lhs, + TfheRust_CiphertextType:$rhs + ); + let results = (outs TfheRust_CiphertextType:$output); +} + +def ApplyLookupTable : TfheRust_Op<"apply_lookup_table"> { + let arguments = ( + ins TfheRust_ServerKey:$serverKey, + TfheRust_CiphertextType:$input, + TfheRust_LookupTable:$lookupTable + ); let results = (outs TfheRust_CiphertextType:$output); } diff --git a/include/Dialect/TfheRust/IR/TfheRustTypes.td b/include/Dialect/TfheRust/IR/TfheRustTypes.td index 381bc51b2..40af2378e 100644 --- a/include/Dialect/TfheRust/IR/TfheRustTypes.td +++ b/include/Dialect/TfheRust/IR/TfheRustTypes.td @@ -34,30 +34,35 @@ foreach i = [8, 16, 32, 64, 128, 256] in { def TfheRust_EncryptedInt # i : TfheRust_EncryptedInt; } -def ServerKey : TfheRust_Type<"ServerKey", "server_key"> { - let summary = "The server key required to perform homomorphic operations."; -} - def TfheRust_CiphertextType : AnyTypeOf<[ - EncryptedUInt2, - EncryptedUInt3, - EncryptedUInt4, - EncryptedUInt8, - EncryptedUInt10, - EncryptedUInt12, - EncryptedUInt14, - EncryptedUInt16, - EncryptedUInt32, - EncryptedUInt64, - EncryptedUInt128, - EncryptedUInt256, - EncryptedInt8, - EncryptedInt16, - EncryptedInt32, - EncryptedInt64, - EncryptedInt128, - EncryptedInt256, + TfheRust_EncryptedUInt2, + TfheRust_EncryptedUInt3, + TfheRust_EncryptedUInt4, + TfheRust_EncryptedUInt8, + TfheRust_EncryptedUInt10, + TfheRust_EncryptedUInt12, + TfheRust_EncryptedUInt14, + TfheRust_EncryptedUInt16, + TfheRust_EncryptedUInt32, + TfheRust_EncryptedUInt64, + TfheRust_EncryptedUInt128, + TfheRust_EncryptedUInt256, + TfheRust_EncryptedInt8, + TfheRust_EncryptedInt16, + TfheRust_EncryptedInt32, + TfheRust_EncryptedInt64, + TfheRust_EncryptedInt128, + TfheRust_EncryptedInt256, ]>; + +def TfheRust_ServerKey : TfheRust_Type<"ServerKey", "server_key"> { + let summary = "The server key required to perform homomorphic operations."; +} + +def TfheRust_LookupTable : TfheRust_Type<"LookupTable", "lookup_table"> { + let summary = "A univariate lookup table used for programmable bootstrapping."; +} + #endif // INCLUDE_DIALECT_TFHERUST_IR_TFHERUSTTYPES_TD_ diff --git a/lib/Dialect/TfheRust/IR/TfheRustDialect.cpp b/lib/Dialect/TfheRust/IR/TfheRustDialect.cpp index 2adac1cb2..318b01e18 100644 --- a/lib/Dialect/TfheRust/IR/TfheRustDialect.cpp +++ b/lib/Dialect/TfheRust/IR/TfheRustDialect.cpp @@ -13,7 +13,7 @@ namespace mlir { namespace heir { -namespace tfhe_rs { +namespace tfhe_rust { //===----------------------------------------------------------------------===// // TfheRust dialect. @@ -32,6 +32,6 @@ void TfheRustDialect::initialize() { >(); } -} // namespace tfhe_rs +} // namespace tfhe_rust } // namespace heir } // namespace mlir diff --git a/tests/tfhe_rs/ops.mlir b/tests/tfhe_rs/ops.mlir index 77afc7fc1..993851d14 100644 --- a/tests/tfhe_rs/ops.mlir +++ b/tests/tfhe_rs/ops.mlir @@ -2,16 +2,32 @@ // This simply tests for syntax. -!sks = !tfhe_rs.server_key +!sks = !tfhe_rust.server_key module { // CHECK-LABEL: func @test_create_trivial func.func @test_create_trivial(%sks : !sks) { %0 = arith.constant 1 : i8 %1 = arith.constant 1 : i3 %2 = arith.constant 1 : i128 - %e1 = tfhe_rs.create_trivial %sks, %0 : (!sks, i8) -> !tfhe_rs.ei8 - %eu1 = tfhe_rs.create_trivial %sks, %1 : (!sks, i3) -> !tfhe_rs.eui8 - %e2 = tfhe_rs.create_trivial %sks, %2 : (!sks, i128) -> !tfhe_rs.ei128 + %e1 = tfhe_rust.create_trivial %sks, %0 : (!sks, i8) -> !tfhe_rust.ei8 + %eu1 = tfhe_rust.create_trivial %sks, %1 : (!sks, i3) -> !tfhe_rust.eui8 + %e2 = tfhe_rust.create_trivial %sks, %2 : (!sks, i128) -> !tfhe_rust.ei128 + return + } + + // CHECK-LABEL: func @test_apply_lookup_table + func.func @test_apply_lookup_table(%sks : !sks, %lut: !tfhe_rust.lookup_table) { + %0 = arith.constant 1 : i3 + %1 = arith.constant 2 : i3 + // TODO: add shift + add for bivariate input + %e1 = tfhe_rust.create_trivial %sks, %0 : (!sks, i3) -> !tfhe_rust.eui3 + %e2 = tfhe_rust.create_trivial %sks, %0 : (!sks, i3) -> !tfhe_rust.eui3 + + %shiftAmount = arith.constant 1 : i8 + %e2Shifted = tfhe_rust.scalar_left_shift %sks, %e2, %shiftAmount : (!sks, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3 + %eCombined = tfhe_rust.add %sks, %e1, %e2Shifted : (!sks, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3 + + %out = tfhe_rust.apply_lookup_table %sks, %eCombined, %lut : (!sks, !tfhe_rust.eui3, !tfhe_rust.lookup_table) -> !tfhe_rust.eui3 return } } diff --git a/tests/tfhe_rs/types.mlir b/tests/tfhe_rs/types.mlir index 61ac98170..a8318adea 100644 --- a/tests/tfhe_rs/types.mlir +++ b/tests/tfhe_rs/types.mlir @@ -4,24 +4,24 @@ module { // CHECK-LABEL: func @test func.func @test( - %arg_eui2: !tfhe_rs.eui2, - %arg_eui3: !tfhe_rs.eui3, - %arg_eui4: !tfhe_rs.eui4, - %arg_eui8: !tfhe_rs.eui8, - %arg_eui10: !tfhe_rs.eui10, - %arg_eui12: !tfhe_rs.eui12, - %arg_eui14: !tfhe_rs.eui14, - %arg_eui16: !tfhe_rs.eui16, - %arg_eui32: !tfhe_rs.eui32, - %arg_eui64: !tfhe_rs.eui64, - %arg_eui128: !tfhe_rs.eui128, - %arg_eui256: !tfhe_rs.eui256, - %arg_ei8: !tfhe_rs.ei8, - %arg_ei16: !tfhe_rs.ei16, - %arg_ei32: !tfhe_rs.ei32, - %arg_ei64: !tfhe_rs.ei64, - %arg_ei128: !tfhe_rs.ei128, - %arg_ei256: !tfhe_rs.ei256) { + %arg_eui2: !tfhe_rust.eui2, + %arg_eui3: !tfhe_rust.eui3, + %arg_eui4: !tfhe_rust.eui4, + %arg_eui8: !tfhe_rust.eui8, + %arg_eui10: !tfhe_rust.eui10, + %arg_eui12: !tfhe_rust.eui12, + %arg_eui14: !tfhe_rust.eui14, + %arg_eui16: !tfhe_rust.eui16, + %arg_eui32: !tfhe_rust.eui32, + %arg_eui64: !tfhe_rust.eui64, + %arg_eui128: !tfhe_rust.eui128, + %arg_eui256: !tfhe_rust.eui256, + %arg_ei8: !tfhe_rust.ei8, + %arg_ei16: !tfhe_rust.ei16, + %arg_ei32: !tfhe_rust.ei32, + %arg_ei64: !tfhe_rust.ei64, + %arg_ei128: !tfhe_rust.ei128, + %arg_ei256: !tfhe_rust.ei256) { return } } diff --git a/tools/heir-opt.cpp b/tools/heir-opt.cpp index f353569eb..a992bdea4 100644 --- a/tools/heir-opt.cpp +++ b/tools/heir-opt.cpp @@ -156,7 +156,7 @@ int main(int argc, char **argv) { registry.insert(); registry.insert(); registry.insert(); - registry.insert(); + registry.insert(); // Add expected MLIR dialects to the registry. registry.insert(); From 2068327c5623a3b075f6396a59afcbb76b60d162 Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Thu, 2 Nov 2023 16:19:32 -0700 Subject: [PATCH 6/6] remove extra comments --- include/Dialect/TfheRust/IR/TfheRustTypes.td | 2 -- lib/Dialect/TfheRust/IR/TfheRustDialect.cpp | 6 ------ tests/tfhe_rs/ops.mlir | 1 - 3 files changed, 9 deletions(-) diff --git a/include/Dialect/TfheRust/IR/TfheRustTypes.td b/include/Dialect/TfheRust/IR/TfheRustTypes.td index 40af2378e..2dc3fccd1 100644 --- a/include/Dialect/TfheRust/IR/TfheRustTypes.td +++ b/include/Dialect/TfheRust/IR/TfheRustTypes.td @@ -9,13 +9,11 @@ include "mlir/IR/DialectBase.td" include "mlir/IR/OpBase.td" include "mlir/Interfaces/InferTypeOpInterface.td" -// A base class for all types in this dialect class TfheRust_Type : TypeDef { let mnemonic = typeMnemonic; } - class TfheRust_EncryptedUInt : TfheRust_Type<"TfheRust_EncryptedUInt" # width, "eui" # width> { let summary = "An encrypted unsigned integer corresponding to tfhe-rs's FHEUint" # width # " type."; } diff --git a/lib/Dialect/TfheRust/IR/TfheRustDialect.cpp b/lib/Dialect/TfheRust/IR/TfheRustDialect.cpp index 318b01e18..e0bde0c18 100644 --- a/lib/Dialect/TfheRust/IR/TfheRustDialect.cpp +++ b/lib/Dialect/TfheRust/IR/TfheRustDialect.cpp @@ -15,12 +15,6 @@ namespace mlir { namespace heir { namespace tfhe_rust { -//===----------------------------------------------------------------------===// -// TfheRust dialect. -//===----------------------------------------------------------------------===// - -// Dialect construction: there is one instance per context and it registers its -// operations, types, and interfaces here. void TfheRustDialect::initialize() { addTypes< #define GET_TYPEDEF_LIST diff --git a/tests/tfhe_rs/ops.mlir b/tests/tfhe_rs/ops.mlir index 993851d14..7b047467c 100644 --- a/tests/tfhe_rs/ops.mlir +++ b/tests/tfhe_rs/ops.mlir @@ -19,7 +19,6 @@ module { func.func @test_apply_lookup_table(%sks : !sks, %lut: !tfhe_rust.lookup_table) { %0 = arith.constant 1 : i3 %1 = arith.constant 2 : i3 - // TODO: add shift + add for bivariate input %e1 = tfhe_rust.create_trivial %sks, %0 : (!sks, i3) -> !tfhe_rust.eui3 %e2 = tfhe_rust.create_trivial %sks, %0 : (!sks, i3) -> !tfhe_rust.eui3