Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ML-DSA: Vectorizing more parts of the code. #391

Merged
merged 3 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 47 additions & 127 deletions libcrux-ml-dsa/src/encoding/gamma1.rs
Original file line number Diff line number Diff line change
@@ -1,146 +1,66 @@
use crate::polynomial::PolynomialRingElement;
use crate::{
polynomial::{PolynomialRingElement, SIMDPolynomialRingElement},
simd::{portable::PortableSIMDUnit, traits::Operations},
};

#[inline(always)]
fn serialize_when_gamma1_is_2_pow_17<const OUTPUT_SIZE: usize>(
pub(crate) fn serialize<const GAMMA1_EXPONENT: usize, const OUTPUT_BYTES: usize>(
re: PolynomialRingElement,
) -> [u8; OUTPUT_SIZE] {
let mut serialized = [0u8; OUTPUT_SIZE];
const GAMMA1: i32 = 1 << 17;
) -> [u8; OUTPUT_BYTES] {
let mut serialized = [0u8; OUTPUT_BYTES];

for (i, coefficients) in re.coefficients.chunks_exact(4).enumerate() {
let coefficient0 = GAMMA1 - coefficients[0];
let coefficient1 = GAMMA1 - coefficients[1];
let coefficient2 = GAMMA1 - coefficients[2];
let coefficient3 = GAMMA1 - coefficients[3];
let mut v_re = SIMDPolynomialRingElement::<PortableSIMDUnit>::from_polynomial_ring_element(re);

serialized[9 * i] = coefficient0 as u8;
serialized[9 * i + 1] = (coefficient0 >> 8) as u8;

serialized[9 * i + 2] = (coefficient0 >> 16) as u8;
serialized[9 * i + 2] |= (coefficient1 << 2) as u8;

serialized[9 * i + 3] = (coefficient1 >> 6) as u8;

serialized[9 * i + 4] = (coefficient1 >> 14) as u8;
serialized[9 * i + 4] |= (coefficient2 << 4) as u8;

serialized[9 * i + 5] = (coefficient2 >> 4) as u8;

serialized[9 * i + 6] = (coefficient2 >> 12) as u8;
serialized[9 * i + 6] |= (coefficient3 << 6) as u8;

serialized[9 * i + 7] = (coefficient3 >> 2) as u8;
serialized[9 * i + 8] = (coefficient3 >> 10) as u8;
}

serialized
}

#[inline(always)]
fn serialize_when_gamma1_is_2_pow_19<const OUTPUT_SIZE: usize>(
re: PolynomialRingElement,
) -> [u8; OUTPUT_SIZE] {
let mut serialized = [0u8; OUTPUT_SIZE];
const GAMMA1: i32 = 1 << 19;

for (i, coefficients) in re.coefficients.chunks_exact(2).enumerate() {
let coefficient0 = GAMMA1 - coefficients[0];
let coefficient1 = GAMMA1 - coefficients[1];

serialized[5 * i] = coefficient0 as u8;
serialized[5 * i + 1] = (coefficient0 >> 8) as u8;

serialized[5 * i + 2] = (coefficient0 >> 16) as u8;
serialized[5 * i + 2] |= (coefficient1 << 4) as u8;

serialized[5 * i + 3] = (coefficient1 >> 4) as u8;
serialized[5 * i + 4] = (coefficient1 >> 12) as u8;
}

serialized
}

#[inline(always)]
pub(crate) fn serialize<const GAMMA1_EXPONENT: usize, const OUTPUT_SIZE: usize>(
re: PolynomialRingElement,
) -> [u8; OUTPUT_SIZE] {
match GAMMA1_EXPONENT {
17 => serialize_when_gamma1_is_2_pow_17(re),
19 => serialize_when_gamma1_is_2_pow_19(re),
17 => {
const OUTPUT_BYTES_PER_SIMD_UNIT: usize = 18;

for (i, simd_unit) in v_re.simd_units.iter().enumerate() {
serialized[i * OUTPUT_BYTES_PER_SIMD_UNIT..(i + 1) * OUTPUT_BYTES_PER_SIMD_UNIT]
.copy_from_slice(&PortableSIMDUnit::gamma1_serialize::<
GAMMA1_EXPONENT,
OUTPUT_BYTES_PER_SIMD_UNIT,
>(*simd_unit));
}

serialized
}
19 => {
const OUTPUT_BYTES_PER_SIMD_UNIT: usize = 20;

for (i, simd_unit) in v_re.simd_units.iter().enumerate() {
serialized[i * OUTPUT_BYTES_PER_SIMD_UNIT..(i + 1) * OUTPUT_BYTES_PER_SIMD_UNIT]
.copy_from_slice(&PortableSIMDUnit::gamma1_serialize::<
GAMMA1_EXPONENT,
OUTPUT_BYTES_PER_SIMD_UNIT,
>(*simd_unit));
}

serialized
}
_ => unreachable!(),
}
}

#[inline(always)]
fn deserialize_when_gamma1_is_2_pow_17(serialized: &[u8]) -> PolynomialRingElement {
const GAMMA1: i32 = 1 << 17;
const GAMMA1_TIMES_2_BITMASK: i32 = (GAMMA1 << 1) - 1;

let mut re = PolynomialRingElement::ZERO;

for (i, bytes) in serialized.chunks_exact(9).enumerate() {
re.coefficients[4 * i] = bytes[0] as i32;
re.coefficients[4 * i] |= (bytes[1] as i32) << 8;
re.coefficients[4 * i] |= (bytes[2] as i32) << 16;
re.coefficients[4 * i] &= GAMMA1_TIMES_2_BITMASK;

re.coefficients[4 * i + 1] = (bytes[2] as i32) >> 2;
re.coefficients[4 * i + 1] |= (bytes[3] as i32) << 6;
re.coefficients[4 * i + 1] |= (bytes[4] as i32) << 14;
re.coefficients[4 * i + 1] &= GAMMA1_TIMES_2_BITMASK;

re.coefficients[4 * i + 2] = (bytes[4] as i32) >> 4;
re.coefficients[4 * i + 2] |= (bytes[5] as i32) << 4;
re.coefficients[4 * i + 2] |= (bytes[6] as i32) << 12;
re.coefficients[4 * i + 2] &= GAMMA1_TIMES_2_BITMASK;

re.coefficients[4 * i + 3] = (bytes[6] as i32) >> 6;
re.coefficients[4 * i + 3] |= (bytes[7] as i32) << 2;
re.coefficients[4 * i + 3] |= (bytes[8] as i32) << 10;
re.coefficients[4 * i + 3] &= GAMMA1_TIMES_2_BITMASK;

re.coefficients[4 * i] = GAMMA1 - re.coefficients[4 * i];
re.coefficients[4 * i + 1] = GAMMA1 - re.coefficients[4 * i + 1];
re.coefficients[4 * i + 2] = GAMMA1 - re.coefficients[4 * i + 2];
re.coefficients[4 * i + 3] = GAMMA1 - re.coefficients[4 * i + 3];
}

re
}

#[inline(always)]
fn deserialize_when_gamma1_is_2_pow_19(serialized: &[u8]) -> PolynomialRingElement {
const GAMMA1: i32 = 1 << 19;
const GAMMA1_TIMES_2_BITMASK: i32 = (GAMMA1 << 1) - 1;

let mut re = PolynomialRingElement::ZERO;

for (i, bytes) in serialized.chunks_exact(5).enumerate() {
re.coefficients[2 * i] = bytes[0] as i32;
re.coefficients[2 * i] |= (bytes[1] as i32) << 8;
re.coefficients[2 * i] |= (bytes[2] as i32) << 16;
re.coefficients[2 * i] &= GAMMA1_TIMES_2_BITMASK;

re.coefficients[2 * i + 1] = (bytes[2] as i32) >> 4;
re.coefficients[2 * i + 1] |= (bytes[3] as i32) << 4;
re.coefficients[2 * i + 1] |= (bytes[4] as i32) << 12;

re.coefficients[2 * i] = GAMMA1 - re.coefficients[2 * i];
re.coefficients[2 * i + 1] = GAMMA1 - re.coefficients[2 * i + 1];
}

re
}

#[inline(always)]
pub(crate) fn deserialize<const GAMMA1_EXPONENT: usize>(
serialized: &[u8],
) -> PolynomialRingElement {
match GAMMA1_EXPONENT {
17 => deserialize_when_gamma1_is_2_pow_17(serialized),
19 => deserialize_when_gamma1_is_2_pow_19(serialized),
let mut serialized_chunks = match GAMMA1_EXPONENT {
17 => serialized.chunks(18),
19 => serialized.chunks(20),
_ => unreachable!(),
};

let mut result = SIMDPolynomialRingElement::ZERO();

for i in 0..result.simd_units.len() {
result.simd_units[i] = PortableSIMDUnit::gamma1_deserialize::<GAMMA1_EXPONENT>(
&serialized_chunks.next().unwrap(),
);
}

result.to_polynomial_ring_element()
}

#[cfg(test)]
Expand Down
85 changes: 35 additions & 50 deletions libcrux-ml-dsa/src/sample.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,97 +61,82 @@ pub(crate) fn sample_ring_element_uniform(seed: [u8; 34]) -> PolynomialRingEleme
#[inline(always)]
fn rejection_sample_less_than_eta_equals_2(
randomness: &[u8],
sampled: &mut usize,
out: &mut PolynomialRingElement,
sampled_coefficients: &mut usize,
out: &mut [i32; 263],
) -> bool {
let mut done = false;

for byte in randomness {
// Since each byte can be used to sample up to 2 coefficients, and since
// a single SIMDUnit can hold 8 coefficients, we pass in 4 bytes of randomness.
for random_bytes in randomness.chunks(4) {
if !done {
let try_0 = byte & 0xF;
let try_1 = byte >> 4;

if try_0 < 15 && *sampled < out.coefficients.len() {
let try_0 = try_0 as i32;

// (try_0 * 26) >> 7 computes ⌊try_0 / 5⌋
let try_0_mod_5 = try_0 - ((try_0 * 26) >> 7) * 5;

out.coefficients[*sampled] = 2 - try_0_mod_5;

*sampled += 1;
}

if try_1 < 15 && *sampled < out.coefficients.len() {
let try_1 = try_1 as i32;
let try_1_mod_5 = try_1 - ((try_1 * 26) >> 7) * 5;

out.coefficients[*sampled] = 2 - try_1_mod_5;

*sampled += 1;
}
let sampled = PortableSIMDUnit::rejection_sample_less_than_eta_equals_2(
random_bytes,
&mut out[*sampled_coefficients..],
);
*sampled_coefficients += sampled;

if *sampled == out.coefficients.len() {
if *sampled_coefficients >= COEFFICIENTS_IN_RING_ELEMENT {
done = true;
}
}
}

done
}

#[inline(always)]
fn rejection_sample_less_than_eta_equals_4(
randomness: &[u8],
sampled: &mut usize,
out: &mut PolynomialRingElement,
sampled_coefficients: &mut usize,
out: &mut [i32; 263],
) -> bool {
let mut done = false;

for byte in randomness {
// Since each byte can be used to sample up to 2 coefficients, and since
// a single SIMDUnit can hold 8 coefficients, we pass in 4 bytes of randomness.
for random_bytes in randomness.chunks(4) {
if !done {
let try_0 = byte & 0xF;
let try_1 = byte >> 4;

if try_0 < 9 && *sampled < out.coefficients.len() {
out.coefficients[*sampled] = 4 - (try_0 as i32);
*sampled += 1;
}

if try_1 < 9 && *sampled < out.coefficients.len() {
out.coefficients[*sampled] = 4 - (try_1 as i32);
*sampled += 1;
}
let sampled = PortableSIMDUnit::rejection_sample_less_than_eta_equals_4(
random_bytes,
&mut out[*sampled_coefficients..],
);
*sampled_coefficients += sampled;

if *sampled == out.coefficients.len() {
if *sampled_coefficients >= COEFFICIENTS_IN_RING_ELEMENT {
done = true;
}
}
}

done
}

#[inline(always)]
pub(crate) fn rejection_sample_less_than_eta<const ETA: usize>(
randomness: &[u8],
sampled: &mut usize,
out: &mut PolynomialRingElement,
out: &mut [i32; 263],
) -> bool {
match ETA {
2 => rejection_sample_less_than_eta_equals_2(randomness, sampled, out),
4 => rejection_sample_less_than_eta_equals_4(randomness, sampled, out),
_ => unreachable!(),
}
}

#[allow(non_snake_case)]
#[inline(always)]
fn sample_error_ring_element<const ETA: usize>(seed: [u8; 66]) -> PolynomialRingElement {
let mut state = H::new(&seed);
let randomness = H::squeeze_first_block(&mut state);

let mut out = PolynomialRingElement::ZERO;
// Every call to |rejection_sample_less_than_field_modulus|
// will result in a call to |PortableSIMDUnit::rejection_sample_less_than_field_modulus|;
// this latter function performs no bounds checking and can write up to 8
// elements to its output. It is therefore possible that 255 elements have
// already been sampled and we call the function again.
//
// To ensure we don't overflow the buffer in this case, we allocate 255 + 8
// = 263 elements.
let mut out = [0i32; 263];

let mut sampled = 0;
let mut done = rejection_sample_less_than_eta::<ETA>(&randomness, &mut sampled, &mut out);
Expand All @@ -161,9 +146,10 @@ fn sample_error_ring_element<const ETA: usize>(seed: [u8; 66]) -> PolynomialRing
done = rejection_sample_less_than_eta::<ETA>(&randomness, &mut sampled, &mut out);
}

out
PolynomialRingElement {
coefficients: out[0..256].try_into().unwrap(),
}
}

#[inline(always)]
pub(crate) fn sample_error_vector<const DIMENSION: usize, const ETA: usize>(
mut seed: [u8; 66],
Expand Down Expand Up @@ -191,7 +177,6 @@ fn sample_mask_ring_element<const GAMMA1_EXPONENT: usize>(seed: [u8; 66]) -> Pol
_ => unreachable!(),
}
}

#[inline(always)]
pub(crate) fn sample_mask_vector<const DIMENSION: usize, const GAMMA1_EXPONENT: usize>(
mut seed: [u8; 66],
Expand Down
16 changes: 16 additions & 0 deletions libcrux-ml-dsa/src/simd/portable.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::simd::traits::Operations;

mod arithmetic;
mod encoding;
mod ntt;
mod sample;
mod simd_unit_type;
Expand Down Expand Up @@ -46,6 +47,21 @@ impl Operations for PortableSIMDUnit {
fn rejection_sample_less_than_field_modulus(randomness: &[u8], out: &mut [i32]) -> usize {
sample::rejection_sample_less_than_field_modulus(randomness, out)
}
fn rejection_sample_less_than_eta_equals_2(randomness: &[u8], out: &mut [i32]) -> usize {
sample::rejection_sample_less_than_eta_equals_2(randomness, out)
}
fn rejection_sample_less_than_eta_equals_4(randomness: &[u8], out: &mut [i32]) -> usize {
sample::rejection_sample_less_than_eta_equals_4(randomness, out)
}

fn gamma1_serialize<const GAMMA1_EXPONENT: usize, const OUTPUT_BYTES: usize>(
simd_unit: Self,
) -> [u8; OUTPUT_BYTES] {
encoding::gamma1::serialize::<GAMMA1_EXPONENT, OUTPUT_BYTES>(simd_unit)
}
fn gamma1_deserialize<const GAMMA1_EXPONENT: usize>(serialized: &[u8]) -> Self {
encoding::gamma1::deserialize::<GAMMA1_EXPONENT>(serialized)
}

fn ntt_at_layer_0(simd_unit: Self, zeta0: i32, zeta1: i32, zeta2: i32, zeta3: i32) -> Self {
ntt::ntt_at_layer_0(simd_unit, zeta0, zeta1, zeta2, zeta3)
Expand Down
1 change: 1 addition & 0 deletions libcrux-ml-dsa/src/simd/portable/encoding.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub mod gamma1;
Loading
Loading