diff --git a/src/client.rs b/src/client.rs index 29e24e2b0..a592ca4ab 100644 --- a/src/client.rs +++ b/src/client.rs @@ -9,7 +9,7 @@ use crate::{ polynomial::{poly_fft, PolyAuxMemory}, prng::{Prng, PrngError}, util::{proof_length, unpack_proof_mut}, - vdaf::suite::Suite, + vdaf::suite::{Key, KeyStream, Suite}, }; use std::convert::TryFrom; @@ -99,11 +99,15 @@ impl Client { // use prng to share the proof: share2 is the PRNG seed, and proof is mutated // in-place - let share2 = crate::prng::secret_share(&mut proof)?; + let share2 = Key::generate(Suite::Aes128CtrHmacSha256)?; + let share2_prng = Prng::from_key_stream(KeyStream::from_key(&share2)); + for (s1, d) in proof.iter_mut().zip(share2_prng.into_iter()) { + *s1 -= d; + } let share1 = F::slice_into_byte_vec(&proof); // encrypt shares with respective keys let encrypted_share1 = encrypt_share(&share1, &self.public_key1)?; - let encrypted_share2 = encrypt_share(&share2, &self.public_key2)?; + let encrypted_share2 = encrypt_share(share2.as_slice(), &self.public_key2)?; Ok((encrypted_share1, encrypted_share2)) } @@ -185,8 +189,8 @@ fn construct_proof( let n = (dimension + 1).next_power_of_two(); // set zero terms to random - *f0 = mem.prng.next().unwrap(); - *g0 = mem.prng.next().unwrap(); + *f0 = mem.prng.get(); + *g0 = mem.prng.get(); mem.points_f[0] = *f0; mem.points_g[0] = *g0; @@ -219,7 +223,6 @@ fn construct_proof( #[test] fn test_encode() { use crate::field::Field32; - let pub_key1 = PublicKey::from_base64( "BIl6j+J6dYttxALdjISDv6ZI4/VWVEhUzaS05LgrsfswmbLOgNt9HUC2E0w+9RqZx3XMkdEHBHfNuCSMpOwofVQ=", ) diff --git a/src/encrypt.rs b/src/encrypt.rs index cd9432cd2..193fd2132 100644 --- a/src/encrypt.rs +++ b/src/encrypt.rs @@ -4,6 +4,7 @@ //! Utilities for ECIES encryption / decryption used by the Prio client and server. use crate::prng::PrngError; +use crate::vdaf::suite::SuiteError; use aes_gcm::aead::generic_array::typenum::U16; use aes_gcm::aead::generic_array::GenericArray; @@ -39,6 +40,9 @@ pub enum EncryptError { /// PRNG error #[error("prng error: {0}")] Prng(#[from] PrngError), + /// Suite error + #[error("suite error: {0}")] + Suite(#[from] SuiteError), } /// NIST P-256, public key in X9.62 uncompressed format diff --git a/src/field.rs b/src/field.rs index 3b008b405..48d52cf1d 100644 --- a/src/field.rs +++ b/src/field.rs @@ -686,8 +686,8 @@ mod tests { // add + sub for _ in 0..100 { - let f = prng.next().unwrap(); - let g = prng.next().unwrap(); + let f = prng.get(); + let g = prng.get(); assert_eq!(f + g - f - g, zero); assert_eq!(f + g - g, f); assert_eq!(f + g - f, g); @@ -708,7 +708,7 @@ mod tests { // mul + div for _ in 0..100 { - let f = prng.next().unwrap(); + let f = prng.get(); if f == zero { continue; } @@ -736,12 +736,7 @@ mod tests { } // serialization - let test_inputs = vec![ - zero, - one, - prng.next().unwrap(), - F::from(int_modulus - int_one), - ]; + let test_inputs = vec![zero, one, prng.get(), F::from(int_modulus - int_one)]; for want in test_inputs.iter() { println!("check {:?}", want); let mut bytes: Vec = vec![]; diff --git a/src/pcp/gadgets.rs b/src/pcp/gadgets.rs index c6cbc7be0..04e403ce0 100644 --- a/src/pcp/gadgets.rs +++ b/src/pcp/gadgets.rs @@ -306,10 +306,10 @@ mod tests { let mut poly_outp = vec![F::zero(); (g.degree() * (1 + num_calls)).next_power_of_two()]; let mut poly_inp = vec![vec![F::zero(); 1 + num_calls]; g.arity()]; - let r = prng.next().unwrap(); + let r = prng.get(); for i in 0..g.arity() { for j in 0..num_calls { - poly_inp[i][j] = prng.next().unwrap(); + poly_inp[i][j] = prng.get(); } inp[i] = poly_eval(&poly_inp[i], r); } diff --git a/src/prng.rs b/src/prng.rs index 802690e0f..6c1fefaa4 100644 --- a/src/prng.rs +++ b/src/prng.rs @@ -12,46 +12,12 @@ use std::marker::PhantomData; const BUFFER_SIZE_IN_ELEMENTS: usize = 128; -pub(crate) fn secret_share(share1: &mut [F]) -> Result, PrngError> { - let key = Key::generate(Suite::Aes128CtrHmacSha256)?; - - // get prng array - let data: Vec = Prng::from_key_stream(KeyStream::from_key(&key)) - .take(share1.len()) - .collect(); - - // secret share - for (s1, d) in share1.iter_mut().zip(data.iter()) { - *s1 -= *d; - } - - Ok(key.as_slice().to_vec()) -} - -pub(crate) fn extract_share_from_seed( - length: usize, - seed: &[u8], -) -> Result, PrngError> { - if seed.len() != aes::BLOCK_SIZE * 2 { - return Err(PrngError::SeedSize); - } - - let mut key = [0; aes::BLOCK_SIZE * 2]; - key.copy_from_slice(seed); - let key_stream = KeyStream::from_key(&Key::Aes128CtrHmacSha256(key)); - Ok(Prng::from_key_stream(key_stream).take(length).collect()) -} - /// Errors propagated by methods in this module. #[derive(Debug, PartialEq, thiserror::Error)] pub enum PrngError { - /// Tried to construct a PRNG from a seed of invalid length. - #[error("invalid seed length")] - SeedSize, - - /// VDAF suite error. - #[error("vdaf suite error: {0}")] - VdafSuite(#[from] SuiteError), + /// Suite error. + #[error("suite error: {0}")] + Suite(#[from] SuiteError), } /// This type implements an iterator that generates a pseudorandom sequence of field elements. The @@ -86,12 +52,8 @@ impl Prng { output_written: 0, } } -} -impl Iterator for Prng { - type Item = F; - - fn next(&mut self) -> Option { + pub(crate) fn get(&mut self) -> F { loop { // Seek to the next chunk of the buffer that encodes an element of F. for i in (self.buffer_index..self.buffer.len()).step_by(F::ENCODED_SIZE) { @@ -104,7 +66,7 @@ impl Iterator for Prng { // Set the buffer index to the next chunk. self.buffer_index = j; self.output_written += 1; - return Some(x); + return x; } } @@ -118,32 +80,18 @@ impl Iterator for Prng { } } +impl Iterator for Prng { + type Item = F; + + fn next(&mut self) -> Option { + Some(self.get()) + } +} + #[cfg(test)] mod tests { use super::*; - use crate::field::{Field32, FieldPriov2}; - - #[test] - fn secret_sharing() { - let mut data = vec![Field32::from(0); 123]; - data[3] = 23.into(); - - let data_clone = data.clone(); - - let seed = secret_share(&mut data).unwrap(); - assert_ne!(data, data_clone); - - let share2 = extract_share_from_seed(data.len(), &seed).unwrap(); - - assert_eq!(data.len(), share2.len()); - - // recombine - for (d, d2) in data.iter_mut().zip(share2.iter()) { - *d += *d2; - } - - assert_eq!(data, data_clone); - } + use crate::field::FieldPriov2; #[test] fn secret_sharing_interop() { @@ -161,7 +109,7 @@ mod tests { 0xacb8b748, 0x6f5b9d49, 0x887d061b, 0x86db0c58, ]; - let share2 = extract_share_from_seed::(reference.len(), &seed).unwrap(); + let share2 = extract_share_from_seed::(reference.len(), &seed); assert_eq!(share2, reference); } @@ -169,7 +117,7 @@ mod tests { /// takes a seed and hash as base64 encoded strings fn random_data_interop(seed_base64: &str, hash_base64: &str, len: usize) { let seed = base64::decode(seed_base64).unwrap(); - let random_data = extract_share_from_seed::(len, &seed).unwrap(); + let random_data = extract_share_from_seed::(len, &seed); let random_bytes = FieldPriov2::slice_into_byte_vec(&random_data); @@ -210,4 +158,12 @@ mod tests { 100_000, ); } + + fn extract_share_from_seed(length: usize, seed: &[u8]) -> Vec { + assert_eq!(seed.len(), aes::BLOCK_SIZE * 2); + let mut key = [0; aes::BLOCK_SIZE * 2]; + key.copy_from_slice(seed); + let key_stream = KeyStream::from_key(&Key::Aes128CtrHmacSha256(key)); + Prng::from_key_stream(key_stream).take(length).collect() + } } diff --git a/src/server.rs b/src/server.rs index 7ee4021f7..60b879658 100644 --- a/src/server.rs +++ b/src/server.rs @@ -6,15 +6,18 @@ use crate::{ encrypt::{decrypt_share, EncryptError, PrivateKey}, field::{merge_vector, FieldElement, FieldError}, polynomial::{poly_interpret_eval, PolyAuxMemory}, - prng::{extract_share_from_seed, Prng, PrngError}, + prng::{Prng, PrngError}, util::{proof_length, unpack_proof, SerializeError}, - vdaf::suite::Suite, + vdaf::suite::{Key, KeyStream, Suite}, }; use serde::{Deserialize, Serialize}; /// Possible errors from server operations #[derive(Debug, thiserror::Error)] pub enum ServerError { + /// Unexpected Share Length + #[error("unexpected share length")] + ShareLength, /// Encryption/decryption error #[error("encryption/decryption error")] Encrypt(#[from] EncryptError), @@ -91,12 +94,19 @@ impl Server { /// Decrypt and deserialize fn deserialize_share(&self, encrypted_share: &[u8]) -> Result, ServerError> { + let len = proof_length(self.dimension); let share = decrypt_share(encrypted_share, &self.private_key)?; Ok(if self.is_first_server { F::byte_slice_into_vec(&share)? } else { - let len = proof_length(self.dimension); - extract_share_from_seed(len, &share)? + if share.len() != 32 { + return Err(ServerError::ShareLength); + } + + let mut key = [0; 32]; + key.copy_from_slice(&share); + let key_stream = KeyStream::from_key(&Key::Aes128CtrHmacSha256(key)); + Prng::from_key_stream(key_stream).take(len).collect() }) } @@ -169,7 +179,7 @@ impl Server { /// evaluation. pub fn choose_eval_at(&mut self) -> F { loop { - let eval_at = self.prng.next().unwrap(); + let eval_at = self.prng.get(); if !self.validation_mem.poly_mem.roots_2n.contains(&eval_at) { break eval_at; } diff --git a/src/vdaf/hits.rs b/src/vdaf/hits.rs index 18df8b30a..9cb8d86cf 100644 --- a/src/vdaf/hits.rs +++ b/src/vdaf/hits.rs @@ -291,19 +291,16 @@ impl> Client for Hits { // [BBCG+21, Appendix C.4] // // $(a, b, c)$ - let a = - leader_sketch_start_prng.next().unwrap() + helper_sketch_start_prng.next().unwrap(); - let b = - leader_sketch_start_prng.next().unwrap() + helper_sketch_start_prng.next().unwrap(); - let c = - leader_sketch_start_prng.next().unwrap() + helper_sketch_start_prng.next().unwrap(); + let a = leader_sketch_start_prng.get() + helper_sketch_start_prng.get(); + let b = leader_sketch_start_prng.get() + helper_sketch_start_prng.get(); + let c = leader_sketch_start_prng.get() + helper_sketch_start_prng.get(); // $A = -2a + k$ // $B = a^2 + b + -ak + c$ let d = k - (a + a); let e = (a * a) + b - (a * k) + c; - leader_sketch_next.push(d - helper_sketch_next_prng.next().unwrap()); - leader_sketch_next.push(e - helper_sketch_next_prng.next().unwrap()); + leader_sketch_next.push(d - helper_sketch_next_prng.get()); + leader_sketch_next.push(e - helper_sketch_next_prng.get()); } // Generate IDPF shares of the data and authentication vectors. @@ -389,7 +386,7 @@ impl> Aggregator for Hits { for prefix in agg_param.iter() { let value = input_share.idpf.eval(prefix)?; let (v, k) = (value[0], value[1]); - let r = verify_rand_prng.next().unwrap(); + let r = verify_rand_prng.get(); // [BBCG+21, Appendix C.4] //