Skip to content

Commit

Permalink
avx2 ntt cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
franziskuskiefer committed Dec 23, 2024
1 parent ae86962 commit 36fa5e0
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 201 deletions.
15 changes: 8 additions & 7 deletions libcrux-ml-dsa/src/encoding/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ impl<
}
}

#[allow(non_snake_case)]
#[inline(always)]
pub(crate) fn deserialize<
const GAMMA1_EXPONENT: usize,
Expand All @@ -75,7 +74,8 @@ impl<
const SIGNATURE_SIZE: usize,
>(
serialized: &[u8; SIGNATURE_SIZE],
) -> Result<Self, VerificationError> {
signature: &mut Self,
) -> Result<(), VerificationError> {
let (commitment_hash, rest_of_serialized) = serialized.split_at(COMMITMENT_HASH_SIZE);
let (signer_response_serialized, hint_serialized) =
rest_of_serialized.split_at(GAMMA1_RING_ELEMENT_SIZE * COLUMNS_IN_A);
Expand Down Expand Up @@ -141,10 +141,11 @@ impl<
return Err(VerificationError::MalformedHintError);
}

Ok(Signature {
commitment_hash: commitment_hash.try_into().unwrap(),
signer_response,
hint,
})
// Set output
signature.commitment_hash = commitment_hash.try_into().unwrap();
signature.signer_response = signer_response;
signature.hint = hint;

Ok(())
}
}
3 changes: 2 additions & 1 deletion libcrux-ml-dsa/src/matrix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,11 @@ pub(crate) fn compute_w_approx<
const COLUMNS_IN_A: usize,
>(
A_as_ntt: &[[PolynomialRingElement<SIMDUnit>; COLUMNS_IN_A]; ROWS_IN_A],
mut signer_response: [PolynomialRingElement<SIMDUnit>; COLUMNS_IN_A],
signer_response: &[PolynomialRingElement<SIMDUnit>; COLUMNS_IN_A],
verifier_challenge_as_ntt: &PolynomialRingElement<SIMDUnit>,
t1: &mut [PolynomialRingElement<SIMDUnit>; ROWS_IN_A],
) {
let mut signer_response = signer_response.clone();
// Move signer response into NTT
for i in 0..signer_response.len() {
ntt(&mut signer_response[i]);
Expand Down
39 changes: 19 additions & 20 deletions libcrux-ml-dsa/src/ml_dsa_generic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ pub(crate) fn sign_internal<
let mut message_representative = [0; MESSAGE_REPRESENTATIVE_SIZE];
derive_message_representative::<Shake256Xof>(
verification_key_hash,
domain_separation_context,
&domain_separation_context,
message,
&mut message_representative,
);
Expand Down Expand Up @@ -494,7 +494,7 @@ pub(crate) fn sign_internal<
#[inline(always)]
fn derive_message_representative<Shake256Xof: shake256::Xof>(
verification_key_hash: &[u8],
domain_separation_context: Option<DomainSeparationContext>,
domain_separation_context: &Option<DomainSeparationContext>,
message: &[u8],
message_representative: &mut [u8; 64],
) {
Expand Down Expand Up @@ -553,22 +553,21 @@ pub(crate) fn verify_internal<
&mut t1,
);

// let (seed_for_a, mut t1) =
// encoding::verification_key::deserialize::<SIMDUnit, ROWS_IN_A, VERIFICATION_KEY_SIZE>(
// verification_key_serialized,
// );

let signature =
match Signature::<SIMDUnit, COMMITMENT_HASH_SIZE, COLUMNS_IN_A, ROWS_IN_A>::deserialize::<
GAMMA1_EXPONENT,
GAMMA1_RING_ELEMENT_SIZE,
MAX_ONES_IN_HINT,
SIGNATURE_SIZE,
>(signature_serialized)
{
Ok(s) => s,
Err(e) => return Err(e),
};
let mut signature = Signature {
commitment_hash: [0u8; COMMITMENT_HASH_SIZE],
signer_response: [PolynomialRingElement::zero(); COLUMNS_IN_A],
hint: [[0i32; COEFFICIENTS_IN_RING_ELEMENT]; ROWS_IN_A],
};
match Signature::<SIMDUnit, COMMITMENT_HASH_SIZE, COLUMNS_IN_A, ROWS_IN_A>::deserialize::<
GAMMA1_EXPONENT,
GAMMA1_RING_ELEMENT_SIZE,
MAX_ONES_IN_HINT,
SIGNATURE_SIZE,
>(signature_serialized, &mut signature)
{
Ok(_) => (),
Err(e) => return Err(e),
};

// We use if-else branches because early returns will not go through hax.
if vector_infinity_norm_exceeds::<SIMDUnit, COLUMNS_IN_A>(
Expand All @@ -588,7 +587,7 @@ pub(crate) fn verify_internal<
let mut message_representative = [0; MESSAGE_REPRESENTATIVE_SIZE];
derive_message_representative::<Shake256Xof>(
&verification_key_hash,
domain_separation_context,
&domain_separation_context,
message,
&mut message_representative,
);
Expand All @@ -604,7 +603,7 @@ pub(crate) fn verify_internal<

compute_w_approx::<SIMDUnit, ROWS_IN_A, COLUMNS_IN_A>(
&matrix,
signature.signer_response,
&signature.signer_response,
&verifier_challenge,
&mut t1,
);
Expand Down
24 changes: 14 additions & 10 deletions libcrux-ml-dsa/src/simd/avx2/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,17 @@ fn to_unsigned_representatives(t: &mut Vec256) {
}

#[inline(always)]
pub fn add(lhs: &mut Vec256, rhs: &Vec256) {
pub(super) fn add(lhs: &mut Vec256, rhs: &Vec256) {
*lhs = mm256_add_epi32(*lhs, *rhs)
}

#[inline(always)]
pub fn subtract(lhs: &mut Vec256, rhs: &Vec256) {
pub(super) fn subtract(lhs: &mut Vec256, rhs: &Vec256) {
*lhs = mm256_sub_epi32(*lhs, *rhs)
}

#[inline(always)]
pub fn montgomery_multiply_by_constant(lhs: Vec256, constant: i32) -> Vec256 {
pub(super) fn montgomery_multiply_by_constant(lhs: Vec256, constant: i32) -> Vec256 {
let rhs = mm256_set1_epi32(constant);
let field_modulus = mm256_set1_epi32(FIELD_MODULUS);
let inverse_of_modulus_mod_montgomery_r =
Expand All @@ -52,7 +52,7 @@ pub fn montgomery_multiply_by_constant(lhs: Vec256, constant: i32) -> Vec256 {
}

#[inline(always)]
pub fn montgomery_multiply(lhs: &mut Vec256, rhs: &Vec256) {
pub(super) fn montgomery_multiply(lhs: &mut Vec256, rhs: &Vec256) {
let field_modulus = mm256_set1_epi32(FIELD_MODULUS);
let inverse_of_modulus_mod_montgomery_r =
mm256_set1_epi32(INVERSE_OF_MODULUS_MOD_MONTGOMERY_R as i32);
Expand All @@ -75,7 +75,7 @@ pub fn montgomery_multiply(lhs: &mut Vec256, rhs: &Vec256) {
}

#[inline(always)]
pub fn shift_left_then_reduce<const SHIFT_BY: i32>(simd_unit: &mut Vec256) {
pub(super) fn shift_left_then_reduce<const SHIFT_BY: i32>(simd_unit: &mut Vec256) {
let shifted = mm256_slli_epi32::<SHIFT_BY>(*simd_unit);

let quotient = mm256_add_epi32(shifted, mm256_set1_epi32(1 << 22));
Expand All @@ -90,7 +90,7 @@ pub fn shift_left_then_reduce<const SHIFT_BY: i32>(simd_unit: &mut Vec256) {
// TODO: Revisit this function when doing the range analysis and testing
// additional KATs.
#[inline(always)]
pub fn infinity_norm_exceeds(simd_unit: &Vec256, bound: i32) -> bool {
pub(super) fn infinity_norm_exceeds(simd_unit: &Vec256, bound: i32) -> bool {
let absolute_values = mm256_abs_epi32(*simd_unit);

// We will test if |simd_unit| > bound - 1, because if this is the case then
Expand All @@ -106,7 +106,7 @@ pub fn infinity_norm_exceeds(simd_unit: &Vec256, bound: i32) -> bool {
}

#[inline(always)]
pub fn power2round(r0: &mut Vec256, r1: &mut Vec256) {
pub(super) fn power2round(r0: &mut Vec256, r1: &mut Vec256) {
to_unsigned_representatives(r0);

*r1 = mm256_add_epi32(
Expand All @@ -121,7 +121,7 @@ pub fn power2round(r0: &mut Vec256, r1: &mut Vec256) {

#[allow(non_snake_case)]
#[inline(always)]
pub fn decompose<const GAMMA2: i32>(r: &Vec256, r0: &mut Vec256, r1: &mut Vec256) {
pub(super) fn decompose<const GAMMA2: i32>(r: &Vec256, r0: &mut Vec256, r1: &mut Vec256) {
let mut r = r.clone();
to_unsigned_representatives(&mut r);

Expand Down Expand Up @@ -182,7 +182,11 @@ pub fn decompose<const GAMMA2: i32>(r: &Vec256, r0: &mut Vec256, r1: &mut Vec256
}

#[inline(always)]
pub fn compute_hint<const GAMMA2: i32>(low: &Vec256, high: &Vec256, hint: &mut Vec256) -> usize {
pub(super) fn compute_hint<const GAMMA2: i32>(
low: &Vec256,
high: &Vec256,
hint: &mut Vec256,
) -> usize {
let gamma2 = mm256_set1_epi32(GAMMA2);
let minus_gamma2 = mm256_set1_epi32(-GAMMA2);

Expand All @@ -206,7 +210,7 @@ pub fn compute_hint<const GAMMA2: i32>(low: &Vec256, high: &Vec256, hint: &mut V
}

#[inline(always)]
pub(crate) fn use_hint<const GAMMA2: i32>(r: &Vec256, hint: &mut Vec256) {
pub(super) fn use_hint<const GAMMA2: i32>(r: &Vec256, hint: &mut Vec256) {
let (mut r0, mut r1) = (zero(), zero());
decompose::<GAMMA2>(r, &mut r0, &mut r1);

Expand Down
32 changes: 17 additions & 15 deletions libcrux-ml-dsa/src/simd/avx2/invntt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ use libcrux_intrinsics::avx2::*;
#[inline(always)]
#[allow(unsafe_code)]
pub(crate) fn invert_ntt_montgomery(re: &mut AVX2RingElement) {
unsafe {
#[cfg_attr(not(hax), target_feature(enable = "avx2"))]
#[allow(unsafe_code)]
unsafe fn inv_inner(re: &mut AVX2RingElement) {
invert_ntt_at_layer_0(re);
invert_ntt_at_layer_1(re);
invert_ntt_at_layer_2(re);
Expand All @@ -15,16 +17,19 @@ pub(crate) fn invert_ntt_montgomery(re: &mut AVX2RingElement) {
invert_ntt_at_layer_5(re);
invert_ntt_at_layer_6(re);
invert_ntt_at_layer_7(re);

for i in 0..re.len() {
// After invert_ntt_at_layer, elements are of the form a * MONTGOMERY_R^{-1}
// we multiply by (MONTGOMERY_R^2) * (1/2^8) mod Q = 41,978 to both:
//
// - Divide the elements by 256 and
// - Convert the elements form montgomery domain to the standard domain.
const FACTOR: i32 = 41_978;
re[i] = arithmetic::montgomery_multiply_by_constant(re[i], FACTOR);
}
}
for i in 0..re.len() {
// After invert_ntt_at_layer, elements are of the form a * MONTGOMERY_R^{-1}
// we multiply by (MONTGOMERY_R^2) * (1/2^8) mod Q = 41,978 to both:
//
// - Divide the elements by 256 and
// - Convert the elements form montgomery domain to the standard domain.
const FACTOR: i32 = 41_978;
re[i] = arithmetic::montgomery_multiply_by_constant(re[i], FACTOR);
}

unsafe { inv_inner(re) };
}

#[inline(always)]
Expand Down Expand Up @@ -270,11 +275,8 @@ fn outer_3_plus<const OFFSET: usize, const STEP_BY: usize, const ZETA: i32>(
re: &mut [Vec256; SIMD_UNITS_IN_RING_ELEMENT],
) {
for j in OFFSET..OFFSET + STEP_BY {
// XXX: make nicer
let rejs = re[j + STEP_BY];
let mut a_minus_b = rejs;
arithmetic::subtract(&mut a_minus_b, &re[j]);
arithmetic::add(&mut re[j], &rejs);
let a_minus_b = mm256_sub_epi32(re[j + STEP_BY], re[j]);
re[j] = mm256_add_epi32(re[j], re[j + STEP_BY]);
re[j + STEP_BY] = arithmetic::montgomery_multiply_by_constant(a_minus_b, ZETA);
}
()
Expand Down
Loading

0 comments on commit 36fa5e0

Please sign in to comment.