Skip to content

Commit

Permalink
Setting up scaffolding for vectorization. (#367)
Browse files Browse the repository at this point in the history
I moved the code for `PolynomialRingElement` into `polynomial.rs`, and
in order to facilitate the transition towards a vectorized codebase, I
added another struct to polynomial.rs called
`VectorPolynomialRingElement`:

`VectorPolynomialRingElement` contains functions to convert to and from
a `PolynomialRingElement`. This way I can incrementally make changes to
the codebase; eventually `PolynomialRingElement` will be removed and
`VectorPolynomialRingElement` will be take its place. An example of how
this works can be seen
[here](https://github.com/cryspen/libcrux/compare/main...goutam/ml-dsa-vec-setup?expand=1#diff-3d667db8aecf82b13428fd3d918d06bb0e7a8b76b8c784e8a3ae3af0873a5adcR126).
  • Loading branch information
franziskuskiefer authored Jul 8, 2024
2 parents d1f5027 + edf08a2 commit 1727de0
Show file tree
Hide file tree
Showing 21 changed files with 472 additions and 97 deletions.
70 changes: 1 addition & 69 deletions libcrux-ml-dsa/src/arithmetic.rs
Original file line number Diff line number Diff line change
@@ -1,73 +1,5 @@
use crate::constants::{BITS_IN_LOWER_PART_OF_T, COEFFICIENTS_IN_RING_ELEMENT, FIELD_MODULUS};

#[derive(Clone, Copy, Debug)]
pub struct PolynomialRingElement {
pub(crate) coefficients: [FieldElement; COEFFICIENTS_IN_RING_ELEMENT],
}

impl PolynomialRingElement {
pub const ZERO: Self = Self {
// FIXME: hax issue, 256 is COEFFICIENTS_IN_RING_ELEMENT
coefficients: [0i32; 256],
};

#[inline(always)]
pub(crate) fn add(&self, rhs: &Self) -> Self {
let mut sum = Self::ZERO;

for i in 0..rhs.coefficients.len() {
sum.coefficients[i] = self.coefficients[i] + rhs.coefficients[i];
}

sum
}

#[inline(always)]
pub(crate) fn sub(&self, rhs: &Self) -> Self {
let mut difference = Self::ZERO;

for i in 0..rhs.coefficients.len() {
difference.coefficients[i] = self.coefficients[i] - rhs.coefficients[i];
}

difference
}

// TODO: Revisit this function when doing the range analysis and testing
// additional KATs.
#[inline(always)]
pub(crate) fn infinity_norm_exceeds(&self, bound: i32) -> bool {
let mut exceeds = false;

// It is ok to leak which coefficient violates the bound since
// the probability for each coefficient is independent of secret
// data but we must not leak the sign of the centralized representative.
//
// TODO: We can break out of this loop early if need be, but the most
// straightforward way to do so (returning false) will not go through hax;
// revisit if performance is impacted.
for coefficient in self.coefficients.into_iter() {
debug_assert!(
coefficient > -FIELD_MODULUS && coefficient < FIELD_MODULUS,
"coefficient is {}",
coefficient
);
// This norm is calculated using the absolute value of the
// signed representative in the range:
//
// -FIELD_MODULUS / 2 < r <= FIELD_MODULUS / 2.
//
// So if the coefficient is negative, get its absolute value, but
// don't convert it into a different representation.
let sign = coefficient >> 31;
let normalized = coefficient - (sign & (2 * coefficient));

exceeds |= normalized >= bound;
}

exceeds
}
}
use crate::polynomial::PolynomialRingElement;

#[inline(always)]
pub(crate) fn vector_infinity_norm_exceeds<const DIMENSION: usize>(
Expand Down
4 changes: 2 additions & 2 deletions libcrux-ml-dsa/src/encoding/commitment.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::arithmetic::PolynomialRingElement;
use crate::polynomial::PolynomialRingElement;

#[inline(always)]
fn serialize<const OUTPUT_SIZE: usize>(re: PolynomialRingElement) -> [u8; OUTPUT_SIZE] {
Expand Down Expand Up @@ -63,7 +63,7 @@ pub(crate) fn serialize_vector<
mod tests {
use super::*;

use crate::arithmetic::PolynomialRingElement;
use crate::polynomial::PolynomialRingElement;

#[test]
fn test_serialize_commitment() {
Expand Down
2 changes: 1 addition & 1 deletion libcrux-ml-dsa/src/encoding/error.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// Functions for serializing and deserializing an error ring element.

use crate::{arithmetic::PolynomialRingElement, ntt::ntt};
use crate::{ntt::ntt, polynomial::PolynomialRingElement};

#[inline(always)]
fn serialize_when_eta_is_2<const OUTPUT_SIZE: usize>(
Expand Down
2 changes: 1 addition & 1 deletion libcrux-ml-dsa/src/encoding/gamma1.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::arithmetic::PolynomialRingElement;
use crate::polynomial::PolynomialRingElement;

#[inline(always)]
fn serialize_when_gamma1_is_2_pow_17<const OUTPUT_SIZE: usize>(
Expand Down
2 changes: 1 addition & 1 deletion libcrux-ml-dsa/src/encoding/signing_key.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use crate::{
arithmetic::PolynomialRingElement,
constants::{
BYTES_FOR_VERIFICATION_KEY_HASH, RING_ELEMENT_OF_T0S_SIZE, SEED_FOR_A_SIZE,
SEED_FOR_SIGNING_SIZE,
},
encoding,
hash_functions::H,
polynomial::PolynomialRingElement,
};

#[allow(non_snake_case)]
Expand Down
4 changes: 2 additions & 2 deletions libcrux-ml-dsa/src/encoding/t0.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
// ---------------------------------------------------------------------------

use crate::{
arithmetic::PolynomialRingElement,
constants::{BITS_IN_LOWER_PART_OF_T, RING_ELEMENT_OF_T0S_SIZE},
ntt::ntt,
polynomial::PolynomialRingElement,
};

// If t0 is a signed representative, change it to an unsigned one and
Expand Down Expand Up @@ -153,7 +153,7 @@ pub(crate) fn deserialize_to_vector_then_ntt<const DIMENSION: usize>(
mod tests {
use super::*;

use crate::arithmetic::PolynomialRingElement;
use crate::polynomial::PolynomialRingElement;

#[test]
fn test_serialize() {
Expand Down
2 changes: 1 addition & 1 deletion libcrux-ml-dsa/src/encoding/t1.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{
arithmetic::PolynomialRingElement,
constants::{BITS_IN_UPPER_PART_OF_T, RING_ELEMENT_OF_T1S_SIZE},
polynomial::PolynomialRingElement,
};

// Each coefficient takes up 10 bits.
Expand Down
2 changes: 1 addition & 1 deletion libcrux-ml-dsa/src/encoding/verification_key.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::{
arithmetic::PolynomialRingElement,
constants::{RING_ELEMENT_OF_T1S_SIZE, SEED_FOR_A_SIZE},
encoding::t1,
polynomial::PolynomialRingElement,
};

#[allow(non_snake_case)]
Expand Down
2 changes: 2 additions & 0 deletions libcrux-ml-dsa/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ mod hash_functions;
mod matrix;
mod ml_dsa_generic;
mod ntt;
mod polynomial;
mod sample;
mod utils;
mod vector;

pub use ml_dsa_generic::VerificationError;

Expand Down
23 changes: 20 additions & 3 deletions libcrux-ml-dsa/src/matrix.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use crate::{
arithmetic::{shift_coefficients_left_then_reduce, PolynomialRingElement},
arithmetic::shift_coefficients_left_then_reduce,
constants::BITS_IN_LOWER_PART_OF_T,
ntt::{invert_ntt_montgomery, ntt, ntt_multiply_montgomery},
polynomial::{PolynomialRingElement, VectorPolynomialRingElement},
sample::sample_ring_element_uniform,
vector::portable::PortableVector,
};

#[allow(non_snake_case)]
Expand All @@ -12,6 +14,7 @@ pub(crate) fn expand_to_A<const ROWS_IN_A: usize, const COLUMNS_IN_A: usize>(
) -> [[PolynomialRingElement; COLUMNS_IN_A]; ROWS_IN_A] {
let mut A = [[PolynomialRingElement::ZERO; COLUMNS_IN_A]; ROWS_IN_A];

// Mutable iterators won't go through hax, so we need these range loops.
#[allow(clippy::needless_range_loop)]
for i in 0..ROWS_IN_A {
for j in 0..COLUMNS_IN_A {
Expand Down Expand Up @@ -94,7 +97,14 @@ pub(crate) fn add_vectors<const DIMENSION: usize>(
let mut result = [PolynomialRingElement::ZERO; DIMENSION];

for i in 0..DIMENSION {
result[i] = lhs[i].add(&rhs[i]);
let lhs_vectorized =
VectorPolynomialRingElement::<PortableVector>::from_polynomial_ring_element(lhs[i]);
let rhs_vectorized =
VectorPolynomialRingElement::<PortableVector>::from_polynomial_ring_element(rhs[i]);

result[i] = lhs_vectorized
.add(&rhs_vectorized)
.to_polynomial_ring_element();
}

result
Expand All @@ -109,7 +119,14 @@ pub(crate) fn subtract_vectors<const DIMENSION: usize>(
let mut result = [PolynomialRingElement::ZERO; DIMENSION];

for i in 0..DIMENSION {
result[i] = lhs[i].sub(&rhs[i]);
let lhs_vectorized =
VectorPolynomialRingElement::<PortableVector>::from_polynomial_ring_element(lhs[i]);
let rhs_vectorized =
VectorPolynomialRingElement::<PortableVector>::from_polynomial_ring_element(rhs[i]);

result[i] = lhs_vectorized
.subtract(&rhs_vectorized)
.to_polynomial_ring_element();
}

result
Expand Down
2 changes: 1 addition & 1 deletion libcrux-ml-dsa/src/ml_dsa_generic.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use crate::{
arithmetic::{
decompose_vector, make_hint, power2round_vector, use_hint, vector_infinity_norm_exceeds,
PolynomialRingElement,
},
constants::*,
encoding,
Expand All @@ -11,6 +10,7 @@ use crate::{
subtract_vectors, vector_times_ring_element,
},
ntt::ntt,
polynomial::PolynomialRingElement,
sample::{sample_challenge_ring_element, sample_error_vector, sample_mask_vector},
utils::into_padded_array,
};
Expand Down
89 changes: 76 additions & 13 deletions libcrux-ml-dsa/src/ntt.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use super::{
arithmetic::{
montgomery_multiply_fe_by_fer, montgomery_reduce, FieldElementTimesMontgomeryR,
PolynomialRingElement,
},
use crate::{
arithmetic::{montgomery_multiply_fe_by_fer, montgomery_reduce, FieldElementTimesMontgomeryR},
constants::COEFFICIENTS_IN_RING_ELEMENT,
polynomial::{PolynomialRingElement, VectorPolynomialRingElement, VECTORS_IN_RING_ELEMENT},
vector::{
portable::PortableVector,
traits::{montgomery_multiply_by_fer, Operations, COEFFICIENTS_PER_VECTOR},
},
};

const ZETAS_TIMES_MONTGOMERY_R: [FieldElementTimesMontgomeryR; 256] = [
Expand Down Expand Up @@ -35,6 +37,62 @@ const ZETAS_TIMES_MONTGOMERY_R: [FieldElementTimesMontgomeryR; 256] = [
-1362209, 3937738, 1400424, -846154, 1976782,
];

#[inline(always)]
pub(crate) fn ntt_at_layer_1<Vector: Operations>(
zeta_i: &mut usize,
re: &mut VectorPolynomialRingElement<Vector>,
) {
*zeta_i += 1;

for round in 0..VECTORS_IN_RING_ELEMENT {
re.coefficients[round] = Vector::ntt_at_layer_1(
re.coefficients[round],
ZETAS_TIMES_MONTGOMERY_R[*zeta_i],
ZETAS_TIMES_MONTGOMERY_R[*zeta_i + 1],
);

*zeta_i += 2;
}

*zeta_i -= 1;
}
#[inline(always)]
pub(crate) fn ntt_at_layer_2<Vector: Operations>(
zeta_i: &mut usize,
re: &mut VectorPolynomialRingElement<Vector>,
) {
for round in 0..VECTORS_IN_RING_ELEMENT {
*zeta_i += 1;
re.coefficients[round] =
Vector::ntt_at_layer_2(re.coefficients[round], ZETAS_TIMES_MONTGOMERY_R[*zeta_i]);
}
}
#[inline(always)]
pub(crate) fn ntt_at_layer_3_plus<Vector: Operations>(
zeta_i: &mut usize,
re: &mut VectorPolynomialRingElement<Vector>,
layer: usize,
) {
let step = 1 << layer;

for round in 0..(128 >> layer) {
*zeta_i += 1;

let offset = (round * step * 2) / COEFFICIENTS_PER_VECTOR;
let step_by = step / COEFFICIENTS_PER_VECTOR;

for j in offset..offset + step_by {
let t = montgomery_multiply_by_fer::<Vector>(
re.coefficients[j + step_by],
ZETAS_TIMES_MONTGOMERY_R[*zeta_i],
);

re.coefficients[j + step_by] = Vector::subtract(&re.coefficients[j], &t);
re.coefficients[j] = Vector::add(&re.coefficients[j], &t);
}
}
}

#[inline(always)]
fn ntt_at_layer(
zeta_i: &mut usize,
Expand Down Expand Up @@ -62,16 +120,21 @@ fn ntt_at_layer(
}

#[inline(always)]
pub(crate) fn ntt(mut re: PolynomialRingElement) -> PolynomialRingElement {
pub(crate) fn ntt(re: PolynomialRingElement) -> PolynomialRingElement {
let mut zeta_i = 0;

re = ntt_at_layer(&mut zeta_i, re, 7);
re = ntt_at_layer(&mut zeta_i, re, 6);
re = ntt_at_layer(&mut zeta_i, re, 5);
re = ntt_at_layer(&mut zeta_i, re, 4);
re = ntt_at_layer(&mut zeta_i, re, 3);
re = ntt_at_layer(&mut zeta_i, re, 2);
re = ntt_at_layer(&mut zeta_i, re, 1);
let mut v_re = VectorPolynomialRingElement::<PortableVector>::from_polynomial_ring_element(re);

ntt_at_layer_3_plus::<PortableVector>(&mut zeta_i, &mut v_re, 7);
ntt_at_layer_3_plus::<PortableVector>(&mut zeta_i, &mut v_re, 6);
ntt_at_layer_3_plus::<PortableVector>(&mut zeta_i, &mut v_re, 5);
ntt_at_layer_3_plus::<PortableVector>(&mut zeta_i, &mut v_re, 4);
ntt_at_layer_3_plus::<PortableVector>(&mut zeta_i, &mut v_re, 3);
ntt_at_layer_2::<PortableVector>(&mut zeta_i, &mut v_re);
ntt_at_layer_1::<PortableVector>(&mut zeta_i, &mut v_re);

let mut re = v_re.to_polynomial_ring_element();

re = ntt_at_layer(&mut zeta_i, re, 0);

re
Expand Down
Loading

0 comments on commit 1727de0

Please sign in to comment.