Skip to content

Commit

Permalink
Vectorized structure for rejection sampling.
Browse files Browse the repository at this point in the history
  • Loading branch information
xvzcf committed Jul 10, 2024
1 parent 0ca7598 commit ff8309c
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 18 deletions.
14 changes: 14 additions & 0 deletions libcrux-ml-dsa/src/polynomial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,20 @@ impl<SIMDUnit: Operations> SIMDPolynomialRingElement<SIMDUnit> {
}
}

#[inline(always)]
pub(crate) fn from_i32_array(array: &[i32]) -> Self {
debug_assert!(array.len() == 256);

let mut array_chunks = array.chunks(COEFFICIENTS_IN_SIMD_UNIT);

let mut result = Self::ZERO();

for i in 0..SIMD_UNITS_IN_RING_ELEMENT {
result.simd_units[i] = SIMDUnit::from_i32_array(&array_chunks.next().unwrap());
}
result
}

pub(crate) fn infinity_norm_exceeds(&self, bound: i32) -> bool {
let mut exceeds = false;

Expand Down
41 changes: 23 additions & 18 deletions libcrux-ml-dsa/src/sample.rs
Original file line number Diff line number Diff line change
@@ -1,46 +1,49 @@
use crate::{
constants::FIELD_MODULUS,
constants::COEFFICIENTS_IN_RING_ELEMENT,
encoding,
hash_functions::{H, H_128},
polynomial::PolynomialRingElement,
simd::{portable::PortableSIMDUnit, traits::Operations},
};

#[inline(always)]
fn rejection_sample_less_than_field_modulus(
randomness: &[u8],
sampled: &mut usize,
out: &mut PolynomialRingElement,
sampled_coefficients: &mut usize,
out: &mut [i32; 263],
) -> bool {
let mut done = false;

for bytes in randomness.chunks(3) {
for random_bytes in randomness.chunks(24) {
if !done {
let b0 = bytes[0] as i32;
let b1 = bytes[1] as i32;
let b2 = bytes[2] as i32;

let potential_coefficient = ((b2 << 16) | (b1 << 8) | b0) & 0x00_7F_FF_FF;

if potential_coefficient < FIELD_MODULUS && *sampled < out.coefficients.len() {
out.coefficients[*sampled] = potential_coefficient;
*sampled += 1;
}
let sampled = PortableSIMDUnit::rejection_sample_less_than_field_modulus(
random_bytes,
&mut out[*sampled_coefficients..],
);
*sampled_coefficients += sampled;

if *sampled == out.coefficients.len() {
if *sampled_coefficients >= COEFFICIENTS_IN_RING_ELEMENT {
done = true;
}
}
}

done
}

#[inline(always)]
pub(crate) fn sample_ring_element_uniform(seed: [u8; 34]) -> PolynomialRingElement {
let mut state = H_128::new(seed);
let randomness = H_128::squeeze_first_five_blocks(&mut state);

let mut out = PolynomialRingElement::ZERO;
// Every call to |rejection_sample_less_than_field_modulus|
// will result in a call to |PortableSIMDUnit::rejection_sample_less_than_field_modulus|;
// this latter function performs no bounds checking and can write up to 8
// elements to its output. It is therefore possible that 255 elements have
// already been sampled and we call the function again.
//
// To ensure we don't overflow the buffer in this case, we allocate 255 + 8
// = 263 elements.
let mut out = [0i32; 263];

let mut sampled = 0;
let mut done = rejection_sample_less_than_field_modulus(&randomness, &mut sampled, &mut out);
Expand All @@ -50,7 +53,9 @@ pub(crate) fn sample_ring_element_uniform(seed: [u8; 34]) -> PolynomialRingEleme
done = rejection_sample_less_than_field_modulus(&randomness, &mut sampled, &mut out);
}

out
PolynomialRingElement {
coefficients: out[0..256].try_into().unwrap(),
}
}

#[inline(always)]
Expand Down
5 changes: 5 additions & 0 deletions libcrux-ml-dsa/src/simd/portable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::simd::traits::Operations;

mod arithmetic;
mod ntt;
mod sample;
mod simd_unit_type;

pub(crate) use simd_unit_type::PortableSIMDUnit;
Expand Down Expand Up @@ -42,6 +43,10 @@ impl Operations for PortableSIMDUnit {
arithmetic::infinity_norm_exceeds(simd_unit, bound)
}

fn rejection_sample_less_than_field_modulus(randomness: &[u8], out: &mut [i32]) -> usize {
sample::rejection_sample_less_than_field_modulus(randomness, out)
}

fn ntt_at_layer_0(simd_unit: Self, zeta0: i32, zeta1: i32, zeta2: i32, zeta3: i32) -> Self {
ntt::ntt_at_layer_0(simd_unit, zeta0, zeta1, zeta2, zeta3)
}
Expand Down
21 changes: 21 additions & 0 deletions libcrux-ml-dsa/src/simd/portable/sample.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
use crate::simd::traits::FIELD_MODULUS;

#[inline(always)]
pub fn rejection_sample_less_than_field_modulus(randomness: &[u8], out: &mut [i32]) -> usize {
let mut sampled = 0;

for bytes in randomness.chunks(3) {
let b0 = bytes[0] as i32;
let b1 = bytes[1] as i32;
let b2 = bytes[2] as i32;

let coefficient = ((b2 << 16) | (b1 << 8) | b0) & 0x00_7F_FF_FF;

if coefficient < FIELD_MODULUS {
out[sampled] = coefficient;
sampled += 1;
}
}

sampled
}
7 changes: 7 additions & 0 deletions libcrux-ml-dsa/src/simd/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ pub(crate) trait Operations: Copy + Clone {
// Decomposition operations
fn power2round(simd_unit: Self) -> (Self, Self);

// Sampling

// Since each SIMD unit can hold 8 coefficients, and each coefficient needs
// (at least) 3 bytes to be sampled, we expect that |randomness| holds 24 bytes,
// and that |out| holds 8 i32s.
fn rejection_sample_less_than_field_modulus(randomness: &[u8], out: &mut [i32]) -> usize;

// NTT
fn ntt_at_layer_0(simd_unit: Self, zeta0: i32, zeta1: i32, zeta2: i32, zeta3: i32) -> Self;
fn ntt_at_layer_1(simd_unit: Self, zeta0: i32, zeta1: i32) -> Self;
Expand Down

0 comments on commit ff8309c

Please sign in to comment.