Skip to content

Commit

Permalink
combine assert and less_than
Browse files Browse the repository at this point in the history
  • Loading branch information
zemse committed Sep 9, 2024
1 parent 6bee0a1 commit 99c53ed
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 54 deletions.
67 changes: 27 additions & 40 deletions ceno_zkvm/src/chip_handler/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,11 @@ use crate::{
structs::ROMType,
};

// TODO move to correct file
#[derive(Debug)]
pub struct LtWtns {
pub is_lt: Option<WitIn>,
pub diff_lo: WitIn,
pub diff_hi: WitIn,
#[cfg(feature = "riv64")]
pub diff_lo_2: WitIn,
#[cfg(feature = "riv64")]
pub diff_hi_2: WitIn,
pub diff: Vec<WitIn>,
}

impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
Expand Down Expand Up @@ -307,7 +303,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
(Some(is_lt), is_lt.expr())
};

let mut witin_u16 = |var_name: &str| -> Result<WitIn, ZKVMError> {
let mut witin_u16 = |var_name: String| -> Result<WitIn, ZKVMError> {
cb.namespace(
|| format!("var {var_name}"),
|cb| {
Expand All @@ -318,44 +314,35 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
)
};

let diff_lo = witin_u16("diff_lo")?;
let diff_hi = witin_u16("diff_hi")?;
#[cfg(feature = "riv64")]
let diff_lo_2 = witin_u16("diff_lo_2")?;
#[cfg(feature = "riv64")]
let diff_hi_2 = witin_u16("diff_hi_2")?;
let diff = {
#[cfg(feature = "riv32")]
{
0..2
}
#[cfg(feature = "riv64")]
{
0..4
}
}
.map(|i| witin_u16(format!("diff_{i}")))
.collect::<Result<Vec<WitIn>, _>>()?;

let diff_expr = diff
.iter()
.enumerate()
.map(|(i, diff)| (i, diff.expr()))
.fold(Expression::ZERO, |sum, (i, a)| {
sum + if i > 0 { a * (1 << (16 * i)).into() } else { a }
});

#[cfg(feature = "riv32")]
let range = Expression::Constant((u32::MAX as u64).into());
#[cfg(feature = "riv64")]
let range = Expression::Constant(u64::MAX.into());
let range = Expression::Constant(u64::MAX.into()); // TODO beyond modulus

cb.require_equal(
|| name.clone(),
lhs - rhs,
#[cfg(feature = "riv32")]
{
diff_lo.expr() + diff_hi.expr() * (1 << 16).into() - is_lt_expr * range
},
#[cfg(feature = "riv64")]
{
diff_lo.expr()
+ diff_hi.expr() * (1 << 16).into()
+ diff_lo_2.expr() * (1 << 32).into()
+ diff_hi_2.expr() * (1 << 48).into()
- is_lt.expr() * range
},
)?;

Ok(LtWtns {
is_lt,
diff_lo,
diff_hi,
#[cfg(feature = "riv64")]
diff_lo_2,
#[cfg(feature = "riv64")]
diff_hi_2,
})
cb.require_equal(|| name.clone(), lhs - rhs, diff_expr - is_lt_expr * range)?;

Ok(LtWtns { is_lt, diff })
},
)
}
Expand Down
15 changes: 6 additions & 9 deletions ceno_zkvm/src/instructions/riscv/addsub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,17 +212,14 @@ impl<E: ExtensionField> Instruction<E> for AddInstruction {

let u16_max = u16::MAX as u64;

set_val!(instance, config.lt_wtns_rs1.is_lt, 1);
set_val!(instance, config.lt_wtns_rs1.diff_lo, u16_max - 2 + 1); // range - lhs + rhs
set_val!(instance, config.lt_wtns_rs1.diff_hi, u16_max);
set_val!(instance, config.lt_wtns_rs1.diff[0], u16_max - 2 + 1); // range - lhs + rhs
set_val!(instance, config.lt_wtns_rs1.diff[1], u16_max);

set_val!(instance, config.lt_wtns_rs2.is_lt, 1);
set_val!(instance, config.lt_wtns_rs2.diff_lo, u16_max - 3 + 2); // range - lhs + rhs
set_val!(instance, config.lt_wtns_rs2.diff_hi, u16_max);
set_val!(instance, config.lt_wtns_rs2.diff[0], u16_max - 3 + 2); // range - lhs + rhs
set_val!(instance, config.lt_wtns_rs2.diff[1], u16_max);

set_val!(instance, config.lt_wtns_rd.is_lt, 1);
set_val!(instance, config.lt_wtns_rd.diff_lo, u16_max - 3 + 2); // range - lhs + rhs
set_val!(instance, config.lt_wtns_rd.diff_hi, u16_max);
set_val!(instance, config.lt_wtns_rd.diff[0], u16_max - 3 + 2); // range - lhs + rhs
set_val!(instance, config.lt_wtns_rd.diff[1], u16_max);
Ok(())
}
}
Expand Down
99 changes: 94 additions & 5 deletions ceno_zkvm/src/scheme/mock_prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,97 @@ mod tests {
);
}

#[allow(dead_code)]
#[derive(Debug)]
struct AssertLtCircuit {
pub a: WitIn,
pub b: WitIn,
pub lt_wtns: LtWtns,
}

impl AssertLtCircuit {
fn construct_circuit(cb: &mut CircuitBuilder<GoldilocksExt2>) -> Result<Self, ZKVMError> {
let a = cb.create_witin(|| "a")?;
let b = cb.create_witin(|| "b")?;
let lt_wtns = cb.less_than(|| "lt", a.expr(), b.expr(), Some(true))?;
Ok(Self { a, b, lt_wtns })
}
}

#[test]
fn test_assert_lt_1() {
let mut cs = ConstraintSystem::new(|| "test_lt_1");
let mut builder = CircuitBuilder::<GoldilocksExt2>::new(&mut cs);

let _ = AssertLtCircuit::construct_circuit(&mut builder).unwrap();

let wits_in = vec![
vec![Goldilocks::from(3u64)].into_mle().into(),
vec![Goldilocks::from(5u64)].into_mle().into(),
vec![Goldilocks::from(u16::MAX as u64 + 3 - 5)]
.into_mle()
.into(),
vec![Goldilocks::from(u16::MAX as u64)].into_mle().into(),
#[cfg(feature = "riv64")]
vec![Goldilocks::from(u16::MAX as u64)].into_mle().into(),
#[cfg(feature = "riv64")]
vec![Goldilocks::from(u16::MAX as u64)].into_mle().into(),
];

MockProver::assert_satisfied(&mut builder, &wits_in, None);
}

#[test]
fn test_assert_lt_u32() {
let mut cs = ConstraintSystem::new(|| "test_lt_u32");
let mut builder = CircuitBuilder::<GoldilocksExt2>::new(&mut cs);

let _ = AssertLtCircuit::construct_circuit(&mut builder).unwrap();

let wits_in = vec![
vec![Goldilocks::from(u32::MAX as u64 - 5)]
.into_mle()
.into(),
vec![Goldilocks::from(u32::MAX as u64 - 3)]
.into_mle()
.into(),
vec![Goldilocks::from(u16::MAX as u64 + 3 - 5)]
.into_mle()
.into(),
vec![Goldilocks::from(u16::MAX as u64)].into_mle().into(),
#[cfg(feature = "riv64")]
vec![Goldilocks::from(u16::MAX as u64)].into_mle().into(),
#[cfg(feature = "riv64")]
vec![Goldilocks::from(u16::MAX as u64)].into_mle().into(),
];

MockProver::assert_satisfied(&mut builder, &wits_in, None);
}

#[test]
#[cfg(feature = "riv64")]
fn test_assert_lt_u64() {
let mut cs = ConstraintSystem::new(|| "test_lt_u64");
let mut builder = CircuitBuilder::<GoldilocksExt2>::new(&mut cs);

let _ = AssertLtCircuit::construct_circuit(&mut builder).unwrap();

let wits_in = vec![
vec![Goldilocks::from(u64::MAX - 5)].into_mle().into(),
vec![Goldilocks::from(u64::MAX - 3)].into_mle().into(),
vec![Goldilocks::from(u16::MAX as u64 + 3 - 5)]
.into_mle()
.into(),
vec![Goldilocks::from(u16::MAX as u64)].into_mle().into(),
#[cfg(feature = "riv64")]
vec![Goldilocks::from(u16::MAX as u64)].into_mle().into(),
#[cfg(feature = "riv64")]
vec![Goldilocks::from(u16::MAX as u64)].into_mle().into(),
];

MockProver::assert_satisfied(&mut builder, &wits_in, None);
}

#[allow(dead_code)]
#[derive(Debug)]
struct LtCircuit {
Expand All @@ -543,19 +634,17 @@ mod tests {
}

impl LtCircuit {
fn construct_circuit(
cb: &mut CircuitBuilder<GoldilocksExt2>,
) -> Result<LtCircuit, ZKVMError> {
fn construct_circuit(cb: &mut CircuitBuilder<GoldilocksExt2>) -> Result<Self, ZKVMError> {
let a = cb.create_witin(|| "a")?;
let b = cb.create_witin(|| "b")?;
let lt_wtns = cb.less_than(|| "lt", a.expr(), b.expr(), Some(true))?;
let lt_wtns = cb.less_than(|| "lt", a.expr(), b.expr(), None)?;
Ok(Self { a, b, lt_wtns })
}
}

#[test]
fn test_lt_1() {
let mut cs = ConstraintSystem::new(|| "test_lt_1");
let mut cs = ConstraintSystem::new(|| "test_lt2_u32");
let mut builder = CircuitBuilder::<GoldilocksExt2>::new(&mut cs);

let _ = LtCircuit::construct_circuit(&mut builder).unwrap();
Expand Down

0 comments on commit 99c53ed

Please sign in to comment.