Skip to content

Commit

Permalink
Merge pull request google#217 from j2kun:tfhe-rs-dialect
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 579897199
  • Loading branch information
copybara-github committed Nov 6, 2023
2 parents c3290bf + 2068327 commit 9020c83
Show file tree
Hide file tree
Showing 15 changed files with 410 additions and 1 deletion.
2 changes: 1 addition & 1 deletion include/Dialect/BGV/IR/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ gentbl_cc_library(
),
(
["-gen-op-doc"],
"SecretOps.md",
"BGVOps.md",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
Expand Down
111 changes: 111 additions & 0 deletions include/Dialect/TfheRust/IR/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# TfheRust, an exit dialect to the tfhe-rs API

load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library")

package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

exports_files(
[
"TfheRustDialect.h",
"TfheRustOps.h",
"TfheRustTypes.h",
],
)

td_library(
name = "td_files",
srcs = [
"TfheRustDialect.td",
"TfheRustOps.td",
"TfheRustTypes.td",
],
deps = [
"@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
"@llvm-project//mlir:OpBaseTdFiles",
],
)

gentbl_cc_library(
name = "dialect_inc_gen",
tbl_outs = [
(
[
"-gen-dialect-decls",
],
"TfheRustDialect.h.inc",
),
(
[
"-gen-dialect-defs",
],
"TfheRustDialect.cpp.inc",
),
(
[
"-gen-dialect-doc",
],
"TfheRustDialect.md",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "TfheRustDialect.td",
deps = [
":td_files",
],
)

gentbl_cc_library(
name = "types_inc_gen",
tbl_outs = [
(
[
"-gen-typedef-decls",
],
"TfheRustTypes.h.inc",
),
(
[
"-gen-typedef-defs",
],
"TfheRustTypes.cpp.inc",
),
(
["-gen-typedef-doc"],
"TfheRustTypes.md",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "TfheRustTypes.td",
deps = [
":dialect_inc_gen",
":td_files",
],
)

gentbl_cc_library(
name = "ops_inc_gen",
tbl_outs = [
(
["-gen-op-decls"],
"TfheRustOps.h.inc",
),
(
["-gen-op-defs"],
"TfheRustOps.cpp.inc",
),
(
["-gen-op-doc"],
"TfheRustOps.md",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "TfheRustOps.td",
deps = [
":dialect_inc_gen",
":td_files",
":types_inc_gen",
],
)
12 changes: 12 additions & 0 deletions include/Dialect/TfheRust/IR/TfheRustDialect.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#ifndef INCLUDE_DIALECT_TFHERUST_IR_TFHERUSTDIALECT_H_
#define INCLUDE_DIALECT_TFHERUST_IR_TFHERUSTDIALECT_H_

#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/include/mlir/IR/Dialect.h" // from @llvm-project
#include "mlir/include/mlir/IR/DialectImplementation.h" // from @llvm-project

// Generated headers (block clang-format from messing up order)
#include "include/Dialect/TfheRust/IR/TfheRustDialect.h.inc"

#endif // INCLUDE_DIALECT_TFHERUST_IR_TFHERUSTDIALECT_H_
21 changes: 21 additions & 0 deletions include/Dialect/TfheRust/IR/TfheRustDialect.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#ifndef INCLUDE_DIALECT_TFHERUST_IR_TFHERUSTDIALECT_TD_
#define INCLUDE_DIALECT_TFHERUST_IR_TFHERUSTDIALECT_TD_

include "mlir/IR/DialectBase.td"
include "mlir/IR/OpBase.td"

def TfheRust_Dialect : Dialect {
let name = "tfhe_rust";

let description = [{
The `thfe_rust` dialect is an exit dialect for generating rust code against the tfhe-rs library API.

See https://github.com/zama-ai/tfhe-rs
}];

let cppNamespace = "::mlir::heir::tfhe_rust";

let useDefaultTypePrinterParser = 1;
}

#endif // INCLUDE_DIALECT_TFHERUST_IR_TFHERUSTDIALECT_TD_
13 changes: 13 additions & 0 deletions include/Dialect/TfheRust/IR/TfheRustOps.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#ifndef INCLUDE_DIALECT_TFHERUST_IR_TFHERUSTOPS_H_
#define INCLUDE_DIALECT_TFHERUST_IR_TFHERUSTOPS_H_

#include "include/Dialect/TfheRust/IR/TfheRustDialect.h"
#include "include/Dialect/TfheRust/IR/TfheRustTypes.h"
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/include/mlir/IR/Dialect.h" // from @llvm-project

#define GET_OP_CLASSES
#include "include/Dialect/TfheRust/IR/TfheRustOps.h.inc"

#endif // INCLUDE_DIALECT_TFHERUST_IR_TFHERUSTOPS_H_
48 changes: 48 additions & 0 deletions include/Dialect/TfheRust/IR/TfheRustOps.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#ifndef INCLUDE_DIALECT_TFHERUST_IR_TFHERUSTOPS_TD_
#define INCLUDE_DIALECT_TFHERUST_IR_TFHERUSTOPS_TD_

include "TfheRustDialect.td"
include "TfheRustTypes.td"

include "mlir/IR/CommonTypeConstraints.td"
include "mlir/IR/OpBase.td"

class TfheRust_Op<string mnemonic, list<Trait> traits = []> :
Op<TfheRust_Dialect, mnemonic, traits> {

let assemblyFormat = [{
operands attr-dict `:` `(` type(operands) `)` `->` type(results)
}];
let cppNamespace = "::mlir::heir::thfe_rust";
}

def CreateTrivial : TfheRust_Op<"create_trivial"> {
let arguments = (ins TfheRust_ServerKey:$serverKey, AnyInteger:$value);
let results = (outs TfheRust_CiphertextType:$output);
}

def ScalarLeftShift : TfheRust_Op<"scalar_left_shift"> {
let arguments = (ins TfheRust_ServerKey:$serverKey, TfheRust_CiphertextType:$ciphertext, AnyI8:$shiftAmount);
let results = (outs TfheRust_CiphertextType:$output);
}

def Add : TfheRust_Op<"add"> {
let arguments = (ins
TfheRust_ServerKey:$serverKey,
TfheRust_CiphertextType:$lhs,
TfheRust_CiphertextType:$rhs
);
let results = (outs TfheRust_CiphertextType:$output);
}

def ApplyLookupTable : TfheRust_Op<"apply_lookup_table"> {
let arguments = (
ins TfheRust_ServerKey:$serverKey,
TfheRust_CiphertextType:$input,
TfheRust_LookupTable:$lookupTable
);
let results = (outs TfheRust_CiphertextType:$output);
}


#endif // INCLUDE_DIALECT_TFHERUST_IR_TFHERUSTOPS_TD_
9 changes: 9 additions & 0 deletions include/Dialect/TfheRust/IR/TfheRustTypes.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#ifndef INCLUDE_DIALECT_TFHERUST_IR_TFHERUSTTYPES_H_
#define INCLUDE_DIALECT_TFHERUST_IR_TFHERUSTTYPES_H_

#include "include/Dialect/TfheRust/IR/TfheRustDialect.h"

#define GET_TYPEDEF_CLASSES
#include "include/Dialect/TfheRust/IR/TfheRustTypes.h.inc"

#endif // INCLUDE_DIALECT_TFHERUST_IR_TFHERUSTTYPES_H_
66 changes: 66 additions & 0 deletions include/Dialect/TfheRust/IR/TfheRustTypes.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#ifndef INCLUDE_DIALECT_TFHERUST_IR_TFHERUSTTYPES_TD_
#define INCLUDE_DIALECT_TFHERUST_IR_TFHERUSTTYPES_TD_

include "TfheRustDialect.td"

include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/CommonTypeConstraints.td"
include "mlir/IR/DialectBase.td"
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td"

class TfheRust_Type<string name, string typeMnemonic>
: TypeDef<TfheRust_Dialect, name> {
let mnemonic = typeMnemonic;
}

class TfheRust_EncryptedUInt<int width> : TfheRust_Type<"TfheRust_EncryptedUInt" # width, "eui" # width> {
let summary = "An encrypted unsigned integer corresponding to tfhe-rs's FHEUint" # width # " type.";
}

// Available options are https://docs.rs/tfhe/latest/tfhe/index.html#types
foreach i = [2, 3, 4, 8, 10, 12, 14, 16, 32, 64, 128, 256] in {
def TfheRust_EncryptedUInt # i : TfheRust_EncryptedUInt<i>;
}

class TfheRust_EncryptedInt<int width> : TfheRust_Type<"TfheRust_EncryptedInt" # width, "ei" # width> {
let summary = "An encrypted signed integer corresponding to tfhe-rs's FHEInt" # width # " type.";
}

// Available options are https://docs.rs/tfhe/latest/tfhe/index.html#types
foreach i = [8, 16, 32, 64, 128, 256] in {
def TfheRust_EncryptedInt # i : TfheRust_EncryptedInt<i>;
}

def TfheRust_CiphertextType :
AnyTypeOf<[
TfheRust_EncryptedUInt2,
TfheRust_EncryptedUInt3,
TfheRust_EncryptedUInt4,
TfheRust_EncryptedUInt8,
TfheRust_EncryptedUInt10,
TfheRust_EncryptedUInt12,
TfheRust_EncryptedUInt14,
TfheRust_EncryptedUInt16,
TfheRust_EncryptedUInt32,
TfheRust_EncryptedUInt64,
TfheRust_EncryptedUInt128,
TfheRust_EncryptedUInt256,
TfheRust_EncryptedInt8,
TfheRust_EncryptedInt16,
TfheRust_EncryptedInt32,
TfheRust_EncryptedInt64,
TfheRust_EncryptedInt128,
TfheRust_EncryptedInt256,
]>;


def TfheRust_ServerKey : TfheRust_Type<"ServerKey", "server_key"> {
let summary = "The server key required to perform homomorphic operations.";
}

def TfheRust_LookupTable : TfheRust_Type<"LookupTable", "lookup_table"> {
let summary = "A univariate lookup table used for programmable bootstrapping.";
}

#endif // INCLUDE_DIALECT_TFHERUST_IR_TFHERUSTTYPES_TD_
23 changes: 23 additions & 0 deletions lib/Dialect/TfheRust/IR/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

cc_library(
name = "Dialect",
srcs = [
"TfheRustDialect.cpp",
],
hdrs = [
"@heir//include/Dialect/TfheRust/IR:TfheRustDialect.h",
"@heir//include/Dialect/TfheRust/IR:TfheRustOps.h",
"@heir//include/Dialect/TfheRust/IR:TfheRustTypes.h",
],
deps = [
"@heir//include/Dialect/TfheRust/IR:dialect_inc_gen",
"@heir//include/Dialect/TfheRust/IR:ops_inc_gen",
"@heir//include/Dialect/TfheRust/IR:types_inc_gen",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
],
)
31 changes: 31 additions & 0 deletions lib/Dialect/TfheRust/IR/TfheRustDialect.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#include "include/Dialect/TfheRust/IR/TfheRustDialect.h"

#include "include/Dialect/TfheRust/IR/TfheRustDialect.cpp.inc"
#include "include/Dialect/TfheRust/IR/TfheRustOps.h"
#include "include/Dialect/TfheRust/IR/TfheRustTypes.h"
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project
#include "mlir/include/mlir/IR/DialectImplementation.h" // from @llvm-project
#define GET_TYPEDEF_CLASSES
#include "include/Dialect/TfheRust/IR/TfheRustTypes.cpp.inc"
#define GET_OP_CLASSES
#include "include/Dialect/TfheRust/IR/TfheRustOps.cpp.inc"

namespace mlir {
namespace heir {
namespace tfhe_rust {

void TfheRustDialect::initialize() {
addTypes<
#define GET_TYPEDEF_LIST
#include "include/Dialect/TfheRust/IR/TfheRustTypes.cpp.inc"
>();
addOperations<
#define GET_OP_LIST
#include "include/Dialect/TfheRust/IR/TfheRustOps.cpp.inc"
>();
}

} // namespace tfhe_rust
} // namespace heir
} // namespace mlir
13 changes: 13 additions & 0 deletions tests/tfhe_rs/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
load("//bazel:lit.bzl", "glob_lit_tests")

package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

glob_lit_tests(
name = "all_tests",
data = ["@heir//tests:test_utilities"],
driver = "@heir//tests:run_lit.sh",
test_file_exts = ["mlir"],
)
32 changes: 32 additions & 0 deletions tests/tfhe_rs/ops.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// RUN: heir-opt %s | FileCheck %s

// This simply tests for syntax.

!sks = !tfhe_rust.server_key
module {
// CHECK-LABEL: func @test_create_trivial
func.func @test_create_trivial(%sks : !sks) {
%0 = arith.constant 1 : i8
%1 = arith.constant 1 : i3
%2 = arith.constant 1 : i128
%e1 = tfhe_rust.create_trivial %sks, %0 : (!sks, i8) -> !tfhe_rust.ei8
%eu1 = tfhe_rust.create_trivial %sks, %1 : (!sks, i3) -> !tfhe_rust.eui8
%e2 = tfhe_rust.create_trivial %sks, %2 : (!sks, i128) -> !tfhe_rust.ei128
return
}

// CHECK-LABEL: func @test_apply_lookup_table
func.func @test_apply_lookup_table(%sks : !sks, %lut: !tfhe_rust.lookup_table) {
%0 = arith.constant 1 : i3
%1 = arith.constant 2 : i3
%e1 = tfhe_rust.create_trivial %sks, %0 : (!sks, i3) -> !tfhe_rust.eui3
%e2 = tfhe_rust.create_trivial %sks, %0 : (!sks, i3) -> !tfhe_rust.eui3

%shiftAmount = arith.constant 1 : i8
%e2Shifted = tfhe_rust.scalar_left_shift %sks, %e2, %shiftAmount : (!sks, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3
%eCombined = tfhe_rust.add %sks, %e1, %e2Shifted : (!sks, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3

%out = tfhe_rust.apply_lookup_table %sks, %eCombined, %lut : (!sks, !tfhe_rust.eui3, !tfhe_rust.lookup_table) -> !tfhe_rust.eui3
return
}
}
Loading

0 comments on commit 9020c83

Please sign in to comment.