From 455df38d58e8096ffe4e3e85d005e6d693914f21 Mon Sep 17 00:00:00 2001 From: David Nevado Date: Mon, 9 Sep 2024 10:02:11 +0200 Subject: [PATCH] feat: extended_lagrange function --- halo2_backend/src/plonk/keygen.rs | 71 ++----------------------------- halo2_backend/src/poly/domain.rs | 62 +++++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 68 deletions(-) diff --git a/halo2_backend/src/plonk/keygen.rs b/halo2_backend/src/plonk/keygen.rs index dfd35a9bd1..2771ee09de 100644 --- a/halo2_backend/src/plonk/keygen.rs +++ b/halo2_backend/src/plonk/keygen.rs @@ -4,7 +4,6 @@ #![allow(clippy::int_plus_one)] -use ff::{BatchInvert, WithSmallOrderMulGroup}; use group::Curve; use halo2_middleware::ff::{Field, FromUniformBytes}; use halo2_middleware::zal::impls::H2cEngine; @@ -131,68 +130,7 @@ where .map(Polynomial::new_lagrange_from_vec) .collect(); - // Compute L_0(X) in the extended co-domain. - // L_0(X) the 0th Lagrange polynomial in the original domain. - // Its representation in the original domain H = {1, g, g^2, ..., g^(n-1)} - // is [1, 0, ..., 0]. - // We compute its represenation in the extended co-domain - // zH = {z, z*w, z*w^2, ... , z*w^(n*k - 1)}, where k is the extension factor - // of the domain, and z is the extended root such that w^k = g. - // We assume z = F::ZETA, a cubic root the field. This simplifies the computation. - // - // The computation uses the fomula: - // L_i(X) = g^i/n * (X^n -1)/(X-g^i) - // L_0(X) = 1/n * (X^n -1)/(X-1) - let start = std::time::Instant::now(); - let l0 = { - let one = C::ScalarExt::ONE; - let zeta = >::ZETA; - - let n: u64 = 1 << vk.domain.k(); - let c = (C::ScalarExt::from(n)).invert().unwrap(); - let mut l0 = vec![C::ScalarExt::ZERO; vk.domain.extended_len()]; - - let w = vk.domain.get_extended_omega(); - let wn = w.pow_vartime(&[n]); - let zeta_n = match n % 3 { - 1 => zeta, - 2 => zeta * zeta, - _ => one, - }; - - // Compute denominators. - parallelize(&mut l0, |e, mut index| { - let mut acc = zeta * w.pow_vartime(&[index as u64]); - for e in e { - *e = acc - one; - acc *= w; - index += 1; - } - }); - l0.batch_invert(); - - // Compute numinators. - // C * (zeta * w^i)^n = (C * zeta^n) * w^(i*n) - // We use w^k = g and g^n = 1 to save multiplications. - let k = 1 << (vk.domain.extended_k() - vk.domain.k()); - let mut wn_powers = vec![zeta_n * c; k]; - for i in 1..k { - wn_powers[i] = wn_powers[i - 1] * wn - } - - parallelize(&mut l0, |e, mut index| { - for e in e { - *e *= wn_powers[index % k] - c; - index += 1; - } - }); - - Polynomial { - values: l0, - _marker: std::marker::PhantomData, - } - }; - println!("L0 gen: {:?}", start.elapsed()); + let l0 = vk.domain.lagrange_extended(0usize); // Compute l_blind(X) which evaluates to 1 for each blinding factor row // and 0 otherwise over the domain. @@ -205,11 +143,8 @@ where // Compute l_last(X) which evaluates to 1 on the first inactive row (just // before the blinding factors) and 0 otherwise over the domain - // TODO L_0 method could be used here too. - let mut l_last = vk.domain.empty_lagrange(); - l_last[params.n() as usize - vk.cs.blinding_factors() - 1] = C::Scalar::ONE; - let l_last = vk.domain.lagrange_to_coeff(l_last); - let l_last = vk.domain.coeff_to_extended(l_last); + let idx = params.n() as usize - vk.cs.blinding_factors() - 1; + let l_last = vk.domain.lagrange_extended(idx); // Compute l_active_row(X) let one = C::Scalar::ONE; diff --git a/halo2_backend/src/poly/domain.rs b/halo2_backend/src/poly/domain.rs index dabc797da4..9eb277ca63 100644 --- a/halo2_backend/src/poly/domain.rs +++ b/halo2_backend/src/poly/domain.rs @@ -243,6 +243,68 @@ impl> EvaluationDomain { } } + // Compute L_i(X) in the extended co-domain, where + // L_i(X)is the ith Lagrange polynomial in the original domain, + // H = {1, g, g^2, ..., g^(n-1)}. + // We compute its represenation in the extended co-domain + // zH = {z, z*w, z*w^2, ... , z*w^(n*k - 1)}, where k is the extension factor + // of the domain, and z is the extended root such that w^k = g. + // We assume z = F::ZETA, a cubic root the field. This simplifies the computation. + // + // The computation uses the fomula: + // L_i(X) = g^i/n * (X^n -1)/(X-g^i) + pub fn lagrange_extended(&self, idx: usize) -> Polynomial { + let one = F::ONE; + let zeta = >::ZETA; + + let n: u64 = 1 << self.k(); + // c = g^i / n + let g_i = self.omega.pow_vartime([idx as u64]); + let mut lag_poly = vec![F::ZERO; self.extended_len()]; + + let w = self.get_extended_omega(); + let wn = w.pow_vartime([n]); + let zeta_n = match n % 3 { + 1 => zeta, + 2 => zeta * zeta, + _ => one, + }; + + // Compute denominators. ( n * (w^j - g_i)) + let n = F::from(n); + let n_g_i = n * g_i; + parallelize(&mut lag_poly, |e, mut index| { + let mut acc = n * zeta * w.pow_vartime([index as u64]); + for e in e { + *e = acc - n_g_i; + acc *= w; + index += 1; + } + }); + lag_poly.batch_invert(); + + // Compute numerators. + // g_i * (zeta * w^i)^n = (g_i * zeta^n) * w^(i*n) + // We use w^k = g and g^n = 1 to save multiplications. + let k = 1 << (self.extended_k() - self.k()); + let mut wn_powers = vec![zeta_n * g_i; k]; + for i in 1..k { + wn_powers[i] = wn_powers[i - 1] * wn + } + + parallelize(&mut lag_poly, |e, mut index| { + for e in e { + *e *= wn_powers[index % k] - g_i; + index += 1; + } + }); + + Polynomial { + values: lag_poly, + _marker: std::marker::PhantomData, + } + } + /// Rotate the extended domain polynomial over the original domain. pub fn rotate_extended( &self,