From 4fd5e2be4ce39c6f461db9563341f9ff84718688 Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Tue, 29 Oct 2024 10:00:06 -0700 Subject: [PATCH] Save intermediates during sharded shuffle (#1372) * Save intermediates during sharded shuffle * Change IntermediateShuffleMessages to an enum * Clarify public entry points to shuffle --- .../ipa_prf/aggregation/breakdown_reveal.rs | 6 +- ipa-core/src/protocol/ipa_prf/mod.rs | 3 +- .../src/protocol/ipa_prf/prf_sharding/mod.rs | 3 +- ipa-core/src/protocol/ipa_prf/shuffle/base.rs | 219 ++++++++---------- .../src/protocol/ipa_prf/shuffle/malicious.rs | 26 ++- ipa-core/src/protocol/ipa_prf/shuffle/mod.rs | 125 +++++----- .../src/protocol/ipa_prf/shuffle/sharded.rs | 154 +++++++++--- ipa-core/src/query/runner/oprf_ipa.rs | 4 +- ipa-core/src/secret_sharing/mod.rs | 2 +- 9 files changed, 316 insertions(+), 226 deletions(-) diff --git a/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs b/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs index d3ec2ca3e..372cf2447 100644 --- a/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs +++ b/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs @@ -24,7 +24,7 @@ use crate::{ }, oprf_padding::{apply_dp_padding, PaddingParameters}, prf_sharding::{AttributionOutputs, SecretSharedAttributionOutputs}, - shuffle::shuffle_attribution_outputs, + shuffle::{shuffle_attribution_outputs, Shuffle}, BreakdownKey, }, BooleanProtocols, RecordId, @@ -70,7 +70,7 @@ pub async fn breakdown_reveal_aggregation( padding_params: &PaddingParameters, ) -> Result>, Error> where - C: UpgradableContext, + C: UpgradableContext + Shuffle, Boolean: FieldSimd, Replicated: BooleanProtocols, B>, BK: BreakdownKey, @@ -153,7 +153,7 @@ async fn shuffle_attributions( contribs: Vec>, ) -> Result>, Error> where - C: Context, + C: Context + Shuffle, BK: BreakdownKey, TV: BooleanArray + U128Conversions, { diff --git a/ipa-core/src/protocol/ipa_prf/mod.rs b/ipa-core/src/protocol/ipa_prf/mod.rs index 55a02f9f8..f01370567 100644 --- a/ipa-core/src/protocol/ipa_prf/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/mod.rs @@ -56,6 +56,7 @@ pub(crate) mod step; pub mod validation_protocol; pub use malicious_security::prover::{LargeProofGenerator, SmallProofGenerator}; +pub use shuffle::Shuffle; /// Match key type pub type MatchKey = BA64; @@ -98,7 +99,7 @@ use crate::{ protocol::{ context::Validator, dp::dp_for_histogram, - ipa_prf::{oprf_padding::PaddingParameters, prf_eval::PrfSharing, shuffle::Shuffle}, + ipa_prf::{oprf_padding::PaddingParameters, prf_eval::PrfSharing}, }, secret_sharing::replicated::semi_honest::AdditiveShare, }; diff --git a/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs b/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs index 0df23f6c9..f6bf2e339 100644 --- a/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs @@ -43,6 +43,7 @@ use crate::{ AttributionWindowStep as WindowStep, AttributionZeroOutTriggerStep as ZeroOutTriggerStep, UserNthRowStep, }, + shuffle::Shuffle, BreakdownKey, AGG_CHUNK, }, RecordId, @@ -471,7 +472,7 @@ pub async fn attribute_cap_aggregate< padding_parameters: &PaddingParameters, ) -> Result>, Error> where - C: UpgradableContext + 'ctx, + C: UpgradableContext + Shuffle + 'ctx, BK: BreakdownKey, TV: BooleanArray + U128Conversions, HV: BooleanArray + U128Conversions, diff --git a/ipa-core/src/protocol/ipa_prf/shuffle/base.rs b/ipa-core/src/protocol/ipa_prf/shuffle/base.rs index 83343b739..6b495eff7 100644 --- a/ipa-core/src/protocol/ipa_prf/shuffle/base.rs +++ b/ipa-core/src/protocol/ipa_prf/shuffle/base.rs @@ -9,31 +9,23 @@ use rand::{distributions::Standard, prelude::Distribution, seq::SliceRandom, Rng use crate::{ error::Error, helpers::{Direction, MpcReceivingEnd, Role, TotalRecords}, - protocol::{context::Context, ipa_prf::shuffle::step::OPRFShuffleStep, RecordId}, + protocol::{ + context::Context, + ipa_prf::shuffle::{step::OPRFShuffleStep, IntermediateShuffleMessages}, + RecordId, + }, secret_sharing::{ replicated::{semi_honest::AdditiveShare, ReplicatedSecretSharing}, SharedValue, }, }; +/// Internal entry point to non-sharded shuffle protocol, excluding validation of +/// intermediates for malicious security. Protocols should use `trait Shuffle`. +/// /// # Errors /// Will propagate errors from transport and a few typecasts -pub async fn semi_honest_shuffle(ctx: C, shares: I) -> Result>, Error> -where - C: Context, - I: IntoIterator>, - I::IntoIter: ExactSizeIterator, - S: SharedValue + Add, - for<'a> &'a S: Add, - for<'a> &'a S: Add<&'a S, Output = S>, - Standard: Distribution, -{ - Ok(shuffle_protocol(ctx, shares).await?.0) -} - -/// # Errors -/// Will propagate errors from transport and a few typecasts -pub async fn shuffle_protocol( +pub(super) async fn shuffle_protocol( ctx: C, shares: I, ) -> Result<(Vec>, IntermediateShuffleMessages), Error> @@ -50,13 +42,7 @@ where // This protocol can take a mutable iterator and replace items in the input. let shares = shares.into_iter(); let Some(shares_len) = NonZeroUsize::new(shares.len()) else { - return Ok(( - vec![], - IntermediateShuffleMessages { - x1_or_y1: None, - x2_or_y2: None, - }, - )); + return Ok((vec![], IntermediateShuffleMessages::empty(&ctx))); }; let ctx_z = ctx.narrow(&OPRFShuffleStep::GenerateZ); let zs = generate_random_tables_with_peers(shares_len, &ctx_z); @@ -68,48 +54,6 @@ where } } -/// This struct stores some intermediate messages during the shuffle. -/// In a maliciously secure shuffle, -/// these messages need to be checked for consistency across helpers. -/// `H1` stores `x1`, `H2` stores `x2` and `H3` stores `y1` and `y2`. -#[derive(Debug, Clone)] -pub struct IntermediateShuffleMessages { - x1_or_y1: Option>, - x2_or_y2: Option>, -} - -impl IntermediateShuffleMessages { - /// When `IntermediateShuffleMessages` is initialized correctly, - /// this function returns `x1` when `Role = H1` - /// and `y1` when `Role = H3`. - /// - /// ## Panics - /// Panics when `Role = H2`, i.e. `x1_or_y1` is `None`. - pub fn get_x1_or_y1(self) -> Vec { - self.x1_or_y1.unwrap() - } - - /// When `IntermediateShuffleMessages` is initialized correctly, - /// this function returns `x2` when `Role = H2` - /// and `y2` when `Role = H3`. - /// - /// ## Panics - /// Panics when `Role = H1`, i.e. `x2_or_y2` is `None`. - pub fn get_x2_or_y2(self) -> Vec { - self.x2_or_y2.unwrap() - } - - /// When `IntermediateShuffleMessages` is initialized correctly, - /// this function returns `y1` and `y2` when `Role = H3`. - /// - /// ## Panics - /// Panics when `Role = H1`, i.e. `x2_or_y2` is `None` or - /// when `Role = H2`, i.e. `x1_or_y1` is `None`. - pub fn get_both_x_or_ys(self) -> (Vec, Vec) { - (self.x1_or_y1.unwrap(), self.x2_or_y2.unwrap()) - } -} - async fn run_h1( ctx: &C, batch_size: NonZeroUsize, @@ -152,13 +96,7 @@ where let res = combine_single_shares(a_hat, b_hat).collect::>(); // we only need to store x_1 in IntermediateShuffleMessage - Ok(( - res, - IntermediateShuffleMessages { - x1_or_y1: Some(x_1), - x2_or_y2: None, - }, - )) + Ok((res, IntermediateShuffleMessages::H1 { x1: x_1 })) } async fn run_h2( @@ -234,13 +172,7 @@ where let c_hat = add_single_shares(c_hat_1.iter(), c_hat_2.iter()); let res = combine_single_shares(b_hat, c_hat).collect::>(); // we only need to store x_2 in IntermediateShuffleMessage - Ok(( - res, - IntermediateShuffleMessages { - x1_or_y1: None, - x2_or_y2: Some(x_2), - }, - )) + Ok((res, IntermediateShuffleMessages::H2 { x2: x_2 })) } async fn run_h3( @@ -310,13 +242,7 @@ where let c_hat = add_single_shares(c_hat_1, c_hat_2); let res = combine_single_shares(c_hat, a_hat).collect::>(); - Ok(( - res, - IntermediateShuffleMessages { - x1_or_y1: Some(y_1), - x2_or_y2: Some(y_2), - }, - )) + Ok((res, IntermediateShuffleMessages::H3 { y1: y_1, y2: y_2 })) } fn add_single_shares(l: L, r: R) -> impl Iterator @@ -439,14 +365,95 @@ where Ok(()) } +#[cfg(all(test, any(unit_test, feature = "shuttle")))] +pub(super) mod test_helpers { + use std::iter::zip; + + use crate::{ + protocol::ipa_prf::shuffle::IntermediateShuffleMessages, + secret_sharing::{ + replicated::{semi_honest::AdditiveShare, ReplicatedSecretSharing}, + SharedValue, + }, + }; + + fn check_replicated_shares<'a, S, I1, I2>(left_helper_shares: I1, right_helper_shares: I2) + where + S: SharedValue, + I1: Iterator>, + I2: Iterator>, + { + assert!(zip(left_helper_shares, right_helper_shares) + .all(|(lhs, rhs)| lhs.right() == rhs.left())); + } + + pub struct ExtractedShuffleResults { + pub x1_xor_y1: Vec, + pub x2_xor_y2: Vec, + pub a_xor_b_xor_c: Vec, + } + + impl ExtractedShuffleResults { + pub fn empty() -> Self { + ExtractedShuffleResults { + x1_xor_y1: vec![], + x2_xor_y2: vec![], + a_xor_b_xor_c: vec![], + } + } + } + + /// Extract the data returned from shuffle (the shuffled records and intermediate + /// values) into a more usable form for verification. This routine is used for + /// both the unsharded and sharded shuffles. + pub fn extract_shuffle_results( + results: [(Vec>, IntermediateShuffleMessages); 3], + ) -> ExtractedShuffleResults { + // check consistency + // i.e. x_1 xor y_1 = x_2 xor y_2 = C xor A xor B + let [(h1_shares, h1_messages), (h2_shares, h2_messages), (h3_shares, h3_messages)] = + results; + + let IntermediateShuffleMessages::H1 { x1 } = h1_messages else { + panic!("H1 returned shuffle messages for {:?}", h1_messages.role()); + }; + let IntermediateShuffleMessages::H2 { x2 } = h2_messages else { + panic!("H2 returned shuffle messages for {:?}", h2_messages.role()); + }; + let IntermediateShuffleMessages::H3 { y1, y2 } = h3_messages else { + panic!("H3 returned shuffle messages for {:?}", h3_messages.role()); + }; + + check_replicated_shares(h1_shares.iter(), h2_shares.iter()); + check_replicated_shares(h2_shares.iter(), h3_shares.iter()); + check_replicated_shares(h3_shares.iter(), h1_shares.iter()); + + let x1_xor_y1 = zip(x1, y1).map(|(x1, y1)| x1 + y1).collect(); + + let x2_xor_y2 = zip(x2, y2).map(|(x2, y2)| x2 + y2).collect(); + + let a_xor_b_xor_c = zip(&h1_shares, h3_shares) + .map(|(h1_share, h3_share)| h1_share.left() + h1_share.right() + h3_share.left()) + .collect(); + + ExtractedShuffleResults { + x1_xor_y1, + x2_xor_y2, + a_xor_b_xor_c, + } + } +} + #[cfg(all(test, unit_test))] -pub mod tests { +pub(super) mod tests { use rand::{thread_rng, Rng}; use super::shuffle_protocol; use crate::{ ff::{Gf40Bit, U128Conversions}, - secret_sharing::replicated::ReplicatedSecretSharing, + protocol::ipa_prf::shuffle::base::test_helpers::{ + extract_shuffle_results, ExtractedShuffleResults, + }, test_executor::run, test_fixture::{Reconstruct, Runner, TestWorld, TestWorldConfig}, }; @@ -495,37 +502,17 @@ pub mod tests { .map(|_| rng.gen()) .collect::>(); - let [h1, h2, h3] = world + let results = world .semi_honest(records.clone().into_iter(), |ctx, records| async move { - shuffle_protocol(ctx, records).await + shuffle_protocol(ctx, records).await.unwrap() }) .await; - // check consistency - // i.e. x_1 xor y_1 = x_2 xor y_2 = C xor A xor B - let (h1_shares, h1_messages) = h1.unwrap(); - let (_, h2_messages) = h2.unwrap(); - let (h3_shares, h3_messages) = h3.unwrap(); - - let mut x1_xor_y1 = h1_messages - .x1_or_y1 - .unwrap() - .iter() - .zip(h3_messages.x1_or_y1.unwrap()) - .map(|(x1, y1)| x1 + y1) - .collect::>(); - let mut x2_xor_y2 = h2_messages - .x2_or_y2 - .unwrap() - .iter() - .zip(h3_messages.x2_or_y2.unwrap()) - .map(|(x2, y2)| x2 + y2) - .collect::>(); - let mut a_xor_b_xor_c = h1_shares - .iter() - .zip(h3_shares) - .map(|(h1_share, h3_share)| h1_share.left() + h1_share.right() + h3_share.left()) - .collect::>(); + let ExtractedShuffleResults { + mut x1_xor_y1, + mut x2_xor_y2, + mut a_xor_b_xor_c, + } = extract_shuffle_results(results); // unshuffle by sorting records.sort(); diff --git a/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs b/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs index b03363288..c059684ce 100644 --- a/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs +++ b/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs @@ -13,15 +13,15 @@ use crate::{ ff::{boolean_array::BooleanArray, Field, Gf32Bit, Serializable}, helpers::{ hashing::{compute_hash, Hash}, - Direction, Role, TotalRecords, + Direction, TotalRecords, }, protocol::{ basics::{malicious_reveal, mul::semi_honest_multiply}, context::Context, ipa_prf::shuffle::{ - base::IntermediateShuffleMessages, shuffle_protocol, step::{OPRFShuffleStep, VerifyShuffleStep}, + IntermediateShuffleMessages, }, prss::SharedRandomness, RecordId, @@ -40,7 +40,7 @@ use crate::{ /// /// ## Panics /// Panics when `S::Bits + 32 != B::Bits` or type conversions fail. -pub async fn malicious_shuffle( +pub(super) async fn malicious_shuffle( ctx: C, shares: I, ) -> Result>, Error> @@ -149,16 +149,17 @@ async fn verify_shuffle( .map(Gf32Bit::from_array) .collect::>(); + assert_eq!(messages.role(), ctx.role()); + // verify messages and shares - match ctx.role() { - Role::H1 => { - h1_verify::<_, S, B>(ctx, &keys, shuffled_shares, messages.get_x1_or_y1()).await + match messages { + IntermediateShuffleMessages::H1 { x1 } => { + h1_verify::<_, S, B>(ctx, &keys, shuffled_shares, x1).await } - Role::H2 => { - h2_verify::<_, S, B>(ctx, &keys, shuffled_shares, messages.get_x2_or_y2()).await + IntermediateShuffleMessages::H2 { x2 } => { + h2_verify::<_, S, B>(ctx, &keys, shuffled_shares, x2).await } - Role::H3 => { - let (y1, y2) = messages.get_both_x_or_ys(); + IntermediateShuffleMessages::H3 { y1, y2 } => { h3_verify::<_, S, B>(ctx, &keys, shuffled_shares, y1, y2).await } } @@ -479,7 +480,10 @@ mod tests { boolean_array::{BA112, BA144, BA20, BA32, BA64}, Serializable, U128Conversions, }, - helpers::in_memory_config::{MaliciousHelper, MaliciousHelperContext}, + helpers::{ + in_memory_config::{MaliciousHelper, MaliciousHelperContext}, + Role, + }, protocol::ipa_prf::shuffle::base::shuffle_protocol, secret_sharing::SharedValue, test_executor::run, diff --git a/ipa-core/src/protocol/ipa_prf/shuffle/mod.rs b/ipa-core/src/protocol/ipa_prf/shuffle/mod.rs index 582445190..5d8517fd5 100644 --- a/ipa-core/src/protocol/ipa_prf/shuffle/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/shuffle/mod.rs @@ -1,8 +1,8 @@ use std::{future::Future, ops::Add}; +use futures::FutureExt; use rand::distributions::{Distribution, Standard}; -use self::base::shuffle_protocol; use super::{ boolean_ops::{expand_shared_array_in_place, extract_from_shared_array}, prf_sharding::SecretSharedAttributionOutputs, @@ -14,10 +14,11 @@ use crate::{ boolean_array::{BooleanArray, BA112, BA144, BA64, BA96}, ArrayAccess, }, + helpers::Role, protocol::{ context::{Context, MaliciousContext, SemiHonestContext}, ipa_prf::{ - shuffle::{base::semi_honest_shuffle, malicious::malicious_shuffle}, + shuffle::{base::shuffle_protocol, malicious::malicious_shuffle}, OPRFIPAInputRow, }, }, @@ -34,6 +35,41 @@ pub mod malicious; mod sharded; pub(crate) mod step; +/// This struct stores some intermediate messages during the shuffle. +/// In a maliciously secure shuffle, +/// these messages need to be checked for consistency across helpers. +/// `H1` stores `x1`, `H2` stores `x2` and `H3` stores `y1` and `y2`. +#[derive(Debug, Clone)] +enum IntermediateShuffleMessages { + H1 { x1: Vec }, + H2 { x2: Vec }, + H3 { y1: Vec, y2: Vec }, +} + +impl IntermediateShuffleMessages { + pub fn role(&self) -> Role { + match *self { + IntermediateShuffleMessages::H1 { .. } => Role::H1, + IntermediateShuffleMessages::H2 { .. } => Role::H2, + IntermediateShuffleMessages::H3 { .. } => Role::H3, + } + } + + /// Return an empty `IntermediateShuffleMessages` for the currrent helper. + pub fn empty(ctx: &C) -> Self { + match ctx.role() { + Role::H1 => IntermediateShuffleMessages::H1 { x1: vec![] }, + Role::H2 => IntermediateShuffleMessages::H2 { x2: vec![] }, + Role::H3 => IntermediateShuffleMessages::H3 { + y1: vec![], + y2: vec![], + }, + } + } +} + +/// Trait used by protocols to invoke either semi-honest or malicious shuffle, depending +/// on the type of context being used. pub trait Shuffle: Context { fn shuffle( self, @@ -43,14 +79,10 @@ pub trait Shuffle: Context { S: BooleanArray, B: BooleanArray, I: IntoIterator> + Send, - I::IntoIter: ExactSizeIterator, - ::IntoIter: Send, - for<'a> &'a S: Add, - for<'a> &'a S: Add<&'a S, Output = S>, - for<'a> &'a B: Add, - for<'a> &'a B: Add<&'a B, Output = B>, - Standard: Distribution, - Standard: Distribution; + I::IntoIter: ExactSizeIterator + Send, + for<'a> &'a S: Add + Add<&'a S, Output = S>, + for<'a> &'a B: Add + Add<&'a B, Output = B>, + Standard: Distribution + Distribution; } impl<'b, T: ShardBinding> Shuffle for SemiHonestContext<'b, T> { @@ -62,16 +94,13 @@ impl<'b, T: ShardBinding> Shuffle for SemiHonestContext<'b, T> { S: BooleanArray, B: BooleanArray, I: IntoIterator> + Send, - I::IntoIter: ExactSizeIterator, - ::IntoIter: Send, - for<'a> &'a S: Add, - for<'a> &'a S: Add<&'a S, Output = S>, - for<'a> &'a B: Add, - for<'a> &'a B: Add<&'a B, Output = B>, - Standard: Distribution, - Standard: Distribution, + I::IntoIter: ExactSizeIterator + Send, + for<'a> &'a S: Add + Add<&'a S, Output = S>, + for<'a> &'a B: Add + Add<&'a B, Output = B>, + Standard: Distribution + Distribution, { - semi_honest_shuffle::<_, I, S>(self, shares) + let fut = shuffle_protocol::<_, I, S>(self, shares); + fut.map(|res| res.map(|(output, _intermediates)| output)) } } @@ -84,14 +113,10 @@ impl<'b> Shuffle for MaliciousContext<'b> { S: BooleanArray, B: BooleanArray, I: IntoIterator> + Send, - I::IntoIter: ExactSizeIterator, - ::IntoIter: Send, - for<'a> &'a S: Add, - for<'a> &'a S: Add<&'a S, Output = S>, - for<'a> &'a B: Add, - for<'a> &'a B: Add<&'a B, Output = B>, - Standard: Distribution, - Standard: Distribution, + I::IntoIter: ExactSizeIterator + Send, + for<'a> &'a S: Add + Add<&'a S, Output = S>, + for<'a> &'a B: Add + Add<&'a B, Output = B>, + Standard: Distribution + Distribution, { malicious_shuffle::<_, S, B, I>(self, shares) } @@ -127,7 +152,7 @@ pub async fn shuffle_attribution_outputs( input: Vec>, ) -> Result>, Error> where - C: Context, + C: Context + Shuffle, BK: BooleanArray, TV: BooleanArray, R: BooleanArray, @@ -140,7 +165,7 @@ where .map(|item| attribution_outputs_to_shuffle_input::(&item)) .collect::>(); - let shuffled = malicious_shuffle::<_, R, BA96, _>(ctx, shuffle_input).await?; + let shuffled = ctx.shuffle::(shuffle_input).await?; Ok(shuffled .into_iter() @@ -260,21 +285,16 @@ pub mod tests { use crate::{ ff::{ - boolean::Boolean, boolean_array::{BA20, BA3, BA32, BA64, BA8}, U128Conversions, }, - protocol::{ - context::UpgradedSemiHonestContext, - ipa_prf::{ - prf_sharding::{ - tests::PreAggregationTestOutputInDecimal, AttributionOutputsTestInput, - SecretSharedAttributionOutputs, - }, - shuffle::{shuffle_attribution_outputs, shuffle_inputs}, + protocol::ipa_prf::{ + prf_sharding::{ + tests::PreAggregationTestOutputInDecimal, AttributionOutputsTestInput, + SecretSharedAttributionOutputs, }, + shuffle::{shuffle_attribution_outputs, shuffle_inputs}, }, - sharding::NotSharded, test_executor::run, test_fixture::{ipa::TestRawDataRecord, Reconstruct, Runner, TestWorld}, }; @@ -351,21 +371,18 @@ pub mod tests { expectation.push(e); } let mut result: Vec = world - .upgraded_semi_honest( - inputs.into_iter(), - |ctx: UpgradedSemiHonestContext, input_rows| async move { - let aos: Vec<_> = input_rows - .into_iter() - .map(|ti| SecretSharedAttributionOutputs { - attributed_breakdown_key_bits: ti.0, - capped_attributed_trigger_value: ti.1, - }) - .collect(); - shuffle_attribution_outputs::<_, BA32, BA32, BA64>(ctx, aos) - .await - .unwrap() - }, - ) + .semi_honest(inputs.into_iter(), |ctx, input_rows| async move { + let aos: Vec<_> = input_rows + .into_iter() + .map(|ti| SecretSharedAttributionOutputs { + attributed_breakdown_key_bits: ti.0, + capped_attributed_trigger_value: ti.1, + }) + .collect(); + shuffle_attribution_outputs::<_, BA32, BA32, BA64>(ctx, aos) + .await + .unwrap() + }) .await .reconstruct(); assert_ne!(result, expectation); diff --git a/ipa-core/src/protocol/ipa_prf/shuffle/sharded.rs b/ipa-core/src/protocol/ipa_prf/shuffle/sharded.rs index 48c02c103..2d1d78695 100644 --- a/ipa-core/src/protocol/ipa_prf/shuffle/sharded.rs +++ b/ipa-core/src/protocol/ipa_prf/shuffle/sharded.rs @@ -8,7 +8,7 @@ //! MPC communication, it uses 6 rounds of intra-helper communications to send data between shards. //! In this implementation, this operation is called "resharding". -use std::{future::Future, num::NonZeroUsize, ops::Add}; +use std::{borrow::Borrow, future::Future, num::NonZeroUsize, ops::Add}; use futures::{future::try_join, stream, StreamExt, TryFutureExt}; use ipa_step::Step; @@ -19,7 +19,8 @@ use crate::{ helpers::{Direction, Error, Role, TotalRecords}, protocol::{ context::{reshard_iter, ShardedContext}, - prss::{FromRandom, FromRandomU128, SharedRandomness}, + ipa_prf::shuffle::IntermediateShuffleMessages, + prss::{FromRandom, SharedRandomness}, RecordId, }, secret_sharing::{ @@ -81,7 +82,8 @@ trait ShuffleContext: ShardedContext { data: I, ) -> impl Future, crate::error::Error>> + Send where - I: IntoIterator, + I: IntoIterator, + I::Item: Borrow, I::IntoIter: ExactSizeIterator + Send, S: ShuffleShare, { @@ -93,12 +95,12 @@ trait ShuffleContext: ShardedContext { data.enumerate().map(|(i, item)| { // FIXME(1029): update PRSS trait to compute only left or right part let (l, r) = masking_ctx.prss().generate(RecordId::from(i)); - let mask = match direction { + let mask: S = match direction { Direction::Left => l, Direction::Right => r, }; - item + mask + item.borrow().clone() + mask }), |ctx, record_id, _| ctx.pick_shard(record_id, direction), )) @@ -246,7 +248,7 @@ pub trait Shuffleable: Send + 'static { fn new(l: Self::Share, r: Self::Share) -> Self; } -impl Shuffleable for AdditiveShare { +impl Shuffleable for AdditiveShare { type Share = V; fn left(&self) -> Self::Share { @@ -263,7 +265,10 @@ impl Shuffleable for AdditiveShare { } /// Sharded shuffle as performed by shards on H1. -async fn h1_shuffle_for_shard(ctx: C, shares: I) -> Result, crate::error::Error> +async fn h1_shuffle_for_shard( + ctx: C, + shares: I, +) -> Result<(Vec, IntermediateShuffleMessages), crate::error::Error> where I: IntoIterator, I::IntoIter: Send + ExactSizeIterator, @@ -271,22 +276,22 @@ where S: Shuffleable, { // Generate X_1 = perm_12(left ⊕ right ⊕ z_12). - let x1 = ctx + let x1: Vec = ctx .narrow(&ShuffleStep::Permute12) - .mask_and_shuffle::<_, S::Share>( + .mask_and_shuffle( Direction::Right, shares.into_iter().map(|share| share.left() + share.right()), ) .await?; - // Generate X_2 = perm_31(x_1 ⊕ z_31) and reshard it using the randomness + // Generate X_2 = perm_31(X_1 ⊕ z_31) and reshard it using the randomness // shared with the left helper. - let x2 = ctx + let x2: Vec = ctx .narrow(&ShuffleStep::Permute31) - .mask_and_shuffle(Direction::Left, x1) + .mask_and_shuffle(Direction::Left, &x1) .await?; - // X2 is masked now and cannot reveal anything to the helper on the right. + // X_2 is masked now and cannot reveal anything to the helper on the right. ctx.narrow(&ShuffleStep::LeftToRight) .send_all(x2, Direction::Right) .await?; @@ -302,7 +307,7 @@ where // set our shares let ctx = ctx.narrow(&ShuffleStep::PseudoRandomTable); - Ok((0..sz) + let res = (0..sz) .map(|i| { // This may be confusing as paper specifies à and B̃ as independent tables, but // there is really no reason to generate them using unique PRSS keys. @@ -310,11 +315,16 @@ where S::new(a, b) }) - .collect()) + .collect(); + + Ok((res, IntermediateShuffleMessages::H1 { x1 })) } /// Sharded shuffle as performed by shards on H2. -async fn h2_shuffle_for_shard(ctx: C, shares: I) -> Result, crate::error::Error> +async fn h2_shuffle_for_shard( + ctx: C, + shares: I, +) -> Result<(Vec, IntermediateShuffleMessages), crate::error::Error> where I: IntoIterator, I::IntoIter: Send + ExactSizeIterator, @@ -341,10 +351,10 @@ where .recv_all::(Direction::Left) .await?; - // generate X_3 = perm_23(x_2 ⊕ z_23) - let x3 = ctx + // generate X_3 = perm_23(X_2 ⊕ z_23) + let x3: Vec = ctx .narrow(&ShuffleStep::Permute23) - .mask_and_shuffle(Direction::Right, x2) + .mask_and_shuffle(Direction::Right, &x2) .await?; // at this moment we know the cardinality of C, and we let H1 know it, so it can start @@ -354,10 +364,10 @@ where .await?; let Some(x3_len) = NonZeroUsize::new(x3.len()) else { - return Ok(Vec::new()); + return Ok((Vec::new(), IntermediateShuffleMessages::H2 { x2 })); }; - // Generate c_1 = x_3 ⊕ b, stream it to H3 and receive c_2 from it at the same time. + // Generate c_1 = X_3 ⊕ b, stream it to H3 and receive c_2 from it at the same time. // Knowing b, c_1 and c_2 lets us set our resulting share, according to the paper it is // (b, c_1 + c_2) let send_channel = ctx @@ -368,7 +378,7 @@ where .narrow(&ShuffleStep::C) .recv_channel(ctx.role().peer(Direction::Right)); - Ok(ctx + let res = ctx .try_join(x3.into_iter().enumerate().map(|(i, x3)| { let record_id = RecordId::from(i); // FIXME(1029): update PRSS trait to compute only left or right part @@ -376,19 +386,24 @@ where .narrow(&ShuffleStep::PseudoRandomTable) .prss() .generate(RecordId::from(i)); - let c1 = x3 + b.clone(); + let c1: S::Share = x3 + b.clone(); try_join( send_channel.send(record_id, c1.clone()), recv_channel.receive(record_id), ) .map_ok(|((), c2)| S::new(b, c1 + c2)) })) - .await?) + .await?; + + Ok((res, IntermediateShuffleMessages::H2 { x2 })) } /// Sharded shuffle as performed by shards on H3. Note that in semi-honest setting, H3 does not /// use its input. Adding support for active security will change that. -async fn h3_shuffle_for_shard(ctx: C, _: I) -> Result, crate::error::Error> +async fn h3_shuffle_for_shard( + ctx: C, + _: I, +) -> Result<(Vec, IntermediateShuffleMessages), crate::error::Error> where I: IntoIterator, I::IntoIter: Send + ExactSizeIterator, @@ -402,19 +417,19 @@ where .await?; // Generate y2 = perm_31(y_1 ⊕ z_31) - let y2 = ctx + let y2: Vec = ctx .narrow(&ShuffleStep::Permute31) - .mask_and_shuffle(Direction::Right, y1) + .mask_and_shuffle(Direction::Right, &y1) .await?; // Generate y3 = perm_23(y_2 ⊕ z_23) - let y3 = ctx + let y3: Vec = ctx .narrow(&ShuffleStep::Permute23) - .mask_and_shuffle(Direction::Left, y2) + .mask_and_shuffle(Direction::Left, &y2) .await?; let Some(y3_len) = NonZeroUsize::new(y3.len()) else { - return Ok(Vec::new()); + return Ok((Vec::new(), IntermediateShuffleMessages::H3 { y1, y2 })); }; // Generate c_2 = y_3 ⊕ a, stream it to H2 and receive c_1 from it at the same time. @@ -426,7 +441,7 @@ where let recv_channel = ctx .narrow(&ShuffleStep::C) .recv_channel::(ctx.role().peer(Direction::Left)); - Ok(ctx + let res = ctx .try_join(y3.into_iter().enumerate().map(|(i, y3)| { let record_id = RecordId::from(i); // FIXME(1029): update PRSS trait to compute only left or right part @@ -441,14 +456,21 @@ where ) .map_ok(|((), c1)| S::new(c1 + c2, a)) })) - .await?) + .await?; + + Ok((res, IntermediateShuffleMessages::H3 { y1, y2 })) } -/// Entry point to execute sharded shuffle. +/// Internal entry point to sharded shuffle protocol, excluding validation of +/// intermediates for malicious security. Protocols should use `trait Shuffle`. +/// /// ## Errors /// Failure to communicate over the network, either to other MPC helpers, and/or to other shards /// will generate a shuffle error. -pub async fn shuffle(ctx: C, shares: I) -> Result, crate::error::Error> +pub(super) async fn shuffle( + ctx: C, + shares: I, +) -> Result<(Vec, IntermediateShuffleMessages), crate::error::Error> where I: IntoIterator, I::IntoIter: Send + ExactSizeIterator, @@ -464,9 +486,14 @@ where #[cfg(all(test, any(unit_test, feature = "shuttle")))] mod tests { + use rand::{thread_rng, Rng}; + use crate::{ - ff::{boolean_array::BA8, U128Conversions}, - protocol::ipa_prf::shuffle::sharded::shuffle, + ff::{boolean_array::BA8, Gf40Bit, U128Conversions}, + protocol::ipa_prf::shuffle::{ + base::test_helpers::{extract_shuffle_results, ExtractedShuffleResults}, + sharded::shuffle, + }, test_executor::run, test_fixture::{ Distribute, RandomInputDistribution, Reconstruct, RoundRobinInputDistribution, Runner, @@ -479,7 +506,7 @@ mod tests { TestWorld::with_shards(TestWorldConfig::default()); world .semi_honest(input.into_iter(), |ctx, input| async move { - shuffle(ctx, input).await.unwrap() + shuffle(ctx, input).await.unwrap().0 }) .await .into_iter() @@ -539,4 +566,57 @@ mod tests { assert!(result.is_empty()); }); } + + #[test] + fn check_intermediate_messages() { + const SHARDS: usize = 3; + const RECORD_AMOUNT: usize = 100; + type Distribution = RandomInputDistribution; + run(|| async { + let mut rng = thread_rng(); + // using Gf40Bit here since it implements cmp such that vec can later be sorted + let mut records = (0..RECORD_AMOUNT) + .map(|_| rng.gen()) + .collect::>(); + + let results = TestWorld::>::with_shards( + TestWorldConfig::default(), + ) + .semi_honest(records.clone().into_iter(), |ctx, input| async move { + shuffle(ctx, input).await.unwrap() + }) + .await + .into_iter() + .map(extract_shuffle_results) + .fold(ExtractedShuffleResults::empty(), |mut acc, results| { + let ExtractedShuffleResults { + x1_xor_y1, + x2_xor_y2, + a_xor_b_xor_c, + } = results; + + acc.x1_xor_y1.extend(x1_xor_y1); + acc.x2_xor_y2.extend(x2_xor_y2); + acc.a_xor_b_xor_c.extend(a_xor_b_xor_c); + + acc + }); + + let ExtractedShuffleResults { + mut x1_xor_y1, + mut x2_xor_y2, + mut a_xor_b_xor_c, + } = results; + + // unshuffle by sorting + records.sort(); + x1_xor_y1.sort(); + x2_xor_y2.sort(); + a_xor_b_xor_c.sort(); + + assert_eq!(records, a_xor_b_xor_c); + assert_eq!(records, x1_xor_y1); + assert_eq!(records, x2_xor_y2); + }); + } } diff --git a/ipa-core/src/query/runner/oprf_ipa.rs b/ipa-core/src/query/runner/oprf_ipa.rs index 11846c86c..8cbf45eac 100644 --- a/ipa-core/src/query/runner/oprf_ipa.rs +++ b/ipa-core/src/query/runner/oprf_ipa.rs @@ -21,8 +21,8 @@ use crate::{ basics::{BooleanArrayMul, Reveal, ShareKnownValue}, context::{DZKPUpgraded, MacUpgraded, UpgradableContext}, ipa_prf::{ - oprf_ipa, oprf_padding::PaddingParameters, prf_eval::PrfSharing, shuffle::Shuffle, - OPRFIPAInputRow, AGG_CHUNK, CONV_CHUNK, PRF_CHUNK, SORT_CHUNK, + oprf_ipa, oprf_padding::PaddingParameters, prf_eval::PrfSharing, OPRFIPAInputRow, + Shuffle, AGG_CHUNK, CONV_CHUNK, PRF_CHUNK, SORT_CHUNK, }, prss::FromPrss, step::ProtocolStep::IpaPrf, diff --git a/ipa-core/src/secret_sharing/mod.rs b/ipa-core/src/secret_sharing/mod.rs index e27604b8b..3e3085149 100644 --- a/ipa-core/src/secret_sharing/mod.rs +++ b/ipa-core/src/secret_sharing/mod.rs @@ -62,7 +62,7 @@ pub trait Block: Sized + Copy + Debug { type Size: ArrayLength; } -pub trait Sendable: Send + Debug + Serializable + 'static {} +pub trait Sendable: Send + Sync + Debug + Serializable + 'static {} impl Sendable for V {}