diff --git a/libcrux-ml-dsa/src/simd/avx2.rs b/libcrux-ml-dsa/src/simd/avx2.rs index f4236caa2..608f37add 100644 --- a/libcrux-ml-dsa/src/simd/avx2.rs +++ b/libcrux-ml-dsa/src/simd/avx2.rs @@ -129,9 +129,8 @@ impl Operations for AVX2SIMDUnit { } #[inline(always)] - #[allow(unsafe_code)] fn ntt(simd_units: [Self; SIMD_UNITS_IN_RING_ELEMENT]) -> [Self; SIMD_UNITS_IN_RING_ELEMENT] { - let result = unsafe { ntt::ntt(simd_units.map(|x| x.coefficients)) }; + let result = ntt::ntt(simd_units.map(|x| x.coefficients)); result.map(|x| x.into()) } diff --git a/libcrux-ml-dsa/src/simd/avx2/ntt.rs b/libcrux-ml-dsa/src/simd/avx2/ntt.rs index 8ae3c9d68..cb84a2933 100644 --- a/libcrux-ml-dsa/src/simd/avx2/ntt.rs +++ b/libcrux-ml-dsa/src/simd/avx2/ntt.rs @@ -488,16 +488,18 @@ unsafe fn ntt_at_layer_5_to_3(re: &mut [Vec256; SIMD_UNITS_IN_RING_ELEMENT]) { () } -#[cfg_attr(not(hax), target_feature(enable = "avx2"))] #[allow(unsafe_code)] -pub(crate) unsafe fn ntt( +#[inline(always)] +pub(crate) fn ntt( mut re: [Vec256; SIMD_UNITS_IN_RING_ELEMENT], ) -> [Vec256; SIMD_UNITS_IN_RING_ELEMENT] { - ntt_at_layer_7_and_6(&mut re); - ntt_at_layer_5_to_3(&mut re); - ntt_at_layer_2(&mut re); - ntt_at_layer_1(&mut re); - ntt_at_layer_0(&mut re); + unsafe { + ntt_at_layer_7_and_6(&mut re); + ntt_at_layer_5_to_3(&mut re); + ntt_at_layer_2(&mut re); + ntt_at_layer_1(&mut re); + ntt_at_layer_0(&mut re); + } re }