From 99c53ed8cd872fcab152753bb118d7a131fda792 Mon Sep 17 00:00:00 2001 From: Soham Zemse <22412996+zemse@users.noreply.github.com> Date: Mon, 9 Sep 2024 13:52:58 +0530 Subject: [PATCH] combine `assert` and `less_than` --- ceno_zkvm/src/chip_handler/general.rs | 67 ++++++--------- ceno_zkvm/src/instructions/riscv/addsub.rs | 15 ++-- ceno_zkvm/src/scheme/mock_prover.rs | 99 ++++++++++++++++++++-- 3 files changed, 127 insertions(+), 54 deletions(-) diff --git a/ceno_zkvm/src/chip_handler/general.rs b/ceno_zkvm/src/chip_handler/general.rs index 893a02eec..d1d21b051 100644 --- a/ceno_zkvm/src/chip_handler/general.rs +++ b/ceno_zkvm/src/chip_handler/general.rs @@ -11,15 +11,11 @@ use crate::{ structs::ROMType, }; +// TODO move to correct file #[derive(Debug)] pub struct LtWtns { pub is_lt: Option, - pub diff_lo: WitIn, - pub diff_hi: WitIn, - #[cfg(feature = "riv64")] - pub diff_lo_2: WitIn, - #[cfg(feature = "riv64")] - pub diff_hi_2: WitIn, + pub diff: Vec, } impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { @@ -307,7 +303,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { (Some(is_lt), is_lt.expr()) }; - let mut witin_u16 = |var_name: &str| -> Result { + let mut witin_u16 = |var_name: String| -> Result { cb.namespace( || format!("var {var_name}"), |cb| { @@ -318,44 +314,35 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { ) }; - let diff_lo = witin_u16("diff_lo")?; - let diff_hi = witin_u16("diff_hi")?; - #[cfg(feature = "riv64")] - let diff_lo_2 = witin_u16("diff_lo_2")?; - #[cfg(feature = "riv64")] - let diff_hi_2 = witin_u16("diff_hi_2")?; + let diff = { + #[cfg(feature = "riv32")] + { + 0..2 + } + #[cfg(feature = "riv64")] + { + 0..4 + } + } + .map(|i| witin_u16(format!("diff_{i}"))) + .collect::, _>>()?; + + let diff_expr = diff + .iter() + .enumerate() + .map(|(i, diff)| (i, diff.expr())) + .fold(Expression::ZERO, |sum, (i, a)| { + sum + if i > 0 { a * (1 << (16 * i)).into() } else { a } + }); #[cfg(feature = "riv32")] let range = Expression::Constant((u32::MAX as u64).into()); #[cfg(feature = "riv64")] - let range = Expression::Constant(u64::MAX.into()); + let range = Expression::Constant(u64::MAX.into()); // TODO beyond modulus - cb.require_equal( - || name.clone(), - lhs - rhs, - #[cfg(feature = "riv32")] - { - diff_lo.expr() + diff_hi.expr() * (1 << 16).into() - is_lt_expr * range - }, - #[cfg(feature = "riv64")] - { - diff_lo.expr() - + diff_hi.expr() * (1 << 16).into() - + diff_lo_2.expr() * (1 << 32).into() - + diff_hi_2.expr() * (1 << 48).into() - - is_lt.expr() * range - }, - )?; - - Ok(LtWtns { - is_lt, - diff_lo, - diff_hi, - #[cfg(feature = "riv64")] - diff_lo_2, - #[cfg(feature = "riv64")] - diff_hi_2, - }) + cb.require_equal(|| name.clone(), lhs - rhs, diff_expr - is_lt_expr * range)?; + + Ok(LtWtns { is_lt, diff }) }, ) } diff --git a/ceno_zkvm/src/instructions/riscv/addsub.rs b/ceno_zkvm/src/instructions/riscv/addsub.rs index d0e288052..2022d8725 100644 --- a/ceno_zkvm/src/instructions/riscv/addsub.rs +++ b/ceno_zkvm/src/instructions/riscv/addsub.rs @@ -212,17 +212,14 @@ impl Instruction for AddInstruction { let u16_max = u16::MAX as u64; - set_val!(instance, config.lt_wtns_rs1.is_lt, 1); - set_val!(instance, config.lt_wtns_rs1.diff_lo, u16_max - 2 + 1); // range - lhs + rhs - set_val!(instance, config.lt_wtns_rs1.diff_hi, u16_max); + set_val!(instance, config.lt_wtns_rs1.diff[0], u16_max - 2 + 1); // range - lhs + rhs + set_val!(instance, config.lt_wtns_rs1.diff[1], u16_max); - set_val!(instance, config.lt_wtns_rs2.is_lt, 1); - set_val!(instance, config.lt_wtns_rs2.diff_lo, u16_max - 3 + 2); // range - lhs + rhs - set_val!(instance, config.lt_wtns_rs2.diff_hi, u16_max); + set_val!(instance, config.lt_wtns_rs2.diff[0], u16_max - 3 + 2); // range - lhs + rhs + set_val!(instance, config.lt_wtns_rs2.diff[1], u16_max); - set_val!(instance, config.lt_wtns_rd.is_lt, 1); - set_val!(instance, config.lt_wtns_rd.diff_lo, u16_max - 3 + 2); // range - lhs + rhs - set_val!(instance, config.lt_wtns_rd.diff_hi, u16_max); + set_val!(instance, config.lt_wtns_rd.diff[0], u16_max - 3 + 2); // range - lhs + rhs + set_val!(instance, config.lt_wtns_rd.diff[1], u16_max); Ok(()) } } diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index 2ec487517..3f3eb6adb 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -534,6 +534,97 @@ mod tests { ); } + #[allow(dead_code)] + #[derive(Debug)] + struct AssertLtCircuit { + pub a: WitIn, + pub b: WitIn, + pub lt_wtns: LtWtns, + } + + impl AssertLtCircuit { + fn construct_circuit(cb: &mut CircuitBuilder) -> Result { + let a = cb.create_witin(|| "a")?; + let b = cb.create_witin(|| "b")?; + let lt_wtns = cb.less_than(|| "lt", a.expr(), b.expr(), Some(true))?; + Ok(Self { a, b, lt_wtns }) + } + } + + #[test] + fn test_assert_lt_1() { + let mut cs = ConstraintSystem::new(|| "test_lt_1"); + let mut builder = CircuitBuilder::::new(&mut cs); + + let _ = AssertLtCircuit::construct_circuit(&mut builder).unwrap(); + + let wits_in = vec![ + vec![Goldilocks::from(3u64)].into_mle().into(), + vec![Goldilocks::from(5u64)].into_mle().into(), + vec![Goldilocks::from(u16::MAX as u64 + 3 - 5)] + .into_mle() + .into(), + vec![Goldilocks::from(u16::MAX as u64)].into_mle().into(), + #[cfg(feature = "riv64")] + vec![Goldilocks::from(u16::MAX as u64)].into_mle().into(), + #[cfg(feature = "riv64")] + vec![Goldilocks::from(u16::MAX as u64)].into_mle().into(), + ]; + + MockProver::assert_satisfied(&mut builder, &wits_in, None); + } + + #[test] + fn test_assert_lt_u32() { + let mut cs = ConstraintSystem::new(|| "test_lt_u32"); + let mut builder = CircuitBuilder::::new(&mut cs); + + let _ = AssertLtCircuit::construct_circuit(&mut builder).unwrap(); + + let wits_in = vec![ + vec![Goldilocks::from(u32::MAX as u64 - 5)] + .into_mle() + .into(), + vec![Goldilocks::from(u32::MAX as u64 - 3)] + .into_mle() + .into(), + vec![Goldilocks::from(u16::MAX as u64 + 3 - 5)] + .into_mle() + .into(), + vec![Goldilocks::from(u16::MAX as u64)].into_mle().into(), + #[cfg(feature = "riv64")] + vec![Goldilocks::from(u16::MAX as u64)].into_mle().into(), + #[cfg(feature = "riv64")] + vec![Goldilocks::from(u16::MAX as u64)].into_mle().into(), + ]; + + MockProver::assert_satisfied(&mut builder, &wits_in, None); + } + + #[test] + #[cfg(feature = "riv64")] + fn test_assert_lt_u64() { + let mut cs = ConstraintSystem::new(|| "test_lt_u64"); + let mut builder = CircuitBuilder::::new(&mut cs); + + let _ = AssertLtCircuit::construct_circuit(&mut builder).unwrap(); + + let wits_in = vec![ + vec![Goldilocks::from(u64::MAX - 5)].into_mle().into(), + vec![Goldilocks::from(u64::MAX - 3)].into_mle().into(), + vec![Goldilocks::from(u16::MAX as u64 + 3 - 5)] + .into_mle() + .into(), + vec![Goldilocks::from(u16::MAX as u64)].into_mle().into(), + #[cfg(feature = "riv64")] + vec![Goldilocks::from(u16::MAX as u64)].into_mle().into(), + #[cfg(feature = "riv64")] + vec![Goldilocks::from(u16::MAX as u64)].into_mle().into(), + ]; + + MockProver::assert_satisfied(&mut builder, &wits_in, None); + } + #[allow(dead_code)] #[derive(Debug)] struct LtCircuit { @@ -543,19 +634,17 @@ mod tests { } impl LtCircuit { - fn construct_circuit( - cb: &mut CircuitBuilder, - ) -> Result { + fn construct_circuit(cb: &mut CircuitBuilder) -> Result { let a = cb.create_witin(|| "a")?; let b = cb.create_witin(|| "b")?; - let lt_wtns = cb.less_than(|| "lt", a.expr(), b.expr(), Some(true))?; + let lt_wtns = cb.less_than(|| "lt", a.expr(), b.expr(), None)?; Ok(Self { a, b, lt_wtns }) } } #[test] fn test_lt_1() { - let mut cs = ConstraintSystem::new(|| "test_lt_1"); + let mut cs = ConstraintSystem::new(|| "test_lt2_u32"); let mut builder = CircuitBuilder::::new(&mut cs); let _ = LtCircuit::construct_circuit(&mut builder).unwrap();