forked from google/heir
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request google#298 from asraa:secretize-tf-models
PiperOrigin-RevId: 586521579
- Loading branch information
Showing
13 changed files
with
216 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
# Secretize tablegen and headers. | ||
|
||
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") | ||
|
||
package( | ||
default_applicable_licenses = ["@heir//:license"], | ||
default_visibility = ["//visibility:public"], | ||
) | ||
|
||
exports_files([ | ||
"Secretize.h", | ||
]) | ||
|
||
gentbl_cc_library( | ||
name = "pass_inc_gen", | ||
tbl_outs = [ | ||
( | ||
[ | ||
"-gen-pass-decls", | ||
"-name=Secretize", | ||
], | ||
"Secretize.h.inc", | ||
), | ||
( | ||
["-gen-pass-doc"], | ||
"Secretize.md", | ||
), | ||
], | ||
tblgen = "@llvm-project//mlir:mlir-tblgen", | ||
td_file = "Secretize.td", | ||
deps = [ | ||
"@llvm-project//mlir:OpBaseTdFiles", | ||
"@llvm-project//mlir:PassBaseTdFiles", | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
#ifndef INCLUDE_TRANSFORMS_SECRETIZE_SECRETIZE_H_ | ||
#define INCLUDE_TRANSFORMS_SECRETIZE_SECRETIZE_H_ | ||
|
||
#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project | ||
|
||
namespace mlir { | ||
namespace heir { | ||
|
||
#define GEN_PASS_DECL | ||
#include "include/Transforms/Secretize/Secretize.h.inc" | ||
|
||
#define GEN_PASS_REGISTRATION | ||
#include "include/Transforms/Secretize/Secretize.h.inc" | ||
|
||
} // namespace heir | ||
} // namespace mlir | ||
|
||
#endif // INCLUDE_TRANSFORMS_SECRETIZE_SECRETIZE_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
#ifndef INCLUDE_TRANSFORMS_SECRETIZE_SECRETIZE_TD_ | ||
#define INCLUDE_TRANSFORMS_SECRETIZE_SECRETIZE_TD_ | ||
|
||
include "mlir/Pass/PassBase.td" | ||
|
||
def Secretize : Pass<"secretize", "ModuleOp"> { | ||
let summary = "Adds secret argument attributes to entry function"; | ||
|
||
let description = [{ | ||
Adds a secret.secret attribute argument to each argument in the entry | ||
function of an MLIR module. By default, the function is `main`. This may be | ||
overridden with the option -entry-function=top_level_func. | ||
}]; | ||
|
||
let dependentDialects = [ | ||
"mlir::heir::secret::SecretDialect", | ||
"mlir::func::FuncDialect" | ||
]; | ||
|
||
let options = [ | ||
Option<"entryFunction", "entry-function", "std::string", "\"main\"", "entry function of the module"> | ||
]; | ||
} | ||
|
||
#endif // INCLUDE_TRANSFORMS_SECRETIZE_SECRETIZE_TD_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
package( | ||
default_applicable_licenses = ["@heir//:license"], | ||
default_visibility = ["//visibility:public"], | ||
) | ||
|
||
cc_library( | ||
name = "Secretize", | ||
srcs = ["Secretize.cpp"], | ||
hdrs = [ | ||
"@heir//include/Transforms/Secretize:Secretize.h", | ||
], | ||
deps = [ | ||
"@heir//include/Transforms/Secretize:pass_inc_gen", | ||
"@heir//lib/Dialect/Secret/IR:Dialect", | ||
"@llvm-project//mlir:FuncDialect", | ||
"@llvm-project//mlir:IR", | ||
"@llvm-project//mlir:Pass", | ||
"@llvm-project//mlir:Support", | ||
"@llvm-project//mlir:Transforms", | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
#include "include/Transforms/Secretize/Secretize.h" | ||
|
||
#include "include/Dialect/Secret/IR/SecretDialect.h" | ||
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project | ||
#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project | ||
#include "mlir/include/mlir/Transforms/Passes.h" // from @llvm-project | ||
|
||
namespace mlir { | ||
namespace heir { | ||
|
||
#define GEN_PASS_DEF_SECRETIZE | ||
#include "include/Transforms/Secretize/Secretize.h.inc" | ||
|
||
struct Secretize : impl::SecretizeBase<Secretize> { | ||
using SecretizeBase::SecretizeBase; | ||
|
||
void runOnOperation() override { | ||
MLIRContext* ctx = &getContext(); | ||
ModuleOp module = getOperation(); | ||
OpBuilder builder(module); | ||
|
||
auto mainFunction = dyn_cast_or_null<func::FuncOp>( | ||
SymbolTable::lookupSymbolIn(module, entryFunction)); | ||
if (!mainFunction) { | ||
module.emitError("could not find entry point function"); | ||
signalPassFailure(); | ||
return; | ||
} | ||
|
||
auto secretArgAttr = | ||
StringAttr::get(ctx, secret::SecretDialect::kArgSecretAttrName); | ||
for (unsigned i = 0; i < mainFunction.getNumArguments(); i++) { | ||
mainFunction.setArgAttr(i, secretArgAttr, UnitAttr::get(ctx)); | ||
} | ||
} | ||
}; | ||
|
||
} // namespace heir | ||
} // namespace mlir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
load("//bazel:lit.bzl", "glob_lit_tests") | ||
|
||
package( | ||
default_applicable_licenses = ["@heir//:license"], | ||
default_visibility = ["//visibility:public"], | ||
) | ||
|
||
glob_lit_tests( | ||
name = "all_tests", | ||
data = ["@heir//tests:test_utilities"], | ||
driver = "@heir//tests:run_lit.sh", | ||
test_file_exts = ["mlir"], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
// RUN: heir-opt -secretize %s | FileCheck %s | ||
|
||
module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} { | ||
// CHECK: func.func @main(%arg0: tensor<1x1xi8> {secret.secret, tf_saved_model.index_path = ["dense_input"]}) | ||
func.func @main(%arg0: tensor<1x1xi8> {tf_saved_model.index_path = ["dense_input"]}) -> (tensor<1x1xi8> {tf_saved_model.index_path = ["dense_2"]}) { | ||
%0 = "tosa.const"() <{value = dense<429> : tensor<1xi32>}> : () -> tensor<1xi32> | ||
%1 = "tosa.const"() <{value = dense<[[-39, 59, 39, 21, 28, -32, -34, -35, 15, 27, -59, -41, 18, -35, -7, 127]]> : tensor<1x16xi8>}> : () -> tensor<1x16xi8> | ||
%2 = "tosa.const"() <{value = dense<[-729, 1954, 610, 0, 241, -471, -35, -867, 571, 581, 4260, 3943, 591, 0, -889, -5103]> : tensor<16xi32>}> : () -> tensor<16xi32> | ||
%3 = "tosa.const"() <{value = dense<"0xF41AED091921F424E021EFBCF7F5FA1903DCD20206F9F402FFFAEFF1EFD327E1FB27DDEBDBE4051A17FC241215EF1EE410FE14DA1CF8F3F1EFE2F309E3E9EDE3E415070B041B1AFEEB01DE21E60BEC03230A22241E2703E60324FFC011F8FCF1110CF5E0F30717E5E8EDFADCE823FB07DDFBFD0014261117E7F111EA0226040425211D0ADB1DDC2001FAE3370BF11A16EF1CE703E01602032118092ED9E5140BEA1AFCD81300C4D8ECD9FE0D1920D8D6E21FE9D7CAE2DDC613E7043E000114C7DBE71515F506D61ADC0922FE080213EF191EE209FDF314DDDA20D90FE3F9F7EEE924E629000716E21E0D23D3DDF714FA0822262109080F0BE012F47FDC58E526"> : tensor<16x16xi8>}> : () -> tensor<16x16xi8> | ||
%4 = "tosa.const"() <{value = dense<[0, 0, -5438, -5515, -1352, -1500, -4152, -84, 3396, 0, 1981, -5581, 0, -6964, 3407, -7217]> : tensor<16xi32>}> : () -> tensor<16xi32> | ||
%5 = "tosa.const"() <{value = dense<[[-9], [-54], [57], [71], [104], [115], [98], [99], [64], [-26], [127], [25], [-82], [68], [95], [86]]> : tensor<16x1xi8>}> : () -> tensor<16x1xi8> | ||
%6 = "tosa.fully_connected"(%arg0, %5, %4) <{quantization_info = #tosa.conv_quant<input_zp = -128, weight_zp = 0>}> : (tensor<1x1xi8>, tensor<16x1xi8>, tensor<16xi32>) -> tensor<1x16xi32> | ||
%7 = "tosa.rescale"(%6) <{double_round = true, input_zp = 0 : i32, multiplier = array<i32: 2039655736>, output_zp = -128 : i32, per_channel = false, scale32 = true, shift = array<i32: 38>}> : (tensor<1x16xi32>) -> tensor<1x16xi8> | ||
%8 = "tosa.clamp"(%7) <{max_fp = 0.000000e+00 : f32, max_int = 127 : i64, min_fp = 0.000000e+00 : f32, min_int = -128 : i64}> : (tensor<1x16xi8>) -> tensor<1x16xi8> | ||
%9 = "tosa.fully_connected"(%8, %3, %2) <{quantization_info = #tosa.conv_quant<input_zp = -128, weight_zp = 0>}> : (tensor<1x16xi8>, tensor<16x16xi8>, tensor<16xi32>) -> tensor<1x16xi32> | ||
%10 = "tosa.rescale"(%9) <{double_round = true, input_zp = 0 : i32, multiplier = array<i32: 1561796795>, output_zp = -128 : i32, per_channel = false, scale32 = true, shift = array<i32: 37>}> : (tensor<1x16xi32>) -> tensor<1x16xi8> | ||
%11 = "tosa.clamp"(%10) <{max_fp = 0.000000e+00 : f32, max_int = 127 : i64, min_fp = 0.000000e+00 : f32, min_int = -128 : i64}> : (tensor<1x16xi8>) -> tensor<1x16xi8> | ||
%12 = "tosa.fully_connected"(%11, %1, %0) <{quantization_info = #tosa.conv_quant<input_zp = -128, weight_zp = 0>}> : (tensor<1x16xi8>, tensor<1x16xi8>, tensor<1xi32>) -> tensor<1x1xi32> | ||
%13 = "tosa.rescale"(%12) <{double_round = true, input_zp = 0 : i32, multiplier = array<i32: 1630361836>, output_zp = 5 : i32, per_channel = false, scale32 = true, shift = array<i32: 36>}> : (tensor<1x1xi32>) -> tensor<1x1xi8> | ||
return %13 : tensor<1x1xi8> | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
// RUN: heir-opt -secretize -verify-diagnostics %s | ||
|
||
// expected-error@+1 {{could not find entry point function}} | ||
module { | ||
func.func @comb(%a: i1, %b: i1) -> () { | ||
%0 = comb.truth_table %a, %b -> 6 : ui4 | ||
return | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
// RUN: heir-opt -secretize %s | FileCheck %s | ||
|
||
module { | ||
// CHECK: func.func @inner(%arg0: i1, %arg1: i1) | ||
func.func @inner(%a: i1, %b: i1) -> () { | ||
%0 = comb.truth_table %a, %b -> 6 : ui4 | ||
return | ||
} | ||
|
||
// CHECK: func.func @main(%arg0: i1 {secret.secret}, %arg1: i1 {secret.secret}) | ||
func.func @main(%a: i1, %b: i1) -> () { | ||
func.call @inner(%a, %b) : (i1, i1) -> () | ||
return | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
// RUN: heir-opt -secretize=entry-function=comb %s | FileCheck %s | ||
|
||
module { | ||
// CHECK: func.func @comb(%arg0: i1 {secret.secret}, %arg1: i1 {secret.secret}) | ||
func.func @comb(%a: i1, %b: i1) -> () { | ||
%0 = comb.truth_table %a, %b -> 6 : ui4 | ||
return | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters