Skip to content

Commit

Permalink
Align Aggregator trait with spec (#744)
Browse files Browse the repository at this point in the history
`Vdaf::prepare_preprocess` was missing the `&Vdaf::AggregationParam`
argument specified in VDAF. This adds that argument to the trait method
and its implementations. While we're at it, to further align with the
spec, we add new methods `prepare_shares_to_prepare_message` and
`prepare_next`, to align with the spec's `prep_shares_to_prep` and
`prep_next`, as aliases for the existing `prepare_preprocess` and
`prepare_step`. The existing methods are marked as deprecated.

Resolves #670
  • Loading branch information
tgeoghegan authored Sep 11, 2023
1 parent 753a461 commit e2751fe
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 31 deletions.
92 changes: 80 additions & 12 deletions src/vdaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,10 @@ pub trait Vdaf: Clone + Debug {
pub trait Client<const NONCE_SIZE: usize>: Vdaf {
/// Shards a measurement into a public share and a sequence of input shares, one for each
/// Aggregator.
///
/// Implements `Vdaf::shard` from [VDAF].
///
/// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-07#section-5.1
fn shard(
&self,
measurement: &Self::Measurement,
Expand All @@ -234,8 +238,11 @@ pub trait Aggregator<const VERIFY_KEY_SIZE: usize, const NONCE_SIZE: usize>: Vda
type PrepareMessage: Clone + Debug + ParameterizedDecode<Self::PrepareState> + Encode;

/// Begins the Prepare process with the other Aggregators. The [`Self::PrepareState`] returned
/// is passed to [`Aggregator::prepare_step`] to get this aggregator's first-round prepare
/// message.
/// is passed to [`Self::prepare_next`] to get this aggregator's first-round prepare message.
///
/// Implements `Vdaf.prep_init` from [VDAF].
///
/// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-07#section-5.2
fn prepare_init(
&self,
verify_key: &[u8; VERIFY_KEY_SIZE],
Expand All @@ -246,9 +253,36 @@ pub trait Aggregator<const VERIFY_KEY_SIZE: usize, const NONCE_SIZE: usize>: Vda
input_share: &Self::InputShare,
) -> Result<(Self::PrepareState, Self::PrepareShare), VdafError>;

/// Preprocess a round of preparation shares into a single input to [`Aggregator::prepare_step`].
/// Preprocess a round of preparation shares into a single input to [`Self::prepare_next`].
///
/// Implements `Vdaf.prep_shares_to_prep` from [VDAF].
///
/// # Notes
///
/// [`Self::prepare_shares_to_prepare_message`] is preferable since its name better matches the
/// specification.
///
/// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-07#section-5.2
#[deprecated(
since = "0.15.0",
note = "Use Vdaf::prepare_shares_to_prepare_message instead"
)]
fn prepare_preprocess<M: IntoIterator<Item = Self::PrepareShare>>(
&self,
agg_param: &Self::AggregationParam,
inputs: M,
) -> Result<Self::PrepareMessage, VdafError> {
self.prepare_shares_to_prepare_message(agg_param, inputs)
}

/// Preprocess a round of preparation shares into a single input to [`Self::prepare_next`].
///
/// Implements `Vdaf.prep_shares_to_prep` from [VDAF].
///
/// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-07#section-5.2
fn prepare_shares_to_prepare_message<M: IntoIterator<Item = Self::PrepareShare>>(
&self,
agg_param: &Self::AggregationParam,
inputs: M,
) -> Result<Self::PrepareMessage, VdafError>;

Expand All @@ -259,10 +293,38 @@ pub trait Aggregator<const VERIFY_KEY_SIZE: usize, const NONCE_SIZE: usize>: Vda
/// returns [`PrepareTransition::Finish`], at which point the returned output share may be
/// aggregated. If the method returns an error, the aggregator should consider its input share
/// invalid and not attempt to process it any further.
///
/// Implements `Vdaf.prep_next` from [VDAF].
///
/// # Notes
///
/// [`Self::prepare_next`] is preferable since its name better matches the specification.
///
/// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-07#section-5.2
#[deprecated(since = "0.15.0", note = "Use Vdaf::prepare_next")]
fn prepare_step(
&self,
state: Self::PrepareState,
input: Self::PrepareMessage,
) -> Result<PrepareTransition<Self, VERIFY_KEY_SIZE, NONCE_SIZE>, VdafError> {
self.prepare_next(state, input)
}

/// Compute the next state transition from the current state and the previous round of input
/// messages. If this returns [`PrepareTransition::Continue`], then the returned
/// [`Self::PrepareShare`] should be combined with the other Aggregators' `PrepareShare`s from
/// this round and passed into another call to this method. This continues until this method
/// returns [`PrepareTransition::Finish`], at which point the returned output share may be
/// aggregated. If the method returns an error, the aggregator should consider its input share
/// invalid and not attempt to process it any further.
///
/// Implements `Vdaf.prep_next` from [VDAF].
///
/// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-07#section-5.2
fn prepare_next(
&self,
state: Self::PrepareState,
input: Self::PrepareMessage,
) -> Result<PrepareTransition<Self, VERIFY_KEY_SIZE, NONCE_SIZE>, VdafError>;

/// Aggregates a sequence of output shares into an aggregate share.
Expand Down Expand Up @@ -531,17 +593,20 @@ where
}

let mut inbound = vdaf
.prepare_preprocess(outbound.iter().map(|encoded| {
V::PrepareShare::get_decoded_with_param(&states[0], encoded)
.expect("failed to decode prep share")
}))?
.prepare_shares_to_prepare_message(
agg_param,
outbound.iter().map(|encoded| {
V::PrepareShare::get_decoded_with_param(&states[0], encoded)
.expect("failed to decode prep share")
}),
)?
.get_encoded();

let mut out_shares = Vec::new();
loop {
let mut outbound = Vec::new();
for state in states.iter_mut() {
match vdaf.prepare_step(
match vdaf.prepare_next(
state.clone(),
V::PrepareMessage::get_decoded_with_param(state, &inbound)
.expect("failed to decode prep message"),
Expand All @@ -559,10 +624,13 @@ where
if outbound.len() == vdaf.num_aggregators() {
// Another round is required before output shares are computed.
inbound = vdaf
.prepare_preprocess(outbound.iter().map(|encoded| {
V::PrepareShare::get_decoded_with_param(&states[0], encoded)
.expect("failed to decode prep share")
}))?
.prepare_shares_to_prepare_message(
agg_param,
outbound.iter().map(|encoded| {
V::PrepareShare::get_decoded_with_param(&states[0], encoded)
.expect("failed to decode prep share")
}),
)?
.get_encoded();
} else if outbound.is_empty() {
// Each Aggregator recovered an output share.
Expand Down
23 changes: 15 additions & 8 deletions src/vdaf/poplar1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1076,8 +1076,9 @@ impl<P: Xof<SEED_SIZE>, const SEED_SIZE: usize> Aggregator<SEED_SIZE, 16>
}
}

fn prepare_preprocess<M: IntoIterator<Item = Poplar1FieldVec>>(
fn prepare_shares_to_prepare_message<M: IntoIterator<Item = Poplar1FieldVec>>(
&self,
_: &Poplar1AggregationParam,
inputs: M,
) -> Result<Poplar1PrepareMessage, VdafError> {
let mut inputs = inputs.into_iter();
Expand Down Expand Up @@ -1114,7 +1115,7 @@ impl<P: Xof<SEED_SIZE>, const SEED_SIZE: usize> Aggregator<SEED_SIZE, 16>
}
}

fn prepare_step(
fn prepare_next(
&self,
state: Poplar1PrepareState,
msg: Poplar1PrepareMessage,
Expand Down Expand Up @@ -2065,35 +2066,41 @@ mod tests {
.unwrap();

let r1_prep_msg = poplar
.prepare_preprocess([init_prep_share_0.clone(), init_prep_share_1.clone()])
.prepare_shares_to_prepare_message(
&agg_param,
[init_prep_share_0.clone(), init_prep_share_1.clone()],
)
.unwrap();

let (r1_prep_state_0, r1_prep_share_0) = assert_matches!(
poplar
.prepare_step(init_prep_state_0.clone(), r1_prep_msg.clone())
.prepare_next(init_prep_state_0.clone(), r1_prep_msg.clone())
.unwrap(),
PrepareTransition::Continue(state, share) => (state, share)
);
let (r1_prep_state_1, r1_prep_share_1) = assert_matches!(
poplar
.prepare_step(init_prep_state_1.clone(), r1_prep_msg.clone())
.prepare_next(init_prep_state_1.clone(), r1_prep_msg.clone())
.unwrap(),
PrepareTransition::Continue(state, share) => (state, share)
);

let r2_prep_msg = poplar
.prepare_preprocess([r1_prep_share_0.clone(), r1_prep_share_1.clone()])
.prepare_shares_to_prepare_message(
&agg_param,
[r1_prep_share_0.clone(), r1_prep_share_1.clone()],
)
.unwrap();

let out_share_0 = assert_matches!(
poplar
.prepare_step(r1_prep_state_0.clone(), r2_prep_msg.clone())
.prepare_next(r1_prep_state_0.clone(), r2_prep_msg.clone())
.unwrap(),
PrepareTransition::Finish(out) => out
);
let out_share_1 = assert_matches!(
poplar
.prepare_step(r1_prep_state_1, r2_prep_msg.clone())
.prepare_next(r1_prep_state_1, r2_prep_msg.clone())
.unwrap(),
PrepareTransition::Finish(out) => out
);
Expand Down
11 changes: 6 additions & 5 deletions src/vdaf/prio2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,9 @@ impl Aggregator<32, 16> for Prio2 {
self.prepare_init_with_query_rand(query_rand, input_share, is_leader)
}

fn prepare_preprocess<M: IntoIterator<Item = Prio2PrepareShare>>(
fn prepare_shares_to_prepare_message<M: IntoIterator<Item = Prio2PrepareShare>>(
&self,
_: &Self::AggregationParam,
inputs: M,
) -> Result<(), VdafError> {
let verifier_shares: Vec<v2_server::VerificationMessage<FieldPrio2>> =
Expand All @@ -289,7 +290,7 @@ impl Aggregator<32, 16> for Prio2 {
Ok(())
}

fn prepare_step(
fn prepare_next(
&self,
state: Prio2PrepareState,
_input: (),
Expand Down Expand Up @@ -491,12 +492,12 @@ mod tests {
let (prepare_state_2, prepare_share_2) = vdaf
.prepare_init(&[0; 32], 1, &(), &[0; 16], &(), &input_share_2)
.unwrap();
vdaf.prepare_preprocess([prepare_share_1, prepare_share_2])
vdaf.prepare_shares_to_prepare_message(&(), [prepare_share_1, prepare_share_2])
.unwrap();
let transition_1 = vdaf.prepare_step(prepare_state_1, ()).unwrap();
let transition_1 = vdaf.prepare_next(prepare_state_1, ()).unwrap();
let output_share_1 =
assert_matches!(transition_1, PrepareTransition::Finish(out) => out);
let transition_2 = vdaf.prepare_step(prepare_state_2, ()).unwrap();
let transition_2 = vdaf.prepare_next(prepare_state_2, ()).unwrap();
let output_share_2 =
assert_matches!(transition_2, PrepareTransition::Finish(out) => out);
leader_output_shares.push(output_share_1);
Expand Down
13 changes: 9 additions & 4 deletions src/vdaf/prio3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ impl Prio3Average {
/// prep_states.push(state);
/// prep_shares.push(share);
/// }
/// let prep_msg = vdaf.prepare_preprocess(prep_shares).unwrap();
/// let prep_msg = vdaf.prepare_shares_to_prepare_message(&(), prep_shares).unwrap();
///
/// for (agg_id, state) in prep_states.into_iter().enumerate() {
/// let out_share = match vdaf.prepare_step(state, prep_msg.clone()).unwrap() {
Expand Down Expand Up @@ -1173,8 +1173,11 @@ where
))
}

fn prepare_preprocess<M: IntoIterator<Item = Prio3PrepareShare<T::Field, SEED_SIZE>>>(
fn prepare_shares_to_prepare_message<
M: IntoIterator<Item = Prio3PrepareShare<T::Field, SEED_SIZE>>,
>(
&self,
_: &Self::AggregationParam,
inputs: M,
) -> Result<Prio3PrepareMessage<SEED_SIZE>, VdafError> {
let mut verifier = vec![T::Field::zero(); self.typ.verifier_len()];
Expand Down Expand Up @@ -1228,7 +1231,7 @@ where
Ok(Prio3PrepareMessage { joint_rand_seed })
}

fn prepare_step(
fn prepare_next(
&self,
step: Prio3PrepareState<T::Field, SEED_SIZE>,
msg: Prio3PrepareMessage<SEED_SIZE>,
Expand Down Expand Up @@ -1851,7 +1854,9 @@ mod tests {
last_prepare_state = Some(prepare_state);
}

let prepare_message = prio3.prepare_preprocess(prepare_shares).unwrap();
let prepare_message = prio3
.prepare_shares_to_prepare_message(&(), prepare_shares)
.unwrap();

let encoded_prepare_message = prepare_message.get_encoded();
let decoded_prepare_message = Prio3PrepareMessage::get_decoded_with_param(
Expand Down
4 changes: 2 additions & 2 deletions src/vdaf/prio3_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,14 +106,14 @@ fn check_prep_test_vec<M, T, P, const SEED_SIZE: usize>(
}

let inbound = prio3
.prepare_preprocess(prep_shares)
.prepare_shares_to_prepare_message(&(), prep_shares)
.unwrap_or_else(|e| err!(test_num, e, "prep preprocess"));
assert_eq!(t.prep_messages.len(), 1);
assert_eq!(inbound.get_encoded(), t.prep_messages[0].as_ref());

let mut out_shares = Vec::new();
for state in states.iter_mut() {
match prio3.prepare_step(state.clone(), inbound.clone()).unwrap() {
match prio3.prepare_next(state.clone(), inbound.clone()).unwrap() {
PrepareTransition::Finish(out_share) => {
out_shares.push(out_share);
}
Expand Down

0 comments on commit e2751fe

Please sign in to comment.