Skip to content

Commit

Permalink
Karthik's montgomery patch
Browse files Browse the repository at this point in the history
  • Loading branch information
jschneider-bensch committed Sep 18, 2024
1 parent f7aa4b4 commit 87e9358
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 16 deletions.
43 changes: 30 additions & 13 deletions libcrux-ml-dsa/src/simd/avx2/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,20 +38,29 @@ fn simd_multiply_i32_and_return_high(lhs: Vec256, rhs: Vec256) -> Vec256 {
}

#[inline(always)]
pub fn montgomery_multiply_by_constant(simd_unit: Vec256, constant: i32) -> Vec256 {
let constant = mm256_set1_epi32(constant);
pub fn montgomery_multiply_by_constant(lhs: Vec256, constant: i32) -> Vec256 {
let rhs = mm256_set1_epi32(constant);
let field_modulus = mm256_set1_epi32(FIELD_MODULUS);
let inverse_of_modulus_mod_montgomery_r =
mm256_set1_epi32(INVERSE_OF_MODULUS_MOD_MONTGOMERY_R as i32);

let product_low = mm256_mullo_epi32(simd_unit, constant);
let prod02 = mm256_mul_epi32(lhs, rhs);
let prod13 = mm256_mul_epi32(
mm256_shuffle_epi32::<0b11_11_01_01>(lhs),
mm256_shuffle_epi32::<0b11_11_01_01>(rhs),
);

let k = mm256_mullo_epi32(product_low, inverse_of_modulus_mod_montgomery_r);
let k02 = mm256_mul_epi32(prod02, inverse_of_modulus_mod_montgomery_r);
let k13 = mm256_mul_epi32(prod13, inverse_of_modulus_mod_montgomery_r);

let c = simd_multiply_i32_and_return_high(k, field_modulus);
let product_high = simd_multiply_i32_and_return_high(simd_unit, constant);
let c02 = mm256_mul_epi32(k02, field_modulus);
let c13 = mm256_mul_epi32(k13, field_modulus);

mm256_sub_epi32(product_high, c)
let res02 = mm256_sub_epi32(prod02, c02);
let res13 = mm256_sub_epi32(prod13, c13);
let res02_shifted = mm256_shuffle_epi32::<0b11_11_01_01>(res02);
let res = mm256_blend_epi32::<0b10101010>(res02_shifted, res13);
res
}

#[inline(always)]
Expand All @@ -60,14 +69,22 @@ pub fn montgomery_multiply(lhs: Vec256, rhs: Vec256) -> Vec256 {
let inverse_of_modulus_mod_montgomery_r =
mm256_set1_epi32(INVERSE_OF_MODULUS_MOD_MONTGOMERY_R as i32);

let product_low = mm256_mullo_epi32(lhs, rhs);

let k = mm256_mullo_epi32(product_low, inverse_of_modulus_mod_montgomery_r);
let prod02 = mm256_mul_epi32(lhs, rhs);
let prod13 = mm256_mul_epi32(
mm256_shuffle_epi32::<0b11_11_01_01>(lhs),
mm256_shuffle_epi32::<0b11_11_01_01>(rhs),
);
let k02 = mm256_mul_epi32(prod02, inverse_of_modulus_mod_montgomery_r);
let k13 = mm256_mul_epi32(prod13, inverse_of_modulus_mod_montgomery_r);

let c = simd_multiply_i32_and_return_high(k, field_modulus);
let product_high = simd_multiply_i32_and_return_high(lhs, rhs);
let c02 = mm256_mul_epi32(k02, field_modulus);
let c13 = mm256_mul_epi32(k13, field_modulus);

mm256_sub_epi32(product_high, c)
let res02 = mm256_sub_epi32(prod02, c02);
let res13 = mm256_sub_epi32(prod13, c13);
let res02_shifted = mm256_shuffle_epi32::<0b11_11_01_01>(res02);
let res = mm256_blend_epi32::<0b10101010>(res02_shifted, res13);
res
}

#[inline(always)]
Expand Down
5 changes: 4 additions & 1 deletion libcrux-ml-dsa/src/simd/avx2/ntt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,10 @@ fn butterfly_8(a: Vec256, b: Vec256, zeta0: i32, zeta1: i32) -> (Vec256, Vec256)
let add_terms = arithmetic::add(summands, zeta_products);
let sub_terms = arithmetic::subtract(summands, zeta_products);

let a_out = mm256_set_m128i(mm256_castsi256_si128(sub_terms), mm256_castsi256_si128(add_terms));
let a_out = mm256_set_m128i(
mm256_castsi256_si128(sub_terms),
mm256_castsi256_si128(add_terms),
);
let b_out = mm256_permute2x128_si256::<0b0001_0011>(sub_terms, add_terms);

(a_out, b_out)
Expand Down
5 changes: 4 additions & 1 deletion libcrux-ml-dsa/src/simd/portable/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ use crate::{
constants::BITS_IN_LOWER_PART_OF_T,
simd::{
portable::PortableSIMDUnit,
traits::{Operations, FIELD_MODULUS, INVERSE_OF_MODULUS_MOD_MONTGOMERY_R, FieldElementTimesMontgomeryR},
traits::{
FieldElementTimesMontgomeryR, Operations, FIELD_MODULUS,
INVERSE_OF_MODULUS_MOD_MONTGOMERY_R,
},
},
};

Expand Down
5 changes: 4 additions & 1 deletion libcrux-ml-dsa/src/simd/portable/ntt.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use super::arithmetic;
use crate::simd::{
portable::PortableSIMDUnit,
traits::{montgomery_multiply_by_fer, COEFFICIENTS_IN_SIMD_UNIT, SIMD_UNITS_IN_RING_ELEMENT, ZETAS_TIMES_MONTGOMERY_R},
traits::{
montgomery_multiply_by_fer, COEFFICIENTS_IN_SIMD_UNIT, SIMD_UNITS_IN_RING_ELEMENT,
ZETAS_TIMES_MONTGOMERY_R,
},
};

#[inline(always)]
Expand Down

0 comments on commit 87e9358

Please sign in to comment.