From 2fb64740a1c3478b2e77a9374e3b0d9c0bfd4af2 Mon Sep 17 00:00:00 2001 From: Asra Ali Date: Wed, 9 Oct 2024 11:24:10 -0700 Subject: [PATCH] lwe: add new LWE Attributes & Types See #785 New Attributes & Types: ``` #generator = #polynomial.int_polynomial<1 + x**1024> #ring = #polynomial.ring // Application Data Information #preserve_overflow = #lwe.preserve_overflow<> #application_data = #lwe.application_data // Plaintext Space Information #inverse_canonical_enc = #lwe.inverse_canonical_encoding #plaintext_space = #lwe.plaintext_space // Ciphertext Space Information #ciphertext_space = #lwe.ciphertext_space // Modulus Chain info (for RLWE) #modulus_chain = #lwe.modulus_chain, current = 0> // Key Information #key = #lwe.key // New Types !secret_key = !lwe.new_lwe_secret_key !public_key = !lwe.new_lwe_public_key !new_lwe_plaintext = !lwe.new_lwe_plaintext !new_lwe_ciphertext = !lwe.new_lwe_ciphertext ``` PiperOrigin-RevId: 684108448 --- lib/Dialect/LWE/IR/BUILD | 14 + lib/Dialect/LWE/IR/LWEAttributes.h | 3 +- lib/Dialect/LWE/IR/LWEAttributes.td | 11 +- lib/Dialect/LWE/IR/LWEDialect.cpp | 50 +++- lib/Dialect/LWE/IR/LWETypes.td | 8 +- lib/Dialect/LWE/IR/NewLWEAttributes.td | 351 +++++++++++++++++++++++++ lib/Dialect/LWE/IR/NewLWETypes.td | 73 +++++ tests/lwe/attributes.mlir | 75 ++++++ tests/lwe/attributes_errors.mlir | 21 ++ tests/lwe/types.mlir | 38 +++ 10 files changed, 630 insertions(+), 14 deletions(-) create mode 100644 lib/Dialect/LWE/IR/NewLWEAttributes.td create mode 100644 lib/Dialect/LWE/IR/NewLWETypes.td diff --git a/lib/Dialect/LWE/IR/BUILD b/lib/Dialect/LWE/IR/BUILD index 0af2abe27..4fb566ce2 100644 --- a/lib/Dialect/LWE/IR/BUILD +++ b/lib/Dialect/LWE/IR/BUILD @@ -52,6 +52,8 @@ td_library( "LWEOps.td", "LWETraits.td", "LWETypes.td", + "NewLWEAttributes.td", + "NewLWETypes.td", ], # include from the heir-root to enable fully-qualified include-paths includes = ["../../../.."], @@ -101,6 +103,18 @@ gentbl_cc_library( ], "LWEAttributes.cpp.inc", ), + ( + [ + "-gen-enum-decls", + ], + "LWEEnums.h.inc", + ), + ( + [ + "-gen-enum-defs", + ], + "LWEEnums.cpp.inc", + ), ( ["-gen-attrdef-doc"], "LWEAttributes.md", diff --git a/lib/Dialect/LWE/IR/LWEAttributes.h b/lib/Dialect/LWE/IR/LWEAttributes.h index d2b4244fc..d4bddf6b3 100644 --- a/lib/Dialect/LWE/IR/LWEAttributes.h +++ b/lib/Dialect/LWE/IR/LWEAttributes.h @@ -2,11 +2,10 @@ #define LIB_DIALECT_LWE_IR_LWEATTRIBUTES_H_ #include "lib/Dialect/LWE/IR/LWEDialect.h" +#include "lib/Dialect/LWE/IR/LWEEnums.h.inc" #include "mlir/include/mlir/IR/TensorEncoding.h" // from @llvm-project - // Required to pull in poly's Ring_Attr #include "mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.h" // from @llvm-project - #define GET_ATTRDEF_CLASSES #include "lib/Dialect/LWE/IR/LWEAttributes.h.inc" diff --git a/lib/Dialect/LWE/IR/LWEAttributes.td b/lib/Dialect/LWE/IR/LWEAttributes.td index d245b18c1..67b9a8fae 100644 --- a/lib/Dialect/LWE/IR/LWEAttributes.td +++ b/lib/Dialect/LWE/IR/LWEAttributes.td @@ -8,6 +8,7 @@ include "mlir/IR/DialectBase.td" include "mlir/IR/OpBase.td" include "mlir/IR/TensorEncoding.td" include "mlir/Interfaces/InferTypeOpInterface.td" +include "lib/Dialect/LWE/IR/NewLWEAttributes.td" class LWE_EncodingAttr traits = []> : AttrDef traits class LWE_EncodingAttrWithScalingFactor traits = []> : LWE_EncodingAttr { - // These parameters represent the base-2 logarithm of the scaling factor to - // scale cleartexts in preparation for noise growth of FHE schemes. This - // representation restricts the scaling factors to being powers of two. let parameters = (ins "unsigned":$cleartext_start, "unsigned":$cleartext_bitwidth @@ -294,6 +292,13 @@ def RLWE_InverseCanonicalEmbeddingEncoding def AnyRLWEEncodingAttr : AnyAttrOf<[RLWE_PolynomialCoefficientEncoding, RLWE_PolynomialEvaluationEncoding, RLWE_InverseCanonicalEmbeddingEncoding]>; +def AnyPlaintextEncodingInfo : AnyAttrOf<[ + LWE_BitFieldEncoding, + RLWE_PolynomialCoefficientEncoding, + RLWE_PolynomialEvaluationEncoding, + RLWE_InverseCanonicalEmbeddingEncoding +]>; + def LWE_LWEParams : AttrDef { let mnemonic = "lwe_params"; diff --git a/lib/Dialect/LWE/IR/LWEDialect.cpp b/lib/Dialect/LWE/IR/LWEDialect.cpp index fabd27363..665b78c5b 100644 --- a/lib/Dialect/LWE/IR/LWEDialect.cpp +++ b/lib/Dialect/LWE/IR/LWEDialect.cpp @@ -10,12 +10,14 @@ #include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project #include "llvm/include/llvm/Support/Casting.h" // from @llvm-project #include "llvm/include/llvm/Support/ErrorHandling.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Polynomial/IR/PolynomialTypes.h" // from @llvm-project #include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/include/mlir/IR/DialectImplementation.h" // from @llvm-project // Generated definitions #include "lib/Dialect/LWE/IR/LWEDialect.cpp.inc" +#include "lib/Dialect/LWE/IR/LWEEnums.cpp.inc" #include "mlir/include/mlir/IR/Location.h" // from @llvm-project #include "mlir/include/mlir/IR/Types.h" // from @llvm-project #include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project @@ -61,8 +63,8 @@ LogicalResult RMulOp::verify() { } LogicalResult RMulOp::inferReturnTypes( - MLIRContext *ctx, std::optional, RMulOp::Adaptor adaptor, - SmallVectorImpl &inferredReturnTypes) { + MLIRContext* ctx, std::optional, RMulOp::Adaptor adaptor, + SmallVectorImpl& inferredReturnTypes) { auto x = cast(adaptor.getLhs().getType()); auto y = cast(adaptor.getRhs().getType()); auto newDim = @@ -230,6 +232,50 @@ LogicalResult RLWEEncryptOp::verify() { return success(); } +LogicalResult ApplicationDataAttr::verify( + ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, + mlir::Type messageType, Attribute overflow) { + if (!mlir::isa(overflow)) { + return emitError() << "overflow must be either preserve_overflow or " + << "no_overflow, but found " << overflow << "\n"; + } + + return success(); +} + +LogicalResult PlaintextSpaceAttr::verify( + ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, + mlir::polynomial::RingAttr ring, Attribute encoding) { + if (mlir::isa(encoding)) { + // For full CRT packing, the ring must be of the form x^n + 1 and the + // modulus must be 1 mod n. + auto polyMod = ring.getPolynomialModulus(); + auto poly = polyMod.getPolynomial(); + auto polyTerms = poly.getTerms(); + if (polyTerms.size() != 2) { + return emitError() << "polynomial modulus must be of the form x^n + 1, " + << "but found " << polyMod << "\n"; + } + const auto& constantTerm = polyTerms[0]; + const auto& constantCoeff = constantTerm.getCoefficient(); + if (!(constantTerm.getExponent().isZero() && constantCoeff.isOne() && + polyTerms[1].getCoefficient().isOne())) { + return emitError() << "polynomial modulus must be of the form x^n + 1, " + << "but found " << polyMod << "\n"; + } + // Check that the modulus is 1 mod n. + APInt modulus = ring.getCoefficientModulus().getValue(); + unsigned n = poly.getDegree(); + if (!modulus.urem(APInt(modulus.getBitWidth(), n)).isOne()) { + return emitError() + << "modulus must be 1 mod n for full CRT packing, mod = " + << ring.getCoefficientModulus() << " n = " << n << "\n"; + } + } + + return success(); +} + } // namespace lwe } // namespace heir } // namespace mlir diff --git a/lib/Dialect/LWE/IR/LWETypes.td b/lib/Dialect/LWE/IR/LWETypes.td index 6a6a41d1f..aad534545 100644 --- a/lib/Dialect/LWE/IR/LWETypes.td +++ b/lib/Dialect/LWE/IR/LWETypes.td @@ -3,18 +3,12 @@ include "lib/Dialect/LWE/IR/LWEDialect.td" include "lib/Dialect/LWE/IR/LWEAttributes.td" +include "lib/Dialect/LWE/IR/NewLWETypes.td" include "mlir/IR/DialectBase.td" include "mlir/IR/BuiltinTypeInterfaces.td" include "mlir/IR/AttrTypeBase.td" -// A base class for all types in this dialect -class LWE_Type traits = []> - : TypeDef { - let mnemonic = typeMnemonic; - let assemblyFormat = "`<` struct(params) `>`"; -} - // LWE Ciphertexts are ranked tensors of integers representing the LWE samples // and the bias. def LWECiphertext : LWE_Type<"LWECiphertext", "lwe_ciphertext", [MemRefElementTypeInterface]> { diff --git a/lib/Dialect/LWE/IR/NewLWEAttributes.td b/lib/Dialect/LWE/IR/NewLWEAttributes.td new file mode 100644 index 000000000..6870f1f7d --- /dev/null +++ b/lib/Dialect/LWE/IR/NewLWEAttributes.td @@ -0,0 +1,351 @@ +#ifndef LIB_DIALECT_LWE_IR_NEWLWEATTRIBUTES_TD_ +#define LIB_DIALECT_LWE_IR_NEWLWEATTRIBUTES_TD_ + +include "lib/Dialect/LWE/IR/LWEDialect.td" + +include "mlir/IR/EnumAttr.td" +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/DialectBase.td" +include "mlir/IR/OpBase.td" +include "mlir/IR/TensorEncoding.td" +include "mlir/Interfaces/InferTypeOpInterface.td" + +// Below defines new LWE attributes following +// [#785](https://github.com/google/heir/issues/785). + +class LWE_OverflowAttr + : AttrDef { + let mnemonic = attrMnemonic; + let assemblyFormat = "`<` struct(params) `>`"; +} + +def LWE_NoOverflowAttr : LWE_OverflowAttr<"NoOverflow", "no_overflow"> { + let summary = "An attribute informing that application data never overflows."; + let description = [{ + This attribute informs lowerings that a program is written so that the message data + will never overflow beyond the message type. + + // FIXME: Have a separate WraparoundOverflow, which lowers the same as NoOverflow? + }]; +} + +def LWE_PreserveOverflowAttr : LWE_OverflowAttr<"PreserveOverflow", "preserve_overflow"> { + let summary = "An attribute informing that application data overflows in the message type."; + let description = [{ + This attribute informs lowerings that a program is written so that the message data + may overflow beyond the message type. + }]; +} + +def LWE_ApplicationDataAttr : AttrDef { + let mnemonic = "application_data"; + let description = [{ + An attribute describing the semantics of the underlying application data. + + The `messageType` parameter is used to describe the type and bits of the + original application data, e.g. i1, i32, f32. This type is later mapped + into the plaintext space of an FHE scheme by embedding, scaling, or other + techniques. + + This attribute also contains information about the overflow semantics of the + data in the application. By default, we assume that the application program + was written so that the overflow is not expected and the overflow attribute + can can be `no_overflow`. For LWE-based CGGI ciphertexts, the overflow + attribute will usually be `preserve_overflow`, since messages will overflow + into padding bits. + }]; + + let parameters = (ins + "mlir::Type":$message_type, + DefaultValuedParameter<"Attribute", "NoOverflowAttr::get($_ctxt)">:$overflow + ); + + let builders = [ + AttrBuilderWithInferredContext<(ins "Type":$message_type, "Attribute":$overflow), [{ + return $_get(message_type.getContext(), message_type, overflow); + }]> + ]; + + let assemblyFormat = "`<` struct(params) `>`"; + + // Verify that the overflow attribute is one of preserve or no overflow. + let genVerifyDecl = 1; +} + +class LWE_EncodingAttrForLWE traits = []> + : AttrDef { + let mnemonic = attrMnemonic; + let assemblyFormat = "`<` struct(params) `>`"; +} + +class LWE_EncodingAttrWithScalingParam traits = []> + : LWE_EncodingAttrForLWE { + let mnemonic = attrMnemonic; + let assemblyFormat = "`<` struct(params) `>`"; + + let parameters = (ins + "unsigned":$scaling_factor + ); +} + +def LWE_ConstantCoefficientEncoding + : LWE_EncodingAttrWithScalingParam<"ConstantCoefficientEncoding", "constant_coefficient_encoding"> { + let summary = "An encoding of a scalar in the constant coefficient"; + let description = [{ + An encoding of a single scalar into the constant coefficient of the plaintext. + + All other coefficients of the plaintext are set to be zero. This encoding is + used to encode scalar LWE ciphertexts where the plaintext space is viewed + as a polynomial ring modulo `x`. + + The scalar is first multiplied by the `scaling_factor` and then rounded to + the nearest integer before encoding into the plaintext coefficient. + + Example: + + ``` + #coeff_encoding = #lwe.constant_coefficient_encoding + ``` + }]; +} + +def LWE_CoefficientEncoding + : LWE_EncodingAttrWithScalingParam<"CoefficientEncoding", "coefficient_encoding"> { + let summary = "An encoding of cleartexts directly as coefficients."; + let description = [{ + A coefficient encoding of a list of integers asserts that the coefficients + of the polynomials contain the integers, with the same semantics as + `constant_coefficient_encoding` for per-coefficient encodings. + + A `scaling_factor` is optionally applied on the scalar when converting from + a rounded floating point to an integer. + + Example: + + ``` + #coeff_encoding = #lwe.coefficient_encoding + ``` + }]; +} + +def LWE_InverseCanonicalEmbeddingEncoding + : LWE_EncodingAttrWithScalingParam<"InverseCanonicalEncoding", "inverse_canonical_encoding"> { + let summary = "An encoding of cleartexts via the inverse canonical embedding."; + let description = [{ + Let $n$ be the degree of the polynomials in the plaintext space. An + "inverse_canonical_encoding" of a list of real or complex values + $v_1, \dots, v_{n/2}$ is (almost) the inverse of the following decoding + map. + + Define a map $\tau_N$ that maps a polynomial $p \in \mathbb{Z}[x] / (x^N + 1) + \to \mathbb{C}^{N/2}$ by evaluating it at the following $N/2$ points, + where $\omega = e^{2 \pi i / 2N}$ is the primitive $2N$th root of unity: + + \[ + \omega, \omega^3, \omega^5, \dots, \omega^{N-1} + \] + + Then the complete decoding operation is $\textup{Decode}(p) = + (1/\Delta)\tau_N(p)$, where $\Delta$ is a scaling parameter and $\tau_N$ is + the truncated canonical embedding above. The encoding operation is the + inverse of the decoding operation, with some caveats explained below. + + The map $\tau_N$ is derived from the so-called _canonical embedding_ + $\tau$, though in the standard canonical embedding, we evaluate at all odd + powers of the root of unity, $\omega, \omega^3, \dots, \omega^{2N-1}$. For + polynomials in the slightly larger space $\mathbb{R}[x] / (x^N + 1)$, the + image of the canonical embedding is the subspace $H \subset \mathbb{C}^N$ + defined by tuples $(z_1, \dots, z_N)$ such that $\overline{z_i} = + \overline{z_{N-i+1}}$. Note that this property holds because polynomial + evaluation commutes with complex conjugates, and the second half of the + roots of unity evaluate are complex conjugates of the first half. The + converse, that any such tuple with complex conjugate symmetry has an + inverse under $\tau$ with all real coefficients, makes $\tau$ is a + bijection onto $H$. $\tau$ and its inverse are explicitly computable as + discrete Fourier Transforms. + + Because of the symmetry in canonical embedding for real polynomials, inputs + to this encoding can be represented as a list of $N/2$ complex points, with + the extra symmetric structure left implicit. $\tau_N$ and its inverse can + also be explicitly computed without need to expand the vectors to length + $N$. + + The rounding step is required to invert the decoding because, while + cleartexts must be (implicitly) in the subspace $H$, they need not be the + output of $\tau_N$ for an _integer_ polynomial. The rounding step ensures + we can use integer polynomial plaintexts for the FHE operations. There are + multiple rounding mechanisms, and this attribute does not specify which is + used, because in theory two ciphertexts that have used different roundings + are still compatible, though they may have different noise growth patterns. + + The scaling parameter $\Delta$ is specified by the `scaling_factor`, which + are applied coefficient-wise using the same semantics as the + `constant_coefficient_encoding`. + + A typical flow for the CKKS scheme using this encoding would be to apply an + inverse FFT operation to invert the canonical embedding to be a polynomial + with real coefficients, then encrypt scale the resulting polynomial's + coefficients according to the scaling parameters, then round to get integer + coefficients. + + Example: + + ``` + #canonical_encoding = #lwe.inverse_canonical_encoding + ``` + }]; +} + +def LWE_FullCRTPackingEncoding + : LWE_EncodingAttrWithScalingParam<"FullCRTPackingEncoding", "full_crt_packing_encoding"> { + let summary = "An encoding of cleartexts via CRT slots."; + let description = [{ + This encoding maps a list of integers via the Chinese Remainder Theorem (CRT) into the plaintext space. + + Given a ring with irreducible ideal polynomial `f(x)` and coefficient + modulus `q`, `f(x)` can be decomposed modulo `q` into a direct product of + lower-degree polynomials. This allows full SIMD-style homomorphic operations + across the slots formed from each factor. + + This attribute can only be used in the context of on full CRT packing, where + the polynomial `f(x)` splits completely (into linear factors) and the number + of slots equals the degree of `f(x)`. This happens when `q` is prime and `q + = 1 mod n`. + + A `scaling_factor` is optionally applied on the scalar when converting from + a rounded floating point to an integer. + + Example: + + ``` + #coeff_encoding = #lwe.full_crt_packing_encoding + ``` + }]; +} + +def LWE_AnyPlaintextEncodingAttr : LWE_EncodingAttrForLWE<"PlaintextEncoding", "plaintext_encoding"> { + let returnType = "Attribute"; + let convertFromStorage = "$_self"; + string cppType = "Attribute"; + let predicate = Or<[ + LWE_ConstantCoefficientEncoding.predicate, + LWE_CoefficientEncoding.predicate, + LWE_InverseCanonicalEmbeddingEncoding.predicate, + LWE_FullCRTPackingEncoding.predicate + ]>; +} + +def LWE_PlaintextSpaceAttr : AttrDef { + let mnemonic = "plaintext_space"; + let description = [{ + An attribute describing the plaintext space and the transformation from + application data to plaintext space of an FHE scheme. + + The plaintext space information is the ring structure, which contains the + plaintext modulus $t$, which may be a power of two in the case of CGGI + ciphertexts, or a prime power for RLWE. LWE ciphertexts use the + ideal polynomial of degree 1 $x$. The plaintext modulus used in LWE-based + CGGI plaintexts describes the full message space $\mathbb{Z}_p$ including + the padding bits. The application data info attribute describes the space + $\mathbb{Z}_p'$ where $p' < p$ that the underlying message belongs to. + + For RLWE schemes, this will include the type of encoding of application data + integers to a plaintext space `Z_p[X]/X^N + 1`. This may be a constant + coefficient encoding, CRT-based packing for SIMD semantics, or other slot + packing. When using full CRT packing, the ring must split into linear + factors. The CKKS scheme will also include attributes describing the complex + encoding, including the scaling factor, which will change after + multiplication and rescaling. + }]; + + let parameters = (ins + "::mlir::polynomial::RingAttr":$ring, + LWE_AnyPlaintextEncodingAttr:$encoding + ); + + let assemblyFormat = "`<` struct(params) `>`"; + + let genVerifyDecl = 1; +} + +def LWE_KeyAttr : AttrDef { + let mnemonic = "key"; + let description = [{ + An attribute describing the key used for encrypting the ciphertext. + + This attribute includes a key identifier for the original key used to + encrypt the secret key. + + The `key_size` parameter is used to describe the number of polynomials of + the secret key. This is typically $1$ for RLWE ciphertexts and greater than + $1$ for LWE instances. A ciphertext encrypted with a `key_size` of $k$ will + have size $k+1$. + + The key basis describes the inner product used in the phase calculation in + decryption. This attribute is only supported for RLWE ciphertexts whose + `key_size` is $1$. An RLWE ciphertext is canonically encrypted against key + basis `(1, s)`. After a multiplication, its size will increase and the basis + will be `(1, s, s^2)`. The array that represents the key basis is + constructed by listing the powers of `s` at each position of the array. For + example, `(1, s, s^2)` corresponds to `[0, 1, 2]`, while `(1, s^2)` + corresponds to `[0, 2]`. + }]; + + let parameters = (ins + "::mlir::StringAttr":$id, + DefaultValuedParameter<"unsigned", "1">:$size, + OptionalArrayRefParameter<"unsigned int">:$basis + ); + + let assemblyFormat = "`<` struct(params) `>`"; +} + +def LWE_EncryptionTypeEnum : I32EnumAttr<"LweEncryptionType", "An enum attribute representing an encryption method", [ + I32EnumAttrCase<"msb", 0>, + I32EnumAttrCase<"lsb", 1>, + I32EnumAttrCase<"mix", 2> +]> { + let cppNamespace = "::mlir::heir::lwe"; +} + +def LWE_CiphertextSpaceAttr : AttrDef { + let mnemonic = "ciphertext_space"; + let description = [{ + An attribute describing the ciphertext space and the transformation from + plaintext space to ciphertext space of an FHE scheme. + + The ciphertext space information includes the ring structure, which contains + the ciphertext modulus $q$. Ciphertexts using an RNS representation for $q$ + will represent their ciphertext components in the ring attribute. Scalar LWE + ciphertexts (as opposed to RLWE) use an ideal polynomial of degree 1, $x$. + CGGI ciphertexts will typically use a power of two modulus. + + The ciphertext encoding info is used to describe the way the plaintext data + is encoded into the ciphertext (in the MSB, LSB, or mixed). + }]; + + let parameters = (ins + "::mlir::polynomial::RingAttr":$ring, + "::mlir::heir::lwe::LweEncryptionType":$encryption_type + ); + + let assemblyFormat = "`<` struct(params) `>`"; +} + +def LWE_ModulusChainAttr : AttrDef { + let mnemonic = "modulus_chain"; + let description = [{ + An attribute describing the elements of the modulus chain of an RLWE scheme. + }]; + + let parameters = (ins + ArrayRefParameter<"mlir::IntegerAttr">:$elements, + "int":$current + ); + + let assemblyFormat = "`<` `elements` `=` `<` $elements `>``,` `current` `=` $current `>`"; + + // let genVerifyDecl = 1; // Verify index into list +} + +#endif // LIB_DIALECT_LWE_IR_NEWLWEATTRIBUTES_TD_ diff --git a/lib/Dialect/LWE/IR/NewLWETypes.td b/lib/Dialect/LWE/IR/NewLWETypes.td new file mode 100644 index 000000000..ce576d9b6 --- /dev/null +++ b/lib/Dialect/LWE/IR/NewLWETypes.td @@ -0,0 +1,73 @@ +#ifndef LIB_DIALECT_LWE_IR_NEWLWETYPES_TD_ +#define LIB_DIALECT_LWE_IR_NEWLWETYPES_TD_ + +include "lib/Dialect/LWE/IR/LWEDialect.td" +include "lib/Dialect/LWE/IR/NewLWEAttributes.td" + +include "mlir/IR/DialectBase.td" +include "mlir/IR/BuiltinTypeInterfaces.td" +include "mlir/IR/AttrTypeBase.td" + +// A base class for all types in this dialect +class LWE_Type traits = []> + : TypeDef { + let mnemonic = typeMnemonic; + let assemblyFormat = "`<` struct(params) `>`"; +} + +// This file defines new LWE types following +// [#785](https://github.com/google/heir/issues/785). + +def NewLWESecretKey : LWE_Type<"NewLWESecretKey", "new_lwe_secret_key"> { + let summary = "A secret key for LWE"; + let parameters = (ins + "KeyAttr":$key, + "::mlir::polynomial::RingAttr":$ring + ); +} + +def NewLWEPublicKey : LWE_Type<"NewLWEPublicKey", "new_lwe_public_key"> { + let summary = "A public key for LWE"; + let parameters = (ins + "KeyAttr":$key, + "::mlir::polynomial::RingAttr":$ring + ); +} + +def NewLWESecretOrPublicKey : AnyTypeOf<[NewLWESecretKey, NewLWEPublicKey]>; + +def NewLWEPlaintext : LWE_Type<"NewLWEPlaintext", "new_lwe_plaintext"> { + let summary = "A plaintext type"; + let parameters = (ins + "ApplicationDataAttr":$application_data, + "PlaintextSpaceAttr":$plaintext_space + ); +} + +def NewLWEPlaintextLike : TypeOrContainer; + +def NewLWECiphertext : LWE_Type<"NewLWECiphertext", "new_lwe_ciphertext"> { + let summary = "A ciphertext type"; + + let description = [{ + An LWE ciphertext will always contain the application data, plaintext space, + ciphertext space, and key information. + + A modulus chain is optionally specified for parameter choices in RLWE + schemes that use more than one of modulus. When no modulus chain is + specified, the ciphertext modulus is always the ciphertext ring's + coefficient modulus. + }]; + + let parameters = (ins + "ApplicationDataAttr":$application_data, + "PlaintextSpaceAttr":$plaintext_space, + "CiphertextSpaceAttr":$ciphertext_space, + "KeyAttr":$key, + OptionalParameter<"ModulusChainAttr">:$modulus_chain + ); +} + +def NewLWECiphertextLike : TypeOrContainer; + +#endif // LIB_DIALECT_LWE_IR_NEWLWETYPES_TD_ diff --git a/tests/lwe/attributes.mlir b/tests/lwe/attributes.mlir index a57629043..73afee962 100644 --- a/tests/lwe/attributes.mlir +++ b/tests/lwe/attributes.mlir @@ -99,3 +99,78 @@ func.func @test_valid_inverse_canonical_embedding_encoding(%coeffs1 : tensor<10x %rlwe_ciphertext = tensor.from_elements %poly1, %poly2 : tensor<2x!polynomial.polynomial, #inverse_canonical_enc> return } + +// ----- + +#preserve_overflow = #lwe.preserve_overflow<> +#application = #lwe.application_data + +// CHECK-LABEL: test_fn +func.func @test_fn() { + return +} + +// ----- + +#application = #lwe.application_data + +// CHECK-LABEL: test_fn +func.func @test_fn() { + return +} + +// ----- + +#generator4 = #polynomial.int_polynomial<1 + x**1024> +#ring4 = #polynomial.ring +#inverse_canonical_enc = #lwe.inverse_canonical_encoding + +#plaintext_space = #lwe.plaintext_space + +// CHECK-LABEL: test_fn +func.func @test_fn() { + return +} + +// ----- + +#poly = #polynomial.int_polynomial +#ring = #polynomial.ring +#crt = #lwe.full_crt_packing_encoding +#plaintext_space = #lwe.plaintext_space + +// CHECK-LABEL: test_fn +func.func @test_fn() { + return +} + +// ----- + +#key = #lwe.key +#key_rlwe_rotate = #lwe.key +#key_rlwe_2 = #lwe.key + +// CHECK-LABEL: test_fn +func.func @test_fn() { + return +} + +// ----- + +#generator4 = #polynomial.int_polynomial<1 + x**1024> +#ring4 = #polynomial.ring + +#ciphertext_space = #lwe.ciphertext_space + +// CHECK-LABEL: test_fn +func.func @test_fn() { + return +} +// ----- + +#modulus_chain = #lwe.modulus_chain, current = 0> + +// CHECK-LABEL: test_fn +func.func @test_fn() { + return +} diff --git a/tests/lwe/attributes_errors.mlir b/tests/lwe/attributes_errors.mlir index 583cbeff6..b8215be55 100644 --- a/tests/lwe/attributes_errors.mlir +++ b/tests/lwe/attributes_errors.mlir @@ -64,3 +64,24 @@ func.func @test_invalid_inverse_canonical_embedding_encoding() { %a = arith.constant dense<[2, 2, 5]> : tensor<3xi32, #inverse_canonical_enc2> return } + +// ----- + +// expected-error@below {{overflow must be either preserve_overflow or no_overflow, but found i1}} +#application = #lwe.application_data + +// ----- + +#poly = #polynomial.int_polynomial +#ring = #polynomial.ring +#crt = #lwe.full_crt_packing_encoding +// expected-error@below {{polynomial modulus must be of the form x^n + 1}} +#plaintext_space = #lwe.plaintext_space + +// ----- + +#poly = #polynomial.int_polynomial +#ring = #polynomial.ring +#crt = #lwe.full_crt_packing_encoding +// expected-error@below {{modulus must be 1 mod n for full CRT packing}} +#plaintext_space = #lwe.plaintext_space diff --git a/tests/lwe/types.mlir b/tests/lwe/types.mlir index b5b10d2fc..73e1ed6ad 100644 --- a/tests/lwe/types.mlir +++ b/tests/lwe/types.mlir @@ -32,3 +32,41 @@ func.func @test_valid_lwe_ciphertext_unspecified(%arg0 : !ciphertext_noparams) - func.func @test_valid_rlwe_ciphertext(%arg0 : !ciphertext_rlwe) -> !ciphertext_rlwe { return %arg0 : !ciphertext_rlwe } + +#key = #lwe.key +!secret_key = !lwe.new_lwe_secret_key + +// CHECK-LABEL: test_new_lwe_secret_key +func.func @test_new_lwe_secret_key(%arg0 : !secret_key) -> !secret_key { + return %arg0 :!secret_key +} + +!public_key = !lwe.new_lwe_public_key + +// CHECK-LABEL: test_new_lwe_public_key +func.func @test_new_lwe_public_key(%arg0 : !public_key) -> !public_key { + return %arg0 : !public_key +} + + +#preserve_overflow = #lwe.preserve_overflow<> +#application_data = #lwe.application_data +#inverse_canonical_enc = #lwe.inverse_canonical_encoding +#plaintext_space = #lwe.plaintext_space + +!new_lwe_plaintext = !lwe.new_lwe_plaintext + +// CHECK-LABEL: test_new_lwe_plaintext +func.func @test_new_lwe_plaintext(%arg0 : !new_lwe_plaintext) -> !new_lwe_plaintext { + return %arg0 : !new_lwe_plaintext +} + +#ciphertext_space = #lwe.ciphertext_space +#modulus_chain = #lwe.modulus_chain, current = 0> + +!new_lwe_ciphertext = !lwe.new_lwe_ciphertext + +// CHECK-LABEL: test_new_lwe_ciphertext +func.func @test_new_lwe_ciphertext(%arg0 : !new_lwe_ciphertext) -> !new_lwe_ciphertext { + return %arg0 : !new_lwe_ciphertext +}