Skip to content

Commit

Permalink
Save intermediates during sharded shuffle (#1372)
Browse files Browse the repository at this point in the history
* Save intermediates during sharded shuffle

* Change IntermediateShuffleMessages to an enum

* Clarify public entry points to shuffle
  • Loading branch information
andyleiserson authored Oct 29, 2024
1 parent 933a902 commit 4fd5e2b
Show file tree
Hide file tree
Showing 9 changed files with 316 additions and 226 deletions.
6 changes: 3 additions & 3 deletions ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use crate::{
},
oprf_padding::{apply_dp_padding, PaddingParameters},
prf_sharding::{AttributionOutputs, SecretSharedAttributionOutputs},
shuffle::shuffle_attribution_outputs,
shuffle::{shuffle_attribution_outputs, Shuffle},
BreakdownKey,
},
BooleanProtocols, RecordId,
Expand Down Expand Up @@ -70,7 +70,7 @@ pub async fn breakdown_reveal_aggregation<C, BK, TV, HV, const B: usize>(
padding_params: &PaddingParameters,
) -> Result<BitDecomposed<Replicated<Boolean, B>>, Error>
where
C: UpgradableContext,
C: UpgradableContext + Shuffle,
Boolean: FieldSimd<B>,
Replicated<Boolean, B>: BooleanProtocols<DZKPUpgraded<C>, B>,
BK: BreakdownKey<B>,
Expand Down Expand Up @@ -153,7 +153,7 @@ async fn shuffle_attributions<C, BK, TV, const B: usize>(
contribs: Vec<SecretSharedAttributionOutputs<BK, TV>>,
) -> Result<Vec<SecretSharedAttributionOutputs<BK, TV>>, Error>
where
C: Context,
C: Context + Shuffle,
BK: BreakdownKey<B>,
TV: BooleanArray + U128Conversions,
{
Expand Down
3 changes: 2 additions & 1 deletion ipa-core/src/protocol/ipa_prf/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ pub(crate) mod step;
pub mod validation_protocol;

pub use malicious_security::prover::{LargeProofGenerator, SmallProofGenerator};
pub use shuffle::Shuffle;

/// Match key type
pub type MatchKey = BA64;
Expand Down Expand Up @@ -98,7 +99,7 @@ use crate::{
protocol::{
context::Validator,
dp::dp_for_histogram,
ipa_prf::{oprf_padding::PaddingParameters, prf_eval::PrfSharing, shuffle::Shuffle},
ipa_prf::{oprf_padding::PaddingParameters, prf_eval::PrfSharing},
},
secret_sharing::replicated::semi_honest::AdditiveShare,
};
Expand Down
3 changes: 2 additions & 1 deletion ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ use crate::{
AttributionWindowStep as WindowStep,
AttributionZeroOutTriggerStep as ZeroOutTriggerStep, UserNthRowStep,
},
shuffle::Shuffle,
BreakdownKey, AGG_CHUNK,
},
RecordId,
Expand Down Expand Up @@ -471,7 +472,7 @@ pub async fn attribute_cap_aggregate<
padding_parameters: &PaddingParameters,
) -> Result<BitDecomposed<Replicated<Boolean, B>>, Error>
where
C: UpgradableContext + 'ctx,
C: UpgradableContext + Shuffle + 'ctx,
BK: BreakdownKey<B>,
TV: BooleanArray + U128Conversions,
HV: BooleanArray + U128Conversions,
Expand Down
219 changes: 103 additions & 116 deletions ipa-core/src/protocol/ipa_prf/shuffle/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,31 +9,23 @@ use rand::{distributions::Standard, prelude::Distribution, seq::SliceRandom, Rng
use crate::{
error::Error,
helpers::{Direction, MpcReceivingEnd, Role, TotalRecords},
protocol::{context::Context, ipa_prf::shuffle::step::OPRFShuffleStep, RecordId},
protocol::{
context::Context,
ipa_prf::shuffle::{step::OPRFShuffleStep, IntermediateShuffleMessages},
RecordId,
},
secret_sharing::{
replicated::{semi_honest::AdditiveShare, ReplicatedSecretSharing},
SharedValue,
},
};

/// Internal entry point to non-sharded shuffle protocol, excluding validation of
/// intermediates for malicious security. Protocols should use `trait Shuffle`.
///
/// # Errors
/// Will propagate errors from transport and a few typecasts
pub async fn semi_honest_shuffle<C, I, S>(ctx: C, shares: I) -> Result<Vec<AdditiveShare<S>>, Error>
where
C: Context,
I: IntoIterator<Item = AdditiveShare<S>>,
I::IntoIter: ExactSizeIterator,
S: SharedValue + Add<Output = S>,
for<'a> &'a S: Add<S, Output = S>,
for<'a> &'a S: Add<&'a S, Output = S>,
Standard: Distribution<S>,
{
Ok(shuffle_protocol(ctx, shares).await?.0)
}

/// # Errors
/// Will propagate errors from transport and a few typecasts
pub async fn shuffle_protocol<C, I, S>(
pub(super) async fn shuffle_protocol<C, I, S>(
ctx: C,
shares: I,
) -> Result<(Vec<AdditiveShare<S>>, IntermediateShuffleMessages<S>), Error>
Expand All @@ -50,13 +42,7 @@ where
// This protocol can take a mutable iterator and replace items in the input.
let shares = shares.into_iter();
let Some(shares_len) = NonZeroUsize::new(shares.len()) else {
return Ok((
vec![],
IntermediateShuffleMessages {
x1_or_y1: None,
x2_or_y2: None,
},
));
return Ok((vec![], IntermediateShuffleMessages::empty(&ctx)));
};
let ctx_z = ctx.narrow(&OPRFShuffleStep::GenerateZ);
let zs = generate_random_tables_with_peers(shares_len, &ctx_z);
Expand All @@ -68,48 +54,6 @@ where
}
}

/// This struct stores some intermediate messages during the shuffle.
/// In a maliciously secure shuffle,
/// these messages need to be checked for consistency across helpers.
/// `H1` stores `x1`, `H2` stores `x2` and `H3` stores `y1` and `y2`.
#[derive(Debug, Clone)]
pub struct IntermediateShuffleMessages<S: SharedValue> {
x1_or_y1: Option<Vec<S>>,
x2_or_y2: Option<Vec<S>>,
}

impl<S: SharedValue> IntermediateShuffleMessages<S> {
/// When `IntermediateShuffleMessages` is initialized correctly,
/// this function returns `x1` when `Role = H1`
/// and `y1` when `Role = H3`.
///
/// ## Panics
/// Panics when `Role = H2`, i.e. `x1_or_y1` is `None`.
pub fn get_x1_or_y1(self) -> Vec<S> {
self.x1_or_y1.unwrap()
}

/// When `IntermediateShuffleMessages` is initialized correctly,
/// this function returns `x2` when `Role = H2`
/// and `y2` when `Role = H3`.
///
/// ## Panics
/// Panics when `Role = H1`, i.e. `x2_or_y2` is `None`.
pub fn get_x2_or_y2(self) -> Vec<S> {
self.x2_or_y2.unwrap()
}

/// When `IntermediateShuffleMessages` is initialized correctly,
/// this function returns `y1` and `y2` when `Role = H3`.
///
/// ## Panics
/// Panics when `Role = H1`, i.e. `x2_or_y2` is `None` or
/// when `Role = H2`, i.e. `x1_or_y1` is `None`.
pub fn get_both_x_or_ys(self) -> (Vec<S>, Vec<S>) {
(self.x1_or_y1.unwrap(), self.x2_or_y2.unwrap())
}
}

async fn run_h1<C, I, S, Zl, Zr>(
ctx: &C,
batch_size: NonZeroUsize,
Expand Down Expand Up @@ -152,13 +96,7 @@ where

let res = combine_single_shares(a_hat, b_hat).collect::<Vec<_>>();
// we only need to store x_1 in IntermediateShuffleMessage
Ok((
res,
IntermediateShuffleMessages {
x1_or_y1: Some(x_1),
x2_or_y2: None,
},
))
Ok((res, IntermediateShuffleMessages::H1 { x1: x_1 }))
}

async fn run_h2<C, I, S, Zl, Zr>(
Expand Down Expand Up @@ -234,13 +172,7 @@ where
let c_hat = add_single_shares(c_hat_1.iter(), c_hat_2.iter());
let res = combine_single_shares(b_hat, c_hat).collect::<Vec<_>>();
// we only need to store x_2 in IntermediateShuffleMessage
Ok((
res,
IntermediateShuffleMessages {
x1_or_y1: None,
x2_or_y2: Some(x_2),
},
))
Ok((res, IntermediateShuffleMessages::H2 { x2: x_2 }))
}

async fn run_h3<C, S, Zl, Zr>(
Expand Down Expand Up @@ -310,13 +242,7 @@ where

let c_hat = add_single_shares(c_hat_1, c_hat_2);
let res = combine_single_shares(c_hat, a_hat).collect::<Vec<_>>();
Ok((
res,
IntermediateShuffleMessages {
x1_or_y1: Some(y_1),
x2_or_y2: Some(y_2),
},
))
Ok((res, IntermediateShuffleMessages::H3 { y1: y_1, y2: y_2 }))
}

fn add_single_shares<A, B, S, L, R>(l: L, r: R) -> impl Iterator<Item = S>
Expand Down Expand Up @@ -439,14 +365,95 @@ where
Ok(())
}

#[cfg(all(test, any(unit_test, feature = "shuttle")))]
pub(super) mod test_helpers {
use std::iter::zip;

use crate::{
protocol::ipa_prf::shuffle::IntermediateShuffleMessages,
secret_sharing::{
replicated::{semi_honest::AdditiveShare, ReplicatedSecretSharing},
SharedValue,
},
};

fn check_replicated_shares<'a, S, I1, I2>(left_helper_shares: I1, right_helper_shares: I2)
where
S: SharedValue,
I1: Iterator<Item = &'a AdditiveShare<S>>,
I2: Iterator<Item = &'a AdditiveShare<S>>,
{
assert!(zip(left_helper_shares, right_helper_shares)
.all(|(lhs, rhs)| lhs.right() == rhs.left()));
}

pub struct ExtractedShuffleResults<S> {
pub x1_xor_y1: Vec<S>,
pub x2_xor_y2: Vec<S>,
pub a_xor_b_xor_c: Vec<S>,
}

impl<S> ExtractedShuffleResults<S> {
pub fn empty() -> Self {
ExtractedShuffleResults {
x1_xor_y1: vec![],
x2_xor_y2: vec![],
a_xor_b_xor_c: vec![],
}
}
}

/// Extract the data returned from shuffle (the shuffled records and intermediate
/// values) into a more usable form for verification. This routine is used for
/// both the unsharded and sharded shuffles.
pub fn extract_shuffle_results<S: SharedValue>(
results: [(Vec<AdditiveShare<S>>, IntermediateShuffleMessages<S>); 3],
) -> ExtractedShuffleResults<S> {
// check consistency
// i.e. x_1 xor y_1 = x_2 xor y_2 = C xor A xor B
let [(h1_shares, h1_messages), (h2_shares, h2_messages), (h3_shares, h3_messages)] =
results;

let IntermediateShuffleMessages::H1 { x1 } = h1_messages else {
panic!("H1 returned shuffle messages for {:?}", h1_messages.role());
};
let IntermediateShuffleMessages::H2 { x2 } = h2_messages else {
panic!("H2 returned shuffle messages for {:?}", h2_messages.role());
};
let IntermediateShuffleMessages::H3 { y1, y2 } = h3_messages else {
panic!("H3 returned shuffle messages for {:?}", h3_messages.role());
};

check_replicated_shares(h1_shares.iter(), h2_shares.iter());
check_replicated_shares(h2_shares.iter(), h3_shares.iter());
check_replicated_shares(h3_shares.iter(), h1_shares.iter());

let x1_xor_y1 = zip(x1, y1).map(|(x1, y1)| x1 + y1).collect();

let x2_xor_y2 = zip(x2, y2).map(|(x2, y2)| x2 + y2).collect();

let a_xor_b_xor_c = zip(&h1_shares, h3_shares)
.map(|(h1_share, h3_share)| h1_share.left() + h1_share.right() + h3_share.left())
.collect();

ExtractedShuffleResults {
x1_xor_y1,
x2_xor_y2,
a_xor_b_xor_c,
}
}
}

#[cfg(all(test, unit_test))]
pub mod tests {
pub(super) mod tests {
use rand::{thread_rng, Rng};

use super::shuffle_protocol;
use crate::{
ff::{Gf40Bit, U128Conversions},
secret_sharing::replicated::ReplicatedSecretSharing,
protocol::ipa_prf::shuffle::base::test_helpers::{
extract_shuffle_results, ExtractedShuffleResults,
},
test_executor::run,
test_fixture::{Reconstruct, Runner, TestWorld, TestWorldConfig},
};
Expand Down Expand Up @@ -495,37 +502,17 @@ pub mod tests {
.map(|_| rng.gen())
.collect::<Vec<Gf40Bit>>();

let [h1, h2, h3] = world
let results = world
.semi_honest(records.clone().into_iter(), |ctx, records| async move {
shuffle_protocol(ctx, records).await
shuffle_protocol(ctx, records).await.unwrap()
})
.await;

// check consistency
// i.e. x_1 xor y_1 = x_2 xor y_2 = C xor A xor B
let (h1_shares, h1_messages) = h1.unwrap();
let (_, h2_messages) = h2.unwrap();
let (h3_shares, h3_messages) = h3.unwrap();

let mut x1_xor_y1 = h1_messages
.x1_or_y1
.unwrap()
.iter()
.zip(h3_messages.x1_or_y1.unwrap())
.map(|(x1, y1)| x1 + y1)
.collect::<Vec<_>>();
let mut x2_xor_y2 = h2_messages
.x2_or_y2
.unwrap()
.iter()
.zip(h3_messages.x2_or_y2.unwrap())
.map(|(x2, y2)| x2 + y2)
.collect::<Vec<_>>();
let mut a_xor_b_xor_c = h1_shares
.iter()
.zip(h3_shares)
.map(|(h1_share, h3_share)| h1_share.left() + h1_share.right() + h3_share.left())
.collect::<Vec<_>>();
let ExtractedShuffleResults {
mut x1_xor_y1,
mut x2_xor_y2,
mut a_xor_b_xor_c,
} = extract_shuffle_results(results);

// unshuffle by sorting
records.sort();
Expand Down
Loading

0 comments on commit 4fd5e2b

Please sign in to comment.