Skip to content

Commit

Permalink
caching completed
Browse files Browse the repository at this point in the history
  • Loading branch information
Hannah Davis committed Oct 3, 2024
1 parent 5699e50 commit 4d47513
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 61 deletions.
2 changes: 1 addition & 1 deletion src/idpf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ impl IdpfInput {
/// Return the single bit of this IDPF input at the given level.
pub fn next_branch(&self, level: usize) -> Self {
Self {
index: self.index[level - 1..level].to_owned().into(),
index: self.index[level..level + 1].to_owned().into(),
}
}

Expand Down
47 changes: 24 additions & 23 deletions src/vdaf/mastic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
//! [draft-mouris-cfrg-mastic-01]: https://www.ietf.org/archive/id/draft-mouris-cfrg-mastic-01.html

use crate::{
bt::{BinaryTree, Path},
codec::{CodecError, Decode, Encode, ParameterizedDecode},
field::{decode_fieldvec, FieldElement},
flp::{
Expand All @@ -18,7 +19,8 @@ use crate::{
PrepareTransition, Vdaf, VdafError,
},
vidpf::{
Vidpf, VidpfError, VidpfInput, VidpfKey, VidpfPublicShare, VidpfServerId, VidpfWeight,
Vidpf, VidpfError, VidpfEvalCache, VidpfInput, VidpfKey, VidpfPublicShare, VidpfServerId,
VidpfWeight,
},
};

Expand Down Expand Up @@ -549,33 +551,32 @@ where
let mut output_shares = Vec::<T::Field>::with_capacity(
self.vidpf.weight_parameter * agg_param.level_and_prefixes.prefixes().len(),
);
let mut cache_tree = BinaryTree::<VidpfEvalCache<VidpfWeight<T::Field>>>::default();
let cache = VidpfEvalCache::<VidpfWeight<T::Field>>::init_from_key(
&input_share.vidpf_key,
&self.vidpf.weight_parameter,
);
cache_tree
.insert(Path::empty(), cache)
.expect("Should alwys be able to insert into empty tree at root");
for prefix in agg_param.level_and_prefixes.prefixes() {
let mut value_share =
self.vidpf
.eval(&input_share.vidpf_key, public_share, prefix, nonce)?;
let mut value_share = self.vidpf.eval_with_cache(
&input_share.vidpf_key,
public_share,
prefix,
&mut cache_tree,
nonce,
)?;
xof.update(&value_share.proof);
output_shares.append(&mut value_share.share.0);
}
let root_share_opt = if agg_param.require_check_flag {
Some(
self.vidpf
.eval(
&input_share.vidpf_key,
public_share,
&VidpfInput::from_bools(&[false]),
nonce,
)?
.share
+ self
.vidpf
.eval(
&input_share.vidpf_key,
public_share,
&VidpfInput::from_bools(&[true]),
nonce,
)?
.share,
)
Some(self.vidpf.eval_root_with_cache(
&input_share.vidpf_key,
public_share,
&mut cache_tree,
nonce,
)?)
} else {
None
};
Expand Down
171 changes: 134 additions & 37 deletions src/vidpf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use core::{
};

use bitvec::field::BitField;
use bitvec::prelude::{BitVec, Lsb0};
use bitvec::prelude::{BitSlice, BitVec, Lsb0};
use rand_core::RngCore;
use std::fmt::Debug;
use std::io::{Cursor, Read};
Expand Down Expand Up @@ -232,7 +232,7 @@ impl<W: VidpfValue, const NONCE_SIZE: usize> Vidpf<W, NONCE_SIZE> {
})
}

/// [`Vidpf::eval_cached`] evaluates the entire `input` and produces a share of the
/// [`Vidpf::eval_with_cache`] evaluates the entire `input` and produces a share of the
/// input's weight. It reuses computation from previous levels available in the
/// cache
pub fn eval_with_cache(
Expand All @@ -250,24 +250,19 @@ impl<W: VidpfValue, const NONCE_SIZE: usize> Vidpf<W, NONCE_SIZE> {

let state = VidpfEvalState::init_from_key(key);
let path = input;
match cache_tree
.insert(
input.prefix(0).index.as_bitslice(),
self.eval_next_with_cache(key.id, public, input, 0, &state, nonce)?
)
{
Ok(_) => (),
Err(value) => match value {
BinaryTreeError::<VidpfEvalCache<W>>::InsertNonEmptyNode(_) => Ok(()),
BinaryTreeError::<VidpfEvalCache<W>>::UnreachableNode(_) => Err(VidpfError::CacheError),
}?
};

let mut cache_node = cache_tree
.get_node(IdpfInput::from_bools(&[]).index.as_bitslice())
.expect("previous match statement ensures initialization");

for level in 1..n {
let first_cache = VidpfEvalCache::init_from_state(state, W::zero(&self.weight_parameter));

let mut cache_node = match cache_tree.get_node(BitSlice::empty()) {
Some(node) => Ok(node),
None => match cache_tree.insert(BitSlice::empty(), first_cache) {
Ok(_) => Ok(cache_tree
.get_node(BitSlice::empty())
.expect("node inserted by above successful statement")),
Err(BinaryTreeError::InsertNonEmptyNode(_)) => Err(VidpfError::CacheError),
Err(BinaryTreeError::UnreachableNode(_)) => Err(VidpfError::CacheError),
},
}?;
for level in 0..n {
if cache_node
.get(path.next_branch(level).index.as_bitslice())
.is_some()
Expand All @@ -290,10 +285,10 @@ impl<W: VidpfValue, const NONCE_SIZE: usize> Vidpf<W, NONCE_SIZE> {
.expect("node was inserted by previous statement");
}
}
let final_cache = cache_node
.get(IdpfInput::from_bools(&[]).index.as_bitslice())
.expect("node was inserted by last loop iteration");
Ok(final_cache.to_share())
Ok(cache_tree
.get(input.index.as_bitslice())
.expect("node was inserted by last loop iteration")
.to_share())
}

/// [`Vidpf::eval_next`] evaluates the `input` at the given level using the provided initial
Expand Down Expand Up @@ -346,19 +341,31 @@ impl<W: VidpfValue, const NONCE_SIZE: usize> Vidpf<W, NONCE_SIZE> {
Ok((next_state, y))
}

/// [`Vidpf::eval_next_cached`] evaluates the `input` at the given level using the provided initial
/// state, and returns a cache containing a new state and a share of the input's weight at that level.
fn eval_next_with_cache(
pub(crate) fn eval_root_with_cache(
&self,
id: VidpfServerId,
public: &VidpfPublicShare<W>,
input: &VidpfInput,
level: usize,
state: &VidpfEvalState,
key: &VidpfKey,
public_share: &VidpfPublicShare<W>,
cache_tree: &mut BinaryTree<VidpfEvalCache<W>>,
nonce: &[u8; NONCE_SIZE],
) -> Result<VidpfEvalCache<W>, VidpfError> {
let (state, share) = self.eval_next(id, public, input, level, state, nonce)?;
Ok(VidpfEvalCache::init_from_state(state, share))
) -> Result<W, VidpfError> {
Ok(self
.eval_with_cache(
key,
public_share,
&VidpfInput::from_bools(&[false]),
cache_tree,
nonce,
)?
.share
+ self
.eval_with_cache(
key,
public_share,
&VidpfInput::from_bools(&[true]),
cache_tree,
nonce,
)?
.share)
}

pub(crate) fn eval_root(
Expand Down Expand Up @@ -662,6 +669,12 @@ pub struct VidpfEvalCache<W: VidpfValue> {
}

impl<W: VidpfValue> VidpfEvalCache<W> {
pub(crate) fn init_from_key(key: &VidpfKey, length: &W::ValueParameter) -> Self {
Self {
state: VidpfEvalState::init_from_key(key),
share: W::zero(length),
}
}
fn init_from_state(state: VidpfEvalState, share: W) -> Self {
Self { state, share }
}
Expand Down Expand Up @@ -847,6 +860,7 @@ impl<F: FieldElement> ParameterizedDecode<<Self as IdpfValue>::ValueParameter> f

#[cfg(test)]
mod tests {

use crate::field::Field128;

use super::VidpfWeight;
Expand All @@ -858,10 +872,11 @@ mod tests {

mod vidpf {
use crate::{
bt::{BinaryTree, Path},
idpf::IdpfValue,
vidpf::{
Vidpf, VidpfError, VidpfEvalState, VidpfInput, VidpfKey, VidpfPublicShare,
VidpfServerId,
Vidpf, VidpfError, VidpfEvalCache, VidpfEvalState, VidpfInput, VidpfKey,
VidpfPublicShare, VidpfServerId,
},
};

Expand Down Expand Up @@ -985,6 +1000,88 @@ mod tests {
);
}
}

#[test]
fn caching_at_each_level() {
let input = VidpfInput::from_bytes(&[0xFF]);
let weight = TestWeight::from(vec![21.into(), 22.into(), 23.into()]);
let (vidpf, public, keys, nonce) = vidpf_gen_setup(&input, &weight);

equivalence_of_eval_with_caching(&vidpf, &keys, &public, &input, &nonce);
}

fn equivalence_of_eval_with_caching(
vidpf: &Vidpf<TestWeight, TEST_NONCE_SIZE>,
[key_0, key_1]: &[VidpfKey; 2],
public: &VidpfPublicShare<TestWeight>,
input: &VidpfInput,
nonce: &[u8; TEST_NONCE_SIZE],
) {
let mut cache_tree_0 = BinaryTree::<VidpfEvalCache<TestWeight>>::default();
let mut cache_tree_1 = BinaryTree::<VidpfEvalCache<TestWeight>>::default();
let cache_0 =
VidpfEvalCache::<TestWeight>::init_from_key(key_0, &vidpf.weight_parameter);
let cache_1 =
VidpfEvalCache::<TestWeight>::init_from_key(key_1, &vidpf.weight_parameter);
cache_tree_0
.insert(Path::empty(), cache_0)
.expect("Should alwys be able to insert into empty tree at root");
cache_tree_1
.insert(Path::empty(), cache_1)
.expect("Should alwys be able to insert into empty tree at root");

let n = input.len();
for level in 0..n {
let val_share_0 = vidpf
.eval(key_0, public, &input.prefix(level), nonce)
.unwrap();
let val_share_1 = vidpf
.eval(key_1, public, &input.prefix(level), nonce)
.unwrap();
let val_share_0_cached = vidpf
.eval_with_cache(
key_0,
public,
&input.prefix(level),
&mut cache_tree_0,
nonce,
)
.unwrap();
let val_share_1_cached = vidpf
.eval_with_cache(
key_1,
public,
&input.prefix(level),
&mut cache_tree_1,
nonce,
)
.unwrap();

assert_eq!(
val_share_0.share, val_share_0_cached.share,
"shares must be computed equally with or without caching: {:?}",
level
);

assert_eq!(
val_share_1.share, val_share_1_cached.share,
"shares must be computed equally with or without caching: {:?}",
level
);

assert_eq!(
val_share_0.proof, val_share_0_cached.proof,
"proofs must be equal with or without caching: {:?}",
level
);

assert_eq!(
val_share_1.proof, val_share_1_cached.proof,
"proofs must be equal with or without caching: {:?}",
level
);
}
}
}

mod weight {
Expand Down

0 comments on commit 4d47513

Please sign in to comment.