Skip to content

Commit

Permalink
Implement Prio3MutlihotCountVec (#1123)
Browse files Browse the repository at this point in the history
  • Loading branch information
rozbb authored Oct 31, 2024
1 parent a2ffdcd commit d2fe428
Show file tree
Hide file tree
Showing 3 changed files with 404 additions and 7 deletions.
5 changes: 4 additions & 1 deletion src/flp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,10 @@ pub trait Type: Sized + Eq + Clone + Debug {
measurement: &Self::Measurement,
) -> Result<Vec<Self::Field>, FlpError>;

/// Decode an aggregate result.
/// Decodes an aggregate result.
///
/// This is NOT the inverse of `encode_measurement`. Rather, the input is an aggregation of
/// truncated measurements.
fn decode_result(
&self,
data: &[Self::Field],
Expand Down
349 changes: 345 additions & 4 deletions src/flp/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use crate::field::{FftFriendlyFieldElement, FieldElementWithIntegerExt};
use crate::flp::gadgets::{Mul, ParallelSumGadget, PolyEval};
use crate::flp::{FlpError, Gadget, Type};
use crate::polynomial::poly_range_check;
use crate::vdaf::prio3::ilog2;
use std::convert::TryInto;
use std::fmt::{self, Debug};
use std::marker::PhantomData;
Expand Down Expand Up @@ -471,6 +472,237 @@ where
}
}

/// The multihot counter data type. Each measurement is a list of booleans of length `length`, with
/// at most `max_weight` true values, and the aggregate is a histogram counting the number of true
/// values at each position across all measurements.
#[derive(PartialEq, Eq)]
pub struct MultihotCountVec<F, S> {
// Parameters
/// The number of elements in the list of booleans
length: usize,
/// The max number of permissible `true` values in the list of booleans
max_weight: usize,
/// The size of the chunks fed into our gadget calls
chunk_length: usize,

// Calculated from parameters
gadget_calls: usize,
bits_for_weight: usize,
offset: usize,
phantom: PhantomData<(F, S)>,
}

impl<F: FftFriendlyFieldElement, S> Debug for MultihotCountVec<F, S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("MultihotCountVec")
.field("length", &self.length)
.field("max_weight", &self.max_weight)
.field("chunk_length", &self.chunk_length)
.finish()
}
}

impl<F: FftFriendlyFieldElement, S: ParallelSumGadget<F, Mul<F>>> MultihotCountVec<F, S> {
/// Return a new [`MultihotCountVec`] type with the given number of buckets.
pub fn new(
num_buckets: usize,
max_weight: usize,
chunk_length: usize,
) -> Result<Self, FlpError> {
if num_buckets >= u32::MAX as usize {
return Err(FlpError::Encode(
"invalid num_buckets: exceeds maximum permitted".to_string(),
));
}
if num_buckets == 0 {
return Err(FlpError::InvalidParameter(
"num_buckets cannot be zero".to_string(),
));
}
if chunk_length == 0 {
return Err(FlpError::InvalidParameter(
"chunk_length cannot be zero".to_string(),
));
}
if max_weight == 0 {
return Err(FlpError::InvalidParameter(
"max_weight cannot be zero".to_string(),
));
}

// The bitlength of a measurement is the number of buckets plus the bitlength of the max
// weight
let bits_for_weight = ilog2(max_weight) as usize + 1;
let meas_length = num_buckets + bits_for_weight;

// Gadget calls is ⌈meas_length / chunk_length⌉
let gadget_calls = (meas_length + chunk_length - 1) / chunk_length;
// Offset is 2^max_weight.bitlen() - 1 - max_weight
let offset = (1 << bits_for_weight) - 1 - max_weight;

Ok(Self {
length: num_buckets,
max_weight,
chunk_length,
gadget_calls,
bits_for_weight,
offset,
phantom: PhantomData,
})
}
}

// Cannot autoderive clone because it requires F and S to be Clone, which they're not in general
impl<F, S> Clone for MultihotCountVec<F, S> {
fn clone(&self) -> Self {
Self {
length: self.length,
max_weight: self.max_weight,
chunk_length: self.chunk_length,
bits_for_weight: self.bits_for_weight,
offset: self.offset,
gadget_calls: self.gadget_calls,
phantom: self.phantom,
}
}
}

impl<F, S> Type for MultihotCountVec<F, S>
where
F: FftFriendlyFieldElement,
S: ParallelSumGadget<F, Mul<F>> + Eq + 'static,
{
type Measurement = Vec<bool>;
type AggregateResult = Vec<F::Integer>;
type Field = F;

fn encode_measurement(&self, measurement: &Vec<bool>) -> Result<Vec<F>, FlpError> {
let weight_reported: usize = measurement.iter().filter(|bit| **bit).count();

if measurement.len() != self.length {
return Err(FlpError::Encode(format!(
"unexpected measurement length: got {}; want {}",
measurement.len(),
self.length
)));
}
if weight_reported > self.max_weight {
return Err(FlpError::Encode(format!(
"unexpected measurement weight: got {}; want ≤{}",
weight_reported, self.max_weight
)));
}

// Convert bool vector to field elems
let multihot_vec: Vec<F> = measurement
.iter()
// We can unwrap because any Integer type can cast from bool
.map(|bit| F::from(F::valid_integer_try_from(*bit as usize).unwrap()))
.collect();

// Encode the measurement weight in binary (actually, the weight plus some offset)
let offset_weight_bits = {
let offset_weight_reported = F::valid_integer_try_from(self.offset + weight_reported)?;
F::encode_as_bitvector(offset_weight_reported, self.bits_for_weight)?.collect()
};

// Report the concat of the two
Ok([multihot_vec, offset_weight_bits].concat())
}

fn decode_result(
&self,
data: &[Self::Field],
_num_measurements: usize,
) -> Result<Self::AggregateResult, FlpError> {
// The aggregate is the same as the decoded result. Just convert to integers
decode_result_vec(data, self.length)
}

fn gadget(&self) -> Vec<Box<dyn Gadget<F>>> {
vec![Box::new(S::new(
Mul::new(self.gadget_calls),
self.chunk_length,
))]
}

fn valid(
&self,
g: &mut Vec<Box<dyn Gadget<F>>>,
input: &[F],
joint_rand: &[F],
num_shares: usize,
) -> Result<F, FlpError> {
self.valid_call_check(input, joint_rand)?;

// Check that each element of `input` is a 0 or 1.
let range_check = parallel_sum_range_checks(
&mut g[0],
input,
joint_rand[0],
self.chunk_length,
num_shares,
)?;

// Check that the elements of `input` sum to at most `max_weight`.
let count_vec = &input[..self.length];
let weight = count_vec.iter().fold(F::zero(), |a, b| a + *b);
let offset_weight_reported = F::decode_bitvector(&input[self.length..])?;

// From spec: weight_check = self.offset*shares_inv + weight - weight_reported
let weight_check = {
let offset = F::from(F::valid_integer_try_from(self.offset)?);
let shares_inv = F::from(F::valid_integer_try_from(num_shares)?).inv();
offset * shares_inv + weight - offset_weight_reported
};

// Take a random linear combination of both checks.
let out = joint_rand[1] * range_check + (joint_rand[1] * joint_rand[1]) * weight_check;
Ok(out)
}

// Truncates the measurement, removing extra data that was necessary for validity (here, the
// encoded weight), but not important for aggregation
fn truncate(&self, input: Vec<Self::Field>) -> Result<Vec<Self::Field>, FlpError> {
self.truncate_call_check(&input)?;
// Cut off the encoded weight
Ok(input[..self.length].to_vec())
}

// The length in field elements of the encoded input returned by [`Self::encode_measurement`].
fn input_len(&self) -> usize {
self.length + self.bits_for_weight
}

fn proof_len(&self) -> usize {
(self.chunk_length * 2) + 2 * ((1 + self.gadget_calls).next_power_of_two() - 1) + 1
}

fn verifier_len(&self) -> usize {
2 + self.chunk_length * 2
}

// The length of the truncated output (i.e., the output of [`Type::truncate`]).
fn output_len(&self) -> usize {
self.length
}

// The number of random values needed in the validity checks
fn joint_rand_len(&self) -> usize {
2
}

fn prove_rand_len(&self) -> usize {
self.chunk_length * 2
}

fn query_rand_len(&self) -> usize {
// TODO: this will need to be increase once draft-10 is implemented and more randomness is
// necessary due to random linear combination computations
1
}
}

/// A sequence of integers in range `[0, 2^bits)`. This type uses a neat trick from [[BBCG+19],
/// Corollary 4.9] to reduce the proof size to roughly the square root of the input size.
///
Expand Down Expand Up @@ -685,13 +917,13 @@ pub(crate) fn call_gadget_on_vec_entries<F: FftFriendlyFieldElement>(
input: &[F],
rnd: F,
) -> Result<F, FlpError> {
let mut range_check = F::zero();
let mut comb = F::zero();
let mut r = rnd;
for chunk in input.chunks(1) {
range_check += r * g.call(chunk)?;
comb += r * g.call(chunk)?;
r *= rnd;
}
Ok(range_check)
Ok(comb)
}

/// Given a vector `data` of field elements which should contain exactly one entry, return the
Expand Down Expand Up @@ -776,7 +1008,9 @@ pub(crate) fn parallel_sum_range_checks<F: FftFriendlyFieldElement>(
#[cfg(test)]
mod tests {
use super::*;
use crate::field::{random_vector, Field64 as TestField, FieldElement};
use crate::field::{
random_vector, Field64 as TestField, FieldElement, FieldElementWithInteger,
};
use crate::flp::gadgets::ParallelSum;
#[cfg(feature = "multithreaded")]
use crate::flp::gadgets::ParallelSumMultithreaded;
Expand Down Expand Up @@ -957,6 +1191,113 @@ mod tests {
);
}

fn test_multihot<F, S>(constructor: F)
where
F: Fn(usize, usize, usize) -> Result<MultihotCountVec<TestField, S>, FlpError>,
S: ParallelSumGadget<TestField, Mul<TestField>> + Eq + 'static,
{
const NUM_SHARES: usize = 3;

// Chunk size for our range check gadget
let chunk_size = 2;

// Our test is on multihot vecs of length 3, with max weight 2
let num_buckets = 3;
let max_weight = 2;

let multihot_instance = constructor(num_buckets, max_weight, chunk_size).unwrap();
let zero = TestField::zero();
let one = TestField::one();
let nine = TestField::from(9);

let encoded_weight_plus_offset = |weight| {
let bits_for_weight = ilog2(max_weight) as usize + 1;
let offset = (1 << bits_for_weight) - 1 - max_weight;
TestField::encode_as_bitvector(
<TestField as FieldElementWithInteger>::Integer::try_from(weight + offset).unwrap(),
bits_for_weight,
)
.unwrap()
.collect::<Vec<TestField>>()
};

assert_eq!(
multihot_instance
.encode_measurement(&vec![true, true, false])
.unwrap(),
[&[one, one, zero], &*encoded_weight_plus_offset(2)].concat(),
);
assert_eq!(
multihot_instance
.encode_measurement(&vec![false, true, true])
.unwrap(),
[&[zero, one, one], &*encoded_weight_plus_offset(2)].concat(),
);

// Round trip
assert_eq!(
multihot_instance
.decode_result(
&multihot_instance
.truncate(
multihot_instance
.encode_measurement(&vec![false, true, true])
.unwrap()
)
.unwrap(),
1
)
.unwrap(),
[0, 1, 1]
);

// Test valid inputs with weights 0, 1, and 2
FlpTest::expect_valid::<NUM_SHARES>(
&multihot_instance,
&multihot_instance
.encode_measurement(&vec![true, false, false])
.unwrap(),
&[one, zero, zero],
);

FlpTest::expect_valid::<NUM_SHARES>(
&multihot_instance,
&multihot_instance
.encode_measurement(&vec![false, true, true])
.unwrap(),
&[zero, one, one],
);

FlpTest::expect_valid::<NUM_SHARES>(
&multihot_instance,
&multihot_instance
.encode_measurement(&vec![false, false, false])
.unwrap(),
&[zero, zero, zero],
);

// Test invalid inputs.

// Not binary
FlpTest::expect_invalid::<NUM_SHARES>(
&multihot_instance,
&[&[zero, zero, nine], &*encoded_weight_plus_offset(1)].concat(),
);
// Wrong weight
FlpTest::expect_invalid::<NUM_SHARES>(
&multihot_instance,
&[&[zero, zero, one], &*encoded_weight_plus_offset(2)].concat(),
);
// We cannot test the case where the weight is higher than max_weight. This is because
// weight + offset cannot fit into a bitvector of the correct length. In other words, being
// out-of-range requires the prover to lie about their weight, which is tested above
}

#[test]
fn test_multihot_serial() {
test_multihot(MultihotCountVec::<TestField, ParallelSum<TestField, Mul<TestField>>>::new);
}

fn test_sum_vec<F, S>(f: F)
where
F: Fn(usize, usize, usize) -> Result<SumVec<TestField, S>, FlpError>,
Expand Down
Loading

0 comments on commit d2fe428

Please sign in to comment.