Skip to content

Commit

Permalink
Merge pull request google#931 from MeronZerihun:do-pipeline-dev
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 684103251
  • Loading branch information
copybara-github committed Oct 9, 2024
2 parents 6b6f253 + 41d5922 commit 5280ed3
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 0 deletions.
10 changes: 10 additions & 0 deletions tests/convert_to_data_oblivious/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
load("//bazel:lit.bzl", "glob_lit_tests")

package(default_applicable_licenses = ["@heir//:license"])

glob_lit_tests(
name = "all_tests",
data = ["@heir//tests:test_utilities"],
driver = "@heir//tests:run_lit.sh",
test_file_exts = ["mlir"],
)
32 changes: 32 additions & 0 deletions tests/convert_to_data_oblivious/test.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// RUN: heir-opt --secretize=entry-function=test --wrap-generic --convert-to-data-oblivious %s | FileCheck %s

// CHECK-LABEL: test
func.func @test(%secretUpperBound : index{secret.secret}, %secretIndex : index {secret.secret}, %secretTensor : tensor<32xi16>{secret.secret}) -> i16{
%i0 = arith.constant 0 : index
%i1 = arith.constant 1 : index
%c0 = arith.constant 0 : i16
// CHECK: affine.for
// CHECK-NEXT: arith.cmpi eq
// CHECK-NEXT: tensor.extract
// CHECK-NEXT: arith.select
%extracted = tensor.extract %secretTensor[%secretIndex] : tensor<32xi16>
// CHECK: affine.for
// CHECK-NEXT: arith.cmpi slt
// CHECK-NOT: scf.if
%result = scf.for %i = %i0 to %secretIndex step %i1 iter_args(%sum = %c0) -> i16 {
%element = tensor.extract %secretTensor[%i] : tensor<32xi16>
%cond = arith.cmpi eq, %element, %extracted : i16
%if = scf.if %cond -> i16 {
%c2 = arith.constant 2 : i16
%mul = arith.muli %element, %c2 : i16
scf.yield %mul : i16
} else {
%add = arith.addi %element, %extracted : i16
scf.yield %add : i16
}
%for = arith.addi %sum, %if : i16
scf.yield %for : i16
}{lower = 0, upper = 32}
return %result : i16

}
61 changes: 61 additions & 0 deletions tools/heir-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ void oneShotBufferize(OpPassManager &manager) {
manager.addPass(bufferization::createBufferDeallocationSimplificationPass());
manager.addPass(bufferization::createLowerDeallocationsPass());
manager.addPass(createCSEPass());
manager.addPass(mlir::createBufferizationToMemRefPass());
manager.addPass(createCanonicalizerPass());
}

Expand Down Expand Up @@ -185,6 +186,19 @@ void tosaPipelineBuilder(OpPassManager &manager) {
manager.addPass(createSymbolDCEPass());
}

void convertToDataObliviousPipelineBuilder(OpPassManager &manager) {
// Access Transformation
manager.addPass(createConvertSecretExtractToStaticExtract());
manager.addPass(createConvertSecretInsertToStaticInsert());

// Loop Transformation
manager.addPass(createConvertSecretWhileToStaticFor());
manager.addPass(createConvertSecretForToStaticFor());

// If Transformation
manager.addPass(createConvertIfToSelect());
}

void polynomialToLLVMPipelineBuilder(OpPassManager &manager) {
// Poly
manager.addPass(createElementwiseToAffine());
Expand Down Expand Up @@ -228,6 +242,43 @@ void polynomialToLLVMPipelineBuilder(OpPassManager &manager) {
manager.addPass(createSymbolDCEPass());
}

void basicMLIRToLLVMPipelineBuilder(OpPassManager &manager) {
// Linalg
manager.addNestedPass<FuncOp>(createConvertElementwiseToLinalgPass());
// Needed to lower affine.map and affine.apply
manager.addNestedPass<FuncOp>(affine::createAffineExpandIndexOpsPass());
manager.addNestedPass<FuncOp>(affine::createSimplifyAffineStructuresPass());
manager.addPass(createLowerAffinePass());
manager.addNestedPass<FuncOp>(memref::createExpandOpsPass());
manager.addNestedPass<FuncOp>(memref::createExpandStridedMetadataPass());

// Bufferize
oneShotBufferize(manager);

// Linalg must be bufferized before it can be lowered
// But lowering to loops also re-introduces affine.apply, so re-lower that
manager.addNestedPass<FuncOp>(createConvertLinalgToLoopsPass());
manager.addPass(createLowerAffinePass());

// Cleanup
manager.addPass(createCanonicalizerPass());
manager.addPass(createSCCPPass());
manager.addPass(createCSEPass());
manager.addPass(createSymbolDCEPass());

// ToLLVM
manager.addPass(arith::createArithExpandOpsPass());
manager.addPass(createConvertSCFToCFPass());
manager.addNestedPass<FuncOp>(memref::createExpandStridedMetadataPass());
manager.addPass(createConvertToLLVMPass());

// Cleanup
manager.addPass(createCanonicalizerPass());
manager.addPass(createSCCPPass());
manager.addPass(createCSEPass());
manager.addPass(createSymbolDCEPass());
}

void heirSIMDVectorizerPipelineBuilder(OpPassManager &manager) {
// For now we unroll loops to enable insert-rotate, but we would like to be
// smarter about this and do an affine loop analysis.
Expand Down Expand Up @@ -571,6 +622,7 @@ void mlirToRLWEPipeline(OpPassManager &pm,
// Secretize inputs
pm.addPass(createSecretize(SecretizeOptions{options.entryFunction}));
pm.addPass(createWrapGeneric());
convertToDataObliviousPipelineBuilder(pm);
pm.addPass(createCanonicalizerPass());
pm.addPass(createCSEPass());

Expand Down Expand Up @@ -806,6 +858,10 @@ int main(int argc, char **argv) {
"Run passes to lower the polynomial dialect to LLVM",
polynomialToLLVMPipelineBuilder);

PassPipelineRegistration<>("heir-basic-mlir-to-llvm",
"Lower basic MLIR to LLVM",
basicMLIRToLLVMPipelineBuilder);

PassPipelineRegistration<>(
"heir-simd-vectorizer",
"Run scheme-agnostic passes to convert FHE programs that operate on "
Expand Down Expand Up @@ -839,6 +895,11 @@ int main(int argc, char **argv) {
"to OpenFHE C++ code.",
mlirToOpenFheRLWEPipelineBuilder(RLWEScheme::ckks));

PassPipelineRegistration<>(
"convert-to-data-oblivious",
"Transforms a native program to data-oblivious program",
convertToDataObliviousPipelineBuilder);

return asMainReturnCode(
MlirOptMain(argc, argv, "HEIR Pass Driver", registry));
}

0 comments on commit 5280ed3

Please sign in to comment.