Skip to content

Commit

Permalink
ML-DSA verifcation works for the KATs at least.
Browse files Browse the repository at this point in the history
  • Loading branch information
xvzcf committed Jun 20, 2024
1 parent 64b6dd9 commit cdae0f3
Show file tree
Hide file tree
Showing 11 changed files with 123 additions and 42 deletions.
44 changes: 36 additions & 8 deletions libcrux-ml-dsa/src/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,10 @@ impl PolynomialRingElement {
difference
}

// TODO: Revisit this function when doing the range analysis and testing
// additional KATs.
#[inline(always)]
pub(crate) fn infinity_norm_exceeds(&self, value: i32) -> bool {
if value > (FIELD_MODULUS - 1) / 8 {
return true;
}

pub(crate) fn infinity_norm_exceeds(&self, bound: i32) -> bool {
let mut exceeds = false;

// It is ok to leak which coefficient violates the bound since
Expand All @@ -48,12 +46,23 @@ impl PolynomialRingElement {
// TODO: We can break out of this loop early if need be, but the most
// straightforward way to do so (returning false) will not go through hax;
// revisit if performance is impacted.
for coefficient in self.coefficients.iter() {
// Normalize the coefficient
for coefficient in self.coefficients.into_iter() {
debug_assert!(
coefficient > -FIELD_MODULUS && coefficient < FIELD_MODULUS,
"coefficient is {}",
coefficient
);
// This norm is calculated using the absolute value of the
// signed representative in the range:
//
// -FIELD_MODULUS / 2 < r <= FIELD_MODULUS / 2.
//
// So if the coefficient is negative, get its absolute value, but
// don't convert it into a different representation.
let sign = coefficient >> 31;
let normalized = coefficient - (sign & (2 * coefficient));

exceeds |= normalized >= value;
exceeds |= normalized >= bound;
}

exceeds
Expand Down Expand Up @@ -120,6 +129,25 @@ pub(crate) fn montgomery_multiply_fe_by_fer(
montgomery_reduce((fe as i64) * (fer as i64))
}

fn reduce(fe: FieldElement) -> FieldElement {
let quotient = (fe + (1 << 22)) >> 23;

fe - (quotient * FIELD_MODULUS)
}

pub(crate) fn shift_coefficients_left_then_reduce(
re: PolynomialRingElement,
shift_by: usize,
) -> PolynomialRingElement {
let mut out = PolynomialRingElement::ZERO;

for i in 0..COEFFICIENTS_IN_RING_ELEMENT {
out.coefficients[i] = reduce(re.coefficients[i] << shift_by);
}

out
}

// Splits t ∈ {0, ..., q-1} into t0 and t1 with a = t1*2ᴰ + t0
// and -2ᴰ⁻¹ < t0 < 2ᴰ⁻¹. Returns t0 and t1 computed as.
//
Expand Down
2 changes: 1 addition & 1 deletion libcrux-ml-dsa/src/encoding/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ pub(crate) fn deserialize_to_vector_then_ntt<
let mut ring_elements = [PolynomialRingElement::ZERO; DIMENSION];

for (i, bytes) in serialized.chunks(RING_ELEMENT_SIZE).enumerate() {
ring_elements[i] = ntt::<0>(deserialize::<ETA>(bytes));
ring_elements[i] = ntt(deserialize::<ETA>(bytes));
}

ring_elements
Expand Down
2 changes: 1 addition & 1 deletion libcrux-ml-dsa/src/encoding/t0.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ pub(crate) fn deserialize_to_vector_then_ntt<const DIMENSION: usize>(
let mut ring_elements = [PolynomialRingElement::ZERO; DIMENSION];

for (i, bytes) in serialized.chunks(RING_ELEMENT_OF_T0S_SIZE).enumerate() {
ring_elements[i] = ntt::<0>(deserialize(bytes));
ring_elements[i] = ntt(deserialize(bytes));
}

ring_elements
Expand Down
2 changes: 2 additions & 0 deletions libcrux-ml-dsa/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ mod ntt;
mod sample;
mod utils;

pub use ml_dsa_generic::VerificationError;

pub mod ml_dsa_44;
pub mod ml_dsa_65;
pub mod ml_dsa_87;
15 changes: 7 additions & 8 deletions libcrux-ml-dsa/src/matrix.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::{
arithmetic::PolynomialRingElement,
arithmetic::{shift_coefficients_left_then_reduce, PolynomialRingElement},
constants::BITS_IN_LOWER_PART_OF_T,
ntt::{invert_ntt_montgomery, ntt, ntt_multiply_montgomery},
sample::sample_ring_element_uniform,
Expand Down Expand Up @@ -36,7 +36,7 @@ pub(crate) fn compute_As1_plus_s2<const ROWS_IN_A: usize, const COLUMNS_IN_A: us

for (i, row) in A_as_ntt.iter().enumerate() {
for (j, ring_element) in row.iter().enumerate() {
let product = ntt_multiply_montgomery(ring_element, &ntt::<0>(s1[j]));
let product = ntt_multiply_montgomery(ring_element, &ntt(s1[j]));
result[i] = result[i].add(&product);
}

Expand All @@ -58,7 +58,7 @@ pub(crate) fn compute_A_times_mask<const ROWS_IN_A: usize, const COLUMNS_IN_A: u

for (i, row) in A_as_ntt.iter().enumerate() {
for (j, ring_element) in row.iter().enumerate() {
let product = ntt_multiply_montgomery(ring_element, &ntt::<0>(mask[j]));
let product = ntt_multiply_montgomery(ring_element, &ntt(mask[j]));
result[i] = result[i].add(&product);
}

Expand Down Expand Up @@ -126,15 +126,14 @@ pub(crate) fn compute_w_approx<const ROWS_IN_A: usize, const COLUMNS_IN_A: usize

for (i, row) in A_as_ntt.iter().enumerate() {
for (j, ring_element) in row.iter().enumerate() {
let product = ntt_multiply_montgomery(&ring_element, &ntt::<0>(signer_response[j]));
let product = ntt_multiply_montgomery(&ring_element, &ntt(signer_response[j]));

result[i] = result[i].add(&product);
}

let challenge_times_t1_shifted = ntt_multiply_montgomery(
&verifier_challenge_as_ntt,
&ntt::<BITS_IN_LOWER_PART_OF_T>(t1[i]),
);
let t1_shifted = shift_coefficients_left_then_reduce(t1[i], BITS_IN_LOWER_PART_OF_T);
let challenge_times_t1_shifted =
ntt_multiply_montgomery(&verifier_challenge_as_ntt, &ntt(t1_shifted));
result[i] = invert_ntt_montgomery(result[i].sub(&challenge_times_t1_shifted));
}

Expand Down
27 changes: 26 additions & 1 deletion libcrux-ml-dsa/src/ml_dsa_44.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::constants::*;
use crate::{constants::*, VerificationError};

// ML-DSA-44-specific parameters

Expand All @@ -17,6 +17,8 @@ const ERROR_RING_ELEMENT_SIZE: usize =
const GAMMA1_EXPONENT: usize = 17;
const GAMMA2: i32 = (FIELD_MODULUS - 1) / 88;

const BETA: i32 = (ONES_IN_VERIFIER_CHALLENGE * ETA) as i32;

// To sample a value in the interval [-(GAMMA - 1), GAMMA], we can sample a
// value (say 'v') in the interval [0, (2 * GAMMA) - 1] and then compute
// GAMMA - v. This can be done in 18 bits when GAMMA is 2^{17}.
Expand Down Expand Up @@ -105,3 +107,26 @@ pub fn sign(

MLDSA44Signature(signature)
}

/// Verify an ML-DSA-44 Signature
pub fn verify(
verification_key: MLDSA44VerificationKey,
message: &[u8],
signature: MLDSA44Signature,
) -> Result<(), VerificationError> {
crate::ml_dsa_generic::verify::<
ROWS_IN_A,
COLUMNS_IN_A,
SIGNATURE_SIZE,
VERIFICATION_KEY_SIZE,
GAMMA1_EXPONENT,
GAMMA1_RING_ELEMENT_SIZE,
GAMMA2,
BETA,
COMMITMENT_RING_ELEMENT_SIZE,
COMMITMENT_VECTOR_SIZE,
COMMITMENT_HASH_SIZE,
ONES_IN_VERIFIER_CHALLENGE,
MAX_ONES_IN_HINT,
>(verification_key.0, message, signature.0)
}
4 changes: 1 addition & 3 deletions libcrux-ml-dsa/src/ml_dsa_65.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::constants::*;
use crate::{constants::*, VerificationError};

// ML-DSA-65-specific parameters

Expand Down Expand Up @@ -110,8 +110,6 @@ pub fn sign(
MLDSA65Signature(signature)
}

pub use crate::ml_dsa_generic::VerificationError;

/// Verify an ML-DSA-65 Signature
pub fn verify(
verification_key: MLDSA65VerificationKey,
Expand Down
27 changes: 26 additions & 1 deletion libcrux-ml-dsa/src/ml_dsa_87.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::constants::*;
use crate::{constants::*, VerificationError};

// ML-DSA-87 parameters

Expand Down Expand Up @@ -29,6 +29,8 @@ const ONES_IN_VERIFIER_CHALLENGE: usize = 60;

const GAMMA2: i32 = (FIELD_MODULUS - 1) / 32;

const BETA: i32 = (ONES_IN_VERIFIER_CHALLENGE * ETA) as i32;

// Commitment coefficients are in the interval: [0, ((FIELD_MODULUS − 1)/2γ2) − 1]
// ((FIELD_MODULUS − 1)/2γ2) − 1 = 15, which means we need 4 bits to represent a
// coefficient.
Expand Down Expand Up @@ -107,3 +109,26 @@ pub fn sign(

MLDSA87Signature(signature)
}

/// Verify an ML-DSA-87 Signature
pub fn verify(
verification_key: MLDSA87VerificationKey,
message: &[u8],
signature: MLDSA87Signature,
) -> Result<(), VerificationError> {
crate::ml_dsa_generic::verify::<
ROWS_IN_A,
COLUMNS_IN_A,
SIGNATURE_SIZE,
VERIFICATION_KEY_SIZE,
GAMMA1_EXPONENT,
GAMMA1_RING_ELEMENT_SIZE,
GAMMA2,
BETA,
COMMITMENT_RING_ELEMENT_SIZE,
COMMITMENT_VECTOR_SIZE,
COMMITMENT_HASH_SIZE,
ONES_IN_VERIFIER_CHALLENGE,
MAX_ONES_IN_HINT,
>(verification_key.0, message, signature.0)
}
4 changes: 2 additions & 2 deletions libcrux-ml-dsa/src/ml_dsa_generic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ pub(crate) fn sign<
};

let verifier_challenge_as_ntt =
ntt::<0>(sample_challenge_ring_element::<ONES_IN_VERIFIER_CHALLENGE>(
ntt(sample_challenge_ring_element::<ONES_IN_VERIFIER_CHALLENGE>(
commitment_hash[0..VERIFIER_CHALLENGE_SEED_SIZE]
.try_into()
.unwrap(),
Expand Down Expand Up @@ -496,7 +496,7 @@ pub(crate) fn verify<
};

let verifier_challenge_as_ntt =
ntt::<0>(sample_challenge_ring_element::<ONES_IN_VERIFIER_CHALLENGE>(
ntt(sample_challenge_ring_element::<ONES_IN_VERIFIER_CHALLENGE>(
signature.commitment_hash[0..VERIFIER_CHALLENGE_SEED_SIZE]
.try_into()
.unwrap(),
Expand Down
30 changes: 14 additions & 16 deletions libcrux-ml-dsa/src/ntt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ const ZETAS_TIMES_MONTGOMERY_R: [FieldElementTimesMontgomeryR; 256] = [
];

#[inline(always)]
fn ntt_at_layer<const SHIFT_COEFFICIENT_BY: usize>(
fn ntt_at_layer(
zeta_i: &mut usize,
mut re: PolynomialRingElement,
layer: usize,
Expand All @@ -50,31 +50,29 @@ fn ntt_at_layer<const SHIFT_COEFFICIENT_BY: usize>(

for j in offset..offset + step {
let t = montgomery_multiply_fe_by_fer(
re.coefficients[j + step] << SHIFT_COEFFICIENT_BY,
re.coefficients[j + step],
ZETAS_TIMES_MONTGOMERY_R[*zeta_i],
);
re.coefficients[j + step] = (re.coefficients[j] << SHIFT_COEFFICIENT_BY) - t;
re.coefficients[j] = (re.coefficients[j] << SHIFT_COEFFICIENT_BY) + t;
re.coefficients[j + step] = re.coefficients[j] - t;
re.coefficients[j] = re.coefficients[j] + t;
}
}

re
}

#[inline(always)]
pub(crate) fn ntt<const SHIFT_COEFFICIENT_BY: usize>(
mut re: PolynomialRingElement,
) -> PolynomialRingElement {
pub(crate) fn ntt(mut re: PolynomialRingElement) -> PolynomialRingElement {
let mut zeta_i = 0;

re = ntt_at_layer::<SHIFT_COEFFICIENT_BY>(&mut zeta_i, re, 7);
re = ntt_at_layer::<SHIFT_COEFFICIENT_BY>(&mut zeta_i, re, 6);
re = ntt_at_layer::<SHIFT_COEFFICIENT_BY>(&mut zeta_i, re, 5);
re = ntt_at_layer::<SHIFT_COEFFICIENT_BY>(&mut zeta_i, re, 4);
re = ntt_at_layer::<SHIFT_COEFFICIENT_BY>(&mut zeta_i, re, 3);
re = ntt_at_layer::<SHIFT_COEFFICIENT_BY>(&mut zeta_i, re, 2);
re = ntt_at_layer::<SHIFT_COEFFICIENT_BY>(&mut zeta_i, re, 1);
re = ntt_at_layer::<SHIFT_COEFFICIENT_BY>(&mut zeta_i, re, 0);
re = ntt_at_layer(&mut zeta_i, re, 7);
re = ntt_at_layer(&mut zeta_i, re, 6);
re = ntt_at_layer(&mut zeta_i, re, 5);
re = ntt_at_layer(&mut zeta_i, re, 4);
re = ntt_at_layer(&mut zeta_i, re, 3);
re = ntt_at_layer(&mut zeta_i, re, 2);
re = ntt_at_layer(&mut zeta_i, re, 1);
re = ntt_at_layer(&mut zeta_i, re, 0);

re
}
Expand Down Expand Up @@ -213,7 +211,7 @@ mod tests {
15979738, 1459696, 8351548, 3335586, 1150210, -2462074, -4642922, 4538634, 1858098,
];

assert_eq!(ntt::<0>(re).coefficients, expected_coefficients);
assert_eq!(ntt(re).coefficients, expected_coefficients);
}

#[test]
Expand Down
8 changes: 7 additions & 1 deletion libcrux-ml-dsa/tests/nistkats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ fn ml_dsa_44_nist_known_answer_tests() {
signature_hash, kat.sha3_256_hash_of_signature,
"signature_hash != kat.sha3_256_hash_of_signature"
);

libcrux_ml_dsa::ml_dsa_44::verify(key_pair.verification_key, &message, signature)
.expect("Verification should pass since the signature was honestly generated");
}
}

Expand Down Expand Up @@ -104,7 +107,7 @@ fn ml_dsa_65_nist_known_answer_tests() {
);

libcrux_ml_dsa::ml_dsa_65::verify(key_pair.verification_key, &message, signature)
.expect("Signature was generated honestly, so verification should pass");
.expect("Verification should pass since the signature was honestly generated");
}
}

Expand Down Expand Up @@ -144,5 +147,8 @@ fn ml_dsa_87_nist_known_answer_tests() {
signature_hash, kat.sha3_256_hash_of_signature,
"signature_hash != kat.sha3_256_hash_of_signature"
);

libcrux_ml_dsa::ml_dsa_87::verify(key_pair.verification_key, &message, signature)
.expect("Verification should pass since the signature was honestly generated");
}
}

0 comments on commit cdae0f3

Please sign in to comment.