Skip to content

Commit

Permalink
Make FLP Eval Return Vec<F> rather than a single F (#1132)
Browse files Browse the repository at this point in the history
* Simplified Average to use Sum internally

* Made Type::valid return a Vec<F> rather than F

* Made Prio3Sum circuit output bits-many elements

---------

Co-authored-by: Michael Rosenberg <[email protected]>
  • Loading branch information
rozbb and Michael Rosenberg authored Nov 12, 2024
1 parent baf3b86 commit a8e48e7
Show file tree
Hide file tree
Showing 6 changed files with 191 additions and 170 deletions.
108 changes: 77 additions & 31 deletions src/flp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,9 @@ pub trait Type: Sized + Eq + Clone + Debug {
/// [BBCG+19]: https://ia.cr/2019/188
fn gadget(&self) -> Vec<Box<dyn Gadget<Self::Field>>>;

/// Returns the number of gadgets associated with this validity circuit. This MUST equal `self.gadget().len()`.
fn num_gadgets(&self) -> usize;

/// Evaluates the validity circuit on an input and returns the output.
///
/// # Parameters
Expand All @@ -179,15 +182,15 @@ pub trait Type: Sized + Eq + Clone + Debug {
/// let input: Vec<Field64> = count.encode_measurement(&true).unwrap();
/// let joint_rand = random_vector(count.joint_rand_len()).unwrap();
/// let v = count.valid(&mut count.gadget(), &input, &joint_rand, 1).unwrap();
/// assert_eq!(v, Field64::zero());
/// assert!(v.into_iter().all(|f| f == Field64::zero()));
/// ```
fn valid(
&self,
gadgets: &mut Vec<Box<dyn Gadget<Self::Field>>>,
input: &[Self::Field],
joint_rand: &[Self::Field],
num_shares: usize,
) -> Result<Self::Field, FlpError>;
) -> Result<Vec<Self::Field>, FlpError>;

/// Constructs an aggregatable output from an encoded input. Calling this method is only safe
/// once `input` has been validated.
Expand All @@ -208,14 +211,25 @@ pub trait Type: Sized + Eq + Clone + Debug {
/// The length of the joint random input.
fn joint_rand_len(&self) -> usize;

/// The length of the circuit output
fn eval_output_len(&self) -> usize;

/// The length in field elements of the random input consumed by the prover to generate a
/// proof. This is the same as the sum of the arity of each gadget in the validity circuit.
fn prove_rand_len(&self) -> usize;

/// The length in field elements of the random input consumed by the verifier to make queries
/// against inputs and proofs. This is the same as the number of gadgets in the validity
/// circuit.
fn query_rand_len(&self) -> usize;
/// circuit, plus the number of elements output by the validity circuit (if >1).
fn query_rand_len(&self) -> usize {
let mut n = self.num_gadgets();
let eval_elems = self.eval_output_len();
if eval_elems > 1 {
n += eval_elems;
}

n
}

/// Generate a proof of an input's validity. The return value is a sequence of
/// [`Self::proof_len`] field elements.
Expand Down Expand Up @@ -388,6 +402,24 @@ pub trait Type: Sized + Eq + Clone + Debug {
self.query_rand_len()
)));
}
// We use query randomness to compress outputs from `valid()` (if size is > 1), as well as
// for gadget evaluations. Split these up
let (query_rand_for_validity, query_rand_for_gadgets) = if self.eval_output_len() > 1 {
query_rand.split_at(self.eval_output_len())
} else {
query_rand.split_at(0)
};

// Another check that we have the right amount of randomness
let my_gadgets = self.gadget();
if query_rand_for_gadgets.len() != my_gadgets.len() {
return Err(FlpError::Query(format!(
"length of query randomness for gadgets doesn't match number of gadgets: \
got {}; want {}",
query_rand_for_gadgets.len(),
my_gadgets.len()
)));
}

if joint_rand.len() != self.joint_rand_len() {
return Err(FlpError::Query(format!(
Expand All @@ -398,15 +430,13 @@ pub trait Type: Sized + Eq + Clone + Debug {
}

let mut proof_len = 0;
let mut shims = self
.gadget()
let mut shims = my_gadgets
.into_iter()
.enumerate()
.map(|(idx, gadget)| {
.zip(query_rand_for_gadgets)
.map(|(gadget, &r)| {
let gadget_degree = gadget.degree();
let gadget_arity = gadget.arity();
let m = (1 + gadget.calls()).next_power_of_two();
let r = query_rand[idx];

// Make sure the query randomness isn't a root of unity. Evaluating the gadget
// polynomial at any of these points would be a privacy violation, since these points
Expand All @@ -419,7 +449,7 @@ pub trait Type: Sized + Eq + Clone + Debug {
)));
}

// Compute the length of the sub-proof corresponding to the `idx`-th gadget.
// Compute the length of the sub-proof corresponding to this gadget.
let next_len = gadget_arity + gadget_degree * (m - 1) + 1;
let proof_data = &proof[proof_len..proof_len + next_len];
proof_len += next_len;
Expand All @@ -444,10 +474,23 @@ pub trait Type: Sized + Eq + Clone + Debug {
// should be OK, since it's possible to transform any circuit into one for which this is true.
// (Needs security analysis.)
let validity = self.valid(&mut shims, input, joint_rand, num_shares)?;
verifier.push(validity);
assert_eq!(validity.len(), self.eval_output_len());
// If `valid()` outputs multiple field elements, compress them into 1 field element using
// query randomness
let check = if validity.len() > 1 {
validity
.iter()
.zip(query_rand_for_validity)
.fold(Self::Field::zero(), |acc, (&val, &r)| acc + r * val)
} else {
// If `valid()` outputs one field element, just use that. If it outputs none, then it is
// trivially satisfied, so use 0
validity.first().cloned().unwrap_or(Self::Field::zero())
};
verifier.push(check);

// Fill the buffer with the verifier message.
for (query_rand_val, shim) in query_rand[..shims.len()].iter().zip(shims.iter_mut()) {
for (query_rand_val, shim) in query_rand_for_gadgets.iter().zip(shims.iter_mut()) {
let gadget = shim
.as_any()
.downcast_ref::<QueryShimGadget<Self::Field>>()
Expand Down Expand Up @@ -836,11 +879,6 @@ pub mod test_utils {
let joint_rand = random_vector(self.flp.joint_rand_len()).unwrap();
let prove_rand = random_vector(self.flp.prove_rand_len()).unwrap();
let query_rand = random_vector(self.flp.query_rand_len()).unwrap();
assert_eq!(
self.flp.query_rand_len(),
gadgets.len(),
"{name}: unexpected number of gadgets"
);
assert_eq!(
self.flp.joint_rand_len(),
joint_rand.len(),
Expand All @@ -863,9 +901,9 @@ pub mod test_utils {
.valid(&mut gadgets, self.input, &joint_rand, 1)
.unwrap();
assert_eq!(
v == T::Field::zero(),
v.iter().all(|f| f == &T::Field::zero()),
self.expect_valid,
"{name}: unexpected output of valid() returned {v}",
"{name}: unexpected output of valid() returned {v:?}",
);

// Generate the proof.
Expand Down Expand Up @@ -1056,7 +1094,7 @@ mod tests {
input: &[F],
joint_rand: &[F],
_num_shares: usize,
) -> Result<F, FlpError> {
) -> Result<Vec<F>, FlpError> {
let r = joint_rand[0];
let mut res = F::zero();

Expand All @@ -1071,7 +1109,7 @@ mod tests {
let x_checked = g[1].call(&[input[0]])?;
res += (r * r) * x_checked;

Ok(res)
Ok(vec![res])
}

fn input_len(&self) -> usize {
Expand Down Expand Up @@ -1108,12 +1146,12 @@ mod tests {
1
}

fn prove_rand_len(&self) -> usize {
3
fn eval_output_len(&self) -> usize {
1
}

fn query_rand_len(&self) -> usize {
2
fn prove_rand_len(&self) -> usize {
3
}

fn gadget(&self) -> Vec<Box<dyn Gadget<F>>> {
Expand All @@ -1123,6 +1161,10 @@ mod tests {
]
}

fn num_gadgets(&self) -> usize {
2
}

fn encode_measurement(&self, measurement: &F::Integer) -> Result<Vec<F>, FlpError> {
Ok(vec![
F::from(*measurement),
Expand Down Expand Up @@ -1190,7 +1232,7 @@ mod tests {
input: &[F],
_joint_rand: &[F],
_num_shares: usize,
) -> Result<F, FlpError> {
) -> Result<Vec<F>, FlpError> {
// This is a useless circuit, as it only accepts "0". Its purpose is to exercise the
// use of multiple gadgets, each of which is called an arbitrary number of times.
let mut res = F::zero();
Expand All @@ -1200,7 +1242,7 @@ mod tests {
for _ in 0..self.num_gadget_calls[1] {
res += g[1].call(&[input[0]])?;
}
Ok(res)
Ok(vec![res])
}

fn input_len(&self) -> usize {
Expand Down Expand Up @@ -1237,6 +1279,10 @@ mod tests {
0
}

fn eval_output_len(&self) -> usize {
1
}

fn prove_rand_len(&self) -> usize {
// First chunk
let first = 1; // gadget arity
Expand All @@ -1247,10 +1293,6 @@ mod tests {
first + second
}

fn query_rand_len(&self) -> usize {
2 // number of gadgets
}

fn gadget(&self) -> Vec<Box<dyn Gadget<F>>> {
let poly = poly_range_check(0, 2); // A polynomial with degree 2
vec![
Expand All @@ -1259,6 +1301,10 @@ mod tests {
]
}

fn num_gadgets(&self) -> usize {
2
}

fn encode_measurement(&self, measurement: &F::Integer) -> Result<Vec<F>, FlpError> {
Ok(vec![F::from(*measurement)])
}
Expand Down
4 changes: 2 additions & 2 deletions src/flp/szk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -904,7 +904,7 @@ mod tests {
let szk_typ = Szk::new_turboshake128(sum, algorithm_id);
let prove_rand_seed = Seed::<16>::generate().unwrap();
let helper_seed = Seed::<16>::generate().unwrap();
let leader_seed_opt = Some(Seed::<16>::generate().unwrap());
let leader_seed_opt = None;
let helper_input_share = random_vector(szk_typ.typ.input_len()).unwrap();
let mut leader_input_share = encoded_measurement.clone().to_owned();
for (x, y) in leader_input_share.iter_mut().zip(&helper_input_share) {
Expand Down Expand Up @@ -944,7 +944,7 @@ mod tests {
let szk_typ = Szk::new_turboshake128(sum, algorithm_id);
let prove_rand_seed = Seed::<16>::generate().unwrap();
let helper_seed = Seed::<16>::generate().unwrap();
let leader_seed_opt = Some(Seed::<16>::generate().unwrap());
let leader_seed_opt = None;
let helper_input_share = random_vector(szk_typ.typ.input_len()).unwrap();
let mut leader_input_share = encoded_measurement.clone().to_owned();
for (x, y) in leader_input_share.iter_mut().zip(&helper_input_share) {
Expand Down
Loading

0 comments on commit a8e48e7

Please sign in to comment.