From 36fa5e0684c117939afe80c696d1a277c5e65501 Mon Sep 17 00:00:00 2001 From: Franziskus Kiefer Date: Mon, 23 Dec 2024 16:31:18 +0000 Subject: [PATCH] avx2 ntt cleanup --- libcrux-ml-dsa/src/encoding/signature.rs | 15 +- libcrux-ml-dsa/src/matrix.rs | 3 +- libcrux-ml-dsa/src/ml_dsa_generic.rs | 39 ++-- libcrux-ml-dsa/src/simd/avx2/arithmetic.rs | 24 ++- libcrux-ml-dsa/src/simd/avx2/invntt.rs | 32 +-- libcrux-ml-dsa/src/simd/avx2/ntt.rs | 239 ++++++++------------- 6 files changed, 151 insertions(+), 201 deletions(-) diff --git a/libcrux-ml-dsa/src/encoding/signature.rs b/libcrux-ml-dsa/src/encoding/signature.rs index 12fe7f9e7..c8e9be9ee 100644 --- a/libcrux-ml-dsa/src/encoding/signature.rs +++ b/libcrux-ml-dsa/src/encoding/signature.rs @@ -66,7 +66,6 @@ impl< } } - #[allow(non_snake_case)] #[inline(always)] pub(crate) fn deserialize< const GAMMA1_EXPONENT: usize, @@ -75,7 +74,8 @@ impl< const SIGNATURE_SIZE: usize, >( serialized: &[u8; SIGNATURE_SIZE], - ) -> Result { + signature: &mut Self, + ) -> Result<(), VerificationError> { let (commitment_hash, rest_of_serialized) = serialized.split_at(COMMITMENT_HASH_SIZE); let (signer_response_serialized, hint_serialized) = rest_of_serialized.split_at(GAMMA1_RING_ELEMENT_SIZE * COLUMNS_IN_A); @@ -141,10 +141,11 @@ impl< return Err(VerificationError::MalformedHintError); } - Ok(Signature { - commitment_hash: commitment_hash.try_into().unwrap(), - signer_response, - hint, - }) + // Set output + signature.commitment_hash = commitment_hash.try_into().unwrap(); + signature.signer_response = signer_response; + signature.hint = hint; + + Ok(()) } } diff --git a/libcrux-ml-dsa/src/matrix.rs b/libcrux-ml-dsa/src/matrix.rs index 0728c56ee..713d2bafc 100644 --- a/libcrux-ml-dsa/src/matrix.rs +++ b/libcrux-ml-dsa/src/matrix.rs @@ -113,10 +113,11 @@ pub(crate) fn compute_w_approx< const COLUMNS_IN_A: usize, >( A_as_ntt: &[[PolynomialRingElement; COLUMNS_IN_A]; ROWS_IN_A], - mut signer_response: [PolynomialRingElement; COLUMNS_IN_A], + signer_response: &[PolynomialRingElement; COLUMNS_IN_A], verifier_challenge_as_ntt: &PolynomialRingElement, t1: &mut [PolynomialRingElement; ROWS_IN_A], ) { + let mut signer_response = signer_response.clone(); // Move signer response into NTT for i in 0..signer_response.len() { ntt(&mut signer_response[i]); diff --git a/libcrux-ml-dsa/src/ml_dsa_generic.rs b/libcrux-ml-dsa/src/ml_dsa_generic.rs index aef0d89b2..4d3aac122 100644 --- a/libcrux-ml-dsa/src/ml_dsa_generic.rs +++ b/libcrux-ml-dsa/src/ml_dsa_generic.rs @@ -301,7 +301,7 @@ pub(crate) fn sign_internal< let mut message_representative = [0; MESSAGE_REPRESENTATIVE_SIZE]; derive_message_representative::( verification_key_hash, - domain_separation_context, + &domain_separation_context, message, &mut message_representative, ); @@ -494,7 +494,7 @@ pub(crate) fn sign_internal< #[inline(always)] fn derive_message_representative( verification_key_hash: &[u8], - domain_separation_context: Option, + domain_separation_context: &Option, message: &[u8], message_representative: &mut [u8; 64], ) { @@ -553,22 +553,21 @@ pub(crate) fn verify_internal< &mut t1, ); - // let (seed_for_a, mut t1) = - // encoding::verification_key::deserialize::( - // verification_key_serialized, - // ); - - let signature = - match Signature::::deserialize::< - GAMMA1_EXPONENT, - GAMMA1_RING_ELEMENT_SIZE, - MAX_ONES_IN_HINT, - SIGNATURE_SIZE, - >(signature_serialized) - { - Ok(s) => s, - Err(e) => return Err(e), - }; + let mut signature = Signature { + commitment_hash: [0u8; COMMITMENT_HASH_SIZE], + signer_response: [PolynomialRingElement::zero(); COLUMNS_IN_A], + hint: [[0i32; COEFFICIENTS_IN_RING_ELEMENT]; ROWS_IN_A], + }; + match Signature::::deserialize::< + GAMMA1_EXPONENT, + GAMMA1_RING_ELEMENT_SIZE, + MAX_ONES_IN_HINT, + SIGNATURE_SIZE, + >(signature_serialized, &mut signature) + { + Ok(_) => (), + Err(e) => return Err(e), + }; // We use if-else branches because early returns will not go through hax. if vector_infinity_norm_exceeds::( @@ -588,7 +587,7 @@ pub(crate) fn verify_internal< let mut message_representative = [0; MESSAGE_REPRESENTATIVE_SIZE]; derive_message_representative::( &verification_key_hash, - domain_separation_context, + &domain_separation_context, message, &mut message_representative, ); @@ -604,7 +603,7 @@ pub(crate) fn verify_internal< compute_w_approx::( &matrix, - signature.signer_response, + &signature.signer_response, &verifier_challenge, &mut t1, ); diff --git a/libcrux-ml-dsa/src/simd/avx2/arithmetic.rs b/libcrux-ml-dsa/src/simd/avx2/arithmetic.rs index b67140c46..88e1927d8 100644 --- a/libcrux-ml-dsa/src/simd/avx2/arithmetic.rs +++ b/libcrux-ml-dsa/src/simd/avx2/arithmetic.rs @@ -16,17 +16,17 @@ fn to_unsigned_representatives(t: &mut Vec256) { } #[inline(always)] -pub fn add(lhs: &mut Vec256, rhs: &Vec256) { +pub(super) fn add(lhs: &mut Vec256, rhs: &Vec256) { *lhs = mm256_add_epi32(*lhs, *rhs) } #[inline(always)] -pub fn subtract(lhs: &mut Vec256, rhs: &Vec256) { +pub(super) fn subtract(lhs: &mut Vec256, rhs: &Vec256) { *lhs = mm256_sub_epi32(*lhs, *rhs) } #[inline(always)] -pub fn montgomery_multiply_by_constant(lhs: Vec256, constant: i32) -> Vec256 { +pub(super) fn montgomery_multiply_by_constant(lhs: Vec256, constant: i32) -> Vec256 { let rhs = mm256_set1_epi32(constant); let field_modulus = mm256_set1_epi32(FIELD_MODULUS); let inverse_of_modulus_mod_montgomery_r = @@ -52,7 +52,7 @@ pub fn montgomery_multiply_by_constant(lhs: Vec256, constant: i32) -> Vec256 { } #[inline(always)] -pub fn montgomery_multiply(lhs: &mut Vec256, rhs: &Vec256) { +pub(super) fn montgomery_multiply(lhs: &mut Vec256, rhs: &Vec256) { let field_modulus = mm256_set1_epi32(FIELD_MODULUS); let inverse_of_modulus_mod_montgomery_r = mm256_set1_epi32(INVERSE_OF_MODULUS_MOD_MONTGOMERY_R as i32); @@ -75,7 +75,7 @@ pub fn montgomery_multiply(lhs: &mut Vec256, rhs: &Vec256) { } #[inline(always)] -pub fn shift_left_then_reduce(simd_unit: &mut Vec256) { +pub(super) fn shift_left_then_reduce(simd_unit: &mut Vec256) { let shifted = mm256_slli_epi32::(*simd_unit); let quotient = mm256_add_epi32(shifted, mm256_set1_epi32(1 << 22)); @@ -90,7 +90,7 @@ pub fn shift_left_then_reduce(simd_unit: &mut Vec256) { // TODO: Revisit this function when doing the range analysis and testing // additional KATs. #[inline(always)] -pub fn infinity_norm_exceeds(simd_unit: &Vec256, bound: i32) -> bool { +pub(super) fn infinity_norm_exceeds(simd_unit: &Vec256, bound: i32) -> bool { let absolute_values = mm256_abs_epi32(*simd_unit); // We will test if |simd_unit| > bound - 1, because if this is the case then @@ -106,7 +106,7 @@ pub fn infinity_norm_exceeds(simd_unit: &Vec256, bound: i32) -> bool { } #[inline(always)] -pub fn power2round(r0: &mut Vec256, r1: &mut Vec256) { +pub(super) fn power2round(r0: &mut Vec256, r1: &mut Vec256) { to_unsigned_representatives(r0); *r1 = mm256_add_epi32( @@ -121,7 +121,7 @@ pub fn power2round(r0: &mut Vec256, r1: &mut Vec256) { #[allow(non_snake_case)] #[inline(always)] -pub fn decompose(r: &Vec256, r0: &mut Vec256, r1: &mut Vec256) { +pub(super) fn decompose(r: &Vec256, r0: &mut Vec256, r1: &mut Vec256) { let mut r = r.clone(); to_unsigned_representatives(&mut r); @@ -182,7 +182,11 @@ pub fn decompose(r: &Vec256, r0: &mut Vec256, r1: &mut Vec256 } #[inline(always)] -pub fn compute_hint(low: &Vec256, high: &Vec256, hint: &mut Vec256) -> usize { +pub(super) fn compute_hint( + low: &Vec256, + high: &Vec256, + hint: &mut Vec256, +) -> usize { let gamma2 = mm256_set1_epi32(GAMMA2); let minus_gamma2 = mm256_set1_epi32(-GAMMA2); @@ -206,7 +210,7 @@ pub fn compute_hint(low: &Vec256, high: &Vec256, hint: &mut V } #[inline(always)] -pub(crate) fn use_hint(r: &Vec256, hint: &mut Vec256) { +pub(super) fn use_hint(r: &Vec256, hint: &mut Vec256) { let (mut r0, mut r1) = (zero(), zero()); decompose::(r, &mut r0, &mut r1); diff --git a/libcrux-ml-dsa/src/simd/avx2/invntt.rs b/libcrux-ml-dsa/src/simd/avx2/invntt.rs index 53f08c830..5337a68f8 100644 --- a/libcrux-ml-dsa/src/simd/avx2/invntt.rs +++ b/libcrux-ml-dsa/src/simd/avx2/invntt.rs @@ -6,7 +6,9 @@ use libcrux_intrinsics::avx2::*; #[inline(always)] #[allow(unsafe_code)] pub(crate) fn invert_ntt_montgomery(re: &mut AVX2RingElement) { - unsafe { + #[cfg_attr(not(hax), target_feature(enable = "avx2"))] + #[allow(unsafe_code)] + unsafe fn inv_inner(re: &mut AVX2RingElement) { invert_ntt_at_layer_0(re); invert_ntt_at_layer_1(re); invert_ntt_at_layer_2(re); @@ -15,16 +17,19 @@ pub(crate) fn invert_ntt_montgomery(re: &mut AVX2RingElement) { invert_ntt_at_layer_5(re); invert_ntt_at_layer_6(re); invert_ntt_at_layer_7(re); + + for i in 0..re.len() { + // After invert_ntt_at_layer, elements are of the form a * MONTGOMERY_R^{-1} + // we multiply by (MONTGOMERY_R^2) * (1/2^8) mod Q = 41,978 to both: + // + // - Divide the elements by 256 and + // - Convert the elements form montgomery domain to the standard domain. + const FACTOR: i32 = 41_978; + re[i] = arithmetic::montgomery_multiply_by_constant(re[i], FACTOR); + } } - for i in 0..re.len() { - // After invert_ntt_at_layer, elements are of the form a * MONTGOMERY_R^{-1} - // we multiply by (MONTGOMERY_R^2) * (1/2^8) mod Q = 41,978 to both: - // - // - Divide the elements by 256 and - // - Convert the elements form montgomery domain to the standard domain. - const FACTOR: i32 = 41_978; - re[i] = arithmetic::montgomery_multiply_by_constant(re[i], FACTOR); - } + + unsafe { inv_inner(re) }; } #[inline(always)] @@ -270,11 +275,8 @@ fn outer_3_plus( re: &mut [Vec256; SIMD_UNITS_IN_RING_ELEMENT], ) { for j in OFFSET..OFFSET + STEP_BY { - // XXX: make nicer - let rejs = re[j + STEP_BY]; - let mut a_minus_b = rejs; - arithmetic::subtract(&mut a_minus_b, &re[j]); - arithmetic::add(&mut re[j], &rejs); + let a_minus_b = mm256_sub_epi32(re[j + STEP_BY], re[j]); + re[j] = mm256_add_epi32(re[j], re[j + STEP_BY]); re[j + STEP_BY] = arithmetic::montgomery_multiply_by_constant(a_minus_b, ZETA); } () diff --git a/libcrux-ml-dsa/src/simd/avx2/ntt.rs b/libcrux-ml-dsa/src/simd/avx2/ntt.rs index cf64b0088..ece8055c2 100644 --- a/libcrux-ml-dsa/src/simd/avx2/ntt.rs +++ b/libcrux-ml-dsa/src/simd/avx2/ntt.rs @@ -5,8 +5,8 @@ use libcrux_intrinsics::avx2::*; #[inline(always)] fn butterfly_2( - a: Vec256, - b: Vec256, + re: &mut [Vec256; SIMD_UNITS_IN_RING_ELEMENT], + index: usize, zeta_a0: i32, zeta_a1: i32, zeta_a2: i32, @@ -15,7 +15,7 @@ fn butterfly_2( zeta_b1: i32, zeta_b2: i32, zeta_b3: i32, -) -> (Vec256, Vec256) { +) { // We shuffle the terms to group those that need to be multiplied // with zetas in the high QWORDS of the vectors, i.e. if the inputs are // a = (a7, a6, a5, a4, a3, a2, a1, a0) @@ -24,166 +24,133 @@ fn butterfly_2( // a_shuffled = ( a7, a5, a6, a4, a3, a1, a2, a0) // b_shuffled = ( b7, b5, b6, b4, b3, b1, b2, b0) const SHUFFLE: i32 = 0b11_01_10_00; - let a_shuffled = mm256_shuffle_epi32::(a); - let b_shuffled = mm256_shuffle_epi32::(b); + let a = mm256_shuffle_epi32::(re[index]); + let b = mm256_shuffle_epi32::(re[index + 1]); // Now we can use the same approach as for `butterfly_4`, only // zetas need to be adjusted. - let mut summands = mm256_unpacklo_epi64(a_shuffled, b_shuffled); - let mut zeta_products = mm256_unpackhi_epi64(a_shuffled, b_shuffled); + let summands = mm256_unpacklo_epi64(a, b); + let mut zeta_products = mm256_unpackhi_epi64(a, b); let zetas = mm256_set_epi32( zeta_b3, zeta_b2, zeta_a3, zeta_a2, zeta_b1, zeta_b0, zeta_a1, zeta_a0, ); arithmetic::montgomery_multiply(&mut zeta_products, &zetas); - let mut sub_terms = summands; - arithmetic::subtract(&mut sub_terms, &zeta_products); - arithmetic::add(&mut summands, &zeta_products); - let add_terms = summands; + let sub_terms = mm256_sub_epi32(summands, zeta_products); + let add_terms = mm256_add_epi32(summands, zeta_products); let a_terms_shuffled = mm256_unpacklo_epi64(add_terms, sub_terms); let b_terms_shuffled = mm256_unpackhi_epi64(add_terms, sub_terms); // Here, we undo the initial shuffle (it's self-inverse). - let a_out = mm256_shuffle_epi32::(a_terms_shuffled); - let b_out = mm256_shuffle_epi32::(b_terms_shuffled); - - (a_out, b_out) + re[index] = mm256_shuffle_epi32::(a_terms_shuffled); + re[index + 1] = mm256_shuffle_epi32::(b_terms_shuffled); } // Compute (a,b) ↦ (a + ζb, a - ζb) at layer 1 for 2 SIMD Units in one go. #[inline(always)] fn butterfly_4( - a: Vec256, - b: Vec256, + re: &mut [Vec256; SIMD_UNITS_IN_RING_ELEMENT], + index: usize, zeta_a0: i32, zeta_a1: i32, zeta_b0: i32, zeta_b1: i32, -) -> (Vec256, Vec256) { - let mut summands = mm256_unpacklo_epi64(a, b); - let mut zeta_products = mm256_unpackhi_epi64(a, b); +) { + let summands = mm256_unpacklo_epi64(re[index], re[index + 1]); + let mut zeta_products = mm256_unpackhi_epi64(re[index], re[index + 1]); let zetas = mm256_set_epi32( zeta_b1, zeta_b1, zeta_a1, zeta_a1, zeta_b0, zeta_b0, zeta_a0, zeta_a0, ); arithmetic::montgomery_multiply(&mut zeta_products, &zetas); - let mut sub_terms = summands; - arithmetic::subtract(&mut sub_terms, &zeta_products); - arithmetic::add(&mut summands, &zeta_products); - let add_terms = summands; + let sub_terms = mm256_sub_epi32(summands, zeta_products); + let add_terms = mm256_add_epi32(summands, zeta_products); // Results are shuffled across the two SIMD registers. // We need to bring them in the right order. - let a_out = mm256_unpacklo_epi64(add_terms, sub_terms); - let b_out = mm256_unpackhi_epi64(add_terms, sub_terms); - - (a_out, b_out) + re[index] = mm256_unpacklo_epi64(add_terms, sub_terms); + re[index + 1] = mm256_unpackhi_epi64(add_terms, sub_terms); } // Compute (a,b) ↦ (a + ζb, a - ζb) at layer 2 for 2 SIMD Units in one go. #[inline(always)] -fn butterfly_8(a: Vec256, b: Vec256, zeta0: i32, zeta1: i32) -> (Vec256, Vec256) { - let mut summands = mm256_set_m128i(mm256_castsi256_si128(b), mm256_castsi256_si128(a)); - let mut zeta_products = mm256_permute2x128_si256::<0b0001_0011>(b, a); +fn butterfly_8( + re: &mut [Vec256; SIMD_UNITS_IN_RING_ELEMENT], + index: usize, + zeta0: i32, + zeta1: i32, +) { + let summands = mm256_set_m128i( + mm256_castsi256_si128(re[index + 1]), + mm256_castsi256_si128(re[index]), + ); + let mut zeta_products = mm256_permute2x128_si256::<0b0001_0011>(re[index + 1], re[index]); let zetas = mm256_set_epi32(zeta1, zeta1, zeta1, zeta1, zeta0, zeta0, zeta0, zeta0); arithmetic::montgomery_multiply(&mut zeta_products, &zetas); - let mut sub_terms = summands; - arithmetic::subtract(&mut sub_terms, &zeta_products); - arithmetic::add(&mut summands, &zeta_products); - let add_terms = summands; + let sub_terms = mm256_sub_epi32(summands, zeta_products); + let add_terms = mm256_add_epi32(summands, zeta_products); - let a_out = mm256_set_m128i( + re[index] = mm256_set_m128i( mm256_castsi256_si128(sub_terms), mm256_castsi256_si128(add_terms), ); - let b_out = mm256_permute2x128_si256::<0b0001_0011>(sub_terms, add_terms); - - (a_out, b_out) + re[index + 1] = mm256_permute2x128_si256::<0b0001_0011>(sub_terms, add_terms); } #[cfg_attr(not(hax), target_feature(enable = "avx2"))] #[allow(unsafe_code)] unsafe fn ntt_at_layer_0(re: &mut [Vec256; SIMD_UNITS_IN_RING_ELEMENT]) { - #[inline(always)] - fn round( - re: &mut [Vec256; SIMD_UNITS_IN_RING_ELEMENT], - index: usize, - zeta_0: i32, - zeta_1: i32, - zeta_2: i32, - zeta_3: i32, - zeta_4: i32, - zeta_5: i32, - zeta_6: i32, - zeta_7: i32, - ) { - let (a, b) = butterfly_2( - re[index], - re[index + 1], - zeta_0, - zeta_1, - zeta_2, - zeta_3, - zeta_4, - zeta_5, - zeta_6, - zeta_7, - ); - re[index] = a; - re[index + 1] = b; - } - - round( + butterfly_2( re, 0, 2091667, 3407706, 2316500, 3817976, -3342478, 2244091, -2446433, -3562462, ); - round( + butterfly_2( re, 2, 266997, 2434439, -1235728, 3513181, -3520352, -3759364, -1197226, -3193378, ); - round( + butterfly_2( re, 4, 900702, 1859098, 909542, 819034, 495491, -1613174, -43260, -522500, ); - round( + butterfly_2( re, 6, -655327, -3122442, 2031748, 3207046, -3556995, -525098, -768622, -3595838, ); - round( + butterfly_2( re, 8, 342297, 286988, -2437823, 4108315, 3437287, -3342277, 1735879, 203044, ); - round( + butterfly_2( re, 10, 2842341, 2691481, -2590150, 1265009, 4055324, 1247620, 2486353, 1595974, ); - round( + butterfly_2( re, 12, -3767016, 1250494, 2635921, -3548272, -2994039, 1869119, 1903435, -1050970, ); - round( + butterfly_2( re, 14, -1333058, 1237275, -3318210, -1430225, -451100, 1312455, 3306115, -1962642, ); - round( + butterfly_2( re, 16, -1279661, 1917081, -2546312, -1374803, 1500165, 777191, 2235880, 3406031, ); - round( + butterfly_2( re, 18, -542412, -2831860, -1671176, -1846953, -2584293, -3724270, 594136, -3776993, ); - round( + butterfly_2( re, 20, -2013608, 2432395, 2454455, -164721, 1957272, 3369112, 185531, -1207385, ); - round( + butterfly_2( re, 22, -3183426, 162844, 1616392, 3014001, 810149, 1652634, -3694233, -1799107, ); - round( + butterfly_2( re, 24, -3038916, 3523897, 3866901, 269760, 2213111, -975884, 1717735, 472078, ); - round( + butterfly_2( re, 26, -426683, 1723600, -1803090, 1910376, -1667432, -1104333, -260646, -3833893, ); - round( + butterfly_2( re, 28, -2939036, -2235985, -420899, -2286327, 183443, -976891, 1612842, -3545687, ); - round( + butterfly_2( re, 30, -554416, 3919660, -48306, -1362209, 3937738, 1400424, -846154, 1976782, ); } @@ -191,69 +158,43 @@ unsafe fn ntt_at_layer_0(re: &mut [Vec256; SIMD_UNITS_IN_RING_ELEMENT]) { #[cfg_attr(not(hax), target_feature(enable = "avx2"))] #[allow(unsafe_code)] unsafe fn ntt_at_layer_1(re: &mut [Vec256; SIMD_UNITS_IN_RING_ELEMENT]) { - #[inline(always)] - fn round( - re: &mut [Vec256; SIMD_UNITS_IN_RING_ELEMENT], - index: usize, - zeta_0: i32, - zeta_1: i32, - zeta_2: i32, - zeta_3: i32, - ) { - let (a, b) = butterfly_4(re[index], re[index + 1], zeta_0, zeta_1, zeta_2, zeta_3); - re[index] = a; - re[index + 1] = b; - } - - round(re, 0, -3930395, -1528703, -3677745, -3041255); - round(re, 2, -1452451, 3475950, 2176455, -1585221); - round(re, 4, -1257611, 1939314, -4083598, -1000202); - round(re, 6, -3190144, -3157330, -3632928, 126922); - round(re, 8, 3412210, -983419, 2147896, 2715295); - round(re, 10, -2967645, -3693493, -411027, -2477047); - round(re, 12, -671102, -1228525, -22981, -1308169); - round(re, 14, -381987, 1349076, 1852771, -1430430); - round(re, 16, -3343383, 264944, 508951, 3097992); - round(re, 18, 44288, -1100098, 904516, 3958618); - round(re, 20, -3724342, -8578, 1653064, -3249728); - round(re, 22, 2389356, -210977, 759969, -1316856); - round(re, 24, 189548, -3553272, 3159746, -1851402); - round(re, 26, -2409325, -177440, 1315589, 1341330); - round(re, 28, 1285669, -1584928, -812732, -1439742); - round(re, 30, -3019102, -3881060, -3628969, 3839961); + butterfly_4(re, 0, -3930395, -1528703, -3677745, -3041255); + butterfly_4(re, 2, -1452451, 3475950, 2176455, -1585221); + butterfly_4(re, 4, -1257611, 1939314, -4083598, -1000202); + butterfly_4(re, 6, -3190144, -3157330, -3632928, 126922); + butterfly_4(re, 8, 3412210, -983419, 2147896, 2715295); + butterfly_4(re, 10, -2967645, -3693493, -411027, -2477047); + butterfly_4(re, 12, -671102, -1228525, -22981, -1308169); + butterfly_4(re, 14, -381987, 1349076, 1852771, -1430430); + butterfly_4(re, 16, -3343383, 264944, 508951, 3097992); + butterfly_4(re, 18, 44288, -1100098, 904516, 3958618); + butterfly_4(re, 20, -3724342, -8578, 1653064, -3249728); + butterfly_4(re, 22, 2389356, -210977, 759969, -1316856); + butterfly_4(re, 24, 189548, -3553272, 3159746, -1851402); + butterfly_4(re, 26, -2409325, -177440, 1315589, 1341330); + butterfly_4(re, 28, 1285669, -1584928, -812732, -1439742); + butterfly_4(re, 30, -3019102, -3881060, -3628969, 3839961); } #[cfg_attr(not(hax), target_feature(enable = "avx2"))] #[allow(unsafe_code)] unsafe fn ntt_at_layer_2(re: &mut [Vec256; SIMD_UNITS_IN_RING_ELEMENT]) { - #[inline(always)] - fn round( - re: &mut [Vec256; SIMD_UNITS_IN_RING_ELEMENT], - index: usize, - zeta_0: i32, - zeta_1: i32, - ) { - let (a, b) = butterfly_8(re[index], re[index + 1], zeta_0, zeta_1); - re[index] = a; - re[index + 1] = b; - } - - round(re, 0, 2706023, 95776); - round(re, 2, 3077325, 3530437); - round(re, 4, -1661693, -3592148); - round(re, 6, -2537516, 3915439); - round(re, 8, -3861115, -3043716); - round(re, 10, 3574422, -2867647); - round(re, 12, 3539968, -300467); - round(re, 14, 2348700, -539299); - round(re, 16, -1699267, -1643818); - round(re, 18, 3505694, -3821735); - round(re, 20, 3507263, -2140649); - round(re, 22, -1600420, 3699596); - round(re, 24, 811944, 531354); - round(re, 26, 954230, 3881043); - round(re, 28, 3900724, -2556880); - round(re, 30, 2071892, -2797779); + butterfly_8(re, 0, 2706023, 95776); + butterfly_8(re, 2, 3077325, 3530437); + butterfly_8(re, 4, -1661693, -3592148); + butterfly_8(re, 6, -2537516, 3915439); + butterfly_8(re, 8, -3861115, -3043716); + butterfly_8(re, 10, 3574422, -2867647); + butterfly_8(re, 12, 3539968, -300467); + butterfly_8(re, 14, 2348700, -539299); + butterfly_8(re, 16, -1699267, -1643818); + butterfly_8(re, 18, 3505694, -3821735); + butterfly_8(re, 20, 3507263, -2140649); + butterfly_8(re, 22, -1600420, 3699596); + butterfly_8(re, 24, 811944, 531354); + butterfly_8(re, 26, 954230, 3881043); + butterfly_8(re, 28, 3900724, -2556880); + butterfly_8(re, 30, 2071892, -2797779); } /// This is equivalent to the pqclean 0 and 1 @@ -369,12 +310,11 @@ unsafe fn ntt_at_layer_5_to_3(re: &mut [Vec256; SIMD_UNITS_IN_RING_ELEMENT]) { let offset = (index * STEP * 2) / COEFFICIENTS_IN_SIMD_UNIT; for j in offset..offset + STEP_BY { - let mut t = re[j + STEP_BY]; - arithmetic::montgomery_multiply(&mut t, &rhs); + arithmetic::montgomery_multiply(&mut re[j + STEP_BY], &rhs); - re[j + STEP_BY] = re[j]; - arithmetic::subtract(&mut re[j + STEP_BY], &t); - arithmetic::add(&mut re[j], &t); + let tmp = mm256_sub_epi32(re[j], re[j + STEP_BY]); + re[j] = mm256_add_epi32(re[j], re[j + STEP_BY]); + re[j + STEP_BY] = tmp; } () // Needed because of https://github.com/hacspec/hax/issues/720 } @@ -446,11 +386,14 @@ unsafe fn ntt_at_layer_5_to_3(re: &mut [Vec256; SIMD_UNITS_IN_RING_ELEMENT]) { #[allow(unsafe_code)] #[inline(always)] pub(crate) fn ntt(re: &mut AVX2RingElement) { - unsafe { + #[cfg_attr(not(hax), target_feature(enable = "avx2"))] + unsafe fn avx2_ntt(re: &mut AVX2RingElement) { ntt_at_layer_7_and_6(re); ntt_at_layer_5_to_3(re); ntt_at_layer_2(re); ntt_at_layer_1(re); ntt_at_layer_0(re); } + + unsafe { avx2_ntt(re) } }