From e2751fe6cc1b994867c40ad47cf1e2c23771bb9b Mon Sep 17 00:00:00 2001 From: Tim Geoghegan Date: Mon, 11 Sep 2023 16:49:56 -0700 Subject: [PATCH] Align `Aggregator` trait with spec (#744) `Vdaf::prepare_preprocess` was missing the `&Vdaf::AggregationParam` argument specified in VDAF. This adds that argument to the trait method and its implementations. While we're at it, to further align with the spec, we add new methods `prepare_shares_to_prepare_message` and `prepare_next`, to align with the spec's `prep_shares_to_prep` and `prep_next`, as aliases for the existing `prepare_preprocess` and `prepare_step`. The existing methods are marked as deprecated. Resolves #670 --- src/vdaf.rs | 92 ++++++++++++++++++++++++++++++++++++------ src/vdaf/poplar1.rs | 23 +++++++---- src/vdaf/prio2.rs | 11 ++--- src/vdaf/prio3.rs | 13 ++++-- src/vdaf/prio3_test.rs | 4 +- 5 files changed, 112 insertions(+), 31 deletions(-) diff --git a/src/vdaf.rs b/src/vdaf.rs index 5fb568514..d7c1499f8 100644 --- a/src/vdaf.rs +++ b/src/vdaf.rs @@ -208,6 +208,10 @@ pub trait Vdaf: Clone + Debug { pub trait Client: Vdaf { /// Shards a measurement into a public share and a sequence of input shares, one for each /// Aggregator. + /// + /// Implements `Vdaf::shard` from [VDAF]. + /// + /// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-07#section-5.1 fn shard( &self, measurement: &Self::Measurement, @@ -234,8 +238,11 @@ pub trait Aggregator: Vda type PrepareMessage: Clone + Debug + ParameterizedDecode + Encode; /// Begins the Prepare process with the other Aggregators. The [`Self::PrepareState`] returned - /// is passed to [`Aggregator::prepare_step`] to get this aggregator's first-round prepare - /// message. + /// is passed to [`Self::prepare_next`] to get this aggregator's first-round prepare message. + /// + /// Implements `Vdaf.prep_init` from [VDAF]. + /// + /// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-07#section-5.2 fn prepare_init( &self, verify_key: &[u8; VERIFY_KEY_SIZE], @@ -246,9 +253,36 @@ pub trait Aggregator: Vda input_share: &Self::InputShare, ) -> Result<(Self::PrepareState, Self::PrepareShare), VdafError>; - /// Preprocess a round of preparation shares into a single input to [`Aggregator::prepare_step`]. + /// Preprocess a round of preparation shares into a single input to [`Self::prepare_next`]. + /// + /// Implements `Vdaf.prep_shares_to_prep` from [VDAF]. + /// + /// # Notes + /// + /// [`Self::prepare_shares_to_prepare_message`] is preferable since its name better matches the + /// specification. + /// + /// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-07#section-5.2 + #[deprecated( + since = "0.15.0", + note = "Use Vdaf::prepare_shares_to_prepare_message instead" + )] fn prepare_preprocess>( &self, + agg_param: &Self::AggregationParam, + inputs: M, + ) -> Result { + self.prepare_shares_to_prepare_message(agg_param, inputs) + } + + /// Preprocess a round of preparation shares into a single input to [`Self::prepare_next`]. + /// + /// Implements `Vdaf.prep_shares_to_prep` from [VDAF]. + /// + /// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-07#section-5.2 + fn prepare_shares_to_prepare_message>( + &self, + agg_param: &Self::AggregationParam, inputs: M, ) -> Result; @@ -259,10 +293,38 @@ pub trait Aggregator: Vda /// returns [`PrepareTransition::Finish`], at which point the returned output share may be /// aggregated. If the method returns an error, the aggregator should consider its input share /// invalid and not attempt to process it any further. + /// + /// Implements `Vdaf.prep_next` from [VDAF]. + /// + /// # Notes + /// + /// [`Self::prepare_next`] is preferable since its name better matches the specification. + /// + /// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-07#section-5.2 + #[deprecated(since = "0.15.0", note = "Use Vdaf::prepare_next")] fn prepare_step( &self, state: Self::PrepareState, input: Self::PrepareMessage, + ) -> Result, VdafError> { + self.prepare_next(state, input) + } + + /// Compute the next state transition from the current state and the previous round of input + /// messages. If this returns [`PrepareTransition::Continue`], then the returned + /// [`Self::PrepareShare`] should be combined with the other Aggregators' `PrepareShare`s from + /// this round and passed into another call to this method. This continues until this method + /// returns [`PrepareTransition::Finish`], at which point the returned output share may be + /// aggregated. If the method returns an error, the aggregator should consider its input share + /// invalid and not attempt to process it any further. + /// + /// Implements `Vdaf.prep_next` from [VDAF]. + /// + /// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-07#section-5.2 + fn prepare_next( + &self, + state: Self::PrepareState, + input: Self::PrepareMessage, ) -> Result, VdafError>; /// Aggregates a sequence of output shares into an aggregate share. @@ -531,17 +593,20 @@ where } let mut inbound = vdaf - .prepare_preprocess(outbound.iter().map(|encoded| { - V::PrepareShare::get_decoded_with_param(&states[0], encoded) - .expect("failed to decode prep share") - }))? + .prepare_shares_to_prepare_message( + agg_param, + outbound.iter().map(|encoded| { + V::PrepareShare::get_decoded_with_param(&states[0], encoded) + .expect("failed to decode prep share") + }), + )? .get_encoded(); let mut out_shares = Vec::new(); loop { let mut outbound = Vec::new(); for state in states.iter_mut() { - match vdaf.prepare_step( + match vdaf.prepare_next( state.clone(), V::PrepareMessage::get_decoded_with_param(state, &inbound) .expect("failed to decode prep message"), @@ -559,10 +624,13 @@ where if outbound.len() == vdaf.num_aggregators() { // Another round is required before output shares are computed. inbound = vdaf - .prepare_preprocess(outbound.iter().map(|encoded| { - V::PrepareShare::get_decoded_with_param(&states[0], encoded) - .expect("failed to decode prep share") - }))? + .prepare_shares_to_prepare_message( + agg_param, + outbound.iter().map(|encoded| { + V::PrepareShare::get_decoded_with_param(&states[0], encoded) + .expect("failed to decode prep share") + }), + )? .get_encoded(); } else if outbound.is_empty() { // Each Aggregator recovered an output share. diff --git a/src/vdaf/poplar1.rs b/src/vdaf/poplar1.rs index 342fe5bf2..727e6f739 100644 --- a/src/vdaf/poplar1.rs +++ b/src/vdaf/poplar1.rs @@ -1076,8 +1076,9 @@ impl, const SEED_SIZE: usize> Aggregator } } - fn prepare_preprocess>( + fn prepare_shares_to_prepare_message>( &self, + _: &Poplar1AggregationParam, inputs: M, ) -> Result { let mut inputs = inputs.into_iter(); @@ -1114,7 +1115,7 @@ impl, const SEED_SIZE: usize> Aggregator } } - fn prepare_step( + fn prepare_next( &self, state: Poplar1PrepareState, msg: Poplar1PrepareMessage, @@ -2065,35 +2066,41 @@ mod tests { .unwrap(); let r1_prep_msg = poplar - .prepare_preprocess([init_prep_share_0.clone(), init_prep_share_1.clone()]) + .prepare_shares_to_prepare_message( + &agg_param, + [init_prep_share_0.clone(), init_prep_share_1.clone()], + ) .unwrap(); let (r1_prep_state_0, r1_prep_share_0) = assert_matches!( poplar - .prepare_step(init_prep_state_0.clone(), r1_prep_msg.clone()) + .prepare_next(init_prep_state_0.clone(), r1_prep_msg.clone()) .unwrap(), PrepareTransition::Continue(state, share) => (state, share) ); let (r1_prep_state_1, r1_prep_share_1) = assert_matches!( poplar - .prepare_step(init_prep_state_1.clone(), r1_prep_msg.clone()) + .prepare_next(init_prep_state_1.clone(), r1_prep_msg.clone()) .unwrap(), PrepareTransition::Continue(state, share) => (state, share) ); let r2_prep_msg = poplar - .prepare_preprocess([r1_prep_share_0.clone(), r1_prep_share_1.clone()]) + .prepare_shares_to_prepare_message( + &agg_param, + [r1_prep_share_0.clone(), r1_prep_share_1.clone()], + ) .unwrap(); let out_share_0 = assert_matches!( poplar - .prepare_step(r1_prep_state_0.clone(), r2_prep_msg.clone()) + .prepare_next(r1_prep_state_0.clone(), r2_prep_msg.clone()) .unwrap(), PrepareTransition::Finish(out) => out ); let out_share_1 = assert_matches!( poplar - .prepare_step(r1_prep_state_1, r2_prep_msg.clone()) + .prepare_next(r1_prep_state_1, r2_prep_msg.clone()) .unwrap(), PrepareTransition::Finish(out) => out ); diff --git a/src/vdaf/prio2.rs b/src/vdaf/prio2.rs index ad8f89102..4669c47d0 100644 --- a/src/vdaf/prio2.rs +++ b/src/vdaf/prio2.rs @@ -268,8 +268,9 @@ impl Aggregator<32, 16> for Prio2 { self.prepare_init_with_query_rand(query_rand, input_share, is_leader) } - fn prepare_preprocess>( + fn prepare_shares_to_prepare_message>( &self, + _: &Self::AggregationParam, inputs: M, ) -> Result<(), VdafError> { let verifier_shares: Vec> = @@ -289,7 +290,7 @@ impl Aggregator<32, 16> for Prio2 { Ok(()) } - fn prepare_step( + fn prepare_next( &self, state: Prio2PrepareState, _input: (), @@ -491,12 +492,12 @@ mod tests { let (prepare_state_2, prepare_share_2) = vdaf .prepare_init(&[0; 32], 1, &(), &[0; 16], &(), &input_share_2) .unwrap(); - vdaf.prepare_preprocess([prepare_share_1, prepare_share_2]) + vdaf.prepare_shares_to_prepare_message(&(), [prepare_share_1, prepare_share_2]) .unwrap(); - let transition_1 = vdaf.prepare_step(prepare_state_1, ()).unwrap(); + let transition_1 = vdaf.prepare_next(prepare_state_1, ()).unwrap(); let output_share_1 = assert_matches!(transition_1, PrepareTransition::Finish(out) => out); - let transition_2 = vdaf.prepare_step(prepare_state_2, ()).unwrap(); + let transition_2 = vdaf.prepare_next(prepare_state_2, ()).unwrap(); let output_share_2 = assert_matches!(transition_2, PrepareTransition::Finish(out) => out); leader_output_shares.push(output_share_1); diff --git a/src/vdaf/prio3.rs b/src/vdaf/prio3.rs index 4f1a6e8de..5dd224380 100644 --- a/src/vdaf/prio3.rs +++ b/src/vdaf/prio3.rs @@ -313,7 +313,7 @@ impl Prio3Average { /// prep_states.push(state); /// prep_shares.push(share); /// } -/// let prep_msg = vdaf.prepare_preprocess(prep_shares).unwrap(); +/// let prep_msg = vdaf.prepare_shares_to_prepare_message(&(), prep_shares).unwrap(); /// /// for (agg_id, state) in prep_states.into_iter().enumerate() { /// let out_share = match vdaf.prepare_step(state, prep_msg.clone()).unwrap() { @@ -1173,8 +1173,11 @@ where )) } - fn prepare_preprocess>>( + fn prepare_shares_to_prepare_message< + M: IntoIterator>, + >( &self, + _: &Self::AggregationParam, inputs: M, ) -> Result, VdafError> { let mut verifier = vec![T::Field::zero(); self.typ.verifier_len()]; @@ -1228,7 +1231,7 @@ where Ok(Prio3PrepareMessage { joint_rand_seed }) } - fn prepare_step( + fn prepare_next( &self, step: Prio3PrepareState, msg: Prio3PrepareMessage, @@ -1851,7 +1854,9 @@ mod tests { last_prepare_state = Some(prepare_state); } - let prepare_message = prio3.prepare_preprocess(prepare_shares).unwrap(); + let prepare_message = prio3 + .prepare_shares_to_prepare_message(&(), prepare_shares) + .unwrap(); let encoded_prepare_message = prepare_message.get_encoded(); let decoded_prepare_message = Prio3PrepareMessage::get_decoded_with_param( diff --git a/src/vdaf/prio3_test.rs b/src/vdaf/prio3_test.rs index 78614b1fd..a0358db39 100644 --- a/src/vdaf/prio3_test.rs +++ b/src/vdaf/prio3_test.rs @@ -106,14 +106,14 @@ fn check_prep_test_vec( } let inbound = prio3 - .prepare_preprocess(prep_shares) + .prepare_shares_to_prepare_message(&(), prep_shares) .unwrap_or_else(|e| err!(test_num, e, "prep preprocess")); assert_eq!(t.prep_messages.len(), 1); assert_eq!(inbound.get_encoded(), t.prep_messages[0].as_ref()); let mut out_shares = Vec::new(); for state in states.iter_mut() { - match prio3.prepare_step(state.clone(), inbound.clone()).unwrap() { + match prio3.prepare_next(state.clone(), inbound.clone()).unwrap() { PrepareTransition::Finish(out_share) => { out_shares.push(out_share); }