diff --git a/lib/Analysis/SelectVariableNames/BUILD b/lib/Analysis/SelectVariableNames/BUILD index e7b28499a..9705208e5 100644 --- a/lib/Analysis/SelectVariableNames/BUILD +++ b/lib/Analysis/SelectVariableNames/BUILD @@ -10,6 +10,9 @@ cc_library( hdrs = ["SelectVariableNames.h"], deps = [ "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", ], ) diff --git a/lib/Analysis/SelectVariableNames/SelectVariableNames.cpp b/lib/Analysis/SelectVariableNames/SelectVariableNames.cpp index 96c9a2934..2c7100646 100644 --- a/lib/Analysis/SelectVariableNames/SelectVariableNames.cpp +++ b/lib/Analysis/SelectVariableNames/SelectVariableNames.cpp @@ -4,6 +4,7 @@ #include "mlir/include/mlir/IR/Operation.h" // from @llvm-project #include "mlir/include/mlir/IR/Value.h" // from @llvm-project #include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project namespace mlir { namespace heir { diff --git a/lib/Analysis/SelectVariableNames/SelectVariableNames.h b/lib/Analysis/SelectVariableNames/SelectVariableNames.h index b2a9ba69d..25347ef5e 100644 --- a/lib/Analysis/SelectVariableNames/SelectVariableNames.h +++ b/lib/Analysis/SelectVariableNames/SelectVariableNames.h @@ -1,11 +1,14 @@ #ifndef LIB_ANALYSIS_SELECTVARIABLENAMES_SELECTVARIABLENAMES_H_ #define LIB_ANALYSIS_SELECTVARIABLENAMES_SELECTVARIABLENAMES_H_ +#include #include -#include "llvm/include/llvm/ADT/DenseMap.h" // from @llvm-project -#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project -#include "mlir/include/mlir/IR/Value.h" // from @llvm-project +#include "llvm/include/llvm/ADT/DenseMap.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project +#include "mlir/include/mlir/IR/Value.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project namespace mlir { namespace heir { diff --git a/lib/Conversion/SecretToCKKS/SecretToCKKS.cpp b/lib/Conversion/SecretToCKKS/SecretToCKKS.cpp index 2cb503696..97886ace1 100644 --- a/lib/Conversion/SecretToCKKS/SecretToCKKS.cpp +++ b/lib/Conversion/SecretToCKKS/SecretToCKKS.cpp @@ -225,6 +225,7 @@ struct SecretToCKKS : public impl::SecretToCKKSBase { SecretGenericOpCipherConversion, SecretGenericOpCipherConversion, SecretGenericOpCipherConversion, + SecretGenericOpCipherConversion, SecretGenericTensorExtractConversion, SecretGenericTensorInsertConversion, SecretGenericOpRotateConversion, diff --git a/lib/Target/OpenFhePke/BUILD b/lib/Target/OpenFhePke/BUILD index d9e687006..2a641bebb 100644 --- a/lib/Target/OpenFhePke/BUILD +++ b/lib/Target/OpenFhePke/BUILD @@ -5,10 +5,38 @@ package( default_visibility = ["//visibility:public"], ) +cc_library( + name = "OpenFheRegistration", + srcs = [ + "OpenFheTranslateRegistration.cpp", + ], + hdrs = [ + "OpenFheTranslateRegistration.h", + ], + deps = [ + ":OpenFhePkeEmitter", + ":OpenFhePkeHeaderEmitter", + ":OpenFheUtils", + "@heir//lib/Analysis/SelectVariableNames", + "@heir//lib/Dialect/LWE/IR:Dialect", + "@heir//lib/Dialect/Openfhe/IR:Dialect", + "@heir//lib/Target:Utils", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:PolynomialDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TranslateLib", + ], +) + cc_library( name = "OpenFheUtils", srcs = ["OpenFheUtils.cpp"], hdrs = [ + "OpenFhePkeTemplates.h", "OpenFheUtils.h", ], deps = [ @@ -26,7 +54,6 @@ cc_library( srcs = ["OpenFhePkeEmitter.cpp"], hdrs = [ "OpenFhePkeEmitter.h", - "OpenFhePkeTemplates.h", ], deps = [ ":OpenFheUtils", @@ -38,10 +65,8 @@ cc_library( "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", - "@llvm-project//mlir:PolynomialDialect", "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", - "@llvm-project//mlir:TranslateLib", ], ) @@ -55,15 +80,12 @@ cc_library( deps = [ ":OpenFheUtils", "@heir//lib/Analysis/SelectVariableNames", - "@heir//lib/Dialect/LWE/IR:Dialect", "@heir//lib/Dialect/Openfhe/IR:Dialect", "@heir//lib/Target:Utils", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", - "@llvm-project//mlir:PolynomialDialect", "@llvm-project//mlir:Support", - "@llvm-project//mlir:TranslateLib", ], ) diff --git a/lib/Target/OpenFhePke/OpenFhePkeEmitter.cpp b/lib/Target/OpenFhePke/OpenFhePkeEmitter.cpp index 6449bcc54..b76aa3b28 100644 --- a/lib/Target/OpenFhePke/OpenFhePkeEmitter.cpp +++ b/lib/Target/OpenFhePke/OpenFhePkeEmitter.cpp @@ -3,61 +3,70 @@ #include #include #include -#include #include #include #include #include "lib/Analysis/SelectVariableNames/SelectVariableNames.h" -#include "lib/Dialect/LWE/IR/LWEDialect.h" #include "lib/Dialect/LWE/IR/LWEOps.h" -#include "lib/Dialect/Openfhe/IR/OpenfheDialect.h" #include "lib/Dialect/Openfhe/IR/OpenfheOps.h" -#include "lib/Target/OpenFhePke/OpenFhePkeTemplates.h" #include "lib/Target/OpenFhePke/OpenFheUtils.h" #include "lib/Target/Utils.h" -#include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project -#include "llvm/include/llvm/ADT/StringExtras.h" // from @llvm-project -#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project -#include "llvm/include/llvm/Support/FormatVariadic.h" // from @llvm-project -#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project -#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/include/mlir/Dialect/Polynomial/IR/PolynomialDialect.h" // from @llvm-project +#include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project +#include "llvm/include/llvm/ADT/StringExtras.h" // from @llvm-project +#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project +#include "llvm/include/llvm/Support/FormatVariadic.h" // from @llvm-project +#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project #include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/include/mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/include/mlir/IR/Types.h" // from @llvm-project #include "mlir/include/mlir/IR/Value.h" // from @llvm-project #include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project #include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project #include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project #include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project -#include "mlir/include/mlir/Tools/mlir-translate/Translation.h" // from @llvm-project namespace mlir { namespace heir { namespace openfhe { -void registerToOpenFhePkeTranslation() { - TranslateFromMLIRRegistration reg( - "emit-openfhe-pke", - "translate the openfhe dialect to C++ code against the OpenFHE pke API", - [](Operation *op, llvm::raw_ostream &output) { - return translateToOpenFhePke(op, output); - }, - [](DialectRegistry ®istry) { - registry.insert(); - }); -} - -LogicalResult translateToOpenFhePke(Operation *op, llvm::raw_ostream &os) { +namespace { + +FailureOr printFloatAttr(FloatAttr floatAttr) { + if (!floatAttr.getType().isF32() || !floatAttr.getType().isF64()) { + return failure(); + } + + SmallString<128> strValue; + auto apValue = APFloat(floatAttr.getValueAsDouble()); + apValue.toString(strValue, /*FormatPrecision=*/0, /*FormatMaxPadding=*/15, + /*TruncateZero=*/true); + return std::string(strValue); +} + +FailureOr getStringForConstant(Value value) { + if (auto constantOp = + dyn_cast_or_null(value.getDefiningOp())) { + auto valueAttr = constantOp.getValue(); + if (auto intAttr = dyn_cast(valueAttr)) { + return std::to_string(intAttr.getInt()); + } else if (auto floatAttr = dyn_cast(valueAttr)) { + return printFloatAttr(floatAttr); + } + } + return failure(); +} + +} // namespace + +LogicalResult translateToOpenFhePke(Operation *op, llvm::raw_ostream &os, + const OpenfheScheme &scheme) { SelectVariableNames variableNames(op); - OpenFhePkeEmitter emitter(os, &variableNames); + OpenFhePkeEmitter emitter(os, &variableNames, scheme); LogicalResult result = emitter.translate(*op); return result; } @@ -71,8 +80,11 @@ LogicalResult OpenFhePkeEmitter::translate(Operation &op) { .Case( [&](auto op) { return printOperation(op); }) // Arith ops - .Case( - [&](auto op) { return printOperation(op); }) + .Case([&](auto op) { return printOperation(op); }) + // Tensor ops + .Case([&](auto op) { return printOperation(op); }) // LWE ops .Case( [&](auto op) { return printOperation(op); }) @@ -81,21 +93,21 @@ LogicalResult OpenFhePkeEmitter::translate(Operation &op) { NegateOp, MulConstOp, RelinOp, ModReduceOp, LevelReduceOp, RotOp, AutomorphOp, KeySwitchOp, EncryptOp, DecryptOp, GenParamsOp, GenContextOp, GenMulKeyOp, GenRotKeyOp, - MakePackedPlaintextOp>( + MakePackedPlaintextOp, MakeCKKSPackedPlaintextOp>( [&](auto op) { return printOperation(op); }) .Default([&](Operation &) { - return op.emitOpError("unable to find printer for op"); + return emitError(op.getLoc(), "unable to find printer for op"); }); if (failed(status)) { - op.emitOpError(llvm::formatv("Failed to translate op {0}", op.getName())); - return failure(); + return emitError(op.getLoc(), + llvm::formatv("Failed to translate op {0}", op.getName())); } return success(); } LogicalResult OpenFhePkeEmitter::printOperation(ModuleOp moduleOp) { - os << kModulePrelude << "\n"; + os << getModulePrelude(scheme_) << "\n"; for (Operation &op : moduleOp) { if (failed(translate(op))) { return failure(); @@ -107,15 +119,17 @@ LogicalResult OpenFhePkeEmitter::printOperation(ModuleOp moduleOp) { LogicalResult OpenFhePkeEmitter::printOperation(func::FuncOp funcOp) { if (funcOp.getNumResults() != 1) { - return funcOp.emitOpError() << "Only functions with a single return type " - "are supported, but this function has " - << funcOp.getNumResults(); + return emitError(funcOp.getLoc(), + llvm::formatv("Only functions with a single return type " + "are supported, but this function has ", + funcOp.getNumResults())); return failure(); } Type result = funcOp.getResultTypes()[0]; if (failed(emitType(result))) { - return funcOp.emitOpError() << "Failed to emit type " << result; + return emitError(funcOp.getLoc(), + llvm::formatv("Failed to emit type {0}", result)); } os << " " << funcOp.getName() << "("; @@ -127,7 +141,8 @@ LogicalResult OpenFhePkeEmitter::printOperation(func::FuncOp funcOp) { // emitter. for (Value arg : funcOp.getArguments()) { if (failed(convertType(arg.getType()))) { - return funcOp.emitOpError() << "Failed to emit type " << arg.getType(); + return emitError(funcOp.getLoc(), + llvm::formatv("Failed to emit type {0}", arg.getType())); } } @@ -154,8 +169,7 @@ LogicalResult OpenFhePkeEmitter::printOperation(func::FuncOp funcOp) { LogicalResult OpenFhePkeEmitter::printOperation(func::ReturnOp op) { if (op.getNumOperands() != 1) { - op.emitError() << "Only one return value supported"; - return failure(); + return emitError(op.getLoc(), "Only one return value supported"); } os << "return " << variableNames->getNameForValue(op.getOperands()[0]) << ";\n"; @@ -295,32 +309,52 @@ LogicalResult OpenFhePkeEmitter::printOperation(KeySwitchOp op) { LogicalResult OpenFhePkeEmitter::printOperation(arith::ConstantOp op) { auto valueAttr = op.getValue(); if (auto intAttr = dyn_cast(valueAttr)) { + // Constant integers may be unused if their uses directly output the + // constant value (e.g. tensor.insert and tensor.extract use the defining + // constant values of indices if available). + os << "[[maybe_unused]] "; if (failed(emitTypedAssignPrefix(op.getResult()))) { return failure(); } os << intAttr.getValue() << ";\n"; - } else if (auto denseElementsAttr = dyn_cast(valueAttr)) { - if (denseElementsAttr.getType().getRank() != 1) { - return op.emitError() << "Only 1D dense elements supported"; - } - + } else if (auto floatAttr = dyn_cast(valueAttr)) { if (failed(emitTypedAssignPrefix(op.getResult()))) { return failure(); } - os << "{"; - - auto cstIter = denseElementsAttr.value_begin(); - auto cstIterEnd = denseElementsAttr.value_end(); - SmallString<10> first; - APInt firstVal = *cstIter; - firstVal.toStringSigned(first); - os << std::accumulate(std::next(cstIter), cstIterEnd, std::string(first), - [&](const std::string &a, const APInt &b) { - SmallString<10> str; - b.toStringSigned(str); - return a + ", " + std::string(str); - }); - os << "};\n"; + auto floatStr = printFloatAttr(floatAttr); + if (failed(floatStr)) { + return failure(); + } + os << floatStr.value() << ";\n"; + } else if (auto denseElementsAttr = dyn_cast(valueAttr)) { + if (denseElementsAttr.getType().getRank() == 1) { + // Print a 1-D constant. + if (failed(emitType(op.getResult().getType()))) { + return failure(); + } + os << " " << variableNames->getNameForValue(op.getResult()); + + std::string value_str; + llvm::raw_string_ostream ss(value_str); + denseElementsAttr.print(ss); + + if (denseElementsAttr.isSplat()) { + // SplatElementsAttr are printed as dense<2> : tensor<1xi32>. + // Output as `std::vector constant(2, 1);` + int start = value_str.find('<') + 1; + int end = value_str.find('>') - start; + os << "(" << denseElementsAttr.getNumElements() << ", " + << value_str.substr(start, end) << ");\n"; + } else { + // DenseElementsAttr are printed as dense<[1, 2]> : tensor<2xi32>. + // Output as `std::vector constant = {1, 2};` + int start = value_str.find_last_of('[') + 1; + int end = value_str.find_first_of(']') - start; + os << " = {" << value_str.substr(start, end) << "};\n"; + } + return success(); + } + return failure(); } else { return op.emitError() << "Unsupported constant type " << valueAttr.getType(); @@ -346,6 +380,24 @@ LogicalResult OpenFhePkeEmitter::printOperation(arith::ExtSIOp op) { return success(); } +LogicalResult OpenFhePkeEmitter::printOperation(arith::ExtFOp op) { + // OpenFHE has a convention that all inputs to MakeCKKSPackedPlaintext are + // std::vector, so earlier stages in the pipeline emit typecasts + + std::string inputVarName = variableNames->getNameForValue(op.getOperand()); + std::string resultVarName = variableNames->getNameForValue(op.getResult()); + + // If it's a vector, we can use a copy constructor to upcast. + if (auto tensorTy = dyn_cast(op.getOperand().getType())) { + os << "std::vector " << resultVarName << "(std::begin(" + << inputVarName << "), std::end(" << inputVarName << "));\n"; + } else { + return op.emitOpError() << "Unsupported input type"; + } + + return success(); +} + LogicalResult OpenFhePkeEmitter::printOperation(arith::IndexCastOp op) { Type outputType = op.getOut().getType(); if (failed(emitTypedAssignPrefix(op.getResult()))) { @@ -359,6 +411,75 @@ LogicalResult OpenFhePkeEmitter::printOperation(arith::IndexCastOp op) { return success(); } +LogicalResult OpenFhePkeEmitter::printOperation(tensor::EmptyOp op) { + // std::vector> result(dim0, + // std::vector(dim1)); initStr = (dim1) initStr = (dim0, + // std::vector{initStr}) + RankedTensorType resultTy = op.getResult().getType(); + auto elementTy = convertType(resultTy.getElementType()); + if (failed(elementTy)) { + return failure(); + } + if (failed(emitType(resultTy))) { + return failure(); + } + os << " " << variableNames->getNameForValue(op.getResult()); + std::string initStr = llvm::formatv("({0})", resultTy.getShape().back()); + for (auto dim : + llvm::reverse(op.getResult().getType().getShape().drop_back(1))) { + initStr = llvm::formatv("({0}, std::vector<{1}>{2})", dim, + elementTy.value(), initStr); + } + os << initStr << ";\n"; + return success(); +} + +LogicalResult OpenFhePkeEmitter::printOperation(tensor::ExtractOp op) { + // const auto& v1 = in[0, 1]; + emitAutoAssignPrefix(op.getResult()); + os << variableNames->getNameForValue(op.getTensor()); + os << bracketEnclosedValues(op.getIndices(), [&](Value value) { + auto constantStr = getStringForConstant(value); + return constantStr.value_or(variableNames->getNameForValue(value)); + }); + os << ";\n"; + return success(); +} + +LogicalResult OpenFhePkeEmitter::printOperation(tensor::InsertOp op) { + // For a tensor.insert MLIR statement, we assign the destination vector and + // then move the vector to the result. + // %result = tensor.insert %scalar into %dest[%idx] + // dest[idx] = scalar; + // Type result = std::move(dest); + os << variableNames->getNameForValue(op.getDest()); + os << bracketEnclosedValues(op.getIndices(), [&](Value value) { + auto constantStr = getStringForConstant(value); + return constantStr.value_or(variableNames->getNameForValue(value)); + }); + os << " = " << variableNames->getNameForValue(op.getScalar()) << ";\n"; + if (failed(emitTypedAssignPrefix(op.getResult()))) { + return failure(); + } + os << "std::move(" << variableNames->getNameForValue(op.getDest()) << ");\n"; + return success(); +} + +LogicalResult OpenFhePkeEmitter::printOperation(tensor::SplatOp op) { + // std::vector result(num, value); + auto result = op.getResult(); + if (failed(emitType(result.getType()))) { + return failure(); + } + if (result.getType().getRank() != 1) { + return failure(); + } + os << " " << variableNames->getNameForValue(result) << "(" + << result.getType().getNumElements() << ", " + << variableNames->getNameForValue(op.getInput()) << ");\n"; + return success(); +} + LogicalResult OpenFhePkeEmitter::printOperation( lwe::ReinterpretUnderlyingTypeOp op) { emitAutoAssignPrefix(op.getResult()); @@ -378,6 +499,23 @@ LogicalResult OpenFhePkeEmitter::printOperation( return success(); } +LogicalResult OpenFhePkeEmitter::printOperation( + openfhe::MakeCKKSPackedPlaintextOp op) { + if (scheme_ != OpenfheScheme::CKKS) { + return emitError(op.getLoc(), + "encoding CKKS plaintext not supported by chosen scheme"); + } + + std::string inputVarName = variableNames->getNameForValue(op.getValue()); + + emitAutoAssignPrefix(op.getResult()); + FailureOr resultCC = getContextualCryptoContext(op.getOperation()); + if (failed(resultCC)) return resultCC; + os << variableNames->getNameForValue(resultCC.value()) + << "->MakeCKKSPackedPlaintext(" << inputVarName << ");\n"; + return success(); +} + LogicalResult OpenFhePkeEmitter::printOperation(lwe::RLWEDecodeOp op) { // In OpenFHE a plaintext is already decoded by decrypt. The internal OpenFHE // implementation is simple enough (and dependent on currently-hard-coded @@ -386,7 +524,7 @@ LogicalResult OpenFhePkeEmitter::printOperation(lwe::RLWEDecodeOp op) { auto tensorTy = dyn_cast(op.getResult().getType()); if (tensorTy) { if (tensorTy.getRank() != 1) { - return op.emitOpError() << "Only 1D tensors supported"; + return emitError(op.getLoc(), "Only 1D tensors supported"); } // OpenFHE plaintexts must be manually resized to the decoded output size // via plaintext->SetLength(); @@ -489,8 +627,9 @@ LogicalResult OpenFhePkeEmitter::emitType(Type type) { } OpenFhePkeEmitter::OpenFhePkeEmitter(raw_ostream &os, - SelectVariableNames *variableNames) - : os(os), variableNames(variableNames) {} + SelectVariableNames *variableNames, + const OpenfheScheme &scheme) + : scheme_(scheme), os(os), variableNames(variableNames) {} } // namespace openfhe } // namespace heir } // namespace mlir diff --git a/lib/Target/OpenFhePke/OpenFhePkeEmitter.h b/lib/Target/OpenFhePke/OpenFhePkeEmitter.h index e28ccfacc..3eb588bd2 100644 --- a/lib/Target/OpenFhePke/OpenFhePkeEmitter.h +++ b/lib/Target/OpenFhePke/OpenFhePkeEmitter.h @@ -6,35 +6,40 @@ #include "lib/Analysis/SelectVariableNames/SelectVariableNames.h" #include "lib/Dialect/LWE/IR/LWEOps.h" #include "lib/Dialect/Openfhe/IR/OpenfheOps.h" -#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project -#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project -#include "mlir/include/mlir/IR/Types.h" // from @llvm-project -#include "mlir/include/mlir/IR/Value.h" // from @llvm-project -#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project -#include "mlir/include/mlir/Support/IndentedOstream.h" // from @llvm-project -#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project -#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project +#include "lib/Target/OpenFhePke/OpenFheUtils.h" +#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project +#include "mlir/include/mlir/IR/Types.h" // from @llvm-project +#include "mlir/include/mlir/IR/Value.h" // from @llvm-project +#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/include/mlir/Support/IndentedOstream.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project namespace mlir { namespace heir { namespace openfhe { -void registerToOpenFhePkeTranslation(); - /// Translates the given operation to OpenFhePke. ::mlir::LogicalResult translateToOpenFhePke(::mlir::Operation *op, - llvm::raw_ostream &os); + llvm::raw_ostream &os, + const OpenfheScheme &scheme); class OpenFhePkeEmitter { public: - OpenFhePkeEmitter(raw_ostream &os, SelectVariableNames *variableNames); + OpenFhePkeEmitter(raw_ostream &os, SelectVariableNames *variableNames, + const OpenfheScheme &scheme); LogicalResult translate(::mlir::Operation &operation); private: + /// OpenFHE scheme to emit. + OpenfheScheme scheme_; + /// Output stream to emit to. raw_indented_ostream os; @@ -46,7 +51,12 @@ class OpenFhePkeEmitter { LogicalResult printOperation(::mlir::ModuleOp op); LogicalResult printOperation(::mlir::arith::ConstantOp op); LogicalResult printOperation(::mlir::arith::ExtSIOp op); + LogicalResult printOperation(::mlir::arith::ExtFOp op); LogicalResult printOperation(::mlir::arith::IndexCastOp op); + LogicalResult printOperation(::mlir::tensor::EmptyOp op); + LogicalResult printOperation(::mlir::tensor::ExtractOp op); + LogicalResult printOperation(::mlir::tensor::InsertOp op); + LogicalResult printOperation(::mlir::tensor::SplatOp op); LogicalResult printOperation(::mlir::func::FuncOp op); LogicalResult printOperation(::mlir::func::ReturnOp op); LogicalResult printOperation(::mlir::heir::lwe::RLWEDecodeOp op); @@ -63,6 +73,7 @@ class OpenFhePkeEmitter { LogicalResult printOperation(KeySwitchOp op); LogicalResult printOperation(LevelReduceOp op); LogicalResult printOperation(MakePackedPlaintextOp op); + LogicalResult printOperation(MakeCKKSPackedPlaintextOp op); LogicalResult printOperation(ModReduceOp op); LogicalResult printOperation(MulConstOp op); LogicalResult printOperation(MulNoRelinOp op); diff --git a/lib/Target/OpenFhePke/OpenFhePkeHeaderEmitter.cpp b/lib/Target/OpenFhePke/OpenFhePkeHeaderEmitter.cpp index 0889db9a6..efd5c4a4a 100644 --- a/lib/Target/OpenFhePke/OpenFhePkeHeaderEmitter.cpp +++ b/lib/Target/OpenFhePke/OpenFhePkeHeaderEmitter.cpp @@ -1,51 +1,29 @@ #include "lib/Target/OpenFhePke/OpenFhePkeHeaderEmitter.h" #include "lib/Analysis/SelectVariableNames/SelectVariableNames.h" -#include "lib/Dialect/LWE/IR/LWEDialect.h" -#include "lib/Dialect/Openfhe/IR/OpenfheDialect.h" -#include "lib/Target/OpenFhePke/OpenFhePkeTemplates.h" #include "lib/Target/OpenFhePke/OpenFheUtils.h" #include "lib/Target/Utils.h" #include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project #include "llvm/include/llvm/Support/FormatVariadic.h" // from @llvm-project #include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/include/mlir/Dialect/Polynomial/IR/PolynomialDialect.h" // from @llvm-project -#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/include/mlir/IR/DialectRegistry.h" // from @llvm-project -#include "mlir/include/mlir/IR/Types.h" // from @llvm-project -#include "mlir/include/mlir/IR/Value.h" // from @llvm-project -#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project -#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project -#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project -#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project -#include "mlir/include/mlir/Tools/mlir-translate/Translation.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/include/mlir/IR/Types.h" // from @llvm-project +#include "mlir/include/mlir/IR/Value.h" // from @llvm-project +#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project namespace mlir { namespace heir { namespace openfhe { -void registerToOpenFhePkeHeaderTranslation() { - TranslateFromMLIRRegistration reg( - "emit-openfhe-pke-header", - "Emit a header corresponding to the C++ file generated by " - "--emit-openfhe-pke", - [](Operation *op, llvm::raw_ostream &output) { - return translateToOpenFhePkeHeader(op, output); - }, - [](DialectRegistry ®istry) { - registry.insert(); - }); -} - -LogicalResult translateToOpenFhePkeHeader(Operation *op, - llvm::raw_ostream &os) { +LogicalResult translateToOpenFhePkeHeader(Operation *op, llvm::raw_ostream &os, + OpenfheScheme scheme) { SelectVariableNames variableNames(op); - OpenFhePkeHeaderEmitter emitter(os, &variableNames); + OpenFhePkeHeaderEmitter emitter(os, &variableNames, scheme); return emitter.translate(*op); } @@ -66,7 +44,7 @@ LogicalResult OpenFhePkeHeaderEmitter::translate(Operation &op) { } LogicalResult OpenFhePkeHeaderEmitter::printOperation(ModuleOp moduleOp) { - os << kModulePrelude << "\n"; + os << getModulePrelude(scheme_) << "\n"; for (Operation &op : moduleOp) { if (failed(translate(op))) { return failure(); @@ -117,8 +95,9 @@ LogicalResult OpenFhePkeHeaderEmitter::emitType(Type type) { } OpenFhePkeHeaderEmitter::OpenFhePkeHeaderEmitter( - raw_ostream &os, SelectVariableNames *variableNames) - : os(os), variableNames(variableNames) {} + raw_ostream &os, SelectVariableNames *variableNames, OpenfheScheme scheme) + : scheme_(scheme), os(os), variableNames(variableNames) {} + } // namespace openfhe } // namespace heir } // namespace mlir diff --git a/lib/Target/OpenFhePke/OpenFhePkeHeaderEmitter.h b/lib/Target/OpenFhePke/OpenFhePkeHeaderEmitter.h index 46973cdd3..8f4713747 100644 --- a/lib/Target/OpenFhePke/OpenFhePkeHeaderEmitter.h +++ b/lib/Target/OpenFhePke/OpenFhePkeHeaderEmitter.h @@ -2,6 +2,7 @@ #define LIB_TARGET_OPENFHEPKE_OPENFHEPKEHEADEREMITTER_H_ #include "lib/Analysis/SelectVariableNames/SelectVariableNames.h" +#include "lib/Target/OpenFhePke/OpenFheUtils.h" #include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project @@ -16,21 +17,24 @@ namespace mlir { namespace heir { namespace openfhe { -void registerToOpenFhePkeHeaderTranslation(); - /// Translates the given operation to OpenFhePke. ::mlir::LogicalResult translateToOpenFhePkeHeader(::mlir::Operation *op, - llvm::raw_ostream &os); + llvm::raw_ostream &os, + OpenfheScheme scheme); /// For each function in the mlir module, emits a function header declaration /// along with any necessary includes. class OpenFhePkeHeaderEmitter { public: - OpenFhePkeHeaderEmitter(raw_ostream &os, SelectVariableNames *variableNames); + OpenFhePkeHeaderEmitter(raw_ostream &os, SelectVariableNames *variableNames, + OpenfheScheme scheme); LogicalResult translate(::mlir::Operation &operation); private: + /// OpenFHE scheme to emit. + OpenfheScheme scheme_; + /// Output stream to emit to. raw_indented_ostream os; diff --git a/lib/Target/OpenFhePke/OpenFhePkeTemplates.h b/lib/Target/OpenFhePke/OpenFhePkeTemplates.h index 8bcc2b2cb..705c22593 100644 --- a/lib/Target/OpenFhePke/OpenFhePkeTemplates.h +++ b/lib/Target/OpenFhePke/OpenFhePkeTemplates.h @@ -11,12 +11,12 @@ namespace openfhe { // transforms for HEIR includes. // clang-format off -constexpr std::string_view kModulePrelude = R"cpp( +constexpr std::string_view kModulePreludeTemplate = R"cpp( #include "src/pke/include/openfhe.h" // from @openfhe using namespace lbcrypto; using CiphertextT = ConstCiphertext; -using CCParamsT = CCParams; +using CCParamsT = CCParams; using CryptoContextT = CryptoContext; using EvalKeyT = EvalKey; using PlaintextT = Plaintext; diff --git a/lib/Target/OpenFhePke/OpenFheTranslateRegistration.cpp b/lib/Target/OpenFhePke/OpenFheTranslateRegistration.cpp new file mode 100644 index 000000000..9ecfd42b1 --- /dev/null +++ b/lib/Target/OpenFhePke/OpenFheTranslateRegistration.cpp @@ -0,0 +1,74 @@ +#include "lib/Target/OpenFhePke/OpenFheTranslateRegistration.h" + +#include "lib/Dialect/LWE/IR/LWEDialect.h" +#include "lib/Dialect/Openfhe/IR/OpenfheDialect.h" +#include "lib/Target/OpenFhePke/OpenFhePkeEmitter.h" +#include "lib/Target/OpenFhePke/OpenFhePkeHeaderEmitter.h" +#include "lib/Target/OpenFhePke/OpenFheUtils.h" +#include "llvm/include/llvm/Support/CommandLine.h" // from @llvm-project +#include "llvm/include/llvm/Support/ManagedStatic.h" // from @llvm-project +#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Polynomial/IR/PolynomialDialect.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/include/mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project +#include "mlir/include/mlir/IR/Value.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/include/mlir/Tools/mlir-translate/Translation.h" // from @llvm-project + +namespace mlir { +namespace heir { +namespace openfhe { + +struct TranslateOptions { + llvm::cl::opt openfheScheme{ + "openfhe-scheme", llvm::cl::desc("The OpenFHE scheme API to use"), + llvm::cl::init(mlir::heir::openfhe::OpenfheScheme::BGV), + llvm::cl::values(clEnumValN(mlir::heir::openfhe::OpenfheScheme::BGV, + "bgv", "Emit with OpenFHE BGV scheme"), + clEnumValN(mlir::heir::openfhe::OpenfheScheme::CKKS, + "ckks", "Emit with OpenFHE CKKS scheme"))}; +}; +static llvm::ManagedStatic options; + +void registerTranslateOptions() { + // Forces initialization of options. + *options; +} + +void registerToOpenFhePkeTranslation() { + TranslateFromMLIRRegistration reg( + "emit-openfhe-pke", + "translate the openfhe dialect to C++ code against the OpenFHE pke API", + [](Operation *op, llvm::raw_ostream &output) { + return translateToOpenFhePke(op, output, options->openfheScheme); + }, + [](DialectRegistry ®istry) { + registry.insert(); + }); +} + +void registerToOpenFhePkeHeaderTranslation() { + TranslateFromMLIRRegistration reg( + "emit-openfhe-pke-header", + "Emit a header corresponding to the C++ file generated by " + "--emit-openfhe-pke", + [](Operation *op, llvm::raw_ostream &output) { + return translateToOpenFhePkeHeader(op, output, options->openfheScheme); + }, + [](DialectRegistry ®istry) { + registry + .insert(); + }); +} + +} // namespace openfhe +} // namespace heir +} // namespace mlir diff --git a/lib/Target/OpenFhePke/OpenFheTranslateRegistration.h b/lib/Target/OpenFhePke/OpenFheTranslateRegistration.h new file mode 100644 index 000000000..563130cb9 --- /dev/null +++ b/lib/Target/OpenFhePke/OpenFheTranslateRegistration.h @@ -0,0 +1,19 @@ + +#ifndef LIB_TARGET_OPENFHEPKE_OPENFHETRANSLATEREGISTRATION_H_ +#define LIB_TARGET_OPENFHEPKE_OPENFHETRANSLATEREGISTRATION_H_ + +namespace mlir { +namespace heir { +namespace openfhe { + +void registerTranslateOptions(); + +void registerToOpenFhePkeTranslation(); + +void registerToOpenFhePkeHeaderTranslation(); + +} // namespace openfhe +} // namespace heir +} // namespace mlir + +#endif // LIB_TARGET_OPENFHEPKE_OPENFHETRANSLATEREGISTRATION_H_ diff --git a/lib/Target/OpenFhePke/OpenFheUtils.cpp b/lib/Target/OpenFhePke/OpenFheUtils.cpp index a5847cbe2..4efa50a42 100644 --- a/lib/Target/OpenFhePke/OpenFheUtils.cpp +++ b/lib/Target/OpenFhePke/OpenFheUtils.cpp @@ -4,7 +4,9 @@ #include "lib/Dialect/LWE/IR/LWETypes.h" #include "lib/Dialect/Openfhe/IR/OpenfheTypes.h" +#include "lib/Target/OpenFhePke/OpenFhePkeTemplates.h" #include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project +#include "llvm/include/llvm/Support/FormatVariadic.h" // from @llvm-project #include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project @@ -18,6 +20,11 @@ namespace mlir { namespace heir { namespace openfhe { +std::string getModulePrelude(OpenfheScheme scheme) { + return llvm::formatv(kModulePreludeTemplate.data(), + scheme == OpenfheScheme::CKKS ? "CKKS" : "BGV"); +} + FailureOr convertType(Type type) { return llvm::TypeSwitch>(type) // For now, these types are defined in the prelude as aliases. @@ -45,19 +52,25 @@ FailureOr convertType(Type type) { os << "int" << width << "_t"; return FailureOr(std::string(result)); }) - .Case([&](auto ty) { - if (ty.getRank() != 1) { - return FailureOr(); + .Case([&](auto ty) -> FailureOr { + auto width = ty.getWidth(); + if (width == 32) { + return std::string("float"); + } else if (width == 64) { + return std::string("double"); } - + return failure(); + }) + .Case([&](auto ty) { auto eltTyResult = convertType(ty.getElementType()); if (failed(eltTyResult)) { return FailureOr(); } - - SmallString<8> result; - llvm::raw_svector_ostream os(result); - os << "std::vector<" << eltTyResult.value() << ">"; + auto result = eltTyResult.value(); + for (int i = 0; i < ty.getRank(); ++i) { + // For each dimension, wrap the element in a std::vector. + result = "std::vector<" + result + ">"; + } return FailureOr(std::string(result)); }) .Default([&](Type &) { return failure(); }); diff --git a/lib/Target/OpenFhePke/OpenFheUtils.h b/lib/Target/OpenFhePke/OpenFheUtils.h index 93917cb0b..37f0becec 100644 --- a/lib/Target/OpenFhePke/OpenFheUtils.h +++ b/lib/Target/OpenFhePke/OpenFheUtils.h @@ -11,6 +11,10 @@ namespace mlir { namespace heir { namespace openfhe { +enum class OpenfheScheme { BGV, CKKS }; + +std::string getModulePrelude(OpenfheScheme scheme); + /// Convert a type to a string. ::mlir::FailureOr convertType(::mlir::Type type); diff --git a/tests/jaxite/end_to_end/test.bzl b/tests/jaxite/end_to_end/test.bzl index 59265f39f..bd1191931 100644 --- a/tests/jaxite/end_to_end/test.bzl +++ b/tests/jaxite/end_to_end/test.bzl @@ -38,7 +38,7 @@ def jaxite_end_to_end_test(name, mlir_src, test_src, heir_opt_flags = "", tags = heir_translate( name = py_codegen_target, src = generated_heir_opt_name, - pass_flag = "--emit-jaxite", + pass_flags = ["--emit-jaxite"], generated_filename = generated_py_filename, ) py_library( diff --git a/tests/openfhe/emit_openfhe_pke.mlir b/tests/openfhe/emit_openfhe_pke.mlir index 194864f48..eb2e29a81 100644 --- a/tests/openfhe/emit_openfhe_pke.mlir +++ b/tests/openfhe/emit_openfhe_pke.mlir @@ -1,4 +1,4 @@ -// RUN: heir-translate %s --emit-openfhe-pke | FileCheck %s +// RUN: heir-translate %s --emit-openfhe-pke --split-input-file | FileCheck %s #encoding = #lwe.polynomial_evaluation_encoding @@ -52,6 +52,8 @@ func.func @test_basic_emitter(%cc : !cc, %input1 : !ct, %input2 : !ct, %input3: return %key_switch_res: !ct } +// ----- + #degree_32_poly = #polynomial.int_polynomial<1 + x**32> #eval_encoding = #lwe.polynomial_evaluation_encoding #ring2 = #polynomial.ring @@ -99,3 +101,21 @@ func.func @simple_sum__decrypt(%arg0: !openfhe.crypto_context, %arg1: !scalar_ct %1 = lwe.rlwe_decode %0 {encoding = #eval_encoding, ring = #ring2} : !scalar_pt_ty -> i16 return %1 : i16 } + +// ----- + +// CHECK-LABEL: test_constant +// CHECK-NEXT: std::vector [[splat:.*]](2, 1.500000e+00); +// CHECK-NEXT: std::vector [[ints:.*]] = {1, 2}; +// CHECK-NEXT: std::vector [[floats:.*]] = {1.500000e+00, 2.500000e+00}; +// CHECK-NEXT: std::vector [[double1:.*]](16, -0.38478666543960571); +// CHECK-NEXT: std::vector [[double2:.*]](16, -1.1268185335211456E-4); +// CHECK-NEXT: return [[splat]]; +func.func @test_constant() -> tensor<2xf32> { + %splat = arith.constant dense<1.5> : tensor<2xf32> + %ints = arith.constant dense<[1, 2]> : tensor<2xi32> + %floats = arith.constant dense<[1.5, 2.5]> : tensor<2xf32> + %cst_175 = arith.constant dense<-0.38478666543960571> : tensor<16xf64> + %cst_176 = arith.constant dense<-1.1268185335211456E-4> : tensor<16xf64> + return %splat : tensor<2xf32> +} diff --git a/tests/openfhe/end_to_end/BUILD b/tests/openfhe/end_to_end/BUILD index fb9c6b414..cc71ac49f 100644 --- a/tests/openfhe/end_to_end/BUILD +++ b/tests/openfhe/end_to_end/BUILD @@ -47,3 +47,13 @@ openfhe_end_to_end_test( tags = ["notap"], test_src = "roberts_cross_test.cpp", ) + +openfhe_end_to_end_test( + name = "naive_matmul_test", + generated_lib_header = "naive_matmul_lib.h", + heir_opt_flags = "--mlir-to-openfhe-ckks=entry-function=matmul", + heir_translate_flags = ["--openfhe-scheme=ckks"], + mlir_src = "naive_matmul.mlir", + tags = ["notap"], + test_src = "naive_matmul_test.cpp", +) diff --git a/tests/openfhe/end_to_end/naive_matmul.mlir b/tests/openfhe/end_to_end/naive_matmul.mlir new file mode 100644 index 000000000..e9f195d77 --- /dev/null +++ b/tests/openfhe/end_to_end/naive_matmul.mlir @@ -0,0 +1,21 @@ +module { + func.func @matmul(%arg0: tensor<1x16xf32>, %arg1: tensor<1x16xf32>) -> tensor<1x16xf32> { + %0 = "tosa.const"() <{value = dense<"0x5036CB3DE147C3BEE4A9393E47C021BE40F376BFA1078D3E8D53EB3DD6E0493EEFFC3CBFBEB947BE4597B5BBD185903E9B9C1BBEEB0713BD23B418BF66736C3EABF141BEED693F3E584F72BF3CB9E83EBD0E8D3E4D87BDBE5A0439BFBE94AABECDCA91BE695FA93E870B93BE576920BF6294083F4C08633DCBACC6BDD8C9243F6CAA17BE63FE853E647E8E3F27116D3DBA00FA3DDDD4C93EA96AA03E1FD4A7BE3C3297BD387D02BFA695923E3402CB3E6A4E0F3E8D0700BF195E3E3ECA2E0EBF28F39CBC21AE853F26F7803F1C7029BFAA05383F0DBFF0BEBF82CB3F8D9F843D3640A63F75BF7FBEF615053D1937CB3FC68B41BEEC66B9BED998223F90944FBD511F9DBFE8A4C23F3C11793F68822ABFBEE1923F32109FBF79DF193F726237BFBF6FFB3F55B69FBF1EFEF83E7D4EB63F553A37BF50F054BF4072EE3FACE7A1BF79CC633C8E44723F8D844E3FBE51FABE7BFA0F3F83F258BEC956703F4E00073FE1645E3F9C8203BF6D8B66BD1936893F0042113E6EC745BF161EB23E570AFE3D961D7C3E1039C43D665C36BF2791AF3D47B452BE34128B3E77DDA1BFADE2DCBE29DA10BEA4569FBD24B92ABE4DF072BFAE8AA13E4C661B3F3DCF823E4FDC1F3EB562A1BD5FEE0ABD8FB21CBFCE54193FF31C79BE2A0A763EA3B655BFD7EDB93CEE8F443E1C9693BDB0863BBF7F70C5BEC166A93EAF6CACBDBF5F023FEC98153FC2D49C3FD9A115C01EC09EBF36CD1B3F9ABE19C05129963F4CA4BDBFAB2F1C3E309239C03EF9903F12360BBFD15A1FC0733F6C3F8D4BDF3F615C2DC083868D3F497814BD8523E03DCCA8D6BDEA77253E99D43DBEECACB53E74A657BF7AE739BECC272F3F842D833E90A07CBFEEFCFBBD97BE063E7CE7DA3EBF4AEABD473F593FE32D25BF911CAB3F07ED413DD65AFDBE532AA23F85E451BF88925B3E3D09BD3D6C0AC33F2B3A19BFB0C3163F7803133F051EBDBE94A451BF1F83C13F9FD976BFB809763EDF71D0BD4BC424BE13E9853ED757033F15A656BE522F40BEA19AA4BEF1F9953E0FDF2E3D198BD9BED1DB2ABFCDB2E83EAE500E3FE4AA0B3F0284113EF339193FCD4C10BF382D6CBE7A020C3F016DA2BE590DF63E1923163D8B94383D1AD4EABEFDA50DBBBDA0BABCA75C0DBF5D971B3FDC29103F598190BF0C8726BEA2AD41BE1B19E13CC88265BF2DD392BD6509A73ED73A6F3EC4280CBFC24284BE76727CBEC5DE023F79B7B8BE6E8E23BF12C739BD1091853ECC190A3F369C0D3F65AB74BF1A15F63EA1F5FA3E6B2BABBE9C4FECB895E499BE2D0268BE8EA7EABE374FFD3DADBB19BFC759C83F4A69D73E37B836BFE5F4E1BEED900CBEECD986BFAE4E853D022E55BE1CDB073F6E31C9BEC202C5BE4BF853BEEE54DB3EEBC9613E74C317BEB9F2A3BE755B6F3F37CE383FF01E2F3D989532BE1C591EBE19464BBF"> : tensor<16x16xf32>}> : () -> tensor<16x16xf32> + %1 = affine.for %arg2 = 0 to 1 iter_args(%arg3 = %arg1) -> (tensor<1x16xf32>) { + %2 = affine.for %arg4 = 0 to 16 iter_args(%arg5 = %arg3) -> (tensor<1x16xf32>) { + %3 = affine.for %arg6 = 0 to 16 iter_args(%arg7 = %arg5) -> (tensor<1x16xf32>) { + %extracted = tensor.extract %arg0[%arg2, %arg6] : tensor<1x16xf32> + %extracted_0 = tensor.extract %0[%arg6, %arg4] : tensor<16x16xf32> + %extracted_1 = tensor.extract %arg7[%arg2, %arg4] : tensor<1x16xf32> + %4 = arith.mulf %extracted, %extracted_0 : f32 + %5 = arith.addf %extracted_1, %4 : f32 + %inserted = tensor.insert %5 into %arg7[%arg2, %arg4] : tensor<1x16xf32> + affine.yield %inserted : tensor<1x16xf32> + } + affine.yield %3 : tensor<1x16xf32> + } + affine.yield %2 : tensor<1x16xf32> + } + return %1 : tensor<1x16xf32> + } +} diff --git a/tests/openfhe/end_to_end/naive_matmul_test.cpp b/tests/openfhe/end_to_end/naive_matmul_test.cpp new file mode 100644 index 000000000..f2fb75eb7 --- /dev/null +++ b/tests/openfhe/end_to_end/naive_matmul_test.cpp @@ -0,0 +1,100 @@ +#include +#include +#include +#include + +#include "gtest/gtest.h" // from @googletest +#include "src/core/include/lattice/hal/lat-backend.h" // from @openfhe +#include "src/pke/include/constants.h" // from @openfhe +#include "src/pke/include/cryptocontext-fwd.h" // from @openfhe +#include "src/pke/include/gen-cryptocontext.h" // from @openfhe +#include "src/pke/include/key/keypair.h" // from @openfhe +#include "src/pke/include/scheme/ckksrns/gen-cryptocontext-ckksrns-params.h" // from @openfhe +#include "src/pke/include/scheme/ckksrns/gen-cryptocontext-ckksrns.h" // from @openfhe + +// Generated headers (block clang-format from messing up order) +#include "tests/openfhe/end_to_end/naive_matmul_lib.h" + +namespace mlir { +namespace heir { +namespace openfhe { + +std::vector> dot_product__encrypt__arg0( + CryptoContextT v16, std::vector v17, PublicKeyT v18) { + std::vector v17_cast(std::begin(v17), std::end(v17)); + int32_t n = + v16->GetCryptoParameters()->GetElementParams()->GetRingDimension() / 2; + // create a 1x16 vector of ciphertexts encrypting each value + std::vector> outputs; + outputs.reserve(1); + std::vector inner; + inner.reserve(16); + for (auto v17_val : v17_cast) { + std::vector single(n, v17_val); + const auto& v19 = v16->MakeCKKSPackedPlaintext(single); + const auto& v20 = v16->Encrypt(v18, v19); + inner.push_back(v20); + } + outputs.push_back(inner); + return outputs; +} + +double dot_product__decrypt__result0(CryptoContextT v26, + std::vector> v27, + PrivateKeyT v28) { + PlaintextT v29; + v26->Decrypt(v28, v27[0][0], &v29); // just decrypt first element + double v30 = v29->GetCKKSPackedValue()[0].real(); + return v30; +} + +TEST(NaiveMatmulTest, RunTest) { + CCParams parameters; + parameters.SetMultiplicativeDepth(1); + // needs to be large enough to accommodate overflow in the plaintext space + // pick a 32-bit prime for which (p-1) / 65536 is an integer. + parameters.SetPlaintextModulus(4295294977); + CryptoContext cryptoContext = GenCryptoContext(parameters); + cryptoContext->Enable(PKE); + cryptoContext->Enable(KEYSWITCH); + cryptoContext->Enable(LEVELEDSHE); + + KeyPair keyPair; + keyPair = cryptoContext->KeyGen(); + cryptoContext->EvalMultKeyGen(keyPair.secretKey); + + auto publicKey = keyPair.publicKey; + auto secretKey = keyPair.secretKey; + + std::vector arg0Vals = {1.0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0}; // input + std::vector arg1Vals = {0.25, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0}; // bias + + // This select the first element of the matrix (0x5036cb3d = 0.0992247) and + // adds 0.25 + double expected = 0.3492247; + + // TODO(#645): support cyclic repetition in add-client-interface + // TODO(#891): support other schemes besides BGV in add-client-interface + auto arg0Encrypted = + dot_product__encrypt__arg0(cryptoContext, arg0Vals, publicKey); + auto arg1Encrypted = + dot_product__encrypt__arg0(cryptoContext, arg1Vals, publicKey); + + // Insert timing info + std::clock_t c_start = std::clock(); + auto outputEncrypted = matmul(cryptoContext, arg0Encrypted, arg1Encrypted); + std::clock_t c_end = std::clock(); + double time_elapsed_ms = 1000.0 * (c_end - c_start) / CLOCKS_PER_SEC; + std::cout << "CPU time used: " << time_elapsed_ms << " ms\n"; + + auto actual = + dot_product__decrypt__result0(cryptoContext, outputEncrypted, secretKey); + + EXPECT_NEAR(expected, actual, 1e-6); +} + +} // namespace openfhe +} // namespace heir +} // namespace mlir diff --git a/tests/openfhe/end_to_end/test.bzl b/tests/openfhe/end_to_end/test.bzl index 1dd2f4f21..d8ca3701f 100644 --- a/tests/openfhe/end_to_end/test.bzl +++ b/tests/openfhe/end_to_end/test.bzl @@ -3,7 +3,7 @@ load("@heir//tools:heir-opt.bzl", "heir_opt") load("@heir//tools:heir-translate.bzl", "heir_translate") -def openfhe_end_to_end_test(name, mlir_src, test_src, generated_lib_header, heir_opt_flags = "", data = [], tags = [], deps = [], **kwargs): +def openfhe_end_to_end_test(name, mlir_src, test_src, generated_lib_header, heir_opt_flags = "", heir_translate_flags = [], data = [], tags = [], deps = [], **kwargs): """A rule for running generating OpenFHE and running a test on it. Args: @@ -13,6 +13,7 @@ def openfhe_end_to_end_test(name, mlir_src, test_src, generated_lib_header, heir generated_lib_header: The name of the generated .h file (explicit because it needs to be manually #include'd in the test_src file) heir_opt_flags: Flags to pass to heir-opt before heir-translate + heir_translate_flags: Flags to pass to heir-translate data: Data dependencies to be passed to cc_test tags: Tags to pass to cc_test deps: Deps to pass to cc_test and cc_library @@ -24,6 +25,7 @@ def openfhe_end_to_end_test(name, mlir_src, test_src, generated_lib_header, heir generated_cc_filename = "%s_lib.inc.cc" % name heir_opt_name = "%s_heir_opt" % name generated_heir_opt_name = "%s_heir_opt.mlir" % name + heir_translate_flags = heir_translate_flags + ["--emit-openfhe-pke"] if heir_opt_flags: heir_opt( @@ -38,13 +40,13 @@ def openfhe_end_to_end_test(name, mlir_src, test_src, generated_lib_header, heir heir_translate( name = cc_codegen_target, src = generated_heir_opt_name, - pass_flag = "--emit-openfhe-pke", + pass_flags = heir_translate_flags, generated_filename = generated_cc_filename, ) heir_translate( name = h_codegen_target, src = generated_heir_opt_name, - pass_flag = "--emit-openfhe-pke-header", + pass_flags = heir_translate_flags, generated_filename = generated_lib_header, ) native.cc_library( diff --git a/tests/openfhe/naive_matmul.mlir b/tests/openfhe/naive_matmul.mlir new file mode 100644 index 000000000..ceb50bac5 --- /dev/null +++ b/tests/openfhe/naive_matmul.mlir @@ -0,0 +1,46 @@ +// RUN: heir-opt --mlir-to-openfhe-ckks='ciphertext-degree=16 entry-function=matmul' %s | heir-translate --emit-openfhe-pke --openfhe-scheme=ckks | FileCheck %s + +// CHECK-LABEL: std::vector> matmul( +// CHECK-SAME: CryptoContextT [[v0:[^,]*]], +// CHECK-SAME: std::vector> [[v1:[^,]*]], +// CHECK-SAME: std::vector> [[v2:[^,]*]]) +// CHECK-DAG: std::vector [[v3:.*]](16, 6.000000e+00); +// CHECK-DAG: std::vector [[v4:.*]](16, 3.000000e+00); +// CHECK-DAG: std::vector [[v5:.*]](16, 4.000000e+00); +// CHECK-DAG: std::vector [[v6:.*]](16, 2.000000e+00); +// CHECK-DAG: size_t [[v7:.*]] = 1; +// CHECK-DAG: size_t [[v8:.*]] = 0; +// CHECK-DAG: const auto& [[v9:.*]] = [[v1]][0][0]; +// CHECK-DAG: const auto& [[v10:.*]] = [[v2]][0][0]; +// CHECK: const auto& [[v11:.*]] = [[v0]]->MakeCKKSPackedPlaintext([[v6]]); +// CHECK-NEXT: const auto& [[v12:.*]] = [[v0]]->EvalMult([[v9]], [[v11]]); +// CHECK-NEXT: const auto& [[v13:.*]] = [[v0]]->EvalAdd([[v10]], [[v12]]); +// CHECK-NEXT: [[v2]][0][0] = [[v13]]; +// CHECK-COUNT-3: [[v0]]->EvalMult +// CHECK: return + +// CHECK-LABEL: matmul__generate_crypto_context +// CHECK: SetMultiplicativeDepth(1) +// CHECK-LABEL: matmul__configure_crypto_context + +module { + func.func @matmul(%arg0: tensor<1x2xf32>, %arg1: tensor<1x2xf32>) -> tensor<1x2xf32> { + %0 = "tosa.const"() <{value = dense<[[2.0, 3.0], [4.0, 6.0]]> : tensor<2x2xf32>}> : () -> tensor<2x2xf32> + %1 = affine.for %arg2 = 0 to 1 iter_args(%arg3 = %arg1) -> (tensor<1x2xf32>) { + %2 = affine.for %arg4 = 0 to 2 iter_args(%arg5 = %arg3) -> (tensor<1x2xf32>) { + %3 = affine.for %arg6 = 0 to 2 iter_args(%arg7 = %arg5) -> (tensor<1x2xf32>) { + %extracted = tensor.extract %arg0[%arg2, %arg6] : tensor<1x2xf32> + %extracted_0 = tensor.extract %0[%arg6, %arg4] : tensor<2x2xf32> + %extracted_1 = tensor.extract %arg7[%arg2, %arg4] : tensor<1x2xf32> + %4 = arith.mulf %extracted, %extracted_0 : f32 + %5 = arith.addf %extracted_1, %4 : f32 + %inserted = tensor.insert %5 into %arg7[%arg2, %arg4] : tensor<1x2xf32> + affine.yield %inserted : tensor<1x2xf32> + } + affine.yield %3 : tensor<1x2xf32> + } + affine.yield %2 : tensor<1x2xf32> + } + return %1 : tensor<1x2xf32> + } +} diff --git a/tools/BUILD b/tools/BUILD index f7694422d..70454f5d8 100644 --- a/tools/BUILD +++ b/tools/BUILD @@ -149,11 +149,11 @@ cc_binary( "@heir//lib/Source/AutoHog:AutoHogImporter", "@heir//lib/Target/Jaxite:JaxiteEmitter", "@heir//lib/Target/Metadata:MetadataEmitter", - "@heir//lib/Target/OpenFhePke:OpenFhePkeEmitter", - "@heir//lib/Target/OpenFhePke:OpenFhePkeHeaderEmitter", + "@heir//lib/Target/OpenFhePke:OpenFheRegistration", "@heir//lib/Target/TfheRust:TfheRustEmitter", "@heir//lib/Target/TfheRustBool:TfheRustBoolEmitter", "@heir//lib/Target/Verilog:VerilogEmitter", + "@llvm-project//llvm:Support", "@llvm-project//mlir:Support", "@llvm-project//mlir:TranslateLib", ], diff --git a/tools/heir-translate.bzl b/tools/heir-translate.bzl index a5f394b89..86a721bd9 100644 --- a/tools/heir-translate.bzl +++ b/tools/heir-translate.bzl @@ -14,7 +14,7 @@ _HEIR_TRANSLATE = "@heir//tools:heir-translate" def _heir_translate_impl(ctx): generated_file = ctx.outputs.generated_filename args = ctx.actions.args() - args.add(ctx.attr.pass_flag) + args.add_all(ctx.attr.pass_flags) args.add_all(["-o", generated_file.path]) args.add(ctx.file.src) @@ -40,9 +40,9 @@ heir_translate = rule( doc = "A single MLIR source file to translate.", allow_single_file = [".mlir"], ), - "pass_flag": attr.string( + "pass_flags": attr.string_list( doc = """ - The pass flag passed to heir-translate, e.g., --emit-openfhe-pke + The pass flags passed to heir-translate, e.g., --emit-openfhe-pke """, ), "generated_filename": attr.output( diff --git a/tools/heir-translate.cpp b/tools/heir-translate.cpp index 0290f6300..0cb5548ef 100644 --- a/tools/heir-translate.cpp +++ b/tools/heir-translate.cpp @@ -1,12 +1,11 @@ #include "lib/Source/AutoHog/AutoHogImporter.h" #include "lib/Target/Jaxite/JaxiteEmitter.h" #include "lib/Target/Metadata/MetadataEmitter.h" -#include "lib/Target/OpenFhePke/OpenFhePkeEmitter.h" -#include "lib/Target/OpenFhePke/OpenFhePkeHeaderEmitter.h" +#include "lib/Target/OpenFhePke/OpenFheTranslateRegistration.h" #include "lib/Target/TfheRust/TfheRustEmitter.h" #include "lib/Target/TfheRustBool/TfheRustBoolEmitter.h" #include "lib/Target/Verilog/VerilogEmitter.h" -#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project +#include "llvm/include/llvm/Support/LogicalResult.h" // from @llvm-project #include "mlir/include/mlir/Tools/mlir-translate/MlirTranslateMain.h" // from @llvm-project int main(int argc, char **argv) { @@ -22,6 +21,7 @@ int main(int argc, char **argv) { mlir::heir::jaxite::registerToJaxiteTranslation(); // OpenFHE + mlir::heir::openfhe::registerTranslateOptions(); mlir::heir::openfhe::registerToOpenFhePkeTranslation(); mlir::heir::openfhe::registerToOpenFhePkeHeaderTranslation();