Skip to content

Commit

Permalink
tosa-to-boolean-tfhe: optimize pipeline & unify fpga pipeline
Browse files Browse the repository at this point in the history
* Adds optimizations by changing secret generic splitting to be across separators rather than affine.for operations (TODO: add automatic submodule splitting)
* Unifies tosa-to-boolean-tfhe and tosa-to-boolean-fpga-tfhe by facotring out the common tosa-to-cggi pipeline

PiperOrigin-RevId: 688262006
  • Loading branch information
asraa authored and copybara-github committed Oct 21, 2024
1 parent 6620b3f commit 5ffc8f5
Show file tree
Hide file tree
Showing 17 changed files with 307 additions and 219 deletions.
12 changes: 12 additions & 0 deletions lib/Dialect/CGGI/Conversions/CGGIToTfheRust/CGGIToTfheRust.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "lib/Utils/ConversionUtils/ConversionUtils.h"
#include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project
#include "llvm/include/llvm/Support/Casting.h" // from @llvm-project
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
#include "llvm/include/llvm/Support/ErrorHandling.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
Expand All @@ -27,6 +28,8 @@
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project

#define DEBUG_TYPE "cggi-to-tfhe-rust"

namespace mlir::heir {

#define GEN_PASS_DEF_CGGITOTFHERUST
Expand All @@ -41,6 +44,15 @@ Type encrytpedUIntTypeFromWidth(MLIRContext *ctx, int width) {
// Only supporting unsigned types because the LWE dialect does not have a
// notion of signedness.
switch (width) {
case 1:
// The minimum bit width of the integer tfhe_rust API is UInt2
// https://docs.rs/tfhe/latest/tfhe/index.html#types
// This may happen if there are no LUT or boolean gate operations that
// require a minimum bit width (e.g. shuffling bits in a program that
// multiplies by two).
LLVM_DEBUG(llvm::dbgs()
<< "Upgrading ciphertext with bit width 1 to UInt2");
[[fallthrough]];
case 2:
return tfhe_rust::EncryptedUInt2Type::get(ctx);
case 3:
Expand Down
16 changes: 10 additions & 6 deletions lib/Dialect/Secret/Conversions/SecretToCGGI/SecretToCGGI.cpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
#include "lib/Dialect/Secret/Conversions/SecretToCGGI/SecretToCGGI.h"

#include <algorithm>
#include <cassert>
#include <cstdint>
#include <utility>

#include "lib/Dialect/CGGI/IR/CGGIOps.h"
#include "lib/Dialect/Comb/IR/CombDialect.h"
#include "lib/Dialect/Comb/IR/CombOps.h"
#include "lib/Dialect/LWE/IR/LWEAttributes.h"
#include "lib/Dialect/LWE/IR/LWEOps.h"
Expand Down Expand Up @@ -536,23 +538,23 @@ struct ConvertSecretCastOp : public OpConversionPattern<secret::CastOp> {
};

int findLUTSize(MLIRContext *context, Operation *module) {
int max_int_size = 0;
int maxIntSize = 1;
auto processOperation = [&](Operation *op) {
if (isa<comb::CombDialect>(op->getDialect())) {
int current_size = 0;
int currentSize = 0;
if (dyn_cast<comb::TruthTableOp>(op))
current_size = 3;
currentSize = 3;
else
current_size = op->getResults().getTypes()[0].getIntOrFloatBitWidth();
currentSize = op->getResults().getTypes()[0].getIntOrFloatBitWidth();

max_int_size = std::max(max_int_size, current_size);
maxIntSize = std::max(maxIntSize, currentSize);
}
};

// Walk all operations within the module in post-order (default)
module->walk(processOperation);

return max_int_size;
return maxIntSize;
}

struct SecretToCGGI : public impl::SecretToCGGIBase<SecretToCGGI> {
Expand All @@ -572,6 +574,8 @@ struct SecretToCGGI : public impl::SecretToCGGIBase<SecretToCGGI> {
.add<SecretGenericOpLUTConversion,
SecretGenericOpConversion<memref::AllocOp, memref::AllocOp>,
SecretGenericOpConversion<memref::DeallocOp, memref::DeallocOp>,
SecretGenericOpConversion<memref::CollapseShapeOp,
memref::CollapseShapeOp>,
SecretGenericOpMemRefLoadConversion,
SecretGenericOpAffineStoreConversion,
SecretGenericOpAffineLoadConversion,
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Secret/IR/SecretPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -736,6 +736,7 @@ LogicalResult extractGenericBody(secret::GenericOp genericOp,
std::string funcName = llvm::formatv(
"internal_generic_{0}", mlir::hash_value(yieldOp.getValues()[0]));
auto func = builder.create<func::FuncOp>(module.getLoc(), funcName, type);
func.setPrivate();

// Populate function body by cloning the ops in the inner body and mapping
// the func args and func outputs.
Expand Down
5 changes: 1 addition & 4 deletions lib/Target/TfheRust/TfheRustEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -739,10 +739,7 @@ LogicalResult TfheRustEmitter::printOperation(memref::LoadOp op) {
emitAssignPrefix(op.getResult());
bool isRef =
isa<tfhe_rust::TfheRustDialect>(op.getResult().getType().getDialect());
bool storeUse = llvm::all_of(op.getResult().getUsers(), [](Operation *op) {
return isa<memref::StoreOp>(*op);
});
os << ((isRef && !storeUse) ? "&" : "");
os << (isRef ? "&" : "");
printLoadOp(op);
os << ";\n";
}
Expand Down
48 changes: 37 additions & 11 deletions lib/Target/Verilog/VerilogEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <algorithm>
#include <cassert>
#include <cstddef>
#include <cstdint>
#include <functional>
#include <optional>
Expand All @@ -21,6 +22,7 @@
#include "llvm/include/llvm/ADT/StringRef.h" // from @llvm-project
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
#include "llvm/include/llvm/ADT/ilist.h" // from @llvm-project
#include "llvm/include/llvm/Support/Casting.h" // from @llvm-project
#include "llvm/include/llvm/Support/ErrorHandling.h" // from @llvm-project
#include "llvm/include/llvm/Support/FormatVariadic.h" // from @llvm-project
#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project
Expand All @@ -37,6 +39,7 @@
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/include/mlir/IR/DialectRegistry.h" // from @llvm-project
#include "mlir/include/mlir/IR/SymbolTable.h" // from @llvm-project
#include "mlir/include/mlir/IR/Types.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project
Expand Down Expand Up @@ -131,6 +134,39 @@ CtlzValueStruct ctlzStructForResult(StringRef result) {
.temp4 = llvm::formatv("{0}_{1}", result, "temp4")};
}

func::FuncOp getCalledFunction(func::CallOp callOp) {
SymbolRefAttr sym =
llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
if (!sym) return nullptr;
return dyn_cast_or_null<func::FuncOp>(
SymbolTable::lookupNearestSymbolFrom(callOp, sym));
}

int32_t getMaxMemrefIndexed(Value index) {
int32_t maxSize = 0;
for (auto &use : index.getUses()) {
Operation *user = use.getOwner();
int32_t memrefSize =
llvm::TypeSwitch<Operation *, int32_t>(user)
.Case<affine::AffineLoadOp, affine::AffineStoreOp, memref::LoadOp,
memref::StoreOp>(
[&](auto op) { return op.getMemRefType().getNumElements(); })
.Case<func::CallOp>([&](func::CallOp op) {
// Index is passed into a function, get largest use.
func::FuncOp func = getCalledFunction(op);
auto &operand =
func.getBody().getArguments()[use.getOperandNumber()];
assert(isa<IndexType>(operand.getType()) &&
"expected block arg of index type use to be index type");
return getMaxMemrefIndexed(operand);
})
.Default([&](Operation *) { return 0; });
maxSize = std::max(maxSize, memrefSize);
}

return maxSize;
}

} // namespace

void registerToVerilogTranslation() {
Expand Down Expand Up @@ -918,17 +954,7 @@ LogicalResult VerilogEmitter::emitType(Type type, raw_ostream &os) {
LogicalResult VerilogEmitter::emitIndexType(Value indexValue, raw_ostream &os) {
// Operations on index types are not supported in this emitter, so we just
// need to check the immediate users and inspect the memrefs they contain.
int32_t biggestMemrefSize = 0;
for (auto *user : indexValue.getUsers()) {
int32_t memrefSize =
llvm::TypeSwitch<Operation *, int32_t>(user)
.Case<affine::AffineLoadOp, affine::AffineStoreOp, memref::LoadOp,
memref::StoreOp>(
[&](auto op) { return op.getMemRefType().getNumElements(); })
.Default([&](Operation *) { return 0; });
biggestMemrefSize = std::max(biggestMemrefSize, memrefSize);
}

int32_t biggestMemrefSize = getMaxMemrefIndexed(indexValue);
assert(biggestMemrefSize >= 0 &&
"unexpected index value unused by any memref ops");
auto widthBigint = APInt(64, biggestMemrefSize);
Expand Down
68 changes: 58 additions & 10 deletions lib/Transforms/YosysOptimizer/YosysOptimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ read_verilog -sv {0};
hierarchy -check -top \{1};
proc; memory; stat;
techmap -map {2}/techmap.v; stat;
opt_expr; opt_clean -purge; stat;
splitnets -ports \{1} %n;
flatten; opt_expr; opt; opt_clean -purge;
rename -hide */w:*; rename -enumerate */w:*;
Expand All @@ -96,13 +97,16 @@ stat;
// $3: yosys runfiles path
// $4: abc fast option -fast
constexpr std::string_view kYosysBooleanTemplate = R"(
read_verilog {0};
read_verilog -sv {0};
hierarchy -check -top \{1};
proc; memory; stat;
techmap -map {3}/techmap.v; opt; stat;
techmap -map {3}/techmap.v; stat;
opt_expr; opt_clean -purge; stat;
splitnets -ports \{1} %n;
flatten; opt_expr; opt; opt_clean -purge;
rename -hide */w:*; rename -enumerate */w:*;
abc -exe {2} -g AND,NAND,OR,NOR,XOR,XNOR {4};
opt_clean -purge; stat;
rename -hide */c:*; rename -enumerate */c:*;
hierarchy -generate * o:Y i:*; opt; opt_clean -purge;
clean;
stat;
Expand Down Expand Up @@ -142,12 +146,14 @@ struct YosysOptimizer : public impl::YosysOptimizerBase<YosysOptimizer> {
using YosysOptimizerBase::YosysOptimizerBase;

YosysOptimizer(std::string yosysFilesPath, std::string abcPath, bool abcFast,
int unrollFactor, Mode mode, bool printStats)
int unrollFactor, bool useSubmodules, Mode mode,
bool printStats)
: yosysFilesPath(std::move(yosysFilesPath)),
abcPath(std::move(abcPath)),
abcFast(abcFast),
printStats(printStats),
unrollFactor(unrollFactor),
useSubmodules(useSubmodules),
mode(mode) {}

void runOnOperation() override;
Expand All @@ -163,6 +169,7 @@ struct YosysOptimizer : public impl::YosysOptimizerBase<YosysOptimizer> {
bool abcFast;
bool printStats;
int unrollFactor;
bool useSubmodules;
Mode mode;
llvm::SmallVector<RelativeOptimizationStatistics> optStatistics;
};
Expand Down Expand Up @@ -464,6 +471,7 @@ LogicalResult YosysOptimizer::runOnGenericOp(secret::GenericOp op) {

// Translate Yosys result back to MLIR and insert into the func
LLVM_DEBUG(Yosys::run_pass("dump;"));
Yosys::log_streams.clear();
std::stringstream cellOrder;
Yosys::log_streams.push_back(&cellOrder);
Yosys::run_pass("torder -stop * P*;");
Expand Down Expand Up @@ -554,6 +562,10 @@ void YosysOptimizer::runOnOperation() {
auto *ctx = &getContext();
auto *op = getOperation();

// Absorb any memref deallocs into generic's that allocate and use the memref.
mlir::IRRewriter builder(&getContext());
op->walk([&](secret::GenericOp op) { genericAbsorbDealloc(op, builder); });

mlir::RewritePatternSet cleanupPatterns(ctx);
if (unrollFactor > 1) {
if (failed(unrollAndMergeGenerics(op, unrollFactor,
Expand Down Expand Up @@ -596,12 +608,47 @@ void YosysOptimizer::runOnOperation() {
return;
}

// Extract generics body's into function calls.
if (useSubmodules) {
auto result = op->walk([&](secret::GenericOp op) {
genericAbsorbConstants(op, builder);

auto isTrivial = op.getBody()->walk([&](Operation *body) {
if (isa<arith::ArithDialect>(body->getDialect()) &&
!isa<arith::ConstantOp>(body)) {
return WalkResult::interrupt();
}
return WalkResult::advance();
});
if (isTrivial.wasInterrupted()) {
if (failed(extractGenericBody(op, builder))) {
return WalkResult::interrupt();
}
}

return WalkResult::advance();
});

if (result.wasInterrupted()) {
signalPassFailure();
}

// Merge generics after the function bodies are extracted.
mlir::RewritePatternSet mergePatterns(ctx);
mergePatterns.add<secret::MergeAdjacentGenerics>(ctx);
if (failed(applyPatternsAndFoldGreedily(op, std::move(mergePatterns)))) {
signalPassFailure();
getOperation()->emitError()
<< "Failed to merge generic ops before yosys optimizer";
return;
}
}

LLVM_DEBUG({
llvm::dbgs() << "IR after cleanup in preparation for yosys optimizer\n";
getOperation()->dump();
});

mlir::IRRewriter builder(&getContext());
auto result = op->walk([&](secret::GenericOp op) {
// Now pass through any constants used after capturing the ambient scope.
// This way Yosys can optimize constants away instead of treating them as
Expand Down Expand Up @@ -633,9 +680,10 @@ void YosysOptimizer::runOnOperation() {

std::unique_ptr<mlir::Pass> createYosysOptimizer(
const std::string &yosysFilesPath, const std::string &abcPath, bool abcFast,
int unrollFactor, Mode mode, bool printStats) {
int unrollFactor, bool useSubmodules, Mode mode, bool printStats) {
return std::make_unique<YosysOptimizer>(yosysFilesPath, abcPath, abcFast,
unrollFactor, mode, printStats);
unrollFactor, useSubmodules, mode,
printStats);
}

void registerYosysOptimizerPipeline(const std::string &yosysFilesPath,
Expand All @@ -644,9 +692,9 @@ void registerYosysOptimizerPipeline(const std::string &yosysFilesPath,
"yosys-optimizer", "The yosys optimizer pipeline.",
[yosysFilesPath, abcPath](OpPassManager &pm,
const YosysOptimizerPipelineOptions &options) {
pm.addPass(createYosysOptimizer(yosysFilesPath, abcPath,
options.abcFast, options.unrollFactor,
options.mode, options.printStats));
pm.addPass(createYosysOptimizer(
yosysFilesPath, abcPath, options.abcFast, options.unrollFactor,
options.useSubmodules, options.mode, options.printStats));
pm.addPass(mlir::createCSEPass());
});
}
Expand Down
9 changes: 8 additions & 1 deletion lib/Transforms/YosysOptimizer/YosysOptimizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ enum Mode { Boolean, LUT };

std::unique_ptr<mlir::Pass> createYosysOptimizer(
const std::string &yosysFilesPath, const std::string &abcPath, bool abcFast,
int unrollFactor = 0, Mode mode = LUT, bool printStats = false);
int unrollFactor = 0, bool useSubmodules = true, Mode mode = LUT,
bool printStats = false);

#define GEN_PASS_DECL
#include "lib/Transforms/YosysOptimizer/YosysOptimizer.h.inc"
Expand All @@ -30,6 +31,12 @@ struct YosysOptimizerPipelineOptions
"value of zero (default) prevents unrolling."),
llvm::cl::init(0)};

PassOptions::Option<bool> useSubmodules{
*this, "use-submodules",
llvm::cl::desc("Extracts secret.generic bodies into submodules before "
"optimizing. Default is true."),
llvm::cl::init(true)};

PassOptions::Option<enum Mode> mode{
*this, "mode",
llvm::cl::desc("Map gates to boolean gates or lookup table gates."),
Expand Down
4 changes: 4 additions & 0 deletions lib/Transforms/YosysOptimizer/YosysOptimizer.td
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ def YosysOptimizer : Pass<"yosys-optimizer"> {
factor. If unset, this pass will not unroll any loops.
- `print-stats`: Prints statistics about the optimized circuits.
- `mode={Boolean,LUT}`: Map gates to boolean gates or lookup table gates.
- `use-submodules`: Extract the body of a generic op into submodules.
Useful for large programs with generics that can be isolated. This should
not be used when distributing generics through loops to avoid index
arguments in the function body.
}];
// TODO(#257): add option for the pass to select the unroll factor
// automatically.
Expand Down
1 change: 1 addition & 0 deletions tests/Transforms/tosa_to_boolean_tfhe/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ glob_lit_tests(
driver = "@heir//tests:run_lit.sh",
size_override = {
"fully_connected.mlir": "large",
"hello_world_small.mlir": "large",
},
tags_override = {
"hello_world.mlir": [
Expand Down
Loading

0 comments on commit 5ffc8f5

Please sign in to comment.