Skip to content

Commit

Permalink
Merge pull request google#298 from asraa:secretize-tf-models
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 586521579
  • Loading branch information
copybara-github committed Nov 30, 2023
2 parents c057ae2 + 8051069 commit 0911ec3
Show file tree
Hide file tree
Showing 13 changed files with 216 additions and 0 deletions.
7 changes: 7 additions & 0 deletions include/Dialect/Secret/IR/SecretDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@ def Secret_Dialect : Dialect {
custom types for arithmetic on secret integers of various bit widths.
}];

let extraClassDeclaration = [{
/// Name of the attribute indicate whether an argument of a function is a
//secret.
constexpr const static ::llvm::StringLiteral
kArgSecretAttrName = "secret.secret";
}];

let cppNamespace = "::mlir::heir::secret";
let useDefaultTypePrinterParser = 1;
}
Expand Down
35 changes: 35 additions & 0 deletions include/Transforms/Secretize/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Secretize tablegen and headers.

load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library")

package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

exports_files([
"Secretize.h",
])

gentbl_cc_library(
name = "pass_inc_gen",
tbl_outs = [
(
[
"-gen-pass-decls",
"-name=Secretize",
],
"Secretize.h.inc",
),
(
["-gen-pass-doc"],
"Secretize.md",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "Secretize.td",
deps = [
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:PassBaseTdFiles",
],
)
18 changes: 18 additions & 0 deletions include/Transforms/Secretize/Secretize.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#ifndef INCLUDE_TRANSFORMS_SECRETIZE_SECRETIZE_H_
#define INCLUDE_TRANSFORMS_SECRETIZE_SECRETIZE_H_

#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project

namespace mlir {
namespace heir {

#define GEN_PASS_DECL
#include "include/Transforms/Secretize/Secretize.h.inc"

#define GEN_PASS_REGISTRATION
#include "include/Transforms/Secretize/Secretize.h.inc"

} // namespace heir
} // namespace mlir

#endif // INCLUDE_TRANSFORMS_SECRETIZE_SECRETIZE_H_
25 changes: 25 additions & 0 deletions include/Transforms/Secretize/Secretize.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#ifndef INCLUDE_TRANSFORMS_SECRETIZE_SECRETIZE_TD_
#define INCLUDE_TRANSFORMS_SECRETIZE_SECRETIZE_TD_

include "mlir/Pass/PassBase.td"

def Secretize : Pass<"secretize", "ModuleOp"> {
let summary = "Adds secret argument attributes to entry function";

let description = [{
Adds a secret.secret attribute argument to each argument in the entry
function of an MLIR module. By default, the function is `main`. This may be
overridden with the option -entry-function=top_level_func.
}];

let dependentDialects = [
"mlir::heir::secret::SecretDialect",
"mlir::func::FuncDialect"
];

let options = [
Option<"entryFunction", "entry-function", "std::string", "\"main\"", "entry function of the module">
];
}

#endif // INCLUDE_TRANSFORMS_SECRETIZE_SECRETIZE_TD_
21 changes: 21 additions & 0 deletions lib/Transforms/Secretize/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

cc_library(
name = "Secretize",
srcs = ["Secretize.cpp"],
hdrs = [
"@heir//include/Transforms/Secretize:Secretize.h",
],
deps = [
"@heir//include/Transforms/Secretize:pass_inc_gen",
"@heir//lib/Dialect/Secret/IR:Dialect",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
],
)
39 changes: 39 additions & 0 deletions lib/Transforms/Secretize/Secretize.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#include "include/Transforms/Secretize/Secretize.h"

#include "include/Dialect/Secret/IR/SecretDialect.h"
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/Passes.h" // from @llvm-project

namespace mlir {
namespace heir {

#define GEN_PASS_DEF_SECRETIZE
#include "include/Transforms/Secretize/Secretize.h.inc"

struct Secretize : impl::SecretizeBase<Secretize> {
using SecretizeBase::SecretizeBase;

void runOnOperation() override {
MLIRContext* ctx = &getContext();
ModuleOp module = getOperation();
OpBuilder builder(module);

auto mainFunction = dyn_cast_or_null<func::FuncOp>(
SymbolTable::lookupSymbolIn(module, entryFunction));
if (!mainFunction) {
module.emitError("could not find entry point function");
signalPassFailure();
return;
}

auto secretArgAttr =
StringAttr::get(ctx, secret::SecretDialect::kArgSecretAttrName);
for (unsigned i = 0; i < mainFunction.getNumArguments(); i++) {
mainFunction.setArgAttr(i, secretArgAttr, UnitAttr::get(ctx));
}
}
};

} // namespace heir
} // namespace mlir
13 changes: 13 additions & 0 deletions tests/secretize/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
load("//bazel:lit.bzl", "glob_lit_tests")

package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

glob_lit_tests(
name = "all_tests",
data = ["@heir//tests:test_utilities"],
driver = "@heir//tests:run_lit.sh",
test_file_exts = ["mlir"],
)
22 changes: 22 additions & 0 deletions tests/secretize/main.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// RUN: heir-opt -secretize %s | FileCheck %s

module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} {
// CHECK: func.func @main(%arg0: tensor<1x1xi8> {secret.secret, tf_saved_model.index_path = ["dense_input"]})
func.func @main(%arg0: tensor<1x1xi8> {tf_saved_model.index_path = ["dense_input"]}) -> (tensor<1x1xi8> {tf_saved_model.index_path = ["dense_2"]}) {
%0 = "tosa.const"() <{value = dense<429> : tensor<1xi32>}> : () -> tensor<1xi32>
%1 = "tosa.const"() <{value = dense<[[-39, 59, 39, 21, 28, -32, -34, -35, 15, 27, -59, -41, 18, -35, -7, 127]]> : tensor<1x16xi8>}> : () -> tensor<1x16xi8>
%2 = "tosa.const"() <{value = dense<[-729, 1954, 610, 0, 241, -471, -35, -867, 571, 581, 4260, 3943, 591, 0, -889, -5103]> : tensor<16xi32>}> : () -> tensor<16xi32>
%3 = "tosa.const"() <{value = dense<"0xF41AED091921F424E021EFBCF7F5FA1903DCD20206F9F402FFFAEFF1EFD327E1FB27DDEBDBE4051A17FC241215EF1EE410FE14DA1CF8F3F1EFE2F309E3E9EDE3E415070B041B1AFEEB01DE21E60BEC03230A22241E2703E60324FFC011F8FCF1110CF5E0F30717E5E8EDFADCE823FB07DDFBFD0014261117E7F111EA0226040425211D0ADB1DDC2001FAE3370BF11A16EF1CE703E01602032118092ED9E5140BEA1AFCD81300C4D8ECD9FE0D1920D8D6E21FE9D7CAE2DDC613E7043E000114C7DBE71515F506D61ADC0922FE080213EF191EE209FDF314DDDA20D90FE3F9F7EEE924E629000716E21E0D23D3DDF714FA0822262109080F0BE012F47FDC58E526"> : tensor<16x16xi8>}> : () -> tensor<16x16xi8>
%4 = "tosa.const"() <{value = dense<[0, 0, -5438, -5515, -1352, -1500, -4152, -84, 3396, 0, 1981, -5581, 0, -6964, 3407, -7217]> : tensor<16xi32>}> : () -> tensor<16xi32>
%5 = "tosa.const"() <{value = dense<[[-9], [-54], [57], [71], [104], [115], [98], [99], [64], [-26], [127], [25], [-82], [68], [95], [86]]> : tensor<16x1xi8>}> : () -> tensor<16x1xi8>
%6 = "tosa.fully_connected"(%arg0, %5, %4) <{quantization_info = #tosa.conv_quant<input_zp = -128, weight_zp = 0>}> : (tensor<1x1xi8>, tensor<16x1xi8>, tensor<16xi32>) -> tensor<1x16xi32>
%7 = "tosa.rescale"(%6) <{double_round = true, input_zp = 0 : i32, multiplier = array<i32: 2039655736>, output_zp = -128 : i32, per_channel = false, scale32 = true, shift = array<i32: 38>}> : (tensor<1x16xi32>) -> tensor<1x16xi8>
%8 = "tosa.clamp"(%7) <{max_fp = 0.000000e+00 : f32, max_int = 127 : i64, min_fp = 0.000000e+00 : f32, min_int = -128 : i64}> : (tensor<1x16xi8>) -> tensor<1x16xi8>
%9 = "tosa.fully_connected"(%8, %3, %2) <{quantization_info = #tosa.conv_quant<input_zp = -128, weight_zp = 0>}> : (tensor<1x16xi8>, tensor<16x16xi8>, tensor<16xi32>) -> tensor<1x16xi32>
%10 = "tosa.rescale"(%9) <{double_round = true, input_zp = 0 : i32, multiplier = array<i32: 1561796795>, output_zp = -128 : i32, per_channel = false, scale32 = true, shift = array<i32: 37>}> : (tensor<1x16xi32>) -> tensor<1x16xi8>
%11 = "tosa.clamp"(%10) <{max_fp = 0.000000e+00 : f32, max_int = 127 : i64, min_fp = 0.000000e+00 : f32, min_int = -128 : i64}> : (tensor<1x16xi8>) -> tensor<1x16xi8>
%12 = "tosa.fully_connected"(%11, %1, %0) <{quantization_info = #tosa.conv_quant<input_zp = -128, weight_zp = 0>}> : (tensor<1x16xi8>, tensor<1x16xi8>, tensor<1xi32>) -> tensor<1x1xi32>
%13 = "tosa.rescale"(%12) <{double_round = true, input_zp = 0 : i32, multiplier = array<i32: 1630361836>, output_zp = 5 : i32, per_channel = false, scale32 = true, shift = array<i32: 36>}> : (tensor<1x1xi32>) -> tensor<1x1xi8>
return %13 : tensor<1x1xi8>
}
}
9 changes: 9 additions & 0 deletions tests/secretize/missing.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// RUN: heir-opt -secretize -verify-diagnostics %s

// expected-error@+1 {{could not find entry point function}}
module {
func.func @comb(%a: i1, %b: i1) -> () {
%0 = comb.truth_table %a, %b -> 6 : ui4
return
}
}
15 changes: 15 additions & 0 deletions tests/secretize/multiple.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// RUN: heir-opt -secretize %s | FileCheck %s

module {
// CHECK: func.func @inner(%arg0: i1, %arg1: i1)
func.func @inner(%a: i1, %b: i1) -> () {
%0 = comb.truth_table %a, %b -> 6 : ui4
return
}

// CHECK: func.func @main(%arg0: i1 {secret.secret}, %arg1: i1 {secret.secret})
func.func @main(%a: i1, %b: i1) -> () {
func.call @inner(%a, %b) : (i1, i1) -> ()
return
}
}
9 changes: 9 additions & 0 deletions tests/secretize/named.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// RUN: heir-opt -secretize=entry-function=comb %s | FileCheck %s

module {
// CHECK: func.func @comb(%arg0: i1 {secret.secret}, %arg1: i1 {secret.secret})
func.func @comb(%a: i1, %b: i1) -> () {
%0 = comb.truth_table %a, %b -> 6 : ui4
return
}
}
1 change: 1 addition & 0 deletions tools/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ cc_binary(
"@heir//lib/Dialect/Secret/IR:Dialect",
"@heir//lib/Dialect/Secret/Transforms",
"@heir//lib/Dialect/TfheRust/IR:Dialect",
"@heir//lib/Transforms/Secretize",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AffineDialect",
"@llvm-project//mlir:AffineToStandard",
Expand Down
2 changes: 2 additions & 0 deletions tools/heir-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "include/Dialect/Secret/IR/SecretDialect.h"
#include "include/Dialect/Secret/Transforms/Passes.h"
#include "include/Dialect/TfheRust/IR/TfheRustDialect.h"
#include "include/Transforms/Secretize/Secretize.h"
#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project
#include "mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h" // from @llvm-project
#include "mlir/include/mlir/Conversion/ArithToLLVM/ArithToLLVM.h" // from @llvm-project
Expand Down Expand Up @@ -182,6 +183,7 @@ int main(int argc, char **argv) {
bgv::registerBGVToPolynomialPasses();
comb::registerCombToCGGIPasses();
registerCGGIToTfheRustPasses();
registerSecretizePasses();
// Register yosys optimizer pipeline if configured.
#ifndef HEIR_NO_YOSYS
const char *abcEnvPath = std::getenv("HEIR_ABC_BINARY");
Expand Down

0 comments on commit 0911ec3

Please sign in to comment.