diff --git a/.gitignore b/.gitignore index b152106..237315a 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,8 @@ cgt* # ignore instatiations of my test template # don't leak secret env vars .env +.history/* + # exclude compiled files and binaries debug/ target/ diff --git a/Cargo.lock b/Cargo.lock index e2d68e5..cda390d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -29,6 +29,15 @@ version = "0.2.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5c6cb57a04249c6480766f7f7cef5467412af1490f8d1e243141daddada3264f" +[[package]] +name = "approx" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6" +dependencies = [ + "num-traits", +] + [[package]] name = "ark-crypto-primitives" version = "0.4.0" @@ -206,6 +215,12 @@ dependencies = [ "generic-array", ] +[[package]] +name = "bytemuck" +version = "1.16.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "102087e286b4677862ea56cf8fc58bb2cdfa8725c40ffb80fe3a008eb7f2fc83" + [[package]] name = "cfg-if" version = "1.0.0" @@ -510,12 +525,49 @@ version = "0.2.155" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c" +[[package]] +name = "matrixmultiply" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9380b911e3e96d10c1f415da0876389aaf1b56759054eeb0de7df940c456ba1a" +dependencies = [ + "autocfg", + "rawpointer", +] + [[package]] name = "memchr" version = "2.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" +[[package]] +name = "nalgebra" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d506eb7e08d6329505faa8a3a00a5dcc6de9f76e0c77e4b75763ae3c770831ff" +dependencies = [ + "approx", + "matrixmultiply", + "nalgebra-macros", + "num-complex", + "num-rational", + "num-traits", + "simba", + "typenum", +] + +[[package]] +name = "nalgebra-macros" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01fcc0b8149b4632adc89ac3b7b31a12fb6099a0317a4eb2ebff574ef7de7218" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "num-bigint" version = "0.4.6" @@ -526,6 +578,15 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-complex" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" +dependencies = [ + "num-traits", +] + [[package]] name = "num-integer" version = "0.1.46" @@ -535,6 +596,16 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-rational" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824" +dependencies = [ + "num-integer", + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.19" @@ -641,6 +712,12 @@ dependencies = [ "getrandom", ] +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + [[package]] name = "regex" version = "1.10.5" @@ -686,6 +763,7 @@ dependencies = [ "des", "hex", "itertools", + "nalgebra", "pretty_assertions", "rand", "rstest", @@ -731,6 +809,15 @@ dependencies = [ "semver", ] +[[package]] +name = "safe_arch" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3460605018fdc9612bce72735cba0d27efbcd9904780d44c7e3a9948f96148a" +dependencies = [ + "bytemuck", +] + [[package]] name = "semver" version = "1.0.23" @@ -748,6 +835,19 @@ dependencies = [ "digest", ] +[[package]] +name = "simba" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0b7840f121a46d63066ee7a99fc81dcabbc6105e437cae43528cea199b5a05f" +dependencies = [ + "approx", + "num-complex", + "num-traits", + "paste", + "wide", +] + [[package]] name = "slab" version = "0.4.9" @@ -861,6 +961,16 @@ version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +[[package]] +name = "wide" +version = "0.7.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "901e8597c777fa042e9e245bd56c0dc4418c5db3f845b6ff94fbac732c6a0692" +dependencies = [ + "bytemuck", + "safe_arch", +] + [[package]] name = "winnow" version = "0.5.40" diff --git a/Cargo.toml b/Cargo.toml index b26431d..f26190e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,13 +4,15 @@ description="""ronkathon""" edition ="2021" license ="Apache2.0 OR MIT" name ="ronkathon" -repository ="https://github.com/pluto/ronkathon" + +repository ="https://github.com/wu-s-john/ronkathon" version ="0.1.0" [dependencies] rand ="0.8.5" itertools="0.13.0" hex ="0.4.3" +nalgebra = "0.29" [dev-dependencies] rstest ="0.22.0" diff --git a/src/algebra/field/prime/mod.rs b/src/algebra/field/prime/mod.rs index 8794f80..b3d2908 100644 --- a/src/algebra/field/prime/mod.rs +++ b/src/algebra/field/prime/mod.rs @@ -8,7 +8,7 @@ use std::{fmt, str::FromStr}; use rand::{distributions::Standard, prelude::Distribution, Rng}; use super::*; -use crate::algebra::Finite; +use crate::{algebra::Finite, random::Random}; mod arithmetic; @@ -41,6 +41,13 @@ pub struct PrimeField { pub(crate) value: usize, } +impl Random for PlutoBaseField { + fn random(rng: &mut R) -> Self { + let value = rng.gen_range(0..PlutoPrime::Base as usize); + PlutoBaseField::new(value) + } +} + impl PrimeField

{ /// Creates a new element of the [`PrimeField`] and will automatically compute the modulus and /// return a congruent element between 0 and `P`. Given the `const fn is_prime`, a program that diff --git a/src/lib.rs b/src/lib.rs index 40ae5dc..e102dd7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -32,6 +32,8 @@ pub mod encryption; pub mod hashes; pub mod kzg; pub mod polynomial; +pub mod random; +pub mod sumcheck; pub mod tree; use core::{ diff --git a/src/polynomial/mod.rs b/src/polynomial/mod.rs index 49d9a74..7223aa6 100644 --- a/src/polynomial/mod.rs +++ b/src/polynomial/mod.rs @@ -23,6 +23,7 @@ use super::*; use crate::algebra::field::FiniteField; pub mod arithmetic; +pub mod multivariate_polynomial; #[cfg(test)] mod tests; // https://people.inf.ethz.ch/gander/papers/changing.pdf diff --git a/src/polynomial/multivariate_polynomial.rs b/src/polynomial/multivariate_polynomial.rs new file mode 100644 index 0000000..e02492f --- /dev/null +++ b/src/polynomial/multivariate_polynomial.rs @@ -0,0 +1,497 @@ +//! Represents a multivariate polynomial over a finite field. +//! +//! This implementation uses a novel and highly efficient representation for multivariate +//! polynomials. Each term in the polynomial is represented as a key-value pair in a HashMap, where: +//! - The key is a BTreeMap mapping variable indices to their exponents. +//! - The value is the coefficient of the term. +//! +//! This representation offers several significant advantages: +//! 1. Space Efficiency: Only non-zero terms are stored, making it ideal for sparse polynomials. +//! 2. Fast Term Lookup: The use of BTreeMap for exponents allows for quick term identification and +//! manipulation. +//! 3. Ordered Operations: BTreeMap's ordered nature facilitates efficient polynomial arithmetic. +//! 4. Memory Optimization: By using indices instead of full variable objects, we reduce memory +//! usage. +//! 5. Flexible Degree Handling: This structure naturally accommodates polynomials of arbitrary +//! degree. +//! 6. Efficient Iteration: Easy to iterate over terms, useful for various algorithms and +//! transformations. +//! +//! While this representation may have a slight overhead for very small polynomials, +//! its benefits become increasingly apparent as the polynomial's complexity grows, +//! making it an excellent choice for a wide range of cryptographic and algebraic applications. + +use std::{ + collections::{BTreeMap, HashMap}, + hash::Hash, + ops::{Add, Mul, Sub}, +}; + +use itertools::Itertools; + +use super::{Monomial, Polynomial, *}; +use crate::algebra::field::FiniteField; + +/// Represents a multivariate polynomial over a finite field. +/// +/// The polynomial is stored as a collection of terms, where each term is represented by: +/// - A `BTreeMap` as the key, mapping variable indices to their exponents. This +/// allows for efficient storage and manipulation of sparse polynomials. +/// - An `F` value as the coefficient, where `F` is a finite field. +/// +/// The use of `HashMap` for `terms` provides: +/// 1. O(1) average-case complexity for term lookup and insertion. +/// 2. Efficient storage for sparse polynomials, as only non-zero terms are stored. +/// +/// The use of `BTreeMap` for exponents provides: +/// 1. Ordered storage of variable exponents, facilitating polynomial arithmetic. +/// 2. Efficient comparison and manipulation of terms. +/// 3. Memory efficiency by using indices instead of full variable objects. +/// +/// This representation is particularly effective for large, sparse multivariate polynomials +/// commonly encountered in cryptographic and algebraic applications. +#[derive(PartialEq, Eq, Debug, Clone)] +pub struct MultivariatePolynomial { + terms: HashMap, F>, +} + +impl MultivariatePolynomial { + /// Constructs a new `MultivariatePolynomial` representing the zero polynomial. + /// + /// This function creates an empty `MultivariatePolynomial`, which is equivalent to the zero + /// polynomial. The zero polynomial has no terms, and evaluates to zero for all inputs. + pub fn new() -> Self { Self { terms: HashMap::new() } } + + /// Creates a new `MultivariatePolynomial` from a vector of `MultivariateTerm`s. + /// + /// This is the preferred way to create a multivariate polynomial, as it allows + /// for a more intuitive representation of the polynomial's terms. + /// + /// # Arguments + /// + /// * `terms` - A vector of `MultivariateTerm`s representing the polynomial. + /// + /// # Returns + /// + /// A new `MultivariatePolynomial` instance. + /// + /// # Example + /// + /// ``` + /// use your_crate::{ + /// MultivariatePolynomial, MultivariateTerm, MultivariateVariable, PlutoBaseField, + /// }; + /// + /// let poly = MultivariatePolynomial::::from_terms(vec![ + /// MultivariateTerm::new( + /// vec![MultivariateVariable { index: 0, exponent: 2 }, MultivariateVariable { + /// index: 1, + /// exponent: 1, + /// }], + /// PlutoBaseField::new(3), + /// ), + /// MultivariateTerm::new( + /// vec![MultivariateVariable { index: 0, exponent: 1 }], + /// PlutoBaseField::new(2), + /// ), + /// MultivariateTerm::new(vec![], PlutoBaseField::new(1)), + /// ]); + /// + /// // This creates the polynomial: 3x_0^2*x_1 + 2x_0 + 1 + /// ``` + pub fn from_terms(terms: Vec>) -> Self { + let mut poly = MultivariatePolynomial::new(); + for term in terms { + let mut btree_map = BTreeMap::new(); + for var in term.variables { + btree_map.insert(var.index, var.exponent); + } + poly.insert_term(btree_map, term.coefficient); + } + poly + } + + fn insert_term(&mut self, exponents: BTreeMap, coefficient: F) { + if coefficient != F::ZERO { + let entry = self.terms.entry(exponents.clone()).or_insert(F::ZERO); + *entry += coefficient; + if *entry == F::ZERO { + self.terms.remove(&exponents); + } + } + } + + /// Returns the coefficient of the term with the given exponents. + /// + /// # Arguments + /// + /// * `exponents` - A `BTreeMap` where the keys are variable indices and the values are their + /// exponents. + /// + /// # Returns + /// + /// * `Some(&F)` if a term with the given exponents exists in the polynomial, where `F` is the + /// coefficient. + /// * `None` if no term with the given exponents exists in the polynomial. + /// + /// # Example + /// + /// ``` + /// use std::collections::BTreeMap; + /// use your_crate::{MultivariatePolynomial, PlutoBaseField}; + /// + /// let poly = MultivariatePolynomial::::from_terms(vec![ + /// // ... (terms as in the previous example) + /// ]); + /// + /// let mut exponents = BTreeMap::new(); + /// exponents.insert(0, 2); + /// exponents.insert(1, 1); + /// + /// assert_eq!(poly.coefficient(&exponents), Some(&PlutoBaseField::new(3))); + /// ``` + pub fn coefficient(&self, exponents: &BTreeMap) -> Option<&F> { + self.terms.get(exponents) + } + + /// Evaluates the multivariate polynomial at the given points. + /// + /// # Arguments + /// + /// * `points` - A slice of tuples where each tuple contains: + /// - The index of the variable (usize) + /// - The value to evaluate the variable at (F) + /// + /// # Returns + /// + /// * The result of evaluating the polynomial (F) + /// + /// # Example + /// + /// ``` + /// use your_crate::{MultivariatePolynomial, PlutoBaseField}; + /// + /// let poly = MultivariatePolynomial::::from_terms(vec![ + /// // x^2 + 2xy + 3z + /// // ... (terms definition) + /// ]); + /// + /// let points = vec![(0, PlutoBaseField::new(2)), (1, PlutoBaseField::new(3)), (2, PlutoBaseField::new(1))]; + /// let result = poly.evaluate(&points); + /// // result will be the evaluation of x^2 + 2xy + 3z at x=2, y=3, z=1 + /// ``` + pub fn evaluate(&self, points: &[(usize, F)]) -> F { + self + .terms + .iter() + .map(|(exponents, coeff)| { + let term_value = exponents + .iter() + .map(|(&var, &exp)| { + points + .iter() + .find(|&&(v, _)| v == var) + .map(|&(_, value)| value.pow(exp)) + .unwrap_or(F::ONE) + }) + .product::(); + *coeff * term_value + }) + .sum() + } + + /// Applies the given variable assignments to the polynomial, reducing its degree. + /// + /// This method substitutes the specified variables with their corresponding values, + /// effectively reducing the polynomial's degree for those variables. The resulting + /// polynomial will have fewer variables if any were fully substituted. + /// + /// # Arguments + /// + /// * `variables` - A slice of tuples, where each tuple contains: + /// - The index of the variable to substitute (usize) + /// - The value to substitute for that variable (F) + /// + /// # Returns + /// + /// A new `MultivariatePolynomial` with the specified variables substituted. + /// + /// # Example + /// + /// ``` + /// use your_crate::{MultivariatePolynomial, PlutoBaseField}; + /// + /// let poly = MultivariatePolynomial::::from_terms(vec![ + /// // x^2y + 3xyz + 2z + /// // ... (terms definition) + /// ]); + /// + /// let assignments = vec![(1, PlutoBaseField::new(2))]; // y = 2 + /// let reduced_poly = poly.apply_variables(&assignments); + /// // The resulting polynomial will be of the form: 2x^2 + 6xz + 2z + /// ``` + pub fn apply_variables(&self, variables: &[(usize, F)]) -> Self { + let mut result = MultivariatePolynomial::new(); + + for (exponents, coeff) in &self.terms { + let mut new_exponents = exponents.clone(); + let mut new_coeff = *coeff; + + for &(var, value) in variables { + if let Some(exp) = new_exponents.get(&var) { + new_coeff *= value.pow(*exp); + new_exponents.remove(&var); + } + } + + if !new_exponents.is_empty() { + result.insert_term(new_exponents, new_coeff); + } else { + result.insert_term(BTreeMap::new(), new_coeff); + } + } + + result + } + + /// Calculates the total degree of the multivariate polynomial. + /// + /// The total degree of a multivariate polynomial is the maximum sum of exponents + /// across all terms in the polynomial. + /// + /// # Returns + /// + /// * `usize` - The total degree of the polynomial. + /// + /// # Example + /// + /// ``` + /// use your_crate::MultivariatePolynomial; + /// use your_crate::FiniteField; + /// + /// let poly = MultivariatePolynomial::::from_terms(vec![ + /// // x^2y + 3xyz^2 + 2z + /// // ... (terms definition) + /// ]); + /// + /// assert_eq!(poly.degree(), 4); // The term xyz^2 has the highest total degree of 4 + /// ``` + pub fn degree(&self) -> usize { + self.terms.keys().map(|exponents| exponents.values().sum::()).max().unwrap_or(0) + } + + /// Returns a vector of all variables present in the polynomial. + /// + /// This method collects all unique variables (represented by their indices) + /// that appear in any term of the polynomial. + /// + /// # Returns + /// + /// * `Vec` - A vector containing the indices of all variables in the polynomial. + /// + /// # Example + /// + /// ``` + /// use your_crate::MultivariatePolynomial; + /// use your_crate::FiniteField; + /// + /// let poly = MultivariatePolynomial::::from_terms(vec![ + /// // x^2y + 3xyz^2 + 2z + /// // ... (terms definition) + /// ]); + /// + /// let vars = poly.variables(); + /// assert_eq!(vars, vec![0, 1, 2]); // Assuming x, y, z are represented by 0, 1, 2 respectively + /// ``` + pub fn variables(&self) -> Vec { + self + .terms + .keys() + .flat_map(|exponents| exponents.keys().cloned()) + .collect::>() + .into_iter() + .collect() + } +} + +impl Add for MultivariatePolynomial { + type Output = Self; + + fn add(mut self, rhs: Self) -> Self::Output { + for (exponents, coeff) in rhs.terms { + self.insert_term(exponents, coeff); + } + self + } +} + +impl Sub for MultivariatePolynomial { + type Output = Self; + + fn sub(mut self, rhs: Self) -> Self::Output { + for (exponents, coeff) in rhs.terms { + // Negate the coefficient and insert + self.insert_term(exponents, -coeff); + } + self + } +} + +impl Mul for MultivariatePolynomial { + type Output = Self; + + fn mul(self, rhs: Self) -> Self::Output { + let mut result = MultivariatePolynomial::new(); + for (exp1, coeff1) in &self.terms { + for (exp2, coeff2) in &rhs.terms { + let mut new_exp = exp1.clone(); + for (&var, &exp) in exp2 { + *new_exp.entry(var).or_insert(0) += exp; + } + result.insert_term(new_exp, *coeff1 * *coeff2); + } + } + result + } +} + +impl Display for MultivariatePolynomial { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + let mut first = true; + for (exponents, coeff) in self.terms.iter().sorted_by(|(a_exp, _), (b_exp, _)| { + a_exp + .iter() + .zip(b_exp.iter()) + .find(|((a_var, a_pow), (b_var, b_pow))| { + a_var.cmp(b_var).then_with(|| b_pow.cmp(a_pow)).is_ne() + }) + .map_or(std::cmp::Ordering::Equal, |(..)| std::cmp::Ordering::Less) + }) { + if !first { + write!(f, " + ")?; + } + first = false; + + if *coeff != F::ONE || exponents.is_empty() { + write!(f, "{}", coeff)?; + } + + let mut first_var = true; + for (&var, &exp) in exponents { + if exp > 0 { + if !first_var || *coeff != F::ONE { + write!(f, "*")?; + } + write!(f, "x_{}", var)?; + if exp > 1 { + write!(f, "^{}", exp)?; + } + first_var = false; + } + } + } + + if first { + write!(f, "0")?; + } + + Ok(()) + } +} + +// Implement From for univariate polynomials +impl From> + for MultivariatePolynomial +{ + fn from(poly: Polynomial) -> Self { + let mut result = MultivariatePolynomial::new(); + for (i, &coeff) in poly.coefficients.iter().enumerate() { + if coeff != F::ZERO { + let mut exponents = BTreeMap::new(); + exponents.insert(0, i); + result.insert_term(exponents, coeff); + } + } + result + } +} + +// Extend Polynomial to support conversion to multivariate +impl Polynomial { + /// Converts a univariate polynomial to a multivariate polynomial. + /// + /// This method transforms the current univariate polynomial into a multivariate polynomial + /// where all terms use the same variable, specified by `variable_index`. + /// + /// # Arguments + /// + /// * `variable_index` - The index of the variable to use in the resulting multivariate + /// polynomial. + /// + /// # Returns + /// + /// A `MultivariatePolynomial` equivalent to the original univariate polynomial. + /// + /// # Example + /// + /// ``` + /// let univariate = Polynomial::new([F::ONE, F::TWO, F::THREE]); // x^2 + 2x + 1 + /// let multivariate = univariate.to_multivariate(0); + /// // Result: x_0^2 + 2*x_0 + 1 + /// ``` + pub fn to_multivariate(self, variable_index: usize) -> MultivariatePolynomial { + let mut result = MultivariatePolynomial::new(); + for (i, &coeff) in self.coefficients.iter().enumerate() { + if coeff != F::ZERO { + let mut exponents = BTreeMap::new(); + exponents.insert(variable_index, i); + result.insert_term(exponents, coeff); + } + } + result + } +} + +/// Represents a variable with an exponent in a multivariate polynomial. +/// Each variable is uniquely identified by its index and has an associated exponent. +#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct MultivariateVariable { + /// The unique identifier for the variable. + /// This index distinguishes one variable from another in the polynomial. + pub index: usize, + + /// The power to which the variable is raised. + /// For example, if exponent is 2, it represents x^2 for the variable x. + pub exponent: usize, +} + +impl MultivariateVariable { + /// Creates a new multivariate variable with the given index and exponent. + pub fn new(index: usize, exponent: usize) -> Self { MultivariateVariable { index, exponent } } +} + +/// Represents a term in a multivariate polynomial. +/// +/// # Fields +/// +/// * `variables` - A vector of `MultivariateVariable`s representing the variables in this term. +/// * `coefficient` - The coefficient of this term, represented as a finite field element. +#[derive(PartialEq, Eq)] +pub struct MultivariateTerm { + /// A vector of `MultivariateVariable`s representing the variables in this term. + /// Each `MultivariateVariable` contains an index and an exponent. + pub variables: Vec, + + /// The coefficient of this term, represented as a finite field element. + /// This value multiplies the product of the variables in the term. + pub coefficient: F, +} + +/// Represents a term in a multivariate polynomial. +/// A term consists of a coefficient and a collection of variables with their exponents. +impl MultivariateTerm { + /// Creates a new multivariate term with the given variables and coefficient. + pub fn new(variables: Vec, coefficient: F) -> Self { + MultivariateTerm { variables, coefficient } + } +} diff --git a/src/polynomial/tests.rs b/src/polynomial/tests.rs index 7761e58..a35b12c 100644 --- a/src/polynomial/tests.rs +++ b/src/polynomial/tests.rs @@ -1,4 +1,7 @@ use super::*; +use crate::polynomial::multivariate_polynomial::{ + MultivariatePolynomial, MultivariateTerm, MultivariateVariable, +}; #[fixture] fn poly() -> Polynomial { @@ -126,3 +129,283 @@ fn dft(poly: Polynomial) { // Polynomial::::new(vec![PlutoBaseField::ZERO, // PlutoBaseField::ZERO]); assert_eq!(poly.coefficients, [PlutoBaseField::ZERO]); } + +#[test] +fn test_multivariate_polynomial_creation() { + let poly = MultivariatePolynomial::::from_terms(vec![ + MultivariateTerm::new( + vec![MultivariateVariable { index: 0, exponent: 2 }, MultivariateVariable { + index: 1, + exponent: 1, + }], + PlutoBaseField::new(3), + ), + MultivariateTerm::new( + vec![MultivariateVariable { index: 0, exponent: 1 }], + PlutoBaseField::new(2), + ), + MultivariateTerm::new(vec![], PlutoBaseField::new(1)), + ]); + + assert_eq!(poly.degree(), 3); + + assert_eq!( + poly.variables().into_iter().collect::>(), + vec![0, 1].into_iter().collect::>() + ); +} + +#[test] +fn test_multivariate_polynomial_addition() { + let poly1 = MultivariatePolynomial::::from_terms(vec![ + MultivariateTerm::new( + vec![MultivariateVariable { index: 0, exponent: 2 }], + PlutoBaseField::new(1), + ), + MultivariateTerm::new( + vec![MultivariateVariable { index: 1, exponent: 1 }], + PlutoBaseField::new(2), + ), + ]); + + let poly2 = MultivariatePolynomial::::from_terms(vec![ + MultivariateTerm::new( + vec![MultivariateVariable { index: 0, exponent: 2 }], + PlutoBaseField::new(3), + ), + MultivariateTerm::new( + vec![MultivariateVariable { index: 2, exponent: 1 }], + PlutoBaseField::new(4), + ), + ]); + + let result = poly1 + poly2; + + println!("Addition Result polynomial: {}", result); + + let expected_result = MultivariatePolynomial::::from_terms(vec![ + MultivariateTerm::new( + vec![MultivariateVariable { index: 0, exponent: 2 }], + PlutoBaseField::new(4), + ), + MultivariateTerm::new( + vec![MultivariateVariable { index: 1, exponent: 1 }], + PlutoBaseField::new(2), + ), + MultivariateTerm::new( + vec![MultivariateVariable { index: 2, exponent: 1 }], + PlutoBaseField::new(4), + ), + ]); + + assert_eq!(result, expected_result); +} + +#[test] +fn test_multivariate_polynomial_multiplication() { + let poly1 = MultivariatePolynomial::::from_terms(vec![ + MultivariateTerm::new( + vec![MultivariateVariable { index: 0, exponent: 1 }], + PlutoBaseField::new(2), + ), + MultivariateTerm::new(vec![], PlutoBaseField::new(1)), + ]); + + let poly2 = MultivariatePolynomial::::from_terms(vec![ + MultivariateTerm::new( + vec![MultivariateVariable { index: 1, exponent: 1 }], + PlutoBaseField::new(3), + ), + MultivariateTerm::new(vec![], PlutoBaseField::new(1)), + ]); + + let result = poly1 * poly2; + + println!("Multiplication Result polynomial: {}", result); + + let expected_result = MultivariatePolynomial::::from_terms(vec![ + MultivariateTerm::new( + vec![MultivariateVariable { index: 0, exponent: 1 }, MultivariateVariable { + index: 1, + exponent: 1, + }], + PlutoBaseField::new(6), + ), + MultivariateTerm::new( + vec![MultivariateVariable { index: 0, exponent: 1 }], + PlutoBaseField::new(2), + ), + MultivariateTerm::new( + vec![MultivariateVariable { index: 1, exponent: 1 }], + PlutoBaseField::new(3), + ), + MultivariateTerm::new(vec![], PlutoBaseField::new(1)), + ]); + + assert_eq!(result, expected_result); +} + +#[test] +fn test_multivariate_polynomial_evaluation() { + let poly = MultivariatePolynomial::::from_terms(vec![ + MultivariateTerm::new( + vec![MultivariateVariable { index: 0, exponent: 2 }, MultivariateVariable { + index: 1, + exponent: 1, + }], + PlutoBaseField::new(3), + ), + MultivariateTerm::new( + vec![MultivariateVariable { index: 0, exponent: 1 }], + PlutoBaseField::new(2), + ), + MultivariateTerm::new(vec![], PlutoBaseField::new(1)), + ]); + + let points = vec![(0, PlutoBaseField::new(2)), (1, PlutoBaseField::new(3))]; + + println!("{}", poly); + + let result = poly.evaluate(&points); + // 3*(2^2)*(3) + 2*(2) + 1 = 3*4*3 + 4 + 1 = 36 + 4 + 1 = 41 + assert_eq!(result, PlutoBaseField::new(41)); +} + +#[test] +fn test_apply_variables_single_variable() { + let poly = MultivariatePolynomial::::from_terms(vec![ + MultivariateTerm::new( + vec![MultivariateVariable { index: 0, exponent: 2 }], + PlutoBaseField::new(3), + ), + MultivariateTerm::new( + vec![MultivariateVariable { index: 0, exponent: 1 }], + PlutoBaseField::new(2), + ), + MultivariateTerm::new(vec![], PlutoBaseField::new(1)), + ]); + + let variables = vec![(0, PlutoBaseField::new(2))]; + let result = poly.apply_variables(&variables); + + let expected_result = + MultivariatePolynomial::::from_terms(vec![MultivariateTerm::new( + vec![], + PlutoBaseField::new(17), + )]); + + assert_eq!(result, expected_result); +} + +#[test] +fn test_apply_variables_multiple_variables() { + let poly = MultivariatePolynomial::::from_terms(vec![ + MultivariateTerm::new( + vec![MultivariateVariable { index: 0, exponent: 2 }, MultivariateVariable { + index: 1, + exponent: 1, + }], + PlutoBaseField::new(3), + ), + MultivariateTerm::new( + vec![MultivariateVariable { index: 0, exponent: 1 }], + PlutoBaseField::new(2), + ), + MultivariateTerm::new( + vec![MultivariateVariable { index: 1, exponent: 1 }], + PlutoBaseField::new(4), + ), + MultivariateTerm::new(vec![], PlutoBaseField::new(1)), + ]); + + println!("Apply Multiple Variables Polynomial: {}", poly); + + let variables = vec![(0, PlutoBaseField::new(2)), (1, PlutoBaseField::new(3))]; + let result = poly.apply_variables(&variables); + + println!("Reduced Multiple Variable Polynomial: {}", result); + + let expected_result = + MultivariatePolynomial::::from_terms(vec![MultivariateTerm::new( + vec![], + PlutoBaseField::new(53), + )]); + + assert_eq!(result, expected_result); +} + +#[test] +fn test_apply_variables_partial_application() { + let poly = MultivariatePolynomial::::from_terms(vec![ + MultivariateTerm::new( + vec![ + MultivariateVariable { index: 0, exponent: 2 }, + MultivariateVariable { index: 1, exponent: 1 }, + MultivariateVariable { index: 2, exponent: 1 }, + ], + PlutoBaseField::new(3), + ), + MultivariateTerm::new( + vec![MultivariateVariable { index: 0, exponent: 1 }, MultivariateVariable { + index: 2, + exponent: 1, + }], + PlutoBaseField::new(2), + ), + MultivariateTerm::new( + vec![MultivariateVariable { index: 1, exponent: 1 }], + PlutoBaseField::new(4), + ), + ]); + + let variables = vec![(0, PlutoBaseField::new(2)), (1, PlutoBaseField::new(3))]; + let applied_poly = poly.apply_variables(&variables); + + let expected_result = MultivariatePolynomial::::from_terms(vec![ + MultivariateTerm::new( + vec![MultivariateVariable { index: 2, exponent: 1 }], + PlutoBaseField::new(40), + ), + MultivariateTerm::new(vec![], PlutoBaseField::new(12)), + ]); + + assert_eq!(applied_poly, expected_result); +} + +#[test] +fn test_apply_variables_no_effect() { + let poly = MultivariatePolynomial::::from_terms(vec![ + MultivariateTerm::new( + vec![MultivariateVariable { index: 0, exponent: 2 }, MultivariateVariable { + index: 1, + exponent: 1, + }], + PlutoBaseField::new(3), + ), + MultivariateTerm::new( + vec![MultivariateVariable { index: 2, exponent: 1 }], + PlutoBaseField::new(2), + ), + ]); + + let variables = vec![(3, PlutoBaseField::new(5))]; + let result = poly.apply_variables(&variables); + + assert_eq!(result, poly); +} + +#[test] +fn test_apply_variables_empty() { + let poly = MultivariatePolynomial::::from_terms(vec![MultivariateTerm::new( + vec![MultivariateVariable { index: 0, exponent: 2 }, MultivariateVariable { + index: 1, + exponent: 1, + }], + PlutoBaseField::new(3), + )]); // 3x_0^2*x_1 + + let variables = vec![]; + let result = poly.apply_variables(&variables); + + assert_eq!(result, poly); +} diff --git a/src/random/mod.rs b/src/random/mod.rs new file mode 100644 index 0000000..76db00e --- /dev/null +++ b/src/random/mod.rs @@ -0,0 +1,58 @@ +//! # Random Number Generation and Random Oracle Functionality +//! +//! This module provides traits and utilities for random number generation and +//! random oracle functionality, which are essential for various cryptographic +//! operations and protocols. +//! +//! ## Key Components +//! +//! - `Random`: A trait for types that can be randomly generated. +//! - `RandomOracle`: A trait for types that can be generated using a random oracle approach. +//! +//! These traits allow for flexible and secure generation of random instances +//! for implementing types, supporting both standard random generation and +//! more complex random oracle-based generation. +//! +//! The module is designed to work seamlessly with the `rand` crate's `Rng` trait, +//! providing a consistent interface for random number generation across the library. + +use rand::Rng; + +/// A trait for types that can be randomly generated. +/// +/// Types implementing this trait can create random instances of themselves +/// using a provided random number generator. +pub trait Random { + /// Generates a random instance of the implementing type. + /// + /// # Arguments + /// + /// * `rng` - A mutable reference to a random number generator. + /// + /// # Returns + /// + /// A randomly generated instance of the implementing type. + fn random(rng: &mut R) -> Self; +} + +/// A trait for types that can be generated using a random oracle. +/// +/// Types implementing this trait can create instances of themselves +/// using a provided random number generator and an input byte slice, +/// simulating a random oracle functionality. +pub trait RandomOracle: Random { + /// Generates an instance of the implementing type using a random oracle approach. + /// + /// This method takes both a random number generator and an input byte slice, + /// allowing for deterministic yet unpredictable output based on the input. + /// + /// # Arguments + /// + /// * `rng` - A mutable reference to a random number generator. + /// * `input` - A byte slice used as input to the random oracle. + /// + /// # Returns + /// + /// An instance of the implementing type, generated using the random oracle approach. + fn random_oracle(rng: &mut R, input: &[u8]) -> Self; +} diff --git a/src/sumcheck/boolean_array.rs b/src/sumcheck/boolean_array.rs new file mode 100644 index 0000000..d50fa9c --- /dev/null +++ b/src/sumcheck/boolean_array.rs @@ -0,0 +1,34 @@ +struct BooleanArrayIter { + current: Vec, + done: bool, +} + +impl Iterator for BooleanArrayIter { + type Item = Vec; + + fn next(&mut self) -> Option { + if self.done { + return None; + } + + let result = self.current.clone(); + + // Generate next array + for i in 0..self.current.len() { + if self.current[i] { + self.current[i] = false; + } else { + self.current[i] = true; + return Some(result); + } + } + + // If we've reached here, we've generated all arrays + self.done = true; + Some(result) + } +} + +pub fn get_all_possible_boolean_values(length: usize) -> impl Iterator> { + BooleanArrayIter { current: vec![false; length], done: false } +} diff --git a/src/sumcheck/mod.rs b/src/sumcheck/mod.rs new file mode 100644 index 0000000..1cd0336 --- /dev/null +++ b/src/sumcheck/mod.rs @@ -0,0 +1,623 @@ +//! # Sumcheck Protocol Implementation +//! +//! This module implements the sumcheck protocol, a powerful interactive proof system +//! used in various zero-knowledge proof constructions. +//! +//! ## Overview of the Sumcheck Protocol +//! +//! The sumcheck protocol allows a prover to convince a verifier that the sum of a +//! multivariate polynomial over all boolean inputs (i.e., the boolean hypercube) is +//! equal to a claimed value, without the verifier having to compute the sum directly. +//! +//! The protocol proceeds in rounds, where in each round: +//! 1. The prover sends a univariate polynomial. +//! 2. The verifier checks certain properties of this polynomial and sends a random challenge. +//! 3. This process reduces the multivariate polynomial to a univariate one in each round. +//! +//! ## Implementation Details +//! +//! This implementation provides both interactive and non-interactive versions of the sumcheck +//! protocol. +//! +//! ### Prover Implementation +//! +//! The prover's implementation is split into several methods: +//! +//! - `prove_first_sumcheck_round`: Computes the claimed sum and the first univariate polynomial. +//! - `prove_sumcheck_round_i`: Generates the univariate polynomial for intermediate rounds. +//! - `prove_sumcheck_last_round`: Handles the final round of the protocol. +//! - `compute_univariate_polynomial`: A helper method to compute the univariate polynomial for each +//! round. +//! +//! This structure allows for a clear separation of concerns and follows the round-based +//! nature of the sumcheck protocol. +//! +//! ### Verifier Implementation +//! +//! The verifier's implementation is divided into separate functions for each stage: +//! +//! - `verify_sumcheck_first_round`: Verifies the first round, checking the claimed sum. +//! - `verify_sumcheck_univariate_poly_sum`: Verifies intermediate rounds. +//! - `verify_sumcheck_last_round`: Performs the final verification step. +//! +//! This separation allows for clear and modular verification logic, closely following +//! the structure of the sumcheck protocol. +//! +//! ### Non-Interactive Version +//! +//! The module also provides non-interactive versions of the protocol: +//! +//! - `non_interactive_sumcheck_prove`: Generates a complete proof in one step. +//! - `non_interactive_sumcheck_verify`: Verifies the complete proof. +//! +//! These functions use a random oracle model to simulate the interactive challenges, +//! making the protocol suitable for non-interactive scenarios. +//! +//! ## Correctness and Efficiency +//! +//! The implementation correctly follows the sumcheck protocol: +//! +//! 1. It reduces the multivariate polynomial to univariate polynomials in each round. +//! 2. It uses random challenges to ensure the prover cannot predict the verification path. +//! 3. The final verification step ties the protocol back to the original multivariate polynomial. +//! +//! The use of `MultivariatePolynomial` and efficient polynomial operations ensures that +//! the implementation is both correct and computationally efficient. +//! +//! ## Usage +//! +//! To use this implementation, create a `MultivariatePolynomial`, then use the +//! `non_interactive_sumcheck_prove` function to generate a proof, and +//! `non_interactive_sumcheck_verify` to verify it. +//! +//! For more fine-grained control, you can use the individual prover and verifier functions +//! to implement an interactive version of the protocol. + +use std::{ + fmt::Display, + hash::{Hash, Hasher}, +}; + +use rand::{Rng, SeedableRng}; + +use crate::{ + algebra::field::FiniteField, + polynomial::multivariate_polynomial::MultivariatePolynomial, + random::{Random, RandomOracle}, +}; + +mod boolean_array; +#[cfg(test)] mod tests; +mod to_bytes; + +use self::{boolean_array::get_all_possible_boolean_values, to_bytes::ToBytes}; + +impl RandomOracle for F { + fn random_oracle(_rng: &mut R, input: &[u8]) -> Self { + // This is a simplified example. In a real implementation, + // you'd want to use a cryptographic hash function here. + let mut hasher = std::collections::hash_map::DefaultHasher::new(); + input.hash(&mut hasher); + let hash = hasher.finish(); + + // Use the hash to seed a new RNG + let mut seeded_rng = rand::rngs::StdRng::seed_from_u64(hash); + + // Generate a random field element using the seeded RNG + Self::random(&mut seeded_rng) + } +} + +impl MultivariatePolynomial { + /// Proves the first round of the sumcheck protocol for this multivariate polynomial. + /// + /// This function is crucial for initiating the sumcheck protocol because: + /// 1. It computes the total sum of the polynomial over all boolean inputs, which is the claimed + /// sum that the prover wants to prove. + /// 2. It generates the first univariate polynomial g_1(X1), which is the first step in reducing + /// the multivariate sumcheck to a series of univariate sumchecks. + /// + /// The sumcheck protocol is essential for efficiently verifying the sum of a multivariate + /// polynomial over a boolean hypercube without evaluating every point, which would be + /// exponential in the number of variables. This function sets up the foundation for the + /// entire protocol. + /// + /// # Returns + /// - A tuple containing: + /// 1. The claimed sum (F): The total sum of the polynomial over all boolean inputs. + /// 2. The first univariate polynomial (MultivariatePolynomial): g_1(X1), which is actually + /// univariate despite the type name. + pub fn prove_first_sumcheck_round(&self) -> (F, MultivariatePolynomial) { + let variables = self.variables(); + let num_variables = variables.len(); + + let sum = get_all_possible_boolean_values(num_variables) + .map(|bool_values| { + let assignment: Vec<(usize, F)> = variables + .iter() + .enumerate() + .map(|(i, &var)| (var, if bool_values[i] { F::ONE } else { F::ZERO })) + .collect(); + self.evaluate(&assignment) + }) + .sum(); + + // Compute the univariate polynomial g_1(X1) + let univariate_poly = self.compute_univariate_polynomial(0, vec![]); + + (sum, univariate_poly) + } + + /// Proves the i-th round of the sumcheck protocol for this multivariate polynomial. + /// + /// This function is a key part of the sumcheck protocol, generating the univariate polynomial + /// for the i-th round based on the partial assignment from previous rounds. + /// + /// # Arguments + /// + /// * `i` - The current round number (0-indexed). + /// * `partial_assignment` - A vector of field elements representing the values chosen by the + /// verifier in previous rounds. + /// + /// # Returns + /// + /// * `MultivariatePolynomial` - The univariate polynomial g_i(X_i) for the i-th round. Despite + /// the type name, this polynomial is univariate in X_i. + /// + /// # Properties and Equalities + /// + /// 1. Degree Preservation: The degree of g_i(X_i) in X_i is at most the degree of the original + /// polynomial in X_i. + /// + /// 2. Sum Consistency: The sum of g_i(X_i) over {0,1} equals g_{i-1}(r_{i-1}), where r_{i-1} is + /// the random challenge from the previous round. + /// + /// 3. Randomized Reduction: g_i(X_i) reduces the sum check for i variables to a sum check for i-1 + /// variables when a random point is chosen. + /// + /// 4. Partial Evaluation: g_i(X_i) can be seen as a partial evaluation of the original + /// polynomial, with the first i-1 variables fixed to the values in partial_assignment. + /// + /// These properties ensure the soundness and completeness of the sumcheck protocol, + /// allowing for efficient verification of the claimed sum. + /// + /// # Note + /// + /// This function relies on `compute_univariate_polynomial` to perform the actual computation + /// of the univariate polynomial for the current round. + pub fn prove_sumcheck_round_i( + &self, + i: usize, + partial_assignment: Vec, + ) -> MultivariatePolynomial { + return self.compute_univariate_polynomial(i, partial_assignment); + } + + /// Proves the last round of the sumcheck protocol for this multivariate polynomial. + /// + /// This function is similar to `prove_sumcheck_round_i`, but specifically handles the last round + /// of the sumcheck protocol. It generates the final univariate polynomial based on all previous + /// assignments. + /// + /// # Arguments + /// + /// * `i` - The index of the last round (should be equal to the number of variables minus 1). + /// * `partial_assignment` - A vector of field elements representing all values chosen by the + /// verifier in previous rounds. + /// + /// # Returns + /// + /// * `MultivariatePolynomial` - The final univariate polynomial for the last round. This + /// polynomial is univariate in the last remaining variable. + /// + /// # Note + /// + /// This function relies on `compute_univariate_polynomial` to perform the actual computation + /// of the univariate polynomial for the last round. The result of this function is crucial + /// for the final verification step in the sumcheck protocol. + pub fn prove_sumcheck_last_round( + &self, + i: usize, + partial_assignment: Vec, + ) -> MultivariatePolynomial { + return self.compute_univariate_polynomial(i, partial_assignment); + } + + fn compute_univariate_polynomial( + &self, + round: usize, + partial_assignment: Vec, + ) -> MultivariatePolynomial { + let variables = self.variables(); + let num_variables = variables.len(); + + // Create a polynomial to store the result + // First create a partial evaluation + let partial_poly = self.apply_variables( + &partial_assignment.iter().enumerate().map(|(i, &v)| (i, v)).collect::>(), + ); + + let result_polynomial = get_all_possible_boolean_values(num_variables - round - 1) + .map(|bool_values| { + let further_assignments: Vec = + bool_values.iter().map(|&b| if b { F::ONE } else { F::ZERO }).collect(); + let further_variables = + ((round + 1)..num_variables).zip(further_assignments).collect::>(); + let poly = partial_poly.clone().apply_variables(&further_variables); + poly + }) + .fold(MultivariatePolynomial::new(), |acc, poly| acc + poly); + + // Assert that the resulting polynomial has only one variable + assert!( + result_polynomial.variables().len() <= 1, + "The univariate polynomial should have at most one variable" + ); + + result_polynomial + } +} + +/// Verifies the first round of the sumcheck protocol. +/// +/// This function is crucial for initiating the verification process in the sumcheck protocol. +/// The verifier needs these components to ensure the correctness of the prover's claim: +/// +/// 1. `claimed_sum`: The total sum claimed by the prover. This is the value that the verifier wants +/// to check without computing the entire sum themselves. +/// +/// 2. `univariate_poly`: The first univariate polynomial g_1(X_1) provided by the prover. This +/// polynomial is supposed to represent the sum over all but the first variable. +/// +/// The verification process involves: +/// +/// 1. Checking that the provided polynomial is indeed univariate. This ensures that the prover is +/// following the protocol correctly by reducing one variable at a time. +/// +/// 2. Verifying that g_1(0) + g_1(1) equals the claimed sum. This check is fundamental because it +/// connects the univariate polynomial to the original multivariate sum. If this equality holds, +/// it suggests that the prover has correctly computed the univariate polynomial for the first +/// round. +/// +/// 3. Generating a random challenge. This challenge will be used in subsequent rounds and is +/// crucial for the security of the protocol. It ensures that the prover cannot predict or +/// manipulate future rounds. +/// +/// # Arguments +/// +/// * `claimed_sum`: The sum claimed by the prover. +/// * `univariate_poly`: The univariate polynomial for the first round. +/// +/// # Returns +/// +/// A tuple containing: +/// - A boolean indicating whether the verification passed (true) or failed (false). +/// - The random challenge generated for the next round. +/// +/// # Type Parameters +/// +/// * `F`: A type that implements both `FiniteField` and `Random` traits. +pub fn verify_sumcheck_first_round( + claimed_sum: F, + univariate_poly: &MultivariatePolynomial, +) -> (bool, F) { + // Step 1: Verify that the polynomial is univariate (has only one variable) + if univariate_poly.variables().len() != 1 { + return (false, F::ZERO); + } + + // Step 2: Verify that g(0) + g(1) = claimed_sum + let var = 0; + let sum_at_endpoints = + univariate_poly.evaluate(&[(var, F::ZERO)]) + univariate_poly.evaluate(&[(var, F::ONE)]); + + if sum_at_endpoints != claimed_sum { + return (false, F::ZERO); + } + + // Step 3: Generate a random challenge + let mut rng = rand::thread_rng(); + let challenge: F = F::random(&mut rng); + + // Return true (verification passed) and the evaluation at the challenge point + (true, challenge) +} + +/// Verify the i-th round of the sumcheck protocol +/// +/// This function is crucial for verifying the correctness of each intermediate step in the sumcheck +/// protocol. It ensures that the prover is following the protocol correctly and not deviating from +/// the expected behavior. +/// +/// # Arguments +/// +/// * `round`: The current round number of the sumcheck protocol. This is needed to keep track of +/// which variable is being eliminated in the current round. +/// * `challenge`: The random challenge from the previous round. This is used to evaluate the +/// previous round's polynomial and connect it to the current round. +/// * `previous_univariate_poly`: The univariate polynomial from the previous round. This is needed +/// to verify the consistency between rounds. +/// * `current_univariate_poly`: The univariate polynomial for the current round. This is the +/// polynomial that the prover claims represents the sum over the current variable. +/// +/// # Returns +/// +/// A tuple containing: +/// - A boolean indicating whether the verification passed (true) or failed (false). +/// - The new random challenge generated for the next round. +/// +/// # Why these parameters are needed +/// +/// 1. `round`: Keeps track of the protocol's progress and ensures variables are eliminated in +/// order. +/// 2. `challenge`: Connects the current round to the previous one, preventing the prover from +/// deviating. +/// 3. `previous_univariate_poly`: Used to verify consistency between rounds. +/// 4. `current_univariate_poly`: The polynomial to be verified in the current round. +/// +/// These parameters allow the verifier to check: +/// - The univariate nature of the current polynomial (ensuring one variable is eliminated per +/// round). +/// - The consistency between rounds (g_i(r_{i-1}) = g_{i-1}(0) + g_{i-1}(1)). +/// - Generate a new random challenge for the next round, maintaining the protocol's +/// unpredictability. +pub fn verify_sumcheck_univariate_poly_sum( + round: usize, + challenge: F, + previous_univariate_poly: &MultivariatePolynomial, + current_univariate_poly: &MultivariatePolynomial, +) -> (bool, F) { + // Step 1: Verify that the current polynomial is univariate + if current_univariate_poly.variables().len() > 1 { + return (false, F::ZERO); + } + + // Step 2: Verify that g_i(r_{i-1}) = g_{i-1}(0) + g_{i-1}(1) + let prev_var = round - 1; + let sum_at_endpoints = previous_univariate_poly.evaluate(&[(prev_var, challenge)]); + + let eval_at_previous_challenge = current_univariate_poly.evaluate(&[(round, F::ZERO)]) + + current_univariate_poly.evaluate(&[(round, F::ONE)]); + if eval_at_previous_challenge != sum_at_endpoints { + return (false, F::ZERO); + } + + // Step 3: Generate a new random challenge + let mut rng = rand::thread_rng(); + let new_challenge: F = F::random(&mut rng); + + // Return true (verification passed) and the evaluation at the new challenge point + (true, new_challenge) +} + +/// Verifies the final round of the sumcheck protocol. +/// +/// This function is crucial for the verifier to ensure the prover's honesty in the final round +/// of the sumcheck protocol. It checks if the claimed univariate polynomial is consistent with +/// the original multivariate polynomial when all challenges are applied. +/// +/// # Arguments +/// +/// * `challenges`: A vector of all previous challenges from earlier rounds. +/// * `univariate_poly`: The final univariate polynomial claimed by the prover. +/// * `poly`: The original multivariate polynomial. +/// +/// # Returns +/// +/// A boolean indicating whether the verification passed (true) or failed (false). +/// +/// # Why the verifier needs this information +/// +/// 1. `challenges`: The verifier needs all previous challenges to reconstruct the point at which +/// the original polynomial should be evaluated. This ensures consistency across all rounds. +/// +/// 2. `univariate_poly`: This is the prover's final claim about the polynomial after all but one +/// variable have been fixed. The verifier needs to check if this claim is consistent with the +/// original polynomial. +/// +/// 3. `poly`: The original multivariate polynomial is necessary to independently compute the +/// correct evaluation and compare it with the prover's claim. +/// +/// By comparing the evaluation of the original polynomial at the challenge point with the +/// evaluation of the claimed univariate polynomial at a random point, the verifier can detect +/// any dishonesty from the prover with high probability. +pub fn verify_sumcheck_last_round( + challenges: Vec, + univariate_poly: &MultivariatePolynomial, + poly: &MultivariatePolynomial, +) -> bool { + // Step 1: Apply all challenges to the original polynomial + let mut challenges_with_indices = Vec::new(); + for (i, challenge) in challenges.iter().enumerate() { + challenges_with_indices.push((i, *challenge)); + } + let poly_evaluation = poly.evaluate(&challenges_with_indices); + + // Step 2: Generate a random challenge for the last variable + let mut rng = rand::thread_rng(); + let last_challenge: F = F::random(&mut rng); + + // Step 3: Evaluate the univariate polynomial at the last challenge + let last_var = challenges.len(); + let univariate_evaluation = univariate_poly.evaluate(&[(last_var, last_challenge)]); + + // Step 4: Compare the evaluations + poly_evaluation == univariate_evaluation +} + +impl ToBytes for F { + fn to_bytes(&self) -> Vec { + // Implement this based on how your field elements are represented + // This is just an example: + self.to_string().into_bytes() + } +} + +impl ToBytes for MultivariatePolynomial { + fn to_bytes(&self) -> Vec { + // Implement this based on how your polynomials are represented + // This is just an example: + self.to_string().into_bytes() + } +} + +/// Represents a proof for the sumcheck protocol over a finite field. +/// +/// This struct contains all the components necessary for verifying a sumcheck proof, +/// including the claimed sum, round polynomials, challenges, evaluations, and final results. +/// +/// # Type Parameters +/// +/// * `F`: A type that implements the `FiniteField` trait, representing the field over which the +/// sumcheck protocol is performed. +pub struct SumcheckProof { + /// The claimed sum of the polynomial over all boolean inputs. + pub claimed_sum: F, + + /// Vector of univariate polynomials, one for each round of the protocol. + pub round_polynomials: Vec>, + + /// Vector of challenges generated during the protocol. + pub challenges: Vec, + + /// Vector of evaluations of the round polynomials at the challenge points. + pub round_evaluations: Vec, + + /// The final evaluation point, consisting of all challenges combined. + pub final_point: Vec, + + /// The final evaluation of the original multivariate polynomial at the final point. + pub final_evaluation: F, +} + +/// Generates a non-interactive sumcheck proof for a given multivariate polynomial. +/// +/// This function implements the prover's side of the non-interactive sumcheck protocol. +/// It generates a proof that the sum of the polynomial over all boolean inputs equals +/// the claimed sum, without requiring interaction with the verifier. +/// +/// The non-interactive nature is achieved by using a random oracle to generate challenges, +/// which both the prover and verifier can compute independently. +/// +/// # Arguments +/// +/// * `polynomial` - The multivariate polynomial for which to generate the sumcheck proof. +/// +/// # Returns +/// +/// Returns a `SumcheckProof` containing all necessary components for verification: +/// - The claimed sum +/// - Univariate polynomials for each round +/// - Challenges generated using the random oracle +/// - Evaluations of the round polynomials at the challenge points +/// - The final evaluation point and the polynomial's evaluation at that point +/// +/// # Type Parameters +/// +/// * `F` - A finite field type that implements necessary traits for arithmetic, random number +/// generation, conversion to bytes, and display. +pub fn non_interactive_sumcheck_prove< + F: FiniteField + Random + RandomOracle + Display + ToBytes, +>( + polynomial: &MultivariatePolynomial, +) -> SumcheckProof { + let num_variables = polynomial.variables().len(); + let mut challenges = Vec::new(); + let mut round_polynomials = Vec::new(); + let mut round_evaluations = Vec::new(); + + // First round: compute the claimed sum and the first univariate polynomial + let (claimed_sum, first_univariate_poly) = polynomial.prove_first_sumcheck_round(); + round_polynomials.push(first_univariate_poly.clone()); + + // Generate the first challenge using the random oracle + let mut rng = rand::thread_rng(); + let challenge: F = F::random_oracle(&mut rng, &claimed_sum.to_bytes()); + challenges.push(challenge); + round_evaluations.push(first_univariate_poly.evaluate(&[(0, challenge)])); + + let mut previous_univariate_poly = first_univariate_poly; + + // Intermediate rounds: generate univariate polynomials and challenges + for i in 1..num_variables { + let univariate_poly = polynomial.prove_sumcheck_round_i(i, challenges.clone()); + round_polynomials.push(univariate_poly.clone()); + + // Generate challenge for this round using the random oracle + let challenge: F = F::random_oracle(&mut rng, &previous_univariate_poly.to_bytes()); + challenges.push(challenge); + round_evaluations.push(univariate_poly.evaluate(&[(i, challenge)])); + + previous_univariate_poly = univariate_poly; + } + + // Final evaluation: evaluate the original polynomial at the challenge point + let final_point = challenges.clone(); + let final_evaluation = + polynomial.evaluate(&final_point.iter().cloned().enumerate().collect::>()); + + // Construct and return the proof + SumcheckProof { + claimed_sum, + round_polynomials, + round_evaluations, + challenges, + final_point, + final_evaluation, + } +} + +/// Verifies a non-interactive sumcheck proof. +/// +/// This function allows a verifier to be easily convinced of the correctness of a sumcheck proof +/// without interacting with the prover. The verifier can be convinced by the following steps: +/// +/// 1. Check the consistency of the first round's claimed sum with the provided univariate +/// polynomial. +/// 2. Verify the consistency between consecutive rounds' univariate polynomials. +/// 3. Confirm that the final evaluation matches the original multivariate polynomial at the +/// challenge point. +/// +/// The non-interactive nature of this proof system comes from the use of a random oracle to +/// generate challenges, which both the prover and verifier can compute independently. +/// +/// # Arguments +/// +/// * `proof` - The `SumcheckProof` provided by the prover. +/// * `polynomial` - The original multivariate polynomial being summed over. +/// +/// # Returns +/// +/// Returns `true` if the proof is valid, `false` otherwise. +pub fn non_interactive_sumcheck_verify( + proof: &SumcheckProof, + polynomial: &MultivariatePolynomial, +) -> bool { + let num_variables = polynomial.variables().len(); + + // Verify first round + let (valid, _) = verify_sumcheck_first_round(proof.claimed_sum, &proof.round_polynomials[0]); + if !valid { + return false; + } + + // Verify intermediate rounds + for i in 1..num_variables { + let (valid, _) = verify_sumcheck_univariate_poly_sum( + i, + proof.challenges[i - 1], + &proof.round_polynomials[i - 1], + &proof.round_polynomials[i], + ); + if !valid { + return false; + } + } + + // Verify last round + verify_sumcheck_last_round( + proof.final_point.clone(), + &proof.round_polynomials.last().unwrap(), + polynomial, + ) +} diff --git a/src/sumcheck/tests.rs b/src/sumcheck/tests.rs new file mode 100644 index 0000000..186084b --- /dev/null +++ b/src/sumcheck/tests.rs @@ -0,0 +1,147 @@ +use crate::{ + algebra::field::{prime::PlutoBaseField, Field}, + polynomial::multivariate_polynomial::{ + MultivariatePolynomial, MultivariateTerm, MultivariateVariable, + }, + sumcheck::{ + verify_sumcheck_first_round, verify_sumcheck_last_round, verify_sumcheck_univariate_poly_sum, + }, +}; + +#[test] +fn test_full_sumcheck_protocol() { + // This test demonstrates the full sumcheck protocol for the polynomial: + // f(x0, x1, x2) = x0 * (x1 + x2) - (x1 * x2) + // We'll prove and verify the sum of this polynomial over the boolean hypercube {0,1}^3. + + // The sumcheck protocol is used to prove the sum of a multivariate polynomial over a boolean + // hypercube without explicitly computing all 2^n evaluations. This is particularly useful for + // large n, where computing all evaluations would be computationally infeasible. + + // Step 1: Define the polynomial + // We start with a multivariate polynomial because the sumcheck protocol is designed to work + // with functions over boolean inputs, which are naturally represented as multivariate + // polynomials. + let poly = MultivariatePolynomial::::from_terms(vec![ + MultivariateTerm::new( + vec![MultivariateVariable { index: 0, exponent: 1 }, MultivariateVariable { + index: 1, + exponent: 1, + }], + PlutoBaseField::ONE, + ), // x0 * x1 + MultivariateTerm::new( + vec![MultivariateVariable { index: 0, exponent: 1 }, MultivariateVariable { + index: 2, + exponent: 1, + }], + PlutoBaseField::ONE, + ), // x0 * x2 + MultivariateTerm::new( + vec![MultivariateVariable { index: 1, exponent: 1 }, MultivariateVariable { + index: 2, + exponent: 1, + }], + -PlutoBaseField::ONE, + ), // -x1 * x2 + ]); + + // Step 2: First round of the sumcheck protocol + // The prover computes the actual sum over all boolean inputs and generates the first univariate + // polynomial. This step is crucial because it reduces the n-variate polynomial to a univariate + // polynomial in x0, while maintaining the property that its sum over {0,1} equals the original + // sum. + let (claimed_sum, univariate_poly1) = poly.prove_first_sumcheck_round(); + + // The verifier checks the first round + // This check ensures that the sum of the univariate polynomial over {0,1} equals the claimed sum. + // It's a key step in verifying the prover's claim without computing the full sum. + let (valid, _challenge) = verify_sumcheck_first_round(claimed_sum, &univariate_poly1); + assert!(valid, "First round verification failed"); + + // Verify that the first univariate polynomial is correct: f0(x0) = 4x0 - 1 + // This check confirms that the prover correctly computed the univariate polynomial. + let expected_poly1 = MultivariatePolynomial::::from_terms(vec![ + MultivariateTerm::new( + vec![MultivariateVariable { index: 0, exponent: 1 }], + PlutoBaseField::new(4), + ), + MultivariateTerm::new(vec![], -PlutoBaseField::ONE), + ]); + assert_eq!(univariate_poly1, expected_poly1, "First round polynomial is incorrect"); + + println!("Claimed sum: {:?}", claimed_sum); + println!("First univariate polynomial: {}", univariate_poly1); + + // Step 3: Second round of the sumcheck protocol + // The verifier sends a challenge. This challenge is used to reduce the problem further, + // from proving a statement about a sum to proving a statement about a single evaluation. + let random_challenge1 = PlutoBaseField::new(4); + + // The prover generates the second univariate polynomial + // This polynomial represents the partial evaluation of the original polynomial with x0 fixed to + // the challenge value. + let univariate_poly2 = poly.prove_sumcheck_round_i(1, vec![random_challenge1]); + println!("Round 2 univariate polynomial: {}", univariate_poly2); + + // Verify that the second univariate polynomial is correct: f1(x1) = 7x1 + 4 + // This check ensures that the prover correctly computed the second univariate polynomial. + let expected_poly2 = MultivariatePolynomial::::from_terms(vec![ + MultivariateTerm::new( + vec![MultivariateVariable { index: 1, exponent: 1 }], + PlutoBaseField::new(7), + ), + MultivariateTerm::new(vec![], PlutoBaseField::new(4)), + ]); + assert_eq!(univariate_poly2, expected_poly2, "Second round polynomial is incorrect"); + + // The verifier checks the second round + // This check ensures that the evaluation of the first univariate polynomial at the challenge + // point equals the sum of the second univariate polynomial over {0,1}. + let (valid, _challenge) = + verify_sumcheck_univariate_poly_sum(1, random_challenge1, &univariate_poly1, &univariate_poly2); + assert!(valid, "Second round verification failed"); + + // Step 4: Third (final) round of the sumcheck protocol + // The process continues, further reducing the problem to a single point evaluation of the + // original polynomial. + let random_challenge2 = PlutoBaseField::new(4); + + // The prover generates the final univariate polynomial + let univariate_poly3 = + poly.prove_sumcheck_last_round(2, vec![random_challenge1, random_challenge2]); + println!("Round 3 univariate polynomial: {}", univariate_poly3); + + // Verify that the final univariate polynomial is correct: f2(x2) = 16 + // This check confirms that the prover correctly computed the final univariate polynomial. + let expected_poly3 = + MultivariatePolynomial::::from_terms(vec![MultivariateTerm::new( + vec![], + PlutoBaseField::new(16), + )]); + assert_eq!(univariate_poly3, expected_poly3, "Final round polynomial is incorrect"); + + // The verifier checks the final round + // This check ensures that the evaluation of the second univariate polynomial at the challenge + // point equals the sum of the third univariate polynomial over {0,1}. + let (valid, _final_challenge) = + verify_sumcheck_univariate_poly_sum(2, random_challenge2, &univariate_poly2, &univariate_poly3); + assert!(valid, "Final round verification failed"); + + // Step 5: Final verification + // The verifier sends a final challenge and checks the entire protocol + // This step verifies that the final point evaluation claimed by the prover + // matches the evaluation of the original polynomial at the challenge points. + let random_challenge3 = PlutoBaseField::new(4); + let valid = verify_sumcheck_last_round( + vec![random_challenge1, random_challenge2, random_challenge3], + &univariate_poly3, + &poly, + ); + assert!(valid, "Overall sumcheck protocol verification failed"); + + // If we reach this point, the entire sumcheck protocol has been successfully demonstrated + // The verifier is convinced that the prover knows the correct sum without having to compute it + // directly. + println!("Sumcheck protocol successfully verified!"); +} diff --git a/src/sumcheck/to_bytes.rs b/src/sumcheck/to_bytes.rs new file mode 100644 index 0000000..2d5e3b1 --- /dev/null +++ b/src/sumcheck/to_bytes.rs @@ -0,0 +1,3 @@ +pub trait ToBytes { + fn to_bytes(&self) -> Vec; +}