Skip to content

Commit

Permalink
Clean up internal Prng API usage
Browse files Browse the repository at this point in the history
  • Loading branch information
cjpatton committed Nov 11, 2021
1 parent 0f0d237 commit caf7f08
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 99 deletions.
15 changes: 9 additions & 6 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::{
polynomial::{poly_fft, PolyAuxMemory},
prng::{Prng, PrngError},
util::{proof_length, unpack_proof_mut},
vdaf::suite::Suite,
vdaf::suite::{Key, KeyStream, Suite},
};

use std::convert::TryFrom;
Expand Down Expand Up @@ -99,11 +99,15 @@ impl<F: FieldElement> Client<F> {

// use prng to share the proof: share2 is the PRNG seed, and proof is mutated
// in-place
let share2 = crate::prng::secret_share(&mut proof)?;
let share2 = Key::generate(Suite::Aes128CtrHmacSha256)?;
let share2_prng = Prng::from_key_stream(KeyStream::from_key(&share2));
for (s1, d) in proof.iter_mut().zip(share2_prng.into_iter()) {
*s1 -= d;
}
let share1 = F::slice_into_byte_vec(&proof);
// encrypt shares with respective keys
let encrypted_share1 = encrypt_share(&share1, &self.public_key1)?;
let encrypted_share2 = encrypt_share(&share2, &self.public_key2)?;
let encrypted_share2 = encrypt_share(share2.as_slice(), &self.public_key2)?;
Ok((encrypted_share1, encrypted_share2))
}

Expand Down Expand Up @@ -185,8 +189,8 @@ fn construct_proof<F: FieldElement>(
let n = (dimension + 1).next_power_of_two();

// set zero terms to random
*f0 = mem.prng.next().unwrap();
*g0 = mem.prng.next().unwrap();
*f0 = mem.prng.get();
*g0 = mem.prng.get();
mem.points_f[0] = *f0;
mem.points_g[0] = *g0;

Expand Down Expand Up @@ -219,7 +223,6 @@ fn construct_proof<F: FieldElement>(
#[test]
fn test_encode() {
use crate::field::Field32;

let pub_key1 = PublicKey::from_base64(
"BIl6j+J6dYttxALdjISDv6ZI4/VWVEhUzaS05LgrsfswmbLOgNt9HUC2E0w+9RqZx3XMkdEHBHfNuCSMpOwofVQ=",
)
Expand Down
4 changes: 4 additions & 0 deletions src/encrypt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
//! Utilities for ECIES encryption / decryption used by the Prio client and server.

use crate::prng::PrngError;
use crate::vdaf::suite::SuiteError;

use aes_gcm::aead::generic_array::typenum::U16;
use aes_gcm::aead::generic_array::GenericArray;
Expand Down Expand Up @@ -39,6 +40,9 @@ pub enum EncryptError {
/// PRNG error
#[error("prng error: {0}")]
Prng(#[from] PrngError),
/// Suite error
#[error("suite error: {0}")]
Suite(#[from] SuiteError),
}

/// NIST P-256, public key in X9.62 uncompressed format
Expand Down
13 changes: 4 additions & 9 deletions src/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -686,8 +686,8 @@ mod tests {

// add + sub
for _ in 0..100 {
let f = prng.next().unwrap();
let g = prng.next().unwrap();
let f = prng.get();
let g = prng.get();
assert_eq!(f + g - f - g, zero);
assert_eq!(f + g - g, f);
assert_eq!(f + g - f, g);
Expand All @@ -708,7 +708,7 @@ mod tests {

// mul + div
for _ in 0..100 {
let f = prng.next().unwrap();
let f = prng.get();
if f == zero {
continue;
}
Expand Down Expand Up @@ -736,12 +736,7 @@ mod tests {
}

// serialization
let test_inputs = vec![
zero,
one,
prng.next().unwrap(),
F::from(int_modulus - int_one),
];
let test_inputs = vec![zero, one, prng.get(), F::from(int_modulus - int_one)];
for want in test_inputs.iter() {
println!("check {:?}", want);
let mut bytes: Vec<u8> = vec![];
Expand Down
4 changes: 2 additions & 2 deletions src/pcp/gadgets.rs
Original file line number Diff line number Diff line change
Expand Up @@ -306,10 +306,10 @@ mod tests {
let mut poly_outp = vec![F::zero(); (g.degree() * (1 + num_calls)).next_power_of_two()];
let mut poly_inp = vec![vec![F::zero(); 1 + num_calls]; g.arity()];

let r = prng.next().unwrap();
let r = prng.get();
for i in 0..g.arity() {
for j in 0..num_calls {
poly_inp[i][j] = prng.next().unwrap();
poly_inp[i][j] = prng.get();
}
inp[i] = poly_eval(&poly_inp[i], r);
}
Expand Down
92 changes: 24 additions & 68 deletions src/prng.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,46 +12,12 @@ use std::marker::PhantomData;

const BUFFER_SIZE_IN_ELEMENTS: usize = 128;

pub(crate) fn secret_share<F: FieldElement>(share1: &mut [F]) -> Result<Vec<u8>, PrngError> {
let key = Key::generate(Suite::Aes128CtrHmacSha256)?;

// get prng array
let data: Vec<F> = Prng::from_key_stream(KeyStream::from_key(&key))
.take(share1.len())
.collect();

// secret share
for (s1, d) in share1.iter_mut().zip(data.iter()) {
*s1 -= *d;
}

Ok(key.as_slice().to_vec())
}

pub(crate) fn extract_share_from_seed<F: FieldElement>(
length: usize,
seed: &[u8],
) -> Result<Vec<F>, PrngError> {
if seed.len() != aes::BLOCK_SIZE * 2 {
return Err(PrngError::SeedSize);
}

let mut key = [0; aes::BLOCK_SIZE * 2];
key.copy_from_slice(seed);
let key_stream = KeyStream::from_key(&Key::Aes128CtrHmacSha256(key));
Ok(Prng::from_key_stream(key_stream).take(length).collect())
}

/// Errors propagated by methods in this module.
#[derive(Debug, PartialEq, thiserror::Error)]
pub enum PrngError {
/// Tried to construct a PRNG from a seed of invalid length.
#[error("invalid seed length")]
SeedSize,

/// VDAF suite error.
#[error("vdaf suite error: {0}")]
VdafSuite(#[from] SuiteError),
/// Suite error.
#[error("suite error: {0}")]
Suite(#[from] SuiteError),
}

/// This type implements an iterator that generates a pseudorandom sequence of field elements. The
Expand Down Expand Up @@ -86,12 +52,8 @@ impl<F: FieldElement> Prng<F> {
output_written: 0,
}
}
}

impl<F: FieldElement> Iterator for Prng<F> {
type Item = F;

fn next(&mut self) -> Option<F> {
pub(crate) fn get(&mut self) -> F {
loop {
// Seek to the next chunk of the buffer that encodes an element of F.
for i in (self.buffer_index..self.buffer.len()).step_by(F::ENCODED_SIZE) {
Expand All @@ -104,7 +66,7 @@ impl<F: FieldElement> Iterator for Prng<F> {
// Set the buffer index to the next chunk.
self.buffer_index = j;
self.output_written += 1;
return Some(x);
return x;
}
}

Expand All @@ -118,32 +80,18 @@ impl<F: FieldElement> Iterator for Prng<F> {
}
}

impl<F: FieldElement> Iterator for Prng<F> {
type Item = F;

fn next(&mut self) -> Option<F> {
Some(self.get())
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::field::{Field32, FieldPriov2};

#[test]
fn secret_sharing() {
let mut data = vec![Field32::from(0); 123];
data[3] = 23.into();

let data_clone = data.clone();

let seed = secret_share(&mut data).unwrap();
assert_ne!(data, data_clone);

let share2 = extract_share_from_seed(data.len(), &seed).unwrap();

assert_eq!(data.len(), share2.len());

// recombine
for (d, d2) in data.iter_mut().zip(share2.iter()) {
*d += *d2;
}

assert_eq!(data, data_clone);
}
use crate::field::FieldPriov2;

#[test]
fn secret_sharing_interop() {
Expand All @@ -161,15 +109,15 @@ mod tests {
0xacb8b748, 0x6f5b9d49, 0x887d061b, 0x86db0c58,
];

let share2 = extract_share_from_seed::<FieldPriov2>(reference.len(), &seed).unwrap();
let share2 = extract_share_from_seed::<FieldPriov2>(reference.len(), &seed);

assert_eq!(share2, reference);
}

/// takes a seed and hash as base64 encoded strings
fn random_data_interop(seed_base64: &str, hash_base64: &str, len: usize) {
let seed = base64::decode(seed_base64).unwrap();
let random_data = extract_share_from_seed::<FieldPriov2>(len, &seed).unwrap();
let random_data = extract_share_from_seed::<FieldPriov2>(len, &seed);

let random_bytes = FieldPriov2::slice_into_byte_vec(&random_data);

Expand Down Expand Up @@ -210,4 +158,12 @@ mod tests {
100_000,
);
}

fn extract_share_from_seed<F: FieldElement>(length: usize, seed: &[u8]) -> Vec<F> {
assert_eq!(seed.len(), aes::BLOCK_SIZE * 2);
let mut key = [0; aes::BLOCK_SIZE * 2];
key.copy_from_slice(seed);
let key_stream = KeyStream::from_key(&Key::Aes128CtrHmacSha256(key));
Prng::from_key_stream(key_stream).take(length).collect()
}
}
20 changes: 15 additions & 5 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,18 @@ use crate::{
encrypt::{decrypt_share, EncryptError, PrivateKey},
field::{merge_vector, FieldElement, FieldError},
polynomial::{poly_interpret_eval, PolyAuxMemory},
prng::{extract_share_from_seed, Prng, PrngError},
prng::{Prng, PrngError},
util::{proof_length, unpack_proof, SerializeError},
vdaf::suite::Suite,
vdaf::suite::{Key, KeyStream, Suite},
};
use serde::{Deserialize, Serialize};

/// Possible errors from server operations
#[derive(Debug, thiserror::Error)]
pub enum ServerError {
/// Unexpected Share Length
#[error("unexpected share length")]
ShareLength,
/// Encryption/decryption error
#[error("encryption/decryption error")]
Encrypt(#[from] EncryptError),
Expand Down Expand Up @@ -91,12 +94,19 @@ impl<F: FieldElement> Server<F> {

/// Decrypt and deserialize
fn deserialize_share(&self, encrypted_share: &[u8]) -> Result<Vec<F>, ServerError> {
let len = proof_length(self.dimension);
let share = decrypt_share(encrypted_share, &self.private_key)?;
Ok(if self.is_first_server {
F::byte_slice_into_vec(&share)?
} else {
let len = proof_length(self.dimension);
extract_share_from_seed(len, &share)?
if share.len() != 32 {
return Err(ServerError::ShareLength);
}

let mut key = [0; 32];
key.copy_from_slice(&share);
let key_stream = KeyStream::from_key(&Key::Aes128CtrHmacSha256(key));
Prng::from_key_stream(key_stream).take(len).collect()
})
}

Expand Down Expand Up @@ -169,7 +179,7 @@ impl<F: FieldElement> Server<F> {
/// evaluation.
pub fn choose_eval_at(&mut self) -> F {
loop {
let eval_at = self.prng.next().unwrap();
let eval_at = self.prng.get();
if !self.validation_mem.poly_mem.roots_2n.contains(&eval_at) {
break eval_at;
}
Expand Down
15 changes: 6 additions & 9 deletions src/vdaf/hits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,19 +291,16 @@ impl<I: Idpf<2, 2>> Client for Hits<I> {
// [BBCG+21, Appendix C.4]
//
// $(a, b, c)$
let a =
leader_sketch_start_prng.next().unwrap() + helper_sketch_start_prng.next().unwrap();
let b =
leader_sketch_start_prng.next().unwrap() + helper_sketch_start_prng.next().unwrap();
let c =
leader_sketch_start_prng.next().unwrap() + helper_sketch_start_prng.next().unwrap();
let a = leader_sketch_start_prng.get() + helper_sketch_start_prng.get();
let b = leader_sketch_start_prng.get() + helper_sketch_start_prng.get();
let c = leader_sketch_start_prng.get() + helper_sketch_start_prng.get();

// $A = -2a + k$
// $B = a^2 + b + -ak + c$
let d = k - (a + a);
let e = (a * a) + b - (a * k) + c;
leader_sketch_next.push(d - helper_sketch_next_prng.next().unwrap());
leader_sketch_next.push(e - helper_sketch_next_prng.next().unwrap());
leader_sketch_next.push(d - helper_sketch_next_prng.get());
leader_sketch_next.push(e - helper_sketch_next_prng.get());
}

// Generate IDPF shares of the data and authentication vectors.
Expand Down Expand Up @@ -389,7 +386,7 @@ impl<I: Idpf<2, 2>> Aggregator for Hits<I> {
for prefix in agg_param.iter() {
let value = input_share.idpf.eval(prefix)?;
let (v, k) = (value[0], value[1]);
let r = verify_rand_prng.next().unwrap();
let r = verify_rand_prng.get();

// [BBCG+21, Appendix C.4]
//
Expand Down

0 comments on commit caf7f08

Please sign in to comment.