From e2dab6786892511f389e0eba53ee93d7f9df7265 Mon Sep 17 00:00:00 2001 From: adria0 Date: Mon, 17 Jun 2024 11:39:08 +0200 Subject: [PATCH] some refactor --- halo2_frontend/src/circuit.rs | 6 +- halo2_frontend/src/dev.rs | 6 +- halo2_frontend/src/dev/{cost.rs_ => cost.rs} | 8 +- halo2_frontend/src/dev/cost_model.rs | 6 +- halo2_frontend/src/dev/graph.rs | 2 +- halo2_frontend/src/dev/graph/layout.rs | 2 +- halo2_frontend/src/plonk/assigned.rs | 1 - .../src/plonk/circuit/expression.rs | 225 +++++++++--------- halo2_proofs/src/plonk.rs | 8 +- halo2_proofs/src/plonk/keygen.rs | 24 +- halo2_proofs/src/plonk/prover.rs | 14 +- 11 files changed, 149 insertions(+), 153 deletions(-) rename halo2_frontend/src/dev/{cost.rs_ => cost.rs} (98%) diff --git a/halo2_frontend/src/circuit.rs b/halo2_frontend/src/circuit.rs index 0eae281295..cff44823b1 100644 --- a/halo2_frontend/src/circuit.rs +++ b/halo2_frontend/src/circuit.rs @@ -38,7 +38,7 @@ use crate::plonk::FieldFr; /// generation, and proof generation. /// If `compress_selectors` is true, multiple selector columns may be multiplexed. #[allow(clippy::type_complexity)] -pub fn compile_circuit, F: Field, ConcreteCircuit: Circuit>( +pub fn compile_circuit, F: Field, ConcreteCircuit: Circuit>( k: u32, circuit: &ConcreteCircuit, compress_selectors: bool, @@ -46,7 +46,7 @@ pub fn compile_circuit, F: Field, ConcreteCircuit: Circuit< ( CompiledCircuit, ConcreteCircuit::Config, - ConstraintSystem, + ConstraintSystem, ), Error, > { @@ -65,7 +65,7 @@ pub fn compile_circuit, F: Field, ConcreteCircuit: Circuit< let mut assembly = plonk::keygen::Assembly { k, - fixed: vec![vec![EF::ZERO.into(); n]; cs.num_fixed_columns], + fixed: vec![vec![FF::ZERO.into(); n]; cs.num_fixed_columns], permutation: permutation::Assembly::new(n, &cs.permutation), selectors: vec![vec![false; n]; cs.num_selectors], usable_rows: 0..n - (cs.blinding_factors() + 1), diff --git a/halo2_frontend/src/dev.rs b/halo2_frontend/src/dev.rs index 5c99e418bc..7e60c80625 100644 --- a/halo2_frontend/src/dev.rs +++ b/halo2_frontend/src/dev.rs @@ -29,8 +29,8 @@ mod util; mod failure; pub use failure::{FailureLocation, VerifyFailure}; -// pub mod cost; -// pub use cost::CircuitCost; +pub mod cost; +pub use cost::CircuitCost; #[cfg(feature = "cost-estimator")] pub mod cost_model; @@ -284,7 +284,7 @@ impl Mul for Value { /// MockProver::::run(2, &circuit, vec![]).unwrap_err() /// }); /// assert_eq!( -/// result.unwrap_err().downcast_ref::().unwrap(), +/// result.unwrap_err().downcast_rF::().unwrap(), /// "n=4, minimum_rows=8, k=2" /// ); /// ``` diff --git a/halo2_frontend/src/dev/cost.rs_ b/halo2_frontend/src/dev/cost.rs similarity index 98% rename from halo2_frontend/src/dev/cost.rs_ rename to halo2_frontend/src/dev/cost.rs index 7ad7e66055..df0ca90c57 100644 --- a/halo2_frontend/src/dev/cost.rs_ +++ b/halo2_frontend/src/dev/cost.rs @@ -25,7 +25,7 @@ use crate::{ /// Measures a circuit to determine its costs, and explain what contributes to them. #[allow(dead_code)] #[derive(Debug)] -pub struct CircuitCost> { +pub struct CircuitCost, ConcreteCircuit: Circuit> { /// Power-of-2 bound on the number of rows in the circuit. k: u32, /// Maximum degree of the circuit. @@ -53,7 +53,7 @@ pub struct CircuitCost> { num_instance_columns: usize, num_total_columns: usize, - _marker: PhantomData<(G, ConcreteCircuit)>, + _marker: PhantomData<(G, F, ConcreteCircuit)>, } /// Region implementation used by Layout @@ -267,7 +267,7 @@ impl Assignment for Layout { } } -impl> CircuitCost { +impl, ConcreteCircuit: Circuit> CircuitCost { /// Measures a circuit with parameter constant `k`. /// /// Panics if `k` is not large enough for the circuit. @@ -556,6 +556,6 @@ mod tests { Ok(()) } } - CircuitCost::::measure(K, &MyCircuit).proof_size(1); + CircuitCost::::measure(K, &MyCircuit).proof_size(1); } } diff --git a/halo2_frontend/src/dev/cost_model.rs b/halo2_frontend/src/dev/cost_model.rs index 86ce03800a..793c02f71f 100644 --- a/halo2_frontend/src/dev/cost_model.rs +++ b/halo2_frontend/src/dev/cost_model.rs @@ -4,7 +4,7 @@ use std::collections::HashSet; use std::{iter, num::ParseIntError, str::FromStr}; -use crate::plonk::Circuit; +use crate::plonk::{Circuit, FieldFr}; use halo2_middleware::ff::{Field, FromUniformBytes}; use serde::Deserialize; use serde_derive::Serialize; @@ -242,7 +242,7 @@ impl CostOptions { /// Given a Plonk circuit, this function returns a [ModelCircuit] pub fn from_circuit_to_model_circuit< - F: Ord + Field + FromUniformBytes<64>, + F: Ord + FieldFr + FromUniformBytes<64>, C: Circuit, const COMM: usize, const SCALAR: usize, @@ -257,7 +257,7 @@ pub fn from_circuit_to_model_circuit< } /// Given a Plonk circuit, this function returns [CostOptions] -pub fn from_circuit_to_cost_model_options, C: Circuit>( +pub fn from_circuit_to_cost_model_options, C: Circuit>( k: u32, circuit: &C, instances: Vec>, diff --git a/halo2_frontend/src/dev/graph.rs b/halo2_frontend/src/dev/graph.rs index e0f0d667c3..045df059cd 100644 --- a/halo2_frontend/src/dev/graph.rs +++ b/halo2_frontend/src/dev/graph.rs @@ -1,6 +1,6 @@ use crate::plonk::{ Advice, Assigned, Assignment, Challenge, Circuit, Column, ConstraintSystem, Error, Fixed, - FloorPlanner, Instance, Selector, + FloorPlanner, Instance, Selector, FieldFr }; use halo2_middleware::circuit::Any; use halo2_middleware::ff::Field; diff --git a/halo2_frontend/src/dev/graph/layout.rs b/halo2_frontend/src/dev/graph/layout.rs index 8c46a44dc3..2cb18b66f6 100644 --- a/halo2_frontend/src/dev/graph/layout.rs +++ b/halo2_frontend/src/dev/graph/layout.rs @@ -6,7 +6,7 @@ use plotters::{ use std::collections::HashSet; use std::ops::Range; -use crate::plonk::{Circuit, Column, ConstraintSystem, FloorPlanner}; +use crate::plonk::{Circuit, Column, ConstraintSystem, FloorPlanner, FieldFr}; use crate::{circuit::layouter::RegionColumn, dev::cost::Layout}; use halo2_middleware::circuit::Any; diff --git a/halo2_frontend/src/plonk/assigned.rs b/halo2_frontend/src/plonk/assigned.rs index 283186f781..755b5ddde4 100644 --- a/halo2_frontend/src/plonk/assigned.rs +++ b/halo2_frontend/src/plonk/assigned.rs @@ -371,7 +371,6 @@ mod proptests { ops::{Add, Mul, Neg, Sub}, }; - use group::ff::Field; use halo2curves::pasta::Fp; use proptest::{collection::vec, prelude::*, sample::select}; use crate::plonk::FieldFr; diff --git a/halo2_frontend/src/plonk/circuit/expression.rs b/halo2_frontend/src/plonk/circuit/expression.rs index ca89df3502..9ce4298e4a 100644 --- a/halo2_frontend/src/plonk/circuit/expression.rs +++ b/halo2_frontend/src/plonk/circuit/expression.rs @@ -46,27 +46,27 @@ impl Column { } /// Return expression from column at a relative position - pub fn query_cell>(&self, at: Rotation) -> Expression { + pub fn query_cell(&self, at: Rotation) -> Expression { self.column_type.query_cell(self.index, at) } /// Return expression from column at the current row - pub fn cur>(&self) -> Expression { + pub fn cur(&self) -> Expression { self.query_cell(Rotation::cur()) } /// Return expression from column at the next row - pub fn next>(&self) -> Expression { + pub fn next(&self) -> Expression { self.query_cell(Rotation::next()) } /// Return expression from column at the previous row - pub fn prev>(&self) -> Expression { + pub fn prev(&self) -> Expression { self.query_cell(Rotation::prev()) } /// Return expression from column at the specified rotation - pub fn rot>(&self, rotation: i32) -> Expression { + pub fn rot(&self, rotation: i32) -> Expression { self.query_cell(Rotation(rotation)) } } @@ -469,7 +469,7 @@ pub enum Expression { // Arena context pub trait FieldFr: Field { type Field: Field; - fn push(expr: Expression) -> ExprRef; + fn alloc(expr: Expression) -> ExprRef; fn get(ref_: &ExprRef) -> Expression; fn into_field(self) -> Self::Field; fn into_field_fr(f: Self::Field) -> Self; @@ -477,9 +477,9 @@ pub trait FieldFr: Field { #[derive(Clone, Copy, Eq, PartialEq)] pub struct ExprRef(usize, std::marker::PhantomData); -impl Into> for Expression { - fn into(self) -> ExprRef { - EF::push(self) +impl Into> for Expression { + fn into(self) -> ExprRef { + F::alloc(self) } } @@ -554,9 +554,9 @@ impl, IF: Field> From> for ExpressionMid Expression { +impl Expression { /// Make side effects - pub fn query_cells(&mut self, cells: &mut VirtualCells<'_, EF>) { + pub fn query_cells(&mut self, cells: &mut VirtualCells<'_, F>) { match self { Expression::Constant(_) => (), Expression::Selector(selector) => { @@ -595,17 +595,17 @@ impl Expression { } } Expression::Challenge(_) => (), - Expression::Negated(a) => EF::get(a).query_cells(cells), + Expression::Negated(a) => F::get(a).query_cells(cells), Expression::Sum(a, b) => { - EF::get(a).query_cells(cells); - EF::get(b).query_cells(cells); + F::get(a).query_cells(cells); + F::get(b).query_cells(cells); } Expression::Product(a, b) => { - EF::get(a).query_cells(cells); - EF::get(b).query_cells(cells); + F::get(a).query_cells(cells); + F::get(b).query_cells(cells); } - Expression::Scaled(a, _) => EF::get(a).query_cells(cells), - Expression::Ref(a) => EF::get(a).query_cells(cells), + Expression::Scaled(a, _) => F::get(a).query_cells(cells), + Expression::Ref(a) => F::get(a).query_cells(cells), }; } @@ -614,7 +614,7 @@ impl Expression { #[allow(clippy::too_many_arguments)] pub fn evaluate( &self, - constant: &impl Fn(EF) -> T, + constant: &impl Fn(F) -> T, selector_column: &impl Fn(Selector) -> T, fixed_column: &impl Fn(FixedQuery) -> T, advice_column: &impl Fn(AdviceQuery) -> T, @@ -623,7 +623,7 @@ impl Expression { negated: &impl Fn(T) -> T, sum: &impl Fn(T, T) -> T, product: &impl Fn(T, T) -> T, - scaled: &impl Fn(T, EF) -> T, + scaled: &impl Fn(T, F) -> T, ) -> T { match self { Expression::Constant(scalar) => constant(*scalar), @@ -633,7 +633,7 @@ impl Expression { Expression::Instance(query) => instance_column(*query), Expression::Challenge(value) => challenge(*value), Expression::Negated(a) => { - let a = EF::get(a).evaluate( + let a = F::get(a).evaluate( constant, selector_column, fixed_column, @@ -648,7 +648,7 @@ impl Expression { negated(a) } Expression::Sum(a, b) => { - let a = EF::get(a).evaluate( + let a = F::get(a).evaluate( constant, selector_column, fixed_column, @@ -660,7 +660,7 @@ impl Expression { product, scaled, ); - let b = EF::get(b).evaluate( + let b = F::get(b).evaluate( constant, selector_column, fixed_column, @@ -675,7 +675,7 @@ impl Expression { sum(a, b) } Expression::Product(a, b) => { - let a = EF::get(a).evaluate( + let a = F::get(a).evaluate( constant, selector_column, fixed_column, @@ -687,7 +687,7 @@ impl Expression { product, scaled, ); - let b = EF::get(b).evaluate( + let b = F::get(b).evaluate( constant, selector_column, fixed_column, @@ -702,7 +702,7 @@ impl Expression { product(a, b) } Expression::Scaled(a, f) => { - let a = EF::get(a).evaluate( + let a = F::get(a).evaluate( constant, selector_column, fixed_column, @@ -716,7 +716,7 @@ impl Expression { ); scaled(a, *f) } - Expression::Ref(ref_) => EF::get(ref_).evaluate( + Expression::Ref(ref_) => F::get(ref_).evaluate( constant, selector_column, fixed_column, @@ -736,7 +736,7 @@ impl Expression { #[allow(clippy::too_many_arguments)] pub fn evaluate_lazy( &self, - constant: &impl Fn(EF) -> T, + constant: &impl Fn(F) -> T, selector_column: &impl Fn(Selector) -> T, fixed_column: &impl Fn(FixedQuery) -> T, advice_column: &impl Fn(AdviceQuery) -> T, @@ -745,7 +745,7 @@ impl Expression { negated: &impl Fn(T) -> T, sum: &impl Fn(T, T) -> T, product: &impl Fn(T, T) -> T, - scaled: &impl Fn(T, EF) -> T, + scaled: &impl Fn(T, F) -> T, zero: &T, ) -> T { match self { @@ -756,7 +756,7 @@ impl Expression { Expression::Instance(query) => instance_column(*query), Expression::Challenge(value) => challenge(*value), Expression::Negated(a) => { - let a = EF::get(a).evaluate_lazy( + let a = F::get(a).evaluate_lazy( constant, selector_column, fixed_column, @@ -772,7 +772,7 @@ impl Expression { negated(a) } Expression::Sum(a, b) => { - let a = EF::get(a).evaluate_lazy( + let a = F::get(a).evaluate_lazy( constant, selector_column, fixed_column, @@ -785,7 +785,7 @@ impl Expression { scaled, zero, ); - let b = EF::get(b).evaluate_lazy( + let b = F::get(b).evaluate_lazy( constant, selector_column, fixed_column, @@ -801,7 +801,7 @@ impl Expression { sum(a, b) } Expression::Product(a, b) => { - let (a, b) = (EF::get(a), EF::get(b)); + let (a, b) = (F::get(a), F::get(b)); let (a, b) = if a.complexity() <= b.complexity() { (a, b) } else { @@ -841,7 +841,7 @@ impl Expression { } } Expression::Scaled(a, f) => { - let a = EF::get(a).evaluate_lazy( + let a = F::get(a).evaluate_lazy( constant, selector_column, fixed_column, @@ -856,7 +856,7 @@ impl Expression { ); scaled(a, *f) } - Expression::Ref(ref_) => EF::get(ref_).evaluate_lazy( + Expression::Ref(ref_) => F::get(ref_).evaluate_lazy( constant, selector_column, fixed_column, @@ -902,28 +902,28 @@ impl Expression { } Expression::Negated(a) => { writer.write_all(b"(-")?; - EF::get(a).write_identifier(writer)?; + F::get(a).write_identifier(writer)?; writer.write_all(b")") } Expression::Sum(a, b) => { writer.write_all(b"(")?; - EF::get(a).write_identifier(writer)?; + F::get(a).write_identifier(writer)?; writer.write_all(b"+")?; - EF::get(b).write_identifier(writer)?; + F::get(b).write_identifier(writer)?; writer.write_all(b")") } Expression::Product(a, b) => { writer.write_all(b"(")?; - EF::get(a).write_identifier(writer)?; + F::get(a).write_identifier(writer)?; writer.write_all(b"*")?; - EF::get(b).write_identifier(writer)?; + F::get(b).write_identifier(writer)?; writer.write_all(b")") } Expression::Scaled(a, f) => { - EF::get(a).write_identifier(writer)?; + F::get(a).write_identifier(writer)?; write!(writer, "*{f:?}") } - Expression::Ref(a) => EF::get(a).write_identifier(writer), + Expression::Ref(a) => F::get(a).write_identifier(writer), } } @@ -945,11 +945,11 @@ impl Expression { Expression::Advice(_) => 1, Expression::Instance(_) => 1, Expression::Challenge(_) => 0, - Expression::Negated(poly) => EF::get(poly).degree(), - Expression::Sum(a, b) => max(EF::get(a).degree(), EF::get(b).degree()), - Expression::Product(a, b) => EF::get(a).degree() + EF::get(b).degree(), - Expression::Scaled(poly, _) => EF::get(poly).degree(), - Expression::Ref(ref_) => EF::get(ref_).degree(), + Expression::Negated(poly) => F::get(poly).degree(), + Expression::Sum(a, b) => max(F::get(a).degree(), F::get(b).degree()), + Expression::Product(a, b) => F::get(a).degree() + F::get(b).degree(), + Expression::Scaled(poly, _) => F::get(poly).degree(), + Expression::Ref(ref_) => F::get(ref_).degree(), } } @@ -962,16 +962,16 @@ impl Expression { Expression::Advice(_) => 1, Expression::Instance(_) => 1, Expression::Challenge(_) => 0, - Expression::Negated(poly) => EF::get(poly).complexity() + 5, - Expression::Sum(a, b) => EF::get(a).complexity() + EF::get(b).complexity() + 15, - Expression::Product(a, b) => EF::get(a).complexity() + EF::get(b).complexity() + 30, - Expression::Scaled(poly, _) => EF::get(poly).complexity() + 30, - Expression::Ref(ref_) => EF::get(ref_).complexity(), + Expression::Negated(poly) => F::get(poly).complexity() + 5, + Expression::Sum(a, b) => F::get(a).complexity() + F::get(b).complexity() + 15, + Expression::Product(a, b) => F::get(a).complexity() + F::get(b).complexity() + 30, + Expression::Scaled(poly, _) => F::get(poly).complexity() + 30, + Expression::Ref(ref_) => F::get(ref_).complexity(), } } /// Square this expression. - pub fn square(self) -> Expression { + pub fn square(self) -> Expression { self * self } @@ -1021,7 +1021,7 @@ impl Expression { } } -impl std::fmt::Debug for Expression { +impl std::fmt::Debug for Expression { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Expression::Constant(scalar) => f.debug_tuple("Constant").field(scalar).finish(), @@ -1063,23 +1063,23 @@ impl std::fmt::Debug for Expression { Expression::Challenge(challenge) => { f.debug_tuple("Challenge").field(challenge).finish() } - Expression::Negated(poly) => f.debug_tuple("Negated").field(&EF::get(poly)).finish(), + Expression::Negated(poly) => f.debug_tuple("Negated").field(&F::get(poly)).finish(), Expression::Sum(a, b) => f .debug_tuple("Sum") - .field(&EF::get(a)) - .field(&EF::get(b)) + .field(&F::get(a)) + .field(&F::get(b)) .finish(), Expression::Product(a, b) => f .debug_tuple("Product") - .field(&EF::get(a)) - .field(&EF::get(b)) + .field(&F::get(a)) + .field(&F::get(b)) .finish(), Expression::Scaled(poly, scalar) => f .debug_tuple("Scaled") - .field(&EF::get(poly)) + .field(&F::get(poly)) .field(scalar) .finish(), - Expression::Ref(ref_) => f.debug_tuple("").field(&EF::get(ref_)).finish(), + Expression::Ref(ref_) => f.debug_tuple("").field(&F::get(ref_)).finish(), } } } @@ -1089,9 +1089,9 @@ impl Neg for Expression { fn neg(self) -> Self::Output { let ref_ = match self { Self::Ref(ref_) => ref_, - _ => F::push(self), + _ => F::alloc(self), }; - let ref_ = F::push(Expression::Negated(ref_)); + let ref_ = F::alloc(Expression::Negated(ref_)); Expression::Ref(ref_) } } @@ -1104,13 +1104,13 @@ impl Add for Expression { } let self_ = match self { Expression::Ref(ref_) => ref_, - _ => F::push(self), + _ => F::alloc(self), }; let rhs = match rhs { Expression::Ref(ref_) => ref_, - _ => F::push(rhs), + _ => F::alloc(rhs), }; - let ref_ = F::push(Expression::Sum(self_, rhs)); + let ref_ = F::alloc(Expression::Sum(self_, rhs)); Expression::Ref(ref_) } } @@ -1123,87 +1123,84 @@ impl Sub for Expression { } let self_ = match self { Expression::Ref(ref_) => ref_, - _ => F::push(self), + _ => F::alloc(self), }; let rhs = match rhs { Expression::Ref(ref_) => ref_, - _ => F::push(rhs), + _ => F::alloc(rhs), }; - let rhs = F::push(Expression::Negated(rhs)); + let rhs = F::alloc(Expression::Negated(rhs)); - Expression::Ref(F::push(Expression::Sum(self_, rhs))) + Expression::Ref(F::alloc(Expression::Sum(self_, rhs))) } } -impl Mul for Expression { - type Output = Expression; - fn mul(self, rhs: Expression) -> Expression { +impl Mul for Expression { + type Output = Expression; + fn mul(self, rhs: Expression) -> Expression { if self.contains_simple_selector() && rhs.contains_simple_selector() { panic!("attempted to multiply two expressions containing simple selectors"); } let self_ = match self { Expression::Ref(ref_) => ref_, - _ => EF::push(self), + _ => F::alloc(self), }; let rhs = match rhs { Expression::Ref(ref_) => ref_, - _ => EF::push(rhs), + _ => F::alloc(rhs), }; - Expression::Ref(EF::push(Expression::Product(self_, rhs))) + Expression::Ref(F::alloc(Expression::Product(self_, rhs))) } } -impl Mul for Expression { - type Output = Expression; - fn mul(self, rhs: EF) -> Expression { +impl Mul for Expression { + type Output = Expression; + fn mul(self, rhs: F) -> Expression { let self_ = match self { Expression::Ref(ref_) => ref_, - _ => EF::push(self), + _ => F::alloc(self), }; Expression::Scaled(self_, rhs) } } -impl Sum for Expression { +impl Sum for Expression { fn sum>(iter: I) -> Self { iter.reduce(|acc, x| acc + x) - .unwrap_or(Expression::Constant(EF::ZERO)) + .unwrap_or(Expression::Constant(F::ZERO)) } } -impl Product for Expression { +impl Product for Expression { fn product>(iter: I) -> Self { iter.reduce(|acc, x| acc * x) - .unwrap_or(Expression::Constant(EF::ONE)) + .unwrap_or(Expression::Constant(F::ONE)) + } +} + +#[allow(non_camel_case_types)] +pub struct ExprArena(Vec>); + +impl ExprArena { + pub fn new() -> Self { + Self(Vec::new()) + } + pub fn push(&mut self, expr: Expression) -> crate::plonk::ExprRef { + let index = self.0.len(); + self.0.push(expr); + ExprRef(index, std::marker::PhantomData) + } + pub fn get(&self, ref_: crate::plonk::ExprRef) -> &Expression { + &self.0[ref_.0] } } #[macro_export] macro_rules! expression_arena { ($arena:ident, $field:ty) => { - paste::paste! { - static $arena: once_cell::sync::Lazy]<$field>>> = - once_cell::sync::Lazy::new(|| parking_lot::RwLock::new([]::new())); - - #[allow(non_camel_case_types)] - pub struct [] { - expressions: Vec>, - } - impl [] { - pub fn new() -> Self { - Self { - expressions: Vec::new(), - } - } - pub fn push(&mut self, expr: Expression) -> crate::plonk::ExprRef { - let index = self.expressions.len(); - self.expressions.push(expr); - ExprRef(index, std::marker::PhantomData) - } - pub fn get(&self, ref_: crate::plonk::ExprRef) -> &Expression { - &self.expressions[ref_.0] - } - } + fn $arena() -> &'static std::sync::RwLock> { + static LINES: std::sync::OnceLock>> = std::sync::OnceLock::new(); + LINES.get_or_init(|| std::sync::RwLock::new(ExprArena::new())) } impl crate::plonk::FieldFr for $field { @@ -1214,26 +1211,26 @@ macro_rules! expression_arena { fn into_field_fr(f: Self::Field) -> Self { f } - fn push(expr: crate::plonk::Expression) -> crate::plonk::ExprRef { - $arena.write().push(expr) + fn alloc(expr: crate::plonk::Expression) -> crate::plonk::ExprRef { + $arena().write().unwrap().push(expr) } fn get(ref_: &crate::plonk::ExprRef) -> crate::plonk::Expression { - *$arena.read().get(*ref_) + *$arena().read().unwrap().get(*ref_) } } impl Into> for $field { fn into(self) -> crate::plonk::ExprRef<$field> { - crate::plonk::FieldFr::push(crate::plonk::Expression::Constant(self)) + crate::plonk::FieldFr::alloc(crate::plonk::Expression::Constant(self)) } } }; } //#[cfg(test)] -expression_arena!(ARENA_BN256_FR, halo2curves::bn256::Fr); -expression_arena!(ARENA_BN256_FQ, halo2curves::bn256::Fq); -expression_arena!(PASTA_FP_ARENA, halo2curves::pasta::Fp); +expression_arena!(arena_bn256_fr, halo2curves::bn256::Fr); +expression_arena!(arena_bn256_fq, halo2curves::bn256::Fq); +expression_arena!(arena_pasta_fp, halo2curves::pasta::Fp); #[cfg(test)] mod tests { diff --git a/halo2_proofs/src/plonk.rs b/halo2_proofs/src/plonk.rs index d9a2b7fbbe..0c13e1085f 100644 --- a/halo2_proofs/src/plonk.rs +++ b/halo2_proofs/src/plonk.rs @@ -41,7 +41,7 @@ use std::io; /// Checks that field elements are less than modulus, and then checks that the point is on the curve. /// - `RawBytesUnchecked`: Reads an uncompressed curve element with coordinates in Montgomery form; /// does not perform any checks -pub fn vk_read>( +pub fn vk_read>( reader: &mut R, format: SerdeFormat, k: u32, @@ -50,7 +50,7 @@ pub fn vk_read io::Result> where C::Scalar: SerdePrimeField + FromUniformBytes<64>, - EF: FieldFr, + F: FieldFr, { let (_, _, cs) = compile_circuit(k, circuit, compress_selectors) .map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))?; @@ -70,7 +70,7 @@ where /// Checks that field elements are less than modulus, and then checks that the point is on the curve. /// - `RawBytesUnchecked`: Reads an uncompressed curve element with coordinates in Montgomery form; /// does not perform any checks -pub fn pk_read>( +pub fn pk_read>( reader: &mut R, format: SerdeFormat, k: u32, @@ -79,7 +79,7 @@ pub fn pk_read io::Result> where C::Scalar: SerdePrimeField + FromUniformBytes<64>, - EF: FieldFr, + F: FieldFr, { let (_, _, cs) = compile_circuit(k, circuit, compress_selectors) diff --git a/halo2_proofs/src/plonk/keygen.rs b/halo2_proofs/src/plonk/keygen.rs index 130631294a..079c8fd5a2 100644 --- a/halo2_proofs/src/plonk/keygen.rs +++ b/halo2_proofs/src/plonk/keygen.rs @@ -14,15 +14,15 @@ use halo2_middleware::ff::FromUniformBytes; /// **NOTE**: This `keygen_vk` is legacy one, assuming that `compress_selector: true`. /// Hence, it is HIGHLY recommended to pair this util with `keygen_pk`. /// In addition, when using this for key generation, user MUST use `compress_selectors: true`. -pub fn keygen_vk( +pub fn keygen_vk( params: &P, circuit: &ConcreteCircuit, ) -> Result, Error> where C: CurveAffine, P: Params, - EF: FieldFr, - ConcreteCircuit: Circuit, + F: FieldFr, + ConcreteCircuit: Circuit, C::Scalar: FromUniformBytes<64>, { keygen_vk_custom(params, circuit, true) @@ -36,7 +36,7 @@ where /// `ProvingKey` generation process. /// Otherwise, the user could get unmatching pk/vk pair. /// Hence, it is HIGHLY recommended to pair this util with `keygen_pk_custom`. -pub fn keygen_vk_custom( +pub fn keygen_vk_custom( params: &P, circuit: &ConcreteCircuit, compress_selectors: bool, @@ -44,8 +44,8 @@ pub fn keygen_vk_custom( where C: CurveAffine, P: Params, - EF: FieldFr, - ConcreteCircuit: Circuit, + F: FieldFr, + ConcreteCircuit: Circuit, C::Scalar: FromUniformBytes<64>, { let (compiled_circuit, _, _) = compile_circuit(params.k(), circuit, compress_selectors)?; @@ -58,7 +58,7 @@ where /// **NOTE**: This `keygen_pk` is legacy one, assuming that `compress_selector: true`. /// Hence, it is HIGHLY recommended to pair this util with `keygen_vk`. /// In addition, when using this for key generation, user MUST use `compress_selectors: true`. -pub fn keygen_pk( +pub fn keygen_pk( params: &P, vk: VerifyingKey, circuit: &ConcreteCircuit, @@ -66,8 +66,8 @@ pub fn keygen_pk( where C: CurveAffine, P: Params, - EF: FieldFr, - ConcreteCircuit: Circuit, + F: FieldFr, + ConcreteCircuit: Circuit, { keygen_pk_custom(params, vk, circuit, true) } @@ -80,7 +80,7 @@ where /// `VerifyingKey` generation process. /// Otherwise, the user could get unmatching pk/vk pair. /// Hence, it is HIGHLY recommended to pair this util with `keygen_vk_custom`. -pub fn keygen_pk_custom( +pub fn keygen_pk_custom( params: &P, vk: VerifyingKey, circuit: &ConcreteCircuit, @@ -89,8 +89,8 @@ pub fn keygen_pk_custom( where C: CurveAffine, P: Params, - EF: FieldFr, - ConcreteCircuit: Circuit, + F: FieldFr, + ConcreteCircuit: Circuit, { let (compiled_circuit, _, _) = compile_circuit(params.k(), circuit, compress_selectors)?; Ok(backend_keygen_pk(params, vk, &compiled_circuit)?) diff --git a/halo2_proofs/src/plonk/prover.rs b/halo2_proofs/src/plonk/prover.rs index e9c3eb806a..7641ba9efe 100644 --- a/halo2_proofs/src/plonk/prover.rs +++ b/halo2_proofs/src/plonk/prover.rs @@ -23,15 +23,15 @@ pub fn create_proof_with_engine< E: EncodedChallenge, R: RngCore, T: TranscriptWrite, - ConcreteCircuit: Circuit, + ConcreteCircuit: Circuit, M: MsmAccel, - EF: FieldFr, + F: FieldFr, >( engine: PlonkEngine, params: &'params Scheme::ParamsProver, pk: &ProvingKey, circuits: &[ConcreteCircuit], - instances: &[Vec>], + instances: &[Vec>], rng: R, transcript: &mut T, ) -> Result<(), Error> @@ -86,7 +86,7 @@ where }) .collect(); - challenges = prover.commit_phase(*phase, witnesses).unwrap().into_iter().map(|(k, v)| (k, EF::into_field_fr(v))).collect(); + challenges = prover.commit_phase(*phase, witnesses).unwrap().into_iter().map(|(k, v)| (k, F::into_field_fr(v))).collect(); } Ok(prover.create_proof()?) } @@ -102,13 +102,13 @@ pub fn create_proof< E: EncodedChallenge, R: RngCore, T: TranscriptWrite, - ConcreteCircuit: Circuit, - EF: FieldFr, + ConcreteCircuit: Circuit, + F: FieldFr, >( params: &'params Scheme::ParamsProver, pk: &ProvingKey, circuits: &[ConcreteCircuit], - instances: &[Vec>], + instances: &[Vec>], rng: R, transcript: &mut T, ) -> Result<(), Error>