diff --git a/libcrux-ml-dsa/src/simd/avx2/arithmetic.rs b/libcrux-ml-dsa/src/simd/avx2/arithmetic.rs index c9b4eaaa0..12954f91b 100644 --- a/libcrux-ml-dsa/src/simd/avx2/arithmetic.rs +++ b/libcrux-ml-dsa/src/simd/avx2/arithmetic.rs @@ -38,20 +38,29 @@ fn simd_multiply_i32_and_return_high(lhs: Vec256, rhs: Vec256) -> Vec256 { } #[inline(always)] -pub fn montgomery_multiply_by_constant(simd_unit: Vec256, constant: i32) -> Vec256 { - let constant = mm256_set1_epi32(constant); +pub fn montgomery_multiply_by_constant(lhs: Vec256, constant: i32) -> Vec256 { + let rhs = mm256_set1_epi32(constant); let field_modulus = mm256_set1_epi32(FIELD_MODULUS); let inverse_of_modulus_mod_montgomery_r = mm256_set1_epi32(INVERSE_OF_MODULUS_MOD_MONTGOMERY_R as i32); - let product_low = mm256_mullo_epi32(simd_unit, constant); + let prod02 = mm256_mul_epi32(lhs, rhs); + let prod13 = mm256_mul_epi32( + mm256_shuffle_epi32::<0b11_11_01_01>(lhs), + mm256_shuffle_epi32::<0b11_11_01_01>(rhs), + ); - let k = mm256_mullo_epi32(product_low, inverse_of_modulus_mod_montgomery_r); + let k02 = mm256_mul_epi32(prod02, inverse_of_modulus_mod_montgomery_r); + let k13 = mm256_mul_epi32(prod13, inverse_of_modulus_mod_montgomery_r); - let c = simd_multiply_i32_and_return_high(k, field_modulus); - let product_high = simd_multiply_i32_and_return_high(simd_unit, constant); + let c02 = mm256_mul_epi32(k02, field_modulus); + let c13 = mm256_mul_epi32(k13, field_modulus); - mm256_sub_epi32(product_high, c) + let res02 = mm256_sub_epi32(prod02, c02); + let res13 = mm256_sub_epi32(prod13, c13); + let res02_shifted = mm256_shuffle_epi32::<0b11_11_01_01>(res02); + let res = mm256_blend_epi32::<0b10101010>(res02_shifted, res13); + res } #[inline(always)] @@ -60,14 +69,22 @@ pub fn montgomery_multiply(lhs: Vec256, rhs: Vec256) -> Vec256 { let inverse_of_modulus_mod_montgomery_r = mm256_set1_epi32(INVERSE_OF_MODULUS_MOD_MONTGOMERY_R as i32); - let product_low = mm256_mullo_epi32(lhs, rhs); - - let k = mm256_mullo_epi32(product_low, inverse_of_modulus_mod_montgomery_r); + let prod02 = mm256_mul_epi32(lhs, rhs); + let prod13 = mm256_mul_epi32( + mm256_shuffle_epi32::<0b11_11_01_01>(lhs), + mm256_shuffle_epi32::<0b11_11_01_01>(rhs), + ); + let k02 = mm256_mul_epi32(prod02, inverse_of_modulus_mod_montgomery_r); + let k13 = mm256_mul_epi32(prod13, inverse_of_modulus_mod_montgomery_r); - let c = simd_multiply_i32_and_return_high(k, field_modulus); - let product_high = simd_multiply_i32_and_return_high(lhs, rhs); + let c02 = mm256_mul_epi32(k02, field_modulus); + let c13 = mm256_mul_epi32(k13, field_modulus); - mm256_sub_epi32(product_high, c) + let res02 = mm256_sub_epi32(prod02, c02); + let res13 = mm256_sub_epi32(prod13, c13); + let res02_shifted = mm256_shuffle_epi32::<0b11_11_01_01>(res02); + let res = mm256_blend_epi32::<0b10101010>(res02_shifted, res13); + res } #[inline(always)] diff --git a/libcrux-ml-dsa/src/simd/avx2/ntt.rs b/libcrux-ml-dsa/src/simd/avx2/ntt.rs index b18887b1d..c6f302155 100644 --- a/libcrux-ml-dsa/src/simd/avx2/ntt.rs +++ b/libcrux-ml-dsa/src/simd/avx2/ntt.rs @@ -93,7 +93,10 @@ fn butterfly_8(a: Vec256, b: Vec256, zeta0: i32, zeta1: i32) -> (Vec256, Vec256) let add_terms = arithmetic::add(summands, zeta_products); let sub_terms = arithmetic::subtract(summands, zeta_products); - let a_out = mm256_set_m128i(mm256_castsi256_si128(sub_terms), mm256_castsi256_si128(add_terms)); + let a_out = mm256_set_m128i( + mm256_castsi256_si128(sub_terms), + mm256_castsi256_si128(add_terms), + ); let b_out = mm256_permute2x128_si256::<0b0001_0011>(sub_terms, add_terms); (a_out, b_out) diff --git a/libcrux-ml-dsa/src/simd/portable/arithmetic.rs b/libcrux-ml-dsa/src/simd/portable/arithmetic.rs index a57ceb241..1785d108e 100644 --- a/libcrux-ml-dsa/src/simd/portable/arithmetic.rs +++ b/libcrux-ml-dsa/src/simd/portable/arithmetic.rs @@ -2,7 +2,10 @@ use crate::{ constants::BITS_IN_LOWER_PART_OF_T, simd::{ portable::PortableSIMDUnit, - traits::{Operations, FIELD_MODULUS, INVERSE_OF_MODULUS_MOD_MONTGOMERY_R, FieldElementTimesMontgomeryR}, + traits::{ + FieldElementTimesMontgomeryR, Operations, FIELD_MODULUS, + INVERSE_OF_MODULUS_MOD_MONTGOMERY_R, + }, }, }; diff --git a/libcrux-ml-dsa/src/simd/portable/ntt.rs b/libcrux-ml-dsa/src/simd/portable/ntt.rs index ae2e76284..1674c9c66 100644 --- a/libcrux-ml-dsa/src/simd/portable/ntt.rs +++ b/libcrux-ml-dsa/src/simd/portable/ntt.rs @@ -1,7 +1,10 @@ use super::arithmetic; use crate::simd::{ portable::PortableSIMDUnit, - traits::{montgomery_multiply_by_fer, COEFFICIENTS_IN_SIMD_UNIT, SIMD_UNITS_IN_RING_ELEMENT, ZETAS_TIMES_MONTGOMERY_R}, + traits::{ + montgomery_multiply_by_fer, COEFFICIENTS_IN_SIMD_UNIT, SIMD_UNITS_IN_RING_ELEMENT, + ZETAS_TIMES_MONTGOMERY_R, + }, }; #[inline(always)]