Skip to content

Commit

Permalink
Prio2: split precomputations to fix warnings (#1077)
Browse files Browse the repository at this point in the history
  • Loading branch information
divergentdave authored Jun 13, 2024
1 parent 790ba4d commit c8c37bb
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 49 deletions.
14 changes: 4 additions & 10 deletions src/benchmarked.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::fft::discrete_fourier_transform;
use crate::field::FftFriendlyFieldElement;
use crate::flp::gadgets::Mul;
use crate::flp::FlpError;
use crate::polynomial::{poly_fft, PolyAuxMemory};
use crate::polynomial::{fft_get_roots, poly_fft, PolyFFTTempMemory};

/// Sets `outp` to the Discrete Fourier Transform (DFT) using an iterative FFT algorithm.
pub fn benchmarked_iterative_fft<F: FftFriendlyFieldElement>(outp: &mut [F], inp: &[F]) {
Expand All @@ -18,15 +18,9 @@ pub fn benchmarked_iterative_fft<F: FftFriendlyFieldElement>(outp: &mut [F], inp

/// Sets `outp` to the Discrete Fourier Transform (DFT) using a recursive FFT algorithm.
pub fn benchmarked_recursive_fft<F: FftFriendlyFieldElement>(outp: &mut [F], inp: &[F]) {
let mut mem = PolyAuxMemory::new(inp.len() / 2);
poly_fft(
outp,
inp,
&mem.roots_2n,
inp.len(),
false,
&mut mem.fft_memory,
)
let roots_2n = fft_get_roots(inp.len(), false);
let mut fft_memory = PolyFFTTempMemory::new(inp.len());
poly_fft(outp, inp, &roots_2n, inp.len(), false, &mut fft_memory)
}

/// Sets `outp` to `inp[0] * inp[1]`, where `inp[0]` and `inp[1]` are polynomials. This function
Expand Down
4 changes: 2 additions & 2 deletions src/fft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ fn bitrev(d: usize, x: usize) -> usize {
mod tests {
use super::*;
use crate::field::{random_vector, split_vector, Field128, Field64, FieldElement, FieldPrio2};
use crate::polynomial::{poly_fft, PolyAuxMemory};
use crate::polynomial::{poly_fft, TestPolyAuxMemory};

fn discrete_fourier_transform_then_inv_test<F: FftFriendlyFieldElement>() -> Result<(), FftError>
{
Expand Down Expand Up @@ -164,7 +164,7 @@ mod tests {
#[test]
fn test_recursive_fft() {
let size = 128;
let mut mem = PolyAuxMemory::new(size / 2);
let mut mem = TestPolyAuxMemory::new(size / 2);

let inp = random_vector(size).unwrap();
let mut want = vec![FieldPrio2::zero(); size];
Expand Down
26 changes: 11 additions & 15 deletions src/polynomial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ pub struct PolyFFTTempMemory<F> {
}

impl<F: FftFriendlyFieldElement> PolyFFTTempMemory<F> {
fn new(length: usize) -> Self {
pub(crate) fn new(length: usize) -> Self {
PolyFFTTempMemory {
fft_tmp: vec![F::zero(); length],
fft_y_sub: vec![F::zero(); length],
Expand All @@ -27,25 +27,20 @@ impl<F: FftFriendlyFieldElement> PolyFFTTempMemory<F> {
}
}

/// Auxiliary memory for polynomial interpolation and evaluation
#[cfg(test)]
#[derive(Clone, Debug)]
pub struct PolyAuxMemory<F> {
pub(crate) struct TestPolyAuxMemory<F> {
pub roots_2n: Vec<F>,
pub roots_2n_inverted: Vec<F>,
pub roots_n: Vec<F>,
pub roots_n_inverted: Vec<F>,
pub coeffs: Vec<F>,
pub fft_memory: PolyFFTTempMemory<F>,
}

impl<F: FftFriendlyFieldElement> PolyAuxMemory<F> {
pub fn new(n: usize) -> Self {
PolyAuxMemory {
#[cfg(test)]
impl<F: FftFriendlyFieldElement> TestPolyAuxMemory<F> {
pub(crate) fn new(n: usize) -> Self {
Self {
roots_2n: fft_get_roots(2 * n, false),
roots_2n_inverted: fft_get_roots(2 * n, true),
roots_n: fft_get_roots(n, false),
roots_n_inverted: fft_get_roots(n, true),
coeffs: vec![F::zero(); 2 * n],
fft_memory: PolyFFTTempMemory::new(2 * n),
}
}
Expand Down Expand Up @@ -109,7 +104,7 @@ fn fft_recurse<F: FftFriendlyFieldElement>(
}

/// Calculate `count` number of roots of unity of order `count`
fn fft_get_roots<F: FftFriendlyFieldElement>(count: usize, invert: bool) -> Vec<F> {
pub(crate) fn fft_get_roots<F: FftFriendlyFieldElement>(count: usize, invert: bool) -> Vec<F> {
let mut roots = vec![F::zero(); count];
let mut gen = F::generator();
if invert {
Expand Down Expand Up @@ -236,7 +231,8 @@ mod tests {
FftFriendlyFieldElement, Field64, FieldElement, FieldElementWithInteger, FieldPrio2,
},
polynomial::{
fft_get_roots, poly_deg, poly_eval, poly_fft, poly_mul, poly_range_check, PolyAuxMemory,
fft_get_roots, poly_deg, poly_eval, poly_fft, poly_mul, poly_range_check,
TestPolyAuxMemory,
},
};
use rand::prelude::*;
Expand Down Expand Up @@ -333,7 +329,7 @@ mod tests {
#[test]
fn test_fft() {
let count = 128;
let mut mem = PolyAuxMemory::new(count / 2);
let mut mem = TestPolyAuxMemory::new(count / 2);

let mut poly = vec![FieldPrio2::from(0); count];
let mut points2 = vec![FieldPrio2::from(0); count];
Expand Down
69 changes: 47 additions & 22 deletions src/vdaf/prio2/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
use crate::{
codec::CodecError,
field::FftFriendlyFieldElement,
polynomial::{poly_fft, PolyAuxMemory},
polynomial::{fft_get_roots, poly_fft, PolyFFTTempMemory},
prng::{Prng, PrngError},
vdaf::{
xof::{Seed, SeedStreamAes128},
Expand Down Expand Up @@ -48,7 +48,10 @@ pub(crate) struct ClientMemory<F> {
points_g: Vec<F>,
evals_f: Vec<F>,
evals_g: Vec<F>,
poly_mem: PolyAuxMemory<F>,
roots_2n: Vec<F>,
roots_n_inverted: Vec<F>,
fft_memory: PolyFFTTempMemory<F>,
coeffs: Vec<F>,
}

impl<F: FftFriendlyFieldElement> ClientMemory<F> {
Expand All @@ -72,7 +75,10 @@ impl<F: FftFriendlyFieldElement> ClientMemory<F> {
points_g: vec![F::zero(); n],
evals_f: vec![F::zero(); 2 * n],
evals_g: vec![F::zero(); 2 * n],
poly_mem: PolyAuxMemory::new(n),
roots_2n: fft_get_roots(2 * n, false),
roots_n_inverted: fft_get_roots(n, true),
fft_memory: PolyFFTTempMemory::new(2 * n),
coeffs: vec![F::zero(); 2 * n],
})
}
}
Expand Down Expand Up @@ -191,30 +197,33 @@ pub(crate) fn unpack_proof_mut<F: FftFriendlyFieldElement>(
}
}

/// Interpolate a polynomial at the nth roots of unity, and then evaluate it at the 2nth roots of
/// unity.
///
/// # Arguments
///
/// * `n` - The number of points to interpolate a polynomial through.
/// * `points_in` - The values that the polynomial must take on when evaluated at the nth roots of
/// unity. This must have length n.
/// * `evals_out` - The values that the polynomial takes on when evaluated at the 2nth roots of
/// unity. This must have length 2 * n.
/// * `roots_n_inverted` - Precomputed inverses of the nth roots of unity.
/// * `roots_2n` - Precomputed 2nth roots of unity.
/// * `fft_memory` - Scratch space for the FFT algorithm.
/// * `coeffs` - Scratch space. This must have length 2 * n.
fn interpolate_and_evaluate_at_2n<F: FftFriendlyFieldElement>(
n: usize,
points_in: &[F],
evals_out: &mut [F],
mem: &mut PolyAuxMemory<F>,
roots_n_inverted: &[F],
roots_2n: &[F],
fft_memory: &mut PolyFFTTempMemory<F>,
coeffs: &mut [F],
) {
// interpolate through roots of unity
poly_fft(
&mut mem.coeffs,
points_in,
&mem.roots_n_inverted,
n,
true,
&mut mem.fft_memory,
);
poly_fft(coeffs, points_in, roots_n_inverted, n, true, fft_memory);
// evaluate at 2N roots of unity
poly_fft(
evals_out,
&mem.coeffs,
&mem.roots_2n,
2 * n,
false,
&mut mem.fft_memory,
);
poly_fft(evals_out, coeffs, roots_2n, 2 * n, false, fft_memory);
}

/// Proof construction
Expand Down Expand Up @@ -253,8 +262,24 @@ fn construct_proof<F: FftFriendlyFieldElement>(
}

// interpolate and evaluate at roots of unity
interpolate_and_evaluate_at_2n(n, &mem.points_f, &mut mem.evals_f, &mut mem.poly_mem);
interpolate_and_evaluate_at_2n(n, &mem.points_g, &mut mem.evals_g, &mut mem.poly_mem);
interpolate_and_evaluate_at_2n(
n,
&mem.points_f,
&mut mem.evals_f,
&mem.roots_n_inverted,
&mem.roots_2n,
&mut mem.fft_memory,
&mut mem.coeffs,
);
interpolate_and_evaluate_at_2n(
n,
&mem.points_g,
&mut mem.evals_g,
&mem.roots_n_inverted,
&mem.roots_2n,
&mut mem.fft_memory,
&mut mem.coeffs,
);

// calculate the proof polynomial as evals_f(r) * evals_g(r)
// only add non-zero points
Expand Down

0 comments on commit c8c37bb

Please sign in to comment.