Skip to content

Commit

Permalink
lwe: add new LWE Attributes & Types
Browse files Browse the repository at this point in the history
See google#785

New Attributes & Types:

```
#generator = #polynomial.int_polynomial<1 + x**1024>
#ring = #polynomial.ring<coefficientType = i32, coefficientModulus = 65536 : i32, polynomialModulus=#generator>

// Application Data Information
#preserve_overflow = #lwe.preserve_overflow<>
#application_data = #lwe.application_data<message_type = i1, overflow = #preserve_overflow>

// Plaintext Space Information
#inverse_canonical_enc = #lwe.inverse_canonical_encoding<scaling_factor = 10000>
#plaintext_space = #lwe.plaintext_space<plaintext_ring = #ring, plaintext_encoding = #inverse_canonical_enc>

// Ciphertext Space Information
#ciphertext_space = #lwe.ciphertext_space<ciphertext_ring = #ring, encryption_type = msb>

// Modulus Chain info (for RLWE)
#modulus_chain = #lwe.modulus_chain<chain = <65536 : i32>, current = 0>

// Key Information
#key = #lwe.key<id = "1234", size = 1>

// New Types
!secret_key = !lwe.new_lwe_secret_key<key = #key, ring = #ring>
!public_key = !lwe.new_lwe_public_key<key = #key, ring = #ring>
!new_lwe_plaintext = !lwe.new_lwe_plaintext<application_data = #application_data, plaintext_space = #plaintext_space>
!new_lwe_ciphertext = !lwe.new_lwe_ciphertext<application_data = #application_data, plaintext_space = #plaintext_space, key = #key, ciphertext_space = #ciphertext_space, modulus_chain = #modulus_chain>
```

PiperOrigin-RevId: 684108448
  • Loading branch information
asraa authored and copybara-github committed Oct 9, 2024
1 parent 5280ed3 commit 2fb6474
Show file tree
Hide file tree
Showing 10 changed files with 630 additions and 14 deletions.
14 changes: 14 additions & 0 deletions lib/Dialect/LWE/IR/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ td_library(
"LWEOps.td",
"LWETraits.td",
"LWETypes.td",
"NewLWEAttributes.td",
"NewLWETypes.td",
],
# include from the heir-root to enable fully-qualified include-paths
includes = ["../../../.."],
Expand Down Expand Up @@ -101,6 +103,18 @@ gentbl_cc_library(
],
"LWEAttributes.cpp.inc",
),
(
[
"-gen-enum-decls",
],
"LWEEnums.h.inc",
),
(
[
"-gen-enum-defs",
],
"LWEEnums.cpp.inc",
),
(
["-gen-attrdef-doc"],
"LWEAttributes.md",
Expand Down
3 changes: 1 addition & 2 deletions lib/Dialect/LWE/IR/LWEAttributes.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
#define LIB_DIALECT_LWE_IR_LWEATTRIBUTES_H_

#include "lib/Dialect/LWE/IR/LWEDialect.h"
#include "lib/Dialect/LWE/IR/LWEEnums.h.inc"
#include "mlir/include/mlir/IR/TensorEncoding.h" // from @llvm-project

// Required to pull in poly's Ring_Attr
#include "mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.h" // from @llvm-project

#define GET_ATTRDEF_CLASSES
#include "lib/Dialect/LWE/IR/LWEAttributes.h.inc"

Expand Down
11 changes: 8 additions & 3 deletions lib/Dialect/LWE/IR/LWEAttributes.td
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ include "mlir/IR/DialectBase.td"
include "mlir/IR/OpBase.td"
include "mlir/IR/TensorEncoding.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "lib/Dialect/LWE/IR/NewLWEAttributes.td"

class LWE_EncodingAttr<string attrName, string attrMnemonic, list<Trait> traits = []>
: AttrDef<LWE_Dialect, attrName, traits # [
Expand All @@ -21,9 +22,6 @@ class LWE_EncodingAttr<string attrName, string attrMnemonic, list<Trait> traits

class LWE_EncodingAttrWithScalingFactor<string attrName, string attrMnemonic, list<Trait> traits = []>
: LWE_EncodingAttr<attrName, attrMnemonic, traits> {
// These parameters represent the base-2 logarithm of the scaling factor to
// scale cleartexts in preparation for noise growth of FHE schemes. This
// representation restricts the scaling factors to being powers of two.
let parameters = (ins
"unsigned":$cleartext_start,
"unsigned":$cleartext_bitwidth
Expand Down Expand Up @@ -294,6 +292,13 @@ def RLWE_InverseCanonicalEmbeddingEncoding

def AnyRLWEEncodingAttr : AnyAttrOf<[RLWE_PolynomialCoefficientEncoding, RLWE_PolynomialEvaluationEncoding, RLWE_InverseCanonicalEmbeddingEncoding]>;

def AnyPlaintextEncodingInfo : AnyAttrOf<[
LWE_BitFieldEncoding,
RLWE_PolynomialCoefficientEncoding,
RLWE_PolynomialEvaluationEncoding,
RLWE_InverseCanonicalEmbeddingEncoding
]>;

def LWE_LWEParams : AttrDef<LWE_Dialect, "LWEParams"> {
let mnemonic = "lwe_params";

Expand Down
50 changes: 48 additions & 2 deletions lib/Dialect/LWE/IR/LWEDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
#include "llvm/include/llvm/Support/Casting.h" // from @llvm-project
#include "llvm/include/llvm/Support/ErrorHandling.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Polynomial/IR/PolynomialTypes.h" // from @llvm-project
#include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/include/mlir/IR/DialectImplementation.h" // from @llvm-project

// Generated definitions
#include "lib/Dialect/LWE/IR/LWEDialect.cpp.inc"
#include "lib/Dialect/LWE/IR/LWEEnums.cpp.inc"
#include "mlir/include/mlir/IR/Location.h" // from @llvm-project
#include "mlir/include/mlir/IR/Types.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
Expand Down Expand Up @@ -61,8 +63,8 @@ LogicalResult RMulOp::verify() {
}

LogicalResult RMulOp::inferReturnTypes(
MLIRContext *ctx, std::optional<Location>, RMulOp::Adaptor adaptor,
SmallVectorImpl<Type> &inferredReturnTypes) {
MLIRContext* ctx, std::optional<Location>, RMulOp::Adaptor adaptor,
SmallVectorImpl<Type>& inferredReturnTypes) {
auto x = cast<lwe::RLWECiphertextType>(adaptor.getLhs().getType());
auto y = cast<lwe::RLWECiphertextType>(adaptor.getRhs().getType());
auto newDim =
Expand Down Expand Up @@ -230,6 +232,50 @@ LogicalResult RLWEEncryptOp::verify() {
return success();
}

LogicalResult ApplicationDataAttr::verify(
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
mlir::Type messageType, Attribute overflow) {
if (!mlir::isa<PreserveOverflowAttr, NoOverflowAttr>(overflow)) {
return emitError() << "overflow must be either preserve_overflow or "
<< "no_overflow, but found " << overflow << "\n";
}

return success();
}

LogicalResult PlaintextSpaceAttr::verify(
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
mlir::polynomial::RingAttr ring, Attribute encoding) {
if (mlir::isa<FullCRTPackingEncodingAttr>(encoding)) {
// For full CRT packing, the ring must be of the form x^n + 1 and the
// modulus must be 1 mod n.
auto polyMod = ring.getPolynomialModulus();
auto poly = polyMod.getPolynomial();
auto polyTerms = poly.getTerms();
if (polyTerms.size() != 2) {
return emitError() << "polynomial modulus must be of the form x^n + 1, "
<< "but found " << polyMod << "\n";
}
const auto& constantTerm = polyTerms[0];
const auto& constantCoeff = constantTerm.getCoefficient();
if (!(constantTerm.getExponent().isZero() && constantCoeff.isOne() &&
polyTerms[1].getCoefficient().isOne())) {
return emitError() << "polynomial modulus must be of the form x^n + 1, "
<< "but found " << polyMod << "\n";
}
// Check that the modulus is 1 mod n.
APInt modulus = ring.getCoefficientModulus().getValue();
unsigned n = poly.getDegree();
if (!modulus.urem(APInt(modulus.getBitWidth(), n)).isOne()) {
return emitError()
<< "modulus must be 1 mod n for full CRT packing, mod = "
<< ring.getCoefficientModulus() << " n = " << n << "\n";
}
}

return success();
}

} // namespace lwe
} // namespace heir
} // namespace mlir
8 changes: 1 addition & 7 deletions lib/Dialect/LWE/IR/LWETypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,12 @@

include "lib/Dialect/LWE/IR/LWEDialect.td"
include "lib/Dialect/LWE/IR/LWEAttributes.td"
include "lib/Dialect/LWE/IR/NewLWETypes.td"

include "mlir/IR/DialectBase.td"
include "mlir/IR/BuiltinTypeInterfaces.td"
include "mlir/IR/AttrTypeBase.td"

// A base class for all types in this dialect
class LWE_Type<string name, string typeMnemonic, list<Trait> traits = []>
: TypeDef<LWE_Dialect, name, traits> {
let mnemonic = typeMnemonic;
let assemblyFormat = "`<` struct(params) `>`";
}

// LWE Ciphertexts are ranked tensors of integers representing the LWE samples
// and the bias.
def LWECiphertext : LWE_Type<"LWECiphertext", "lwe_ciphertext", [MemRefElementTypeInterface]> {
Expand Down
Loading

0 comments on commit 2fb6474

Please sign in to comment.