diff --git a/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs b/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs index 353541a11..a779caf2d 100644 --- a/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs +++ b/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs @@ -1,7 +1,6 @@ use std::{ borrow::Borrow, fmt::Debug, - iter::{repeat, zip}, }; use generic_array::{ArrayLength, GenericArray}; @@ -95,8 +94,10 @@ where N: ArrayLength, M: ArrayLength, { - /// This function uses the `LagrangeTable` to evaluate `polynomial` on the specified output "x coordinates" - /// outputs the "y coordinates" such that `(x,y)` lies on `polynomial` + /// This function uses the `LagrangeTable` to evaluate `polynomial` on the _output_ "x coordinates" + /// that were used to generate this table. + /// It is assumed that the `y_coordinates` provided to this function correspond the values of the _input_ "x coordinates" + /// that were used to generate this table. pub fn eval(&self, y_coordinates: I) -> GenericArray where I: IntoIterator + Copy, @@ -110,7 +111,7 @@ where .map(|table_row| { table_row .iter() - .zip(y_coordinates.into_iter()) + .zip(y_coordinates) .fold(F::ZERO, |acc, (&base, y)| acc + base * (*y.borrow())) }) .collect() @@ -173,7 +174,7 @@ where mod test { use std::{borrow::Borrow, fmt::Debug}; - use generic_array::{sequence::GenericSequence, ArrayLength, GenericArray}; + use generic_array::{ArrayLength, GenericArray}; use proptest::{prelude::*, proptest}; use typenum::{U1, U32, U7, U8}; diff --git a/ipa-core/src/protocol/ipa_prf/malicious_security/prover.rs b/ipa-core/src/protocol/ipa_prf/malicious_security/prover.rs index cee049621..56ad818e6 100644 --- a/ipa-core/src/protocol/ipa_prf/malicious_security/prover.rs +++ b/ipa-core/src/protocol/ipa_prf/malicious_security/prover.rs @@ -1,5 +1,5 @@ use std::{ - iter::zip, + iter::{zip, once}, ops::{Add, Sub}, }; @@ -27,13 +27,21 @@ pub struct ProofGenerator { /// Distributed Zero Knowledge Proofs algorithm drawn from /// `https://eprint.iacr.org/2023/909.pdf` /// +#[allow(non_camel_case_types)] impl ProofGenerator where F: PrimeField, { - #![allow(non_camel_case_types)] + pub fn new(u: Vec, v: Vec) -> Self { + debug_assert_eq!(u.len(), v.len(), "u and v must be of equal length"); + Self { + u, + v, + } + } + pub fn compute_proof<λ: ArrayLength>( - self, + &self, r: F, ) -> ( ZeroKnowledgeProof, U1>>, @@ -45,13 +53,11 @@ where <<λ as Add>::Output as Sub>::Output: ArrayLength, <λ as Sub>::Output: ArrayLength, { - assert!(self.u.len() % λ::USIZE == 0); // We should pad with zeroes eventually + debug_assert_eq!(self.u.len() % λ::USIZE, 0); // We should pad with zeroes eventually let s = self.u.len() / λ::USIZE; - if s <= 1 { - panic!("When the output is this small, you should call compute_final_proof"); - } + assert!(s > 1, "When the output is this small, you should call `compute_final_proof`"); let mut next_proof_generator = ProofGenerator { u: Vec::::with_capacity(s), @@ -72,7 +78,10 @@ where let q_r = lagrange_table_r.eval(q)[0]; next_proof_generator.u.push(p_r); next_proof_generator.v.push(q_r); - zip(p.into_iter(), q.into_iter()) + // p.into_iter() has elements that are &F + // p_extrapolated.into_iter() has elements that are F + // So these iterators cannot be chained. + zip(p, q) .map(|(a, b)| *a * *b) .chain(zip(p_extrapolated, q_extrapolated).map(|(a, b)| a * b)) .collect::>() @@ -108,10 +117,10 @@ mod test { const R1: u128 = 22; const EXPECTED_NEXT_U: [u128; 8] = [0, 0, 26, 0, 7, 18, 24, 13]; const EXPECTED_NEXT_V: [u128; 8] = [10, 21, 30, 28, 15, 21, 3, 3]; - let pg: ProofGenerator = ProofGenerator { - u: U.into_iter().map(|x| Fp31::try_from(x).unwrap()).collect(), - v: V.into_iter().map(|x| Fp31::try_from(x).unwrap()).collect(), - }; + let pg: ProofGenerator = ProofGenerator::new( + U.into_iter().map(|x| Fp31::try_from(x).unwrap()).collect(), + V.into_iter().map(|x| Fp31::try_from(x).unwrap()).collect(), + ); let (proof, next_proof_generator) = pg.compute_proof::(Fp31::try_from(R1).unwrap()); assert_eq!( proof.g.into_iter().map(|x| x.as_u128()).collect::>(),