From 2c36004fd4e95760d196926b114d0c80d3649cc2 Mon Sep 17 00:00:00 2001 From: Hannah Davis Date: Wed, 12 Jun 2024 14:51:23 +0000 Subject: [PATCH] Create Mastic module with server implementation Implements aggregator and collector functionality for the Mastic protocol for weighted heavy-hitters and attribute-based metrics. --- src/codec.rs | 97 ++++ src/flp.rs | 2 +- src/flp/szk.rs | 840 ++++++++++++++++++++++++----- src/lib.rs | 2 - src/mastic.rs | 336 ------------ src/vdaf.rs | 10 + src/vdaf/mastic.rs | 1254 ++++++++++++++++++++++++++++++++++++++++++++ src/vidpf.rs | 184 ++++++- 8 files changed, 2255 insertions(+), 470 deletions(-) delete mode 100644 src/mastic.rs create mode 100644 src/vdaf/mastic.rs diff --git a/src/codec.rs b/src/codec.rs index 98e6299a..fda71ac1 100644 --- a/src/codec.rs +++ b/src/codec.rs @@ -270,6 +270,73 @@ impl Encode for u64 { } } +impl Decode for [D; SIZE] { + fn decode(bytes: &mut Cursor<&[u8]>) -> Result { + let mut v = Vec::with_capacity(SIZE); + for _ in 0..SIZE { + v.push(D::decode(bytes)?); + } + Ok(v.try_into().expect("If the above for loop completes, then the vector will always contain exactly BUFFER_SIZE elements.")) + } +} + +impl Encode for [E; BUFFER_SIZE] { + fn encode(&self, bytes: &mut Vec) -> Result<(), CodecError> { + for input in self { + input.encode(bytes)? + } + Ok(()) + } + + fn encoded_len(&self) -> Option { + let mut total = 0; + for item in self { + total += item.encoded_len()? + } + Some(total) + } +} + +impl ParameterizedDecode for Vec { + fn decode_with_param(len: &usize, bytes: &mut Cursor<&[u8]>) -> Result { + let mut out = Vec::with_capacity(*len); + for _ in 0..*len { + out.push(::decode(bytes)?) + } + Ok(out) + } +} + +impl> ParameterizedDecode<(usize, P)> for Vec { + fn decode_with_param( + (len, decoding_parameter): &(usize, P), + bytes: &mut Cursor<&[u8]>, + ) -> Result { + let mut out = Vec::with_capacity(*len); + for _ in 0..*len { + out.push(::decode_with_param(decoding_parameter, bytes)?) + } + Ok(out) + } +} + +impl Encode for Vec { + fn encode(&self, bytes: &mut Vec) -> Result<(), CodecError> { + for input in self { + input.encode(bytes)? + } + Ok(()) + } + + fn encoded_len(&self) -> Option { + let mut total = 0; + for item in self { + total += item.encoded_len()? + } + Some(total) + } +} + /// Encode `items` into `bytes` as a [variable-length vector][1] with a maximum length of `0xff`. /// /// [1]: https://datatracker.ietf.org/doc/html/rfc8446#section-3.4 @@ -533,6 +600,28 @@ mod tests { assert_eq!(value, decoded); } + #[test] + fn roundtrip_vec() { + let value = vec![1u32, 2u32, 3u32, 4u32]; + let mut bytes = vec![]; + value.encode(&mut bytes).unwrap(); + assert_eq!(bytes.len(), 16); + assert_eq!(bytes, vec![0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0, 4]); + let decoded = Vec::::decode_with_param(&4, &mut Cursor::new(&bytes)).unwrap(); + assert_eq!(value, decoded); + } + + #[test] + fn roundtrip_array() { + let value = [1u32, 2u32, 3u32, 4u32]; + let mut bytes = vec![]; + value.encode(&mut bytes).unwrap(); + assert_eq!(bytes.len(), 16); + assert_eq!(bytes, vec![0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0, 4]); + let decoded = <[u32; 4]>::decode(&mut Cursor::new(&bytes)).unwrap(); + assert_eq!(value, decoded); + } + #[derive(Debug, Eq, PartialEq)] struct TestMessage { field_u8: u8, @@ -779,6 +868,14 @@ mod tests { 0u64.encoded_len().unwrap(), 0u64.get_encoded().unwrap().len() ); + assert_eq!( + [0u8; 7].encoded_len().unwrap(), + [0u8; 7].get_encoded().unwrap().len() + ); + assert_eq!( + vec![0u8; 7].encoded_len().unwrap(), + vec![0u8; 7].get_encoded().unwrap().len() + ); } #[test] diff --git a/src/flp.rs b/src/flp.rs index 62308bf8..707d7333 100644 --- a/src/flp.rs +++ b/src/flp.rs @@ -57,7 +57,7 @@ use std::convert::TryFrom; use std::fmt::Debug; pub mod gadgets; -#[cfg(all(feature = "experimental", test))] +#[cfg(all(feature = "crypto-dependencies", feature = "experimental"))] pub mod szk; pub mod types; diff --git a/src/flp/szk.rs b/src/flp/szk.rs index a5256269..42dbd74f 100644 --- a/src/flp/szk.rs +++ b/src/flp/szk.rs @@ -12,17 +12,21 @@ //! following a strategy similar to [`Prio3`](crate::vdaf::prio3::Prio3). use crate::{ - codec::{CodecError, Encode}, - field::{FftFriendlyFieldElement, FieldElement}, + codec::{CodecError, Decode, Encode, ParameterizedDecode}, + field::{decode_fieldvec, encode_fieldvec, FieldElement}, flp::{FlpError, Type}, prng::{Prng, PrngError}, vdaf::xof::{IntoFieldVec, Seed, Xof, XofTurboShake128}, }; -use std::{borrow::Cow, marker::PhantomData}; +use std::borrow::Cow; +use std::ops::BitAnd; +use std::{io::Cursor, marker::PhantomData}; +use subtle::{Choice, ConstantTimeEq}; // Domain separation tags const DST_PROVE_RANDOMNESS: u16 = 0; const DST_PROOF_SHARE: u16 = 1; +#[allow(dead_code)] const DST_QUERY_RANDOMNESS: u16 = 2; const DST_JOINT_RAND_SEED: u16 = 3; const DST_JOINT_RAND_PART: u16 = 4; @@ -57,51 +61,301 @@ pub enum SzkError { /// Contains an FLP proof share, and if joint randomness is needed, the blind /// used to derive it and the other party's joint randomness part. -#[derive(Clone)] -pub enum SzkProofShare { - /// Leader's proof share is uncompressed. The first Seed is a blind, second - /// is a joint randomness part. +#[derive(Debug, Clone)] +pub enum SzkProofShare { + /// Leader's proof share is uncompressed. Leader { + /// Share of an FLP proof, as a vector of Field elements. uncompressed_proof_share: Vec, - leader_blind_and_helper_joint_rand_part: Option<(Seed, Seed)>, + /// Set only if joint randomness is needed. The first Seed is a blind, second + /// is the helper's joint randomness part. + leader_blind_and_helper_joint_rand_part_opt: Option<(Seed, Seed)>, }, /// The Helper uses one seed for both its compressed proof share and as the blind for its joint /// randomness. Helper { + /// The Seed that acts both as the compressed proof share and, optionally, as the blind. proof_share_seed_and_blind: Seed, - leader_joint_rand_part: Option>, + /// The leader's joint randomness part, if needed. + leader_joint_rand_part_opt: Option>, }, } +impl PartialEq for SzkProofShare { + fn eq(&self, other: &SzkProofShare) -> bool { + bool::from(self.ct_eq(other)) + } +} + +impl ConstantTimeEq for SzkProofShare { + fn ct_eq(&self, other: &SzkProofShare) -> Choice { + match (self, other) { + ( + SzkProofShare::Leader { + uncompressed_proof_share: s_proof, + leader_blind_and_helper_joint_rand_part_opt: s_blind, + }, + SzkProofShare::Leader { + uncompressed_proof_share: o_proof, + leader_blind_and_helper_joint_rand_part_opt: o_blind, + }, + ) => s_proof[..] + .ct_eq(&o_proof[..]) + .bitand(option_tuple_ct_eq(s_blind, o_blind)), + ( + SzkProofShare::Helper { + proof_share_seed_and_blind: s_seed, + leader_joint_rand_part_opt: s_rand, + }, + SzkProofShare::Helper { + proof_share_seed_and_blind: o_seed, + leader_joint_rand_part_opt: o_rand, + }, + ) => s_seed.ct_eq(o_seed).bitand(option_ct_eq(s_rand, o_rand)), + _ => Choice::from(0), + } + } +} + +impl Encode for SzkProofShare { + fn encode(&self, bytes: &mut Vec) -> Result<(), CodecError> { + match self { + SzkProofShare::Leader { + uncompressed_proof_share, + leader_blind_and_helper_joint_rand_part_opt, + } => ( + encode_fieldvec(uncompressed_proof_share, bytes)?, + if let Some((blind, helper_joint_rand_part)) = + leader_blind_and_helper_joint_rand_part_opt + { + blind.encode(bytes)?; + helper_joint_rand_part.encode(bytes)?; + }, + ), + SzkProofShare::Helper { + proof_share_seed_and_blind, + leader_joint_rand_part_opt, + } => ( + proof_share_seed_and_blind.encode(bytes)?, + if let Some(leader_joint_rand_part) = leader_joint_rand_part_opt { + leader_joint_rand_part.encode(bytes)?; + }, + ), + }; + Ok(()) + } + + fn encoded_len(&self) -> Option { + match self { + SzkProofShare::Leader { + uncompressed_proof_share, + leader_blind_and_helper_joint_rand_part_opt, + } => Some( + uncompressed_proof_share.len() * F::ENCODED_SIZE + + if let Some((blind, helper_joint_rand_part)) = + leader_blind_and_helper_joint_rand_part_opt + { + blind.encoded_len()? + helper_joint_rand_part.encoded_len()? + } else { + 0 + }, + ), + SzkProofShare::Helper { + proof_share_seed_and_blind, + leader_joint_rand_part_opt, + } => Some( + proof_share_seed_and_blind.encoded_len()? + + if let Some(leader_joint_rand_part) = leader_joint_rand_part_opt { + leader_joint_rand_part.encoded_len()? + } else { + 0 + }, + ), + } + } +} + +impl ParameterizedDecode<(bool, usize, bool)> + for SzkProofShare +{ + fn decode_with_param( + (is_leader, proof_len, requires_joint_rand): &(bool, usize, bool), + bytes: &mut Cursor<&[u8]>, + ) -> Result { + if *is_leader { + Ok(SzkProofShare::Leader { + uncompressed_proof_share: decode_fieldvec::(*proof_len, bytes)?, + leader_blind_and_helper_joint_rand_part_opt: if *requires_joint_rand { + Some(( + Seed::::decode(bytes)?, + Seed::::decode(bytes)?, + )) + } else { + None + }, + }) + } else { + Ok(SzkProofShare::Helper { + proof_share_seed_and_blind: Seed::::decode(bytes)?, + leader_joint_rand_part_opt: if *requires_joint_rand { + Some(Seed::::decode(bytes)?) + } else { + None + }, + }) + } + } +} + /// A tuple containing the state and messages produced by an SZK query. -#[derive(Clone)] -pub(crate) struct SzkQueryShare { - joint_rand_part: Option>, - verifier: SzkVerifier, +#[derive(Clone, Debug)] +pub struct SzkQueryShare { + joint_rand_part_opt: Option>, + flp_verifier: Vec, } -/// The state that needs to be stored by an Szk verifier between query() and decide() -pub(crate) struct SzkQueryState { - joint_rand_seed: Option>, +impl Encode for SzkQueryShare { + fn encode(&self, bytes: &mut Vec) -> Result<(), CodecError> { + if let Some(ref part) = self.joint_rand_part_opt { + part.encode(bytes)? + }; + + self.flp_verifier.encode(bytes)?; + Ok(()) + } + + fn encoded_len(&self) -> Option { + Some( + self.flp_verifier.encoded_len()? + + match self.joint_rand_part_opt { + Some(ref part) => part.encoded_len()?, + None => 0, + }, + ) + } +} + +impl ParameterizedDecode<(bool, usize)> + for SzkQueryShare +{ + fn decode_with_param( + (requires_joint_rand, verifier_len): &(bool, usize), + bytes: &mut Cursor<&[u8]>, + ) -> Result { + if *requires_joint_rand { + Ok(SzkQueryShare { + joint_rand_part_opt: Some(Seed::::decode(bytes)?), + flp_verifier: Vec::::decode_with_param(verifier_len, bytes)?, + }) + } else { + Ok(SzkQueryShare { + joint_rand_part_opt: None, + flp_verifier: Vec::::decode_with_param(verifier_len, bytes)?, + }) + } + } +} + +impl SzkQueryShare { + pub(crate) fn verifier_len(&self) -> usize { + self.flp_verifier.len() + } + + pub(crate) fn merge_verifiers( + mut leader_share: SzkQueryShare, + helper_share: SzkQueryShare, + ) -> SzkVerifier { + for (x, y) in leader_share + .flp_verifier + .iter_mut() + .zip(helper_share.flp_verifier) + { + *x += y; + } + SzkVerifier { + flp_verifier: leader_share.flp_verifier, + leader_joint_rand_part_opt: leader_share.joint_rand_part_opt, + helper_joint_rand_part_opt: helper_share.joint_rand_part_opt, + } + } } +/// The state that needs to be stored by an Szk verifier between query() and decide() +pub type SzkQueryState = Option>; + /// Verifier type for the SZK proof. -pub type SzkVerifier = Vec; +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct SzkVerifier { + flp_verifier: Vec, + leader_joint_rand_part_opt: Option>, + helper_joint_rand_part_opt: Option>, +} + +impl Encode for SzkVerifier { + fn encode(&self, bytes: &mut Vec) -> Result<(), CodecError> { + encode_fieldvec(&self.flp_verifier, bytes)?; + if let Some(ref part) = self.leader_joint_rand_part_opt { + part.encode(bytes)? + }; + if let Some(ref part) = self.helper_joint_rand_part_opt { + part.encode(bytes)? + }; + Ok(()) + } + + fn encoded_len(&self) -> Option { + Some( + self.flp_verifier.len() * F::ENCODED_SIZE + + match self.leader_joint_rand_part_opt { + Some(ref part) => part.encoded_len()?, + None => 0, + } + + match self.helper_joint_rand_part_opt { + Some(ref part) => part.encoded_len()?, + None => 0, + }, + ) + } +} + +impl ParameterizedDecode<(bool, usize)> + for SzkVerifier +{ + fn decode_with_param( + (requires_joint_rand, verifier_len): &(bool, usize), + bytes: &mut Cursor<&[u8]>, + ) -> Result { + if *requires_joint_rand { + Ok(SzkVerifier { + flp_verifier: decode_fieldvec(*verifier_len, bytes)?, + leader_joint_rand_part_opt: Some(Seed::::decode(bytes)?), + helper_joint_rand_part_opt: Some(Seed::::decode(bytes)?), + }) + } else { + Ok(SzkVerifier { + flp_verifier: decode_fieldvec(*verifier_len, bytes)?, + leader_joint_rand_part_opt: None, + helper_joint_rand_part_opt: None, + }) + } + } +} /// Main struct encapsulating the shared zero-knowledge functionality. The type /// T is the underlying FLP proof system. P is the XOF used to derive all random /// coins (it should be indifferentiable from a random oracle for security.) +#[derive(Clone, Debug)] pub struct Szk where T: Type, P: Xof, { - typ: T, + /// The Type representing the specific FLP system used to prove validity of an input. + pub(crate) typ: T, algorithm_id: u32, phantom: PhantomData

, } -#[cfg(test)] impl Szk { /// Create an instance of [`Szk`] using [`XofTurboShake128`]. pub fn new_turboshake128(typ: T, algorithm_id: u32) -> Self { @@ -124,6 +378,10 @@ where } } + pub(crate) fn typ(&self) -> &T { + &self.typ + } + fn domain_separation_tag(&self, usage: u16) -> [u8; 8] { let mut dst = [0u8; 8]; dst[0] = MASTIC_VERSION; @@ -221,7 +479,15 @@ where self.typ.joint_rand_len() > 0 } - fn prove( + /// Used by a client to prove validity (according to an FLP system) of an input + /// that is both shared between the leader and helper + /// and encoded as a measurement. Has a precondition that leader_input_share + /// \+ helper_input_share = encoded_measurement. + /// leader_seed_opt should be set only if the underlying FLP system requires + /// joint randomness. + /// In this case, the helper uses the same seed to derive its proof share and + /// joint randomness. + pub(crate) fn prove( &self, leader_input_share: &[T::Field], helper_input_share: &[T::Field], @@ -236,8 +502,8 @@ where // leader its blinding seed and the helper's joint randomness part, and // pass the helper the leader's joint randomness part. (The seed used to // derive the helper's proof share is reused as the helper's blind.) - let (leader_blind_and_helper_joint_rand_part, leader_joint_rand_part, joint_rand) = - if let Some(leader_seed) = leader_seed_opt.clone() { + let (leader_blind_and_helper_joint_rand_part_opt, leader_joint_rand_part_opt, joint_rand) = + if let Some(leader_seed) = leader_seed_opt { let leader_joint_rand_part = self.derive_joint_rand_part(&leader_seed, leader_input_share, nonce)?; let helper_joint_rand_part = @@ -269,16 +535,16 @@ where // Construct the output messages. let leader_proof_share = SzkProofShare::Leader { uncompressed_proof_share: leader_proof_share, - leader_blind_and_helper_joint_rand_part, + leader_blind_and_helper_joint_rand_part_opt, }; let helper_proof_share = SzkProofShare::Helper { - proof_share_seed_and_blind: helper_seed.clone(), - leader_joint_rand_part, + proof_share_seed_and_blind: helper_seed, + leader_joint_rand_part_opt, }; Ok([leader_proof_share, helper_proof_share]) } - fn query( + pub(crate) fn query( &self, input_share: &[T::Field], proof_share: SzkProofShare, @@ -301,8 +567,8 @@ where let ((joint_rand_seed, joint_rand), host_joint_rand_part) = match proof_share { SzkProofShare::Leader { uncompressed_proof_share: _, - leader_blind_and_helper_joint_rand_part, - } => match leader_blind_and_helper_joint_rand_part { + leader_blind_and_helper_joint_rand_part_opt, + } => match leader_blind_and_helper_joint_rand_part_opt { Some((seed, helper_joint_rand_part)) => { match self.derive_joint_rand_part(&seed, input_share, nonce) { Ok(leader_joint_rand_part) => ( @@ -323,8 +589,8 @@ where }, SzkProofShare::Helper { proof_share_seed_and_blind, - leader_joint_rand_part, - } => match leader_joint_rand_part { + leader_joint_rand_part_opt, + } => match leader_joint_rand_part_opt { Some(leader_joint_rand_part) => match self.derive_joint_rand_part( &proof_share_seed_and_blind, input_share, @@ -363,33 +629,31 @@ where )?; Ok(( SzkQueryShare { - joint_rand_part, - verifier: verifier_share, + joint_rand_part_opt: joint_rand_part, + flp_verifier: verifier_share, }, - SzkQueryState { joint_rand_seed }, + joint_rand_seed, )) } /// Returns true if the verifier message indicates that the input from which /// it was generated is valid. - fn decide( + pub fn decide( &self, - verifier: &[T::Field], - leader_joint_rand_part_opt: Option>, - helper_joint_rand_part_opt: Option>, - joint_rand_seed_opt: Option>, + verifier: SzkVerifier, + query_state: SzkQueryState, ) -> Result { // Check if underlying FLP proof validates - let check_flp_proof = self.typ.decide(verifier)?; + let check_flp_proof = self.typ.decide(&verifier.flp_verifier)?; if !check_flp_proof { return Ok(false); } // Check that joint randomness was properly derived from both // aggregators' parts match ( - joint_rand_seed_opt, - leader_joint_rand_part_opt, - helper_joint_rand_part_opt, + query_state, + verifier.leader_joint_rand_part_opt, + verifier.helper_joint_rand_part_opt, ) { (Some(joint_rand_seed), Some(leader_joint_rand_part), Some(helper_joint_rand_part)) => { let expected_joint_rand_seed = @@ -404,12 +668,45 @@ where } } +#[inline] +fn option_ct_eq(left: &Option, right: &Option) -> Choice +where + T: ConstantTimeEq + Sized, +{ + match (left, right) { + (Some(left), Some(right)) => left.ct_eq(right), + (None, None) => Choice::from(1), + _ => Choice::from(0), + } +} + +// This function determines equality between two optional, constant-time comparable tuples. It +// short-circuits on the existence (but not contents) of the values -- a timing side-channel may +// reveal whether the values match on Some or None. +#[inline] +fn option_tuple_ct_eq(left: &Option<(T, T)>, right: &Option<(T, T)>) -> Choice +where + T: ConstantTimeEq + Sized, +{ + match (left, right) { + (Some((left_0, left_1)), Some((right_0, right_1))) => { + left_0.ct_eq(right_0).bitand(left_1.ct_eq(right_1)) + } + (None, None) => Choice::from(1), + _ => Choice::from(0), + } +} + +#[cfg(test)] mod tests { use super::*; - use crate::field::Field128 as TestField; - use crate::field::{random_vector, FieldElementWithInteger}; - use crate::flp::types::{Count, Sum}; - use crate::flp::Type; + use crate::{ + field::Field128, + field::{random_vector, FieldElementWithInteger}, + flp::gadgets::{Mul, ParallelSum}, + flp::types::{Count, Sum, SumVec}, + flp::Type, + }; use rand::{thread_rng, Rng}; fn generic_szk_test(typ: T, encoded_measurement: &[T::Field], valid: bool) { @@ -454,21 +751,8 @@ mod tests { .query(&helper_input_share, h_proof_share, &verify_key, &nonce) .unwrap(); - let mut verifier = l_query_share.clone().verifier; - - for (x, y) in verifier.iter_mut().zip(h_query_share.clone().verifier) { - *x += y; - } - let h_jr_part = h_query_share.clone().joint_rand_part; - let h_jr_seed = h_query_state.joint_rand_seed; - let l_jr_part = l_query_share.joint_rand_part; - let l_jr_seed = l_query_state.joint_rand_seed; - if let Ok(leader_decision) = szk_typ.decide( - &verifier, - l_jr_part.clone(), - h_jr_part.clone(), - l_jr_seed.clone(), - ) { + let verifier = SzkQueryShare::merge_verifiers(l_query_share.clone(), h_query_share.clone()); + if let Ok(leader_decision) = szk_typ.decide(verifier.clone(), l_query_state.clone()) { assert_eq!( leader_decision, valid, "Leader incorrectly determined validity", @@ -476,12 +760,7 @@ mod tests { } else { panic!("Leader failed during decision"); }; - if let Ok(helper_decision) = szk_typ.decide( - &verifier, - l_jr_part.clone(), - h_jr_part.clone(), - h_jr_seed.clone(), - ) { + if let Ok(helper_decision) = szk_typ.decide(verifier.clone(), h_query_state.clone()) { assert_eq!( helper_decision, valid, "Helper incorrectly determined validity", @@ -493,33 +772,22 @@ mod tests { //test mutated jr seed if szk_typ.has_joint_rand() { let joint_rand_seed_opt = Some(Seed::<16>::generate().unwrap()); - if let Ok(leader_decision) = szk_typ.decide( - &verifier, - l_jr_part.clone(), - h_jr_part.clone(), - joint_rand_seed_opt, - ) { + if let Ok(leader_decision) = szk_typ.decide(verifier, joint_rand_seed_opt.clone()) { assert!(!leader_decision, "Leader accepted wrong jr seed"); }; }; - //test mutated verifier - let mut verifier = l_query_share.verifier; - - for (x, y) in verifier.iter_mut().zip(h_query_share.clone().verifier) { - *x += y + T::Field::from( + // test mutated verifier + let mut mutated_query_share = l_query_share.clone(); + for x in mutated_query_share.flp_verifier.iter_mut() { + *x += T::Field::from( ::Integer::try_from(7).unwrap(), ); } - let leader_decision = szk_typ - .decide( - &verifier, - l_jr_part.clone(), - h_jr_part.clone(), - l_jr_seed.clone(), - ) - .unwrap(); + let verifier = SzkQueryShare::merge_verifiers(mutated_query_share, h_query_share.clone()); + + let leader_decision = szk_typ.decide(verifier, l_query_state.clone()).unwrap(); assert!(!leader_decision, "Leader validated after proof mutation"); // test mutated input share @@ -529,31 +797,21 @@ mod tests { let (mutated_query_share, mutated_query_state) = szk_typ .query(&mutated_input, l_proof_share.clone(), &verify_key, &nonce) .unwrap(); - let mut verifier = mutated_query_share.verifier; - for (x, y) in verifier.iter_mut().zip(h_query_share.clone().verifier) { - *x += y; - } + let verifier = SzkQueryShare::merge_verifiers(mutated_query_share, h_query_share.clone()); - let mutated_jr_seed = mutated_query_state.joint_rand_seed; - let mutated_jr_part = mutated_query_share.joint_rand_part; - if let Ok(leader_decision) = szk_typ.decide( - &verifier, - mutated_jr_part.clone(), - h_jr_part.clone(), - mutated_jr_seed, - ) { + if let Ok(leader_decision) = szk_typ.decide(verifier, mutated_query_state) { assert!(!leader_decision, "Leader validated after input mutation"); }; // test mutated proof share - let (mut mutated_proof, leader_blind_and_helper_joint_rand_part) = match l_proof_share { + let (mut mutated_proof, leader_blind_and_helper_joint_rand_part_opt) = match l_proof_share { SzkProofShare::Leader { uncompressed_proof_share, - leader_blind_and_helper_joint_rand_part, + leader_blind_and_helper_joint_rand_part_opt, } => ( uncompressed_proof_share.clone(), - leader_blind_and_helper_joint_rand_part, + leader_blind_and_helper_joint_rand_part_opt, ), _ => (vec![], None), }; @@ -561,7 +819,7 @@ mod tests { T::Field::from(::Integer::try_from(23).unwrap()); let mutated_proof_share = SzkProofShare::Leader { uncompressed_proof_share: mutated_proof, - leader_blind_and_helper_joint_rand_part, + leader_blind_and_helper_joint_rand_part_opt, }; let (l_query_share, l_query_state) = szk_typ .query( @@ -571,38 +829,384 @@ mod tests { &nonce, ) .unwrap(); - let mut verifier = l_query_share.verifier; - - for (x, y) in verifier.iter_mut().zip(h_query_share.clone().verifier) { - *x += y; - } + let verifier = SzkQueryShare::merge_verifiers(l_query_share, h_query_share.clone()); - let mutated_jr_seed = l_query_state.joint_rand_seed; - let mutated_jr_part = l_query_share.joint_rand_part; - if let Ok(leader_decision) = szk_typ.decide( - &verifier, - mutated_jr_part.clone(), - h_jr_part.clone(), - mutated_jr_seed, - ) { + if let Ok(leader_decision) = szk_typ.decide(verifier, l_query_state) { assert!(!leader_decision, "Leader validated after proof mutation"); }; } + #[test] + fn test_sum_proof_share_encode() { + let mut nonce = [0u8; 16]; + thread_rng().fill(&mut nonce[..]); + let sum = Sum::::new(5).unwrap(); + let encoded_measurement = sum.encode_measurement(&9).unwrap(); + let algorithm_id = 5; + let szk_typ = Szk::new_turboshake128(sum, algorithm_id); + let prove_rand_seed = Seed::<16>::generate().unwrap(); + let helper_seed = Seed::<16>::generate().unwrap(); + let leader_seed_opt = Some(Seed::<16>::generate().unwrap()); + let helper_input_share = random_vector(szk_typ.typ.input_len()).unwrap(); + let mut leader_input_share = encoded_measurement.clone().to_owned(); + for (x, y) in leader_input_share.iter_mut().zip(&helper_input_share) { + *x -= *y; + } + + let [l_proof_share, _] = szk_typ + .prove( + &leader_input_share, + &helper_input_share, + &encoded_measurement[..], + [prove_rand_seed, helper_seed], + leader_seed_opt, + &nonce, + ) + .unwrap(); + + assert_eq!( + l_proof_share.encoded_len().unwrap(), + l_proof_share.get_encoded().unwrap().len() + ); + } + + #[test] + fn test_sumvec_proof_share_encode() { + let mut nonce = [0u8; 16]; + thread_rng().fill(&mut nonce[..]); + let sumvec = + SumVec::>>::new(5, 3, 3).unwrap(); + let encoded_measurement = sumvec.encode_measurement(&vec![1, 16, 0]).unwrap(); + let algorithm_id = 5; + let szk_typ = Szk::new_turboshake128(sumvec, algorithm_id); + let prove_rand_seed = Seed::<16>::generate().unwrap(); + let helper_seed = Seed::<16>::generate().unwrap(); + let leader_seed_opt = Some(Seed::<16>::generate().unwrap()); + let helper_input_share = random_vector(szk_typ.typ.input_len()).unwrap(); + let mut leader_input_share = encoded_measurement.clone().to_owned(); + for (x, y) in leader_input_share.iter_mut().zip(&helper_input_share) { + *x -= *y; + } + + let [l_proof_share, _] = szk_typ + .prove( + &leader_input_share, + &helper_input_share, + &encoded_measurement[..], + [prove_rand_seed, helper_seed], + leader_seed_opt, + &nonce, + ) + .unwrap(); + + assert_eq!( + l_proof_share.encoded_len().unwrap(), + l_proof_share.get_encoded().unwrap().len() + ); + } + + #[test] + fn test_count_proof_share_encode() { + let mut nonce = [0u8; 16]; + thread_rng().fill(&mut nonce[..]); + let count = Count::::new(); + let encoded_measurement = count.encode_measurement(&true).unwrap(); + let algorithm_id = 5; + let szk_typ = Szk::new_turboshake128(count, algorithm_id); + let prove_rand_seed = Seed::<16>::generate().unwrap(); + let helper_seed = Seed::<16>::generate().unwrap(); + let leader_seed_opt = Some(Seed::<16>::generate().unwrap()); + let helper_input_share = random_vector(szk_typ.typ.input_len()).unwrap(); + let mut leader_input_share = encoded_measurement.clone().to_owned(); + for (x, y) in leader_input_share.iter_mut().zip(&helper_input_share) { + *x -= *y; + } + + let [l_proof_share, _] = szk_typ + .prove( + &leader_input_share, + &helper_input_share, + &encoded_measurement[..], + [prove_rand_seed, helper_seed], + leader_seed_opt, + &nonce, + ) + .unwrap(); + + assert_eq!( + l_proof_share.encoded_len().unwrap(), + l_proof_share.get_encoded().unwrap().len() + ); + } + + #[test] + fn test_sum_leader_proof_share_roundtrip() { + let mut nonce = [0u8; 16]; + thread_rng().fill(&mut nonce[..]); + let sum = Sum::::new(5).unwrap(); + let encoded_measurement = sum.encode_measurement(&9).unwrap(); + let algorithm_id = 5; + let szk_typ = Szk::new_turboshake128(sum, algorithm_id); + let prove_rand_seed = Seed::<16>::generate().unwrap(); + let helper_seed = Seed::<16>::generate().unwrap(); + let leader_seed_opt = Some(Seed::<16>::generate().unwrap()); + let helper_input_share = random_vector(szk_typ.typ.input_len()).unwrap(); + let mut leader_input_share = encoded_measurement.clone().to_owned(); + for (x, y) in leader_input_share.iter_mut().zip(&helper_input_share) { + *x -= *y; + } + + let [l_proof_share, _] = szk_typ + .prove( + &leader_input_share, + &helper_input_share, + &encoded_measurement[..], + [prove_rand_seed, helper_seed], + leader_seed_opt, + &nonce, + ) + .unwrap(); + + let decoding_parameter = ( + true, + szk_typ.typ.proof_len(), + szk_typ.typ.joint_rand_len() != 0, + ); + let encoded_proof_share = l_proof_share.get_encoded().unwrap(); + let decoded_proof_share = + SzkProofShare::get_decoded_with_param(&decoding_parameter, &encoded_proof_share[..]) + .unwrap(); + assert_eq!(l_proof_share, decoded_proof_share); + } + + #[test] + fn test_sum_helper_proof_share_roundtrip() { + let mut nonce = [0u8; 16]; + thread_rng().fill(&mut nonce[..]); + let sum = Sum::::new(5).unwrap(); + let encoded_measurement = sum.encode_measurement(&9).unwrap(); + let algorithm_id = 5; + let szk_typ = Szk::new_turboshake128(sum, algorithm_id); + let prove_rand_seed = Seed::<16>::generate().unwrap(); + let helper_seed = Seed::<16>::generate().unwrap(); + let leader_seed_opt = Some(Seed::<16>::generate().unwrap()); + let helper_input_share = random_vector(szk_typ.typ.input_len()).unwrap(); + let mut leader_input_share = encoded_measurement.clone().to_owned(); + for (x, y) in leader_input_share.iter_mut().zip(&helper_input_share) { + *x -= *y; + } + + let [_, h_proof_share] = szk_typ + .prove( + &leader_input_share, + &helper_input_share, + &encoded_measurement[..], + [prove_rand_seed, helper_seed], + leader_seed_opt, + &nonce, + ) + .unwrap(); + + let decoding_parameter = ( + false, + szk_typ.typ.proof_len(), + szk_typ.typ.joint_rand_len() != 0, + ); + let encoded_proof_share = h_proof_share.get_encoded().unwrap(); + let decoded_proof_share = + SzkProofShare::get_decoded_with_param(&decoding_parameter, &encoded_proof_share[..]) + .unwrap(); + assert_eq!(h_proof_share, decoded_proof_share); + } + + #[test] + fn test_count_leader_proof_share_roundtrip() { + let mut nonce = [0u8; 16]; + thread_rng().fill(&mut nonce[..]); + let count = Count::::new(); + let encoded_measurement = count.encode_measurement(&true).unwrap(); + let algorithm_id = 5; + let szk_typ = Szk::new_turboshake128(count, algorithm_id); + let prove_rand_seed = Seed::<16>::generate().unwrap(); + let helper_seed = Seed::<16>::generate().unwrap(); + let leader_seed_opt = None; + let helper_input_share = random_vector(szk_typ.typ.input_len()).unwrap(); + let mut leader_input_share = encoded_measurement.clone().to_owned(); + for (x, y) in leader_input_share.iter_mut().zip(&helper_input_share) { + *x -= *y; + } + + let [l_proof_share, _] = szk_typ + .prove( + &leader_input_share, + &helper_input_share, + &encoded_measurement[..], + [prove_rand_seed, helper_seed], + leader_seed_opt, + &nonce, + ) + .unwrap(); + + let decoding_parameter = ( + true, + szk_typ.typ.proof_len(), + szk_typ.typ.joint_rand_len() != 0, + ); + let encoded_proof_share = l_proof_share.get_encoded().unwrap(); + let decoded_proof_share = + SzkProofShare::get_decoded_with_param(&decoding_parameter, &encoded_proof_share[..]) + .unwrap(); + assert_eq!(l_proof_share, decoded_proof_share); + } + + #[test] + fn test_count_helper_proof_share_roundtrip() { + let mut nonce = [0u8; 16]; + thread_rng().fill(&mut nonce[..]); + let count = Count::::new(); + let encoded_measurement = count.encode_measurement(&true).unwrap(); + let algorithm_id = 5; + let szk_typ = Szk::new_turboshake128(count, algorithm_id); + let prove_rand_seed = Seed::<16>::generate().unwrap(); + let helper_seed = Seed::<16>::generate().unwrap(); + let leader_seed_opt = None; + let helper_input_share = random_vector(szk_typ.typ.input_len()).unwrap(); + let mut leader_input_share = encoded_measurement.clone().to_owned(); + for (x, y) in leader_input_share.iter_mut().zip(&helper_input_share) { + *x -= *y; + } + + let [_, h_proof_share] = szk_typ + .prove( + &leader_input_share, + &helper_input_share, + &encoded_measurement[..], + [prove_rand_seed, helper_seed], + leader_seed_opt, + &nonce, + ) + .unwrap(); + + let decoding_parameter = ( + false, + szk_typ.typ.proof_len(), + szk_typ.typ.joint_rand_len() != 0, + ); + let encoded_proof_share = h_proof_share.get_encoded().unwrap(); + let decoded_proof_share = + SzkProofShare::get_decoded_with_param(&decoding_parameter, &encoded_proof_share[..]) + .unwrap(); + assert_eq!(h_proof_share, decoded_proof_share); + } + + #[test] + fn test_sumvec_leader_proof_share_roundtrip() { + let mut nonce = [0u8; 16]; + thread_rng().fill(&mut nonce[..]); + let sumvec = + SumVec::>>::new(5, 3, 3).unwrap(); + let encoded_measurement = sumvec.encode_measurement(&vec![1, 16, 0]).unwrap(); + let algorithm_id = 5; + let szk_typ = Szk::new_turboshake128(sumvec, algorithm_id); + let prove_rand_seed = Seed::<16>::generate().unwrap(); + let helper_seed = Seed::<16>::generate().unwrap(); + let leader_seed_opt = Some(Seed::<16>::generate().unwrap()); + let helper_input_share = random_vector(szk_typ.typ.input_len()).unwrap(); + let mut leader_input_share = encoded_measurement.clone().to_owned(); + for (x, y) in leader_input_share.iter_mut().zip(&helper_input_share) { + *x -= *y; + } + + let [l_proof_share, _] = szk_typ + .prove( + &leader_input_share, + &helper_input_share, + &encoded_measurement[..], + [prove_rand_seed, helper_seed], + leader_seed_opt, + &nonce, + ) + .unwrap(); + + let decoding_parameter = ( + true, + szk_typ.typ.proof_len(), + szk_typ.typ.joint_rand_len() != 0, + ); + let encoded_proof_share = l_proof_share.get_encoded().unwrap(); + let decoded_proof_share = + SzkProofShare::get_decoded_with_param(&decoding_parameter, &encoded_proof_share[..]) + .unwrap(); + assert_eq!(l_proof_share, decoded_proof_share); + } + + #[test] + fn test_sumvec_helper_proof_share_roundtrip() { + let mut nonce = [0u8; 16]; + thread_rng().fill(&mut nonce[..]); + let sumvec = + SumVec::>>::new(5, 3, 3).unwrap(); + let encoded_measurement = sumvec.encode_measurement(&vec![1, 16, 0]).unwrap(); + let algorithm_id = 5; + let szk_typ = Szk::new_turboshake128(sumvec, algorithm_id); + let prove_rand_seed = Seed::<16>::generate().unwrap(); + let helper_seed = Seed::<16>::generate().unwrap(); + let leader_seed_opt = Some(Seed::<16>::generate().unwrap()); + let helper_input_share = random_vector(szk_typ.typ.input_len()).unwrap(); + let mut leader_input_share = encoded_measurement.clone().to_owned(); + for (x, y) in leader_input_share.iter_mut().zip(&helper_input_share) { + *x -= *y; + } + + let [_, h_proof_share] = szk_typ + .prove( + &leader_input_share, + &helper_input_share, + &encoded_measurement[..], + [prove_rand_seed, helper_seed], + leader_seed_opt, + &nonce, + ) + .unwrap(); + + let decoding_parameter = ( + false, + szk_typ.typ.proof_len(), + szk_typ.typ.joint_rand_len() != 0, + ); + let encoded_proof_share = h_proof_share.get_encoded().unwrap(); + let decoded_proof_share = + SzkProofShare::get_decoded_with_param(&decoding_parameter, &encoded_proof_share[..]) + .unwrap(); + assert_eq!(h_proof_share, decoded_proof_share); + } + #[test] fn test_sum() { - let sum = Sum::::new(5).unwrap(); + let sum = Sum::::new(5).unwrap(); - let five = TestField::from(5); + let five = Field128::from(5); let nine = sum.encode_measurement(&9).unwrap(); let bad_encoding = &vec![five; sum.input_len()]; generic_szk_test(sum.clone(), &nine, true); generic_szk_test(sum, bad_encoding, false); } + #[test] + fn test_sumvec() { + let sumvec = + SumVec::>>::new(5, 3, 3).unwrap(); + + let five = Field128::from(5); + let encoded_measurement = sumvec.encode_measurement(&vec![1, 16, 0]).unwrap(); + let bad_encoding = &vec![five; sumvec.input_len()]; + generic_szk_test(sumvec.clone(), &encoded_measurement, true); + generic_szk_test(sumvec, bad_encoding, false); + } + #[test] fn test_count() { - let count = Count::::new(); + let count = Count::::new(); let encoded_true = count.encode_measurement(&true).unwrap(); generic_szk_test(count, &encoded_true, true); } diff --git a/src/lib.rs b/src/lib.rs index 9ada4596..e5280d50 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -32,8 +32,6 @@ mod fp64; doc(cfg(all(feature = "crypto-dependencies", feature = "experimental"))) )] pub mod idpf; -#[cfg(feature = "experimental")] -pub mod mastic; mod polynomial; mod prng; pub mod topology; diff --git a/src/mastic.rs b/src/mastic.rs deleted file mode 100644 index 657b7da3..00000000 --- a/src/mastic.rs +++ /dev/null @@ -1,336 +0,0 @@ -use crate::{ - - field::{FieldElement}, - idpf::{Idpf, IdpfInput, IdpfOutputShare, IdpfPublicShare, IdpfValue, RingBufferCache}, - prng::Prng, - flp::Type, - szk::{Szk, SzkProofShare}, - vidpf::{Vidpf, VidpfInput, VidpfValue}, - vdaf::{ - xof::{Seed, Xof, XofTurboShake128}, - Aggregatable, Aggregator, Client, Collector, PrepareTransition, Vdaf, VdafError, - }, -}; - -/// The MASTIC VDAF. -#[derive(Clone, Debug)] -pub struct Mastic -where - T: Type, - T::Measurement: VidpfValue, - P: Xof, - V: Vidpf -{ - algorithm_id: u32, - szk: Szk, - vpf: V, - bits: usize, - phantom: PhantomData

, -} - -impl Mastic -where - T: Type, - T::Measurement: VidpfValue, - P: Xof, - V: Vidpf -{ -pub fn new( - algorithm_id: u32, - szk: S, - vpf: V, - bits: usize) -> Self { - Self { - algorithm_id, - szk, - vpf, - bits, - phantom: PhantomData, - } -} -} -/// Mastic aggregation parameter. -/// -/// This includes the VIDPF tree level under evaluation, a set of prefixes to evaluate at that level, -/// and, optionally, the aggregate results of prior levels. -#[derive(Clone, Debug)] -pub struct MasticAggregationParam { - level: u16, - prefixes: Vec, - counts: Vec> -} - -/// Add necessary traits for MasticAggregationParam here. - -pub struct MasticPublicShare { - joint_rand_parts: Option>>, - vidpf_public_share: VidpfPublicShare, -} - -/// Add necessary traits for MasticPublicShare here - - - -/// Message sent by the [`Client`] to each [`Aggregator`] during the Sharding phase. -#[derive(Debug, Clone)] -pub struct MasticInputShare { - /// VIDPF key share. - vidpf_key: VidpfKey, - - /// The proof share. - proofs_share: SzkProofShare, - - /// Blinding seed used by the Aggregator to compute the joint randomness. This field is optional - /// because not every [`Type`] requires joint randomness. - joint_rand_blind: Option>, -} - -pub struct MasticOutputShare { - result: Vec, -} - - -impl Vdaf for Mastic -where - S: Szk, - S::Type::Measurement: VidpfValue, - P: Xof, - V: Vidpf -{ - type Measurement = S::Type::Measurement; - type AggregateResult = S::Type::AggregateResult; - type AggregationParam = MasticAggregationParam; - type PublicShare = MasticPublicShare; - type InputShare = MasticInputShare; - type OutputShare = MasticOutputShare; - type AggregateShare = MasticOutputShare; - - fn algorithm_id(&self) -> u32 { - self.algorithm_id - } - - fn num_aggregators(&self) -> usize { - 2 - } -} - -impl Mastic -where - S: Szk, - P: Xof, - V: Vidpf { - - fn shard_with_random( - &self, - measurement_label: &VidpfInput, - measurement_weight: &VidpfValue, - nonce: &[u8; 16], - vidpf_random: &[[u8; 16]; 2], - szk_random: &[[u8; SEED_SIZE]], - ) -> Result<(MasticPublicShare, Vec>), VdafError> { - - if input.len() != self.bits { - return Err(VdafError::Uncategorized(format!( - "unexpected input length ({})", - input.len() - ))); - } - // Compute the measurement shares for each aggregator by generating VIDPF - // keys for the measurement and evaluating each of them. - let (public, keys) = self.vpf.gen(measurement_label, measurement_weight, nonce); - let leader_measurement_share = self.vpf.eval(keys[0], public, input, nonce); - let helper_measurement_share = self.vpf.eval(keys[1], public, input, nonce); - let encoded_measurement = leader_measurement_share.clone(); - for (x, y) in encoded_measurement - .iter_mut() - .zip(helper_measurement_share) - { - *x -= y; - } - match (self.szk.has_joint_rand(), szk_random.len()){ - (true, 3) => (), - (false, 2) => (), - (_, _) => return Err(VdafError::Uncategorized(format!( - "incorrect Szk coins length ({})", - szk_random.len();, - ))) - } - // Compute the Szk proof shares for each aggregator - let szk_coins = [Seed::SEED_SIZE::from_bytes(szk_random[0]), Seed::SEED_SIZE::from_bytes(szk_random[1])]; - let leader_seed_opt = if self.szk.has_joint_rand() { - Some(Seed::SEED_SIZE::from_bytes(szk_random[2])) - } else { - None - }; - let szk_proof_shares = prove( - &self, - leader_measurement_share, - helper_measurement_share, - encoded_measurement, - szk_coins, - leader_seed_opt, - nonce, - ); - // Compute the joint randomness. - let mut helper_joint_rand_parts = if self.typ.joint_rand_len() > 0 { - Some(Vec::with_capacity(num_aggregators as usize - 1)) - } else { - None - }; - let proof_share_seed = random_seeds.next().unwrap().try_into().unwrap(); - let joint_rand_blind = if let Some(helper_joint_rand_parts) = - helper_joint_rand_parts.as_mut() { - let joint_rand_blind = random_seeds.next().unwrap().try_into().unwrap(); - let mut joint_rand_part_xof = P::init( - &joint_rand_blind, - &self.domain_separation_tag(DST_JOINT_RAND_PART), - ); - joint_rand_part_xof.update(&[agg_id]); // Aggregator ID - joint_rand_part_xof.update(nonce); - - let mut encoding_buffer = Vec::with_capacity(T::Field::ENCODED_SIZE); - for (x, y) in leader_measurement_share - .iter_mut() - .zip(measurement_share_prng) - { - *x -= y; - y.encode(&mut encoding_buffer).map_err(|_| { - VdafError::Uncategorized("failed to encode measurement share".to_string()) - })?; - joint_rand_part_xof.update(&encoding_buffer); - encoding_buffer.clear(); - } - - helper_joint_rand_parts.push(joint_rand_part_xof.into_seed()); - - Some(joint_rand_blind) - } else { - for (x, y) in leader_measurement_share - .iter_mut() - .zip(measurement_share_prng) - { - *x -= y; - } - None - }; - let helper = - HelperShare::from_seeds(measurement_share_seed, proof_share_seed, joint_rand_blind); - helper_shares.push(helper); - - let mut leader_blind_opt = None; - let public_share = Prio3PublicShare { - joint_rand_parts: helper_joint_rand_parts - .as_ref() - .map( - |helper_joint_rand_parts| -> Result>, VdafError> { - let leader_blind_bytes = random_seeds.next().unwrap().try_into().unwrap(); - let leader_blind = Seed::from_bytes(leader_blind_bytes); - - let mut joint_rand_part_xof = P::init( - leader_blind.as_ref(), - &self.domain_separation_tag(DST_JOINT_RAND_PART), - ); - joint_rand_part_xof.update(&[0]); // Aggregator ID - joint_rand_part_xof.update(nonce); - let mut encoding_buffer = Vec::with_capacity(T::Field::ENCODED_SIZE); - for x in leader_measurement_share.iter() { - x.encode(&mut encoding_buffer).map_err(|_| { - VdafError::Uncategorized( - "failed to encode measurement share".to_string(), - ) - })?; - joint_rand_part_xof.update(&encoding_buffer); - encoding_buffer.clear(); - } - leader_blind_opt = Some(leader_blind); - - let leader_joint_rand_seed_part = joint_rand_part_xof.into_seed(); - - let mut vec = Vec::with_capacity(self.num_aggregators()); - vec.push(leader_joint_rand_seed_part); - vec.extend(helper_joint_rand_parts.iter().cloned()); - Ok(vec) - }, - ) - .transpose()?, - }; - - // Compute the joint randomness. - let joint_rands = public_share - .joint_rand_parts - .as_ref() - .map(|joint_rand_parts| self.derive_joint_rands(joint_rand_parts.iter()).1) - .unwrap_or_default(); - - // Generate the proofs. - let prove_rands = self.derive_prove_rands(&Seed::from_bytes( - random_seeds.next().unwrap().try_into().unwrap(), - )); - let mut leader_proofs_share = Vec::with_capacity(self.typ.proof_len() * self.num_proofs()); - for p in 0..self.num_proofs() { - let prove_rand = - &prove_rands[p * self.typ.prove_rand_len()..(p + 1) * self.typ.prove_rand_len()]; - let joint_rand = - &joint_rands[p * self.typ.joint_rand_len()..(p + 1) * self.typ.joint_rand_len()]; - - leader_proofs_share.append(&mut self.typ.prove( - &encoded_measurement, - prove_rand, - joint_rand, - )?); - } - - // Generate the proof shares and distribute the joint randomness seed hints. - for (j, helper) in helper_shares.iter_mut().enumerate() { - for (x, y) in - leader_proofs_share - .iter_mut() - .zip(self.derive_helper_proofs_share( - &helper.proofs_share, - u8::try_from(j).unwrap() + 1, - )) - .take(self.typ.proof_len() * self.num_proofs()) - { - *x -= y; - } - } - - // Prep the output messages. - let mut out = Vec::with_capacity(num_aggregators as usize); - out.push(Prio3InputShare { - measurement_share: Share::Leader(leader_measurement_share), - proofs_share: Share::Leader(leader_proofs_share), - joint_rand_blind: leader_blind_opt, - }); - - for helper in helper_shares.into_iter() { - out.push(Prio3InputShare { - measurement_share: Share::Helper(helper.measurement_share), - proofs_share: Share::Helper(helper.proofs_share), - joint_rand_blind: helper.joint_rand_blind, - }); - } - - Ok((public_share, out)) - -} - - -impl Client<16> -for Mastic -where - T: Type, - T::Measurement: VidpfValue, - P: Xof, - V: Vidpf{ - fn shard( - &self, - measurement: &Self::Measurement, - nonce: &[u8; NONCE_SIZE], - ) -> Result<(Self::PublicShare, Vec), VdafError>{ - let mut random = vec![0u8; self.random_size()]; - getrandom::getrandom(&mut random)?; - self.shard_with_random(measurement, nonce, &random) - } - -} \ No newline at end of file diff --git a/src/vdaf.rs b/src/vdaf.rs index e5f4e14c..b87363c5 100644 --- a/src/vdaf.rs +++ b/src/vdaf.rs @@ -8,8 +8,11 @@ #[cfg(feature = "experimental")] use crate::dp::DifferentialPrivacyStrategy; #[cfg(all(feature = "crypto-dependencies", feature = "experimental"))] +use crate::flp::szk::SzkError; +#[cfg(all(feature = "crypto-dependencies", feature = "experimental"))] use crate::idpf::IdpfError; #[cfg(all(feature = "crypto-dependencies", feature = "experimental"))] +#[cfg(all(feature = "crypto-dependencies", feature = "experimental"))] use crate::vidpf::VidpfError; use crate::{ codec::{CodecError, Decode, Encode, ParameterizedDecode}, @@ -46,6 +49,11 @@ pub enum VdafError { #[error("flp error: {0}")] Flp(#[from] FlpError), + /// SZK error. + #[cfg(all(feature = "crypto-dependencies", feature = "experimental"))] + #[error("Szk error: {0}")] + Szk(#[from] SzkError), + /// PRNG error. #[error("prng error: {0}")] Prng(#[from] PrngError), @@ -740,6 +748,8 @@ mod tests { #[cfg_attr(docsrs, doc(cfg(feature = "test-util")))] pub mod dummy; #[cfg(all(feature = "crypto-dependencies", feature = "experimental"))] +pub mod mastic; +#[cfg(all(feature = "crypto-dependencies", feature = "experimental"))] #[cfg_attr( docsrs, doc(cfg(all(feature = "crypto-dependencies", feature = "experimental"))) diff --git a/src/vdaf/mastic.rs b/src/vdaf/mastic.rs new file mode 100644 index 00000000..3a18dbf7 --- /dev/null +++ b/src/vdaf/mastic.rs @@ -0,0 +1,1254 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! Implementation of Mastic as specified in [[draft-mouris-cfrg-mastic-01]]. +//! +//! [draft-mouris-cfrg-mastic-01]: https://www.ietf.org/archive/id/draft-mouris-cfrg-mastic-01.html + +use crate::{ + codec::{CodecError, Decode, Encode, ParameterizedDecode}, + field::{decode_fieldvec, FieldElement}, + flp::{ + szk::{Szk, SzkProofShare, SzkQueryShare, SzkQueryState, SzkVerifier}, + Type, + }, + vdaf::{ + poplar1::Poplar1AggregationParam, + xof::{Seed, Xof}, + Aggregatable, AggregateShare, Aggregator, Client, Collector, OutputShare, + PrepareTransition, Vdaf, VdafError, + }, + vidpf::{ + Vidpf, VidpfError, VidpfInput, VidpfKey, VidpfPublicShare, VidpfServerId, VidpfWeight, + }, +}; + +use std::fmt::Debug; +use std::io::{Cursor, Read}; +use std::ops::BitAnd; +use subtle::{Choice, ConstantTimeEq}; + +const DST_PATH_CHECK_BATCH: u16 = 6; +const NONCE_SIZE: usize = 16; +/// The main struct implementing the Mastic VDAF. +/// Composed of a shared zero knowledge proof system and a verifiable incremental +/// distributed point function. +#[derive(Clone, Debug)] +pub struct Mastic +where + T: Type, + P: Xof, +{ + algorithm_id: u32, + szk: Szk, + pub(crate) vidpf: Vidpf, 16>, + /// The length of the private attribute associated with any input. + pub(crate) bits: usize, +} + +impl Mastic +where + T: Type, + P: Xof, +{ + /// Creates a new instance of Mastic, with a specific attribute length and weight type. + pub fn new( + algorithm_id: u32, + szk: Szk, + vidpf: Vidpf, 16>, + bits: usize, + ) -> Self { + Self { + algorithm_id, + szk, + vidpf, + bits, + } + } +} + +/// Mastic aggregation parameter. +/// +/// This includes the VIDPF tree level under evaluation and a set of prefixes to evaluate at that level. +#[derive(Clone, Debug)] +pub struct MasticAggregationParam { + /// aggregation parameter inherited from [`Poplar1`]: contains the level (attribute length) and a vector of attribute prefixes (IdpfInputs) + level_and_prefixes: Poplar1AggregationParam, + /// Flag indicating whether the VIDPF weight needs to be validated using SZK. + /// This flag must be set the first time any report is aggregated; however this may happen at any level of the tree. + require_check_flag: bool, +} + +#[cfg(test)] +impl MasticAggregationParam { + fn new(prefixes: Vec, require_check_flag: bool) -> Result { + Ok(Self { + level_and_prefixes: Poplar1AggregationParam::try_from_prefixes(prefixes)?, + require_check_flag, + }) + } +} + +impl Encode for MasticAggregationParam { + fn encode(&self, bytes: &mut Vec) -> Result<(), CodecError> { + self.level_and_prefixes.encode(bytes)?; + let require_check = if self.require_check_flag { 1u8 } else { 0u8 }; + require_check.encode(bytes)?; + Ok(()) + } + + fn encoded_len(&self) -> Option { + Some(self.level_and_prefixes.encoded_len()? + 1usize) + } +} + +impl Decode for MasticAggregationParam { + fn decode(bytes: &mut Cursor<&[u8]>) -> Result { + let level_and_prefixes = Poplar1AggregationParam::decode(bytes)?; + let require_check = u8::decode(bytes)?; + let require_check_flag = require_check != 0; + Ok(Self { + level_and_prefixes, + require_check_flag, + }) + } +} + +/// Mastic public share. +/// +/// Contains broadcast information shared between parties to support VIDPF correctness. +pub type MasticPublicShare = VidpfPublicShare; + +impl ParameterizedDecode> + for MasticPublicShare> +where + T: Type, + P: Xof, +{ + fn decode_with_param( + mastic: &Mastic, + bytes: &mut Cursor<&[u8]>, + ) -> Result { + MasticPublicShare::>::decode_with_param( + &(mastic.bits, mastic.vidpf.weight_parameter), + bytes, + ) + } +} + +/// Mastic input share. +/// +/// Message sent by the [`Client`] to each Aggregator during the Sharding phase. +#[derive(Clone, Debug)] +pub struct MasticInputShare { + /// VIDPF key share. + vidpf_key: VidpfKey, + + /// The proof share. + proof_share: SzkProofShare, +} + +impl Encode for MasticInputShare { + fn encode(&self, bytes: &mut Vec) -> Result<(), CodecError> { + bytes.extend_from_slice(&self.vidpf_key.value[..]); + self.proof_share.encode(bytes)?; + Ok(()) + } + + fn encoded_len(&self) -> Option { + Some(16 + self.proof_share.encoded_len()?) + } +} + +impl<'a, T, P, const SEED_SIZE: usize> ParameterizedDecode<(&'a Mastic, usize)> + for MasticInputShare +where + T: Type, + P: Xof, +{ + fn decode_with_param( + (mastic, agg_id): &(&'a Mastic, usize), + bytes: &mut Cursor<&[u8]>, + ) -> Result { + if *agg_id > 1 { + return Err(CodecError::UnexpectedValue); + } + let mut value = [0; 16]; + bytes.read_exact(&mut value)?; + let vidpf_key = VidpfKey::new( + if *agg_id == 0 { + VidpfServerId::S0 + } else { + VidpfServerId::S1 + }, + value, + ); + + let proof_share = SzkProofShare::::decode_with_param( + &( + *agg_id == 0, + mastic.szk.typ.proof_len(), + mastic.szk.typ.joint_rand_len() != 0, + ), + bytes, + )?; + Ok(Self { + vidpf_key, + proof_share, + }) + } +} + +impl PartialEq for MasticInputShare { + fn eq(&self, other: &MasticInputShare) -> bool { + self.ct_eq(other).into() + } +} + +impl ConstantTimeEq for MasticInputShare { + fn ct_eq(&self, other: &MasticInputShare) -> Choice { + self.vidpf_key + .ct_eq(&other.vidpf_key) + .bitand(self.proof_share.ct_eq(&other.proof_share)) + } +} + +/// Mastic output share. +/// +/// Contains a flattened vector of VIDPF outputs: one for each prefix. +pub type MasticOutputShare = OutputShare; + +/// Mastic aggregate share. +/// +/// Contains a flattened vector of VIDPF outputs to be aggregated by Mastic aggregators +pub type MasticAggregateShare = AggregateShare; + +impl<'a, T, P, const SEED_SIZE: usize> + ParameterizedDecode<(&'a Mastic, &'a MasticAggregationParam)> + for MasticAggregateShare +where + T: Type, + P: Xof, +{ + fn decode_with_param( + decoding_parameter: &(&Mastic, &MasticAggregationParam), + bytes: &mut Cursor<&[u8]>, + ) -> Result { + let (mastic, agg_param) = decoding_parameter; + let l = mastic + .vidpf + .weight_parameter + .checked_mul(agg_param.level_and_prefixes.prefixes().len()) + .ok_or_else(|| CodecError::Other("multiplication overflow".into()))?; + let result = decode_fieldvec(l, bytes)?; + Ok(AggregateShare(result)) + } +} + +impl<'a, T, P, const SEED_SIZE: usize> + ParameterizedDecode<(&'a Mastic, &'a MasticAggregationParam)> + for MasticOutputShare +where + T: Type, + P: Xof, +{ + fn decode_with_param( + decoding_parameter: &(&Mastic, &MasticAggregationParam), + bytes: &mut Cursor<&[u8]>, + ) -> Result { + let (mastic, agg_param) = decoding_parameter; + let l = mastic + .vidpf + .weight_parameter + .checked_mul(agg_param.level_and_prefixes.prefixes().len()) + .ok_or_else(|| CodecError::Other("multiplication overflow".into()))?; + let result = decode_fieldvec(l, bytes)?; + Ok(OutputShare(result)) + } +} + +impl Vdaf for Mastic +where + T: Type, + P: Xof, +{ + type Measurement = (VidpfInput, T::Measurement); + type AggregateResult = Vec; + type AggregationParam = MasticAggregationParam; + type PublicShare = MasticPublicShare>; + type InputShare = MasticInputShare; + type OutputShare = MasticOutputShare; + type AggregateShare = MasticAggregateShare; + + fn algorithm_id(&self) -> u32 { + self.algorithm_id + } + + fn num_aggregators(&self) -> usize { + 2 + } +} + +impl Mastic +where + T: Type, + P: Xof, +{ + fn shard_with_random( + &self, + measurement_attribute: &VidpfInput, + measurement_weight: &VidpfWeight, + nonce: &[u8; 16], + vidpf_keys: [VidpfKey; 2], + szk_random: [Seed; 2], + joint_random_opt: Option>, + ) -> Result<(::PublicShare, Vec<::InputShare>), VdafError> { + // Compute the measurement shares for each aggregator by generating VIDPF + let public_share = self.vidpf.gen_with_keys( + &vidpf_keys, + measurement_attribute, + measurement_weight, + nonce, + )?; + + let leader_measurement_share = + self.vidpf.eval_root(&vidpf_keys[0], &public_share, nonce)?; + let helper_measurement_share = + self.vidpf.eval_root(&vidpf_keys[1], &public_share, nonce)?; + + let [leader_szk_proof_share, helper_szk_proof_share] = self.szk.prove( + leader_measurement_share.as_ref(), + helper_measurement_share.as_ref(), + measurement_weight.as_ref(), + szk_random, + joint_random_opt, + nonce, + )?; + let [leader_vidpf_key, helper_vidpf_key] = vidpf_keys; + let leader_share = MasticInputShare:: { + vidpf_key: leader_vidpf_key, + proof_share: leader_szk_proof_share, + }; + let helper_share = MasticInputShare:: { + vidpf_key: helper_vidpf_key, + proof_share: helper_szk_proof_share, + }; + Ok((public_share, vec![leader_share, helper_share])) + } + + fn encode_measurement( + &self, + measurement: &T::Measurement, + ) -> Result, VdafError> { + Ok(VidpfWeight::::from( + self.szk.typ.encode_measurement(measurement)?, + )) + } +} + +impl Client<16> for Mastic +where + T: Type, + P: Xof, +{ + fn shard( + &self, + (attribute, weight): &(VidpfInput, T::Measurement), + nonce: &[u8; 16], + ) -> Result<(Self::PublicShare, Vec), VdafError> { + if attribute.len() != self.bits { + return Err(VdafError::Vidpf(VidpfError::InvalidAttributeLength)); + } + + let vidpf_keys = [ + VidpfKey::gen(VidpfServerId::S0)?, + VidpfKey::gen(VidpfServerId::S1)?, + ]; + let joint_random_opt = if self.szk.has_joint_rand() { + Some(Seed::::generate()?) + } else { + None + }; + let szk_random = [ + Seed::::generate()?, + Seed::::generate()?, + ]; + + let encoded_measurement = self.encode_measurement(weight)?; + if encoded_measurement.as_ref().len() != self.vidpf.weight_parameter { + return Err(VdafError::Uncategorized( + "encoded_measurement is the wrong length".to_string(), + )); + } + self.shard_with_random( + attribute, + &encoded_measurement, + nonce, + vidpf_keys, + szk_random, + joint_random_opt, + ) + } +} + +/// Mastic prepare state +/// +/// State held by an aggregator between rounds of Mastic. Includes intermediate +/// state for Szk verification, the output shares currently being validated, and +/// parameters of Mastic used for encoding. +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum MasticPrepareState { + /// Includes state for performing Szk verification at the root. + FirstRound(SzkQueryState, MasticOutputShare, usize, bool), + /// In all rounds but the first, SZK is not run, so only the output shares and encoding information are stored. + LaterRound(MasticOutputShare, bool), +} + +/// Mastic prepare share +/// +/// Broadcast message from an aggregator between rounds of Mastic. Includes the +/// hashed VIDPF proofs for every prefix in the aggregation parameter, and optionally +/// the verification message for Szk. +#[derive(Clone, Debug)] +pub enum MasticPrepareShare { + /// Includes a batched VIDPF proof and an SZK verification message for the root weight. + FirstRound(Seed, SzkQueryShare), + /// Includes only a batched VIDPF proof as SZK will not be run. + LaterRound(Seed), +} + +impl Encode for MasticPrepareShare { + fn encode(&self, bytes: &mut Vec) -> Result<(), CodecError> { + match self { + MasticPrepareShare::FirstRound(seed, query_share) => { + seed.encode(bytes).and_then(|_| query_share.encode(bytes)) + } + MasticPrepareShare::LaterRound(seed) => seed.encode(bytes), + } + } + + fn encoded_len(&self) -> Option { + match self { + MasticPrepareShare::FirstRound(seed, query_share) => { + Some(seed.encoded_len()? + query_share.encoded_len()?) + } + MasticPrepareShare::LaterRound(seed) => seed.encoded_len(), + } + } +} + +impl ParameterizedDecode> + for MasticPrepareShare +{ + fn decode_with_param( + prep_state: &MasticPrepareState, + bytes: &mut Cursor<&[u8]>, + ) -> Result { + match prep_state { + MasticPrepareState::FirstRound(_, _, verifier_len, requires_joint_rand) => { + Ok(MasticPrepareShare::FirstRound( + Seed::::decode(bytes)?, + SzkQueryShare::::decode_with_param( + &(*requires_joint_rand, *verifier_len), + bytes, + )?, + )) + } + MasticPrepareState::LaterRound(_, _) => { + Ok(MasticPrepareShare::LaterRound(Seed::::decode( + bytes, + )?)) + } + } + } +} + +/// Mastic prepare message +/// +/// Result of preprocessing the broadcast messages of all parties during the +/// preparation phase. +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum MasticPrepareMessage { + ///If Szk is being run, all SzkQueryShares have been summed + /// to produce a verifier to be input to decide() + FirstRound(SzkVerifier), + /// If Szk is not being run, no further computation is necessary as the VIDPF proofs have already + /// been checked. + LaterRound, +} + +impl Encode for MasticPrepareMessage { + fn encode(&self, bytes: &mut Vec) -> Result<(), CodecError> { + match self { + MasticPrepareMessage::FirstRound(verifier) => verifier.encode(bytes), + MasticPrepareMessage::LaterRound => Ok(()), + } + } + + fn encoded_len(&self) -> Option { + match self { + MasticPrepareMessage::FirstRound(verifier) => verifier.encoded_len(), + MasticPrepareMessage::LaterRound => Some(0), + } + } +} + +impl ParameterizedDecode> + for MasticPrepareMessage +{ + fn decode_with_param( + prep_state: &MasticPrepareState, + bytes: &mut Cursor<&[u8]>, + ) -> Result { + match prep_state { + MasticPrepareState::FirstRound(_, _, verifier_len, requires_joint_rand) => { + Ok(MasticPrepareMessage::FirstRound( + SzkVerifier::decode_with_param(&(*requires_joint_rand, *verifier_len), bytes)?, + )) + } + MasticPrepareState::LaterRound(_, _) => Ok(MasticPrepareMessage::LaterRound), + } + } +} + +impl Aggregator for Mastic +where + T: Type, + P: Xof, +{ + type PrepareState = MasticPrepareState; + + /// The type of messages sent by each aggregator at each round of the Prepare Process. + /// + /// Decoding takes a [`Self::PrepareState`] as a parameter; this [`Self::PrepareState`] may be + /// associated with any aggregator involved in the execution of the VDAF. + type PrepareShare = MasticPrepareShare; + + /// Result of preprocessing a round of preparation shares. This is used by all aggregators as an + /// input to the next round of the Prepare Process. + /// + /// Decoding takes a [`Self::PrepareState`] as a parameter; this [`Self::PrepareState`] may be + /// associated with any aggregator involved in the execution of the VDAF. + type PrepareMessage = MasticPrepareMessage; + + /// Begins the Prepare process with the other Aggregators. The [`Self::PrepareState`] returned + /// is passed to [`Self::prepare_next`] to get this aggregator's first-round prepare message. + /// + /// Implements `Vdaf.prep_init` from [VDAF]. + /// + /// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-08#section-5.2 + fn prepare_init( + &self, + verify_key: &[u8; SEED_SIZE], + _agg_id: usize, + agg_param: &MasticAggregationParam, + nonce: &[u8; NONCE_SIZE], + public_share: &MasticPublicShare>, + input_share: &MasticInputShare, + ) -> Result< + ( + MasticPrepareState, + MasticPrepareShare, + ), + VdafError, + > { + let mut xof = P::init( + verify_key, + &self.domain_separation_tag(DST_PATH_CHECK_BATCH), + ); + let mut output_shares = Vec::::with_capacity( + self.vidpf.weight_parameter * agg_param.level_and_prefixes.prefixes().len(), + ); + for prefix in agg_param.level_and_prefixes.prefixes() { + let mut value_share = + self.vidpf + .eval(&input_share.vidpf_key, public_share, prefix, nonce)?; + xof.update(&value_share.proof); + output_shares.append(&mut value_share.share.0); + } + let root_share_opt = if agg_param.require_check_flag { + Some( + self.vidpf + .eval( + &input_share.vidpf_key, + public_share, + &VidpfInput::from_bools(&[false]), + nonce, + )? + .share + + self + .vidpf + .eval( + &input_share.vidpf_key, + public_share, + &VidpfInput::from_bools(&[true]), + nonce, + )? + .share, + ) + } else { + None + }; + + let szk_verify_opt = if let Some(root_share) = root_share_opt { + Some(self.szk.query( + root_share.as_ref(), + input_share.proof_share.clone(), + verify_key, + nonce, + )?) + } else { + None + }; + let (prep_share, prep_state) = + if let Some((szk_query_share, szk_query_state)) = szk_verify_opt { + let verifier_len = szk_query_share.verifier_len(); + ( + MasticPrepareShare::FirstRound(xof.into_seed(), szk_query_share), + MasticPrepareState::FirstRound( + szk_query_state, + MasticOutputShare::::from(output_shares), + verifier_len, + self.szk.has_joint_rand(), + ), + ) + } else { + ( + MasticPrepareShare::LaterRound(xof.into_seed()), + MasticPrepareState::LaterRound( + MasticOutputShare::::from(output_shares), + self.szk.has_joint_rand(), + ), + ) + }; + Ok((prep_state, prep_share)) + } + + /// Preprocess a round of preparation shares into a single input to [`Self::prepare_next`]. + /// + /// Implements `Vdaf.prep_shares_to_prep` from [VDAF]. + /// + /// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-08#section-5.2 + fn prepare_shares_to_prepare_message< + M: IntoIterator>, + >( + &self, + _agg_param: &MasticAggregationParam, + inputs: M, + ) -> Result, VdafError> { + let mut inputs_iter = inputs.into_iter(); + let leader_share = inputs_iter.next().ok_or(VdafError::Uncategorized( + "No leader share received".to_string(), + ))?; + let helper_share = inputs_iter.next().ok_or(VdafError::Uncategorized( + "No helper share received".to_string(), + ))?; + if inputs_iter.next().is_some() { + return Err(VdafError::Uncategorized( + "more than 2 prepare shares".to_string(), + )); + }; + + match (leader_share, helper_share) { + ( + MasticPrepareShare::FirstRound(leader_vidpf_proof, leader_query_share), + MasticPrepareShare::FirstRound(helper_vidpf_proof, helper_query_share), + ) => { + if leader_vidpf_proof == helper_vidpf_proof { + Ok(MasticPrepareMessage::FirstRound( + SzkQueryShare::merge_verifiers(leader_query_share, helper_query_share), + )) + } else { + Err(VdafError::Uncategorized( + "Vidpf proof verification failed".to_string(), + )) + } + } + ( + MasticPrepareShare::LaterRound(leader_vidpf_proof), + MasticPrepareShare::LaterRound(helper_vidpf_proof), + ) => { + if leader_vidpf_proof == helper_vidpf_proof { + Ok(MasticPrepareMessage::LaterRound) + } else { + Err(VdafError::Uncategorized( + "Vidpf proof verification failed".to_string(), + )) + } + } + _ => Err(VdafError::Uncategorized( + "Prepare state and message disagree on whether Szk verification should occur" + .to_string(), + )), + } + } + + /// Compute the next state transition from the current state and the previous round of input + /// messages. If this returns [`PrepareTransition::Continue`], then the returned + /// [`Self::PrepareShare`] should be combined with the other Aggregators' `PrepareShare`s from + /// this round and passed into another call to this method. This continues until this method + /// returns [`PrepareTransition::Finish`], at which point the returned output share may be + /// aggregated. If the method returns an error, the aggregator should consider its input share + /// invalid and not attempt to process it any further. + /// + /// Implements `Vdaf.prep_next` from [VDAF]. + /// + /// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-08#section-5.2 + fn prepare_next( + &self, + state: MasticPrepareState, + input: MasticPrepareMessage, + ) -> Result, VdafError> { + match (state, input) { + (MasticPrepareState::LaterRound(output_share, _), MasticPrepareMessage::LaterRound) => { + Ok(PrepareTransition::Finish(output_share)) + } + ( + MasticPrepareState::FirstRound(query_state, output_share, _, _), + MasticPrepareMessage::FirstRound(verifier), + ) => { + if self.szk.decide(verifier, query_state)? { + Ok(PrepareTransition::Finish(output_share)) + } else { + Err(VdafError::Uncategorized( + "Szk proof failed verification".to_string(), + )) + } + } + _ => Err(VdafError::Uncategorized( + "Prepare state and message disagree on whether Szk verification should occur" + .to_string(), + )), + } + } + + /// Aggregates a sequence of output shares into an aggregate share. + fn aggregate>( + &self, + agg_param: &MasticAggregationParam, + output_shares: M, + ) -> Result, VdafError> { + let mut agg_share = MasticAggregateShare::::from(vec![ + T::Field::zero(); + self.vidpf.weight_parameter + * agg_param + .level_and_prefixes + .prefixes() + .len() + ]); + for output_share in output_shares.into_iter() { + agg_share.accumulate(&output_share)?; + } + Ok(agg_share) + } +} + +/// The Collector's role in the execution of a VDAF. +impl Collector for Mastic +where + T: Type, + P: Xof, +{ + /// Combines aggregate shares into the aggregate result. + fn unshard>( + &self, + agg_param: &MasticAggregationParam, + agg_shares: M, + _num_measurements: usize, + ) -> Result { + let n = agg_param.level_and_prefixes.prefixes().len(); + let mut agg_final = MasticAggregateShare::::from(vec![ + T::Field::zero(); + self.vidpf.weight_parameter + * n + ]); + for agg_share in agg_shares.into_iter() { + agg_final.merge(&agg_share)?; + } + let mut result = Vec::::with_capacity(n); + for i in 0..n { + let encoded_result = &agg_final.0 + [i * self.vidpf.weight_parameter..(i + 1) * self.vidpf.weight_parameter]; + result.push( + self.szk + .typ() + .decode_result(&self.szk.typ().truncate(encoded_result.to_vec())?[..], 1)?, + ); + } + Ok(result) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::field::Field128; + use crate::flp::gadgets::{Mul, ParallelSum}; + use crate::flp::types::{Count, Sum, SumVec}; + use crate::vdaf::test_utils::{run_vdaf, run_vdaf_prepare}; + use rand::{thread_rng, Rng}; + + const TEST_NONCE_SIZE: usize = 16; + + #[test] + fn test_mastic_sum() { + let algorithm_id = 6; + let sum_typ = Sum::::new(5).unwrap(); + let sum_vidpf = Vidpf::, TEST_NONCE_SIZE>::new(5); + let sum_szk = Szk::new_turboshake128(sum_typ, algorithm_id); + let mut nonce = [0u8; 16]; + let mut verify_key = [0u8; 16]; + thread_rng().fill(&mut verify_key[..]); + thread_rng().fill(&mut nonce[..]); + + let first_input = VidpfInput::from_bytes(&[240u8, 0u8, 1u8, 4u8][..]); + let second_input = VidpfInput::from_bytes(&[112u8, 0u8, 1u8, 4u8][..]); + let third_input = VidpfInput::from_bytes(&[48u8, 0u8, 1u8, 4u8][..]); + let fourth_input = VidpfInput::from_bytes(&[32u8, 0u8, 1u8, 4u8][..]); + let fifth_input = VidpfInput::from_bytes(&[0u8, 0u8, 1u8, 4u8][..]); + let first_prefix = VidpfInput::from_bools(&[false, false, true]); + let second_prefix = VidpfInput::from_bools(&[false]); + let third_prefix = VidpfInput::from_bools(&[true]); + let mastic = Mastic::new(algorithm_id, sum_szk, sum_vidpf, 32); + + let first_agg_param = MasticAggregationParam::new(vec![first_prefix], true).unwrap(); + let second_agg_param = + MasticAggregationParam::new(vec![second_prefix, third_prefix], true).unwrap(); + + assert_eq!( + run_vdaf( + &mastic, + &first_agg_param, + [ + (first_input.clone(), 24), + (second_input.clone(), 0), + (third_input.clone(), 0), + (fourth_input.clone(), 3), + (fifth_input.clone(), 28) + ] + ) + .unwrap(), + vec![3] + ); + + assert_eq!( + run_vdaf( + &mastic, + &second_agg_param, + [ + (first_input.clone(), 24), + (second_input, 0), + (third_input, 0), + (fourth_input, 3), + (fifth_input, 28) + ] + ) + .unwrap(), + vec![31, 24] + ); + + let (public_share, input_shares) = mastic + .shard(&(first_input.clone(), 24u128), &nonce) + .unwrap(); + run_vdaf_prepare( + &mastic, + &verify_key, + &first_agg_param, + &nonce, + public_share, + input_shares, + ) + .unwrap(); + + let (public_share, input_shares) = mastic.shard(&(first_input, 4u128), &nonce).unwrap(); + run_vdaf_prepare( + &mastic, + &verify_key, + &first_agg_param, + &nonce, + public_share, + input_shares, + ) + .unwrap(); + } + + #[test] + fn test_input_share_encode_sum() { + let algorithm_id = 6; + let sum_typ = Sum::::new(5).unwrap(); + let sum_szk = Szk::new_turboshake128(sum_typ, algorithm_id); + let sum_vidpf = Vidpf::, TEST_NONCE_SIZE>::new(5); + + let mut nonce = [0u8; 16]; + let mut verify_key = [0u8; 16]; + thread_rng().fill(&mut verify_key[..]); + thread_rng().fill(&mut nonce[..]); + + let first_input = VidpfInput::from_bytes(&[15u8, 0u8, 1u8, 4u8][..]); + + let mastic = Mastic::new(algorithm_id, sum_szk, sum_vidpf, 32); + let (_, input_shares) = mastic.shard(&(first_input, 26u128), &nonce).unwrap(); + let [leader_input_share, helper_input_share] = [&input_shares[0], &input_shares[1]]; + + assert_eq!( + leader_input_share.encoded_len().unwrap(), + leader_input_share.get_encoded().unwrap().len() + ); + assert_eq!( + helper_input_share.encoded_len().unwrap(), + helper_input_share.get_encoded().unwrap().len() + ); + } + + #[test] + fn test_public_share_roundtrip_sum() { + let algorithm_id = 6; + let sum_typ = Sum::::new(5).unwrap(); + let sum_szk = Szk::new_turboshake128(sum_typ, algorithm_id); + let sum_vidpf = Vidpf::, TEST_NONCE_SIZE>::new(5); + + let mut nonce = [0u8; 16]; + let mut verify_key = [0u8; 16]; + thread_rng().fill(&mut verify_key[..]); + thread_rng().fill(&mut nonce[..]); + + let first_input = VidpfInput::from_bytes(&[15u8, 0u8, 1u8, 4u8][..]); + + let mastic = Mastic::new(algorithm_id, sum_szk, sum_vidpf, 32); + let (public, _) = mastic.shard(&(first_input, 25u128), &nonce).unwrap(); + + let encoded_public = public.get_encoded().unwrap(); + let decoded_public = + MasticPublicShare::get_decoded_with_param(&mastic, &encoded_public[..]).unwrap(); + assert_eq!(public, decoded_public); + } + + #[test] + fn test_mastic_count() { + let algorithm_id = 6; + let count = Count::::new(); + let szk = Szk::new_turboshake128(count, algorithm_id); + let sum_vidpf = Vidpf::, TEST_NONCE_SIZE>::new(1); + + let mut nonce = [0u8; 16]; + let mut verify_key = [0u8; 16]; + thread_rng().fill(&mut verify_key[..]); + thread_rng().fill(&mut nonce[..]); + + let first_input = VidpfInput::from_bytes(&[240u8, 0u8, 1u8, 4u8][..]); + let second_input = VidpfInput::from_bytes(&[112u8, 0u8, 1u8, 4u8][..]); + let third_input = VidpfInput::from_bytes(&[48u8, 0u8, 1u8, 4u8][..]); + let fourth_input = VidpfInput::from_bytes(&[32u8, 0u8, 1u8, 4u8][..]); + let fifth_input = VidpfInput::from_bytes(&[0u8, 0u8, 1u8, 4u8][..]); + let first_prefix = VidpfInput::from_bools(&[false, false, true]); + let second_prefix = VidpfInput::from_bools(&[false]); + let third_prefix = VidpfInput::from_bools(&[true]); + let mastic = Mastic::new(algorithm_id, szk, sum_vidpf, 32); + let first_agg_param = MasticAggregationParam::new(vec![first_prefix], true).unwrap(); + let second_agg_param = + MasticAggregationParam::new(vec![second_prefix, third_prefix], true).unwrap(); + + assert_eq!( + run_vdaf( + &mastic, + &first_agg_param, + [ + (first_input.clone(), true), + (second_input.clone(), false), + (third_input.clone(), false), + (fourth_input.clone(), true), + (fifth_input.clone(), true) + ] + ) + .unwrap(), + vec![1] + ); + + assert_eq!( + run_vdaf( + &mastic, + &second_agg_param, + [ + (first_input.clone(), true), + (second_input, false), + (third_input, false), + (fourth_input, true), + (fifth_input, true) + ] + ) + .unwrap(), + vec![2, 1] + ); + + let (public_share, input_shares) = + mastic.shard(&(first_input.clone(), false), &nonce).unwrap(); + run_vdaf_prepare( + &mastic, + &verify_key, + &first_agg_param, + &nonce, + public_share, + input_shares, + ) + .unwrap(); + + let (public_share, input_shares) = mastic.shard(&(first_input, true), &nonce).unwrap(); + run_vdaf_prepare( + &mastic, + &verify_key, + &first_agg_param, + &nonce, + public_share, + input_shares, + ) + .unwrap(); + } + + #[test] + fn test_public_share_encoded_len() { + let algorithm_id = 6; + let count = Count::::new(); + let szk = Szk::new_turboshake128(count, algorithm_id); + let sum_vidpf = Vidpf::, TEST_NONCE_SIZE>::new(1); + + let mut nonce = [0u8; 16]; + let mut verify_key = [0u8; 16]; + thread_rng().fill(&mut verify_key[..]); + thread_rng().fill(&mut nonce[..]); + let first_input = VidpfInput::from_bytes(&[15u8, 0u8, 1u8, 4u8][..]); + + let mastic = Mastic::new(algorithm_id, szk, sum_vidpf, 32); + let (public, _) = mastic.shard(&(first_input, true), &nonce).unwrap(); + + assert_eq!( + public.encoded_len().unwrap(), + public.get_encoded().unwrap().len() + ); + } + + #[test] + fn test_public_share_roundtrip_count() { + let algorithm_id = 6; + let count = Count::::new(); + let szk = Szk::new_turboshake128(count, algorithm_id); + let sum_vidpf = Vidpf::, TEST_NONCE_SIZE>::new(1); + + let mut nonce = [0u8; 16]; + let mut verify_key = [0u8; 16]; + thread_rng().fill(&mut verify_key[..]); + thread_rng().fill(&mut nonce[..]); + + let first_input = VidpfInput::from_bytes(&[15u8, 0u8, 1u8, 4u8][..]); + + let mastic = Mastic::new(algorithm_id, szk, sum_vidpf, 32); + let (public, _) = mastic.shard(&(first_input, true), &nonce).unwrap(); + + let encoded_public = public.get_encoded().unwrap(); + let decoded_public = + MasticPublicShare::get_decoded_with_param(&mastic, &encoded_public[..]).unwrap(); + assert_eq!(public, decoded_public); + } + + #[test] + fn test_mastic_sumvec() { + let algorithm_id = 6; + let sumvec = + SumVec::>>::new(5, 3, 3).unwrap(); + let szk = Szk::new_turboshake128(sumvec, algorithm_id); + let sum_vidpf = Vidpf::, TEST_NONCE_SIZE>::new(15); + + let mut nonce = [0u8; 16]; + let mut verify_key = [0u8; 16]; + thread_rng().fill(&mut verify_key[..]); + thread_rng().fill(&mut nonce[..]); + + let first_input = VidpfInput::from_bytes(&[240u8, 0u8, 1u8, 4u8][..]); + let second_input = VidpfInput::from_bytes(&[112u8, 0u8, 1u8, 4u8][..]); + let third_input = VidpfInput::from_bytes(&[48u8, 0u8, 1u8, 4u8][..]); + let fourth_input = VidpfInput::from_bytes(&[32u8, 0u8, 1u8, 4u8][..]); + let fifth_input = VidpfInput::from_bytes(&[0u8, 0u8, 1u8, 4u8][..]); + let first_measurement = vec![1u128, 16u128, 0u128]; + let second_measurement = vec![0u128, 0u128, 0u128]; + let third_measurement = vec![0u128, 0u128, 0u128]; + let fourth_measurement = vec![1u128, 17u128, 31u128]; + let fifth_measurement = vec![6u128, 4u128, 11u128]; + let first_prefix = VidpfInput::from_bools(&[false, false, true]); + let second_prefix = VidpfInput::from_bools(&[false]); + let third_prefix = VidpfInput::from_bools(&[true]); + let mastic = Mastic::new(algorithm_id, szk, sum_vidpf, 32); + let first_agg_param = MasticAggregationParam::new(vec![first_prefix], true).unwrap(); + let second_agg_param = + MasticAggregationParam::new(vec![second_prefix, third_prefix], true).unwrap(); + + assert_eq!( + run_vdaf( + &mastic, + &first_agg_param, + [ + (first_input.clone(), first_measurement.clone()), + (second_input.clone(), second_measurement.clone()), + (third_input.clone(), third_measurement.clone()), + (fourth_input.clone(), fourth_measurement.clone()), + (fifth_input.clone(), fifth_measurement.clone()) + ] + ) + .unwrap(), + vec![vec![1, 17, 31]] + ); + + assert_eq!( + run_vdaf( + &mastic, + &second_agg_param, + [ + (first_input.clone(), first_measurement.clone()), + (second_input, second_measurement.clone()), + (third_input, third_measurement), + (fourth_input, fourth_measurement), + (fifth_input, fifth_measurement) + ] + ) + .unwrap(), + vec![vec![7, 21, 42], vec![1, 16, 0]] + ); + + let (public_share, input_shares) = mastic + .shard(&(first_input.clone(), first_measurement.clone()), &nonce) + .unwrap(); + run_vdaf_prepare( + &mastic, + &verify_key, + &first_agg_param, + &nonce, + public_share, + input_shares, + ) + .unwrap(); + + let (public_share, input_shares) = mastic + .shard(&(first_input, second_measurement.clone()), &nonce) + .unwrap(); + run_vdaf_prepare( + &mastic, + &verify_key, + &first_agg_param, + &nonce, + public_share, + input_shares, + ) + .unwrap(); + } + + #[test] + fn test_input_share_encode_sumvec() { + let algorithm_id = 6; + let sumvec = + SumVec::>>::new(5, 3, 3).unwrap(); + let measurement = vec![1, 16, 0]; + let szk = Szk::new_turboshake128(sumvec, algorithm_id); + let sum_vidpf = Vidpf::, TEST_NONCE_SIZE>::new(15); + + let mut nonce = [0u8; 16]; + let mut verify_key = [0u8; 16]; + thread_rng().fill(&mut verify_key[..]); + thread_rng().fill(&mut nonce[..]); + + let first_input = VidpfInput::from_bytes(&[15u8, 0u8, 1u8, 4u8][..]); + + let mastic = Mastic::new(algorithm_id, szk, sum_vidpf, 32); + let (_public, input_shares) = mastic.shard(&(first_input, measurement), &nonce).unwrap(); + let leader_input_share = &input_shares[0]; + let helper_input_share = &input_shares[1]; + + assert_eq!( + leader_input_share.encoded_len().unwrap(), + leader_input_share.get_encoded().unwrap().len() + ); + assert_eq!( + helper_input_share.encoded_len().unwrap(), + helper_input_share.get_encoded().unwrap().len() + ); + } + + #[test] + fn test_input_share_roundtrip_sumvec() { + let algorithm_id = 6; + let sumvec = + SumVec::>>::new(5, 3, 3).unwrap(); + let measurement = vec![1, 16, 0]; + let szk = Szk::new_turboshake128(sumvec, algorithm_id); + let sum_vidpf = Vidpf::, TEST_NONCE_SIZE>::new(15); + + let mut nonce = [0u8; 16]; + let mut verify_key = [0u8; 16]; + thread_rng().fill(&mut verify_key[..]); + thread_rng().fill(&mut nonce[..]); + + let first_input = VidpfInput::from_bytes(&[15u8, 0u8, 1u8, 4u8][..]); + + let mastic = Mastic::new(algorithm_id, szk, sum_vidpf, 32); + let (_public, input_shares) = mastic.shard(&(first_input, measurement), &nonce).unwrap(); + let leader_input_share = &input_shares[0]; + let helper_input_share = &input_shares[1]; + + let encoded_input_share = leader_input_share.get_encoded().unwrap(); + let decoded_leader_input_share = + MasticInputShare::get_decoded_with_param(&(&mastic, 0), &encoded_input_share[..]) + .unwrap(); + assert_eq!(leader_input_share, &decoded_leader_input_share); + let encoded_input_share = helper_input_share.get_encoded().unwrap(); + let decoded_helper_input_share = + MasticInputShare::get_decoded_with_param(&(&mastic, 1), &encoded_input_share[..]) + .unwrap(); + assert_eq!(helper_input_share, &decoded_helper_input_share); + } + + #[test] + fn test_public_share_encode_sumvec() { + let algorithm_id = 6; + let sumvec = + SumVec::>>::new(5, 3, 3).unwrap(); + let measurement = vec![1, 16, 0]; + let szk = Szk::new_turboshake128(sumvec, algorithm_id); + let sum_vidpf = Vidpf::, TEST_NONCE_SIZE>::new(15); + + let mut nonce = [0u8; 16]; + let mut verify_key = [0u8; 16]; + thread_rng().fill(&mut verify_key[..]); + thread_rng().fill(&mut nonce[..]); + + let first_input = VidpfInput::from_bytes(&[15u8, 0u8, 1u8, 4u8][..]); + + let mastic = Mastic::new(algorithm_id, szk, sum_vidpf, 32); + let (public, _) = mastic.shard(&(first_input, measurement), &nonce).unwrap(); + + assert_eq!( + public.encoded_len().unwrap(), + public.get_encoded().unwrap().len() + ); + } + + #[test] + fn test_public_share_roundtrip_sumvec() { + let algorithm_id = 6; + let sumvec = + SumVec::>>::new(5, 3, 3).unwrap(); + let measurement = vec![1, 16, 0]; + let szk = Szk::new_turboshake128(sumvec, algorithm_id); + let sum_vidpf = Vidpf::, TEST_NONCE_SIZE>::new(15); + + let mut nonce = [0u8; 16]; + let mut verify_key = [0u8; 16]; + thread_rng().fill(&mut verify_key[..]); + thread_rng().fill(&mut nonce[..]); + + let first_input = VidpfInput::from_bytes(&[15u8, 0u8, 1u8, 4u8][..]); + + let mastic = Mastic::new(algorithm_id, szk, sum_vidpf, 32); + let (public, _) = mastic.shard(&(first_input, measurement), &nonce).unwrap(); + + let encoded_public_share = public.get_encoded().unwrap(); + let decoded_public_share = + MasticPublicShare::get_decoded_with_param(&mastic, &encoded_public_share[..]).unwrap(); + assert_eq!(public, decoded_public_share); + } +} diff --git a/src/vidpf.rs b/src/vidpf.rs index c8ba5db2..1e9bf1ad 100644 --- a/src/vidpf.rs +++ b/src/vidpf.rs @@ -11,16 +11,18 @@ use core::{ iter::zip, - ops::{Add, AddAssign, BitXor, BitXorAssign, Index, Sub}, + ops::{Add, AddAssign, BitAnd, BitXor, BitXorAssign, Index, Sub}, }; use bitvec::field::BitField; +use bitvec::prelude::{BitVec, Lsb0}; use rand_core::RngCore; -use std::io::Cursor; -use subtle::{Choice, ConditionallyNegatable, ConditionallySelectable}; +use std::fmt::Debug; +use std::io::{Cursor, Read}; +use subtle::{Choice, ConditionallyNegatable, ConditionallySelectable, ConstantTimeEq}; use crate::{ - codec::{CodecError, Encode, ParameterizedDecode}, + codec::{CodecError, Decode, Encode, ParameterizedDecode}, field::FieldElement, idpf::{ conditional_select_seed, conditional_swap_seed, conditional_xor_seeds, xor_seeds, @@ -45,6 +47,11 @@ pub enum VidpfError { #[error("level index out of bounds")] IndexLevel, + /// Error when input attribute has too few or many bits to be a path in an initialized + /// VIDPF tree. + #[error("invalid attribute length")] + InvalidAttributeLength, + /// Error when weight's length mismatches the length in weight's parameter. #[error("invalid weight length")] InvalidWeightLength, @@ -58,12 +65,13 @@ pub enum VidpfError { pub type VidpfInput = IdpfInput; /// Represents the codomain of an incremental point function. -pub trait VidpfValue: IdpfValue + Clone {} +pub trait VidpfValue: IdpfValue + Clone + Debug + PartialEq + ConstantTimeEq {} +#[derive(Clone, Debug)] /// A VIDPF instance. pub struct Vidpf { /// Any parameters required to instantiate a weight value. - weight_parameter: W::ValueParameter, + pub(crate) weight_parameter: W::ValueParameter, } impl Vidpf { @@ -108,7 +116,7 @@ impl Vidpf { /// [`Vidpf::gen_with_keys`] works as the [`Vidpf::gen`] method, except that two different /// keys must be provided. - fn gen_with_keys( + pub(crate) fn gen_with_keys( &self, keys: &[VidpfKey; 2], input: &VidpfInput, @@ -206,6 +214,9 @@ impl Vidpf { let mut share = W::zero(&self.weight_parameter); let n = input.len(); + if n > public.cw.len() { + return Err(VidpfError::InvalidAttributeLength); + } for level in 0..n { (state, share) = self.eval_next(key.id, public, input, level, &state, nonce)?; } @@ -266,6 +277,20 @@ impl Vidpf { Ok((next_state, y)) } + pub(crate) fn eval_root( + &self, + key: &VidpfKey, + public_share: &VidpfPublicShare, + nonce: &[u8; NONCE_SIZE], + ) -> Result { + Ok(self + .eval(key, public_share, &VidpfInput::from_bools(&[false]), nonce)? + .share + + self + .eval(key, public_share, &VidpfInput::from_bools(&[true]), nonce)? + .share) + } + fn prg(seed: &VidpfSeed, nonce: &[u8]) -> VidpfPrgOutput { let mut rng = XofFixedKeyAes128::seed_stream(&Seed(*seed), VidpfDomainSepTag::PRG, nonce); @@ -339,6 +364,8 @@ impl Vidpf { } } +/// Vidpf domain separation tag +/// /// Contains the domain separation tags for invoking different oracles. struct VidpfDomainSepTag; impl VidpfDomainSepTag { @@ -348,10 +375,13 @@ impl VidpfDomainSepTag { const NODE_PROOF_ADJUST: &'static [u8] = b"NodeProofAdjust"; } +#[derive(Clone, Debug)] +/// Vidpf key +/// /// Private key of an aggregation server. pub struct VidpfKey { id: VidpfServerId, - value: [u8; 16], + pub(crate) value: [u8; 16], } impl VidpfKey { @@ -364,10 +394,30 @@ impl VidpfKey { getrandom::getrandom(&mut value)?; Ok(Self { id, value }) } + + pub(crate) fn new(id: VidpfServerId, value: [u8; 16]) -> Self { + Self { id, value } + } +} + +impl ConstantTimeEq for VidpfKey { + fn ct_eq(&self, other: &VidpfKey) -> Choice { + Choice::from(self.id) + .ct_eq(&Choice::from(other.id)) + .bitand(self.value.ct_eq(&other.value)) + } } +impl PartialEq for VidpfKey { + fn eq(&self, other: &VidpfKey) -> bool { + bool::from(self.ct_eq(other)) + } +} + +/// Vidpf server ID +/// /// Identifies the two aggregation servers. -#[derive(Clone, Copy, PartialEq, Eq)] +#[derive(Clone, Copy, Debug, PartialEq, Eq)] pub(crate) enum VidpfServerId { /// S0 is the first server. S0, @@ -384,8 +434,10 @@ impl From for Choice { } } -/// Adjusts values of shares during the VIDPF evaluation. -#[derive(Debug)] +/// Vidpf correction word +/// +/// Adjusts values of shares during the VIDPF evaluation. +#[derive(Clone, Debug)] struct VidpfCorrectionWord { seed: VidpfSeed, left_control_bit: Choice, @@ -393,13 +445,107 @@ struct VidpfCorrectionWord { weight: W, } +impl ConstantTimeEq for VidpfCorrectionWord { + fn ct_eq(&self, other: &Self) -> Choice { + self.seed.ct_eq(&other.seed) + & self.left_control_bit.ct_eq(&other.left_control_bit) + & self.right_control_bit.ct_eq(&other.right_control_bit) + & self.weight.ct_eq(&other.weight) + } +} + +impl PartialEq for VidpfCorrectionWord +where + W: ConstantTimeEq, +{ + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).into() + } +} + +/// Vidpf public share +/// /// Common public information used by aggregation servers. -#[derive(Debug)] +#[derive(Clone, Debug, PartialEq)] pub struct VidpfPublicShare { cw: Vec>, cs: Vec, } +impl Encode for VidpfPublicShare { + fn encode(&self, bytes: &mut Vec) -> Result<(), CodecError> { + // Control bits need to be written within each byte in LSB-to-MSB order, and assigned into + // bytes in big-endian order. Thus, the first four levels will have their control bits + // encoded in the last byte, and the last levels will have their control bits encoded in the + // first byte. + let mut control_bits: BitVec = BitVec::with_capacity(self.cw.len() * 2); + for correction_words in self.cw.iter() { + control_bits.extend( + [ + bool::from(correction_words.left_control_bit), + bool::from(correction_words.right_control_bit), + ] + .iter(), + ); + } + control_bits.set_uninitialized(false); + let mut packed_control = control_bits.into_vec(); + bytes.append(&mut packed_control); + + for correction_words in self.cw.iter() { + Seed(correction_words.seed).encode(bytes)?; + correction_words.weight.encode(bytes)?; + } + + for proof in &self.cs { + bytes.extend_from_slice(proof); + } + Ok(()) + } + + fn encoded_len(&self) -> Option { + let control_bits_count = (self.cw.len()) * 2; + let mut len = (control_bits_count + 7) / 8 + (self.cw.len()) * 16; + for correction_words in self.cw.iter() { + len += correction_words.weight.encoded_len()?; + } + len += self.cs.len() * 32; + Some(len) + } +} + +impl ParameterizedDecode<(usize, W::ValueParameter)> for VidpfPublicShare { + fn decode_with_param( + (bits, weight_parameter): &(usize, W::ValueParameter), + bytes: &mut Cursor<&[u8]>, + ) -> Result { + let packed_control_len = (bits + 3) / 4; + let mut packed = vec![0u8; packed_control_len]; + bytes.read_exact(&mut packed)?; + let unpacked_control_bits: BitVec = BitVec::from_vec(packed); + + let mut cw = Vec::>::with_capacity(*bits); + for chunk in unpacked_control_bits[0..(bits) * 2].chunks(2) { + let left_control_bit = (chunk[0] as u8).into(); + let right_control_bit = (chunk[1] as u8).into(); + let seed = Seed::decode(bytes)?.0; + cw.push(VidpfCorrectionWord { + seed, + left_control_bit, + right_control_bit, + weight: W::decode_with_param(weight_parameter, bytes)?, + }) + } + let mut cs = Vec::::with_capacity(*bits); + for _ in 0..*bits { + cs.push(VidpfProof::decode(bytes)?); + } + Ok(Self { cw, cs }) + } +} + +/// Vidpf evaluation state +/// /// Contains the values produced during input evaluation at a given level. pub struct VidpfEvalState { seed: VidpfSeed, @@ -454,7 +600,7 @@ struct VidpfPrgOutput { /// Represents an array of field elements that implements the [`VidpfValue`] trait. #[derive(Debug, PartialEq, Eq, Clone)] -pub struct VidpfWeight(Vec); +pub struct VidpfWeight(pub(crate) Vec); impl From> for VidpfWeight { fn from(value: Vec) -> Self { @@ -462,6 +608,12 @@ impl From> for VidpfWeight { } } +impl AsRef<[F]> for VidpfWeight { + fn as_ref(&self) -> &[F] { + &self.0 + } +} + impl VidpfValue for VidpfWeight {} impl IdpfValue for VidpfWeight { @@ -549,6 +701,12 @@ impl Sub for VidpfWeight { } } +impl ConstantTimeEq for VidpfWeight { + fn ct_eq(&self, other: &Self) -> Choice { + self.0[..].ct_eq(&other.0[..]) + } +} + impl Encode for VidpfWeight { fn encode(&self, bytes: &mut Vec) -> Result<(), CodecError> { for e in &self.0 {