diff --git a/lib/Dialect/CGGI/Conversions/CGGIToTfheRust/CGGIToTfheRust.cpp b/lib/Dialect/CGGI/Conversions/CGGIToTfheRust/CGGIToTfheRust.cpp index d8dabc4a9..5808fcd60 100644 --- a/lib/Dialect/CGGI/Conversions/CGGIToTfheRust/CGGIToTfheRust.cpp +++ b/lib/Dialect/CGGI/Conversions/CGGIToTfheRust/CGGIToTfheRust.cpp @@ -13,6 +13,7 @@ #include "lib/Utils/ConversionUtils/ConversionUtils.h" #include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project #include "llvm/include/llvm/Support/Casting.h" // from @llvm-project +#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project #include "llvm/include/llvm/Support/ErrorHandling.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project @@ -27,6 +28,8 @@ #include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project +#define DEBUG_TYPE "cggi-to-tfhe-rust" + namespace mlir::heir { #define GEN_PASS_DEF_CGGITOTFHERUST @@ -41,6 +44,15 @@ Type encrytpedUIntTypeFromWidth(MLIRContext *ctx, int width) { // Only supporting unsigned types because the LWE dialect does not have a // notion of signedness. switch (width) { + case 1: + // The minimum bit width of the integer tfhe_rust API is UInt2 + // https://docs.rs/tfhe/latest/tfhe/index.html#types + // This may happen if there are no LUT or boolean gate operations that + // require a minimum bit width (e.g. shuffling bits in a program that + // multiplies by two). + LLVM_DEBUG(llvm::dbgs() + << "Upgrading ciphertext with bit width 1 to UInt2"); + [[fallthrough]]; case 2: return tfhe_rust::EncryptedUInt2Type::get(ctx); case 3: diff --git a/lib/Dialect/Secret/Conversions/SecretToCGGI/SecretToCGGI.cpp b/lib/Dialect/Secret/Conversions/SecretToCGGI/SecretToCGGI.cpp index 701134494..11d38f531 100644 --- a/lib/Dialect/Secret/Conversions/SecretToCGGI/SecretToCGGI.cpp +++ b/lib/Dialect/Secret/Conversions/SecretToCGGI/SecretToCGGI.cpp @@ -1,10 +1,12 @@ #include "lib/Dialect/Secret/Conversions/SecretToCGGI/SecretToCGGI.h" +#include #include #include #include #include "lib/Dialect/CGGI/IR/CGGIOps.h" +#include "lib/Dialect/Comb/IR/CombDialect.h" #include "lib/Dialect/Comb/IR/CombOps.h" #include "lib/Dialect/LWE/IR/LWEAttributes.h" #include "lib/Dialect/LWE/IR/LWEOps.h" @@ -536,23 +538,23 @@ struct ConvertSecretCastOp : public OpConversionPattern { }; int findLUTSize(MLIRContext *context, Operation *module) { - int max_int_size = 0; + int maxIntSize = 1; auto processOperation = [&](Operation *op) { if (isa(op->getDialect())) { - int current_size = 0; + int currentSize = 0; if (dyn_cast(op)) - current_size = 3; + currentSize = 3; else - current_size = op->getResults().getTypes()[0].getIntOrFloatBitWidth(); + currentSize = op->getResults().getTypes()[0].getIntOrFloatBitWidth(); - max_int_size = std::max(max_int_size, current_size); + maxIntSize = std::max(maxIntSize, currentSize); } }; // Walk all operations within the module in post-order (default) module->walk(processOperation); - return max_int_size; + return maxIntSize; } struct SecretToCGGI : public impl::SecretToCGGIBase { @@ -572,6 +574,8 @@ struct SecretToCGGI : public impl::SecretToCGGIBase { .add, SecretGenericOpConversion, + SecretGenericOpConversion, SecretGenericOpMemRefLoadConversion, SecretGenericOpAffineStoreConversion, SecretGenericOpAffineLoadConversion, diff --git a/lib/Dialect/Secret/IR/SecretPatterns.cpp b/lib/Dialect/Secret/IR/SecretPatterns.cpp index 97d157662..28bea2ca7 100644 --- a/lib/Dialect/Secret/IR/SecretPatterns.cpp +++ b/lib/Dialect/Secret/IR/SecretPatterns.cpp @@ -736,6 +736,7 @@ LogicalResult extractGenericBody(secret::GenericOp genericOp, std::string funcName = llvm::formatv( "internal_generic_{0}", mlir::hash_value(yieldOp.getValues()[0])); auto func = builder.create(module.getLoc(), funcName, type); + func.setPrivate(); // Populate function body by cloning the ops in the inner body and mapping // the func args and func outputs. diff --git a/lib/Target/TfheRust/TfheRustEmitter.cpp b/lib/Target/TfheRust/TfheRustEmitter.cpp index 92917a453..e10047c6f 100644 --- a/lib/Target/TfheRust/TfheRustEmitter.cpp +++ b/lib/Target/TfheRust/TfheRustEmitter.cpp @@ -739,10 +739,7 @@ LogicalResult TfheRustEmitter::printOperation(memref::LoadOp op) { emitAssignPrefix(op.getResult()); bool isRef = isa(op.getResult().getType().getDialect()); - bool storeUse = llvm::all_of(op.getResult().getUsers(), [](Operation *op) { - return isa(*op); - }); - os << ((isRef && !storeUse) ? "&" : ""); + os << (isRef ? "&" : ""); printLoadOp(op); os << ";\n"; } diff --git a/lib/Target/Verilog/VerilogEmitter.cpp b/lib/Target/Verilog/VerilogEmitter.cpp index c9a020a9b..25a4a09c9 100644 --- a/lib/Target/Verilog/VerilogEmitter.cpp +++ b/lib/Target/Verilog/VerilogEmitter.cpp @@ -2,6 +2,7 @@ #include #include +#include #include #include #include @@ -21,6 +22,7 @@ #include "llvm/include/llvm/ADT/StringRef.h" // from @llvm-project #include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project #include "llvm/include/llvm/ADT/ilist.h" // from @llvm-project +#include "llvm/include/llvm/Support/Casting.h" // from @llvm-project #include "llvm/include/llvm/Support/ErrorHandling.h" // from @llvm-project #include "llvm/include/llvm/Support/FormatVariadic.h" // from @llvm-project #include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project @@ -37,6 +39,7 @@ #include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/include/mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/include/mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/include/mlir/IR/Types.h" // from @llvm-project #include "mlir/include/mlir/IR/Value.h" // from @llvm-project #include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project @@ -131,6 +134,39 @@ CtlzValueStruct ctlzStructForResult(StringRef result) { .temp4 = llvm::formatv("{0}_{1}", result, "temp4")}; } +func::FuncOp getCalledFunction(func::CallOp callOp) { + SymbolRefAttr sym = + llvm::dyn_cast_if_present(callOp.getCallableForCallee()); + if (!sym) return nullptr; + return dyn_cast_or_null( + SymbolTable::lookupNearestSymbolFrom(callOp, sym)); +} + +int32_t getMaxMemrefIndexed(Value index) { + int32_t maxSize = 0; + for (auto &use : index.getUses()) { + Operation *user = use.getOwner(); + int32_t memrefSize = + llvm::TypeSwitch(user) + .Case( + [&](auto op) { return op.getMemRefType().getNumElements(); }) + .Case([&](func::CallOp op) { + // Index is passed into a function, get largest use. + func::FuncOp func = getCalledFunction(op); + auto &operand = + func.getBody().getArguments()[use.getOperandNumber()]; + assert(isa(operand.getType()) && + "expected block arg of index type use to be index type"); + return getMaxMemrefIndexed(operand); + }) + .Default([&](Operation *) { return 0; }); + maxSize = std::max(maxSize, memrefSize); + } + + return maxSize; +} + } // namespace void registerToVerilogTranslation() { @@ -918,17 +954,7 @@ LogicalResult VerilogEmitter::emitType(Type type, raw_ostream &os) { LogicalResult VerilogEmitter::emitIndexType(Value indexValue, raw_ostream &os) { // Operations on index types are not supported in this emitter, so we just // need to check the immediate users and inspect the memrefs they contain. - int32_t biggestMemrefSize = 0; - for (auto *user : indexValue.getUsers()) { - int32_t memrefSize = - llvm::TypeSwitch(user) - .Case( - [&](auto op) { return op.getMemRefType().getNumElements(); }) - .Default([&](Operation *) { return 0; }); - biggestMemrefSize = std::max(biggestMemrefSize, memrefSize); - } - + int32_t biggestMemrefSize = getMaxMemrefIndexed(indexValue); assert(biggestMemrefSize >= 0 && "unexpected index value unused by any memref ops"); auto widthBigint = APInt(64, biggestMemrefSize); diff --git a/lib/Transforms/YosysOptimizer/YosysOptimizer.cpp b/lib/Transforms/YosysOptimizer/YosysOptimizer.cpp index 8d1a8a944..90ee1c405 100644 --- a/lib/Transforms/YosysOptimizer/YosysOptimizer.cpp +++ b/lib/Transforms/YosysOptimizer/YosysOptimizer.cpp @@ -79,6 +79,7 @@ read_verilog -sv {0}; hierarchy -check -top \{1}; proc; memory; stat; techmap -map {2}/techmap.v; stat; +opt_expr; opt_clean -purge; stat; splitnets -ports \{1} %n; flatten; opt_expr; opt; opt_clean -purge; rename -hide */w:*; rename -enumerate */w:*; @@ -96,13 +97,16 @@ stat; // $3: yosys runfiles path // $4: abc fast option -fast constexpr std::string_view kYosysBooleanTemplate = R"( -read_verilog {0}; +read_verilog -sv {0}; hierarchy -check -top \{1}; proc; memory; stat; -techmap -map {3}/techmap.v; opt; stat; +techmap -map {3}/techmap.v; stat; +opt_expr; opt_clean -purge; stat; +splitnets -ports \{1} %n; +flatten; opt_expr; opt; opt_clean -purge; +rename -hide */w:*; rename -enumerate */w:*; abc -exe {2} -g AND,NAND,OR,NOR,XOR,XNOR {4}; opt_clean -purge; stat; -rename -hide */c:*; rename -enumerate */c:*; hierarchy -generate * o:Y i:*; opt; opt_clean -purge; clean; stat; @@ -142,12 +146,14 @@ struct YosysOptimizer : public impl::YosysOptimizerBase { using YosysOptimizerBase::YosysOptimizerBase; YosysOptimizer(std::string yosysFilesPath, std::string abcPath, bool abcFast, - int unrollFactor, Mode mode, bool printStats) + int unrollFactor, bool useSubmodules, Mode mode, + bool printStats) : yosysFilesPath(std::move(yosysFilesPath)), abcPath(std::move(abcPath)), abcFast(abcFast), printStats(printStats), unrollFactor(unrollFactor), + useSubmodules(useSubmodules), mode(mode) {} void runOnOperation() override; @@ -163,6 +169,7 @@ struct YosysOptimizer : public impl::YosysOptimizerBase { bool abcFast; bool printStats; int unrollFactor; + bool useSubmodules; Mode mode; llvm::SmallVector optStatistics; }; @@ -464,6 +471,7 @@ LogicalResult YosysOptimizer::runOnGenericOp(secret::GenericOp op) { // Translate Yosys result back to MLIR and insert into the func LLVM_DEBUG(Yosys::run_pass("dump;")); + Yosys::log_streams.clear(); std::stringstream cellOrder; Yosys::log_streams.push_back(&cellOrder); Yosys::run_pass("torder -stop * P*;"); @@ -554,6 +562,10 @@ void YosysOptimizer::runOnOperation() { auto *ctx = &getContext(); auto *op = getOperation(); + // Absorb any memref deallocs into generic's that allocate and use the memref. + mlir::IRRewriter builder(&getContext()); + op->walk([&](secret::GenericOp op) { genericAbsorbDealloc(op, builder); }); + mlir::RewritePatternSet cleanupPatterns(ctx); if (unrollFactor > 1) { if (failed(unrollAndMergeGenerics(op, unrollFactor, @@ -596,12 +608,47 @@ void YosysOptimizer::runOnOperation() { return; } + // Extract generics body's into function calls. + if (useSubmodules) { + auto result = op->walk([&](secret::GenericOp op) { + genericAbsorbConstants(op, builder); + + auto isTrivial = op.getBody()->walk([&](Operation *body) { + if (isa(body->getDialect()) && + !isa(body)) { + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + if (isTrivial.wasInterrupted()) { + if (failed(extractGenericBody(op, builder))) { + return WalkResult::interrupt(); + } + } + + return WalkResult::advance(); + }); + + if (result.wasInterrupted()) { + signalPassFailure(); + } + + // Merge generics after the function bodies are extracted. + mlir::RewritePatternSet mergePatterns(ctx); + mergePatterns.add(ctx); + if (failed(applyPatternsAndFoldGreedily(op, std::move(mergePatterns)))) { + signalPassFailure(); + getOperation()->emitError() + << "Failed to merge generic ops before yosys optimizer"; + return; + } + } + LLVM_DEBUG({ llvm::dbgs() << "IR after cleanup in preparation for yosys optimizer\n"; getOperation()->dump(); }); - mlir::IRRewriter builder(&getContext()); auto result = op->walk([&](secret::GenericOp op) { // Now pass through any constants used after capturing the ambient scope. // This way Yosys can optimize constants away instead of treating them as @@ -633,9 +680,10 @@ void YosysOptimizer::runOnOperation() { std::unique_ptr createYosysOptimizer( const std::string &yosysFilesPath, const std::string &abcPath, bool abcFast, - int unrollFactor, Mode mode, bool printStats) { + int unrollFactor, bool useSubmodules, Mode mode, bool printStats) { return std::make_unique(yosysFilesPath, abcPath, abcFast, - unrollFactor, mode, printStats); + unrollFactor, useSubmodules, mode, + printStats); } void registerYosysOptimizerPipeline(const std::string &yosysFilesPath, @@ -644,9 +692,9 @@ void registerYosysOptimizerPipeline(const std::string &yosysFilesPath, "yosys-optimizer", "The yosys optimizer pipeline.", [yosysFilesPath, abcPath](OpPassManager &pm, const YosysOptimizerPipelineOptions &options) { - pm.addPass(createYosysOptimizer(yosysFilesPath, abcPath, - options.abcFast, options.unrollFactor, - options.mode, options.printStats)); + pm.addPass(createYosysOptimizer( + yosysFilesPath, abcPath, options.abcFast, options.unrollFactor, + options.useSubmodules, options.mode, options.printStats)); pm.addPass(mlir::createCSEPass()); }); } diff --git a/lib/Transforms/YosysOptimizer/YosysOptimizer.h b/lib/Transforms/YosysOptimizer/YosysOptimizer.h index c50e79eaa..d7c55a293 100644 --- a/lib/Transforms/YosysOptimizer/YosysOptimizer.h +++ b/lib/Transforms/YosysOptimizer/YosysOptimizer.h @@ -13,7 +13,8 @@ enum Mode { Boolean, LUT }; std::unique_ptr createYosysOptimizer( const std::string &yosysFilesPath, const std::string &abcPath, bool abcFast, - int unrollFactor = 0, Mode mode = LUT, bool printStats = false); + int unrollFactor = 0, bool useSubmodules = true, Mode mode = LUT, + bool printStats = false); #define GEN_PASS_DECL #include "lib/Transforms/YosysOptimizer/YosysOptimizer.h.inc" @@ -30,6 +31,12 @@ struct YosysOptimizerPipelineOptions "value of zero (default) prevents unrolling."), llvm::cl::init(0)}; + PassOptions::Option useSubmodules{ + *this, "use-submodules", + llvm::cl::desc("Extracts secret.generic bodies into submodules before " + "optimizing. Default is true."), + llvm::cl::init(true)}; + PassOptions::Option mode{ *this, "mode", llvm::cl::desc("Map gates to boolean gates or lookup table gates."), diff --git a/lib/Transforms/YosysOptimizer/YosysOptimizer.td b/lib/Transforms/YosysOptimizer/YosysOptimizer.td index f88736241..69a20ebda 100644 --- a/lib/Transforms/YosysOptimizer/YosysOptimizer.td +++ b/lib/Transforms/YosysOptimizer/YosysOptimizer.td @@ -25,6 +25,10 @@ def YosysOptimizer : Pass<"yosys-optimizer"> { factor. If unset, this pass will not unroll any loops. - `print-stats`: Prints statistics about the optimized circuits. - `mode={Boolean,LUT}`: Map gates to boolean gates or lookup table gates. + - `use-submodules`: Extract the body of a generic op into submodules. + Useful for large programs with generics that can be isolated. This should + not be used when distributing generics through loops to avoid index + arguments in the function body. }]; // TODO(#257): add option for the pass to select the unroll factor // automatically. diff --git a/tests/Transforms/tosa_to_boolean_tfhe/BUILD b/tests/Transforms/tosa_to_boolean_tfhe/BUILD index bfd872b02..fbdd04049 100644 --- a/tests/Transforms/tosa_to_boolean_tfhe/BUILD +++ b/tests/Transforms/tosa_to_boolean_tfhe/BUILD @@ -9,6 +9,7 @@ glob_lit_tests( driver = "@heir//tests:run_lit.sh", size_override = { "fully_connected.mlir": "large", + "hello_world_small.mlir": "large", }, tags_override = { "hello_world.mlir": [ diff --git a/tests/Transforms/tosa_to_boolean_tfhe/hello_world_small.mlir b/tests/Transforms/tosa_to_boolean_tfhe/hello_world_small.mlir new file mode 100644 index 000000000..6f132ac38 --- /dev/null +++ b/tests/Transforms/tosa_to_boolean_tfhe/hello_world_small.mlir @@ -0,0 +1,26 @@ +// RUN: heir-opt --tosa-to-boolean-tfhe=abc-fast=true %s | FileCheck %s + +// A reduced dimension version of hello world to speed Yosys up. + +// CHECK-LABEL: module +module attributes {tf_saved_model.semantics} { + + func.func @main(%arg0: tensor<1x1xi8> {iree.identifier = "serving_default_dense_input:0", tf_saved_model.index_path = ["dense_input"]}) -> (tensor<1x1xi8> {iree.identifier = "StatefulPartitionedCall:0", tf_saved_model.index_path = ["dense_2"]}) attributes {tf_saved_model.exported_names = ["serving_default"]} { + %0 = "tosa.const"() {value = dense<429> : tensor<1xi32>} : () -> tensor<1xi32> + %1 = "tosa.const"() {value = dense<[[-39, 59, 39]]> : tensor<1x3xi8>} : () -> tensor<1x3xi8> + %2 = "tosa.const"() {value = dense<[-729, 1954, 610]> : tensor<3xi32>} : () -> tensor<3xi32> + %3 = "tosa.const"() {value = dense<"0xF41AED091921F424E0"> : tensor<3x3xi8>} : () -> tensor<3x3xi8> + %4 = "tosa.const"() {value = dense<[0, 0, -5438]> : tensor<3xi32>} : () -> tensor<3xi32> + %5 = "tosa.const"() {value = dense<[[-9], [-54], [57]]> : tensor<3x1xi8>} : () -> tensor<3x1xi8> + %6 = "tosa.fully_connected"(%arg0, %5, %4) {quantization_info = #tosa.conv_quant} : (tensor<1x1xi8>, tensor<3x1xi8>, tensor<3xi32>) -> tensor<1x3xi32> + %7 = "tosa.rescale"(%6) {double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = -128 : i32, per_channel = false, scale32 = true, shift = array} : (tensor<1x3xi32>) -> tensor<1x3xi8> + %8 = "tosa.clamp"(%7) {max_fp = 0.000000e+00 : f32, max_int = 127 : i64, min_fp = 0.000000e+00 : f32, min_int = -128 : i64} : (tensor<1x3xi8>) -> tensor<1x3xi8> + %9 = "tosa.fully_connected"(%8, %3, %2) {quantization_info = #tosa.conv_quant} : (tensor<1x3xi8>, tensor<3x3xi8>, tensor<3xi32>) -> tensor<1x3xi32> + %10 = "tosa.rescale"(%9) {double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = -128 : i32, per_channel = false, scale32 = true, shift = array} : (tensor<1x3xi32>) -> tensor<1x3xi8> + %11 = "tosa.clamp"(%10) {max_fp = 0.000000e+00 : f32, max_int = 127 : i64, min_fp = 0.000000e+00 : f32, min_int = -128 : i64} : (tensor<1x3xi8>) -> tensor<1x3xi8> + %12 = "tosa.fully_connected"(%11, %1, %0) {quantization_info = #tosa.conv_quant} : (tensor<1x3xi8>, tensor<1x3xi8>, tensor<1xi32>) -> tensor<1x1xi32> + %13 = "tosa.rescale"(%12) {double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = 5 : i32, per_channel = false, scale32 = true, shift = array} : (tensor<1x1xi32>) -> tensor<1x1xi8> + // CHECK: return + return %13 : tensor<1x1xi8> + } +} diff --git a/tests/Transforms/yosys_optimizer/global_absorb.mlir b/tests/Transforms/yosys_optimizer/global_absorb.mlir index 31bf9d3ef..05ec5ff1f 100644 --- a/tests/Transforms/yosys_optimizer/global_absorb.mlir +++ b/tests/Transforms/yosys_optimizer/global_absorb.mlir @@ -2,16 +2,16 @@ // optimization. This should optimize the multiplications instead of performing // a 32-bit generic multiplication. -// RUN: heir-opt --yosys-optimizer="abc-fast=True" %s | FileCheck %s +// RUN: heir-opt --yosys-optimizer %s | FileCheck %s module attributes {tf_saved_model.semantics} { // Use a weight vector with multiple weights to avoid constant folding. memref.global "private" constant @__constant_1xi8 : memref<2xi8> = dense<[3, 2]> {alignment = 64 : i64} - func.func @global_mul_32(%arg0 : !secret.secret>, %weight : i8) -> (!secret.secret>, !secret.secret>) { + // CHECK-LABEL: @global_mul_32 + func.func @global_mul_32(%arg0 : !secret.secret>, %weight : i8) -> (!secret.secret>) { // Generic 8-bit multiplication // CHECK: secret.generic - // CHECK-COUNT-86: comb.truth_table - // CHECK-NOT: comb.truth_table + // CHECK-COUNT-85: comb.truth_table // CHECK: secret.yield %0 = secret.generic ins(%arg0 : !secret.secret>) { ^bb0(%ARG0 : memref<1xi8>) : @@ -23,10 +23,14 @@ module attributes {tf_saved_model.semantics} { } secret.yield %alloc_0 : memref<1xi8> } -> !secret.secret> + // CHECK: return + return %0 : !secret.secret> + } + // CHECK-LABEL: @global_mul_32_constants + func.func @global_mul_32_constants(%arg0 : !secret.secret>, %weight : i8) -> (!secret.secret>) { // 8-bit multiplication with constant weights // CHECK: secret.generic - // CHECK-COUNT-28: comb.truth_table - // CHECK-NOT: comb.truth_table + // CHECK-COUNT-24: comb.truth_table // CHECK: secret.yield %4 = memref.get_global @__constant_1xi8 : memref<2xi8> %5 = secret.generic ins(%arg0 : !secret.secret>) { @@ -41,6 +45,6 @@ module attributes {tf_saved_model.semantics} { secret.yield %alloc_0 : memref<1xi8> } -> !secret.secret> // CHECK: return - return %0, %5 : !secret.secret>, !secret.secret> + return %5 : !secret.secret> } } diff --git a/tests/Transforms/yosys_optimizer/many_inputs.mlir b/tests/Transforms/yosys_optimizer/many_inputs.mlir index 0d118c153..440480cdf 100644 --- a/tests/Transforms/yosys_optimizer/many_inputs.mlir +++ b/tests/Transforms/yosys_optimizer/many_inputs.mlir @@ -1,4 +1,4 @@ -// RUN: heir-opt -yosys-optimizer=abc-fast=true %s | FileCheck %s +// RUN: heir-opt -yosys-optimizer="abc-fast=true use-submodules=false" %s | FileCheck %s // Regression test for https://github.com/google/heir/issues/359 When there are // > 10 ports, the RTLIL wire ordering is not the same as the original generic's diff --git a/tests/Transforms/yosys_optimizer/stats.mlir b/tests/Transforms/yosys_optimizer/stats.mlir index a76d88f6e..fbf779e33 100644 --- a/tests/Transforms/yosys_optimizer/stats.mlir +++ b/tests/Transforms/yosys_optimizer/stats.mlir @@ -1,4 +1,4 @@ -// RUN: heir-opt --secret-distribute-generic=distribute-through="affine.for" --yosys-optimizer="unroll-factor=4 print-stats=true" -o /dev/null %s 2>&1 | FileCheck %s +// RUN: heir-opt --secret-distribute-generic=distribute-through="affine.for" --yosys-optimizer="unroll-factor=4 use-submodules=false print-stats=true" -o /dev/null %s 2>&1 | FileCheck %s !in_ty = !secret.secret> !out_ty = !secret.secret> diff --git a/tests/Transforms/yosys_optimizer/submodules.mlir b/tests/Transforms/yosys_optimizer/submodules.mlir new file mode 100644 index 000000000..1ee34ed20 --- /dev/null +++ b/tests/Transforms/yosys_optimizer/submodules.mlir @@ -0,0 +1,35 @@ +// Tests use-submodules error when optimizing a generic distributed through +// affine for loops. + +// RUN: heir-opt -yosys-optimizer="abc-fast=true use-submodules=true" %s --verify-diagnostics + +module attributes {tf_saved_model.semantics} { + func.func @main(%arg0: !secret.secret>) -> (!secret.secret>) { + %c127_i32 = arith.constant 127 : i16 + %0 = secret.generic { + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x1xi8> + secret.yield %alloc : memref<1x1xi8> + } -> !secret.secret> + affine.for %arg1 = 0 to 1 { + affine.for %arg2 = 0 to 1 { + %1 = secret.generic ins(%arg0, %arg1, %arg2 : !secret.secret>, index, index) { + ^bb0(%arg3: memref<1x1xi16>, %arg4: index, %arg5: index): + %3 = memref.load %arg3[%arg4, %arg5] : memref<1x1xi16> + secret.yield %3 : i16 + } -> !secret.secret + %2 = secret.generic ins(%1, %c127_i32 : !secret.secret, i16) { + ^bb0(%arg3: i16, %arg4: i16): + %3 = arith.addi %arg3, %arg4 : i16 + %4 = arith.trunci %3 : i16 to i8 + secret.yield %4 : i8 + } -> !secret.secret + secret.generic ins(%0, %2, %arg1, %arg2 : !secret.secret>, !secret.secret, index, index) { + ^bb0(%arg3: memref<1x1xi8>, %arg4: i8, %arg5: index, %arg6: index): + memref.store %arg4, %arg3[%arg5, %arg6] : memref<1x1xi8> + secret.yield + } + } + } + return %0 : !secret.secret> + } +} diff --git a/tests/Transforms/yosys_optimizer/unroll_and_optimize.mlir b/tests/Transforms/yosys_optimizer/unroll_and_optimize.mlir index 343fd17d5..7ac926182 100644 --- a/tests/Transforms/yosys_optimizer/unroll_and_optimize.mlir +++ b/tests/Transforms/yosys_optimizer/unroll_and_optimize.mlir @@ -1,4 +1,4 @@ -// RUN: heir-opt --secret-distribute-generic=distribute-through="affine.for" --yosys-optimizer="unroll-factor=2" --canonicalize %s | FileCheck %s +// RUN: heir-opt --secret-distribute-generic=distribute-through="affine.for" --yosys-optimizer="unroll-factor=2 use-submodules=false" --canonicalize %s | FileCheck %s !in_ty = !secret.secret> !out_ty = !secret.secret> diff --git a/tests/Transforms/yosys_optimizer/unroll_factor.mlir b/tests/Transforms/yosys_optimizer/unroll_factor.mlir index 6a6fe212a..a74f348ae 100644 --- a/tests/Transforms/yosys_optimizer/unroll_factor.mlir +++ b/tests/Transforms/yosys_optimizer/unroll_factor.mlir @@ -1,4 +1,4 @@ -// RUN: heir-opt --secret-distribute-generic=distribute-through="affine.for" --yosys-optimizer="unroll-factor=3" --canonicalize %s | FileCheck %s +// RUN: heir-opt --secret-distribute-generic=distribute-through="affine.for" --yosys-optimizer="unroll-factor=3 use-submodules=false" --canonicalize %s | FileCheck %s // Regression test for #444 testing the RTLIL imported through an unroll factor // larger than the loop size. diff --git a/tools/heir-opt.cpp b/tools/heir-opt.cpp index 4de8d4843..c16c6458a 100644 --- a/tools/heir-opt.cpp +++ b/tools/heir-opt.cpp @@ -127,9 +127,7 @@ using namespace tosa; using namespace heir; using mlir::func::FuncOp; -static std::vector opsToDistribute = { - "affine.for", "affine.load", "memref.load", "memref.store", - "affine.store", "memref.get_global", "memref.dealloc", "memref.alloc"}; +static std::vector opsToDistribute = {"secret.separator"}; static std::vector bitWidths = {1, 2, 4, 8, 16}; // RLWE scheme selector @@ -342,59 +340,92 @@ struct TosaToBooleanTfheOptions llvm::cl::init("main")}; }; +void tosaToCGGIPipelineBuilder(OpPassManager &pm, + const TosaToBooleanTfheOptions &options, + const std::string &yosysFilesPath, + const std::string &abcPath, + bool abcBooleanGates) { + // Secretize inputs + pm.addPass(createSecretize(SecretizeOptions{options.entryFunction})); + + // TOSA to linalg + tosaToLinalg(pm); + + // Bufferize + oneShotBufferize(pm); + + // Affine + pm.addNestedPass(createConvertLinalgToAffineLoopsPass()); + pm.addNestedPass(memref::createExpandStridedMetadataPass()); + pm.addNestedPass(affine::createAffineExpandIndexOpsPass()); + pm.addNestedPass(memref::createExpandOpsPass()); + pm.addPass(createExpandCopyPass()); + pm.addNestedPass(affine::createSimplifyAffineStructuresPass()); + pm.addNestedPass(affine::createAffineLoopNormalizePass(true)); + pm.addPass(memref::createFoldMemRefAliasOpsPass()); + + // Affine loop optimizations + pm.addNestedPass( + affine::createLoopFusionPass(0, 0, true, affine::FusionMode::Greedy)); + pm.addNestedPass(affine::createAffineLoopNormalizePass(true)); + pm.addPass(createForwardStoreToLoad()); + pm.addPass(affine::createAffineParallelizePass()); + pm.addPass(createFullLoopUnroll()); + pm.addPass(createForwardStoreToLoad()); + pm.addNestedPass(createRemoveUnusedMemRef()); + + // Cleanup + pm.addPass(createMemrefGlobalReplacePass()); + arith::ArithIntNarrowingOptions arithOps; + arithOps.bitwidthsSupported = {4, 8, 16}; + pm.addPass(arith::createArithIntNarrowing(arithOps)); + pm.addPass(createCanonicalizerPass()); + pm.addPass(createSCCPPass()); + pm.addPass(createCSEPass()); + pm.addPass(createSymbolDCEPass()); + + // Wrap with secret.generic and then distribute-generic. + pm.addPass(createWrapGeneric()); + auto distributeOpts = secret::SecretDistributeGenericOptions{ + .opsToDistribute = llvm::to_vector(opsToDistribute)}; + pm.addPass(secret::createSecretDistributeGeneric(distributeOpts)); + pm.addPass(createCanonicalizerPass()); + + // Booleanize and Yosys Optimize + pm.addPass(createYosysOptimizer(yosysFilesPath, abcPath, options.abcFast, + options.unrollFactor, /*useSubmodules=*/true, + abcBooleanGates ? Mode::Boolean : Mode::LUT)); + + // Cleanup + pm.addPass(mlir::createCSEPass()); + pm.addPass(createCanonicalizerPass()); + pm.addPass(createSCCPPass()); + pm.addPass(createSymbolDCEPass()); + pm.addPass(memref::createFoldMemRefAliasOpsPass()); + pm.addPass(createForwardStoreToLoad()); + + // Lower combinational circuit to CGGI + pm.addPass(secret::createSecretDistributeGeneric()); + pm.addPass(createSecretToCGGI()); + + // Cleanup CombToCGGI + pm.addPass( + createExpandCopyPass(ExpandCopyPassOptions{.disableAffineLoop = true})); + pm.addPass(memref::createFoldMemRefAliasOpsPass()); + pm.addPass(createForwardStoreToLoad()); + pm.addPass(createRemoveUnusedMemRef()); + pm.addPass(createCSEPass()); + pm.addPass(createSCCPPass()); +} + void tosaToBooleanTfhePipeline(const std::string &yosysFilesPath, const std::string &abcPath) { PassPipelineRegistration( "tosa-to-boolean-tfhe", "Arithmetic modules to boolean tfhe-rs pipeline.", [yosysFilesPath, abcPath](OpPassManager &pm, const TosaToBooleanTfheOptions &options) { - // Secretize inputs - pm.addPass(createSecretize(SecretizeOptions{options.entryFunction})); - - // TOSA to linalg - tosaToLinalg(pm); - - // Bufferize - oneShotBufferize(pm); - - // Affine - pm.addNestedPass(createConvertLinalgToAffineLoopsPass()); - pm.addNestedPass(memref::createExpandStridedMetadataPass()); - pm.addNestedPass(affine::createAffineExpandIndexOpsPass()); - pm.addNestedPass(memref::createExpandOpsPass()); - pm.addNestedPass(affine::createSimplifyAffineStructuresPass()); - pm.addNestedPass(affine::createAffineLoopNormalizePass(true)); - pm.addPass(memref::createFoldMemRefAliasOpsPass()); - pm.addPass(createExpandCopyPass()); - - // Cleanup - pm.addPass(createMemrefGlobalReplacePass()); - arith::ArithIntNarrowingOptions arithOps; - arithOps.bitwidthsSupported = {4, 8, 16}; - pm.addPass(arith::createArithIntNarrowing(arithOps)); - pm.addPass(createCanonicalizerPass()); - pm.addPass(createSCCPPass()); - pm.addPass(createCSEPass()); - pm.addPass(createSymbolDCEPass()); - - // Wrap with secret.generic and then distribute-generic. - pm.addPass(createWrapGeneric()); - auto distributeOpts = secret::SecretDistributeGenericOptions{ - .opsToDistribute = llvm::to_vector(opsToDistribute)}; - pm.addPass(secret::createSecretDistributeGeneric(distributeOpts)); - pm.addPass(createCanonicalizerPass()); - - // Booleanize and Yosys Optimize - pm.addPass(createYosysOptimizer(yosysFilesPath, abcPath, - options.abcFast, options.unrollFactor)); - - // Lower combinational circuit to CGGI - pm.addPass(createCanonicalizerPass()); - pm.addPass(createSCCPPass()); - - pm.addPass(mlir::createCSEPass()); - pm.addPass(secret::createSecretDistributeGeneric()); - pm.addPass(createSecretToCGGI()); + tosaToCGGIPipelineBuilder(pm, options, yosysFilesPath, abcPath, + /*abcBooleanGates=*/false); // CGGI to Tfhe-Rust exit dialect pm.addPass(createCGGIToTfheRust()); @@ -433,155 +464,47 @@ struct TosaToBooleanFpgaTfheOptions void tosaToBooleanFpgaTfhePipeline(const std::string &yosysFilesPath, const std::string &abcPath) { - PassPipelineRegistration( + PassPipelineRegistration( "tosa-to-boolean-fpga-tfhe", "Arithmetic modules to boolean tfhe-rs for FPGA backend pipeline.", [yosysFilesPath, abcPath](OpPassManager &pm, - const TosaToBooleanFpgaTfheOptions &options) { - // Secretize inputs - pm.addPass(createSecretize(SecretizeOptions{options.entryFunction})); - - // TOSA to linalg - tosaToLinalg(pm); - - // Bufferize - oneShotBufferize(pm); - - // Affine - pm.addNestedPass(createConvertLinalgToAffineLoopsPass()); - pm.addNestedPass(memref::createExpandStridedMetadataPass()); - pm.addNestedPass(affine::createAffineExpandIndexOpsPass()); - pm.addNestedPass(memref::createExpandOpsPass()); - pm.addNestedPass(affine::createSimplifyAffineStructuresPass()); - pm.addPass(memref::createFoldMemRefAliasOpsPass()); - pm.addPass(createExpandCopyPass()); - pm.addNestedPass(affine::createAffineLoopNormalizePass(true)); - pm.addNestedPass(affine::createLoopFusionPass( - 0, 0, true, affine::FusionMode::Greedy)); - pm.addPass(affine::createAffineScalarReplacementPass()); - pm.addPass(createForwardStoreToLoad()); + const TosaToBooleanTfheOptions &options) { + tosaToCGGIPipelineBuilder(pm, options, yosysFilesPath, abcPath, + /*abcBooleanGates=*/true); - // Cleanup - pm.addPass(createMemrefGlobalReplacePass()); - arith::ArithIntNarrowingOptions arithOps; - arithOps.bitwidthsSupported = {4, 8, 16}; - pm.addPass(arith::createArithIntNarrowing(arithOps)); + // Vectorize CGGI operations + pm.addPass(createStraightLineVectorizer( + StraightLineVectorizerOptions{.dialect = "cggi"})); pm.addPass(createCanonicalizerPass()); - pm.addPass(createSCCPPass()); pm.addPass(createCSEPass()); - pm.addPass(createSymbolDCEPass()); + pm.addPass(createSCCPPass()); - pm.addPass(createWrapGeneric()); - auto distributeOpts = secret::SecretDistributeGenericOptions{ - .opsToDistribute = llvm::to_vector(opsToDistribute)}; - pm.addPass(secret::createSecretDistributeGeneric(distributeOpts)); + // CGGI to Tfhe-Rust exit dialect + pm.addPass(createCGGIToTfheRustBool()); + // CSE must be run before canonicalizer, so that redundant ops are + // cleared before the canonicalizer hoists TfheRust ops. + pm.addPass(createCSEPass()); pm.addPass(createCanonicalizerPass()); - // Booleanize and Yosys Optimize - pm.addPass(createYosysOptimizer(yosysFilesPath, abcPath, - options.abcFast, options.unrollFactor, - Mode::Boolean)); - - // Lower combinational circuit to CGGI - pm.addPass(createForwardStoreToLoad()); - pm.addPass(mlir::createCSEPass()); - pm.addPass(secret::createSecretDistributeGeneric()); - pm.addPass(createSecretToCGGI()); - // Cleanup SecretToCGGI + // Cleanup loads and stores pm.addPass(createExpandCopyPass( ExpandCopyPassOptions{.disableAffineLoop = true})); pm.addPass(memref::createFoldMemRefAliasOpsPass()); pm.addPass(createForwardStoreToLoad()); - pm.addPass(createRemoveUnusedMemRef()); - - pm.addPass(createStraightLineVectorizer( - StraightLineVectorizerOptions{.dialect = "cggi"})); - pm.addPass(createCanonicalizerPass()); - pm.addPass(createCSEPass()); - pm.addPass(createSCCPPass()); - - // CGGI to Tfhe-Rust exit dialect - pm.addPass(createCGGIToTfheRustBool()); pm.addPass(createCanonicalizerPass()); pm.addPass(createCSEPass()); pm.addPass(createSCCPPass()); }); } -struct TosaToJaxiteOptions : public PassPipelineOptions { - PassOptions::Option abcFast{*this, "abc-fast", - llvm::cl::desc("Run abc in fast mode."), - llvm::cl::init(false)}; - - PassOptions::Option unrollFactor{ - *this, "unroll-factor", - llvm::cl::desc("Unroll loops by a given factor before optimizing. A " - "value of zero (default) prevents unrolling."), - llvm::cl::init(0)}; - - PassOptions::Option entryFunction{ - *this, "entry-function", llvm::cl::desc("Entry function to secretize"), - llvm::cl::init("main")}; -}; - void tosaToJaxitePipeline(const std::string &yosysFilesPath, const std::string &abcPath) { - PassPipelineRegistration( + PassPipelineRegistration( "tosa-to-boolean-jaxite", "Arithmetic modules to jaxite pipeline.", [yosysFilesPath, abcPath](OpPassManager &pm, - const TosaToJaxiteOptions &options) { - // Secretize inputs - pm.addPass(createSecretize(SecretizeOptions{options.entryFunction})); - - // TOSA to linalg - tosaToLinalg(pm); - - // Bufferize - oneShotBufferize(pm); - - // Affine - pm.addNestedPass(createConvertLinalgToAffineLoopsPass()); - pm.addNestedPass(memref::createExpandStridedMetadataPass()); - pm.addNestedPass(affine::createAffineExpandIndexOpsPass()); - pm.addNestedPass(memref::createExpandOpsPass()); - pm.addNestedPass(affine::createSimplifyAffineStructuresPass()); - pm.addNestedPass(affine::createAffineLoopNormalizePass(true)); - pm.addPass(memref::createFoldMemRefAliasOpsPass()); - pm.addPass(createExpandCopyPass()); - pm.addNestedPass(affine::createAffineLoopNormalizePass(true)); - pm.addNestedPass(affine::createLoopFusionPass( - 0, 0, true, affine::FusionMode::Greedy)); - pm.addPass(affine::createAffineScalarReplacementPass()); - pm.addPass(createForwardStoreToLoad()); - - // Cleanup - pm.addPass(createMemrefGlobalReplacePass()); - arith::ArithIntNarrowingOptions arithOps; - arithOps.bitwidthsSupported = llvm::to_vector(bitWidths); - pm.addPass(arith::createArithIntNarrowing(arithOps)); - pm.addPass(createCanonicalizerPass()); - pm.addPass(createSCCPPass()); - pm.addPass(createCSEPass()); - pm.addPass(createSymbolDCEPass()); - pm.addPass(affine::createAffineScalarReplacementPass()); - - // Wrap with secret.generic and then distribute-generic. - pm.addPass(createWrapGeneric()); - auto distributeOpts = secret::SecretDistributeGenericOptions{ - .opsToDistribute = llvm::to_vector(opsToDistribute)}; - pm.addPass(secret::createSecretDistributeGeneric(distributeOpts)); - pm.addPass(createCanonicalizerPass()); - // Booleanize and Yosys Optimize - pm.addPass(createYosysOptimizer(yosysFilesPath, abcPath, - options.abcFast, options.unrollFactor)); - - // Lower combinational circuit to CGGI - pm.addPass(createCanonicalizerPass()); - pm.addPass(createSCCPPass()); - - pm.addPass(mlir::createCSEPass()); - pm.addPass(secret::createSecretDistributeGeneric()); - pm.addPass(createSecretToCGGI()); + const TosaToBooleanTfheOptions &options) { + tosaToCGGIPipelineBuilder(pm, options, yosysFilesPath, abcPath, + /*abcBooleanGates=*/false); // CGGI to Jaxite exit dialect pm.addPass(createCGGIToJaxite());