Skip to content

Commit

Permalink
Add parallelism to aggregate_non_interactive_multi_party_server_key_s…
Browse files Browse the repository at this point in the history
…hares

For example,

In a `12 cores` server with `48GB RAM`,
the call to `aggregate_server_key_shares`:
- for `examples/if_and_else`:
  - prior to this commit it took `47.56s`
  - with this commit it takes `4.66s`
- for `examples/non_interactive_fheuint8`:
  - prior to this commit it took `158.15s`
  - with this commit it takes `14.96s`

so about `~10x` reduction.

In a `4 cores` laptop with `8GB RAM` (low capacity laptop, with multiple other apps running    ),
the call to `aggregate_server_key_shares`:
- for `examples/if_and_else`:
  - prior to this commit it took `48.65s`
  - with this commit it takes `23.11s`

so about `~2x` reduction.

Co-authored-by: Carlos Pérez <[email protected]>
  • Loading branch information
arnaucube and CPerezz committed Jul 25, 2024
1 parent b9cfe75 commit 720b13f
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 48 deletions.
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ num-traits = "0.2.18"
rand = "0.8.5"
rand_chacha = "0.3.1"
rand_distr = "0.4.3"
rayon = "1.10.0"

[dev-dependencies]
criterion = "0.5.1"
Expand Down Expand Up @@ -59,4 +60,4 @@ required-features = ["non_interactive_mp"]
[[example]]
name = "if_and_else"
path = "./examples/if_and_else.rs"
required-features = ["non_interactive_mp"]
required-features = ["non_interactive_mp"]
108 changes: 61 additions & 47 deletions src/bool/evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ use itertools::{izip, Itertools};
use num_traits::{FromPrimitive, One, PrimInt, ToPrimitive, WrappingAdd, WrappingSub, Zero};
use rand_distr::uniform::SampleUniform;

use rayon::iter::FlatMap;
use rayon::prelude::*;
use std::sync::{Arc, Mutex};

use crate::{
backend::{ArithmeticOps, GetModulus, ModInit, Modulus, ShoupMatrixFMA, VectorOps},
bool::parameters::ParameterVariant,
Expand Down Expand Up @@ -627,7 +631,7 @@ pub(super) fn multi_party_user_id_lwe_segment(

impl<M: Matrix, NttOp, RlweModOp, LweModOp, SKey> BoolEvaluator<M, NttOp, RlweModOp, LweModOp, SKey>
where
M: MatrixEntity + MatrixMut,
M: MatrixEntity + MatrixMut + Send + Sync,
M::MatElement: PrimInt
+ Debug
+ Display
Expand All @@ -636,16 +640,21 @@ where
+ WrappingSub
+ WrappingAdd
+ SampleUniform
+ From<bool>,
NttOp: Ntt<Element = M::MatElement>,
+ From<bool>
+ Send
+ Sync,
NttOp: Ntt<Element = M::MatElement> + Send + Sync,
RlweModOp: ArithmeticOps<Element = M::MatElement>
+ VectorOps<Element = M::MatElement>
+ GetModulus<Element = M::MatElement, M = CiphertextModulus<M::MatElement>>
+ ShoupMatrixFMA<M::R>,
+ ShoupMatrixFMA<M::R>
+ Send
+ Sync,
LweModOp: ArithmeticOps<Element = M::MatElement>
+ VectorOps<Element = M::MatElement>
+ GetModulus<Element = M::MatElement, M = CiphertextModulus<M::MatElement>>,
M::R: TryConvertFrom1<[i32], CiphertextModulus<M::MatElement>> + RowEntity + Debug,
M::R:
TryConvertFrom1<[i32], CiphertextModulus<M::MatElement>> + RowEntity + Debug + Send + Sync,
<M as Matrix>::R: RowMut,
{
pub(super) fn new(parameters: BoolParameters<M::MatElement>) -> Self
Expand Down Expand Up @@ -1378,15 +1387,20 @@ where
)
})
.collect_vec();
// clone self.rlwe_n & self.parameters so that we don't access to &self inside the
// threads
let rlwe_n = self.parameters().rlwe_n().0.clone();
let parameters = self.parameters().clone();
// Note: Each user is assigned a contigous LWE segement and the LWE dimension is
// split approximately uniformly across all users. Hence, concatenation of all
// user specific lwe segments will give LWE dimension.
let rgsw_cts = user_segments
.into_iter()
.enumerate()
.flat_map(|(user_id, lwe_segment)| {
let mut rgsws = Vec::new();
(lwe_segment.0..lwe_segment.1)
.into_iter()
.into_par_iter()
.map(|lwe_index| {
// We sample d_b `-a_i`s to key switch and generate RLWE'(m). But before
// we sampling we need to puncture a_prng d_max - d_b times to align
Expand All @@ -1396,7 +1410,7 @@ where
cr_seed.ni_rgsw_ct_seed_for_index::<DefaultSecureRng>(lwe_index),
);

let mut scratch = M::R::zeros(self.parameters().rlwe_n().0);
let mut scratch = M::R::zeros(rlwe_n);
(0..d_max - rgsw_x_rgsw_decomposer.b().decomposition_count().0)
.for_each(|_| {
RandomFillUniformInModulus::random_fill(
Expand All @@ -1420,7 +1434,7 @@ where

let mut decomp_neg_ai = M::zeros(
ni_uj_to_s_decomposer.decomposition_count().0,
self.parameters().rlwe_n().0,
rlwe_n,
);
scratch.as_ref().iter().enumerate().for_each(|(index, el)| {
ni_uj_to_s_decomposer
Expand All @@ -1447,41 +1461,40 @@ where
// then use to produce RLWE'(-sX^{s_{lwe}[l]}).
// Hence, after aggregation we decompose a_{i, l} * s + e to
// prepare for key switching
let ni_rgsw_zero_encs = (0..rgsw_x_rgsw_decomposer
.a()
.decomposition_count()
.0)
.map(|i| {
let mut sum = M::R::zeros(self.parameters().rlwe_n().0);
key_shares.iter().for_each(|k| {
let to_add_ref = k
.ni_rgsw_zero_enc_for_lwe_index(lwe_index)
.get_row_slice(i);
assert!(to_add_ref.len() == self.parameters().rlwe_n().0);
rlwe_modop.elwise_add_mut(sum.as_mut(), to_add_ref);
});

// decompose
let mut decomp_sum = M::zeros(
ni_uj_to_s_decomposer.decomposition_count().0,
self.parameters().rlwe_n().0,
);
sum.as_ref().iter().enumerate().for_each(|(index, el)| {
ni_uj_to_s_decomposer
.decompose_iter(el)
.enumerate()
.for_each(|(row_j, d_el)| {
(decomp_sum.as_mut()[row_j]).as_mut()[index] = d_el;
});
});

decomp_sum
.iter_rows_mut()
.for_each(|r| nttop.forward(r.as_mut()));

decomp_sum
})
.collect_vec();
let ni_rgsw_zero_encs =
(0..rgsw_x_rgsw_decomposer.a().decomposition_count().0)
.map(|i| {
let mut sum = M::R::zeros(rlwe_n);
key_shares.iter().for_each(|k| {
let to_add_ref = k
.ni_rgsw_zero_enc_for_lwe_index(lwe_index)
.get_row_slice(i);
assert!(to_add_ref.len() == rlwe_n);
rlwe_modop.elwise_add_mut(sum.as_mut(), to_add_ref);
});

// decompose
let mut decomp_sum = M::zeros(
ni_uj_to_s_decomposer.decomposition_count().0,
rlwe_n,
);
sum.as_ref().iter().enumerate().for_each(|(index, el)| {
ni_uj_to_s_decomposer
.decompose_iter(el)
.enumerate()
.for_each(|(row_j, d_el)| {
(decomp_sum.as_mut()[row_j]).as_mut()[index] =
d_el;
});
});

decomp_sum
.iter_rows_mut()
.for_each(|r| nttop.forward(r.as_mut()));

decomp_sum
})
.collect_vec();

// Produce RGSW(X^{s_{j=user_id, lwe}[l]}) for the
// leader, ie user's id = user_id.
Expand All @@ -1503,7 +1516,7 @@ where
.0
- rlwe_x_rgsw_decomposer.b().decomposition_count().0..],
&rlwe_x_rgsw_decomposer,
self.parameters(),
&parameters,
(&uj_to_s_ksks[user_id], &uj_to_s_ksks_part_a_eval[user_id]),
rlwe_modop,
nttop,
Expand All @@ -1524,7 +1537,7 @@ where
&ni_rgsw_zero_encs,
&decomp_neg_ais,
&rgsw_x_rgsw_decomposer,
self.parameters(),
&parameters,
(
&uj_to_s_ksks[other_user_id],
&uj_to_s_ksks_part_a_eval[other_user_id],
Expand All @@ -1548,7 +1561,7 @@ where
&rlwe_x_rgsw_decomposer,
&rgsw_x_rgsw_decomposer,
&mut RuntimeScratchMutRef::new(
scratch_rgsw_x_rgsw.as_mut(),
scratch_rgsw_x_rgsw.clone().as_mut(),
),
nttop,
rlwe_modop,
Expand All @@ -1557,7 +1570,8 @@ where

rgsw_i
})
.collect_vec()
.collect_into_vec(&mut rgsws);
rgsws
})
.collect_vec();

Expand Down

0 comments on commit 720b13f

Please sign in to comment.