Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hybrid shuffle #1387

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion ipa-core/src/protocol/hybrid/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
pub(crate) mod shuffle;
pub(crate) mod step;

use step::HybridStep as Step;

use self::shuffle::shuffle_hybrid_inputs;
use crate::{
error::Error,
ff::{
Expand All @@ -19,6 +21,7 @@ use crate::{
report::hybrid::IndistinguishableHybridReport,
secret_sharing::replicated::semi_honest::AdditiveShare as Replicated,
};

// In theory, we could support (runtime-configured breakdown count) ≤ (compile-time breakdown count)
// ≤ 2^|bk|, with all three values distinct, but at present, there is no runtime configuration and
// the latter two must be equal. The implementation of `move_single_value_to_bucket` does support a
Expand Down Expand Up @@ -81,12 +84,14 @@ where
}

// Apply DP padding for OPRF
let _padded_input_rows = apply_dp_padding::<_, IndistinguishableHybridReport<BK, V>, B>(
let padded_input_rows = apply_dp_padding::<_, IndistinguishableHybridReport<BK, V>, B>(
ctx.narrow(&Step::PaddingDp),
input_rows,
&dp_padding_params,
)
.await?;

let _shuffled = shuffle_hybrid_inputs(ctx.narrow(&Step::Shuffle), padded_input_rows).await?;

unimplemented!("protocol::hybrid::hybrid_protocol is not fully implemented")
}
217 changes: 217 additions & 0 deletions ipa-core/src/protocol/hybrid/shuffle.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
use std::ops::Add;

use generic_array::{ArrayLength, GenericArray};
use typenum::{Unsigned, U18};

use crate::{
error::{Error, UnwrapInfallible},
ff::{
boolean_array::{BooleanArray, BA64},
Serializable,
},
protocol::{context::Context, ipa_prf::shuffle::Shuffle},
report::hybrid::IndistinguishableHybridReport,
secret_sharing::{replicated::semi_honest::AdditiveShare as Replicated, Sendable, SharedValue},
};

impl<BK, V> Serializable for IndistinguishableHybridReport<BK, V>
where
BK: SharedValue,
V: SharedValue,
Replicated<BK>: Serializable,
Replicated<V>: Serializable,
<Replicated<BK> as Serializable>::Size: Add<U18>,
<Replicated<V> as Serializable>::Size:
Add<<<Replicated<BK> as Serializable>::Size as Add<U18>>::Output>,
<<Replicated<V> as Serializable>::Size as Add<
<<Replicated<BK> as Serializable>::Size as Add<U18>>::Output,
>>::Output: ArrayLength,
{
type Size = <<Replicated<V> as Serializable>::Size as Add<
<<Replicated<BK> as Serializable>::Size as Add<U18>>::Output,
>>::Output;

type DeserializationError = Error;

fn serialize(&self, buf: &mut GenericArray<u8, Self::Size>) {
let mk_sz = <Replicated<BA64> as Serializable>::Size::USIZE;
let bk_sz = <Replicated<BK> as Serializable>::Size::USIZE;
let v_sz = <Replicated<V> as Serializable>::Size::USIZE;

self.match_key
.serialize(GenericArray::from_mut_slice(&mut buf[..mk_sz]));

self.breakdown_key
.serialize(GenericArray::from_mut_slice(&mut buf[mk_sz..mk_sz + bk_sz]));

self.value.serialize(GenericArray::from_mut_slice(
&mut buf[mk_sz + bk_sz..mk_sz + bk_sz + v_sz],
));
}

fn deserialize(buf: &GenericArray<u8, Self::Size>) -> Result<Self, Self::DeserializationError> {
let mk_sz = <Replicated<BA64> as Serializable>::Size::USIZE;
let bk_sz = <Replicated<BK> as Serializable>::Size::USIZE;
let v_sz = <Replicated<V> as Serializable>::Size::USIZE;

let match_key = Replicated::<BA64>::deserialize(GenericArray::from_slice(&buf[..mk_sz]))
.unwrap_infallible();
let breakdown_key =
Replicated::<BK>::deserialize(GenericArray::from_slice(&buf[mk_sz..mk_sz + bk_sz]))
.map_err(|e| Error::ParseError(e.into()))?;
let value = Replicated::<V>::deserialize(GenericArray::from_slice(
&buf[mk_sz + bk_sz..mk_sz + bk_sz + v_sz],
))
.map_err(|e| Error::ParseError(e.into()))?;

Ok(Self {
match_key,
value,
breakdown_key,
})
}
}

impl<BK: SharedValue, V: SharedValue> Add for IndistinguishableHybridReport<BK, V> {
type Output = Self;

fn add(self, rhs: Self) -> Self {
Self {
match_key: self.match_key + rhs.match_key,
value: self.value + rhs.value,
breakdown_key: self.breakdown_key + rhs.breakdown_key,
}
}
}

impl<BK: SharedValue, V: SharedValue> Sendable for IndistinguishableHybridReport<BK, V>
where
<V as Serializable>::Size: Add,
<BK as Serializable>::Size: Add,
<<V as Serializable>::Size as Add>::Output: ArrayLength,
<<BK as Serializable>::Size as Add>::Output: ArrayLength,
<<BK as Serializable>::Size as Add>::Output: Add<U18>,
<<V as Serializable>::Size as Add>::Output:
Add<<<<BK as Serializable>::Size as Add>::Output as Add<U18>>::Output>,
<<<V as Serializable>::Size as Add>::Output as Add<
<<<BK as Serializable>::Size as Add>::Output as Add<U18>>::Output,
>>::Output: ArrayLength,
{
}

/// Shuffles a Vec of IndistinguishableHybridReport
/// # Errors
/// Propogates errors from ctx.shuffle
#[tracing::instrument(name = "shuffle_inputs", skip_all)]
pub async fn shuffle_hybrid_inputs<C, BK, V>(
_ctx: C,
_input: Vec<IndistinguishableHybridReport<BK, V>>,
) -> Result<Vec<IndistinguishableHybridReport<BK, V>>, Error>
where
C: Context + Shuffle,
BK: BooleanArray,
V: BooleanArray,
{
unimplemented!("shuffle_hybrid_inputs is unimplemented");
}

#[cfg(all(test, unit_test))]
pub mod tests {
use generic_array::GenericArray;
use rand::Rng;

use super::shuffle_hybrid_inputs;
use crate::{
ff::{
boolean_array::{BA3, BA64, BA8},
Serializable,
},
report::hybrid::IndistinguishableHybridReport,
secret_sharing::replicated::{
semi_honest::AdditiveShare as Replicated, ReplicatedSecretSharing,
},
test_executor::run,
test_fixture::{hybrid::TestIndistinguishableHybridReport, Reconstruct, Runner, TestWorld},
};

#[tokio::test]
async fn hybrid_serialize_deserialize() {
let world = TestWorld::default();
let mut rng = world.rng();
let report = IndistinguishableHybridReport::<BA8, BA3> {
match_key: Replicated::new(rng.gen(), rng.gen()),
breakdown_key: Replicated::new(rng.gen(), rng.gen()),
value: Replicated::new(rng.gen(), rng.gen()),
};

let mut buf = GenericArray::default();
report.serialize(&mut buf);
let deserialized_report = IndistinguishableHybridReport::<BA8, BA3>::deserialize(&buf);
assert_eq!(report, deserialized_report.unwrap());
}

#[test]
fn hybrid_add() {
let report1 = IndistinguishableHybridReport::<BA8, BA3> {
match_key: Replicated::new(BA64::try_from(5).unwrap(), BA64::try_from(3).unwrap()),
value: Replicated::new(BA3::try_from(1).unwrap(), BA3::try_from(2).unwrap()),
breakdown_key: Replicated::new(BA8::try_from(2).unwrap(), BA8::try_from(4).unwrap()),
};

assert_eq!(
report1,
report1.clone() + IndistinguishableHybridReport::<BA8, BA3>::ZERO
);
// report 1: (5,3), (1,2), (2,4)
// report 2: (2,2), (0,1), (3,3)
// XOR
// expected: (7,1), (1,3), (1,7)
let report2 = IndistinguishableHybridReport::<BA8, BA3> {
match_key: Replicated::new(BA64::try_from(2).unwrap(), BA64::try_from(2).unwrap()),
value: Replicated::new(BA3::try_from(0).unwrap(), BA3::try_from(1).unwrap()),
breakdown_key: Replicated::new(BA8::try_from(3).unwrap(), BA8::try_from(3).unwrap()),
};

let expected = IndistinguishableHybridReport::<BA8, BA3> {
match_key: Replicated::new(BA64::try_from(7).unwrap(), BA64::try_from(1).unwrap()),
value: Replicated::new(BA3::try_from(1).unwrap(), BA3::try_from(3).unwrap()),
breakdown_key: Replicated::new(BA8::try_from(1).unwrap(), BA8::try_from(7).unwrap()),
};

assert_eq!(report1 + report2, expected)
}

#[test]
#[should_panic(expected = "shuffle_hybrid_inputs is unimplemented")]
fn test_shuffle_hybrid_inputs() {
const BATCHSIZE: usize = 50;
run(|| async {
let world = TestWorld::default();
let mut rng = world.rng();
let mut records = Vec::new();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Can you do something like?

let records = (0..BATCHSIZE).map(|_| TestIndistinguishableHybridReport{ ... })


for _ in 0..BATCHSIZE {
records.push({
TestIndistinguishableHybridReport {
match_key: rng.gen::<u64>(),
breakdown_key: rng.gen_range(0u32..1 << 8),
value: rng.gen_range(0u32..1 << 3),
}
});
}

let mut result: Vec<TestIndistinguishableHybridReport> = world
.semi_honest(records.clone().into_iter(), |ctx, input_rows| async move {
shuffle_hybrid_inputs::<_, BA8, BA3>(ctx, input_rows)
.await
.unwrap()
})
.await
.reconstruct();
assert_ne!(result, records);
records.sort();
result.sort();
assert_eq!(result, records);
});
}
}
2 changes: 2 additions & 0 deletions ipa-core/src/protocol/hybrid/step.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,6 @@ pub(crate) enum HybridStep {
ReshardByTag,
#[step(child = crate::protocol::ipa_prf::oprf_padding::step::PaddingDpStep, name="padding_dp")]
PaddingDp,
#[step(child = crate::protocol::ipa_prf::shuffle::step::OPRFShuffleStep)]
Shuffle,
}
42 changes: 39 additions & 3 deletions ipa-core/src/test_fixture/hybrid.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
use std::collections::{HashMap, HashSet};
use std::{
collections::{HashMap, HashSet},
iter::zip,
};

use rand::Rng;

use crate::{
ff::{boolean_array::BooleanArray, U128Conversions},
ff::{
boolean_array::{BooleanArray, BA64},
U128Conversions,
},
report::hybrid::IndistinguishableHybridReport,
secret_sharing::{replicated::semi_honest::AdditiveShare as Replicated, IntoShares},
test_fixture::sharing::Reconstruct,
Expand All @@ -13,7 +21,7 @@ pub enum TestHybridRecord {
TestConversion { match_key: u64, value: u32 },
}

#[derive(PartialEq, Eq)]
#[derive(Debug, Clone, Ord, PartialEq, PartialOrd, Eq)]
pub struct TestIndistinguishableHybridReport {
pub match_key: u64,
pub value: u32,
Expand Down Expand Up @@ -51,6 +59,34 @@ where
}
}

impl<BK, V> IntoShares<IndistinguishableHybridReport<BK, V>> for TestIndistinguishableHybridReport
where
BK: BooleanArray + U128Conversions + IntoShares<Replicated<BK>>,
V: BooleanArray + U128Conversions + IntoShares<Replicated<V>>,
{
fn share_with<R: Rng>(self, rng: &mut R) -> [IndistinguishableHybridReport<BK, V>; 3] {
let match_key = BA64::try_from(u128::from(self.match_key))
.unwrap()
.share_with(rng);
let breakdown_key = BK::try_from(self.breakdown_key.into())
.unwrap()
.share_with(rng);
let value = V::try_from(self.value.into()).unwrap().share_with(rng);

zip(zip(match_key, breakdown_key), value)
.map(
|((match_key_share, bk_share), value_share)| IndistinguishableHybridReport {
match_key: match_key_share,
breakdown_key: bk_share,
value: value_share,
},
)
.collect::<Vec<_>>()
.try_into()
.unwrap()
}
}

struct HashmapEntry {
breakdown_key: u32,
total_value: u32,
Expand Down
Loading