From 06a6a4a70598ffe2b3139bb9457874e357567063 Mon Sep 17 00:00:00 2001 From: han0110 Date: Tue, 30 Jan 2024 03:31:33 +0000 Subject: [PATCH 1/3] feat: implement for `BatchOpenScheme::Gwc19` --- examples/separately.rs | 2 +- src/codegen.rs | 36 +-- src/codegen/pcs.rs | 545 ++---------------------------------- src/codegen/pcs/bdfg21.rs | 492 ++++++++++++++++++++++++++++++++ src/codegen/pcs/gwc19.rs | 298 ++++++++++++++++++++ src/codegen/template.rs | 1 + src/codegen/util.rs | 91 +++++- src/test.rs | 148 ++++++---- templates/Halo2Verifier.sol | 17 +- 9 files changed, 1015 insertions(+), 615 deletions(-) create mode 100644 src/codegen/pcs/bdfg21.rs create mode 100644 src/codegen/pcs/gwc19.rs diff --git a/examples/separately.rs b/examples/separately.rs index b2e58c5..00c1d84 100644 --- a/examples/separately.rs +++ b/examples/separately.rs @@ -176,7 +176,7 @@ mod application { } fn configure(meta: &mut ConstraintSystem) -> Self::Config { - meta.set_minimum_degree(4); + meta.set_minimum_degree(5); StandardPlonkConfig::configure(meta) } diff --git a/src/codegen.rs b/src/codegen.rs index d950c27..2554c4f 100644 --- a/src/codegen.rs +++ b/src/codegen.rs @@ -1,9 +1,5 @@ use crate::codegen::{ evaluator::Evaluator, - pcs::{ - bdfg21_computations, queries, rotation_sets, - BatchOpenScheme::{Bdfg21, Gwc19}, - }, template::{Halo2Verifier, Halo2VerifyingKey}, util::{fr_to_u256, g1_to_u256s, g2_to_u256s, ConstraintSystemMeta, Data, Ptr}, }; @@ -108,11 +104,6 @@ impl<'a> SolidityGenerator<'a> { .any(|(_, rotation)| *rotation != Rotation::cur()), "Rotated query to instance column is not yet implemented" ); - assert_eq!( - scheme, - BatchOpenScheme::Bdfg21, - "BatchOpenScheme::Gwc19 is not yet implemented" - ); Self { params, @@ -237,7 +228,7 @@ impl<'a> SolidityGenerator<'a> { let vk = self.generate_vk(); let vk_len = vk.len(); - let vk_mptr = Ptr::memory(self.estimate_static_working_memory_size(&vk, proof_cptr)); + let vk_mptr = Ptr::memory(self.static_working_memory_size(&vk, proof_cptr)); let data = Data::new(&self.meta, &vk, vk_mptr, proof_cptr); let evaluator = Evaluator::new(self.vk.cs(), &self.meta, &data); @@ -260,10 +251,7 @@ impl<'a> SolidityGenerator<'a> { }) .collect(); - let pcs_computations = match self.scheme { - Bdfg21 => bdfg21_computations(&self.meta, &data), - Gwc19 => unimplemented!(), - }; + let pcs_computations = self.scheme.computations(&self.meta, &data); Halo2Verifier { scheme: self.scheme, @@ -273,6 +261,7 @@ impl<'a> SolidityGenerator<'a> { num_neg_lagranges: self.meta.rotation_last.unsigned_abs() as usize, num_advices: self.meta.num_advices(), num_challenges: self.meta.num_challenges(), + num_rotations: self.meta.num_rotations, num_evals: self.meta.num_evals, num_quotients: self.meta.num_quotients, proof_cptr, @@ -285,20 +274,11 @@ impl<'a> SolidityGenerator<'a> { } } - fn estimate_static_working_memory_size( - &self, - vk: &Halo2VerifyingKey, - proof_cptr: Ptr, - ) -> usize { - let pcs_computation = match self.scheme { - Bdfg21 => { - let mock_vk_mptr = Ptr::memory(0x100000); - let mock = Data::new(&self.meta, vk, mock_vk_mptr, proof_cptr); - let (superset, sets) = rotation_sets(&queries(&self.meta, &mock)); - let num_coeffs = sets.iter().map(|set| set.rots().len()).sum::(); - 2 * (1 + num_coeffs) + 6 + 2 * superset.len() + 1 + 3 * sets.len() - } - Gwc19 => unimplemented!(), + fn static_working_memory_size(&self, vk: &Halo2VerifyingKey, proof_cptr: Ptr) -> usize { + let pcs_computation = { + let mock_vk_mptr = Ptr::memory(0x100000); + let mock = Data::new(&self.meta, vk, mock_vk_mptr, proof_cptr); + self.scheme.static_working_memory_size(&self.meta, &mock) }; itertools::max([ diff --git a/src/codegen/pcs.rs b/src/codegen/pcs.rs index 2b7536a..e0021e1 100644 --- a/src/codegen/pcs.rs +++ b/src/codegen/pcs.rs @@ -1,8 +1,8 @@ -#![allow(clippy::useless_format)] +use crate::codegen::util::{ConstraintSystemMeta, Data, EcPoint, Word}; +use itertools::{chain, izip}; -use crate::codegen::util::{for_loop, ConstraintSystemMeta, Data, EcPoint, Location, Ptr, Word}; -use itertools::{chain, izip, Itertools}; -use std::collections::{BTreeMap, BTreeSet}; +mod bdfg21; +mod gwc19; /// KZG batch open schemes in `halo2`. #[derive(Clone, Copy, Debug, PartialEq, Eq)] @@ -19,6 +19,30 @@ pub enum BatchOpenScheme { Bdfg21, } +impl BatchOpenScheme { + pub(crate) fn static_working_memory_size( + &self, + meta: &ConstraintSystemMeta, + data: &Data, + ) -> usize { + match self { + Self::Bdfg21 => bdfg21::static_working_memory_size(meta, data), + Self::Gwc19 => gwc19::static_working_memory_size(meta, data), + } + } + + pub(crate) fn computations( + &self, + meta: &ConstraintSystemMeta, + data: &Data, + ) -> Vec> { + match self { + Self::Bdfg21 => bdfg21::computations(meta, data), + Self::Gwc19 => gwc19::computations(meta, data), + } + } +} + #[derive(Debug)] pub(crate) struct Query { comm: EcPoint, @@ -77,516 +101,3 @@ pub(crate) fn queries(meta: &ConstraintSystemMeta, data: &Data) -> Vec { ] .collect() } - -#[derive(Debug)] -pub(crate) struct RotationSet { - rots: BTreeSet, - diffs: BTreeSet, - comms: Vec, - evals: Vec>, -} - -impl RotationSet { - pub(crate) fn rots(&self) -> &BTreeSet { - &self.rots - } - - pub(crate) fn diffs(&self) -> &BTreeSet { - &self.diffs - } - - pub(crate) fn comms(&self) -> &[EcPoint] { - &self.comms - } - - pub(crate) fn evals(&self) -> &[Vec] { - &self.evals - } -} - -pub(crate) fn rotation_sets(queries: &[Query]) -> (BTreeSet, Vec) { - let mut superset = BTreeSet::new(); - let comm_queries = queries.iter().fold( - Vec::<(EcPoint, BTreeMap)>::new(), - |mut comm_queries, query| { - superset.insert(query.rot); - if let Some(pos) = comm_queries - .iter() - .position(|(comm, _)| comm == &query.comm) - { - let (_, queries) = &mut comm_queries[pos]; - assert!(!queries.contains_key(&query.rot)); - queries.insert(query.rot, query.eval); - } else { - comm_queries.push((query.comm, BTreeMap::from_iter([(query.rot, query.eval)]))); - } - comm_queries - }, - ); - let superset = superset; - let sets = - comm_queries - .into_iter() - .fold(Vec::::new(), |mut sets, (comm, queries)| { - if let Some(pos) = sets - .iter() - .position(|set| itertools::equal(&set.rots, queries.keys())) - { - let set = &mut sets[pos]; - if !set.comms.contains(&comm) { - set.comms.push(comm); - set.evals.push(queries.into_values().collect_vec()); - } - } else { - let diffs = BTreeSet::from_iter( - superset - .iter() - .filter(|rot| !queries.contains_key(rot)) - .copied(), - ); - let set = RotationSet { - rots: BTreeSet::from_iter(queries.keys().copied()), - diffs, - comms: vec![comm], - evals: vec![queries.into_values().collect()], - }; - sets.push(set); - } - sets - }); - (superset, sets) -} - -pub(crate) fn bdfg21_computations(meta: &ConstraintSystemMeta, data: &Data) -> Vec> { - let queries = queries(meta, data); - let (superset, sets) = rotation_sets(&queries); - let min_rot = *superset.first().unwrap(); - let max_rot = *superset.last().unwrap(); - let num_coeffs = sets.iter().map(|set| set.rots().len()).sum::(); - - let w = EcPoint::from(data.w_cptr); - let w_prime = EcPoint::from(data.w_cptr + 2); - - let diff_0 = Word::from(Ptr::memory(0x00)); - let coeffs = sets - .iter() - .scan(diff_0.ptr() + 1, |state, set| { - let ptrs = Word::range(*state).take(set.rots().len()).collect_vec(); - *state = *state + set.rots().len(); - Some(ptrs) - }) - .collect_vec(); - - let first_batch_invert_end = diff_0.ptr() + 1 + num_coeffs; - let second_batch_invert_end = diff_0.ptr() + sets.len(); - let free_mptr = diff_0.ptr() + 2 * (1 + num_coeffs) + 6; - - let point_mptr = free_mptr; - let mu_minus_point_mptr = point_mptr + superset.len(); - let vanishing_0_mptr = mu_minus_point_mptr + superset.len(); - let diff_mptr = vanishing_0_mptr + 1; - let r_eval_mptr = diff_mptr + sets.len(); - let sum_mptr = r_eval_mptr + sets.len(); - - let point_vars = - izip!(&superset, (0..).map(|idx| format!("point_{idx}"))).collect::>(); - let points = izip!(&superset, Word::range(point_mptr)).collect::>(); - let mu_minus_points = - izip!(&superset, Word::range(mu_minus_point_mptr)).collect::>(); - let vanishing_0 = Word::from(vanishing_0_mptr); - let diffs = Word::range(diff_mptr).take(sets.len()).collect_vec(); - let r_evals = Word::range(r_eval_mptr).take(sets.len()).collect_vec(); - let sums = Word::range(sum_mptr).take(sets.len()).collect_vec(); - - let point_computations = chain![ - [ - "let x := mload(X_MPTR)", - "let omega := mload(OMEGA_MPTR)", - "let omega_inv := mload(OMEGA_INV_MPTR)", - "let x_pow_of_omega := mulmod(x, omega, r)" - ] - .map(str::to_string), - (1..=max_rot).flat_map(|rot| { - chain![ - points - .get(&rot) - .map(|point| format!("mstore({}, x_pow_of_omega)", point.ptr())), - (rot != max_rot) - .then(|| { "x_pow_of_omega := mulmod(x_pow_of_omega, omega, r)".to_string() }) - ] - }), - [ - format!("mstore({}, x)", points[&0].ptr()), - format!("x_pow_of_omega := mulmod(x, omega_inv, r)") - ], - (min_rot..0).rev().flat_map(|rot| { - chain![ - points - .get(&rot) - .map(|point| format!("mstore({}, x_pow_of_omega)", point.ptr())), - (rot != min_rot).then(|| { - "x_pow_of_omega := mulmod(x_pow_of_omega, omega_inv, r)".to_string() - }) - ] - }) - ] - .collect_vec(); - - let vanishing_computations = chain![ - ["let mu := mload(MU_MPTR)".to_string()], - { - let mptr = mu_minus_points.first_key_value().unwrap().1.ptr(); - let mptr_end = mptr + mu_minus_points.len(); - for_loop( - [ - format!("let mptr := {mptr}"), - format!("let mptr_end := {mptr_end}"), - format!("let point_mptr := {free_mptr}"), - ], - "lt(mptr, mptr_end)", - [ - "mptr := add(mptr, 0x20)", - "point_mptr := add(point_mptr, 0x20)", - ] - .map(str::to_string), - ["mstore(mptr, addmod(mu, sub(r, mload(point_mptr)), r))".to_string()], - ) - }, - ["let s".to_string()], - chain![ - [format!( - "s := {}", - mu_minus_points[sets[0].rots().first().unwrap()] - )], - chain![sets[0].rots().iter().skip(1)] - .map(|rot| { format!("s := mulmod(s, {}, r)", mu_minus_points[rot]) }), - [format!("mstore({}, s)", vanishing_0.ptr())], - ], - ["let diff".to_string()], - izip!(0.., &sets, &diffs).flat_map(|(set_idx, set, diff)| { - chain![ - [set.diffs() - .first() - .map(|rot| format!("diff := {}", mu_minus_points[rot])) - .unwrap_or_else(|| "diff := 1".to_string())], - chain![set.diffs().iter().skip(1)] - .map(|rot| { format!("diff := mulmod(diff, {}, r)", mu_minus_points[rot]) }), - [format!("mstore({}, diff)", diff.ptr())], - (set_idx == 0).then(|| format!("mstore({}, diff)", diff_0.ptr())), - ] - }) - ] - .collect_vec(); - - let coeff_computations = izip!(&sets, &coeffs) - .map(|(set, coeffs)| { - let coeff_points = set - .rots() - .iter() - .map(|rot| &point_vars[rot]) - .enumerate() - .map(|(i, rot_i)| { - set.rots() - .iter() - .map(|rot| &point_vars[rot]) - .enumerate() - .filter_map(|(j, rot_j)| (i != j).then_some((rot_i, rot_j))) - .collect_vec() - }) - .collect_vec(); - chain![ - set.rots() - .iter() - .map(|rot| { format!("let {} := {}", &point_vars[rot], points[rot]) }), - ["let coeff".to_string()], - izip!(set.rots(), &coeff_points, coeffs).flat_map( - |(rot_i, coeff_points, coeff)| chain![ - [coeff_points - .first() - .map(|(point_i, point_j)| { - format!("coeff := addmod({point_i}, sub(r, {point_j}), r)") - }) - .unwrap_or_else(|| { "coeff := 1".to_string() })], - coeff_points.iter().skip(1).map(|(point_i, point_j)| { - let item = format!("addmod({point_i}, sub(r, {point_j}), r)"); - format!("coeff := mulmod(coeff, {item}, r)") - }), - [ - format!("coeff := mulmod(coeff, {}, r)", mu_minus_points[rot_i]), - format!("mstore({}, coeff)", coeff.ptr()) - ], - ] - ) - ] - .collect_vec() - }) - .collect_vec(); - - let normalized_coeff_computations = chain![ - [ - format!("success := batch_invert(success, 0, {first_batch_invert_end}, r)"), - format!("let diff_0_inv := {diff_0}"), - format!("mstore({}, diff_0_inv)", diffs[0].ptr()), - ], - for_loop( - [ - format!("let mptr := {}", diffs[0].ptr() + 1), - format!("let mptr_end := {}", diffs[0].ptr() + sets.len()), - ], - "lt(mptr, mptr_end)", - ["mptr := add(mptr, 0x20)".to_string()], - ["mstore(mptr, mulmod(mload(mptr), diff_0_inv, r))".to_string()], - ), - ] - .collect_vec(); - - let r_evals_computations = izip!(0.., &sets, &coeffs, &diffs, &r_evals).map( - |(set_idx, set, coeffs, set_coeff, r_eval)| { - let is_single_rot_set = set.rots().len() == 1; - chain![ - is_single_rot_set.then(|| format!("let coeff := {}", coeffs[0])), - ["let zeta := mload(ZETA_MPTR)", "let r_eval := 0"].map(str::to_string), - if is_single_rot_set { - let eval_groups = set.evals().iter().rev().fold( - Vec::>::new(), - |mut eval_groups, evals| { - let eval = &evals[0]; - if let Some(last_group) = eval_groups.last_mut() { - let last_eval = **last_group.last().unwrap(); - if last_eval.ptr().value().is_integer() - && last_eval.ptr() - 1 == eval.ptr() - { - last_group.push(eval) - } else { - eval_groups.push(vec![eval]) - } - eval_groups - } else { - vec![vec![eval]] - } - }, - ); - chain![eval_groups.iter().enumerate()] - .flat_map(|(group_idx, evals)| { - if evals.len() < 3 { - chain![evals.iter().enumerate()] - .flat_map(|(eval_idx, eval)| { - let is_first_eval = group_idx == 0 && eval_idx == 0; - let item = format!("mulmod(coeff, {eval}, r)"); - chain![ - (!is_first_eval).then(|| format!( - "r_eval := mulmod(r_eval, zeta, r)" - )), - [format!("r_eval := addmod(r_eval, {item}, r)")], - ] - }) - .collect_vec() - } else { - let item = "mulmod(coeff, calldataload(mptr), r)"; - for_loop( - [ - format!("let mptr := {}", evals[0].ptr()), - format!("let mptr_end := {}", evals[0].ptr() - evals.len()), - ], - "lt(mptr_end, mptr)".to_string(), - ["mptr := sub(mptr, 0x20)".to_string()], - [format!( - "r_eval := addmod(mulmod(r_eval, zeta, r), {item}, r)" - )], - ) - } - }) - .collect_vec() - } else { - chain![set.evals().iter().enumerate().rev()] - .flat_map(|(idx, evals)| { - chain![ - izip!(evals, coeffs).map(|(eval, coeff)| { - let item = format!("mulmod({coeff}, {eval}, r)"); - format!("r_eval := addmod(r_eval, {item}, r)") - }), - (idx != 0).then(|| format!("r_eval := mulmod(r_eval, zeta, r)")), - ] - }) - .collect_vec() - }, - (set_idx != 0).then(|| format!("r_eval := mulmod(r_eval, {set_coeff}, r)")), - [format!("mstore({}, r_eval)", r_eval.ptr())], - ] - .collect_vec() - }, - ); - - let coeff_sums_computation = izip!(&coeffs, &sums).map(|(coeffs, sum)| { - let (coeff_0, rest_coeffs) = coeffs.split_first().unwrap(); - chain![ - [format!("let sum := {coeff_0}")], - rest_coeffs - .iter() - .map(|coeff_mptr| format!("sum := addmod(sum, {coeff_mptr}, r)")), - [format!("mstore({}, sum)", sum.ptr())], - ] - .collect_vec() - }); - - let r_eval_computations = chain![ - for_loop( - [ - format!("let mptr := 0x00"), - format!("let mptr_end := {second_batch_invert_end}"), - format!("let sum_mptr := {}", sums[0].ptr()), - ], - "lt(mptr, mptr_end)", - ["mptr := add(mptr, 0x20)", "sum_mptr := add(sum_mptr, 0x20)"].map(str::to_string), - ["mstore(mptr, mload(sum_mptr))".to_string()], - ), - [ - format!("success := batch_invert(success, 0, {second_batch_invert_end}, r)"), - format!( - "let r_eval := mulmod(mload({}), {}, r)", - second_batch_invert_end - 1, - r_evals.last().unwrap() - ) - ], - for_loop( - [ - format!("let sum_inv_mptr := {}", second_batch_invert_end - 2), - format!("let sum_inv_mptr_end := {second_batch_invert_end}"), - format!("let r_eval_mptr := {}", r_evals[r_evals.len() - 2].ptr()), - ], - "lt(sum_inv_mptr, sum_inv_mptr_end)", - [ - "sum_inv_mptr := sub(sum_inv_mptr, 0x20)", - "r_eval_mptr := sub(r_eval_mptr, 0x20)" - ] - .map(str::to_string), - [ - "r_eval := mulmod(r_eval, mload(NU_MPTR), r)", - "r_eval := addmod(r_eval, mulmod(mload(sum_inv_mptr), mload(r_eval_mptr), r), r)" - ] - .map(str::to_string), - ), - ["mstore(R_EVAL_MPTR, r_eval)".to_string()], - ] - .collect_vec(); - - let pairing_input_computations = chain![ - ["let nu := mload(NU_MPTR)".to_string()], - izip!(0.., &sets, &diffs).flat_map(|(set_idx, set, set_coeff)| { - let is_first_set = set_idx == 0; - let is_last_set = set_idx == sets.len() - 1; - - let ec_add = &format!("ec_add_{}", if is_first_set { "acc" } else { "tmp" }); - let ec_mul = &format!("ec_mul_{}", if is_first_set { "acc" } else { "tmp" }); - let acc_x = Ptr::memory(0x00) + if is_first_set { 0 } else { 4 }; - let acc_y = acc_x + 1; - - let comm_groups = set.comms().iter().rev().skip(1).fold( - Vec::<(Location, Vec<&EcPoint>)>::new(), - |mut comm_groups, comm| { - if let Some(last_group) = comm_groups.last_mut() { - let last_comm = **last_group.1.last().unwrap(); - if last_group.0 == comm.loc() - && last_comm.x().ptr().value().is_integer() - && last_comm.x().ptr() - 2 == comm.x().ptr() - { - last_group.1.push(comm) - } else { - comm_groups.push((comm.loc(), vec![comm])) - } - comm_groups - } else { - vec![(comm.loc(), vec![comm])] - } - }, - ); - - chain![ - set.comms() - .last() - .map(|comm| { - [ - format!("mstore({acc_x}, {})", comm.x()), - format!("mstore({acc_y}, {})", comm.y()), - ] - }) - .into_iter() - .flatten(), - comm_groups.into_iter().flat_map(move |(loc, comms)| { - if comms.len() < 3 { - comms - .iter() - .flat_map(|comm| { - let (x, y) = (comm.x(), comm.y()); - [ - format!("success := {ec_mul}(success, mload(ZETA_MPTR))"), - format!("success := {ec_add}(success, {x}, {y})"), - ] - }) - .collect_vec() - } else { - let mptr = comms.first().unwrap().x().ptr(); - let mptr_end = mptr - 2 * comms.len(); - let x = Word::from(Ptr::new(loc, "mptr")); - let y = Word::from(Ptr::new(loc, "add(mptr, 0x20)")); - for_loop( - [ - format!("let mptr := {mptr}"), - format!("let mptr_end := {mptr_end}"), - ], - "lt(mptr_end, mptr)", - ["mptr := sub(mptr, 0x40)".to_string()], - [ - format!("success := {ec_mul}(success, mload(ZETA_MPTR))"), - format!("success := {ec_add}(success, {x}, {y})"), - ], - ) - } - }), - (!is_first_set) - .then(|| { - let scalar = format!("mulmod(nu, {set_coeff}, r)"); - chain![ - [ - format!("success := ec_mul_tmp(success, {scalar})"), - format!("success := ec_add_acc(success, mload(0x80), mload(0xa0))"), - ], - (!is_last_set).then(|| format!("nu := mulmod(nu, mload(NU_MPTR), r)")) - ] - }) - .into_iter() - .flatten(), - ] - .collect_vec() - }), - [ - format!("mstore(0x80, mload(G1_X_MPTR))"), - format!("mstore(0xa0, mload(G1_Y_MPTR))"), - format!("success := ec_mul_tmp(success, sub(r, mload(R_EVAL_MPTR)))"), - format!("success := ec_add_acc(success, mload(0x80), mload(0xa0))"), - format!("mstore(0x80, {})", w.x()), - format!("mstore(0xa0, {})", w.y()), - format!("success := ec_mul_tmp(success, sub(r, {vanishing_0}))"), - format!("success := ec_add_acc(success, mload(0x80), mload(0xa0))"), - format!("mstore(0x80, {})", w_prime.x()), - format!("mstore(0xa0, {})", w_prime.y()), - format!("success := ec_mul_tmp(success, mload(MU_MPTR))"), - format!("success := ec_add_acc(success, mload(0x80), mload(0xa0))"), - format!("mstore(PAIRING_LHS_X_MPTR, mload(0x00))"), - format!("mstore(PAIRING_LHS_Y_MPTR, mload(0x20))"), - format!("mstore(PAIRING_RHS_X_MPTR, {})", w_prime.x()), - format!("mstore(PAIRING_RHS_Y_MPTR, {})", w_prime.y()), - ], - ] - .collect_vec(); - - chain![ - [point_computations, vanishing_computations], - coeff_computations, - [normalized_coeff_computations], - r_evals_computations, - coeff_sums_computation, - [r_eval_computations, pairing_input_computations], - ] - .collect_vec() -} diff --git a/src/codegen/pcs/bdfg21.rs b/src/codegen/pcs/bdfg21.rs new file mode 100644 index 0000000..fe551a6 --- /dev/null +++ b/src/codegen/pcs/bdfg21.rs @@ -0,0 +1,492 @@ +#![allow(clippy::useless_format)] + +use crate::codegen::{ + pcs::{queries, Query}, + util::{ + for_loop, group_backward_adjacent_ec_points, group_backward_adjacent_words, + ConstraintSystemMeta, Data, EcPoint, Location, Ptr, Word, + }, +}; +use itertools::{chain, izip, Itertools}; +use std::collections::{BTreeMap, BTreeSet}; + +pub(super) fn static_working_memory_size(meta: &ConstraintSystemMeta, data: &Data) -> usize { + let (superset, sets) = rotation_sets(&queries(meta, data)); + let num_coeffs = sets.iter().map(|set| set.rots().len()).sum::(); + 2 * (1 + num_coeffs) + 6 + 2 * superset.len() + 1 + 3 * sets.len() +} + +pub(super) fn computations(meta: &ConstraintSystemMeta, data: &Data) -> Vec> { + let (superset, sets) = rotation_sets(&queries(meta, data)); + let min_rot = *superset.first().unwrap(); + let max_rot = *superset.last().unwrap(); + let num_coeffs = sets.iter().map(|set| set.rots().len()).sum::(); + + let w = EcPoint::from(data.w_cptr); + let w_prime = EcPoint::from(data.w_cptr + 2); + + let diff_0 = Word::from(Ptr::memory(0x00)); + let coeffs = sets + .iter() + .scan(diff_0.ptr() + 1, |state, set| { + let ptrs = Word::range(*state).take(set.rots().len()).collect_vec(); + *state = *state + set.rots().len(); + Some(ptrs) + }) + .collect_vec(); + + let first_batch_invert_end = diff_0.ptr() + 1 + num_coeffs; + let second_batch_invert_end = diff_0.ptr() + sets.len(); + let free_mptr = diff_0.ptr() + 2 * (1 + num_coeffs) + 6; + + let point_mptr = free_mptr; + let mu_minus_point_mptr = point_mptr + superset.len(); + let vanishing_0_mptr = mu_minus_point_mptr + superset.len(); + let diff_mptr = vanishing_0_mptr + 1; + let r_eval_mptr = diff_mptr + sets.len(); + let sum_mptr = r_eval_mptr + sets.len(); + + let point_vars = + izip!(&superset, (0..).map(|idx| format!("point_{idx}"))).collect::>(); + let points = izip!(&superset, Word::range(point_mptr)).collect::>(); + let mu_minus_points = + izip!(&superset, Word::range(mu_minus_point_mptr)).collect::>(); + let vanishing_0 = Word::from(vanishing_0_mptr); + let diffs = Word::range(diff_mptr).take(sets.len()).collect_vec(); + let r_evals = Word::range(r_eval_mptr).take(sets.len()).collect_vec(); + let sums = Word::range(sum_mptr).take(sets.len()).collect_vec(); + + let point_computations = chain![ + [ + "let x := mload(X_MPTR)", + "let omega := mload(OMEGA_MPTR)", + "let omega_inv := mload(OMEGA_INV_MPTR)", + "let x_pow_of_omega := mulmod(x, omega, r)" + ] + .map(str::to_string), + (1..=max_rot).flat_map(|rot| { + chain![ + points + .get(&rot) + .map(|point| format!("mstore({}, x_pow_of_omega)", point.ptr())), + (rot != max_rot) + .then(|| "x_pow_of_omega := mulmod(x_pow_of_omega, omega, r)".to_string()) + ] + }), + [ + format!("mstore({}, x)", points[&0].ptr()), + format!("x_pow_of_omega := mulmod(x, omega_inv, r)") + ], + (min_rot..0).rev().flat_map(|rot| { + chain![ + points + .get(&rot) + .map(|point| format!("mstore({}, x_pow_of_omega)", point.ptr())), + (rot != min_rot).then(|| { + "x_pow_of_omega := mulmod(x_pow_of_omega, omega_inv, r)".to_string() + }) + ] + }) + ] + .collect_vec(); + + let vanishing_computations = chain![ + ["let mu := mload(MU_MPTR)".to_string()], + { + let mptr = mu_minus_points.first_key_value().unwrap().1.ptr(); + let mptr_end = mptr + mu_minus_points.len(); + for_loop( + [ + format!("let mptr := {mptr}"), + format!("let mptr_end := {mptr_end}"), + format!("let point_mptr := {free_mptr}"), + ], + "lt(mptr, mptr_end)", + [ + "mptr := add(mptr, 0x20)", + "point_mptr := add(point_mptr, 0x20)", + ], + ["mstore(mptr, addmod(mu, sub(r, mload(point_mptr)), r))"], + ) + }, + ["let s".to_string()], + chain![ + [format!( + "s := {}", + mu_minus_points[sets[0].rots().first().unwrap()] + )], + chain![sets[0].rots().iter().skip(1)] + .map(|rot| { format!("s := mulmod(s, {}, r)", mu_minus_points[rot]) }), + [format!("mstore({}, s)", vanishing_0.ptr())], + ], + ["let diff".to_string()], + izip!(0.., &sets, &diffs).flat_map(|(set_idx, set, diff)| { + chain![ + [set.diffs() + .first() + .map(|rot| format!("diff := {}", mu_minus_points[rot])) + .unwrap_or_else(|| "diff := 1".to_string())], + chain![set.diffs().iter().skip(1)] + .map(|rot| { format!("diff := mulmod(diff, {}, r)", mu_minus_points[rot]) }), + [format!("mstore({}, diff)", diff.ptr())], + (set_idx == 0).then(|| format!("mstore({}, diff)", diff_0.ptr())), + ] + }) + ] + .collect_vec(); + + let coeff_computations = izip!(&sets, &coeffs) + .map(|(set, coeffs)| { + let coeff_points = set + .rots() + .iter() + .map(|rot| &point_vars[rot]) + .enumerate() + .map(|(i, rot_i)| { + set.rots() + .iter() + .map(|rot| &point_vars[rot]) + .enumerate() + .filter_map(|(j, rot_j)| (i != j).then_some((rot_i, rot_j))) + .collect_vec() + }) + .collect_vec(); + chain![ + set.rots() + .iter() + .map(|rot| format!("let {} := {}", &point_vars[rot], points[rot])), + ["let coeff".to_string()], + izip!(set.rots(), &coeff_points, coeffs).flat_map( + |(rot_i, coeff_points, coeff)| chain![ + [coeff_points + .first() + .map(|(point_i, point_j)| { + format!("coeff := addmod({point_i}, sub(r, {point_j}), r)") + }) + .unwrap_or_else(|| "coeff := 1".to_string())], + coeff_points.iter().skip(1).map(|(point_i, point_j)| { + let item = format!("addmod({point_i}, sub(r, {point_j}), r)"); + format!("coeff := mulmod(coeff, {item}, r)") + }), + [ + format!("coeff := mulmod(coeff, {}, r)", mu_minus_points[rot_i]), + format!("mstore({}, coeff)", coeff.ptr()) + ], + ] + ) + ] + .collect_vec() + }) + .collect_vec(); + + let normalized_coeff_computations = chain![ + [ + format!("success := batch_invert(success, 0, {first_batch_invert_end}, r)"), + format!("let diff_0_inv := {diff_0}"), + format!("mstore({}, diff_0_inv)", diffs[0].ptr()), + ], + for_loop( + [ + format!("let mptr := {}", diffs[0].ptr() + 1), + format!("let mptr_end := {}", diffs[0].ptr() + sets.len()), + ], + "lt(mptr, mptr_end)", + ["mptr := add(mptr, 0x20)"], + ["mstore(mptr, mulmod(mload(mptr), diff_0_inv, r))"], + ), + ] + .collect_vec(); + + let r_evals_computations = izip!(0.., &sets, &coeffs, &diffs, &r_evals).map( + |(set_idx, set, coeffs, set_coeff, r_eval)| { + let is_single_rot_set = set.rots().len() == 1; + chain![ + is_single_rot_set.then(|| format!("let coeff := {}", coeffs[0])), + ["let zeta := mload(ZETA_MPTR)", "let r_eval"].map(str::to_string), + if is_single_rot_set { + let evals = set.evals().iter().map(|evals| evals[0]).collect_vec(); + let eval_groups = group_backward_adjacent_words(evals.iter().rev().skip(1)); + chain![ + evals + .last() + .map(|eval| format!("r_eval := mulmod(coeff, {eval}, r)")), + eval_groups.iter().flat_map(|(loc, evals)| { + if evals.len() < 3 { + evals + .iter() + .flat_map(|eval| { + let item = format!("mulmod(coeff, {eval}, r)"); + [ + format!("r_eval := mulmod(r_eval, zeta, r)"), + format!("r_eval := addmod(r_eval, {item}, r)"), + ] + }) + .collect_vec() + } else { + assert_eq!(*loc, Location::Calldata); + let item = "mulmod(coeff, calldataload(cptr), r)"; + for_loop( + [ + format!("let cptr := {}", evals[0].ptr()), + format!("let cptr_end := {}", evals[0].ptr() - evals.len()), + ], + "lt(cptr_end, cptr)", + ["cptr := sub(cptr, 0x20)"], + [format!( + "r_eval := addmod(mulmod(r_eval, zeta, r), {item}, r)" + )], + ) + } + }) + ] + .collect_vec() + } else { + chain![set.evals().iter().enumerate().rev()] + .flat_map(|(idx, evals)| { + chain![ + izip!(evals, coeffs).map(|(eval, coeff)| { + let item = format!("mulmod({coeff}, {eval}, r)"); + format!("r_eval := addmod(r_eval, {item}, r)") + }), + (idx != 0).then(|| format!("r_eval := mulmod(r_eval, zeta, r)")), + ] + }) + .collect_vec() + }, + (set_idx != 0).then(|| format!("r_eval := mulmod(r_eval, {set_coeff}, r)")), + [format!("mstore({}, r_eval)", r_eval.ptr())], + ] + .collect_vec() + }, + ); + + let coeff_sums_computation = izip!(&coeffs, &sums).map(|(coeffs, sum)| { + let (coeff_0, rest_coeffs) = coeffs.split_first().unwrap(); + chain![ + [format!("let sum := {coeff_0}")], + rest_coeffs + .iter() + .map(|coeff_mptr| format!("sum := addmod(sum, {coeff_mptr}, r)")), + [format!("mstore({}, sum)", sum.ptr())], + ] + .collect_vec() + }); + + let r_eval_computations = chain![ + for_loop( + [ + format!("let mptr := 0x00"), + format!("let mptr_end := {second_batch_invert_end}"), + format!("let sum_mptr := {}", sums[0].ptr()), + ], + "lt(mptr, mptr_end)", + ["mptr := add(mptr, 0x20)", "sum_mptr := add(sum_mptr, 0x20)"], + ["mstore(mptr, mload(sum_mptr))"], + ), + [ + format!("success := batch_invert(success, 0, {second_batch_invert_end}, r)"), + format!( + "let r_eval := mulmod(mload({}), {}, r)", + second_batch_invert_end - 1, + r_evals.last().unwrap() + ) + ], + for_loop( + [ + format!("let sum_inv_mptr := {}", second_batch_invert_end - 2), + format!("let sum_inv_mptr_end := {second_batch_invert_end}"), + format!("let r_eval_mptr := {}", r_evals[r_evals.len() - 2].ptr()), + ], + "lt(sum_inv_mptr, sum_inv_mptr_end)", + [ + "sum_inv_mptr := sub(sum_inv_mptr, 0x20)", + "r_eval_mptr := sub(r_eval_mptr, 0x20)" + ], + [ + "r_eval := mulmod(r_eval, mload(NU_MPTR), r)", + "r_eval := addmod(r_eval, mulmod(mload(sum_inv_mptr), mload(r_eval_mptr), r), r)" + ], + ), + ["mstore(G1_SCALAR_MPTR, sub(r, r_eval))".to_string()], + ] + .collect_vec(); + + let pairing_input_computations = chain![ + ["let zeta := mload(ZETA_MPTR)", "let nu := mload(NU_MPTR)"].map(str::to_string), + izip!(0.., &sets, &diffs).flat_map(|(set_idx, set, set_coeff)| { + let is_first_set = set_idx == 0; + let is_last_set = set_idx == sets.len() - 1; + let ec_add = &format!("ec_add_{}", if is_first_set { "acc" } else { "tmp" }); + let ec_mul = &format!("ec_mul_{}", if is_first_set { "acc" } else { "tmp" }); + let acc_x = Ptr::memory(0x00) + if is_first_set { 0 } else { 4 }; + let acc_y = acc_x + 1; + let comm_groups = group_backward_adjacent_ec_points(set.comms().iter().rev().skip(1)); + + chain![ + set.comms() + .last() + .map(|comm| { + [ + format!("mstore({acc_x}, {})", comm.x()), + format!("mstore({acc_y}, {})", comm.y()), + ] + }) + .into_iter() + .flatten(), + comm_groups.into_iter().flat_map(move |(loc, comms)| { + if comms.len() < 3 { + comms + .iter() + .flat_map(|comm| { + let (x, y) = (comm.x(), comm.y()); + [ + format!("success := {ec_mul}(success, zeta)"), + format!("success := {ec_add}(success, {x}, {y})"), + ] + }) + .collect_vec() + } else { + let ptr = comms.first().unwrap().x().ptr(); + let ptr_end = ptr - 2 * comms.len(); + let x = Word::from(Ptr::new(loc, "ptr")); + let y = Word::from(Ptr::new(loc, "add(ptr, 0x20)")); + for_loop( + [ + format!("let ptr := {ptr}"), + format!("let ptr_end := {ptr_end}"), + ], + "lt(ptr_end, ptr)", + ["ptr := sub(ptr, 0x40)"], + [ + format!("success := {ec_mul}(success, zeta)"), + format!("success := {ec_add}(success, {x}, {y})"), + ], + ) + } + }), + (!is_first_set) + .then(|| { + let scalar = format!("mulmod(nu, {set_coeff}, r)"); + chain![ + [ + format!("success := ec_mul_tmp(success, {scalar})"), + format!("success := ec_add_acc(success, mload(0x80), mload(0xa0))"), + ], + (!is_last_set).then(|| format!("nu := mulmod(nu, mload(NU_MPTR), r)")) + ] + }) + .into_iter() + .flatten(), + ] + .collect_vec() + }), + [ + format!("mstore(0x80, mload(G1_X_MPTR))"), + format!("mstore(0xa0, mload(G1_Y_MPTR))"), + format!("success := ec_mul_tmp(success, mload(G1_SCALAR_MPTR))"), + format!("success := ec_add_acc(success, mload(0x80), mload(0xa0))"), + format!("mstore(0x80, {})", w.x()), + format!("mstore(0xa0, {})", w.y()), + format!("success := ec_mul_tmp(success, sub(r, {vanishing_0}))"), + format!("success := ec_add_acc(success, mload(0x80), mload(0xa0))"), + format!("mstore(0x80, {})", w_prime.x()), + format!("mstore(0xa0, {})", w_prime.y()), + format!("success := ec_mul_tmp(success, mload(MU_MPTR))"), + format!("success := ec_add_acc(success, mload(0x80), mload(0xa0))"), + format!("mstore(PAIRING_LHS_X_MPTR, mload(0x00))"), + format!("mstore(PAIRING_LHS_Y_MPTR, mload(0x20))"), + format!("mstore(PAIRING_RHS_X_MPTR, {})", w_prime.x()), + format!("mstore(PAIRING_RHS_Y_MPTR, {})", w_prime.y()), + ], + ] + .collect_vec(); + + chain![ + [point_computations, vanishing_computations], + coeff_computations, + [normalized_coeff_computations], + r_evals_computations, + coeff_sums_computation, + [r_eval_computations, pairing_input_computations], + ] + .collect_vec() +} + +#[derive(Debug)] +struct RotationSet { + rots: BTreeSet, + diffs: BTreeSet, + comms: Vec, + evals: Vec>, +} + +impl RotationSet { + fn rots(&self) -> &BTreeSet { + &self.rots + } + + fn diffs(&self) -> &BTreeSet { + &self.diffs + } + + fn comms(&self) -> &[EcPoint] { + &self.comms + } + + fn evals(&self) -> &[Vec] { + &self.evals + } +} + +fn rotation_sets(queries: &[Query]) -> (BTreeSet, Vec) { + let mut superset = BTreeSet::new(); + let comm_queries = queries.iter().fold( + Vec::<(EcPoint, BTreeMap)>::new(), + |mut comm_queries, query| { + superset.insert(query.rot); + if let Some(pos) = comm_queries + .iter() + .position(|(comm, _)| comm == &query.comm) + { + let (_, queries) = &mut comm_queries[pos]; + assert!(!queries.contains_key(&query.rot)); + queries.insert(query.rot, query.eval); + } else { + comm_queries.push((query.comm, BTreeMap::from_iter([(query.rot, query.eval)]))); + } + comm_queries + }, + ); + let superset = superset; + let sets = + comm_queries + .into_iter() + .fold(Vec::::new(), |mut sets, (comm, queries)| { + if let Some(pos) = sets + .iter() + .position(|set| itertools::equal(&set.rots, queries.keys())) + { + let set = &mut sets[pos]; + if !set.comms.contains(&comm) { + set.comms.push(comm); + set.evals.push(queries.into_values().collect_vec()); + } + } else { + let diffs = BTreeSet::from_iter( + superset + .iter() + .filter(|rot| !queries.contains_key(rot)) + .copied(), + ); + let set = RotationSet { + rots: BTreeSet::from_iter(queries.keys().copied()), + diffs, + comms: vec![comm], + evals: vec![queries.into_values().collect()], + }; + sets.push(set); + } + sets + }); + (superset, sets) +} diff --git a/src/codegen/pcs/gwc19.rs b/src/codegen/pcs/gwc19.rs new file mode 100644 index 0000000..be7562f --- /dev/null +++ b/src/codegen/pcs/gwc19.rs @@ -0,0 +1,298 @@ +#![allow(clippy::useless_format)] + +use crate::codegen::{ + pcs::{queries, Query}, + util::{ + for_loop, group_backward_adjacent_ec_points, group_backward_adjacent_words, + ConstraintSystemMeta, Data, EcPoint, Location, Ptr, Word, + }, +}; +use itertools::{chain, izip, Itertools}; +use std::collections::BTreeMap; + +pub(super) fn static_working_memory_size(meta: &ConstraintSystemMeta, _: &Data) -> usize { + 0x100 + meta.num_rotations * 0x40 +} + +pub(super) fn computations(meta: &ConstraintSystemMeta, data: &Data) -> Vec> { + let sets = rotation_sets(&queries(meta, data)); + let rots = sets.iter().map(|set| set.rot).collect_vec(); + let (min_rot, max_rot) = rots + .iter() + .copied() + .minmax() + .into_option() + .unwrap_or_default(); + + let ws = EcPoint::range(data.w_cptr).take(sets.len()).collect_vec(); + + let point_w_mptr = Ptr::memory(0x100); + let point_ws = izip!(rots, EcPoint::range(point_w_mptr)).collect::>(); + + let eval_computations = { + chain![ + [ + "let nu := mload(NU_MPTR)", + "let mu := mload(MU_MPTR)", + "let eval_acc", + "let eval_tmp", + ] + .map(str::to_string), + sets.iter().enumerate().rev().flat_map(|(set_idx, set)| { + let is_last_set = set_idx == sets.len() - 1; + let eval_acc = &format!("eval_{}", if is_last_set { "acc" } else { "tmp" }); + let eval_groups = group_backward_adjacent_words(set.evals().iter().rev().skip(1)); + + chain![ + set.evals() + .last() + .map(|eval| format!("{eval_acc} := {}", eval)), + eval_groups.iter().flat_map(|(loc, evals)| { + if evals.len() < 3 { + evals + .iter() + .map(|eval| { + format!( + "{eval_acc} := addmod(mulmod({eval_acc}, nu, r), {eval}, r)" + ) + }) + .collect_vec() + } else { + assert_eq!(*loc, Location::Calldata); + let eval = "calldataload(cptr)"; + for_loop( + [ + format!("let cptr := {}", evals[0].ptr()), + format!("let cptr_end := {}", evals[0].ptr() - evals.len()), + ], + "lt(cptr_end, cptr)", + ["cptr := sub(cptr, 0x20)"], + [format!( + "{eval_acc} := addmod(mulmod({eval_acc}, nu, r), {eval}, r)" + )], + ) + } + }), + (!is_last_set) + .then_some([ + "eval_acc := mulmod(eval_acc, mu, r)", + "eval_acc := addmod(eval_acc, eval_tmp, r)", + ]) + .into_iter() + .flatten() + .map(str::to_string), + ] + .collect_vec() + }), + ["mstore(G1_SCALAR_MPTR, sub(r, eval_acc))".to_string()], + ] + .collect_vec() + }; + + let point_computations = chain![ + [ + "let x := mload(X_MPTR)", + "let omega := mload(OMEGA_MPTR)", + "let omega_inv := mload(OMEGA_INV_MPTR)", + "let x_pow_of_omega := mulmod(x, omega, r)" + ] + .map(str::to_string), + (1..=max_rot).flat_map(|rot| { + chain![ + point_ws + .get(&rot) + .map(|point| format!("mstore({}, x_pow_of_omega)", point.x().ptr())), + (rot != max_rot) + .then(|| "x_pow_of_omega := mulmod(x_pow_of_omega, omega, r)".to_string()) + ] + }), + [ + format!("mstore({}, x)", point_ws[&0].x().ptr()), + format!("x_pow_of_omega := mulmod(x, omega_inv, r)") + ], + (min_rot..0).rev().flat_map(|rot| { + chain![ + point_ws + .get(&rot) + .map(|point| format!("mstore({}, x_pow_of_omega)", point.x().ptr())), + (rot != min_rot).then(|| { + "x_pow_of_omega := mulmod(x_pow_of_omega, omega_inv, r)".to_string() + }) + ] + }) + ] + .collect_vec(); + + let point_w_computations = for_loop( + [ + format!("let cptr := {}", data.w_cptr), + format!("let mptr := {point_w_mptr}"), + format!("let mptr_end := {}", point_w_mptr + 2 * sets.len()), + ], + "lt(mptr, mptr_end)".to_string(), + ["mptr := add(mptr, 0x40)", "cptr := add(cptr, 0x40)"].map(str::to_string), + [ + "mstore(0x00, calldataload(cptr))", + "mstore(0x20, calldataload(add(cptr, 0x20)))", + "success := ec_mul_acc(success, mload(mptr))", + "mstore(mptr, mload(0x00))", + "mstore(add(mptr, 0x20), mload(0x20))", + ] + .map(str::to_string), + ); + + let pairing_lhs_computations = chain![ + ["let nu := mload(NU_MPTR)", "let mu := mload(MU_MPTR)"].map(str::to_string), + sets.iter().enumerate().rev().flat_map(|(set_idx, set)| { + let is_last_set = set_idx == sets.len() - 1; + let ec_add = &format!("ec_add_{}", if is_last_set { "acc" } else { "tmp" }); + let ec_mul = &format!("ec_mul_{}", if is_last_set { "acc" } else { "tmp" }); + let acc_x = Ptr::memory(0x00) + if is_last_set { 0 } else { 4 }; + let acc_y = acc_x + 1; + let point_w = &point_ws[&set.rot]; + let comm_groups = group_backward_adjacent_ec_points(set.comms().iter().rev().skip(1)); + + chain![ + set.comms() + .last() + .map(|comm| { + [ + format!("mstore({acc_x}, {})", comm.x()), + format!("mstore({acc_y}, {})", comm.y()), + ] + }) + .into_iter() + .flatten(), + comm_groups.into_iter().flat_map(move |(loc, comms)| { + if comms.len() < 3 { + comms + .iter() + .flat_map(|comm| { + let (x, y) = (comm.x(), comm.y()); + [ + format!("success := {ec_mul}(success, nu)"), + format!("success := {ec_add}(success, {x}, {y})"), + ] + }) + .collect_vec() + } else { + let ptr = comms.first().unwrap().x().ptr(); + let ptr_end = ptr - 2 * comms.len(); + let x = Word::from(Ptr::new(loc, "ptr")); + let y = Word::from(Ptr::new(loc, "add(ptr, 0x20)")); + for_loop( + [ + format!("let ptr := {ptr}"), + format!("let ptr_end := {ptr_end}"), + ], + "lt(ptr_end, ptr)", + ["ptr := sub(ptr, 0x40)".to_string()], + [ + format!("success := {ec_mul}(success, nu)"), + format!("success := {ec_add}(success, {x}, {y})"), + ], + ) + } + }), + [format!( + "success := {ec_add}(success, {}, {})", + point_w.x(), + point_w.y() + )], + (!is_last_set) + .then_some([ + "success := ec_mul_acc(success, mu)", + "success := ec_add_acc(success, mload(0x80), mload(0xa0))", + ]) + .into_iter() + .flatten() + .map(str::to_string), + ] + .collect_vec() + }), + [ + "mstore(0x80, mload(G1_X_MPTR))", + "mstore(0xa0, mload(G1_Y_MPTR))", + "success := ec_mul_tmp(success, mload(G1_SCALAR_MPTR))", + "success := ec_add_acc(success, mload(0x80), mload(0xa0))", + "mstore(PAIRING_LHS_X_MPTR, mload(0x00))", + "mstore(PAIRING_LHS_Y_MPTR, mload(0x20))", + ] + .map(str::to_string), + ] + .collect_vec(); + + let pairing_rhs_computations = chain![ + [ + format!("let mu := mload(MU_MPTR)"), + format!("mstore(0x00, {})", ws.last().unwrap().x()), + format!("mstore(0x20, {})", ws.last().unwrap().y()), + ], + ws.iter() + .nth_back(1) + .map(|w_second_last| { + let x = "calldataload(cptr)"; + let y = "calldataload(add(cptr, 0x20))"; + for_loop( + [ + format!("let cptr := {}", w_second_last.x().ptr()), + format!("let cptr_end := {}", ws[0].x().ptr() - 1), + ], + "lt(cptr_end, cptr)", + ["cptr := sub(cptr, 0x40)"], + [ + format!("success := ec_mul_acc(success, mu)"), + format!("success := ec_add_acc(success, {x}, {y})"), + ], + ) + }) + .into_iter() + .flatten(), + [ + format!("mstore(PAIRING_RHS_X_MPTR, mload(0x00))"), + format!("mstore(PAIRING_RHS_Y_MPTR, mload(0x20))"), + ], + ] + .collect_vec(); + + vec![ + eval_computations, + point_computations, + point_w_computations, + pairing_lhs_computations, + pairing_rhs_computations, + ] +} + +#[derive(Debug)] +struct RotationSet { + rot: i32, + comms: Vec, + evals: Vec, +} + +impl RotationSet { + fn comms(&self) -> &[EcPoint] { + &self.comms + } + + fn evals(&self) -> &[Word] { + &self.evals + } +} + +fn rotation_sets(queries: &[Query]) -> Vec { + queries.iter().fold(Vec::new(), |mut sets, query| { + if let Some(pos) = sets.iter().position(|set| set.rot == query.rot) { + sets[pos].comms.push(query.comm); + sets[pos].evals.push(query.eval); + } else { + sets.push(RotationSet { + rot: query.rot, + comms: vec![query.comm], + evals: vec![query.eval], + }); + } + sets + }) +} diff --git a/src/codegen/template.rs b/src/codegen/template.rs index 55753f9..ef0ae87 100644 --- a/src/codegen/template.rs +++ b/src/codegen/template.rs @@ -36,6 +36,7 @@ pub(crate) struct Halo2Verifier { pub(crate) num_neg_lagranges: usize, pub(crate) num_advices: Vec, pub(crate) num_challenges: Vec, + pub(crate) num_rotations: usize, pub(crate) num_evals: usize, pub(crate) num_quotients: usize, pub(crate) quotient_eval_numer_computations: Vec>, diff --git a/src/codegen/util.rs b/src/codegen/util.rs index 9d3aea9..b4a2b64 100644 --- a/src/codegen/util.rs +++ b/src/codegen/util.rs @@ -26,6 +26,7 @@ pub(crate) struct ConstraintSystemMeta { pub(crate) num_quotients: usize, pub(crate) advice_queries: Vec<(usize, i32)>, pub(crate) fixed_queries: Vec<(usize, i32)>, + pub(crate) num_rotations: usize, pub(crate) num_evals: usize, pub(crate) num_user_advices: Vec, pub(crate) num_user_challenges: Vec, @@ -91,6 +92,21 @@ impl ConstraintSystemMeta { let (num_user_advices, advice_indices) = remapping(cs.advice_column_phase()); let (num_user_challenges, challenge_indices) = remapping(cs.challenge_phase()); let rotation_last = -(cs.blinding_factors() as i32 + 1); + let num_rotations = chain![ + advice_queries.iter().map(|query| query.1), + fixed_queries.iter().map(|query| query.1), + (num_permutation_zs > 0) + .then_some([0, 1]) + .into_iter() + .flatten(), + (num_permutation_zs > 1).then_some(rotation_last), + (num_lookup_zs > 0) + .then_some([-1, 0, 1]) + .into_iter() + .flatten(), + ] + .unique() + .count(); Self { num_fixeds, permutation_columns, @@ -102,6 +118,7 @@ impl ConstraintSystemMeta { advice_queries, fixed_queries, num_evals, + num_rotations, num_user_advices, num_user_challenges, advice_indices, @@ -159,12 +176,10 @@ impl ConstraintSystemMeta { } pub(crate) fn batch_open_proof_len(&self, scheme: BatchOpenScheme) -> usize { - match scheme { - Bdfg21 => 2 * 0x40, - Gwc19 => { - unimplemented!() - } - } + (match scheme { + Bdfg21 => 2, + Gwc19 => self.num_rotations, + }) * 0x40 } } @@ -477,6 +492,10 @@ impl Word { pub(crate) fn ptr(&self) -> Ptr { self.0 } + + pub(crate) fn loc(&self) -> Location { + self.0.loc() + } } impl Display for Word { @@ -530,10 +549,12 @@ impl From for EcPoint { } /// Add indention to given lines by `4 * N` spaces. -pub(crate) fn indent(lines: impl IntoIterator) -> Vec { +pub(crate) fn indent( + lines: impl IntoIterator>, +) -> Vec { lines .into_iter() - .map(|line| format!("{}{line}", " ".repeat(N * 4))) + .map(|line| format!("{}{}", " ".repeat(N * 4), line.into())) .collect() } @@ -541,9 +562,9 @@ pub(crate) fn indent(lines: impl IntoIterator) -> /// /// If `PACKED` is true, single line code block will be packed into single line. pub(crate) fn code_block( - lines: impl IntoIterator, + lines: impl IntoIterator>, ) -> Vec { - let lines = lines.into_iter().collect_vec(); + let lines = lines.into_iter().map_into().collect_vec(); let bracket_indent = " ".repeat((N - 1) * 4); match lines.len() { 0 => vec![format!("{bracket_indent}{{}}")], @@ -559,10 +580,10 @@ pub(crate) fn code_block( /// Create a for loop with proper indention. pub(crate) fn for_loop( - initialization: impl IntoIterator, + initialization: impl IntoIterator>, condition: impl Into, - advancement: impl IntoIterator, - body: impl IntoIterator, + advancement: impl IntoIterator>, + body: impl IntoIterator>, ) -> Vec { chain![ ["for".to_string()], @@ -574,6 +595,50 @@ pub(crate) fn for_loop( .collect() } +pub(crate) fn group_backward_adjacent_words<'a>( + words: impl IntoIterator, +) -> Vec<(Location, Vec<&'a Word>)> { + words.into_iter().fold(Vec::new(), |mut word_groups, word| { + if let Some(last_group) = word_groups.last_mut() { + let last_word = **last_group.1.last().unwrap(); + if last_group.0 == word.loc() + && last_word.ptr().value().is_integer() + && last_word.ptr() - 1 == word.ptr() + { + last_group.1.push(word) + } else { + word_groups.push((word.loc(), vec![word])) + } + word_groups + } else { + vec![(word.loc(), vec![word])] + } + }) +} + +pub(crate) fn group_backward_adjacent_ec_points<'a>( + ec_point: impl IntoIterator, +) -> Vec<(Location, Vec<&'a EcPoint>)> { + ec_point + .into_iter() + .fold(Vec::new(), |mut ec_point_groups, ec_point| { + if let Some(last_group) = ec_point_groups.last_mut() { + let last_ec_point = **last_group.1.last().unwrap(); + if last_group.0 == ec_point.loc() + && last_ec_point.x().ptr().value().is_integer() + && last_ec_point.x().ptr() - 2 == ec_point.x().ptr() + { + last_group.1.push(ec_point) + } else { + ec_point_groups.push((ec_point.loc(), vec![ec_point])) + } + ec_point_groups + } else { + vec![(ec_point.loc(), vec![ec_point])] + } + }) +} + pub(crate) fn g1_to_u256s(ec_point: impl Borrow) -> [U256; 2] { let coords = ec_point.borrow().coordinates().unwrap(); [coords.x(), coords.y()].map(fq_to_u256) diff --git a/src/test.rs b/src/test.rs index fb31148..7e54ab0 100644 --- a/src/test.rs +++ b/src/test.rs @@ -1,5 +1,9 @@ use crate::{ - codegen::{AccumulatorEncoding, BatchOpenScheme::Bdfg21, SolidityGenerator}, + codegen::{ + AccumulatorEncoding, + BatchOpenScheme::{self, Bdfg21, Gwc19}, + SolidityGenerator, + }, encode_calldata, evm::test::{compile_solidity, Evm}, FN_SIG_VERIFY_PROOF, FN_SIG_VERIFY_PROOF_WITH_VK_ADDRESS, @@ -26,31 +30,51 @@ fn function_signature() { } #[test] -fn render_huge() { - run_render::>() +fn render_bdfg21_huge() { + run_render::>(Bdfg21) } #[test] -fn render_maingate() { - run_render::>() +fn render_bdfg21_maingate() { + run_render::>(Bdfg21) } #[test] -fn render_separately_huge() { - run_render_separately::>() +fn render_gwc19_huge() { + run_render::>(Gwc19) } #[test] -fn render_separately_maingate() { - run_render_separately::>() +fn render_gwc19_maingate() { + run_render::>(Gwc19) } -fn run_render>() { +#[test] +fn render_separately_bdfg21_huge() { + run_render_separately::>(Bdfg21) +} + +#[test] +fn render_separately_bdfg21_maingate() { + run_render_separately::>(Bdfg21) +} + +#[test] +fn render_separately_gwc19_huge() { + run_render_separately::>(Gwc19) +} + +#[test] +fn render_separately_gwc19_maingate() { + run_render_separately::>(Gwc19) +} + +fn run_render>(scheme: BatchOpenScheme) { let acc_encoding = AccumulatorEncoding::new(0, 4, 68).into(); let (params, vk, instances, proof) = - halo2::create_testdata_bdfg21::(C::min_k(), acc_encoding, std_rng()); + halo2::create_testdata::(C::min_k(), scheme, acc_encoding, std_rng()); - let generator = SolidityGenerator::new(¶ms, &vk, Bdfg21, instances.len()) + let generator = SolidityGenerator::new(¶ms, &vk, scheme, instances.len()) .set_acc_encoding(acc_encoding); let verifier_solidity = generator.render().unwrap(); let verifier_creation_code = compile_solidity(verifier_solidity); @@ -68,12 +92,12 @@ fn run_render>() { println!("Gas cost: {gas_cost}"); } -fn run_render_separately>() { +fn run_render_separately>(scheme: BatchOpenScheme) { let acc_encoding = AccumulatorEncoding::new(0, 4, 68).into(); let (params, vk, instances, _) = - halo2::create_testdata_bdfg21::(C::min_k(), acc_encoding, std_rng()); + halo2::create_testdata::(C::min_k(), scheme, acc_encoding, std_rng()); - let generator = SolidityGenerator::new(¶ms, &vk, Bdfg21, instances.len()) + let generator = SolidityGenerator::new(¶ms, &vk, scheme, instances.len()) .set_acc_encoding(acc_encoding); let (verifier_solidity, _vk_solidity) = generator.render_separately().unwrap(); let verifier_creation_code = compile_solidity(&verifier_solidity); @@ -90,8 +114,8 @@ fn run_render_separately>() { for k in C::min_k()..C::min_k() + 4 { let (params, vk, instances, proof) = - halo2::create_testdata_bdfg21::(k, acc_encoding, std_rng()); - let generator = SolidityGenerator::new(¶ms, &vk, Bdfg21, instances.len()) + halo2::create_testdata::(k, scheme, acc_encoding, std_rng()); + let generator = SolidityGenerator::new(¶ms, &vk, scheme, instances.len()) .set_acc_encoding(acc_encoding); let (verifier_solidity, vk_solidity) = generator.render_separately().unwrap(); @@ -131,7 +155,11 @@ fn save_generated(verifier: &str, vk: Option<&str>) { } mod halo2 { - use crate::{codegen::AccumulatorEncoding, transcript::Keccak256Transcript}; + use crate::{ + codegen::AccumulatorEncoding, + transcript::Keccak256Transcript, + BatchOpenScheme::{self, Bdfg21, Gwc19}, + }; use halo2_proofs::{ arithmetic::CurveAffine, halo2curves::{ @@ -143,7 +171,7 @@ mod halo2 { plonk::{create_proof, keygen_pk, keygen_vk, verify_proof, Circuit, VerifyingKey}, poly::kzg::{ commitment::ParamsKZG, - multiopen::{ProverSHPLONK, VerifierSHPLONK}, + multiopen::{ProverGWC, ProverSHPLONK, VerifierGWC, VerifierSHPLONK}, strategy::SingleStrategy, }, transcript::TranscriptWriterBuffer, @@ -162,8 +190,9 @@ mod halo2 { } #[allow(clippy::type_complexity)] - pub fn create_testdata_bdfg21>( + pub fn create_testdata>( k: u32, + scheme: BatchOpenScheme, acc_encoding: Option, mut rng: impl RngCore + Clone, ) -> ( @@ -172,42 +201,55 @@ mod halo2 { Vec, Vec, ) { - let circuit = C::new(acc_encoding, rng.clone()); - let instances = circuit.instances(); - - let params = ParamsKZG::::setup(k, &mut rng); - let vk = keygen_vk(¶ms, &circuit).unwrap(); - let pk = keygen_pk(¶ms, vk.clone(), &circuit).unwrap(); - - let proof = { - let mut transcript = Keccak256Transcript::new(Vec::new()); - create_proof::<_, ProverSHPLONK<_>, _, _, _, _>( - ¶ms, - &pk, - &[circuit], - &[&[&instances]], - &mut rng, - &mut transcript, - ) - .unwrap(); - transcript.finalize() - }; - - let result = { - let mut transcript = Keccak256Transcript::new(proof.as_slice()); - verify_proof::<_, VerifierSHPLONK<_>, _, _, SingleStrategy<_>>( - ¶ms, - pk.get_vk(), - SingleStrategy::new(¶ms), - &[&[&instances]], - &mut transcript, - ) - }; - assert!(result.is_ok()); + match scheme { + Bdfg21 => { + create_testdata_inner!(ProverSHPLONK<_>, VerifierSHPLONK<_>, k, acc_encoding, rng) + } + Gwc19 => create_testdata_inner!(ProverGWC<_>, VerifierGWC<_>, k, acc_encoding, rng), + } + } - (params, vk, instances, proof) + macro_rules! create_testdata_inner { + ($p:ty, $v:ty, $k:ident, $acc_encoding:ident, $rng:ident) => {{ + let circuit = C::new($acc_encoding, $rng.clone()); + let instances = circuit.instances(); + + let params = ParamsKZG::::setup($k, &mut $rng); + let vk = keygen_vk(¶ms, &circuit).unwrap(); + let pk = keygen_pk(¶ms, vk.clone(), &circuit).unwrap(); + + let proof = { + let mut transcript = Keccak256Transcript::new(Vec::new()); + create_proof::<_, $p, _, _, _, _>( + ¶ms, + &pk, + &[circuit], + &[&[&instances]], + &mut $rng, + &mut transcript, + ) + .unwrap(); + transcript.finalize() + }; + + let result = { + let mut transcript = Keccak256Transcript::new(proof.as_slice()); + verify_proof::<_, $v, _, _, SingleStrategy<_>>( + ¶ms, + pk.get_vk(), + SingleStrategy::new(¶ms), + &[&[&instances]], + &mut transcript, + ) + }; + assert!(result.is_ok()); + + (params, vk, instances, proof) + }}; } + use create_testdata_inner; + fn random_accumulator_limbs( acc_encoding: AccumulatorEncoding, mut rng: impl RngCore, diff --git a/templates/Halo2Verifier.sol b/templates/Halo2Verifier.sol index 7abfb0b..0e93967 100644 --- a/templates/Halo2Verifier.sol +++ b/templates/Halo2Verifier.sol @@ -47,7 +47,8 @@ contract Halo2Verifier { uint256 internal constant NU_MPTR = {{ theta_mptr + 6 }}; uint256 internal constant MU_MPTR = {{ theta_mptr + 7 }}; {%- when Gwc19 %} - // TODO + uint256 internal constant NU_MPTR = {{ theta_mptr + 5 }}; + uint256 internal constant MU_MPTR = {{ theta_mptr + 6 }}; {%- endmatch %} uint256 internal constant ACC_LHS_X_MPTR = {{ theta_mptr + 8 }}; @@ -63,7 +64,7 @@ contract Halo2Verifier { uint256 internal constant QUOTIENT_EVAL_MPTR = {{ theta_mptr + 18 }}; uint256 internal constant QUOTIENT_X_MPTR = {{ theta_mptr + 19 }}; uint256 internal constant QUOTIENT_Y_MPTR = {{ theta_mptr + 20 }}; - uint256 internal constant R_EVAL_MPTR = {{ theta_mptr + 21 }}; + uint256 internal constant G1_SCALAR_MPTR = {{ theta_mptr + 21 }}; uint256 internal constant PAIRING_LHS_X_MPTR = {{ theta_mptr + 22 }}; uint256 internal constant PAIRING_LHS_Y_MPTR = {{ theta_mptr + 23 }}; uint256 internal constant PAIRING_RHS_X_MPTR = {{ theta_mptr + 24 }}; @@ -303,7 +304,17 @@ contract Halo2Verifier { success, proof_cptr, hash_mptr := read_ec_point(success, proof_cptr, hash_mptr, q) // W' {%- when Gwc19 %} - // TODO + challenge_mptr, hash_mptr := squeeze_challenge(challenge_mptr, hash_mptr, r) // nu + + for + { let proof_cptr_end := add(proof_cptr, {{ (2 * 32 * num_rotations)|hex() }}) } + lt(proof_cptr, proof_cptr_end) + {} + { + success, proof_cptr, hash_mptr := read_ec_point(success, proof_cptr, hash_mptr, q) + } + + challenge_mptr, hash_mptr := squeeze_challenge(challenge_mptr, hash_mptr, r) // mu {%- endmatch %} {%~ match self.embedded_vk %} From 2ab5ecf95363eb969a36cfd5fa6d05ef28311b92 Mon Sep 17 00:00:00 2001 From: han0110 Date: Tue, 30 Jan 2024 03:33:46 +0000 Subject: [PATCH 2/3] doc: update `README.md` --- README.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/README.md b/README.md index 7cae107..7798fad 100644 --- a/README.md +++ b/README.md @@ -34,9 +34,8 @@ Note that function selector is already included. ## Limitations & Caveats -- It only allows circuit with **exact 1 instance column** and **no rotated query to this instance column**. +- It only allows circuit with **less or equal than 1 instance column** and **no rotated query to this instance column**. - Currently even the `configure` is same, the [selector compression](https://github.com/privacy-scaling-explorations/halo2/blob/7a2165617195d8baa422ca7b2b364cef02380390/halo2_proofs/src/plonk/circuit/compress_selectors.rs#L51) might lead to different configuration when selector assignments are different. To avoid this, please use [`keygen_vk_custom`](https://github.com/privacy-scaling-explorations/halo2/blob/6fc6d7ca018f3899b030618cb18580249b1e7c82/halo2_proofs/src/plonk/keygen.rs#L223) with `compress_selectors: false` to do key generation without selector compression. -- Now it only supports BDFG21 batch open scheme (aka SHPLONK), GWC19 is not yet implemented. ## Compatibility From 43dada6924306308c0eb50dd041ac2a82ddf387a Mon Sep 17 00:00:00 2001 From: han0110 Date: Fri, 2 Feb 2024 10:29:32 +0000 Subject: [PATCH 3/3] fix: make sure accumulator coordinates are less than base field modulus --- templates/Halo2Verifier.sol | 2 ++ 1 file changed, 2 insertions(+) diff --git a/templates/Halo2Verifier.sol b/templates/Halo2Verifier.sol index 0e93967..3dfcf09 100644 --- a/templates/Halo2Verifier.sol +++ b/templates/Halo2Verifier.sol @@ -367,7 +367,9 @@ contract Halo2Verifier { shift := add(shift, num_limb_bits) } + success := and(success, and(lt(lhs_x, q), lt(lhs_y, q))) success := and(success, eq(mulmod(lhs_y, lhs_y, q), addmod(mulmod(lhs_x, mulmod(lhs_x, lhs_x, q), q), 3, q))) + success := and(success, and(lt(rhs_x, q), lt(rhs_y, q))) success := and(success, eq(mulmod(rhs_y, rhs_y, q), addmod(mulmod(rhs_x, mulmod(rhs_x, rhs_x, q), q), 3, q))) mstore(ACC_LHS_X_MPTR, lhs_x)