Skip to content

Commit

Permalink
Add parallelism to aggregate_non_interactive_multi_party_server_key_s…
Browse files Browse the repository at this point in the history
…hares

For example,

In a `12 cores` server with `48GB RAM`,
the call to `aggregate_server_key_shares`:
- for `examples/if_and_else`:
  - prior to this commit it took `47.56s`
  - with this commit it takes `4.66s`
- for `examples/non_interactive_fheuint8`:
  - prior to this commit it took `158.15s`
  - with this commit it takes `14.96s`

so about `~10x` reduction.

In a `4 cores` laptop with `8GB RAM` (and multiple other apps running),
the call to `aggregate_server_key_shares`:
- for `examples/if_and_else`:
  - prior to this commit it took `48.65s`
  - with this commit it takes `23.11s`

so about `~2x` reduction.

Co-authored-by: Carlos Pérez <[email protected]>
  • Loading branch information
arnaucube and CPerezz committed Jul 24, 2024
1 parent a8e6c27 commit 7f6584f
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 54 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ rand = "0.8.5"
rand_chacha = "0.3.1"
rand_distr = "0.4.3"
num-bigint-dig = { version = "0.8.4", features = ["prime"] }
rayon = "1.10.0"

[dev-dependencies]
criterion = "0.5.1"
Expand Down Expand Up @@ -58,4 +59,4 @@ required-features = ["non_interactive_mp"]
[[example]]
name = "if_and_else"
path = "./examples/if_and_else.rs"
required-features = ["non_interactive_mp"]
required-features = ["non_interactive_mp"]
128 changes: 75 additions & 53 deletions src/bool/evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ use itertools::{izip, Itertools};
use num_traits::{FromPrimitive, One, PrimInt, ToPrimitive, WrappingAdd, WrappingSub, Zero};
use rand_distr::uniform::SampleUniform;

use rayon::iter::FlatMap;
use rayon::prelude::*;
use std::sync::{Arc, Mutex};

use crate::{
backend::{ArithmeticOps, GetModulus, ModInit, Modulus, ShoupMatrixFMA, VectorOps},
bool::parameters::ParameterVariant,
Expand Down Expand Up @@ -627,7 +631,7 @@ pub(super) fn multi_party_user_id_lwe_segment(

impl<M: Matrix, NttOp, RlweModOp, LweModOp, SKey> BoolEvaluator<M, NttOp, RlweModOp, LweModOp, SKey>
where
M: MatrixEntity + MatrixMut,
M: MatrixEntity + MatrixMut + Send + Sync,
M::MatElement: PrimInt
+ Debug
+ Display
Expand All @@ -636,16 +640,21 @@ where
+ WrappingSub
+ WrappingAdd
+ SampleUniform
+ Send
+ Sync
+ From<bool>,
NttOp: Ntt<Element = M::MatElement>,
NttOp: Ntt<Element = M::MatElement> + Send + Sync,
RlweModOp: ArithmeticOps<Element = M::MatElement>
+ VectorOps<Element = M::MatElement>
+ GetModulus<Element = M::MatElement, M = CiphertextModulus<M::MatElement>>
+ ShoupMatrixFMA<M::R>,
+ ShoupMatrixFMA<M::R>
+ Send
+ Sync,
LweModOp: ArithmeticOps<Element = M::MatElement>
+ VectorOps<Element = M::MatElement>
+ GetModulus<Element = M::MatElement, M = CiphertextModulus<M::MatElement>>,
M::R: TryConvertFrom1<[i32], CiphertextModulus<M::MatElement>> + RowEntity + Debug,
M::R:
TryConvertFrom1<[i32], CiphertextModulus<M::MatElement>> + RowEntity + Debug + Send + Sync,
<M as Matrix>::R: RowMut,
{
pub(super) fn new(parameters: BoolParameters<M::MatElement>) -> Self
Expand Down Expand Up @@ -1369,7 +1378,7 @@ where
// u_j)
//
// a_{i, l} * s + e = \sum_{j \in P} a_{i, l} * s_{j} + e
let user_segments = (0..total_users)
let user_segments: Vec<(usize, usize)> = (0..total_users)
.map(|user_id| {
multi_party_user_id_lwe_segment(
user_id,
Expand All @@ -1381,22 +1390,31 @@ where
// Note: Each user is assigned a contigous LWE segement and the LWE dimension is
// split approximately uniformly across all users. Hence, concatenation of all
// user specific lwe segments will give LWE dimension.

// clone self.rlwe_n & self.parameters so that we don't access to &self inside the
// threads
let rlwe_n = self.parameters().rlwe_n().0.clone();
let parameters = self.parameters().clone();

let rgsw_cts = user_segments
.into_iter()
.enumerate()
.flat_map(|(user_id, lwe_segment)| {
// `results` will contain the computed values by each thread (unsorted)
let results = Arc::new(Mutex::new(Vec::new()));
(lwe_segment.0..lwe_segment.1)
.into_iter()
.map(|lwe_index| {
.collect::<Vec<usize>>()
.par_iter()
.for_each(|lwe_index| {
// We sample d_b `-a_i`s to key switch and generate RLWE'(m). But before
// we sampling we need to puncture a_prng d_max - d_b times to align
// a_i's. After sampling we decompose `-a_i`s and send them to
// evaluation domain for upcoming key switches.
let mut a_prng = DefaultSecureRng::new_seeded(
cr_seed.ni_rgsw_ct_seed_for_index::<DefaultSecureRng>(lwe_index),
cr_seed.ni_rgsw_ct_seed_for_index::<DefaultSecureRng>(*lwe_index),
);

let mut scratch = M::R::zeros(self.parameters().rlwe_n().0);
let mut scratch = M::R::zeros(rlwe_n);
(0..d_max - rgsw_x_rgsw_decomposer.b().decomposition_count().0)
.for_each(|_| {
RandomFillUniformInModulus::random_fill(
Expand All @@ -1420,7 +1438,7 @@ where

let mut decomp_neg_ai = M::zeros(
ni_uj_to_s_decomposer.decomposition_count().0,
self.parameters().rlwe_n().0,
rlwe_n,
);
scratch.as_ref().iter().enumerate().for_each(|(index, el)| {
ni_uj_to_s_decomposer
Expand All @@ -1447,41 +1465,40 @@ where
// then use to produce RLWE'(-sX^{s_{lwe}[l]}).
// Hence, after aggregation we decompose a_{i, l} * s + e to
// prepare for key switching
let ni_rgsw_zero_encs = (0..rgsw_x_rgsw_decomposer
.a()
.decomposition_count()
.0)
.map(|i| {
let mut sum = M::R::zeros(self.parameters().rlwe_n().0);
key_shares.iter().for_each(|k| {
let to_add_ref = k
.ni_rgsw_zero_enc_for_lwe_index(lwe_index)
.get_row_slice(i);
assert!(to_add_ref.len() == self.parameters().rlwe_n().0);
rlwe_modop.elwise_add_mut(sum.as_mut(), to_add_ref);
});

// decompose
let mut decomp_sum = M::zeros(
ni_uj_to_s_decomposer.decomposition_count().0,
self.parameters().rlwe_n().0,
);
sum.as_ref().iter().enumerate().for_each(|(index, el)| {
ni_uj_to_s_decomposer
.decompose_iter(el)
.enumerate()
.for_each(|(row_j, d_el)| {
(decomp_sum.as_mut()[row_j]).as_mut()[index] = d_el;
});
});

decomp_sum
.iter_rows_mut()
.for_each(|r| nttop.forward(r.as_mut()));

decomp_sum
})
.collect_vec();
let ni_rgsw_zero_encs =
(0..rgsw_x_rgsw_decomposer.a().decomposition_count().0)
.map(|i| {
let mut sum = M::R::zeros(rlwe_n);
key_shares.iter().for_each(|k| {
let to_add_ref = k
.ni_rgsw_zero_enc_for_lwe_index(*lwe_index)
.get_row_slice(i);
assert!(to_add_ref.len() == rlwe_n);
rlwe_modop.elwise_add_mut(sum.as_mut(), to_add_ref);
});

// decompose
let mut decomp_sum = M::zeros(
ni_uj_to_s_decomposer.decomposition_count().0,
rlwe_n,
);
sum.as_ref().iter().enumerate().for_each(|(index, el)| {
ni_uj_to_s_decomposer
.decompose_iter(el)
.enumerate()
.for_each(|(row_j, d_el)| {
(decomp_sum.as_mut()[row_j]).as_mut()[index] =
d_el;
});
});

decomp_sum
.iter_rows_mut()
.for_each(|r| nttop.forward(r.as_mut()));

decomp_sum
})
.collect_vec();

// Produce RGSW(X^{s_{j=user_id, lwe}[l]}) for the
// leader, ie user's id = user_id.
Expand All @@ -1491,7 +1508,7 @@ where
// X^{s_{j != user_id, lwe}[l]} from other users.
let mut rgsw_i = produce_rgsw_ciphertext_from_ni_rgsw(
key_shares[user_id]
.ni_rgsw_cts_for_self_leader_lwe_index(lwe_index),
.ni_rgsw_cts_for_self_leader_lwe_index(*lwe_index),
&ni_rgsw_zero_encs[rgsw_x_rgsw_decomposer
.a()
.decomposition_count()
Expand All @@ -1503,7 +1520,7 @@ where
.0
- rlwe_x_rgsw_decomposer.b().decomposition_count().0..],
&rlwe_x_rgsw_decomposer,
self.parameters(),
&parameters,
(&uj_to_s_ksks[user_id], &uj_to_s_ksks_part_a_eval[user_id]),
rlwe_modop,
nttop,
Expand All @@ -1520,11 +1537,11 @@ where
.for_each(|other_user_id| {
let mut other_rgsw_i = produce_rgsw_ciphertext_from_ni_rgsw(
key_shares[other_user_id]
.ni_rgsw_cts_for_self_not_leader_lwe_index(lwe_index),
.ni_rgsw_cts_for_self_not_leader_lwe_index(*lwe_index),
&ni_rgsw_zero_encs,
&decomp_neg_ais,
&rgsw_x_rgsw_decomposer,
self.parameters(),
&parameters,
(
&uj_to_s_ksks[other_user_id],
&uj_to_s_ksks_part_a_eval[other_user_id],
Expand All @@ -1548,16 +1565,21 @@ where
&rlwe_x_rgsw_decomposer,
&rgsw_x_rgsw_decomposer,
&mut RuntimeScratchMutRef::new(
scratch_rgsw_x_rgsw.as_mut(),
scratch_rgsw_x_rgsw.clone().as_mut(),
),
nttop,
rlwe_modop,
)
});

rgsw_i
})
.collect_vec()
results.lock().unwrap().push((lwe_index.clone(), rgsw_i));
});

// collect from threads
let mut res = results.lock().unwrap().clone();
// sort results by lwe_index and return the rgsw_i vec
res.sort_by(|x, y| x.0.partial_cmp(&y.0).unwrap());
res.iter().map(|r| r.1.clone()).collect::<Vec<M>>()
})
.collect_vec();

Expand Down

0 comments on commit 7f6584f

Please sign in to comment.