diff --git a/src/benchmarked.rs b/src/benchmarked.rs index 1882de91..d8023236 100644 --- a/src/benchmarked.rs +++ b/src/benchmarked.rs @@ -9,7 +9,7 @@ use crate::fft::discrete_fourier_transform; use crate::field::FftFriendlyFieldElement; use crate::flp::gadgets::Mul; use crate::flp::FlpError; -use crate::polynomial::{poly_fft, PolyAuxMemory}; +use crate::polynomial::{fft_get_roots, poly_fft, PolyFFTTempMemory}; /// Sets `outp` to the Discrete Fourier Transform (DFT) using an iterative FFT algorithm. pub fn benchmarked_iterative_fft(outp: &mut [F], inp: &[F]) { @@ -18,15 +18,9 @@ pub fn benchmarked_iterative_fft(outp: &mut [F], inp /// Sets `outp` to the Discrete Fourier Transform (DFT) using a recursive FFT algorithm. pub fn benchmarked_recursive_fft(outp: &mut [F], inp: &[F]) { - let mut mem = PolyAuxMemory::new(inp.len() / 2); - poly_fft( - outp, - inp, - &mem.roots_2n, - inp.len(), - false, - &mut mem.fft_memory, - ) + let roots_2n = fft_get_roots(inp.len(), false); + let mut fft_memory = PolyFFTTempMemory::new(inp.len()); + poly_fft(outp, inp, &roots_2n, inp.len(), false, &mut fft_memory) } /// Sets `outp` to `inp[0] * inp[1]`, where `inp[0]` and `inp[1]` are polynomials. This function diff --git a/src/fft.rs b/src/fft.rs index 039f183f..10729a2a 100644 --- a/src/fft.rs +++ b/src/fft.rs @@ -127,7 +127,7 @@ fn bitrev(d: usize, x: usize) -> usize { mod tests { use super::*; use crate::field::{random_vector, split_vector, Field128, Field64, FieldElement, FieldPrio2}; - use crate::polynomial::{poly_fft, PolyAuxMemory}; + use crate::polynomial::{poly_fft, TestPolyAuxMemory}; fn discrete_fourier_transform_then_inv_test() -> Result<(), FftError> { @@ -164,7 +164,7 @@ mod tests { #[test] fn test_recursive_fft() { let size = 128; - let mut mem = PolyAuxMemory::new(size / 2); + let mut mem = TestPolyAuxMemory::new(size / 2); let inp = random_vector(size).unwrap(); let mut want = vec![FieldPrio2::zero(); size]; diff --git a/src/polynomial.rs b/src/polynomial.rs index 272266f5..782b803e 100644 --- a/src/polynomial.rs +++ b/src/polynomial.rs @@ -18,7 +18,7 @@ pub struct PolyFFTTempMemory { } impl PolyFFTTempMemory { - fn new(length: usize) -> Self { + pub(crate) fn new(length: usize) -> Self { PolyFFTTempMemory { fft_tmp: vec![F::zero(); length], fft_y_sub: vec![F::zero(); length], @@ -27,25 +27,20 @@ impl PolyFFTTempMemory { } } -/// Auxiliary memory for polynomial interpolation and evaluation +#[cfg(test)] #[derive(Clone, Debug)] -pub struct PolyAuxMemory { +pub(crate) struct TestPolyAuxMemory { pub roots_2n: Vec, pub roots_2n_inverted: Vec, - pub roots_n: Vec, - pub roots_n_inverted: Vec, - pub coeffs: Vec, pub fft_memory: PolyFFTTempMemory, } -impl PolyAuxMemory { - pub fn new(n: usize) -> Self { - PolyAuxMemory { +#[cfg(test)] +impl TestPolyAuxMemory { + pub(crate) fn new(n: usize) -> Self { + Self { roots_2n: fft_get_roots(2 * n, false), roots_2n_inverted: fft_get_roots(2 * n, true), - roots_n: fft_get_roots(n, false), - roots_n_inverted: fft_get_roots(n, true), - coeffs: vec![F::zero(); 2 * n], fft_memory: PolyFFTTempMemory::new(2 * n), } } @@ -109,7 +104,7 @@ fn fft_recurse( } /// Calculate `count` number of roots of unity of order `count` -fn fft_get_roots(count: usize, invert: bool) -> Vec { +pub(crate) fn fft_get_roots(count: usize, invert: bool) -> Vec { let mut roots = vec![F::zero(); count]; let mut gen = F::generator(); if invert { @@ -236,7 +231,8 @@ mod tests { FftFriendlyFieldElement, Field64, FieldElement, FieldElementWithInteger, FieldPrio2, }, polynomial::{ - fft_get_roots, poly_deg, poly_eval, poly_fft, poly_mul, poly_range_check, PolyAuxMemory, + fft_get_roots, poly_deg, poly_eval, poly_fft, poly_mul, poly_range_check, + TestPolyAuxMemory, }, }; use rand::prelude::*; @@ -333,7 +329,7 @@ mod tests { #[test] fn test_fft() { let count = 128; - let mut mem = PolyAuxMemory::new(count / 2); + let mut mem = TestPolyAuxMemory::new(count / 2); let mut poly = vec![FieldPrio2::from(0); count]; let mut points2 = vec![FieldPrio2::from(0); count]; diff --git a/src/vdaf/prio2/client.rs b/src/vdaf/prio2/client.rs index 9515601d..32b9af52 100644 --- a/src/vdaf/prio2/client.rs +++ b/src/vdaf/prio2/client.rs @@ -6,7 +6,7 @@ use crate::{ codec::CodecError, field::FftFriendlyFieldElement, - polynomial::{poly_fft, PolyAuxMemory}, + polynomial::{fft_get_roots, poly_fft, PolyFFTTempMemory}, prng::{Prng, PrngError}, vdaf::{ xof::{Seed, SeedStreamAes128}, @@ -48,7 +48,10 @@ pub(crate) struct ClientMemory { points_g: Vec, evals_f: Vec, evals_g: Vec, - poly_mem: PolyAuxMemory, + roots_2n: Vec, + roots_n_inverted: Vec, + fft_memory: PolyFFTTempMemory, + coeffs: Vec, } impl ClientMemory { @@ -72,7 +75,10 @@ impl ClientMemory { points_g: vec![F::zero(); n], evals_f: vec![F::zero(); 2 * n], evals_g: vec![F::zero(); 2 * n], - poly_mem: PolyAuxMemory::new(n), + roots_2n: fft_get_roots(2 * n, false), + roots_n_inverted: fft_get_roots(n, true), + fft_memory: PolyFFTTempMemory::new(2 * n), + coeffs: vec![F::zero(); 2 * n], }) } } @@ -191,30 +197,33 @@ pub(crate) fn unpack_proof_mut( } } +/// Interpolate a polynomial at the nth roots of unity, and then evaluate it at the 2nth roots of +/// unity. +/// +/// # Arguments +/// +/// * `n` - The number of points to interpolate a polynomial through. +/// * `points_in` - The values that the polynomial must take on when evaluated at the nth roots of +/// unity. This must have length n. +/// * `evals_out` - The values that the polynomial takes on when evaluated at the 2nth roots of +/// unity. This must have length 2 * n. +/// * `roots_n_inverted` - Precomputed inverses of the nth roots of unity. +/// * `roots_2n` - Precomputed 2nth roots of unity. +/// * `fft_memory` - Scratch space for the FFT algorithm. +/// * `coeffs` - Scratch space. This must have length 2 * n. fn interpolate_and_evaluate_at_2n( n: usize, points_in: &[F], evals_out: &mut [F], - mem: &mut PolyAuxMemory, + roots_n_inverted: &[F], + roots_2n: &[F], + fft_memory: &mut PolyFFTTempMemory, + coeffs: &mut [F], ) { // interpolate through roots of unity - poly_fft( - &mut mem.coeffs, - points_in, - &mem.roots_n_inverted, - n, - true, - &mut mem.fft_memory, - ); + poly_fft(coeffs, points_in, roots_n_inverted, n, true, fft_memory); // evaluate at 2N roots of unity - poly_fft( - evals_out, - &mem.coeffs, - &mem.roots_2n, - 2 * n, - false, - &mut mem.fft_memory, - ); + poly_fft(evals_out, coeffs, roots_2n, 2 * n, false, fft_memory); } /// Proof construction @@ -253,8 +262,24 @@ fn construct_proof( } // interpolate and evaluate at roots of unity - interpolate_and_evaluate_at_2n(n, &mem.points_f, &mut mem.evals_f, &mut mem.poly_mem); - interpolate_and_evaluate_at_2n(n, &mem.points_g, &mut mem.evals_g, &mut mem.poly_mem); + interpolate_and_evaluate_at_2n( + n, + &mem.points_f, + &mut mem.evals_f, + &mem.roots_n_inverted, + &mem.roots_2n, + &mut mem.fft_memory, + &mut mem.coeffs, + ); + interpolate_and_evaluate_at_2n( + n, + &mem.points_g, + &mut mem.evals_g, + &mem.roots_n_inverted, + &mem.roots_2n, + &mut mem.fft_memory, + &mut mem.coeffs, + ); // calculate the proof polynomial as evals_f(r) * evals_g(r) // only add non-zero points