diff --git a/libcrux-ml-dsa/benches/ml-dsa.rs b/libcrux-ml-dsa/benches/ml-dsa.rs index e861244a1..ab870bf4b 100644 --- a/libcrux-ml-dsa/benches/ml-dsa.rs +++ b/libcrux-ml-dsa/benches/ml-dsa.rs @@ -60,7 +60,7 @@ pub fn comparisons_verification(c: &mut Criterion) { let mut message = [0u8; 511]; rng.fill_bytes(&mut message); - group.bench_function("libcrux portable (external random)", |b| { + group.bench_function("libcrux portable", |b| { let mut randomness = [0; 32]; rng.fill_bytes(&mut randomness); let keypair = ml_dsa_65::generate_key_pair(randomness); @@ -72,7 +72,7 @@ pub fn comparisons_verification(c: &mut Criterion) { }) }); - group.bench_function("pqclean reference implementation (internal random)", |b| { + group.bench_function("pqclean reference implementation", |b| { let (vk, sk) = pqcrypto_dilithium::dilithium3::keypair(); let signature = pqcrypto_dilithium::dilithium3::detached_sign(&message, &sk); b.iter(|| { diff --git a/libcrux-ml-dsa/src/arithmetic.rs b/libcrux-ml-dsa/src/arithmetic.rs index 4fb7d278e..97f42d749 100644 --- a/libcrux-ml-dsa/src/arithmetic.rs +++ b/libcrux-ml-dsa/src/arithmetic.rs @@ -79,8 +79,8 @@ pub(crate) fn vector_infinity_norm_exceeds( // TODO: We can break out of this loop early if need be, but the most // straightforward way to do so (returning false) will not go through hax; // revisit if performance is impacted. - for i in 0..DIMENSION { - exceeds |= vector[i].infinity_norm_exceeds(value); + for ring_element in vector.iter() { + exceeds |= ring_element.infinity_norm_exceeds(value); } exceeds @@ -129,12 +129,14 @@ pub(crate) fn montgomery_multiply_fe_by_fer( montgomery_reduce((fe as i64) * (fer as i64)) } +#[inline(always)] fn reduce(fe: FieldElement) -> FieldElement { let quotient = (fe + (1 << 22)) >> 23; fe - (quotient * FIELD_MODULUS) } +#[inline(always)] pub(crate) fn shift_coefficients_left_then_reduce( re: PolynomialRingElement, shift_by: usize, @@ -312,7 +314,7 @@ pub(crate) fn make_hint( pub(crate) fn use_hint_value(r: i32, hint: bool) -> i32 { let (r0, r1) = decompose::(r); - if hint == false { + if !hint { return r1; } @@ -324,12 +326,10 @@ pub(crate) fn use_hint_value(r: i32, hint: bool) -> i32 { } else { r1 + 1 } + } else if r1 == 0 { + 43 } else { - if r1 == 0 { - 43 - } else { - r1 - 1 - } + r1 - 1 } } diff --git a/libcrux-ml-dsa/src/encoding/commitment.rs b/libcrux-ml-dsa/src/encoding/commitment.rs index 6b9beb471..96145f946 100644 --- a/libcrux-ml-dsa/src/encoding/commitment.rs +++ b/libcrux-ml-dsa/src/encoding/commitment.rs @@ -27,7 +27,7 @@ fn serialize(re: PolynomialRingElement) -> [u8; OUTPUT let coefficient2 = coefficients[2] as u8; let coefficient3 = coefficients[3] as u8; - out[3 * i + 0] = (coefficient1 << 6) | coefficient0; + out[3 * i] = (coefficient1 << 6) | coefficient0; out[3 * i + 1] = (coefficient2 << 4) | coefficient1 >> 2; out[3 * i + 2] = (coefficient3 << 2) | coefficient2 >> 4; } @@ -50,9 +50,9 @@ pub(crate) fn serialize_vector< let mut serialized = [0u8; OUTPUT_SIZE]; let mut offset: usize = 0; - for i in 0..DIMENSION { + for ring_element in vector.iter() { serialized[offset..offset + RING_ELEMENT_SIZE] - .copy_from_slice(&serialize::(vector[i])); + .copy_from_slice(&serialize::(*ring_element)); offset += RING_ELEMENT_SIZE; } diff --git a/libcrux-ml-dsa/src/encoding/error.rs b/libcrux-ml-dsa/src/encoding/error.rs index 9862e73fb..cb01e2362 100644 --- a/libcrux-ml-dsa/src/encoding/error.rs +++ b/libcrux-ml-dsa/src/encoding/error.rs @@ -19,7 +19,7 @@ fn serialize_when_eta_is_2( let coefficient6 = (ETA - coefficients[6]) as u8; let coefficient7 = (ETA - coefficients[7]) as u8; - serialized[3 * i + 0] = (coefficient2 << 6) | (coefficient1 << 3) | coefficient0; + serialized[3 * i] = (coefficient2 << 6) | (coefficient1 << 3) | coefficient0; serialized[3 * i + 1] = (coefficient5 << 7) | (coefficient4 << 4) | (coefficient3 << 1) | (coefficient2 >> 2); serialized[3 * i + 2] = (coefficient7 << 5) | (coefficient6 << 2) | (coefficient5 >> 1); @@ -65,7 +65,7 @@ fn deserialize_when_eta_is_2(serialized: &[u8]) -> PolynomialRingElement { let byte1 = bytes[1] as i32; let byte2 = bytes[2] as i32; - re.coefficients[8 * i + 0] = (byte0 >> 0) & 7; + re.coefficients[8 * i] = byte0 & 7; re.coefficients[8 * i + 1] = (byte0 >> 3) & 7; re.coefficients[8 * i + 2] = ((byte0 >> 6) | (byte1 << 2)) & 7; re.coefficients[8 * i + 3] = (byte1 >> 1) & 7; @@ -74,7 +74,7 @@ fn deserialize_when_eta_is_2(serialized: &[u8]) -> PolynomialRingElement { re.coefficients[8 * i + 6] = (byte2 >> 2) & 7; re.coefficients[8 * i + 7] = (byte2 >> 5) & 7; - re.coefficients[8 * i + 0] = ETA - re.coefficients[8 * i + 0]; + re.coefficients[8 * i] = ETA - re.coefficients[8 * i]; re.coefficients[8 * i + 1] = ETA - re.coefficients[8 * i + 1]; re.coefficients[8 * i + 2] = ETA - re.coefficients[8 * i + 2]; re.coefficients[8 * i + 3] = ETA - re.coefficients[8 * i + 3]; @@ -92,8 +92,8 @@ fn deserialize_when_eta_is_4(serialized: &[u8]) -> PolynomialRingElement { let mut re = PolynomialRingElement::ZERO; const ETA: i32 = 4; - for (i, byte) in serialized.into_iter().enumerate() { - re.coefficients[2 * i + 0] = ETA - ((byte & 0xF) as i32); + for (i, byte) in serialized.iter().enumerate() { + re.coefficients[2 * i] = ETA - ((byte & 0xF) as i32); re.coefficients[2 * i + 1] = ETA - ((byte >> 4) as i32); } diff --git a/libcrux-ml-dsa/src/encoding/gamma1.rs b/libcrux-ml-dsa/src/encoding/gamma1.rs index f8ebbee64..2bb4b8299 100644 --- a/libcrux-ml-dsa/src/encoding/gamma1.rs +++ b/libcrux-ml-dsa/src/encoding/gamma1.rs @@ -13,7 +13,7 @@ fn serialize_when_gamma1_is_2_pow_17( let coefficient2 = GAMMA1 - coefficients[2]; let coefficient3 = GAMMA1 - coefficients[3]; - serialized[9 * i + 0] = coefficient0 as u8; + serialized[9 * i] = coefficient0 as u8; serialized[9 * i + 1] = (coefficient0 >> 8) as u8; serialized[9 * i + 2] = (coefficient0 >> 16) as u8; @@ -47,7 +47,7 @@ fn serialize_when_gamma1_is_2_pow_19( let coefficient0 = GAMMA1 - coefficients[0]; let coefficient1 = GAMMA1 - coefficients[1]; - serialized[5 * i + 0] = coefficient0 as u8; + serialized[5 * i] = coefficient0 as u8; serialized[5 * i + 1] = (coefficient0 >> 8) as u8; serialized[5 * i + 2] = (coefficient0 >> 16) as u8; @@ -79,10 +79,10 @@ fn deserialize_when_gamma1_is_2_pow_17(serialized: &[u8]) -> PolynomialRingEleme let mut re = PolynomialRingElement::ZERO; for (i, bytes) in serialized.chunks_exact(9).enumerate() { - re.coefficients[4 * i + 0] = bytes[0] as i32; - re.coefficients[4 * i + 0] |= (bytes[1] as i32) << 8; - re.coefficients[4 * i + 0] |= (bytes[2] as i32) << 16; - re.coefficients[4 * i + 0] &= GAMMA1_TIMES_2_BITMASK; + re.coefficients[4 * i] = bytes[0] as i32; + re.coefficients[4 * i] |= (bytes[1] as i32) << 8; + re.coefficients[4 * i] |= (bytes[2] as i32) << 16; + re.coefficients[4 * i] &= GAMMA1_TIMES_2_BITMASK; re.coefficients[4 * i + 1] = (bytes[2] as i32) >> 2; re.coefficients[4 * i + 1] |= (bytes[3] as i32) << 6; @@ -99,7 +99,7 @@ fn deserialize_when_gamma1_is_2_pow_17(serialized: &[u8]) -> PolynomialRingEleme re.coefficients[4 * i + 3] |= (bytes[8] as i32) << 10; re.coefficients[4 * i + 3] &= GAMMA1_TIMES_2_BITMASK; - re.coefficients[4 * i + 0] = GAMMA1 - re.coefficients[4 * i + 0]; + re.coefficients[4 * i] = GAMMA1 - re.coefficients[4 * i]; re.coefficients[4 * i + 1] = GAMMA1 - re.coefficients[4 * i + 1]; re.coefficients[4 * i + 2] = GAMMA1 - re.coefficients[4 * i + 2]; re.coefficients[4 * i + 3] = GAMMA1 - re.coefficients[4 * i + 3]; @@ -116,16 +116,16 @@ fn deserialize_when_gamma1_is_2_pow_19(serialized: &[u8]) -> PolynomialRingEleme let mut re = PolynomialRingElement::ZERO; for (i, bytes) in serialized.chunks_exact(5).enumerate() { - re.coefficients[2 * i + 0] = bytes[0] as i32; - re.coefficients[2 * i + 0] |= (bytes[1] as i32) << 8; - re.coefficients[2 * i + 0] |= (bytes[2] as i32) << 16; - re.coefficients[2 * i + 0] &= GAMMA1_TIMES_2_BITMASK; + re.coefficients[2 * i] = bytes[0] as i32; + re.coefficients[2 * i] |= (bytes[1] as i32) << 8; + re.coefficients[2 * i] |= (bytes[2] as i32) << 16; + re.coefficients[2 * i] &= GAMMA1_TIMES_2_BITMASK; re.coefficients[2 * i + 1] = (bytes[2] as i32) >> 4; re.coefficients[2 * i + 1] |= (bytes[3] as i32) << 4; re.coefficients[2 * i + 1] |= (bytes[4] as i32) << 12; - re.coefficients[2 * i + 0] = GAMMA1 - re.coefficients[2 * i + 0]; + re.coefficients[2 * i] = GAMMA1 - re.coefficients[2 * i]; re.coefficients[2 * i + 1] = GAMMA1 - re.coefficients[2 * i + 1]; } diff --git a/libcrux-ml-dsa/src/encoding/signing_key.rs b/libcrux-ml-dsa/src/encoding/signing_key.rs index 92970efba..afb7bdb0a 100644 --- a/libcrux-ml-dsa/src/encoding/signing_key.rs +++ b/libcrux-ml-dsa/src/encoding/signing_key.rs @@ -27,37 +27,88 @@ pub(crate) fn generate_serialized< let mut signing_key_serialized = [0u8; SIGNING_KEY_SIZE]; let mut offset = 0; - signing_key_serialized[offset..offset + SEED_FOR_A_SIZE].copy_from_slice(&seed_for_A); + signing_key_serialized[offset..offset + SEED_FOR_A_SIZE].copy_from_slice(seed_for_A); offset += SEED_FOR_A_SIZE; signing_key_serialized[offset..offset + SEED_FOR_SIGNING_SIZE] - .copy_from_slice(&seed_for_signing); + .copy_from_slice(seed_for_signing); offset += SEED_FOR_SIGNING_SIZE; - let verification_key_hash = H::(verification_key); + let verification_key_hash = H::one_shot::(verification_key); signing_key_serialized[offset..offset + BYTES_FOR_VERIFICATION_KEY_HASH] .copy_from_slice(&verification_key_hash); offset += BYTES_FOR_VERIFICATION_KEY_HASH; - for i in 0..COLUMNS_IN_A { + for ring_element in s1.iter() { signing_key_serialized[offset..offset + ERROR_RING_ELEMENT_SIZE].copy_from_slice( - &encoding::error::serialize::(s1[i]), + &encoding::error::serialize::(*ring_element), ); offset += ERROR_RING_ELEMENT_SIZE; } - for i in 0..ROWS_IN_A { + for ring_element in s2.iter() { signing_key_serialized[offset..offset + ERROR_RING_ELEMENT_SIZE].copy_from_slice( - &encoding::error::serialize::(s2[i]), + &encoding::error::serialize::(*ring_element), ); offset += ERROR_RING_ELEMENT_SIZE; } - for i in 0..ROWS_IN_A { + for ring_element in t0.iter() { signing_key_serialized[offset..offset + RING_ELEMENT_OF_T0S_SIZE] - .copy_from_slice(&encoding::t0::serialize(t0[i])); + .copy_from_slice(&encoding::t0::serialize(*ring_element)); offset += RING_ELEMENT_OF_T0S_SIZE; } signing_key_serialized } + +#[allow(non_snake_case)] +#[inline(always)] +pub(crate) fn deserialize_then_ntt< + const ROWS_IN_A: usize, + const COLUMNS_IN_A: usize, + const ETA: usize, + const ERROR_RING_ELEMENT_SIZE: usize, + const SIGNING_KEY_SIZE: usize, +>( + serialized: [u8; SIGNING_KEY_SIZE], +) -> ( + [u8; SEED_FOR_A_SIZE], // seed_for_A + [u8; SEED_FOR_SIGNING_SIZE], // seed_for_signing + [u8; BYTES_FOR_VERIFICATION_KEY_HASH], // verification_key_hash + [PolynomialRingElement; COLUMNS_IN_A], // s1 + [PolynomialRingElement; ROWS_IN_A], // s2 + [PolynomialRingElement; ROWS_IN_A], // t0_as_ntt +) { + let (seed_for_A, remaining_serialized) = serialized.split_at(SEED_FOR_A_SIZE); + let (seed_for_signing, remaining_serialized) = + remaining_serialized.split_at(SEED_FOR_SIGNING_SIZE); + let (verification_key_hash, remaining_serialized) = + remaining_serialized.split_at(BYTES_FOR_VERIFICATION_KEY_HASH); + + let (s1_serialized, remaining_serialized) = + remaining_serialized.split_at(ERROR_RING_ELEMENT_SIZE * COLUMNS_IN_A); + let (s2_serialized, t0_serialized) = + remaining_serialized.split_at(ERROR_RING_ELEMENT_SIZE * ROWS_IN_A); + + let s1_as_ntt = encoding::error::deserialize_to_vector_then_ntt::< + COLUMNS_IN_A, + ETA, + ERROR_RING_ELEMENT_SIZE, + >(s1_serialized); + let s2_as_ntt = + encoding::error::deserialize_to_vector_then_ntt::( + s2_serialized, + ); + + let t0_as_ntt = encoding::t0::deserialize_to_vector_then_ntt::(t0_serialized); + + ( + seed_for_A.try_into().unwrap(), + seed_for_signing.try_into().unwrap(), + verification_key_hash.try_into().unwrap(), + s1_as_ntt, + s2_as_ntt, + t0_as_ntt, + ) +} diff --git a/libcrux-ml-dsa/src/encoding/t0.rs b/libcrux-ml-dsa/src/encoding/t0.rs index 85910d314..3b6914bb9 100644 --- a/libcrux-ml-dsa/src/encoding/t0.rs +++ b/libcrux-ml-dsa/src/encoding/t0.rs @@ -29,7 +29,7 @@ pub(crate) fn serialize(re: PolynomialRingElement) -> [u8; RING_ELEMENT_OF_T0S_S let coefficient6 = change_t0_interval(coefficients[6]); let coefficient7 = change_t0_interval(coefficients[7]); - serialized[13 * i + 0] = coefficient0 as u8; + serialized[13 * i] = coefficient0 as u8; serialized[13 * i + 1] = (coefficient0 >> 8) as u8; serialized[13 * i + 1] |= (coefficient1 << 5) as u8; @@ -87,9 +87,9 @@ fn deserialize(serialized: &[u8]) -> PolynomialRingElement { let byte11 = bytes[11] as i32; let byte12 = bytes[12] as i32; - re.coefficients[8 * i + 0] = byte0; - re.coefficients[8 * i + 0] |= byte1 << 8; - re.coefficients[8 * i + 0] &= BITS_IN_LOWER_PART_OF_T_MASK; + re.coefficients[8 * i] = byte0; + re.coefficients[8 * i] |= byte1 << 8; + re.coefficients[8 * i] &= BITS_IN_LOWER_PART_OF_T_MASK; re.coefficients[8 * i + 1] = byte1 >> 5; re.coefficients[8 * i + 1] |= byte2 << 3; @@ -123,7 +123,7 @@ fn deserialize(serialized: &[u8]) -> PolynomialRingElement { re.coefficients[8 * i + 7] |= byte12 << 5; re.coefficients[8 * i + 7] &= BITS_IN_LOWER_PART_OF_T_MASK; - re.coefficients[8 * i + 0] = change_t0_interval(re.coefficients[8 * i + 0]); + re.coefficients[8 * i] = change_t0_interval(re.coefficients[8 * i]); re.coefficients[8 * i + 1] = change_t0_interval(re.coefficients[8 * i + 1]); re.coefficients[8 * i + 2] = change_t0_interval(re.coefficients[8 * i + 2]); re.coefficients[8 * i + 3] = change_t0_interval(re.coefficients[8 * i + 3]); diff --git a/libcrux-ml-dsa/src/encoding/t1.rs b/libcrux-ml-dsa/src/encoding/t1.rs index e393c6e12..8f3b448c2 100644 --- a/libcrux-ml-dsa/src/encoding/t1.rs +++ b/libcrux-ml-dsa/src/encoding/t1.rs @@ -36,7 +36,7 @@ pub(crate) fn deserialize(serialized: &[u8]) -> PolynomialRingElement { let byte3 = bytes[3] as i32; let byte4 = bytes[4] as i32; - out.coefficients[4 * i + 0] = ((byte0 >> 0) | (byte1 << 8)) & mask; + out.coefficients[4 * i] = (byte0 | (byte1 << 8)) & mask; out.coefficients[4 * i + 1] = ((byte1 >> 2) | (byte2 << 6)) & mask; out.coefficients[4 * i + 2] = ((byte2 >> 4) | (byte3 << 4)) & mask; out.coefficients[4 * i + 3] = ((byte3 >> 6) | (byte4 << 2)) & mask; diff --git a/libcrux-ml-dsa/src/encoding/verification_key.rs b/libcrux-ml-dsa/src/encoding/verification_key.rs index 0132f3c7e..7a11b8f80 100644 --- a/libcrux-ml-dsa/src/encoding/verification_key.rs +++ b/libcrux-ml-dsa/src/encoding/verification_key.rs @@ -11,12 +11,12 @@ pub(crate) fn generate_serialized [u8; VERIFICATION_KEY_SIZE] { let mut verification_key_serialized = [0u8; VERIFICATION_KEY_SIZE]; - verification_key_serialized[0..SEED_FOR_A_SIZE].copy_from_slice(&seed_for_A); + verification_key_serialized[0..SEED_FOR_A_SIZE].copy_from_slice(seed_for_A); - for i in 0..ROWS_IN_A { + for (i, ring_element) in t1.iter().enumerate() { let offset = SEED_FOR_A_SIZE + (i * RING_ELEMENT_OF_T1S_SIZE); verification_key_serialized[offset..offset + RING_ELEMENT_OF_T1S_SIZE] - .copy_from_slice(&t1::serialize(t1[i])); + .copy_from_slice(&t1::serialize(*ring_element)); } verification_key_serialized diff --git a/libcrux-ml-dsa/src/hash_functions.rs b/libcrux-ml-dsa/src/hash_functions.rs index 0e9cab467..0eeaf06e2 100644 --- a/libcrux-ml-dsa/src/hash_functions.rs +++ b/libcrux-ml-dsa/src/hash_functions.rs @@ -1,9 +1,40 @@ #![allow(non_snake_case)] -pub(crate) fn H(input: &[u8]) -> [u8; OUTPUT_LENGTH] { - let mut out = [0u8; OUTPUT_LENGTH]; - libcrux_sha3::portable::shake256(&mut out, input); - out +pub(crate) mod H { + use libcrux_sha3::portable::{incremental, shake256, KeccakState}; + + const BLOCK_SIZE: usize = 136; + + pub(crate) fn one_shot(input: &[u8]) -> [u8; OUTPUT_LENGTH] { + let mut out = [0u8; OUTPUT_LENGTH]; + shake256(&mut out, input); + + out + } + + #[inline(always)] + pub(crate) fn new(seed: &[u8]) -> KeccakState { + let mut state = incremental::shake256_init(); + incremental::shake256_absorb_final(&mut state, seed); + + state + } + + #[inline(always)] + pub(crate) fn squeeze_first_block(state: &mut KeccakState) -> [u8; BLOCK_SIZE] { + let mut out = [0u8; BLOCK_SIZE]; + incremental::shake256_squeeze_first_block(state, &mut out); + + out + } + + #[inline(always)] + pub(crate) fn squeeze_next_block(state: &mut KeccakState) -> [u8; BLOCK_SIZE] { + let mut out = [0u8; BLOCK_SIZE]; + incremental::shake256_squeeze_next_block(state, &mut out); + + out + } } pub(crate) mod H_128 { diff --git a/libcrux-ml-dsa/src/matrix.rs b/libcrux-ml-dsa/src/matrix.rs index f461d55c8..24586bee2 100644 --- a/libcrux-ml-dsa/src/matrix.rs +++ b/libcrux-ml-dsa/src/matrix.rs @@ -12,6 +12,7 @@ pub(crate) fn expand_to_A( ) -> [[PolynomialRingElement; COLUMNS_IN_A]; ROWS_IN_A] { let mut A = [[PolynomialRingElement::ZERO; COLUMNS_IN_A]; ROWS_IN_A]; + #[allow(clippy::needless_range_loop)] for i in 0..ROWS_IN_A { for j in 0..COLUMNS_IN_A { seed[32] = j as u8; @@ -76,8 +77,9 @@ pub(crate) fn vector_times_ring_element( ) -> [PolynomialRingElement; DIMENSION] { let mut result = [PolynomialRingElement::ZERO; DIMENSION]; - for (i, vector_element) in vector.iter().enumerate() { - result[i] = invert_ntt_montgomery(ntt_multiply_montgomery(&vector_element, ring_element)); + for (i, vector_ring_element) in vector.iter().enumerate() { + result[i] = + invert_ntt_montgomery(ntt_multiply_montgomery(vector_ring_element, ring_element)); } result @@ -126,7 +128,7 @@ pub(crate) fn compute_w_approx MLDSA44KeyPair { let (signing_key, verification_key) = crate::ml_dsa_generic::generate_key_pair::< ROWS_IN_A, diff --git a/libcrux-ml-dsa/src/ml_dsa_generic.rs b/libcrux-ml-dsa/src/ml_dsa_generic.rs index 8e45bf70b..72a272a59 100644 --- a/libcrux-ml-dsa/src/ml_dsa_generic.rs +++ b/libcrux-ml-dsa/src/ml_dsa_generic.rs @@ -27,7 +27,7 @@ pub(crate) fn generate_key_pair< randomness: [u8; KEY_GENERATION_RANDOMNESS_SIZE], ) -> ([u8; SIGNING_KEY_SIZE], [u8; VERIFICATION_KEY_SIZE]) { // 128 = SEED_FOR_A_SIZE + SEED_FOR_ERROR_VECTORS_SIZE + SEED_FOR_SIGNING_SIZE - let seed_expanded = H::<128>(&randomness); + let seed_expanded = H::one_shot::<128>(&randomness); let (seed_for_A, seed_expanded) = seed_expanded.split_at(SEED_FOR_A_SIZE); let (seed_for_error_vectors, seed_for_signing) = @@ -123,7 +123,7 @@ impl previous_true_hints_seen && hint_serialized[j] <= hint_serialized[j - 1] { // indices of true hints for a specific polynomial should be // increasing + // TODO: This return won't pass through hax; it'll need + // to be rewritten. See https://github.com/cryspen/libcrux/issues/341 return Err(VerificationError::MalformedHintError); } @@ -188,9 +190,15 @@ impl [u8; SIGNATURE_SIZE] { - let (seed_for_A, remaining_signing_key) = signing_key.split_at(SEED_FOR_A_SIZE); - let (seed_for_signing, remaining_signing_key) = - remaining_signing_key.split_at(SEED_FOR_SIGNING_SIZE); - let (verification_key_hash, remaining_signing_key) = - remaining_signing_key.split_at(BYTES_FOR_VERIFICATION_KEY_HASH); - - let (s1_serialized, remaining_signing_key) = - remaining_signing_key.split_at(ERROR_RING_ELEMENT_SIZE * COLUMNS_IN_A); - let (s2_serialized, t0_serialized) = - remaining_signing_key.split_at(ERROR_RING_ELEMENT_SIZE * ROWS_IN_A); - - let s1_as_ntt = encoding::error::deserialize_to_vector_then_ntt::< - COLUMNS_IN_A, - ETA, - ERROR_RING_ELEMENT_SIZE, - >(s1_serialized); - let s2_as_ntt = - encoding::error::deserialize_to_vector_then_ntt::( - s2_serialized, - ); - - let t0_as_ntt = encoding::t0::deserialize_to_vector_then_ntt::(t0_serialized); + let (seed_for_A, seed_for_signing, verification_key_hash, s1_as_ntt, s2_as_ntt, t0_as_ntt) = + encoding::signing_key::deserialize_then_ntt::< + ROWS_IN_A, + COLUMNS_IN_A, + ETA, + ERROR_RING_ELEMENT_SIZE, + SIGNING_KEY_SIZE, + >(signing_key); - let A_as_ntt = expand_to_A::(into_padded_array(seed_for_A)); + let A_as_ntt = expand_to_A::(into_padded_array(&seed_for_A)); + // TODO: Remove the use of to_vec with an incremental SHAKE-256 absorb API. let message_representative = { let mut hash_input = verification_key_hash.to_vec(); hash_input.extend_from_slice(message); - H::(&hash_input[..]) + H::one_shot::(&hash_input[..]) }; let mask_seed: [u8; MASK_SEED_SIZE] = { @@ -261,7 +256,7 @@ pub(crate) fn sign< hash_input.extend_from_slice(&randomness); hash_input.extend_from_slice(&message_representative); - H::(&hash_input[..]) + H::one_shot::(&hash_input[..]) }; let mut domain_separator_for_mask: u16 = 0; @@ -270,16 +265,18 @@ pub(crate) fn sign< let mut attempt = 0; + // TODO: This style of rejection sampling, with the break and the continues, + // won't pass through hax; it'll need to be rewritten. + // See https://github.com/cryspen/libcrux/issues/341 let (commitment_hash, signer_response, hint) = loop { attempt += 1; - if attempt >= 576 { - // Depending on the mode, one try has a chance between 1/7 and 1/4 - // of succeeding. Thus it is safe to say that 576 iterations - // are enough as (6/7)⁵⁷⁶ < 2⁻¹²⁸[1]. - // - // [1]: https://github.com/cloudflare/circl/blob/main/sign/dilithium/mode2/internal/dilithium.go#L341 - panic!("At least 576 signing attempts were made; this should only happen 1 in 2^{{128}} times: something is wrong.") - } + + // Depending on the mode, one try has a chance between 1/7 and 1/4 + // of succeeding. Thus it is safe to say that 576 iterations + // are enough as (6/7)⁵⁷⁶ < 2⁻¹²⁸[1]. + // + // [1]: https://github.com/cloudflare/circl/blob/main/sign/dilithium/mode2/internal/dilithium.go#L341 + debug_assert!(attempt < 576); let mask = sample_mask_vector::( into_padded_array(&mask_seed), @@ -300,7 +297,7 @@ pub(crate) fn sign< let mut hash_input = message_representative.to_vec(); hash_input.extend_from_slice(&commitment_serialized); - H::(&hash_input[..]) + H::one_shot::(&hash_input[..]) }; let verifier_challenge_as_ntt = @@ -386,55 +383,56 @@ pub(crate) fn verify< SIGNATURE_SIZE, >(signature_serialized)?; - if vector_infinity_norm_exceeds::( + // We use if-else branches because early returns will not go through hax. + if !vector_infinity_norm_exceeds::( signature.signer_response, (2 << GAMMA1_EXPONENT) - BETA, ) { - // TODO: These early returns won't go through verification, fix them. - return Err(VerificationError::SignerResponseExceedsBoundError); - } - - let A_as_ntt = expand_to_A::(into_padded_array(&seed_for_A)); + let A_as_ntt = expand_to_A::(into_padded_array(&seed_for_A)); - let verification_key_hash = H::(&verification_key_serialized); - let message_representative = { - let mut hash_input = verification_key_hash.to_vec(); - hash_input.extend_from_slice(message); + let verification_key_hash = + H::one_shot::(&verification_key_serialized); + let message_representative = { + let mut hash_input = verification_key_hash.to_vec(); + hash_input.extend_from_slice(message); - H::(&hash_input[..]) - }; + H::one_shot::(&hash_input[..]) + }; - let verifier_challenge_as_ntt = - ntt(sample_challenge_ring_element::( - signature.commitment_hash[0..VERIFIER_CHALLENGE_SEED_SIZE] - .try_into() - .unwrap(), - )); + let verifier_challenge_as_ntt = + ntt(sample_challenge_ring_element::( + signature.commitment_hash[0..VERIFIER_CHALLENGE_SEED_SIZE] + .try_into() + .unwrap(), + )); - let w_approx = compute_w_approx::( - &A_as_ntt, - signature.signer_response, - verifier_challenge_as_ntt, - t1, - ); + let w_approx = compute_w_approx::( + &A_as_ntt, + signature.signer_response, + verifier_challenge_as_ntt, + t1, + ); - let commitment_hash: [u8; COMMITMENT_HASH_SIZE] = { - let commitment = use_hint::(signature.hint, w_approx); - let commitment_serialized = encoding::commitment::serialize_vector::< - ROWS_IN_A, - COMMITMENT_RING_ELEMENT_SIZE, - COMMITMENT_VECTOR_SIZE, - >(commitment); + let commitment_hash: [u8; COMMITMENT_HASH_SIZE] = { + let commitment = use_hint::(signature.hint, w_approx); + let commitment_serialized = encoding::commitment::serialize_vector::< + ROWS_IN_A, + COMMITMENT_RING_ELEMENT_SIZE, + COMMITMENT_VECTOR_SIZE, + >(commitment); - let mut hash_input = message_representative.to_vec(); - hash_input.extend_from_slice(&commitment_serialized); + let mut hash_input = message_representative.to_vec(); + hash_input.extend_from_slice(&commitment_serialized); - H::(&hash_input[..]) - }; + H::one_shot::(&hash_input[..]) + }; - if signature.commitment_hash != commitment_hash { - return Err(VerificationError::CommitmentHashesDontMatchError); + if signature.commitment_hash != commitment_hash { + Err(VerificationError::CommitmentHashesDontMatchError) + } else { + Ok(()) + } + } else { + Err(VerificationError::SignerResponseExceedsBoundError) } - - Ok(()) } diff --git a/libcrux-ml-dsa/src/ntt.rs b/libcrux-ml-dsa/src/ntt.rs index 405de096f..abd20d1b6 100644 --- a/libcrux-ml-dsa/src/ntt.rs +++ b/libcrux-ml-dsa/src/ntt.rs @@ -54,7 +54,7 @@ fn ntt_at_layer( ZETAS_TIMES_MONTGOMERY_R[*zeta_i], ); re.coefficients[j + step] = re.coefficients[j] - t; - re.coefficients[j] = re.coefficients[j] + t; + re.coefficients[j] += t; } } @@ -93,7 +93,7 @@ fn invert_ntt_at_layer( for j in offset..offset + step { let a_minus_b = re.coefficients[j + step] - re.coefficients[j]; - re.coefficients[j] = re.coefficients[j] + re.coefficients[j + step]; + re.coefficients[j] += re.coefficients[j + step]; re.coefficients[j + step] = montgomery_multiply_fe_by_fer(a_minus_b, ZETAS_TIMES_MONTGOMERY_R[*zeta_i]); } diff --git a/libcrux-ml-dsa/src/sample.rs b/libcrux-ml-dsa/src/sample.rs index 650da0833..e4ac888cb 100644 --- a/libcrux-ml-dsa/src/sample.rs +++ b/libcrux-ml-dsa/src/sample.rs @@ -1,6 +1,6 @@ use crate::{ arithmetic::PolynomialRingElement, - constants::{COEFFICIENTS_IN_RING_ELEMENT, FIELD_MODULUS}, + constants::FIELD_MODULUS, encoding, hash_functions::{H, H_128}, }; @@ -21,12 +21,12 @@ fn rejection_sample_less_than_field_modulus( let potential_coefficient = ((b2 << 16) | (b1 << 8) | b0) & 0x00_7F_FF_FF; - if potential_coefficient < FIELD_MODULUS && *sampled < COEFFICIENTS_IN_RING_ELEMENT { + if potential_coefficient < FIELD_MODULUS && *sampled < out.coefficients.len() { out.coefficients[*sampled] = potential_coefficient; *sampled += 1; } - if *sampled == COEFFICIENTS_IN_RING_ELEMENT { + if *sampled == out.coefficients.len() { done = true; } } @@ -66,7 +66,7 @@ fn rejection_sample_less_than_eta_equals_2( let try_0 = byte & 0xF; let try_1 = byte >> 4; - if try_0 < 15 && *sampled < COEFFICIENTS_IN_RING_ELEMENT { + if try_0 < 15 && *sampled < out.coefficients.len() { let try_0 = try_0 as i32; // (try_0 * 26) >> 7 computes ⌊try_0 / 5⌋ @@ -77,7 +77,7 @@ fn rejection_sample_less_than_eta_equals_2( *sampled += 1; } - if try_1 < 15 && *sampled < COEFFICIENTS_IN_RING_ELEMENT { + if try_1 < 15 && *sampled < out.coefficients.len() { let try_1 = try_1 as i32; let try_1_mod_5 = try_1 - ((try_1 * 26) >> 7) * 5; @@ -86,7 +86,7 @@ fn rejection_sample_less_than_eta_equals_2( *sampled += 1; } - if *sampled == COEFFICIENTS_IN_RING_ELEMENT { + if *sampled == out.coefficients.len() { done = true; } } @@ -108,17 +108,17 @@ fn rejection_sample_less_than_eta_equals_4( let try_0 = byte & 0xF; let try_1 = byte >> 4; - if try_0 < 9 && *sampled < COEFFICIENTS_IN_RING_ELEMENT { + if try_0 < 9 && *sampled < out.coefficients.len() { out.coefficients[*sampled] = 4 - (try_0 as i32); *sampled += 1; } - if try_1 < 9 && *sampled < COEFFICIENTS_IN_RING_ELEMENT { + if try_1 < 9 && *sampled < out.coefficients.len() { out.coefficients[*sampled] = 4 - (try_1 as i32); *sampled += 1; } - if *sampled == COEFFICIENTS_IN_RING_ELEMENT { + if *sampled == out.coefficients.len() { done = true; } } @@ -143,17 +143,17 @@ pub(crate) fn rejection_sample_less_than_eta( #[allow(non_snake_case)] #[inline(always)] fn sample_error_ring_element(seed: [u8; 66]) -> PolynomialRingElement { - // TODO: Use incremental API to squeeze one block at a time. - let randomness = H::<272>(&seed); + let mut state = H::new(&seed); + let randomness = H::squeeze_first_block(&mut state); let mut out = PolynomialRingElement::ZERO; let mut sampled = 0; - let done = rejection_sample_less_than_eta::(&randomness, &mut sampled, &mut out); + let mut done = rejection_sample_less_than_eta::(&randomness, &mut sampled, &mut out); - // TODO: Remove this panic using the incremental API. - if !done { - panic!("Not enough randomness for sampling short vector."); + while !done { + let randomness = H::squeeze_next_block(&mut state); + done = rejection_sample_less_than_eta::(&randomness, &mut sampled, &mut out); } out @@ -165,6 +165,8 @@ pub(crate) fn sample_error_vector( domain_separator: &mut u16, ) -> [PolynomialRingElement; DIMENSION] { let mut error = [PolynomialRingElement::ZERO; DIMENSION]; + + #[allow(clippy::needless_range_loop)] for i in 0..DIMENSION { seed[64] = *domain_separator as u8; seed[65] = (*domain_separator >> 8) as u8; @@ -179,8 +181,8 @@ pub(crate) fn sample_error_vector( #[inline(always)] fn sample_mask_ring_element(seed: [u8; 66]) -> PolynomialRingElement { match GAMMA1_EXPONENT { - 17 => encoding::gamma1::deserialize::(&H::<576>(&seed)), - 19 => encoding::gamma1::deserialize::(&H::<640>(&seed)), + 17 => encoding::gamma1::deserialize::(&H::one_shot::<576>(&seed)), + 19 => encoding::gamma1::deserialize::(&H::one_shot::<640>(&seed)), _ => unreachable!(), } } @@ -192,6 +194,7 @@ pub(crate) fn sample_mask_vector [PolynomialRingElement; DIMENSION] { let mut error = [PolynomialRingElement::ZERO; DIMENSION]; + #[allow(clippy::needless_range_loop)] for i in 0..DIMENSION { seed[64] = *domain_separator as u8; seed[65] = (*domain_separator >> 8) as u8; @@ -204,48 +207,58 @@ pub(crate) fn sample_mask_vector( - seed: [u8; 32], -) -> PolynomialRingElement { - // TODO: Use incremental API to squeeze one block at a time. - let mut randomness = H::<136>(&seed).into_iter(); +fn inside_out_shuffle( + randomness: &[u8], + out_index: &mut usize, + signs: &mut u64, + result: &mut PolynomialRingElement, +) -> bool { + let mut done = false; + + for byte in randomness { + if !done { + let sample_at = *byte as usize; + if sample_at <= *out_index { + result.coefficients[*out_index] = result.coefficients[sample_at]; + *out_index += 1; + + result.coefficients[sample_at] = 1 - 2 * ((*signs & 1) as i32); + *signs >>= 1; + } - let mut signs: u64 = 0; - for i in 0..8 { - signs |= (randomness.next().unwrap() as u64) << (8 * i); + done = *out_index == result.coefficients.len(); + } } - let mut out = PolynomialRingElement::ZERO; + done +} +#[inline(always)] +pub(crate) fn sample_challenge_ring_element( + seed: [u8; 32], +) -> PolynomialRingElement { + let mut state = H::new(&seed); + let randomness = H::squeeze_first_block(&mut state); - for index in (out.coefficients.len() - NUMBER_OF_ONES)..out.coefficients.len() { - // TODO: Rewrite this without using `break`. It's doable, just probably - // not as nice. - let sample_at = loop { - let i = match randomness.next() { - Some(byte) => byte as usize, + let mut signs = u64::from_le_bytes(randomness[0..8].try_into().unwrap()); - // TODO: We need to incrementally sample here instead of panicking. - None => panic!("Insufficient randomness to sample challenge ring element."), - }; + let mut result = PolynomialRingElement::ZERO; - if i <= index { - break i; - } - }; + let mut out_index = result.coefficients.len() - NUMBER_OF_ONES; + let mut done = inside_out_shuffle(&randomness[8..], &mut out_index, &mut signs, &mut result); - out.coefficients[index] = out.coefficients[sample_at]; - out.coefficients[sample_at] = 1 - 2 * ((signs & 1) as i32); - signs >>= 1; + while !done { + let randomness = H::squeeze_next_block(&mut state); + done = inside_out_shuffle(&randomness, &mut out_index, &mut signs, &mut result); } - out + result } #[cfg(test)] mod tests { use super::*; - use crate::arithmetic::FieldElement; + use crate::{arithmetic::FieldElement, constants::COEFFICIENTS_IN_RING_ELEMENT}; #[test] fn test_sample_ring_element_uniform() { diff --git a/libcrux-ml-dsa/tests/self.rs b/libcrux-ml-dsa/tests/self.rs index 54b141972..31f7082ee 100644 --- a/libcrux-ml-dsa/tests/self.rs +++ b/libcrux-ml-dsa/tests/self.rs @@ -1,4 +1,5 @@ -use rand::{rngs::OsRng, RngCore}; +use libcrux_ml_dsa::{ml_dsa_44, ml_dsa_65, ml_dsa_87}; +use rand::{rngs::OsRng, Rng, RngCore}; fn random_array() -> [u8; L] { let mut rng = OsRng; @@ -6,16 +7,55 @@ fn random_array() -> [u8; L] { rng.try_fill_bytes(&mut seed).unwrap(); seed } +fn random_message() -> Vec { + let mut rng = OsRng; + + let mut length = [0u8; 2]; + rng.try_fill_bytes(&mut length).unwrap(); + let length = ((length[1] as u16) << 8) | length[0] as u16; + + let mut message = Vec::with_capacity(length.into()); + rng.try_fill_bytes(&mut message).unwrap(); + + message +} + +fn modify_signing_key(signing_key: &mut [u8; SIGNING_KEY_SIZE]) { + let option = rand::thread_rng().gen_range(0..2); + + let position = match option { + // Change the seed used for generating A + 0 => rand::thread_rng().gen_range(0..32), + + // Change the verification key hash + 1 => rand::thread_rng().gen_range(64..128), + + // TODO: Changing s1, s2, and t0 could still result in valid + // signatures. Look into this further. + _ => unreachable!(), + }; + + let random_byte = { + let byte = random_array::<1>()[0]; + + if byte == 0 { + byte + 1 + } else { + byte + } + }; -macro_rules! impl_consistency { + signing_key[position] ^= random_byte; +} + +macro_rules! impl_consistency_test { ($name:ident, $key_gen:expr, $sign:expr, $verify:expr) => { #[test] fn $name() { let key_generation_seed = random_array(); let signing_randomness = random_array(); - // TODO: Choose the length randomly - let message = random_array::<94883>(); + let message = random_message(); let key_pair = $key_gen(key_generation_seed); @@ -27,23 +67,63 @@ macro_rules! impl_consistency { }; } -impl_consistency!( +macro_rules! impl_modified_signing_key_test { + ($name:ident, $key_gen:expr, $signing_key_size: expr, $sign:expr, $verify:expr) => { + #[test] + fn $name() { + let key_generation_seed = random_array(); + let signing_randomness = random_array(); + + let message = random_message(); + + let mut key_pair = $key_gen(key_generation_seed); + + modify_signing_key::<{ $signing_key_size }>(&mut key_pair.signing_key.0); + + let signature = $sign(key_pair.signing_key, &message, signing_randomness); + + assert!($verify(key_pair.verification_key, &message, signature).is_err()); + } + }; +} + +impl_consistency_test!( consistency_44, - libcrux_ml_dsa::ml_dsa_44::generate_key_pair, - libcrux_ml_dsa::ml_dsa_44::sign, - libcrux_ml_dsa::ml_dsa_44::verify + ml_dsa_44::generate_key_pair, + ml_dsa_44::sign, + ml_dsa_44::verify ); - -impl_consistency!( +impl_consistency_test!( consistency_65, - libcrux_ml_dsa::ml_dsa_65::generate_key_pair, - libcrux_ml_dsa::ml_dsa_65::sign, - libcrux_ml_dsa::ml_dsa_65::verify + ml_dsa_65::generate_key_pair, + ml_dsa_65::sign, + ml_dsa_65::verify ); - -impl_consistency!( +impl_consistency_test!( consistency_87, - libcrux_ml_dsa::ml_dsa_87::generate_key_pair, - libcrux_ml_dsa::ml_dsa_87::sign, - libcrux_ml_dsa::ml_dsa_87::verify + ml_dsa_87::generate_key_pair, + ml_dsa_87::sign, + ml_dsa_87::verify +); + +impl_modified_signing_key_test!( + modified_signing_key_44, + ml_dsa_44::generate_key_pair, + ml_dsa_44::SIGNING_KEY_SIZE, + ml_dsa_44::sign, + ml_dsa_44::verify +); +impl_modified_signing_key_test!( + modified_signing_key_65, + ml_dsa_65::generate_key_pair, + ml_dsa_65::SIGNING_KEY_SIZE, + ml_dsa_65::sign, + ml_dsa_65::verify +); +impl_modified_signing_key_test!( + modified_signing_key_87, + ml_dsa_87::generate_key_pair, + ml_dsa_87::SIGNING_KEY_SIZE, + ml_dsa_87::sign, + ml_dsa_87::verify ); diff --git a/libcrux-ml-dsa/tests/wycheproof_sign.rs b/libcrux-ml-dsa/tests/wycheproof_sign.rs index 3d436a31a..1e77d7b65 100644 --- a/libcrux-ml-dsa/tests/wycheproof_sign.rs +++ b/libcrux-ml-dsa/tests/wycheproof_sign.rs @@ -50,7 +50,11 @@ macro_rules! wycheproof_sign_test { signature.0.as_slice(), hex::decode(test.sig).unwrap().as_slice() ); - } // TODO: else, how should invalid signatures be handled? + } + // TODO: else, the generated signature is invalid; we can + // check that our own implementation agrees with this judgement, + // but in order to do so we'd need the verification key. + // This is being tracked in https://github.com/cryspen/libcrux/issues/340 } } } diff --git a/libcrux-sha3/src/lib.rs b/libcrux-sha3/src/lib.rs index dcc01ba5f..e8ec69013 100644 --- a/libcrux-sha3/src/lib.rs +++ b/libcrux-sha3/src/lib.rs @@ -237,12 +237,13 @@ pub mod portable { /// An incremental API for SHAKE pub mod incremental { use generic_keccak::{ - absorb_final, squeeze_first_five_blocks, squeeze_first_three_blocks, squeeze_next_block, + absorb_final, squeeze_first_block, squeeze_first_five_blocks, + squeeze_first_three_blocks, squeeze_next_block, }; use super::*; - /// Initialise the SHAKE state. + /// Create a new SHAKE-128 state object. #[inline(always)] pub fn shake128_init() -> KeccakState { KeccakState { @@ -273,6 +274,31 @@ pub mod portable { pub fn shake128_squeeze_next_block(s: &mut KeccakState, out0: &mut [u8]) { squeeze_next_block::<1, u64, 168>(&mut s.state, [out0]) } + + /// Create a new SHAKE-256 state object. + #[inline(always)] + pub fn shake256_init() -> KeccakState { + KeccakState { + state: GenericState::<1, u64>::new(), + } + } + /// Absorb some data for SHAKE-256 for the last time + #[inline(always)] + pub fn shake256_absorb_final(s: &mut KeccakState, data0: &[u8]) { + absorb_final::<1, u64, 136, 0x1fu8>(&mut s.state, [data0]); + } + + /// Squeeze the first SHAKE-256 block + #[inline(always)] + pub fn shake256_squeeze_first_block(s: &mut KeccakState, out0: &mut [u8]) { + squeeze_first_block::<1, u64, 136>(&mut s.state, [out0]) + } + + /// Squeeze the next SHAKE-256 block + #[inline(always)] + pub fn shake256_squeeze_next_block(s: &mut KeccakState, out0: &mut [u8]) { + squeeze_next_block::<1, u64, 136>(&mut s.state, [out0]) + } } } @@ -454,8 +480,8 @@ pub mod neon { // XXX: These functions could alternatively implement the same with // the portable implementation // { - // let s0 = KeccakState1::new(); - // let s1 = KeccakState1::new(); + // let s0 = KeccakState::new(); + // let s1 = KeccakState::new(); // [s0, s1] // } #[cfg(feature = "simd128")] @@ -822,10 +848,10 @@ pub mod avx2 { // } // #[cfg(not(any(feature = "simd128", feature = "simd256")))] // { - // let s0 = KeccakState1::new(); - // let s1 = KeccakState1::new(); - // let s2 = KeccakState1::new(); - // let s3 = KeccakState1::new(); + // let s0 = KeccakState::new(); + // let s1 = KeccakState::new(); + // let s2 = KeccakState::new(); + // let s3 = KeccakState::new(); // [s0, s1, s2, s3] // } #[cfg(feature = "simd256")]