diff --git a/benches/speed_tests.rs b/benches/speed_tests.rs index bf7a66e9a..c34591527 100644 --- a/benches/speed_tests.rs +++ b/benches/speed_tests.rs @@ -19,6 +19,8 @@ use prio::dp::distributions::DiscreteGaussian; use prio::idpf::test_utils::generate_zipf_distributed_batch; #[cfg(feature = "experimental")] use prio::vdaf::prio2::Prio2; +#[cfg(feature = "experimental")] +use prio::vidpf::VidpfServerId; use prio::{ benchmarked::*, field::{random_vector, Field128 as F, FieldElement}, @@ -813,7 +815,9 @@ fn vidpf(c: &mut Criterion) { let (public, keys) = vidpf.gen(&input, &weight, NONCE).unwrap(); b.iter(|| { - let _ = vidpf.eval(&keys[0], &public, &input, NONCE).unwrap(); + let _ = vidpf + .eval(VidpfServerId::S0, &keys[0], &public, &input, NONCE) + .unwrap(); }); }); } diff --git a/src/bt.rs b/src/bt.rs index 1bb0cb64d..f2cde66b3 100644 --- a/src/bt.rs +++ b/src/bt.rs @@ -78,13 +78,13 @@ type SubTree = Option>>; /// Represents a node of a binary tree. pub struct Node { - value: V, - left: SubTree, - right: SubTree, + pub(crate) value: V, + pub(crate) left: SubTree, + pub(crate) right: SubTree, } impl Node { - fn new(value: V) -> Self { + pub(crate) fn new(value: V) -> Self { Self { value, left: None, @@ -181,7 +181,7 @@ impl Node { /// Represents an append-only binary tree. pub struct BinaryTree { - root: SubTree, + pub(crate) root: SubTree, } impl BinaryTree { diff --git a/src/flp/szk.rs b/src/flp/szk.rs index ef504204f..4e1f22b14 100644 --- a/src/flp/szk.rs +++ b/src/flp/szk.rs @@ -18,7 +18,6 @@ use crate::{ prng::{Prng, PrngError}, vdaf::xof::{IntoFieldVec, Seed, Xof, XofTurboShake128}, }; -#[cfg(test)] use std::borrow::Cow; use std::ops::BitAnd; use std::{io::Cursor, marker::PhantomData}; @@ -27,7 +26,6 @@ use subtle::{Choice, ConstantTimeEq}; // Domain separation tags const DST_PROVE_RANDOMNESS: u16 = 0; const DST_PROOF_SHARE: u16 = 1; -#[allow(dead_code)] const DST_QUERY_RANDOMNESS: u16 = 2; const DST_JOINT_RAND_SEED: u16 = 3; const DST_JOINT_RAND_PART: u16 = 4; @@ -39,14 +37,19 @@ const MASTIC_VERSION: u8 = 0; #[derive(Debug, thiserror::Error)] #[non_exhaustive] pub enum SzkError { - #[error("Szk decide error: {0}")] /// Returned for errors in Szk verification step + #[error("Szk decide error: {0}")] Decide(String), - #[error("Szk query error: {0}")] /// Returned for errors in query evaluation + #[error("Szk query error: {0}")] Query(String), + /// Returned when a user fails to store the length of the verifier so it + /// can be properly read upon receipt. + #[error("Part of Szk query state not stored")] + InvalidState(String), + /// Returned if an FLP operation encountered an error. #[error("Flp error: {0}")] Flp(#[from] FlpError), @@ -209,65 +212,26 @@ impl ParameterizedDecode<(bool } /// A tuple containing the state and messages produced by an SZK query. -#[cfg(test)] #[derive(Clone, Debug)] pub struct SzkQueryShare { joint_rand_part_opt: Option>, - flp_verifier: Vec, -} - -/// The state that needs to be stored by an Szk verifier between query() and decide() -pub type SzkQueryState = Option>; - -#[cfg(test)] -impl SzkQueryShare { - pub(crate) fn merge_verifiers( - mut leader_share: SzkQueryShare, - helper_share: SzkQueryShare, - ) -> SzkVerifier { - for (x, y) in leader_share - .flp_verifier - .iter_mut() - .zip(helper_share.flp_verifier) - { - *x += y; - } - SzkVerifier { - flp_verifier: leader_share.flp_verifier, - leader_joint_rand_part_opt: leader_share.joint_rand_part_opt, - helper_joint_rand_part_opt: helper_share.joint_rand_part_opt, - } - } -} - -/// Verifier type for the SZK proof. -#[derive(Clone, Debug, Eq, PartialEq)] -pub struct SzkVerifier { - flp_verifier: Vec, - leader_joint_rand_part_opt: Option>, - helper_joint_rand_part_opt: Option>, + pub(crate) flp_verifier: Vec, } -impl Encode for SzkVerifier { +impl Encode for SzkQueryShare { fn encode(&self, bytes: &mut Vec) -> Result<(), CodecError> { - encode_fieldvec(&self.flp_verifier, bytes)?; - if let Some(ref part) = self.leader_joint_rand_part_opt { - part.encode(bytes)? - }; - if let Some(ref part) = self.helper_joint_rand_part_opt { + if let Some(ref part) = self.joint_rand_part_opt { part.encode(bytes)? }; + + encode_fieldvec(&self.flp_verifier, bytes)?; Ok(()) } fn encoded_len(&self) -> Option { Some( self.flp_verifier.len() * F::ENCODED_SIZE - + match self.leader_joint_rand_part_opt { - Some(ref part) => part.encoded_len()?, - None => 0, - } - + match self.helper_joint_rand_part_opt { + + match self.joint_rand_part_opt { Some(ref part) => part.encoded_len()?, None => 0, }, @@ -276,24 +240,81 @@ impl Encode for SzkVerifier ParameterizedDecode<(bool, usize)> - for SzkVerifier + for SzkQueryShare { fn decode_with_param( (requires_joint_rand, verifier_len): &(bool, usize), bytes: &mut Cursor<&[u8]>, + ) -> Result { + Ok(SzkQueryShare { + joint_rand_part_opt: (*requires_joint_rand) + .then(|| Seed::::decode(bytes)) + .transpose()?, + flp_verifier: decode_fieldvec(*verifier_len, bytes)?, + }) + } +} + +/// Szk query state. +/// +/// The state that needs to be stored by an Szk verifier between query() and decide(). +pub type SzkQueryState = Option>; + +/// Verifier type for the SZK proof. +pub type SzkVerifier = Vec; + +impl Encode for SzkVerifier { + fn encode(&self, bytes: &mut Vec) -> Result<(), CodecError> { + encode_fieldvec(self, bytes)?; + Ok(()) + } + + fn encoded_len(&self) -> Option { + Some(self.len() * F::ENCODED_SIZE) + } +} + +impl ParameterizedDecode for SzkVerifier { + fn decode_with_param( + verifier_len: &usize, + bytes: &mut Cursor<&[u8]>, + ) -> Result { + decode_fieldvec(*verifier_len, bytes) + } +} + +/// Joint share type for the SZK proof. +pub type SzkJointShare = Option<[Seed; 2]>; + +impl Encode for SzkJointShare { + fn encode(&self, bytes: &mut Vec) -> Result<(), CodecError> { + if let Some([ref leader_part, ref helper_part]) = self { + leader_part.encode(bytes)?; + helper_part.encode(bytes)?; + }; + Ok(()) + } + + fn encoded_len(&self) -> Option { + Some(match self { + Some(ref part) => part[0].encoded_len()? * 2, + None => 0, + }) + } +} + +impl ParameterizedDecode for SzkJointShare { + fn decode_with_param( + requires_joint_rand: &bool, + bytes: &mut Cursor<&[u8]>, ) -> Result { if *requires_joint_rand { - Ok(SzkVerifier { - flp_verifier: decode_fieldvec(*verifier_len, bytes)?, - leader_joint_rand_part_opt: Some(Seed::::decode(bytes)?), - helper_joint_rand_part_opt: Some(Seed::::decode(bytes)?), - }) + Ok(Some([ + Seed::::decode(bytes)?, + Seed::::decode(bytes)?, + ])) } else { - Ok(SzkVerifier { - flp_verifier: decode_fieldvec(*verifier_len, bytes)?, - leader_joint_rand_part_opt: None, - helper_joint_rand_part_opt: None, - }) + Ok(None) } } } @@ -418,7 +439,6 @@ where .collect() } - #[cfg(test)] fn derive_query_rand(&self, verify_key: &[u8; SEED_SIZE], nonce: &[u8; 16]) -> Vec { let mut xof = P::init( verify_key, @@ -429,7 +449,7 @@ where .into_field_vec(self.typ.query_rand_len()) } - pub(crate) fn has_joint_rand(&self) -> bool { + pub(crate) fn requires_joint_rand(&self) -> bool { self.typ.joint_rand_len() > 0 } @@ -498,7 +518,6 @@ where Ok([leader_proof_share, helper_proof_share]) } - #[cfg(test)] pub(crate) fn query( &self, input_share: &[T::Field], @@ -518,7 +537,7 @@ where } => Cow::Owned(self.derive_helper_proof_share(proof_share_seed_and_blind)), }; - let (joint_rand, joint_rand_seed, joint_rand_part) = if self.has_joint_rand() { + let (joint_rand, joint_rand_seed, joint_rand_part) = if self.requires_joint_rand() { let ((joint_rand_seed, joint_rand), host_joint_rand_part) = match proof_share { SzkProofShare::Leader { uncompressed_proof_share: _, @@ -591,34 +610,56 @@ where )) } - /// Returns true if the verifier message indicates that the input from which - /// it was generated is valid. + pub(crate) fn merge_verifiers( + &self, + mut leader_share: SzkQueryShare, + helper_share: SzkQueryShare, + ) -> Result, SzkError> { + for (x, y) in leader_share + .flp_verifier + .iter_mut() + .zip(helper_share.flp_verifier) + { + *x += y; + } + if self.typ.decide(&leader_share.flp_verifier)? { + match ( + leader_share.joint_rand_part_opt, + helper_share.joint_rand_part_opt, + ) { + (Some(leader_part), Some(helper_part)) => Ok(Some([leader_part, helper_part])), + (None, None) => Ok(None), + _ => Err(SzkError::InvalidState( + "at least one of the joint randomness parts is missing".to_string(), + )), + } + } else { + Err(SzkError::Decide("failed to verify FLP proof".to_string())) + } + } + /// Returns true if the leader and helper derive identical joint randomness + /// seeds pub fn decide( &self, - verifier: SzkVerifier, query_state: SzkQueryState, - ) -> Result { - // Check if underlying FLP proof validates - let check_flp_proof = self.typ.decide(&verifier.flp_verifier)?; - if !check_flp_proof { - return Ok(false); - } + joint_share: &[Seed; 2], + ) -> Result<(), SzkError> { // Check that joint randomness was properly derived from both // aggregators' parts - match ( - query_state, - verifier.leader_joint_rand_part_opt, - verifier.helper_joint_rand_part_opt, - ) { - (Some(joint_rand_seed), Some(leader_joint_rand_part), Some(helper_joint_rand_part)) => { + match (query_state, joint_share) { + (Some(joint_rand_seed), [leader_joint_rand_part, helper_joint_rand_part]) => { let expected_joint_rand_seed = - self.derive_joint_rand_seed(&leader_joint_rand_part, &helper_joint_rand_part); - Ok(joint_rand_seed == expected_joint_rand_seed) + self.derive_joint_rand_seed(leader_joint_rand_part, helper_joint_rand_part); + if joint_rand_seed == expected_joint_rand_seed { + Ok(()) + } else { + Err(SzkError::Decide( + "Aggregators failed to compute identical joint randomness seeds" + .to_string(), + )) + } } - (None, None, None) => Ok(true), - (_, _, _) => Err(SzkError::Decide( - "at least one of the input seeds is missing".to_string(), - )), + (None, _) => Ok(()), } } } @@ -673,7 +714,7 @@ mod tests { thread_rng().fill(&mut nonce[..]); let prove_rand_seed = Seed::<16>::generate().unwrap(); let helper_seed = Seed::<16>::generate().unwrap(); - let leader_seed_opt = if szk_typ.has_joint_rand() { + let leader_seed_opt = if szk_typ.requires_joint_rand() { Some(Seed::<16>::generate().unwrap()) } else { None @@ -706,29 +747,34 @@ mod tests { .query(&helper_input_share, h_proof_share, &verify_key, &nonce) .unwrap(); - let verifier = SzkQueryShare::merge_verifiers(l_query_share.clone(), h_query_share.clone()); - if let Ok(leader_decision) = szk_typ.decide(verifier.clone(), l_query_state.clone()) { - assert_eq!( - leader_decision, valid, - "Leader incorrectly determined validity", - ); - } else { - panic!("Leader failed during decision"); - }; - if let Ok(helper_decision) = szk_typ.decide(verifier.clone(), h_query_state.clone()) { - assert_eq!( - helper_decision, valid, - "Helper incorrectly determined validity", - ); - } else { - panic!("Helper failed during decision"); + let joint_share_result = + szk_typ.merge_verifiers(l_query_share.clone(), h_query_share.clone()); + match joint_share_result { + Ok(Some(ref joint_share)) => { + let leader_decision = szk_typ.decide(l_query_state.clone(), joint_share).is_ok(); + assert_eq!( + leader_decision, valid, + "Leader incorrectly determined validity", + ); + let helper_decision = szk_typ.decide(h_query_state.clone(), joint_share).is_ok(); + assert_eq!( + helper_decision, valid, + "Helper incorrectly determined validity", + ); + } + Ok(None) => assert!(valid, "Aggregator incorrectly determined validity"), + Err(_) => { + assert!(!valid, "Aggregator incorrectly determined validity"); + } }; //test mutated jr seed - if szk_typ.has_joint_rand() { + if szk_typ.requires_joint_rand() { let joint_rand_seed_opt = Some(Seed::<16>::generate().unwrap()); - if let Ok(leader_decision) = szk_typ.decide(verifier, joint_rand_seed_opt.clone()) { - assert!(!leader_decision, "Leader accepted wrong jr seed"); + if let Ok(Some(ref joint_share)) = joint_share_result { + if let Ok(()) = szk_typ.decide(joint_rand_seed_opt.clone(), joint_share) { + panic!("Leader accepted wrong jr seed"); + }; }; }; @@ -740,9 +786,12 @@ mod tests { ); } - let verifier = SzkQueryShare::merge_verifiers(mutated_query_share, h_query_share.clone()); - - let leader_decision = szk_typ.decide(verifier, l_query_state.clone()).unwrap(); + let joint_share_res = szk_typ.merge_verifiers(mutated_query_share, h_query_share.clone()); + let leader_decision = match joint_share_res { + Ok(Some(ref joint_share)) => szk_typ.decide(l_query_state.clone(), joint_share).is_ok(), + Ok(None) => true, + Err(_) => false, + }; assert!(!leader_decision, "Leader validated after proof mutation"); // test mutated input share @@ -753,11 +802,14 @@ mod tests { .query(&mutated_input, l_proof_share.clone(), &verify_key, &nonce) .unwrap(); - let verifier = SzkQueryShare::merge_verifiers(mutated_query_share, h_query_share.clone()); + let joint_share_res = szk_typ.merge_verifiers(mutated_query_share, h_query_share.clone()); - if let Ok(leader_decision) = szk_typ.decide(verifier, mutated_query_state) { - assert!(!leader_decision, "Leader validated after input mutation"); + let leader_decision = match joint_share_res { + Ok(Some(ref joint_share)) => szk_typ.decide(mutated_query_state, joint_share).is_ok(), + Ok(None) => true, + Err(_) => false, }; + assert!(!leader_decision, "Leader validated after input mutation"); // test mutated proof share let (mut mutated_proof, leader_blind_and_helper_joint_rand_part_opt) = match l_proof_share { @@ -784,11 +836,14 @@ mod tests { &nonce, ) .unwrap(); - let verifier = SzkQueryShare::merge_verifiers(l_query_share, h_query_share.clone()); + let joint_share_res = szk_typ.merge_verifiers(l_query_share, h_query_share.clone()); - if let Ok(leader_decision) = szk_typ.decide(verifier, l_query_state) { - assert!(!leader_decision, "Leader validated after proof mutation"); + let leader_decision = match joint_share_res { + Ok(Some(ref joint_share)) => szk_typ.decide(l_query_state.clone(), joint_share).is_ok(), + Ok(None) => true, + Err(_) => false, }; + assert!(!leader_decision, "Leader validated after proof mutation"); } #[test] diff --git a/src/vdaf/mastic.rs b/src/vdaf/mastic.rs index 7b8d63424..298445c23 100644 --- a/src/vdaf/mastic.rs +++ b/src/vdaf/mastic.rs @@ -5,19 +5,22 @@ //! [draft-mouris-cfrg-mastic-01]: https://www.ietf.org/archive/id/draft-mouris-cfrg-mastic-01.html use crate::{ + bt::{BinaryTree, Path}, codec::{CodecError, Decode, Encode, ParameterizedDecode}, field::{decode_fieldvec, FieldElement}, flp::{ - szk::{Szk, SzkProofShare}, + szk::{Szk, SzkJointShare, SzkProofShare, SzkQueryShare, SzkQueryState}, Type, }, vdaf::{ poplar1::Poplar1AggregationParam, xof::{Seed, Xof}, - AggregateShare, Client, OutputShare, Vdaf, VdafError, + Aggregatable, AggregateShare, Aggregator, Client, Collector, OutputShare, + PrepareTransition, Vdaf, VdafError, }, vidpf::{ - Vidpf, VidpfError, VidpfInput, VidpfKey, VidpfPublicShare, VidpfServerId, VidpfWeight, + Vidpf, VidpfError, VidpfEvalCache, VidpfInput, VidpfKey, VidpfPublicShare, VidpfServerId, + VidpfWeight, }, }; @@ -26,6 +29,9 @@ use std::io::{Cursor, Read}; use std::ops::BitAnd; use subtle::{Choice, ConstantTimeEq}; +const DST_PATH_CHECK_BATCH: u16 = 6; +const NONCE_SIZE: usize = 16; + /// The main struct implementing the Mastic VDAF. /// Composed of a shared zero knowledge proof system and a verifiable incremental /// distributed point function. @@ -72,14 +78,24 @@ pub struct MasticAggregationParam { level_and_prefixes: Poplar1AggregationParam, /// Flag indicating whether the VIDPF weight needs to be validated using SZK. /// This flag must be set the first time any report is aggregated; however this may happen at any level of the tree. - require_check_flag: bool, + require_weight_check: bool, +} + +#[cfg(test)] +impl MasticAggregationParam { + fn new(prefixes: Vec, require_weight_check: bool) -> Result { + Ok(Self { + level_and_prefixes: Poplar1AggregationParam::try_from_prefixes(prefixes)?, + require_weight_check, + }) + } } impl Encode for MasticAggregationParam { fn encode(&self, bytes: &mut Vec) -> Result<(), CodecError> { self.level_and_prefixes.encode(bytes)?; - let require_check = if self.require_check_flag { 1u8 } else { 0u8 }; - require_check.encode(bytes)?; + let require_weight_check = if self.require_weight_check { 1u8 } else { 0u8 }; + require_weight_check.encode(bytes)?; Ok(()) } @@ -91,11 +107,11 @@ impl Encode for MasticAggregationParam { impl Decode for MasticAggregationParam { fn decode(bytes: &mut Cursor<&[u8]>) -> Result { let level_and_prefixes = Poplar1AggregationParam::decode(bytes)?; - let require_check = u8::decode(bytes)?; - let require_check_flag = require_check != 0; + let require_weight_check_u8 = u8::decode(bytes)?; + let require_weight_check = require_weight_check_u8 != 0; Ok(Self { level_and_prefixes, - require_check_flag, + require_weight_check, }) } } @@ -122,7 +138,7 @@ where } } -/// Mastic input share +/// Mastic input share. /// /// Message sent by the [`Client`] to each Aggregator during the Sharding phase. #[derive(Clone, Debug)] @@ -136,7 +152,7 @@ pub struct MasticInputShare { impl Encode for MasticInputShare { fn encode(&self, bytes: &mut Vec) -> Result<(), CodecError> { - bytes.extend_from_slice(&self.vidpf_key.value[..]); + bytes.extend_from_slice(&self.vidpf_key.0[..]); self.proof_share.encode(bytes)?; Ok(()) } @@ -161,15 +177,7 @@ where } let mut value = [0; 16]; bytes.read_exact(&mut value)?; - let vidpf_key = VidpfKey::new( - if *agg_id == 0 { - VidpfServerId::S0 - } else { - VidpfServerId::S1 - }, - value, - ); - + let vidpf_key = VidpfKey::from_bytes(value); let proof_share = SzkProofShare::::decode_with_param( &( *agg_id == 0, @@ -185,7 +193,6 @@ where } } -#[cfg(test)] impl PartialEq for MasticInputShare { fn eq(&self, other: &MasticInputShare) -> bool { self.ct_eq(other).into() @@ -260,7 +267,7 @@ where P: Xof, { type Measurement = (VidpfInput, T::Measurement); - type AggregateResult = T::AggregateResult; + type AggregateResult = Vec; type AggregationParam = MasticAggregationParam; type PublicShare = MasticPublicShare>; type InputShare = MasticInputShare; @@ -300,9 +307,11 @@ where )?; let leader_measurement_share = - self.vidpf.eval_root(&vidpf_keys[0], &public_share, nonce)?; + self.vidpf + .eval_root(VidpfServerId::S0, &vidpf_keys[0], &public_share, nonce)?; let helper_measurement_share = - self.vidpf.eval_root(&vidpf_keys[1], &public_share, nonce)?; + self.vidpf + .eval_root(VidpfServerId::S1, &vidpf_keys[1], &public_share, nonce)?; let [leader_szk_proof_share, helper_szk_proof_share] = self.szk.prove( leader_measurement_share.as_ref(), @@ -348,19 +357,13 @@ where return Err(VdafError::Vidpf(VidpfError::InvalidAttributeLength)); } - let vidpf_keys = [ - VidpfKey::gen(VidpfServerId::S0)?, - VidpfKey::gen(VidpfServerId::S1)?, - ]; - let joint_random_opt = if self.szk.has_joint_rand() { + let vidpf_keys = [VidpfKey::generate()?, VidpfKey::generate()?]; + let joint_random_opt = if self.szk.requires_joint_rand() { Some(Seed::::generate()?) } else { None }; - let szk_random = [ - Seed::::generate()?, - Seed::::generate()?, - ]; + let szk_random = [Seed::generate()?, Seed::generate()?]; let encoded_measurement = self.encode_measurement(weight)?; if encoded_measurement.as_ref().len() != self.vidpf.weight_parameter { @@ -379,32 +382,418 @@ where } } +/// Mastic prepare state. +/// +/// State held by an aggregator between rounds of Mastic preparation. Includes intermediate +/// state for [`Szk``] verification, the output shares currently being validated, and +/// parameters of Mastic used for encoding. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct MasticPrepareState { + /// Includes output shares for eventual aggregation. + output_shares: MasticOutputShare, + /// If [`Szk`]` verification is being performed, we also store the relevant state for that operation. + szk_query_state: SzkQueryState, + verifier_len: Option, +} + +/// Mastic prepare share. +/// +/// Broadcast message from an aggregator between rounds of Mastic. Includes the +/// [`Vidpf`] evaluation proof covering every prefix in the aggregation parameter, and optionally +/// the verification message for Szk. +#[derive(Clone, Debug)] +pub struct MasticPrepareShare { + /// [`Vidpf`] evaluation proof, which guarantees one-hotness and payload consistency. + vidpf_proof: Seed, + + /// If [`Szk`]` verification of the root weight is needed, a verification message. + szk_query_share_opt: Option>, +} + +impl Encode for MasticPrepareShare { + fn encode(&self, bytes: &mut Vec) -> Result<(), CodecError> { + self.vidpf_proof.encode(bytes)?; + match &self.szk_query_share_opt { + Some(query_share) => query_share.encode(bytes), + None => Ok(()), + } + } + + fn encoded_len(&self) -> Option { + Some( + self.vidpf_proof.encoded_len()? + + match &self.szk_query_share_opt { + Some(query_share) => query_share.encoded_len()?, + None => 0, + }, + ) + } +} + +impl ParameterizedDecode> + for MasticPrepareShare +{ + fn decode_with_param( + prep_state: &MasticPrepareState, + bytes: &mut Cursor<&[u8]>, + ) -> Result { + let vidpf_proof = Seed::decode(bytes)?; + let requires_joint_rand = prep_state.szk_query_state.is_some(); + let szk_query_share_opt = prep_state + .verifier_len + .map(|verifier_len| { + SzkQueryShare::::decode_with_param( + &(requires_joint_rand, verifier_len), + bytes, + ) + }) + .transpose()?; + Ok(Self { + vidpf_proof, + szk_query_share_opt, + }) + } +} + +/// Mastic prepare message. +/// +/// Result of preprocessing the broadcast messages of both aggregators during the +/// preparation phase. +pub type MasticPrepareMessage = SzkJointShare; + +impl ParameterizedDecode> + for MasticPrepareMessage +{ + fn decode_with_param( + prep_state: &MasticPrepareState, + bytes: &mut Cursor<&[u8]>, + ) -> Result { + match prep_state.szk_query_state { + Some(_) => SzkJointShare::::decode_with_param(&true, bytes), + None => SzkJointShare::::decode_with_param(&false, bytes), + } + } +} + +impl Aggregator for Mastic +where + T: Type, + P: Xof, +{ + type PrepareState = MasticPrepareState; + type PrepareShare = MasticPrepareShare; + type PrepareMessage = MasticPrepareMessage; + + fn prepare_init( + &self, + verify_key: &[u8; SEED_SIZE], + agg_id: usize, + agg_param: &MasticAggregationParam, + nonce: &[u8; NONCE_SIZE], + public_share: &MasticPublicShare>, + input_share: &MasticInputShare, + ) -> Result< + ( + MasticPrepareState, + MasticPrepareShare, + ), + VdafError, + > { + let id = match agg_id { + 0 => Ok(VidpfServerId::S0), + 1 => Ok(VidpfServerId::S1), + _ => Err(VdafError::Uncategorized( + "Invalid aggregator ID".to_string(), + )), + }?; + let mut eval_proof = P::init( + verify_key, + &self.domain_separation_tag(DST_PATH_CHECK_BATCH), + ); + let mut output_shares = Vec::::with_capacity( + self.vidpf.weight_parameter * agg_param.level_and_prefixes.prefixes().len(), + ); + let mut cache_tree = BinaryTree::>>::default(); + let cache = VidpfEvalCache::>::init_from_key( + id, + &input_share.vidpf_key, + &self.vidpf.weight_parameter, + ); + cache_tree + .insert(Path::empty(), cache) + .expect("Should alwys be able to insert into empty tree at root"); + for prefix in agg_param.level_and_prefixes.prefixes() { + let mut value_share = self.vidpf.eval_with_cache( + id, + &input_share.vidpf_key, + public_share, + prefix, + &mut cache_tree, + nonce, + )?; + eval_proof.update(&value_share.proof); + output_shares.append(&mut value_share.share.0); + } + + let szk_verify_opt = if agg_param.require_weight_check { + let root_share = self.vidpf.eval_root_with_cache( + id, + &input_share.vidpf_key, + public_share, + &mut cache_tree, + nonce, + )?; + Some(self.szk.query( + root_share.as_ref(), + input_share.proof_share.clone(), + verify_key, + nonce, + )?) + } else { + None + }; + + let (prep_share, prep_state) = + if let Some((szk_query_share, szk_query_state)) = szk_verify_opt { + let verifier_len = szk_query_share.flp_verifier.len(); + ( + MasticPrepareShare { + vidpf_proof: eval_proof.into_seed(), + szk_query_share_opt: Some(szk_query_share), + }, + MasticPrepareState { + output_shares: MasticOutputShare::::from(output_shares), + szk_query_state, + verifier_len: Some(verifier_len), + }, + ) + } else { + ( + MasticPrepareShare { + vidpf_proof: eval_proof.into_seed(), + szk_query_share_opt: None, + }, + MasticPrepareState { + output_shares: MasticOutputShare::::from(output_shares), + szk_query_state: None, + verifier_len: None, + }, + ) + }; + Ok((prep_state, prep_share)) + } + + fn prepare_shares_to_prepare_message< + M: IntoIterator>, + >( + &self, + _agg_param: &MasticAggregationParam, + inputs: M, + ) -> Result, VdafError> { + let mut inputs_iter = inputs.into_iter(); + let leader_share = inputs_iter.next().ok_or(VdafError::Uncategorized( + "No leader share received".to_string(), + ))?; + let helper_share = inputs_iter.next().ok_or(VdafError::Uncategorized( + "No helper share received".to_string(), + ))?; + if inputs_iter.next().is_some() { + return Err(VdafError::Uncategorized( + "Received more than two prepare shares".to_string(), + )); + }; + if leader_share.vidpf_proof != helper_share.vidpf_proof { + return Err(VdafError::Uncategorized( + "Vidpf proof verification failed".to_string(), + )); + }; + match ( + leader_share.szk_query_share_opt, + helper_share.szk_query_share_opt, + ) { + (Some(leader_query_share), Some(helper_query_share)) => Ok(self + .szk + .merge_verifiers(leader_query_share, helper_query_share)?), + (None, None) => Ok(None), + (_, _) => Err(VdafError::Uncategorized( + "Only one of leader and helper query shares is present".to_string(), + )), + } + } + + fn prepare_next( + &self, + state: MasticPrepareState, + input: MasticPrepareMessage, + ) -> Result, VdafError> { + match (state, input) { + ( + MasticPrepareState { + output_shares, + szk_query_state: _, + verifier_len: _, + }, + None, + ) => Ok(PrepareTransition::Finish(output_shares)), + ( + MasticPrepareState { + output_shares, + szk_query_state, + verifier_len: _, + }, + Some(ref joint_share), + ) => { + self.szk.decide(szk_query_state, joint_share)?; + Ok(PrepareTransition::Finish(output_shares)) + } + } + } + + fn aggregate>( + &self, + agg_param: &MasticAggregationParam, + output_shares: M, + ) -> Result, VdafError> { + let mut agg_share = MasticAggregateShare::::from(vec![ + T::Field::zero(); + self.vidpf.weight_parameter + * agg_param + .level_and_prefixes + .prefixes() + .len() + ]); + for output_share in output_shares.into_iter() { + agg_share.accumulate(&output_share)?; + } + Ok(agg_share) + } +} + +impl Collector for Mastic +where + T: Type, + P: Xof, +{ + fn unshard>( + &self, + agg_param: &MasticAggregationParam, + agg_shares: M, + _num_measurements: usize, + ) -> Result { + let n = agg_param.level_and_prefixes.prefixes().len(); + let mut agg_final = MasticAggregateShare::::from(vec![ + T::Field::zero(); + self.vidpf.weight_parameter + * n + ]); + for agg_share in agg_shares.into_iter() { + agg_final.merge(&agg_share)?; + } + let mut result = Vec::::with_capacity(n); + for i in 0..n { + let encoded_result = &agg_final.0 + [i * self.vidpf.weight_parameter..(i + 1) * self.vidpf.weight_parameter]; + result.push( + self.szk + .typ + .decode_result(&self.szk.typ.truncate(encoded_result.to_vec())?[..], 1)?, + ); + } + Ok(result) + } +} + #[cfg(test)] mod tests { use super::*; use crate::field::Field128; use crate::flp::gadgets::{Mul, ParallelSum}; use crate::flp::types::{Count, Sum, SumVec}; + use crate::vdaf::test_utils::{run_vdaf, run_vdaf_prepare}; use rand::{thread_rng, Rng}; const TEST_NONCE_SIZE: usize = 16; #[test] - fn test_mastic_shard_sum() { + fn test_mastic_sum() { let algorithm_id = 6; let sum_typ = Sum::::new(5).unwrap(); - let sum_szk = Szk::new_turboshake128(sum_typ, algorithm_id); let sum_vidpf = Vidpf::, TEST_NONCE_SIZE>::new(5); - + let sum_szk = Szk::new_turboshake128(sum_typ, algorithm_id); let mut nonce = [0u8; 16]; let mut verify_key = [0u8; 16]; thread_rng().fill(&mut verify_key[..]); thread_rng().fill(&mut nonce[..]); - let first_input = VidpfInput::from_bytes(&[15u8, 0u8, 1u8, 4u8][..]); - + let first_input = VidpfInput::from_bytes(&[240u8, 0u8, 1u8, 4u8][..]); + let second_input = VidpfInput::from_bytes(&[112u8, 0u8, 1u8, 4u8][..]); + let third_input = VidpfInput::from_bytes(&[48u8, 0u8, 1u8, 4u8][..]); + let fourth_input = VidpfInput::from_bytes(&[32u8, 0u8, 1u8, 4u8][..]); + let fifth_input = VidpfInput::from_bytes(&[0u8, 0u8, 1u8, 4u8][..]); + let first_prefix = VidpfInput::from_bools(&[false, false, true]); + let second_prefix = VidpfInput::from_bools(&[false]); + let third_prefix = VidpfInput::from_bools(&[true]); let mastic = Mastic::new(algorithm_id, sum_szk, sum_vidpf, 32); - let (_public, _input_shares) = mastic.shard(&(first_input, 24u128), &nonce).unwrap(); + + let first_agg_param = MasticAggregationParam::new(vec![first_prefix], true).unwrap(); + let second_agg_param = + MasticAggregationParam::new(vec![second_prefix, third_prefix], true).unwrap(); + + assert_eq!( + run_vdaf( + &mastic, + &first_agg_param, + [ + (first_input.clone(), 24), + (second_input.clone(), 0), + (third_input.clone(), 0), + (fourth_input.clone(), 3), + (fifth_input.clone(), 28) + ] + ) + .unwrap(), + vec![3] + ); + + assert_eq!( + run_vdaf( + &mastic, + &second_agg_param, + [ + (first_input.clone(), 24), + (second_input, 0), + (third_input, 0), + (fourth_input, 3), + (fifth_input, 28) + ] + ) + .unwrap(), + vec![31, 24] + ); + + let (public_share, input_shares) = mastic + .shard(&(first_input.clone(), 24u128), &nonce) + .unwrap(); + run_vdaf_prepare( + &mastic, + &verify_key, + &first_agg_param, + &nonce, + public_share, + input_shares, + ) + .unwrap(); + + let (public_share, input_shares) = mastic.shard(&(first_input, 4u128), &nonce).unwrap(); + run_vdaf_prepare( + &mastic, + &verify_key, + &first_agg_param, + &nonce, + public_share, + input_shares, + ) + .unwrap(); } #[test] @@ -436,7 +825,30 @@ mod tests { } #[test] - fn test_mastic_shard_count() { + fn test_public_share_roundtrip_sum() { + let algorithm_id = 6; + let sum_typ = Sum::::new(5).unwrap(); + let sum_szk = Szk::new_turboshake128(sum_typ, algorithm_id); + let sum_vidpf = Vidpf::, TEST_NONCE_SIZE>::new(5); + + let mut nonce = [0u8; 16]; + let mut verify_key = [0u8; 16]; + thread_rng().fill(&mut verify_key[..]); + thread_rng().fill(&mut nonce[..]); + + let first_input = VidpfInput::from_bytes(&[15u8, 0u8, 1u8, 4u8][..]); + + let mastic = Mastic::new(algorithm_id, sum_szk, sum_vidpf, 32); + let (public, _) = mastic.shard(&(first_input, 25u128), &nonce).unwrap(); + + let encoded_public = public.get_encoded().unwrap(); + let decoded_public = + MasticPublicShare::get_decoded_with_param(&mastic, &encoded_public[..]).unwrap(); + assert_eq!(public, decoded_public); + } + + #[test] + fn test_mastic_count() { let algorithm_id = 6; let count = Count::::new(); let szk = Szk::new_turboshake128(count, algorithm_id); @@ -447,18 +859,125 @@ mod tests { thread_rng().fill(&mut verify_key[..]); thread_rng().fill(&mut nonce[..]); + let first_input = VidpfInput::from_bytes(&[240u8, 0u8, 1u8, 4u8][..]); + let second_input = VidpfInput::from_bytes(&[112u8, 0u8, 1u8, 4u8][..]); + let third_input = VidpfInput::from_bytes(&[48u8, 0u8, 1u8, 4u8][..]); + let fourth_input = VidpfInput::from_bytes(&[32u8, 0u8, 1u8, 4u8][..]); + let fifth_input = VidpfInput::from_bytes(&[0u8, 0u8, 1u8, 4u8][..]); + let first_prefix = VidpfInput::from_bools(&[false, false, true]); + let second_prefix = VidpfInput::from_bools(&[false]); + let third_prefix = VidpfInput::from_bools(&[true]); + let mastic = Mastic::new(algorithm_id, szk, sum_vidpf, 32); + let first_agg_param = MasticAggregationParam::new(vec![first_prefix], true).unwrap(); + let second_agg_param = + MasticAggregationParam::new(vec![second_prefix, third_prefix], true).unwrap(); + + assert_eq!( + run_vdaf( + &mastic, + &first_agg_param, + [ + (first_input.clone(), true), + (second_input.clone(), false), + (third_input.clone(), false), + (fourth_input.clone(), true), + (fifth_input.clone(), true) + ] + ) + .unwrap(), + vec![1] + ); + + assert_eq!( + run_vdaf( + &mastic, + &second_agg_param, + [ + (first_input.clone(), true), + (second_input, false), + (third_input, false), + (fourth_input, true), + (fifth_input, true) + ] + ) + .unwrap(), + vec![2, 1] + ); + + let (public_share, input_shares) = + mastic.shard(&(first_input.clone(), false), &nonce).unwrap(); + run_vdaf_prepare( + &mastic, + &verify_key, + &first_agg_param, + &nonce, + public_share, + input_shares, + ) + .unwrap(); + + let (public_share, input_shares) = mastic.shard(&(first_input, true), &nonce).unwrap(); + run_vdaf_prepare( + &mastic, + &verify_key, + &first_agg_param, + &nonce, + public_share, + input_shares, + ) + .unwrap(); + } + + #[test] + fn test_public_share_encoded_len() { + let algorithm_id = 6; + let count = Count::::new(); + let szk = Szk::new_turboshake128(count, algorithm_id); + let sum_vidpf = Vidpf::, TEST_NONCE_SIZE>::new(1); + + let mut nonce = [0u8; 16]; + let mut verify_key = [0u8; 16]; + thread_rng().fill(&mut verify_key[..]); + thread_rng().fill(&mut nonce[..]); let first_input = VidpfInput::from_bytes(&[15u8, 0u8, 1u8, 4u8][..]); let mastic = Mastic::new(algorithm_id, szk, sum_vidpf, 32); - let (_public, _input_shares) = mastic.shard(&(first_input, true), &nonce).unwrap(); + let (public, _) = mastic.shard(&(first_input, true), &nonce).unwrap(); + + assert_eq!( + public.encoded_len().unwrap(), + public.get_encoded().unwrap().len() + ); } #[test] - fn test_mastic_shard_sumvec() { + fn test_public_share_roundtrip_count() { + let algorithm_id = 6; + let count = Count::::new(); + let szk = Szk::new_turboshake128(count, algorithm_id); + let sum_vidpf = Vidpf::, TEST_NONCE_SIZE>::new(1); + + let mut nonce = [0u8; 16]; + let mut verify_key = [0u8; 16]; + thread_rng().fill(&mut verify_key[..]); + thread_rng().fill(&mut nonce[..]); + + let first_input = VidpfInput::from_bytes(&[15u8, 0u8, 1u8, 4u8][..]); + + let mastic = Mastic::new(algorithm_id, szk, sum_vidpf, 32); + let (public, _) = mastic.shard(&(first_input, true), &nonce).unwrap(); + + let encoded_public = public.get_encoded().unwrap(); + let decoded_public = + MasticPublicShare::get_decoded_with_param(&mastic, &encoded_public[..]).unwrap(); + assert_eq!(public, decoded_public); + } + + #[test] + fn test_mastic_sumvec() { let algorithm_id = 6; let sumvec = SumVec::>>::new(5, 3, 3).unwrap(); - let measurement = vec![1, 16, 0]; let szk = Szk::new_turboshake128(sumvec, algorithm_id); let sum_vidpf = Vidpf::, TEST_NONCE_SIZE>::new(15); @@ -467,10 +986,81 @@ mod tests { thread_rng().fill(&mut verify_key[..]); thread_rng().fill(&mut nonce[..]); - let first_input = VidpfInput::from_bytes(&[15u8, 0u8, 1u8, 4u8][..]); - + let first_input = VidpfInput::from_bytes(&[240u8, 0u8, 1u8, 4u8][..]); + let second_input = VidpfInput::from_bytes(&[112u8, 0u8, 1u8, 4u8][..]); + let third_input = VidpfInput::from_bytes(&[48u8, 0u8, 1u8, 4u8][..]); + let fourth_input = VidpfInput::from_bytes(&[32u8, 0u8, 1u8, 4u8][..]); + let fifth_input = VidpfInput::from_bytes(&[0u8, 0u8, 1u8, 4u8][..]); + let first_measurement = vec![1u128, 16u128, 0u128]; + let second_measurement = vec![0u128, 0u128, 0u128]; + let third_measurement = vec![0u128, 0u128, 0u128]; + let fourth_measurement = vec![1u128, 17u128, 31u128]; + let fifth_measurement = vec![6u128, 4u128, 11u128]; + let first_prefix = VidpfInput::from_bools(&[false, false, true]); + let second_prefix = VidpfInput::from_bools(&[false]); + let third_prefix = VidpfInput::from_bools(&[true]); let mastic = Mastic::new(algorithm_id, szk, sum_vidpf, 32); - let (_public, _input_shares) = mastic.shard(&(first_input, measurement), &nonce).unwrap(); + let first_agg_param = MasticAggregationParam::new(vec![first_prefix], true).unwrap(); + let second_agg_param = + MasticAggregationParam::new(vec![second_prefix, third_prefix], true).unwrap(); + + assert_eq!( + run_vdaf( + &mastic, + &first_agg_param, + [ + (first_input.clone(), first_measurement.clone()), + (second_input.clone(), second_measurement.clone()), + (third_input.clone(), third_measurement.clone()), + (fourth_input.clone(), fourth_measurement.clone()), + (fifth_input.clone(), fifth_measurement.clone()) + ] + ) + .unwrap(), + vec![vec![1, 17, 31]] + ); + + assert_eq!( + run_vdaf( + &mastic, + &second_agg_param, + [ + (first_input.clone(), first_measurement.clone()), + (second_input, second_measurement.clone()), + (third_input, third_measurement), + (fourth_input, fourth_measurement), + (fifth_input, fifth_measurement) + ] + ) + .unwrap(), + vec![vec![7, 21, 42], vec![1, 16, 0]] + ); + + let (public_share, input_shares) = mastic + .shard(&(first_input.clone(), first_measurement.clone()), &nonce) + .unwrap(); + run_vdaf_prepare( + &mastic, + &verify_key, + &first_agg_param, + &nonce, + public_share, + input_shares, + ) + .unwrap(); + + let (public_share, input_shares) = mastic + .shard(&(first_input, second_measurement.clone()), &nonce) + .unwrap(); + run_vdaf_prepare( + &mastic, + &verify_key, + &first_agg_param, + &nonce, + public_share, + input_shares, + ) + .unwrap(); } #[test] @@ -536,4 +1126,54 @@ mod tests { .unwrap(); assert_eq!(helper_input_share, &decoded_helper_input_share); } + + #[test] + fn test_public_share_encode_sumvec() { + let algorithm_id = 6; + let sumvec = + SumVec::>>::new(5, 3, 3).unwrap(); + let measurement = vec![1, 16, 0]; + let szk = Szk::new_turboshake128(sumvec, algorithm_id); + let sum_vidpf = Vidpf::, TEST_NONCE_SIZE>::new(15); + + let mut nonce = [0u8; 16]; + let mut verify_key = [0u8; 16]; + thread_rng().fill(&mut verify_key[..]); + thread_rng().fill(&mut nonce[..]); + + let first_input = VidpfInput::from_bytes(&[15u8, 0u8, 1u8, 4u8][..]); + + let mastic = Mastic::new(algorithm_id, szk, sum_vidpf, 32); + let (public, _) = mastic.shard(&(first_input, measurement), &nonce).unwrap(); + + assert_eq!( + public.encoded_len().unwrap(), + public.get_encoded().unwrap().len() + ); + } + + #[test] + fn test_public_share_roundtrip_sumvec() { + let algorithm_id = 6; + let sumvec = + SumVec::>>::new(5, 3, 3).unwrap(); + let measurement = vec![1, 16, 0]; + let szk = Szk::new_turboshake128(sumvec, algorithm_id); + let sum_vidpf = Vidpf::, TEST_NONCE_SIZE>::new(15); + + let mut nonce = [0u8; 16]; + let mut verify_key = [0u8; 16]; + thread_rng().fill(&mut verify_key[..]); + thread_rng().fill(&mut nonce[..]); + + let first_input = VidpfInput::from_bytes(&[15u8, 0u8, 1u8, 4u8][..]); + + let mastic = Mastic::new(algorithm_id, szk, sum_vidpf, 32); + let (public, _) = mastic.shard(&(first_input, measurement), &nonce).unwrap(); + + let encoded_public_share = public.get_encoded().unwrap(); + let decoded_public_share = + MasticPublicShare::get_decoded_with_param(&mastic, &encoded_public_share[..]).unwrap(); + assert_eq!(public, decoded_public_share); + } } diff --git a/src/vidpf.rs b/src/vidpf.rs index 3ec8d1347..29a4cc44b 100644 --- a/src/vidpf.rs +++ b/src/vidpf.rs @@ -15,13 +15,15 @@ use core::{ }; use bitvec::field::BitField; +use bitvec::prelude::{BitVec, Lsb0}; use rand_core::RngCore; use std::fmt::Debug; -use std::io::Cursor; +use std::io::{Cursor, Read}; use subtle::{Choice, ConditionallyNegatable, ConditionallySelectable, ConstantTimeEq}; use crate::{ - codec::{CodecError, Encode, ParameterizedDecode}, + bt::{BinaryTree, Node}, + codec::{CodecError, Decode, Encode, ParameterizedDecode}, field::FieldElement, idpf::{ conditional_select_seed, conditional_swap_seed, conditional_xor_seeds, xor_seeds, @@ -105,10 +107,7 @@ impl Vidpf { weight: &W, nonce: &[u8; NONCE_SIZE], ) -> Result<(VidpfPublicShare, [VidpfKey; 2]), VidpfError> { - let keys = [ - VidpfKey::gen(VidpfServerId::S0)?, - VidpfKey::gen(VidpfServerId::S1)?, - ]; + let keys = [VidpfKey::generate()?, VidpfKey::generate()?]; let public = self.gen_with_keys(&keys, input, weight, nonce)?; Ok((public, keys)) } @@ -122,12 +121,11 @@ impl Vidpf { weight: &W, nonce: &[u8; NONCE_SIZE], ) -> Result, VidpfError> { - if keys[0].id == keys[1].id { - return Err(VidpfError::SameKeyId); - } - - let mut s_i = [keys[0].value, keys[1].value]; - let mut t_i = [Choice::from(keys[0].id), Choice::from(keys[1].id)]; + let mut s_i = [keys[0].0, keys[1].0]; + let mut t_i = [ + Choice::from(VidpfServerId::S0), + Choice::from(VidpfServerId::S1), + ]; let n = input.len(); let mut cw = Vec::with_capacity(n); @@ -204,12 +202,13 @@ impl Vidpf { /// input's weight. pub fn eval( &self, + id: VidpfServerId, key: &VidpfKey, public: &VidpfPublicShare, input: &VidpfInput, nonce: &[u8; NONCE_SIZE], ) -> Result, VidpfError> { - let mut state = VidpfEvalState::init_from_key(key); + let mut state = VidpfEvalState::init_from_key(id, key); let mut share = W::zero(&self.weight_parameter); let n = input.len(); @@ -217,7 +216,7 @@ impl Vidpf { return Err(VidpfError::InvalidAttributeLength); } for level in 0..n { - (state, share) = self.eval_next(key.id, public, input, level, &state, nonce)?; + (state, share) = self.eval_next(id, public, input, level, &state, nonce)?; } Ok(VidpfValueShare { @@ -226,6 +225,57 @@ impl Vidpf { }) } + /// [`Vidpf::eval_with_cache`] evaluates the entire `input` and produces a share of the + /// input's weight. It reuses computation from previous levels available in the + /// cache. + pub fn eval_with_cache( + &self, + id: VidpfServerId, + key: &VidpfKey, + public: &VidpfPublicShare, + input: &VidpfInput, + cache_tree: &mut BinaryTree>, + nonce: &[u8; NONCE_SIZE], + ) -> Result, VidpfError> { + let n = input.len(); + if n > public.cw.len() { + return Err(VidpfError::InvalidAttributeLength); + } + + if cache_tree.root.is_none() { + cache_tree.root = Some(Box::new(Node::new(VidpfEvalCache { + state: VidpfEvalState::init_from_key(id, key), + share: W::zero(&self.weight_parameter), // not used + }))); + } + + let mut sub_tree = cache_tree.root.as_mut().expect("root was visited"); + for (level, bit) in input.iter().enumerate() { + sub_tree = if !bit { + if sub_tree.left.is_none() { + let (new_state, new_share) = + self.eval_next(id, public, input, level, &sub_tree.value.state, nonce)?; + sub_tree.left = Some(Box::new(Node::new(VidpfEvalCache { + state: new_state, + share: new_share, + }))); + } + sub_tree.left.as_mut().expect("left child was visited") + } else { + if sub_tree.right.is_none() { + let (new_state, new_share) = + self.eval_next(id, public, input, level, &sub_tree.value.state, nonce)?; + sub_tree.right = Some(Box::new(Node::new(VidpfEvalCache { + state: new_state, + share: new_share, + }))); + } + sub_tree.right.as_mut().expect("right child was visited") + } + } + Ok(sub_tree.value.to_share()) + } + /// [`Vidpf::eval_next`] evaluates the `input` at the given level using the provided initial /// state, and returns a new state and a share of the input's weight at that level. fn eval_next( @@ -276,17 +326,60 @@ impl Vidpf { Ok((next_state, y)) } + pub(crate) fn eval_root_with_cache( + &self, + id: VidpfServerId, + key: &VidpfKey, + public_share: &VidpfPublicShare, + cache_tree: &mut BinaryTree>, + nonce: &[u8; NONCE_SIZE], + ) -> Result { + Ok(self + .eval_with_cache( + id, + key, + public_share, + &VidpfInput::from_bools(&[false]), + cache_tree, + nonce, + )? + .share + + self + .eval_with_cache( + id, + key, + public_share, + &VidpfInput::from_bools(&[true]), + cache_tree, + nonce, + )? + .share) + } + pub(crate) fn eval_root( &self, + id: VidpfServerId, key: &VidpfKey, public_share: &VidpfPublicShare, nonce: &[u8; NONCE_SIZE], ) -> Result { Ok(self - .eval(key, public_share, &VidpfInput::from_bools(&[false]), nonce)? + .eval( + id, + key, + public_share, + &VidpfInput::from_bools(&[false]), + nonce, + )? .share + self - .eval(key, public_share, &VidpfInput::from_bools(&[true]), nonce)? + .eval( + id, + key, + public_share, + &VidpfInput::from_bools(&[true]), + nonce, + )? .share) } @@ -374,52 +467,16 @@ impl VidpfDomainSepTag { const NODE_PROOF_ADJUST: &'static [u8] = b"NodeProofAdjust"; } -#[derive(Clone, Debug)] -/// Vidpf key +/// Vidpf key. /// /// Private key of an aggregation server. -pub struct VidpfKey { - id: VidpfServerId, - pub(crate) value: [u8; 16], -} - -impl VidpfKey { - /// Generates a key at random. - /// - /// # Errors - /// Triggers an error if the random generator fails. - pub(crate) fn gen(id: VidpfServerId) -> Result { - let mut value = [0; 16]; - getrandom::getrandom(&mut value)?; - Ok(Self { id, value }) - } - - pub(crate) fn new(id: VidpfServerId, value: [u8; 16]) -> Self { - Self { id, value } - } -} +pub type VidpfKey = Seed<16>; -impl ConstantTimeEq for VidpfKey { - fn ct_eq(&self, other: &VidpfKey) -> Choice { - if self.id != other.id { - Choice::from(0) - } else { - self.value.ct_eq(&other.value) - } - } -} - -impl PartialEq for VidpfKey { - fn eq(&self, other: &VidpfKey) -> bool { - bool::from(self.ct_eq(other)) - } -} - -/// Vidpf server ID +/// Vidpf server ID. /// /// Identifies the two aggregation servers. #[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub(crate) enum VidpfServerId { +pub enum VidpfServerId { /// S0 is the first server. S0, /// S1 is the second server. @@ -435,7 +492,7 @@ impl From for Choice { } } -/// Vidpf correction word +/// Vidpf correction word. /// /// Adjusts values of shares during the VIDPF evaluation. #[derive(Clone, Debug)] @@ -464,25 +521,6 @@ where } } -impl Encode for VidpfCorrectionWord { - fn encode(&self, _bytes: &mut Vec) -> Result<(), CodecError> { - todo!(); - } - - fn encoded_len(&self) -> Option { - todo!(); - } -} - -impl ParameterizedDecode for VidpfCorrectionWord { - fn decode_with_param( - _decoding_parameter: &W::ValueParameter, - _bytes: &mut Cursor<&[u8]>, - ) -> Result { - todo!(); - } -} - /// Vidpf public share /// /// Common public information used by aggregation servers. @@ -493,27 +531,83 @@ pub struct VidpfPublicShare { } impl Encode for VidpfPublicShare { - fn encode(&self, _bytes: &mut Vec) -> Result<(), CodecError> { - todo!() + fn encode(&self, bytes: &mut Vec) -> Result<(), CodecError> { + // Control bits need to be written within each byte in LSB-to-MSB order, and assigned into + // bytes in big-endian order. Thus, the first four levels will have their control bits + // encoded in the last byte, and the last levels will have their control bits encoded in the + // first byte. + let mut control_bits: BitVec = BitVec::with_capacity(self.cw.len() * 2); + for correction_words in self.cw.iter() { + control_bits.extend( + [ + bool::from(correction_words.left_control_bit), + bool::from(correction_words.right_control_bit), + ] + .iter(), + ); + } + control_bits.set_uninitialized(false); + let mut packed_control = control_bits.into_vec(); + bytes.append(&mut packed_control); + + for correction_words in self.cw.iter() { + Seed(correction_words.seed).encode(bytes)?; + correction_words.weight.encode(bytes)?; + } + + for proof in &self.cs { + bytes.extend_from_slice(proof); + } + Ok(()) } fn encoded_len(&self) -> Option { - todo!() + let control_bits_count = (self.cw.len()) * 2; + let mut len = (control_bits_count + 7) / 8 + (self.cw.len()) * 16; + for correction_words in self.cw.iter() { + len += correction_words.weight.encoded_len()?; + } + len += self.cs.len() * 32; + Some(len) } } impl ParameterizedDecode<(usize, W::ValueParameter)> for VidpfPublicShare { fn decode_with_param( - (_bits, _weight_parameter): &(usize, W::ValueParameter), - _bytes: &mut Cursor<&[u8]>, + (bits, weight_parameter): &(usize, W::ValueParameter), + bytes: &mut Cursor<&[u8]>, ) -> Result { - todo!() + let packed_control_len = (bits + 3) / 4; + let mut packed = vec![0u8; packed_control_len]; + bytes.read_exact(&mut packed)?; + let unpacked_control_bits: BitVec = BitVec::from_vec(packed); + + let mut cw = Vec::>::with_capacity(*bits); + for chunk in unpacked_control_bits[0..(bits) * 2].chunks(2) { + let left_control_bit = (chunk[0] as u8).into(); + let right_control_bit = (chunk[1] as u8).into(); + let seed = Seed::decode(bytes)?.0; + cw.push(VidpfCorrectionWord { + seed, + left_control_bit, + right_control_bit, + weight: W::decode_with_param(weight_parameter, bytes)?, + }) + } + let mut cs = Vec::::with_capacity(*bits); + for _ in 0..*bits { + let mut proof = [0u8; 32]; + bytes.read_exact(&mut proof)?; + cs.push(proof); + } + Ok(Self { cw, cs }) } } /// Vidpf evaluation state /// /// Contains the values produced during input evaluation at a given level. +#[derive(Debug)] pub struct VidpfEvalState { seed: VidpfSeed, control_bit: Choice, @@ -521,15 +615,44 @@ pub struct VidpfEvalState { } impl VidpfEvalState { - fn init_from_key(key: &VidpfKey) -> Self { + fn init_from_key(id: VidpfServerId, key: &VidpfKey) -> Self { Self { - seed: key.value, - control_bit: Choice::from(key.id), + seed: key.0, + control_bit: Choice::from(id), proof: VidpfProof::default(), } } } +/// Vidpf evaluation cache +/// +/// Contains the values produced during input evaluation at a given level. +#[derive(Debug)] +pub struct VidpfEvalCache { + state: VidpfEvalState, + share: W, +} + +impl VidpfEvalCache { + pub(crate) fn init_from_key( + id: VidpfServerId, + key: &VidpfKey, + length: &W::ValueParameter, + ) -> Self { + Self { + state: VidpfEvalState::init_from_key(id, key), + share: W::zero(length), + } + } + + fn to_share(&self) -> VidpfValueShare { + VidpfValueShare:: { + share: self.share.clone(), + proof: self.state.proof, + } + } +} + /// Contains a share of the input's weight together with a proof for verification. pub struct VidpfValueShare { /// Secret share of the input's weight. @@ -703,6 +826,7 @@ impl ParameterizedDecode<::ValueParameter> f #[cfg(test)] mod tests { + use crate::field::Field128; use super::VidpfWeight; @@ -714,15 +838,37 @@ mod tests { mod vidpf { use crate::{ + bt::BinaryTree, + codec::{Encode, ParameterizedDecode}, idpf::IdpfValue, vidpf::{ - Vidpf, VidpfError, VidpfEvalState, VidpfInput, VidpfKey, VidpfPublicShare, + Vidpf, VidpfEvalCache, VidpfEvalState, VidpfInput, VidpfKey, VidpfPublicShare, VidpfServerId, }, }; + use std::io::Cursor; use super::{TestWeight, TEST_NONCE, TEST_NONCE_SIZE, TEST_WEIGHT_LEN}; + #[test] + fn roundtrip_codec() { + let input = VidpfInput::from_bytes(&[0xFF]); + let weight = TestWeight::from(vec![21.into(), 22.into(), 23.into()]); + let (_, public, _, _) = vidpf_gen_setup(&input, &weight); + + let mut bytes = vec![]; + public.encode(&mut bytes).unwrap(); + + assert_eq!(public.encoded_len().unwrap(), bytes.len()); + + let decoded = VidpfPublicShare::::decode_with_param( + &(8, TEST_WEIGHT_LEN), + &mut Cursor::new(&bytes), + ) + .unwrap(); + assert_eq!(public, decoded); + } + fn vidpf_gen_setup( input: &VidpfInput, weight: &TestWeight, @@ -737,31 +883,18 @@ mod tests { (vidpf, public, keys, *TEST_NONCE) } - #[test] - fn gen_with_keys() { - let input = VidpfInput::from_bytes(&[0xFF]); - let weight = TestWeight::from(vec![21.into(), 22.into(), 23.into()]); - let vidpf = Vidpf::new(TEST_WEIGHT_LEN); - let keys_with_same_id = [ - VidpfKey::gen(VidpfServerId::S0).unwrap(), - VidpfKey::gen(VidpfServerId::S0).unwrap(), - ]; - - let err = vidpf - .gen_with_keys(&keys_with_same_id, &input, &weight, TEST_NONCE) - .unwrap_err(); - - assert_eq!(err.to_string(), VidpfError::SameKeyId.to_string()); - } - #[test] fn correctness_at_last_level() { let input = VidpfInput::from_bytes(&[0xFF]); let weight = TestWeight::from(vec![21.into(), 22.into(), 23.into()]); let (vidpf, public, [key_0, key_1], nonce) = vidpf_gen_setup(&input, &weight); - let value_share_0 = vidpf.eval(&key_0, &public, &input, &nonce).unwrap(); - let value_share_1 = vidpf.eval(&key_1, &public, &input, &nonce).unwrap(); + let value_share_0 = vidpf + .eval(VidpfServerId::S0, &key_0, &public, &input, &nonce) + .unwrap(); + let value_share_1 = vidpf + .eval(VidpfServerId::S1, &key_1, &public, &input, &nonce) + .unwrap(); assert_eq!( value_share_0.share + value_share_1.share, @@ -776,8 +909,12 @@ mod tests { let bad_input = VidpfInput::from_bytes(&[0x00]); let zero = TestWeight::zero(&TEST_WEIGHT_LEN); - let value_share_0 = vidpf.eval(&key_0, &public, &bad_input, &nonce).unwrap(); - let value_share_1 = vidpf.eval(&key_1, &public, &bad_input, &nonce).unwrap(); + let value_share_0 = vidpf + .eval(VidpfServerId::S0, &key_0, &public, &bad_input, &nonce) + .unwrap(); + let value_share_1 = vidpf + .eval(VidpfServerId::S1, &key_1, &public, &bad_input, &nonce) + .unwrap(); assert_eq!( value_share_0.share + value_share_1.share, @@ -813,18 +950,18 @@ mod tests { weight: &TestWeight, nonce: &[u8; TEST_NONCE_SIZE], ) { - let mut state_0 = VidpfEvalState::init_from_key(key_0); - let mut state_1 = VidpfEvalState::init_from_key(key_1); + let mut state_0 = VidpfEvalState::init_from_key(VidpfServerId::S0, key_0); + let mut state_1 = VidpfEvalState::init_from_key(VidpfServerId::S1, key_1); let n = input.len(); for level in 0..n { let share_0; let share_1; (state_0, share_0) = vidpf - .eval_next(key_0.id, public, input, level, &state_0, nonce) + .eval_next(VidpfServerId::S0, public, input, level, &state_0, nonce) .unwrap(); (state_1, share_1) = vidpf - .eval_next(key_1.id, public, input, level, &state_1, nonce) + .eval_next(VidpfServerId::S1, public, input, level, &state_1, nonce) .unwrap(); assert_eq!( @@ -841,6 +978,94 @@ mod tests { ); } } + + #[test] + fn caching_at_each_level() { + let input = VidpfInput::from_bytes(&[0xFF]); + let weight = TestWeight::from(vec![21.into(), 22.into(), 23.into()]); + let (vidpf, public, keys, nonce) = vidpf_gen_setup(&input, &weight); + + test_equivalence_of_eval_with_caching(&vidpf, &keys, &public, &input, &nonce); + } + + /// Ensures that VIDPF outputs match regardless of whether the path to + /// each node is recomputed or cached during evaluation. + fn test_equivalence_of_eval_with_caching( + vidpf: &Vidpf, + [key_0, key_1]: &[VidpfKey; 2], + public: &VidpfPublicShare, + input: &VidpfInput, + nonce: &[u8; TEST_NONCE_SIZE], + ) { + let mut cache_tree_0 = BinaryTree::>::default(); + let mut cache_tree_1 = BinaryTree::>::default(); + + let n = input.len(); + for level in 0..n { + let val_share_0 = vidpf + .eval( + VidpfServerId::S0, + key_0, + public, + &input.prefix(level), + nonce, + ) + .unwrap(); + let val_share_1 = vidpf + .eval( + VidpfServerId::S1, + key_1, + public, + &input.prefix(level), + nonce, + ) + .unwrap(); + let val_share_0_cached = vidpf + .eval_with_cache( + VidpfServerId::S0, + key_0, + public, + &input.prefix(level), + &mut cache_tree_0, + nonce, + ) + .unwrap(); + let val_share_1_cached = vidpf + .eval_with_cache( + VidpfServerId::S1, + key_1, + public, + &input.prefix(level), + &mut cache_tree_1, + nonce, + ) + .unwrap(); + + assert_eq!( + val_share_0.share, val_share_0_cached.share, + "shares must be computed equally with or without caching: {:?}", + level + ); + + assert_eq!( + val_share_1.share, val_share_1_cached.share, + "shares must be computed equally with or without caching: {:?}", + level + ); + + assert_eq!( + val_share_0.proof, val_share_0_cached.proof, + "proofs must be equal with or without caching: {:?}", + level + ); + + assert_eq!( + val_share_1.proof, val_share_1_cached.proof, + "proofs must be equal with or without caching: {:?}", + level + ); + } + } } mod weight {