diff --git a/halo2_proofs/src/plonk/circuit.rs b/halo2_proofs/src/plonk/circuit.rs index e59f3dc2f0..6037e6be74 100644 --- a/halo2_proofs/src/plonk/circuit.rs +++ b/halo2_proofs/src/plonk/circuit.rs @@ -11,6 +11,7 @@ use ff::Field; use sealed::SealedPhase; use std::collections::BTreeMap; use std::fmt::Debug; +use std::iter::{Product, Sum}; use std::ops::Range; use std::{ convert::TryFrom, @@ -1375,6 +1376,20 @@ impl Mul for Expression { } } +impl Sum for Expression { + fn sum>(iter: I) -> Self { + iter.reduce(|acc, x| acc + x) + .unwrap_or(Expression::Constant(F::ZERO)) + } +} + +impl Product for Expression { + fn product>(iter: I) -> Self { + iter.reduce(|acc, x| acc * x) + .unwrap_or(Expression::Constant(F::ONE)) + } +} + /// Represents an index into a vector where each entry corresponds to a distinct /// point that polynomials are queried at. #[derive(Copy, Clone, Debug)] @@ -2473,3 +2488,47 @@ impl<'a, F: Field> VirtualCells<'a, F> { Expression::Challenge(challenge) } } + +#[cfg(test)] +mod tests { + use super::Expression; + use halo2curves::bn256::Fr; + + #[test] + fn iter_sum() { + let exprs: Vec> = vec![ + Expression::Constant(1.into()), + Expression::Constant(2.into()), + Expression::Constant(3.into()), + ]; + let happened: Expression = exprs.into_iter().sum(); + let expected: Expression = Expression::Sum( + Box::new(Expression::Sum( + Box::new(Expression::Constant(1.into())), + Box::new(Expression::Constant(2.into())), + )), + Box::new(Expression::Constant(3.into())), + ); + + assert_eq!(happened, expected); + } + + #[test] + fn iter_product() { + let exprs: Vec> = vec![ + Expression::Constant(1.into()), + Expression::Constant(2.into()), + Expression::Constant(3.into()), + ]; + let happened: Expression = exprs.into_iter().product(); + let expected: Expression = Expression::Product( + Box::new(Expression::Product( + Box::new(Expression::Constant(1.into())), + Box::new(Expression::Constant(2.into())), + )), + Box::new(Expression::Constant(3.into())), + ); + + assert_eq!(happened, expected); + } +}