Skip to content

Commit

Permalink
finally found the issue
Browse files Browse the repository at this point in the history
  • Loading branch information
Thomas James Yurek committed Oct 31, 2024
1 parent 8cda0d1 commit 9a6c491
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 186 deletions.
84 changes: 63 additions & 21 deletions ipa-core/src/query/runner/hybrid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use futures::{stream::iter, StreamExt, TryStreamExt};
use crate::{
error::Error,
ff::{
boolean_array::{BooleanArray, BA20, BA3, BA8},
boolean_array::{BooleanArray, BA3, BA8},
U128Conversions,
},
helpers::{
Expand All @@ -20,31 +20,40 @@ use crate::{
step::ProtocolStep::Hybrid,
},
query::runner::reshard_tag::reshard_aad,
report::hybrid::{
EncryptedHybridReport, IndistinguishableHybridReport, UniqueTag, UniqueTagValidator,
report::{
hybrid::{
EncryptedHybridReport, IndistinguishableHybridReport, UniqueTag, UniqueTagValidator,
},
hybrid_info::HybridInfo,
},
secret_sharing::replicated::semi_honest::AdditiveShare as Replicated,
};

#[allow(dead_code)]
pub struct Query<C, HV, R: PrivateKeyRegistry> {
pub struct Query<'a, C, HV, R: PrivateKeyRegistry> {
config: HybridQueryParams,
key_registry: Arc<R>,
hybrid_info: HybridInfo<'a>,
phantom_data: PhantomData<(C, HV)>,
}

#[allow(dead_code)]
impl<C, HV, R: PrivateKeyRegistry> Query<C, HV, R> {
pub fn new(query_params: HybridQueryParams, key_registry: Arc<R>) -> Self {
impl<'a, C, HV, R: PrivateKeyRegistry> Query<'a, C, HV, R> {
pub fn new(
query_params: HybridQueryParams,
key_registry: Arc<R>,
hybrid_info: HybridInfo<'a>,
) -> Self {
Self {
config: query_params,
key_registry,
hybrid_info,
phantom_data: PhantomData,
}
}
}

impl<C, HV, R> Query<C, HV, R>
impl<'a, C, HV, R> Query<'a, C, HV, R>
where
C: UpgradableContext + Shuffle + ShardedContext,
HV: BooleanArray + U128Conversions,
Expand All @@ -60,6 +69,7 @@ where
let Self {
config,
key_registry,
hybrid_info,
phantom_data: _,
} = self;

Expand All @@ -73,13 +83,13 @@ where
));
}

let stream = LengthDelimitedStream::<EncryptedHybridReport, _>::new(input_stream)
let stream = LengthDelimitedStream::<EncryptedHybridReport<BA8, BA3>, _>::new(input_stream)
.map_err(Into::<Error>::into)
.map_ok(|enc_reports| {
iter(enc_reports.into_iter().map({
|enc_report| {
let dec_report = enc_report
.decrypt::<R, BA8, BA3, BA20>(key_registry.as_ref())
.decrypt(key_registry.as_ref(), &hybrid_info)
.map_err(Into::<Error>::into);
let unique_tag = UniqueTag::from_unique_bytes(&enc_report);
dec_report.map(|dec_report1| (dec_report1, unique_tag))
Expand Down Expand Up @@ -142,7 +152,7 @@ mod tests {

use crate::{
ff::{
boolean_array::{BA16, BA20, BA3, BA8},
boolean_array::{BA16, BA3, BA8},
U128Conversions,
},
helpers::{
Expand All @@ -151,7 +161,11 @@ mod tests {
},
hpke::{KeyPair, KeyRegistry},
query::runner::hybrid::Query as HybridQuery,
report::{OprfReport, DEFAULT_KEY_ID},
report::{
hybrid::HybridReport,
hybrid_info::{HybridConversionInfo, HybridInfo},
DEFAULT_KEY_ID,
},
secret_sharing::{replicated::semi_honest::AdditiveShare, IntoShares},
test_fixture::{
flatten3v, ipa::TestRawDataRecord, Reconstruct, RoundRobinInputDistribution, TestWorld,
Expand All @@ -165,7 +179,7 @@ mod tests {
// TODO: When Encryption/Decryption exists for HybridReports
// update these to use that, rather than generating OprfReports
vec![
TestRawDataRecord {
/*TestRawDataRecord {
timestamp: 0,
user_id: 12345,
is_trigger_report: false,
Expand All @@ -178,7 +192,7 @@ mod tests {
is_trigger_report: false,
breakdown_key: 1,
trigger_value: 0,
},
},*/
TestRawDataRecord {
timestamp: 10,
user_id: 12345,
Expand All @@ -193,13 +207,13 @@ mod tests {
breakdown_key: 0,
trigger_value: 2,
},
TestRawDataRecord {
/*TestRawDataRecord {
timestamp: 20,
user_id: 68362,
is_trigger_report: false,
breakdown_key: 1,
trigger_value: 0,
},
},*/
TestRawDataRecord {
timestamp: 30,
user_id: 68362,
Expand All @@ -216,17 +230,27 @@ mod tests {
query_sizes: Vec<QuerySize>,
}

fn build_buffers_from_records(records: &[TestRawDataRecord], s: usize) -> BufferAndKeyRegistry {
fn build_buffers_from_records(
records: &[TestRawDataRecord],
s: usize,
info: &HybridInfo,
) -> BufferAndKeyRegistry {
let mut rng = StdRng::seed_from_u64(42);
let key_id = DEFAULT_KEY_ID;
let key_registry = Arc::new(KeyRegistry::<KeyPair>::random(1, &mut rng));

let mut buffers: [_; 3] = std::array::from_fn(|_| vec![Vec::new(); s]);
let shares: [Vec<OprfReport<BA8, BA3, BA20>>; 3] = records.iter().cloned().share();
let shares: [Vec<HybridReport<BA8, BA3>>; 3] = records.iter().cloned().share();
for (buf, shares) in zip(&mut buffers, shares) {
for (i, share) in shares.into_iter().enumerate() {
share
.delimited_encrypt_to(key_id, key_registry.as_ref(), &mut rng, &mut buf[i % s])
.delimited_encrypt_to(
key_id,
key_registry.as_ref(),
info,
&mut rng,
&mut buf[i % s],
)
.unwrap();
}
}
Expand Down Expand Up @@ -265,11 +289,16 @@ mod tests {
const SHARDS: usize = 2;
let records = build_records();

let hybrid_info = HybridInfo::Conversion(
HybridConversionInfo::new(0, "HELPER_ORIGIN", "meta.com", 1_729_707_432, 5.0, 1.1)
.unwrap(),
);

let BufferAndKeyRegistry {
buffers,
key_registry,
query_sizes,
} = build_buffers_from_records(&records, SHARDS);
} = build_buffers_from_records(&records, SHARDS, &hybrid_info);

let world: TestWorld<WithShards<SHARDS, RoundRobinInputDistribution>> =
TestWorld::with_shards(TestWorldConfig::default());
Expand All @@ -295,6 +324,7 @@ mod tests {
HybridQuery::<_, BA16, KeyRegistry<KeyPair>>::new(
query_params,
Arc::clone(&key_registry),
hybrid_info.clone(),
)
.execute(ctx, query_size, input)
})
Expand Down Expand Up @@ -329,11 +359,16 @@ mod tests {
const SHARDS: usize = 2;
let records = build_records();

let hybrid_info = HybridInfo::Conversion(
HybridConversionInfo::new(0, "HELPER_ORIGIN", "meta.com", 1_729_707_432, 5.0, 1.1)
.unwrap(),
);

let BufferAndKeyRegistry {
mut buffers,
key_registry,
query_sizes,
} = build_buffers_from_records(&records, SHARDS);
} = build_buffers_from_records(&records, SHARDS, &hybrid_info);

// this is double, since we duplicate the data below
let query_sizes = query_sizes
Expand Down Expand Up @@ -381,6 +416,7 @@ mod tests {
HybridQuery::<_, BA16, KeyRegistry<KeyPair>>::new(
query_params,
Arc::clone(&key_registry),
hybrid_info.clone(),
)
.execute(ctx, query_size, input)
})
Expand All @@ -400,11 +436,16 @@ mod tests {
const SHARDS: usize = 2;
let records = build_records();

let hybrid_info = HybridInfo::Conversion(
HybridConversionInfo::new(0, "HELPER_ORIGIN", "meta.com", 1_729_707_432, 5.0, 1.1)
.unwrap(),
);

let BufferAndKeyRegistry {
buffers,
key_registry,
query_sizes,
} = build_buffers_from_records(&records, SHARDS);
} = build_buffers_from_records(&records, SHARDS, &hybrid_info);

let world: TestWorld<WithShards<SHARDS, RoundRobinInputDistribution>> =
TestWorld::with_shards(TestWorldConfig::default());
Expand All @@ -430,6 +471,7 @@ mod tests {
HybridQuery::<_, BA16, KeyRegistry<KeyPair>>::new(
query_params,
Arc::clone(&key_registry),
hybrid_info.clone(),
)
.execute(ctx, query_size, input)
})
Expand Down
Loading

0 comments on commit 9a6c491

Please sign in to comment.