Skip to content

Commit

Permalink
Improvements to the ML-DSA code. (#339)
Browse files Browse the repository at this point in the history
* Move signing key deserialization into separate function.

* Use incremental squeeze API where applicable.

* Rewrite sample_challenge_ring_element without using a break statement.

* Change early returns in verify() to if-else branches.

* Made changes suggested by cargo clippy.

* Added invalid signing key self test.

---------

Co-authored-by: Franziskus Kiefer <[email protected]>
  • Loading branch information
xvzcf and franziskuskiefer authored Jul 3, 2024
1 parent b051a44 commit 4e5be59
Show file tree
Hide file tree
Showing 18 changed files with 413 additions and 208 deletions.
4 changes: 2 additions & 2 deletions libcrux-ml-dsa/benches/ml-dsa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ pub fn comparisons_verification(c: &mut Criterion) {
let mut message = [0u8; 511];
rng.fill_bytes(&mut message);

group.bench_function("libcrux portable (external random)", |b| {
group.bench_function("libcrux portable", |b| {
let mut randomness = [0; 32];
rng.fill_bytes(&mut randomness);
let keypair = ml_dsa_65::generate_key_pair(randomness);
Expand All @@ -72,7 +72,7 @@ pub fn comparisons_verification(c: &mut Criterion) {
})
});

group.bench_function("pqclean reference implementation (internal random)", |b| {
group.bench_function("pqclean reference implementation", |b| {
let (vk, sk) = pqcrypto_dilithium::dilithium3::keypair();
let signature = pqcrypto_dilithium::dilithium3::detached_sign(&message, &sk);
b.iter(|| {
Expand Down
16 changes: 8 additions & 8 deletions libcrux-ml-dsa/src/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ pub(crate) fn vector_infinity_norm_exceeds<const DIMENSION: usize>(
// TODO: We can break out of this loop early if need be, but the most
// straightforward way to do so (returning false) will not go through hax;
// revisit if performance is impacted.
for i in 0..DIMENSION {
exceeds |= vector[i].infinity_norm_exceeds(value);
for ring_element in vector.iter() {
exceeds |= ring_element.infinity_norm_exceeds(value);
}

exceeds
Expand Down Expand Up @@ -129,12 +129,14 @@ pub(crate) fn montgomery_multiply_fe_by_fer(
montgomery_reduce((fe as i64) * (fer as i64))
}

#[inline(always)]
fn reduce(fe: FieldElement) -> FieldElement {
let quotient = (fe + (1 << 22)) >> 23;

fe - (quotient * FIELD_MODULUS)
}

#[inline(always)]
pub(crate) fn shift_coefficients_left_then_reduce(
re: PolynomialRingElement,
shift_by: usize,
Expand Down Expand Up @@ -312,7 +314,7 @@ pub(crate) fn make_hint<const DIMENSION: usize, const GAMMA2: i32>(
pub(crate) fn use_hint_value<const GAMMA2: i32>(r: i32, hint: bool) -> i32 {
let (r0, r1) = decompose::<GAMMA2>(r);

if hint == false {
if !hint {
return r1;
}

Expand All @@ -324,12 +326,10 @@ pub(crate) fn use_hint_value<const GAMMA2: i32>(r: i32, hint: bool) -> i32 {
} else {
r1 + 1
}
} else if r1 == 0 {
43
} else {
if r1 == 0 {
43
} else {
r1 - 1
}
r1 - 1
}
}

Expand Down
6 changes: 3 additions & 3 deletions libcrux-ml-dsa/src/encoding/commitment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ fn serialize<const OUTPUT_SIZE: usize>(re: PolynomialRingElement) -> [u8; OUTPUT
let coefficient2 = coefficients[2] as u8;
let coefficient3 = coefficients[3] as u8;

out[3 * i + 0] = (coefficient1 << 6) | coefficient0;
out[3 * i] = (coefficient1 << 6) | coefficient0;
out[3 * i + 1] = (coefficient2 << 4) | coefficient1 >> 2;
out[3 * i + 2] = (coefficient3 << 2) | coefficient2 >> 4;
}
Expand All @@ -50,9 +50,9 @@ pub(crate) fn serialize_vector<
let mut serialized = [0u8; OUTPUT_SIZE];
let mut offset: usize = 0;

for i in 0..DIMENSION {
for ring_element in vector.iter() {
serialized[offset..offset + RING_ELEMENT_SIZE]
.copy_from_slice(&serialize::<RING_ELEMENT_SIZE>(vector[i]));
.copy_from_slice(&serialize::<RING_ELEMENT_SIZE>(*ring_element));
offset += RING_ELEMENT_SIZE;
}

Expand Down
10 changes: 5 additions & 5 deletions libcrux-ml-dsa/src/encoding/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ fn serialize_when_eta_is_2<const OUTPUT_SIZE: usize>(
let coefficient6 = (ETA - coefficients[6]) as u8;
let coefficient7 = (ETA - coefficients[7]) as u8;

serialized[3 * i + 0] = (coefficient2 << 6) | (coefficient1 << 3) | coefficient0;
serialized[3 * i] = (coefficient2 << 6) | (coefficient1 << 3) | coefficient0;
serialized[3 * i + 1] =
(coefficient5 << 7) | (coefficient4 << 4) | (coefficient3 << 1) | (coefficient2 >> 2);
serialized[3 * i + 2] = (coefficient7 << 5) | (coefficient6 << 2) | (coefficient5 >> 1);
Expand Down Expand Up @@ -65,7 +65,7 @@ fn deserialize_when_eta_is_2(serialized: &[u8]) -> PolynomialRingElement {
let byte1 = bytes[1] as i32;
let byte2 = bytes[2] as i32;

re.coefficients[8 * i + 0] = (byte0 >> 0) & 7;
re.coefficients[8 * i] = byte0 & 7;
re.coefficients[8 * i + 1] = (byte0 >> 3) & 7;
re.coefficients[8 * i + 2] = ((byte0 >> 6) | (byte1 << 2)) & 7;
re.coefficients[8 * i + 3] = (byte1 >> 1) & 7;
Expand All @@ -74,7 +74,7 @@ fn deserialize_when_eta_is_2(serialized: &[u8]) -> PolynomialRingElement {
re.coefficients[8 * i + 6] = (byte2 >> 2) & 7;
re.coefficients[8 * i + 7] = (byte2 >> 5) & 7;

re.coefficients[8 * i + 0] = ETA - re.coefficients[8 * i + 0];
re.coefficients[8 * i] = ETA - re.coefficients[8 * i];
re.coefficients[8 * i + 1] = ETA - re.coefficients[8 * i + 1];
re.coefficients[8 * i + 2] = ETA - re.coefficients[8 * i + 2];
re.coefficients[8 * i + 3] = ETA - re.coefficients[8 * i + 3];
Expand All @@ -92,8 +92,8 @@ fn deserialize_when_eta_is_4(serialized: &[u8]) -> PolynomialRingElement {
let mut re = PolynomialRingElement::ZERO;
const ETA: i32 = 4;

for (i, byte) in serialized.into_iter().enumerate() {
re.coefficients[2 * i + 0] = ETA - ((byte & 0xF) as i32);
for (i, byte) in serialized.iter().enumerate() {
re.coefficients[2 * i] = ETA - ((byte & 0xF) as i32);
re.coefficients[2 * i + 1] = ETA - ((byte >> 4) as i32);
}

Expand Down
24 changes: 12 additions & 12 deletions libcrux-ml-dsa/src/encoding/gamma1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ fn serialize_when_gamma1_is_2_pow_17<const OUTPUT_SIZE: usize>(
let coefficient2 = GAMMA1 - coefficients[2];
let coefficient3 = GAMMA1 - coefficients[3];

serialized[9 * i + 0] = coefficient0 as u8;
serialized[9 * i] = coefficient0 as u8;
serialized[9 * i + 1] = (coefficient0 >> 8) as u8;

serialized[9 * i + 2] = (coefficient0 >> 16) as u8;
Expand Down Expand Up @@ -47,7 +47,7 @@ fn serialize_when_gamma1_is_2_pow_19<const OUTPUT_SIZE: usize>(
let coefficient0 = GAMMA1 - coefficients[0];
let coefficient1 = GAMMA1 - coefficients[1];

serialized[5 * i + 0] = coefficient0 as u8;
serialized[5 * i] = coefficient0 as u8;
serialized[5 * i + 1] = (coefficient0 >> 8) as u8;

serialized[5 * i + 2] = (coefficient0 >> 16) as u8;
Expand Down Expand Up @@ -79,10 +79,10 @@ fn deserialize_when_gamma1_is_2_pow_17(serialized: &[u8]) -> PolynomialRingEleme
let mut re = PolynomialRingElement::ZERO;

for (i, bytes) in serialized.chunks_exact(9).enumerate() {
re.coefficients[4 * i + 0] = bytes[0] as i32;
re.coefficients[4 * i + 0] |= (bytes[1] as i32) << 8;
re.coefficients[4 * i + 0] |= (bytes[2] as i32) << 16;
re.coefficients[4 * i + 0] &= GAMMA1_TIMES_2_BITMASK;
re.coefficients[4 * i] = bytes[0] as i32;
re.coefficients[4 * i] |= (bytes[1] as i32) << 8;
re.coefficients[4 * i] |= (bytes[2] as i32) << 16;
re.coefficients[4 * i] &= GAMMA1_TIMES_2_BITMASK;

re.coefficients[4 * i + 1] = (bytes[2] as i32) >> 2;
re.coefficients[4 * i + 1] |= (bytes[3] as i32) << 6;
Expand All @@ -99,7 +99,7 @@ fn deserialize_when_gamma1_is_2_pow_17(serialized: &[u8]) -> PolynomialRingEleme
re.coefficients[4 * i + 3] |= (bytes[8] as i32) << 10;
re.coefficients[4 * i + 3] &= GAMMA1_TIMES_2_BITMASK;

re.coefficients[4 * i + 0] = GAMMA1 - re.coefficients[4 * i + 0];
re.coefficients[4 * i] = GAMMA1 - re.coefficients[4 * i];
re.coefficients[4 * i + 1] = GAMMA1 - re.coefficients[4 * i + 1];
re.coefficients[4 * i + 2] = GAMMA1 - re.coefficients[4 * i + 2];
re.coefficients[4 * i + 3] = GAMMA1 - re.coefficients[4 * i + 3];
Expand All @@ -116,16 +116,16 @@ fn deserialize_when_gamma1_is_2_pow_19(serialized: &[u8]) -> PolynomialRingEleme
let mut re = PolynomialRingElement::ZERO;

for (i, bytes) in serialized.chunks_exact(5).enumerate() {
re.coefficients[2 * i + 0] = bytes[0] as i32;
re.coefficients[2 * i + 0] |= (bytes[1] as i32) << 8;
re.coefficients[2 * i + 0] |= (bytes[2] as i32) << 16;
re.coefficients[2 * i + 0] &= GAMMA1_TIMES_2_BITMASK;
re.coefficients[2 * i] = bytes[0] as i32;
re.coefficients[2 * i] |= (bytes[1] as i32) << 8;
re.coefficients[2 * i] |= (bytes[2] as i32) << 16;
re.coefficients[2 * i] &= GAMMA1_TIMES_2_BITMASK;

re.coefficients[2 * i + 1] = (bytes[2] as i32) >> 4;
re.coefficients[2 * i + 1] |= (bytes[3] as i32) << 4;
re.coefficients[2 * i + 1] |= (bytes[4] as i32) << 12;

re.coefficients[2 * i + 0] = GAMMA1 - re.coefficients[2 * i + 0];
re.coefficients[2 * i] = GAMMA1 - re.coefficients[2 * i];
re.coefficients[2 * i + 1] = GAMMA1 - re.coefficients[2 * i + 1];
}

Expand Down
69 changes: 60 additions & 9 deletions libcrux-ml-dsa/src/encoding/signing_key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,37 +27,88 @@ pub(crate) fn generate_serialized<
let mut signing_key_serialized = [0u8; SIGNING_KEY_SIZE];
let mut offset = 0;

signing_key_serialized[offset..offset + SEED_FOR_A_SIZE].copy_from_slice(&seed_for_A);
signing_key_serialized[offset..offset + SEED_FOR_A_SIZE].copy_from_slice(seed_for_A);
offset += SEED_FOR_A_SIZE;

signing_key_serialized[offset..offset + SEED_FOR_SIGNING_SIZE]
.copy_from_slice(&seed_for_signing);
.copy_from_slice(seed_for_signing);
offset += SEED_FOR_SIGNING_SIZE;

let verification_key_hash = H::<BYTES_FOR_VERIFICATION_KEY_HASH>(verification_key);
let verification_key_hash = H::one_shot::<BYTES_FOR_VERIFICATION_KEY_HASH>(verification_key);
signing_key_serialized[offset..offset + BYTES_FOR_VERIFICATION_KEY_HASH]
.copy_from_slice(&verification_key_hash);
offset += BYTES_FOR_VERIFICATION_KEY_HASH;

for i in 0..COLUMNS_IN_A {
for ring_element in s1.iter() {
signing_key_serialized[offset..offset + ERROR_RING_ELEMENT_SIZE].copy_from_slice(
&encoding::error::serialize::<ETA, ERROR_RING_ELEMENT_SIZE>(s1[i]),
&encoding::error::serialize::<ETA, ERROR_RING_ELEMENT_SIZE>(*ring_element),
);
offset += ERROR_RING_ELEMENT_SIZE;
}

for i in 0..ROWS_IN_A {
for ring_element in s2.iter() {
signing_key_serialized[offset..offset + ERROR_RING_ELEMENT_SIZE].copy_from_slice(
&encoding::error::serialize::<ETA, ERROR_RING_ELEMENT_SIZE>(s2[i]),
&encoding::error::serialize::<ETA, ERROR_RING_ELEMENT_SIZE>(*ring_element),
);
offset += ERROR_RING_ELEMENT_SIZE;
}

for i in 0..ROWS_IN_A {
for ring_element in t0.iter() {
signing_key_serialized[offset..offset + RING_ELEMENT_OF_T0S_SIZE]
.copy_from_slice(&encoding::t0::serialize(t0[i]));
.copy_from_slice(&encoding::t0::serialize(*ring_element));
offset += RING_ELEMENT_OF_T0S_SIZE;
}

signing_key_serialized
}

#[allow(non_snake_case)]
#[inline(always)]
pub(crate) fn deserialize_then_ntt<
const ROWS_IN_A: usize,
const COLUMNS_IN_A: usize,
const ETA: usize,
const ERROR_RING_ELEMENT_SIZE: usize,
const SIGNING_KEY_SIZE: usize,
>(
serialized: [u8; SIGNING_KEY_SIZE],
) -> (
[u8; SEED_FOR_A_SIZE], // seed_for_A
[u8; SEED_FOR_SIGNING_SIZE], // seed_for_signing
[u8; BYTES_FOR_VERIFICATION_KEY_HASH], // verification_key_hash
[PolynomialRingElement; COLUMNS_IN_A], // s1
[PolynomialRingElement; ROWS_IN_A], // s2
[PolynomialRingElement; ROWS_IN_A], // t0_as_ntt
) {
let (seed_for_A, remaining_serialized) = serialized.split_at(SEED_FOR_A_SIZE);
let (seed_for_signing, remaining_serialized) =
remaining_serialized.split_at(SEED_FOR_SIGNING_SIZE);
let (verification_key_hash, remaining_serialized) =
remaining_serialized.split_at(BYTES_FOR_VERIFICATION_KEY_HASH);

let (s1_serialized, remaining_serialized) =
remaining_serialized.split_at(ERROR_RING_ELEMENT_SIZE * COLUMNS_IN_A);
let (s2_serialized, t0_serialized) =
remaining_serialized.split_at(ERROR_RING_ELEMENT_SIZE * ROWS_IN_A);

let s1_as_ntt = encoding::error::deserialize_to_vector_then_ntt::<
COLUMNS_IN_A,
ETA,
ERROR_RING_ELEMENT_SIZE,
>(s1_serialized);
let s2_as_ntt =
encoding::error::deserialize_to_vector_then_ntt::<ROWS_IN_A, ETA, ERROR_RING_ELEMENT_SIZE>(
s2_serialized,
);

let t0_as_ntt = encoding::t0::deserialize_to_vector_then_ntt::<ROWS_IN_A>(t0_serialized);

(
seed_for_A.try_into().unwrap(),
seed_for_signing.try_into().unwrap(),
verification_key_hash.try_into().unwrap(),
s1_as_ntt,
s2_as_ntt,
t0_as_ntt,
)
}
10 changes: 5 additions & 5 deletions libcrux-ml-dsa/src/encoding/t0.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pub(crate) fn serialize(re: PolynomialRingElement) -> [u8; RING_ELEMENT_OF_T0S_S
let coefficient6 = change_t0_interval(coefficients[6]);
let coefficient7 = change_t0_interval(coefficients[7]);

serialized[13 * i + 0] = coefficient0 as u8;
serialized[13 * i] = coefficient0 as u8;

serialized[13 * i + 1] = (coefficient0 >> 8) as u8;
serialized[13 * i + 1] |= (coefficient1 << 5) as u8;
Expand Down Expand Up @@ -87,9 +87,9 @@ fn deserialize(serialized: &[u8]) -> PolynomialRingElement {
let byte11 = bytes[11] as i32;
let byte12 = bytes[12] as i32;

re.coefficients[8 * i + 0] = byte0;
re.coefficients[8 * i + 0] |= byte1 << 8;
re.coefficients[8 * i + 0] &= BITS_IN_LOWER_PART_OF_T_MASK;
re.coefficients[8 * i] = byte0;
re.coefficients[8 * i] |= byte1 << 8;
re.coefficients[8 * i] &= BITS_IN_LOWER_PART_OF_T_MASK;

re.coefficients[8 * i + 1] = byte1 >> 5;
re.coefficients[8 * i + 1] |= byte2 << 3;
Expand Down Expand Up @@ -123,7 +123,7 @@ fn deserialize(serialized: &[u8]) -> PolynomialRingElement {
re.coefficients[8 * i + 7] |= byte12 << 5;
re.coefficients[8 * i + 7] &= BITS_IN_LOWER_PART_OF_T_MASK;

re.coefficients[8 * i + 0] = change_t0_interval(re.coefficients[8 * i + 0]);
re.coefficients[8 * i] = change_t0_interval(re.coefficients[8 * i]);
re.coefficients[8 * i + 1] = change_t0_interval(re.coefficients[8 * i + 1]);
re.coefficients[8 * i + 2] = change_t0_interval(re.coefficients[8 * i + 2]);
re.coefficients[8 * i + 3] = change_t0_interval(re.coefficients[8 * i + 3]);
Expand Down
2 changes: 1 addition & 1 deletion libcrux-ml-dsa/src/encoding/t1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ pub(crate) fn deserialize(serialized: &[u8]) -> PolynomialRingElement {
let byte3 = bytes[3] as i32;
let byte4 = bytes[4] as i32;

out.coefficients[4 * i + 0] = ((byte0 >> 0) | (byte1 << 8)) & mask;
out.coefficients[4 * i] = (byte0 | (byte1 << 8)) & mask;
out.coefficients[4 * i + 1] = ((byte1 >> 2) | (byte2 << 6)) & mask;
out.coefficients[4 * i + 2] = ((byte2 >> 4) | (byte3 << 4)) & mask;
out.coefficients[4 * i + 3] = ((byte3 >> 6) | (byte4 << 2)) & mask;
Expand Down
6 changes: 3 additions & 3 deletions libcrux-ml-dsa/src/encoding/verification_key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@ pub(crate) fn generate_serialized<const ROWS_IN_A: usize, const VERIFICATION_KEY
t1: [PolynomialRingElement; ROWS_IN_A],
) -> [u8; VERIFICATION_KEY_SIZE] {
let mut verification_key_serialized = [0u8; VERIFICATION_KEY_SIZE];
verification_key_serialized[0..SEED_FOR_A_SIZE].copy_from_slice(&seed_for_A);
verification_key_serialized[0..SEED_FOR_A_SIZE].copy_from_slice(seed_for_A);

for i in 0..ROWS_IN_A {
for (i, ring_element) in t1.iter().enumerate() {
let offset = SEED_FOR_A_SIZE + (i * RING_ELEMENT_OF_T1S_SIZE);
verification_key_serialized[offset..offset + RING_ELEMENT_OF_T1S_SIZE]
.copy_from_slice(&t1::serialize(t1[i]));
.copy_from_slice(&t1::serialize(*ring_element));
}

verification_key_serialized
Expand Down
39 changes: 35 additions & 4 deletions libcrux-ml-dsa/src/hash_functions.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,40 @@
#![allow(non_snake_case)]
pub(crate) fn H<const OUTPUT_LENGTH: usize>(input: &[u8]) -> [u8; OUTPUT_LENGTH] {
let mut out = [0u8; OUTPUT_LENGTH];
libcrux_sha3::portable::shake256(&mut out, input);

out
pub(crate) mod H {
use libcrux_sha3::portable::{incremental, shake256, KeccakState};

const BLOCK_SIZE: usize = 136;

pub(crate) fn one_shot<const OUTPUT_LENGTH: usize>(input: &[u8]) -> [u8; OUTPUT_LENGTH] {
let mut out = [0u8; OUTPUT_LENGTH];
shake256(&mut out, input);

out
}

#[inline(always)]
pub(crate) fn new(seed: &[u8]) -> KeccakState {
let mut state = incremental::shake256_init();
incremental::shake256_absorb_final(&mut state, seed);

state
}

#[inline(always)]
pub(crate) fn squeeze_first_block(state: &mut KeccakState) -> [u8; BLOCK_SIZE] {
let mut out = [0u8; BLOCK_SIZE];
incremental::shake256_squeeze_first_block(state, &mut out);

out
}

#[inline(always)]
pub(crate) fn squeeze_next_block(state: &mut KeccakState) -> [u8; BLOCK_SIZE] {
let mut out = [0u8; BLOCK_SIZE];
incremental::shake256_squeeze_next_block(state, &mut out);

out
}
}

pub(crate) mod H_128 {
Expand Down
Loading

0 comments on commit 4e5be59

Please sign in to comment.