diff --git a/benches/speed_tests.rs b/benches/speed_tests.rs index cac779ad..48e6a806 100644 --- a/benches/speed_tests.rs +++ b/benches/speed_tests.rs @@ -815,7 +815,7 @@ fn vidpf(c: &mut Criterion) { b.iter(|| { let _ = vidpf - .eval(&VidpfServerId::S0, &keys[0], &public, &input, NONCE) + .eval(VidpfServerId::S0, &keys[0], &public, &input, NONCE) .unwrap(); }); }); diff --git a/src/vdaf/mastic.rs b/src/vdaf/mastic.rs index 3ca214ce..9148549f 100644 --- a/src/vdaf/mastic.rs +++ b/src/vdaf/mastic.rs @@ -308,10 +308,10 @@ where let leader_measurement_share = self.vidpf - .eval_root(&VidpfServerId::S0, &vidpf_keys[0], &public_share, nonce)?; + .eval_root(VidpfServerId::S0, &vidpf_keys[0], &public_share, nonce)?; let helper_measurement_share = self.vidpf - .eval_root(&VidpfServerId::S1, &vidpf_keys[1], &public_share, nonce)?; + .eval_root(VidpfServerId::S1, &vidpf_keys[1], &public_share, nonce)?; let [leader_szk_proof_share, helper_szk_proof_share] = self.szk.prove( leader_measurement_share.as_ref(), @@ -537,7 +537,7 @@ where ); let mut cache_tree = BinaryTree::>>::default(); let cache = VidpfEvalCache::>::init_from_key( - &id, + id, &input_share.vidpf_key, &self.vidpf.weight_parameter, ); @@ -546,7 +546,7 @@ where .expect("Should alwys be able to insert into empty tree at root"); for prefix in agg_param.level_and_prefixes.prefixes() { let mut value_share = self.vidpf.eval_with_cache( - &id, + id, &input_share.vidpf_key, public_share, prefix, @@ -558,7 +558,7 @@ where } let root_share_opt = if agg_param.require_weight_check { Some(self.vidpf.eval_root_with_cache( - &id, + id, &input_share.vidpf_key, public_share, &mut cache_tree, @@ -624,53 +624,24 @@ where ))?; if inputs_iter.next().is_some() { return Err(VdafError::Uncategorized( - "more than 2 prepare shares".to_string(), + "Received more than two prepare shares".to_string(), )); }; - - match (leader_share, helper_share) { - ( - MasticPrepareShare { - vidpf_proof: leader_vidpf_proof, - szk_query_share_opt: Some(leader_query_share), - }, - MasticPrepareShare { - vidpf_proof: helper_vidpf_proof, - szk_query_share_opt: Some(helper_query_share), - }, - ) => { - if leader_vidpf_proof == helper_vidpf_proof { - Ok(Some(SzkQueryShare::merge_verifiers( - leader_query_share, - helper_query_share, - ))) - } else { - Err(VdafError::Uncategorized( - "Vidpf proof verification failed".to_string(), - )) - } - } - ( - MasticPrepareShare { - vidpf_proof: leader_vidpf_proof, - szk_query_share_opt: None, - }, - MasticPrepareShare { - vidpf_proof: helper_vidpf_proof, - szk_query_share_opt: None, - }, - ) => { - if leader_vidpf_proof == helper_vidpf_proof { - Ok(None) - } else { - Err(VdafError::Uncategorized( - "Vidpf proof verification failed".to_string(), - )) - } - } - _ => Err(VdafError::Uncategorized( - "Prepare state and message disagree on whether Szk verification should occur" - .to_string(), + if leader_share.vidpf_proof != helper_share.vidpf_proof { + return Err(VdafError::Uncategorized( + "Vidpf proof verification failed".to_string(), + )); + }; + match ( + leader_share.szk_query_share_opt, + helper_share.szk_query_share_opt, + ) { + (Some(leader_query_share), Some(helper_query_share)) => Ok(Some( + SzkQueryShare::merge_verifiers(leader_query_share, helper_query_share), + )), + (None, None) => Ok(None), + (_, _) => Err(VdafError::Uncategorized( + "Only one of leader and helper query shares is present".to_string(), )), } } diff --git a/src/vidpf.rs b/src/vidpf.rs index a2d5e04b..d9646620 100644 --- a/src/vidpf.rs +++ b/src/vidpf.rs @@ -202,7 +202,7 @@ impl Vidpf { /// input's weight. pub fn eval( &self, - id: &VidpfServerId, + id: VidpfServerId, key: &VidpfKey, public: &VidpfPublicShare, input: &VidpfInput, @@ -230,7 +230,7 @@ impl Vidpf { /// cache. pub fn eval_with_cache( &self, - id: &VidpfServerId, + id: VidpfServerId, key: &VidpfKey, public: &VidpfPublicShare, input: &VidpfInput, @@ -280,7 +280,7 @@ impl Vidpf { /// state, and returns a new state and a share of the input's weight at that level. fn eval_next( &self, - id: &VidpfServerId, + id: VidpfServerId, public: &VidpfPublicShare, input: &VidpfInput, level: usize, @@ -306,7 +306,7 @@ impl Vidpf { let zero = ::zero(&self.weight_parameter); let mut y = ::conditional_select(&zero, &cw.weight, next_control_bit); y += w_i; - y.conditional_negate(Choice::from(*id)); + y.conditional_negate(Choice::from(id)); let pi_i = &state.proof; let cs_i = public.cs.get(level).ok_or(VidpfError::IndexLevel)?; @@ -328,7 +328,7 @@ impl Vidpf { pub(crate) fn eval_root_with_cache( &self, - id: &VidpfServerId, + id: VidpfServerId, key: &VidpfKey, public_share: &VidpfPublicShare, cache_tree: &mut BinaryTree>, @@ -358,7 +358,7 @@ impl Vidpf { pub(crate) fn eval_root( &self, - id: &VidpfServerId, + id: VidpfServerId, key: &VidpfKey, public_share: &VidpfPublicShare, nonce: &[u8; NONCE_SIZE], @@ -615,10 +615,10 @@ pub struct VidpfEvalState { } impl VidpfEvalState { - fn init_from_key(id: &VidpfServerId, key: &VidpfKey) -> Self { + fn init_from_key(id: VidpfServerId, key: &VidpfKey) -> Self { Self { seed: key.0, - control_bit: Choice::from(*id), + control_bit: Choice::from(id), proof: VidpfProof::default(), } } @@ -635,7 +635,7 @@ pub struct VidpfEvalCache { impl VidpfEvalCache { pub(crate) fn init_from_key( - id: &VidpfServerId, + id: VidpfServerId, key: &VidpfKey, length: &W::ValueParameter, ) -> Self { @@ -839,15 +839,36 @@ mod tests { mod vidpf { use crate::{ bt::{BinaryTree, Path}, + codec::{Encode, ParameterizedDecode}, idpf::IdpfValue, vidpf::{ Vidpf, VidpfEvalCache, VidpfEvalState, VidpfInput, VidpfKey, VidpfPublicShare, VidpfServerId, }, }; + use std::io::Cursor; use super::{TestWeight, TEST_NONCE, TEST_NONCE_SIZE, TEST_WEIGHT_LEN}; + #[test] + fn roundtrip_codec() { + let input = VidpfInput::from_bytes(&[0xFF]); + let weight = TestWeight::from(vec![21.into(), 22.into(), 23.into()]); + let (_, public, _, _) = vidpf_gen_setup(&input, &weight); + + let mut bytes = vec![]; + public.encode(&mut bytes).unwrap(); + + assert_eq!(public.encoded_len().unwrap(), bytes.len()); + + let decoded = VidpfPublicShare::::decode_with_param( + &(8, TEST_WEIGHT_LEN), + &mut Cursor::new(&bytes), + ) + .unwrap(); + assert_eq!(public, decoded); + } + fn vidpf_gen_setup( input: &VidpfInput, weight: &TestWeight, @@ -869,10 +890,10 @@ mod tests { let (vidpf, public, [key_0, key_1], nonce) = vidpf_gen_setup(&input, &weight); let value_share_0 = vidpf - .eval(&VidpfServerId::S0, &key_0, &public, &input, &nonce) + .eval(VidpfServerId::S0, &key_0, &public, &input, &nonce) .unwrap(); let value_share_1 = vidpf - .eval(&VidpfServerId::S1, &key_1, &public, &input, &nonce) + .eval(VidpfServerId::S1, &key_1, &public, &input, &nonce) .unwrap(); assert_eq!( @@ -889,10 +910,10 @@ mod tests { let bad_input = VidpfInput::from_bytes(&[0x00]); let zero = TestWeight::zero(&TEST_WEIGHT_LEN); let value_share_0 = vidpf - .eval(&VidpfServerId::S0, &key_0, &public, &bad_input, &nonce) + .eval(VidpfServerId::S0, &key_0, &public, &bad_input, &nonce) .unwrap(); let value_share_1 = vidpf - .eval(&VidpfServerId::S1, &key_1, &public, &bad_input, &nonce) + .eval(VidpfServerId::S1, &key_1, &public, &bad_input, &nonce) .unwrap(); assert_eq!( @@ -929,18 +950,18 @@ mod tests { weight: &TestWeight, nonce: &[u8; TEST_NONCE_SIZE], ) { - let mut state_0 = VidpfEvalState::init_from_key(&VidpfServerId::S0, key_0); - let mut state_1 = VidpfEvalState::init_from_key(&VidpfServerId::S1, key_1); + let mut state_0 = VidpfEvalState::init_from_key(VidpfServerId::S0, key_0); + let mut state_1 = VidpfEvalState::init_from_key(VidpfServerId::S1, key_1); let n = input.len(); for level in 0..n { let share_0; let share_1; (state_0, share_0) = vidpf - .eval_next(&VidpfServerId::S0, public, input, level, &state_0, nonce) + .eval_next(VidpfServerId::S0, public, input, level, &state_0, nonce) .unwrap(); (state_1, share_1) = vidpf - .eval_next(&VidpfServerId::S1, public, input, level, &state_1, nonce) + .eval_next(VidpfServerId::S1, public, input, level, &state_1, nonce) .unwrap(); assert_eq!( @@ -964,10 +985,12 @@ mod tests { let weight = TestWeight::from(vec![21.into(), 22.into(), 23.into()]); let (vidpf, public, keys, nonce) = vidpf_gen_setup(&input, &weight); - equivalence_of_eval_with_caching(&vidpf, &keys, &public, &input, &nonce); + test_equivalence_of_eval_with_caching(&vidpf, &keys, &public, &input, &nonce); } - fn equivalence_of_eval_with_caching( + /// Ensures that VIDPF outputs match regardless of whether the path to + /// each node is recomputed or cached during evaluation. + fn test_equivalence_of_eval_with_caching( vidpf: &Vidpf, [key_0, key_1]: &[VidpfKey; 2], public: &VidpfPublicShare, @@ -977,12 +1000,12 @@ mod tests { let mut cache_tree_0 = BinaryTree::>::default(); let mut cache_tree_1 = BinaryTree::>::default(); let cache_0 = VidpfEvalCache::::init_from_key( - &VidpfServerId::S0, + VidpfServerId::S0, key_0, &vidpf.weight_parameter, ); let cache_1 = VidpfEvalCache::::init_from_key( - &VidpfServerId::S1, + VidpfServerId::S1, key_1, &vidpf.weight_parameter, ); @@ -997,7 +1020,7 @@ mod tests { for level in 0..n { let val_share_0 = vidpf .eval( - &VidpfServerId::S0, + VidpfServerId::S0, key_0, public, &input.prefix(level), @@ -1006,7 +1029,7 @@ mod tests { .unwrap(); let val_share_1 = vidpf .eval( - &VidpfServerId::S1, + VidpfServerId::S1, key_1, public, &input.prefix(level), @@ -1015,7 +1038,7 @@ mod tests { .unwrap(); let val_share_0_cached = vidpf .eval_with_cache( - &VidpfServerId::S0, + VidpfServerId::S0, key_0, public, &input.prefix(level), @@ -1025,7 +1048,7 @@ mod tests { .unwrap(); let val_share_1_cached = vidpf .eval_with_cache( - &VidpfServerId::S1, + VidpfServerId::S1, key_1, public, &input.prefix(level),