Skip to content

Commit

Permalink
Merge pull request #584 from cryspen/goutam/ml-dsa-improve-avx2-ntt
Browse files Browse the repository at this point in the history
[ML-DSA] AVX2 performance improvements in NTT
  • Loading branch information
jschneider-bensch authored Sep 25, 2024
2 parents 9cc8d3a + 83fe42e commit 8c05744
Show file tree
Hide file tree
Showing 10 changed files with 374 additions and 160 deletions.
7 changes: 3 additions & 4 deletions libcrux-ml-dsa/src/arithmetic.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use crate::{
constants::COEFFICIENTS_IN_RING_ELEMENT,
polynomial::{PolynomialRingElement, SIMD_UNITS_IN_RING_ELEMENT},
constants::COEFFICIENTS_IN_RING_ELEMENT, polynomial::PolynomialRingElement,
simd::traits::Operations,
};

Expand Down Expand Up @@ -72,7 +71,7 @@ pub(crate) fn decompose_vector<SIMDUnit: Operations, const DIMENSION: usize, con
let mut vector_high = [PolynomialRingElement::<SIMDUnit>::ZERO(); DIMENSION];

for i in 0..DIMENSION {
for j in 0..SIMD_UNITS_IN_RING_ELEMENT {
for j in 0..vector_low[0].simd_units.len() {
let (low, high) = SIMDUnit::decompose::<GAMMA2>(t[i].simd_units[j]);

vector_low[i].simd_units[j] = low;
Expand Down Expand Up @@ -118,7 +117,7 @@ pub(crate) fn use_hint<SIMDUnit: Operations, const DIMENSION: usize, const GAMMA
for i in 0..DIMENSION {
let hint_simd = PolynomialRingElement::<SIMDUnit>::from_i32_array(&hint[i]);

for j in 0..SIMD_UNITS_IN_RING_ELEMENT {
for j in 0..result[0].simd_units.len() {
result[i].simd_units[j] =
SIMDUnit::use_hint::<GAMMA2>(re_vector[i].simd_units[j], hint_simd.simd_units[j]);
}
Expand Down
93 changes: 4 additions & 89 deletions libcrux-ml-dsa/src/ntt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,98 +34,13 @@ const ZETAS_TIMES_MONTGOMERY_R: [FieldElementTimesMontgomeryR; 256] = [
-1362209, 3937738, 1400424, -846154, 1976782,
];

#[inline(always)]
fn ntt_at_layer_0<SIMDUnit: Operations>(
zeta_i: &mut usize,
re: &mut PolynomialRingElement<SIMDUnit>,
) {
*zeta_i += 1;

for round in 0..re.simd_units.len() {
re.simd_units[round] = SIMDUnit::ntt_at_layer_0(
re.simd_units[round],
ZETAS_TIMES_MONTGOMERY_R[*zeta_i],
ZETAS_TIMES_MONTGOMERY_R[*zeta_i + 1],
ZETAS_TIMES_MONTGOMERY_R[*zeta_i + 2],
ZETAS_TIMES_MONTGOMERY_R[*zeta_i + 3],
);

*zeta_i += 4;
}

*zeta_i -= 1;
}
#[inline(always)]
fn ntt_at_layer_1<SIMDUnit: Operations>(
zeta_i: &mut usize,
re: &mut PolynomialRingElement<SIMDUnit>,
) {
*zeta_i += 1;

for round in 0..re.simd_units.len() {
re.simd_units[round] = SIMDUnit::ntt_at_layer_1(
re.simd_units[round],
ZETAS_TIMES_MONTGOMERY_R[*zeta_i],
ZETAS_TIMES_MONTGOMERY_R[*zeta_i + 1],
);

*zeta_i += 2;
}

*zeta_i -= 1;
}
#[inline(always)]
fn ntt_at_layer_2<SIMDUnit: Operations>(
zeta_i: &mut usize,
re: &mut PolynomialRingElement<SIMDUnit>,
) {
for round in 0..re.simd_units.len() {
*zeta_i += 1;
re.simd_units[round] =
SIMDUnit::ntt_at_layer_2(re.simd_units[round], ZETAS_TIMES_MONTGOMERY_R[*zeta_i]);
}
}
#[inline(always)]
fn ntt_at_layer_3_plus<SIMDUnit: Operations, const LAYER: usize>(
zeta_i: &mut usize,
re: &mut PolynomialRingElement<SIMDUnit>,
) {
let step = 1 << LAYER;

for round in 0..(128 >> LAYER) {
*zeta_i += 1;

let offset = (round * step * 2) / COEFFICIENTS_IN_SIMD_UNIT;
let step_by = step / COEFFICIENTS_IN_SIMD_UNIT;

for j in offset..offset + step_by {
let t = montgomery_multiply_by_fer::<SIMDUnit>(
re.simd_units[j + step_by],
ZETAS_TIMES_MONTGOMERY_R[*zeta_i],
);

re.simd_units[j + step_by] = SIMDUnit::subtract(&re.simd_units[j], &t);
re.simd_units[j] = SIMDUnit::add(&re.simd_units[j], &t);
}
}
}

#[inline(always)]
pub(crate) fn ntt<SIMDUnit: Operations>(
mut re: PolynomialRingElement<SIMDUnit>,
re: PolynomialRingElement<SIMDUnit>,
) -> PolynomialRingElement<SIMDUnit> {
let mut zeta_i = 0;

ntt_at_layer_3_plus::<SIMDUnit, 7>(&mut zeta_i, &mut re);
ntt_at_layer_3_plus::<SIMDUnit, 6>(&mut zeta_i, &mut re);
ntt_at_layer_3_plus::<SIMDUnit, 5>(&mut zeta_i, &mut re);
ntt_at_layer_3_plus::<SIMDUnit, 4>(&mut zeta_i, &mut re);
ntt_at_layer_3_plus::<SIMDUnit, 3>(&mut zeta_i, &mut re);
ntt_at_layer_2::<SIMDUnit>(&mut zeta_i, &mut re);
ntt_at_layer_1::<SIMDUnit>(&mut zeta_i, &mut re);
ntt_at_layer_0::<SIMDUnit>(&mut zeta_i, &mut re);

re
PolynomialRingElement {
simd_units: SIMDUnit::ntt(re.simd_units),
}
}

#[inline(always)]
Expand Down
5 changes: 1 addition & 4 deletions libcrux-ml-dsa/src/polynomial.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
use crate::simd::traits::{Operations, COEFFICIENTS_IN_SIMD_UNIT};

pub(crate) const SIMD_UNITS_IN_RING_ELEMENT: usize =
crate::constants::COEFFICIENTS_IN_RING_ELEMENT / COEFFICIENTS_IN_SIMD_UNIT;
use crate::simd::traits::{Operations, COEFFICIENTS_IN_SIMD_UNIT, SIMD_UNITS_IN_RING_ELEMENT};

#[derive(Clone, Copy)]
pub(crate) struct PolynomialRingElement<SIMDUnit: Operations> {
Expand Down
14 changes: 5 additions & 9 deletions libcrux-ml-dsa/src/simd/avx2.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::simd::traits::Operations;
use crate::simd::traits::{Operations, SIMD_UNITS_IN_RING_ELEMENT};
use libcrux_intrinsics;

mod arithmetic;
Expand Down Expand Up @@ -118,14 +118,10 @@ impl Operations for AVX2SIMDUnit {
encoding::t1::deserialize(serialized).into()
}

fn ntt_at_layer_0(simd_unit: Self, zeta0: i32, zeta1: i32, zeta2: i32, zeta3: i32) -> Self {
ntt::ntt_at_layer_0(simd_unit.coefficients, zeta0, zeta1, zeta2, zeta3).into()
}
fn ntt_at_layer_1(simd_unit: Self, zeta0: i32, zeta1: i32) -> Self {
ntt::ntt_at_layer_1(simd_unit.coefficients, zeta0, zeta1).into()
}
fn ntt_at_layer_2(simd_unit: Self, zeta: i32) -> Self {
ntt::ntt_at_layer_2(simd_unit.coefficients, zeta).into()
fn ntt(simd_units: [Self; SIMD_UNITS_IN_RING_ELEMENT]) -> [Self; SIMD_UNITS_IN_RING_ELEMENT] {
let result = ntt::ntt(simd_units.map(|x| x.coefficients));

result.map(|x| x.into())
}

fn invert_ntt_at_layer_0(
Expand Down
43 changes: 30 additions & 13 deletions libcrux-ml-dsa/src/simd/avx2/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,20 +38,29 @@ fn simd_multiply_i32_and_return_high(lhs: Vec256, rhs: Vec256) -> Vec256 {
}

#[inline(always)]
pub fn montgomery_multiply_by_constant(simd_unit: Vec256, constant: i32) -> Vec256 {
let constant = mm256_set1_epi32(constant);
pub fn montgomery_multiply_by_constant(lhs: Vec256, constant: i32) -> Vec256 {
let rhs = mm256_set1_epi32(constant);
let field_modulus = mm256_set1_epi32(FIELD_MODULUS);
let inverse_of_modulus_mod_montgomery_r =
mm256_set1_epi32(INVERSE_OF_MODULUS_MOD_MONTGOMERY_R as i32);

let product_low = mm256_mullo_epi32(simd_unit, constant);
let prod02 = mm256_mul_epi32(lhs, rhs);
let prod13 = mm256_mul_epi32(
mm256_shuffle_epi32::<0b11_11_01_01>(lhs),
mm256_shuffle_epi32::<0b11_11_01_01>(rhs),
);

let k = mm256_mullo_epi32(product_low, inverse_of_modulus_mod_montgomery_r);
let k02 = mm256_mul_epi32(prod02, inverse_of_modulus_mod_montgomery_r);
let k13 = mm256_mul_epi32(prod13, inverse_of_modulus_mod_montgomery_r);

let c = simd_multiply_i32_and_return_high(k, field_modulus);
let product_high = simd_multiply_i32_and_return_high(simd_unit, constant);
let c02 = mm256_mul_epi32(k02, field_modulus);
let c13 = mm256_mul_epi32(k13, field_modulus);

mm256_sub_epi32(product_high, c)
let res02 = mm256_sub_epi32(prod02, c02);
let res13 = mm256_sub_epi32(prod13, c13);
let res02_shifted = mm256_shuffle_epi32::<0b11_11_01_01>(res02);
let res = mm256_blend_epi32::<0b10101010>(res02_shifted, res13);
res
}

#[inline(always)]
Expand All @@ -60,14 +69,22 @@ pub fn montgomery_multiply(lhs: Vec256, rhs: Vec256) -> Vec256 {
let inverse_of_modulus_mod_montgomery_r =
mm256_set1_epi32(INVERSE_OF_MODULUS_MOD_MONTGOMERY_R as i32);

let product_low = mm256_mullo_epi32(lhs, rhs);

let k = mm256_mullo_epi32(product_low, inverse_of_modulus_mod_montgomery_r);
let prod02 = mm256_mul_epi32(lhs, rhs);
let prod13 = mm256_mul_epi32(
mm256_shuffle_epi32::<0b11_11_01_01>(lhs),
mm256_shuffle_epi32::<0b11_11_01_01>(rhs),
);
let k02 = mm256_mul_epi32(prod02, inverse_of_modulus_mod_montgomery_r);
let k13 = mm256_mul_epi32(prod13, inverse_of_modulus_mod_montgomery_r);

let c = simd_multiply_i32_and_return_high(k, field_modulus);
let product_high = simd_multiply_i32_and_return_high(lhs, rhs);
let c02 = mm256_mul_epi32(k02, field_modulus);
let c13 = mm256_mul_epi32(k13, field_modulus);

mm256_sub_epi32(product_high, c)
let res02 = mm256_sub_epi32(prod02, c02);
let res13 = mm256_sub_epi32(prod13, c13);
let res02_shifted = mm256_shuffle_epi32::<0b11_11_01_01>(res02);
let res = mm256_blend_epi32::<0b10101010>(res02_shifted, res13);
res
}

#[inline(always)]
Expand Down
Loading

0 comments on commit 8c05744

Please sign in to comment.