Skip to content

Commit

Permalink
First draft of ML-DSA verification.
Browse files Browse the repository at this point in the history
  • Loading branch information
xvzcf committed Jun 20, 2024
1 parent be1e630 commit 64b6dd9
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 41 deletions.
63 changes: 63 additions & 0 deletions libcrux-ml-dsa/src/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,60 @@ pub(crate) fn make_hint<const DIMENSION: usize, const GAMMA2: i32>(
(hint, true_hints)
}

#[inline(always)]
pub(crate) fn use_hint_value<const GAMMA2: i32>(r: i32, hint: bool) -> i32 {
let (r0, r1) = decompose::<GAMMA2>(r);

if hint == false {
return r1;
}

match GAMMA2 {
95_232 => {
if r0 > 0 {
if r1 == 43 {
0
} else {
r1 + 1
}
} else {
if r1 == 0 {
43
} else {
r1 - 1
}
}
}

261_888 => {
if r0 > 0 {
(r1 + 1) & 15
} else {
(r1 - 1) & 15
}
}

_ => unreachable!(),
}
}

#[inline(always)]
pub(crate) fn use_hint<const DIMENSION: usize, const GAMMA2: i32>(
hint: [[bool; COEFFICIENTS_IN_RING_ELEMENT]; DIMENSION],
re_vector: [PolynomialRingElement; DIMENSION],
) -> [PolynomialRingElement; DIMENSION] {
let mut result = [PolynomialRingElement::ZERO; DIMENSION];

for i in 0..DIMENSION {
for j in 0..COEFFICIENTS_IN_RING_ELEMENT {
result[i].coefficients[j] =
use_hint_value::<GAMMA2>(re_vector[i].coefficients[j], hint[i][j]);
}
}

result
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -310,4 +364,13 @@ mod tests {
assert_eq!(decompose::<261_888>(6645076), (-164012, 13));
assert_eq!(decompose::<261_888>(7806985), (-49655, 15));
}

#[test]
fn test_use_hint_value() {
assert_eq!(use_hint_value::<95_232>(7622170, false), 40);
assert_eq!(use_hint_value::<95_232>(2332762, true), 13);

assert_eq!(use_hint_value::<261_888>(7691572, false), 15);
assert_eq!(use_hint_value::<261_888>(6635697, true), 12);
}
}
13 changes: 8 additions & 5 deletions libcrux-ml-dsa/src/matrix.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::{
arithmetic::PolynomialRingElement,
constants::BITS_IN_LOWER_PART_OF_T,
ntt::{invert_ntt_montgomery, ntt, ntt_multiply_montgomery},
sample::sample_ring_element_uniform,
};
Expand Down Expand Up @@ -117,21 +118,23 @@ pub(crate) fn subtract_vectors<const DIMENSION: usize>(
#[inline(always)]
pub(crate) fn compute_w_approx<const ROWS_IN_A: usize, const COLUMNS_IN_A: usize>(
A_as_ntt: &[[PolynomialRingElement; COLUMNS_IN_A]; ROWS_IN_A],
signer_response_as_ntt: [PolynomialRingElement; COLUMNS_IN_A],
signer_response: [PolynomialRingElement; COLUMNS_IN_A],
verifier_challenge_as_ntt: PolynomialRingElement,
t1_shifted_as_ntt: [PolynomialRingElement; ROWS_IN_A],
t1: [PolynomialRingElement; ROWS_IN_A],
) -> [PolynomialRingElement; ROWS_IN_A] {
let mut result = [PolynomialRingElement::ZERO; ROWS_IN_A];

for (i, row) in A_as_ntt.iter().enumerate() {
for (j, ring_element) in row.iter().enumerate() {
let product = ntt_multiply_montgomery(&ring_element, &signer_response_as_ntt[j]);
let product = ntt_multiply_montgomery(&ring_element, &ntt::<0>(signer_response[j]));

result[i] = result[i].add(&product);
}

let challenge_times_t1_shifted =
ntt_multiply_montgomery(&verifier_challenge_as_ntt, &t1_shifted_as_ntt[i]);
let challenge_times_t1_shifted = ntt_multiply_montgomery(
&verifier_challenge_as_ntt,
&ntt::<BITS_IN_LOWER_PART_OF_T>(t1[i]),
);
result[i] = invert_ntt_montgomery(result[i].sub(&challenge_times_t1_shifted));
}

Expand Down
6 changes: 6 additions & 0 deletions libcrux-ml-dsa/src/ml_dsa_65.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ const ONES_IN_VERIFIER_CHALLENGE: usize = 49;

const GAMMA2: i32 = (FIELD_MODULUS - 1) / 32;

const BETA: i32 = (ONES_IN_VERIFIER_CHALLENGE * ETA) as i32;

// Commitment coefficients are in the interval: [0, ((FIELD_MODULUS − 1)/2γ2) − 1]
// ((FIELD_MODULUS − 1)/2γ2) − 1 = 15, which means we need 4 bits to represent a
// coefficient.
Expand Down Expand Up @@ -123,6 +125,10 @@ pub fn verify(
VERIFICATION_KEY_SIZE,
GAMMA1_EXPONENT,
GAMMA1_RING_ELEMENT_SIZE,
GAMMA2,
BETA,
COMMITMENT_RING_ELEMENT_SIZE,
COMMITMENT_VECTOR_SIZE,
COMMITMENT_HASH_SIZE,
ONES_IN_VERIFIER_CHALLENGE,
MAX_ONES_IN_HINT,
Expand Down
104 changes: 68 additions & 36 deletions libcrux-ml-dsa/src/ml_dsa_generic.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
use crate::{
arithmetic::{
decompose_vector, make_hint, power2round_vector, vector_infinity_norm_exceeds,
decompose_vector, make_hint, power2round_vector, use_hint, vector_infinity_norm_exceeds,
PolynomialRingElement,
},
constants::*,
encoding,
hash_functions::H,
matrix::{
add_vectors, compute_A_times_mask, compute_As1_plus_s2, expand_to_A, subtract_vectors,
vector_times_ring_element,
add_vectors, compute_A_times_mask, compute_As1_plus_s2, compute_w_approx, expand_to_A,
subtract_vectors, vector_times_ring_element,
},
ntt::ntt,
sample::{sample_challenge_ring_element, sample_error_vector, sample_mask_vector},
Expand Down Expand Up @@ -48,9 +48,9 @@ pub(super) fn deserialize_verification_key<
let (seed_for_A, serialized_remaining) = serialized.split_at(SEED_FOR_A_SIZE);

for i in 0..ROWS_IN_A {
t1[i] = ntt::<BITS_IN_LOWER_PART_OF_T>(encoding::t1::deserialize(
t1[i] = encoding::t1::deserialize(
&serialized_remaining[i * RING_ELEMENT_OF_T1S_SIZE..(i + 1) * RING_ELEMENT_OF_T1S_SIZE],
));
);
}

(seed_for_A.try_into().unwrap(), t1)
Expand Down Expand Up @@ -166,8 +166,11 @@ pub(crate) fn generate_key_pair<
(signing_key_serialized, verification_key_serialized)
}

#[derive(Debug)]
pub enum VerificationError {
MalformedHintError,
SignerResponseExceedsBoundError,
CommitmentHashesDontMatchError,
}

struct Signature<
Expand All @@ -185,7 +188,7 @@ impl<const COMMITMENT_HASH_SIZE: usize, const COLUMNS_IN_A: usize, const ROWS_IN
{
#[allow(non_snake_case)]
#[inline(always)]
pub(super) fn serialize<
pub(crate) fn serialize<
const GAMMA1_EXPONENT: usize,
const GAMMA1_RING_ELEMENT_SIZE: usize,
const MAX_ONES_IN_HINT: usize,
Expand All @@ -208,29 +211,25 @@ impl<const COMMITMENT_HASH_SIZE: usize, const COLUMNS_IN_A: usize, const ROWS_IN
offset += GAMMA1_RING_ELEMENT_SIZE;
}

for i in offset..offset + (MAX_ONES_IN_HINT + ROWS_IN_A) {
signature[i] = 0;
}

let mut one_count = 0;
let mut true_hints_seen = 0;
let hint_serialized = &mut signature[offset..];

for i in 0..ROWS_IN_A {
for (j, hint) in self.hint[i].into_iter().enumerate() {
if hint == true {
hint_serialized[one_count] = j as u8;
one_count += 1;
hint_serialized[true_hints_seen] = j as u8;
true_hints_seen += 1;
}
}
hint_serialized[MAX_ONES_IN_HINT + i] = one_count as u8;
hint_serialized[MAX_ONES_IN_HINT + i] = true_hints_seen as u8;
}

signature
}

#[allow(non_snake_case)]
#[inline(always)]
pub(super) fn deserialize<
pub(crate) fn deserialize<
const GAMMA1_EXPONENT: usize,
const GAMMA1_RING_ELEMENT_SIZE: usize,
const MAX_ONES_IN_HINT: usize,
Expand All @@ -239,32 +238,29 @@ impl<const COMMITMENT_HASH_SIZE: usize, const COLUMNS_IN_A: usize, const ROWS_IN
serialized: [u8; SIGNATURE_SIZE],
) -> Result<Self, VerificationError> {
let (commitment_hash, rest_of_serialized) = serialized.split_at(COMMITMENT_HASH_SIZE);
let (signer_response_serialized, hint_serialized) =
rest_of_serialized.split_at(GAMMA1_RING_ELEMENT_SIZE * COLUMNS_IN_A);

let mut signer_response = [PolynomialRingElement::ZERO; COLUMNS_IN_A];

let mut offset = 0;

for i in 0..COLUMNS_IN_A {
signer_response[i] = encoding::gamma1::deserialize::<GAMMA1_EXPONENT>(
&rest_of_serialized[offset..offset + GAMMA1_RING_ELEMENT_SIZE],
&signer_response_serialized
[i * GAMMA1_RING_ELEMENT_SIZE..(i + 1) * GAMMA1_RING_ELEMENT_SIZE],
);

offset += GAMMA1_RING_ELEMENT_SIZE;
}

// While there are several ways to encode the same hint vector, we only
// allow one such encoding, to ensure "strong unforgeability".
let hint_serialized = &serialized[offset..];

// While there are several ways to encode the same hint vector, we
// allow only one such encoding, to ensure strong unforgeability.
let mut hint = [[false; COEFFICIENTS_IN_RING_ELEMENT]; ROWS_IN_A];

let mut previous_true_hints_count = 0usize;
let mut previous_true_hints_seen = 0usize;

for i in 0..ROWS_IN_A {
let current_true_hints_count = hint_serialized[MAX_ONES_IN_HINT + i] as usize;
let current_true_hints_seen = hint_serialized[MAX_ONES_IN_HINT + i] as usize;

if (current_true_hints_count < previous_true_hints_count)
|| (previous_true_hints_count > MAX_ONES_IN_HINT)
if (current_true_hints_seen < previous_true_hints_seen)
|| (previous_true_hints_seen > MAX_ONES_IN_HINT)
{
// the true hints seen should be increasing
//
Expand All @@ -273,19 +269,19 @@ impl<const COMMITMENT_HASH_SIZE: usize, const COLUMNS_IN_A: usize, const ROWS_IN
return Err(VerificationError::MalformedHintError);
}

for j in previous_true_hints_count..current_true_hints_count {
if j > previous_true_hints_count && hint_serialized[j] <= hint_serialized[j - 1] {
for j in previous_true_hints_seen..current_true_hints_seen {
if j > previous_true_hints_seen && hint_serialized[j] <= hint_serialized[j - 1] {
// indices of true hints for a specific polynomial should be
// increasing
return Err(VerificationError::MalformedHintError);
}

hint[i][hint_serialized[j] as usize] = true;
}
previous_true_hints_count = current_true_hints_count;
previous_true_hints_seen = current_true_hints_seen;
}

for j in previous_true_hints_count..MAX_ONES_IN_HINT {
for j in previous_true_hints_seen..MAX_ONES_IN_HINT {
if hint_serialized[j] != 0 {
// ensures padding indices are zero
return Err(VerificationError::MalformedHintError);
Expand Down Expand Up @@ -388,14 +384,14 @@ pub(crate) fn sign<
let (w0, commitment) = decompose_vector::<ROWS_IN_A, GAMMA2>(A_times_mask);

let commitment_hash: [u8; COMMITMENT_HASH_SIZE] = {
let commitment_encoded = encoding::commitment::serialize_vector::<
let commitment_serialized = encoding::commitment::serialize_vector::<
ROWS_IN_A,
COMMITMENT_RING_ELEMENT_SIZE,
COMMITMENT_VECTOR_SIZE,
>(commitment);

let mut hash_input = message_representative.to_vec();
hash_input.extend_from_slice(&commitment_encoded);
hash_input.extend_from_slice(&commitment_serialized);

H::<COMMITMENT_HASH_SIZE>(&hash_input[..])
};
Expand Down Expand Up @@ -459,6 +455,10 @@ pub(crate) fn verify<
const VERIFICATION_KEY_SIZE: usize,
const GAMMA1_EXPONENT: usize,
const GAMMA1_RING_ELEMENT_SIZE: usize,
const GAMMA2: i32,
const BETA: i32,
const COMMITMENT_RING_ELEMENT_SIZE: usize,
const COMMITMENT_VECTOR_SIZE: usize,
const COMMITMENT_HASH_SIZE: usize,
const ONES_IN_VERIFIER_CHALLENGE: usize,
const MAX_ONES_IN_HINT: usize,
Expand All @@ -467,7 +467,7 @@ pub(crate) fn verify<
message: &[u8],
signature_serialized: [u8; SIGNATURE_SIZE],
) -> Result<(), VerificationError> {
let (seed_for_A, t1_as_ntt) = deserialize_verification_key::<ROWS_IN_A, VERIFICATION_KEY_SIZE>(
let (seed_for_A, t1) = deserialize_verification_key::<ROWS_IN_A, VERIFICATION_KEY_SIZE>(
verification_key_serialized,
);

Expand All @@ -478,6 +478,13 @@ pub(crate) fn verify<
SIGNATURE_SIZE,
>(signature_serialized)?;

if vector_infinity_norm_exceeds::<COLUMNS_IN_A>(
signature.signer_response,
(2 << GAMMA1_EXPONENT) - BETA,
) {
return Err(VerificationError::SignerResponseExceedsBoundError);
}

let A_as_ntt = expand_to_A::<ROWS_IN_A, COLUMNS_IN_A>(into_padded_array(&seed_for_A));

let verification_key_hash = H::<BYTES_FOR_VERIFICATION_KEY_HASH>(&verification_key_serialized);
Expand All @@ -495,5 +502,30 @@ pub(crate) fn verify<
.unwrap(),
));

todo!();
let w_approx = compute_w_approx::<ROWS_IN_A, COLUMNS_IN_A>(
&A_as_ntt,
signature.signer_response,
verifier_challenge_as_ntt,
t1,
);

let commitment_hash: [u8; COMMITMENT_HASH_SIZE] = {
let commitment = use_hint::<ROWS_IN_A, GAMMA2>(signature.hint, w_approx);
let commitment_serialized = encoding::commitment::serialize_vector::<
ROWS_IN_A,
COMMITMENT_RING_ELEMENT_SIZE,
COMMITMENT_VECTOR_SIZE,
>(commitment);

let mut hash_input = message_representative.to_vec();
hash_input.extend_from_slice(&commitment_serialized);

H::<COMMITMENT_HASH_SIZE>(&hash_input[..])
};

if signature.commitment_hash != commitment_hash {
return Err(VerificationError::CommitmentHashesDontMatchError);
}

Ok(())
}
3 changes: 3 additions & 0 deletions libcrux-ml-dsa/tests/nistkats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ fn ml_dsa_65_nist_known_answer_tests() {
signature_hash, kat.sha3_256_hash_of_signature,
"signature_hash != kat.sha3_256_hash_of_signature"
);

libcrux_ml_dsa::ml_dsa_65::verify(key_pair.verification_key, &message, signature)
.expect("Signature was generated honestly, so verification should pass");
}
}

Expand Down

0 comments on commit 64b6dd9

Please sign in to comment.