From d0e37cca98364f072a3eb403a3ef35fe8736d1ad Mon Sep 17 00:00:00 2001 From: dante <45801863+alexander-camuto@users.noreply.github.com> Date: Sat, 9 Nov 2024 22:01:13 +0000 Subject: [PATCH] Revert "simplify at touch" This reverts commit 2ba3e70a1965af6e89071e28db7aeebbc3eb6cf9. --- src/circuit/ops/layouts.rs | 19 ++++++++++++++++--- src/tensor/ops.rs | 10 ++++++++-- 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/src/circuit/ops/layouts.rs b/src/circuit/ops/layouts.rs index a1eb0cf2b..41e7664e6 100644 --- a/src/circuit/ops/layouts.rs +++ b/src/circuit/ops/layouts.rs @@ -121,6 +121,17 @@ pub fn diff_less_than Ok(()) } +fn is_positive( + config: &BaseConfig, + region: &mut RegionCtx, + value: &[ValTensor; 1], +) -> Result, CircuitError> { + let neg_one = create_constant_tensor(integer_rep_to_felt(-1), 1); + let is_negative = equals(config, region, &[value[0].clone(), neg_one])?; + + not(config, region, &[is_negative]) +} + /// Div accumulated layout pub(crate) fn div( config: &BaseConfig, @@ -304,10 +315,10 @@ pub fn sqrt( region.increment(claimed_output.len()); // assert value is positive - let sign = sign(config, region, &[claimed_output.clone()])?; - let ones = create_constant_tensor(F::ONE, sign.len()); + let is_positive = is_positive(config, region, &[claimed_output.clone()])?; + let ones = create_constant_tensor(F::ONE, is_positive.len()); // assert the sign is positive - enforce_equality(config, region, &[sign, ones])?; + enforce_equality(config, region, &[is_positive, ones])?; // rescaled input let rescaled_input = pairwise(config, region, &[input.clone(), unit_scale], BaseOp::Mult)?; @@ -4817,6 +4828,8 @@ pub fn ln( comparison_unit.reshape(is_closest.dims())?; let assigned_unit = region.assign(&config.custom_gates.inputs[1], &comparison_unit)?; + println!("is_closest {}", is_closest.show()); + enforce_equality(config, region, &[is_closest, assigned_unit])?; // get a linear interpolation now diff --git a/src/tensor/ops.rs b/src/tensor/ops.rs index b43092050..7ac05cf9a 100644 --- a/src/tensor/ops.rs +++ b/src/tensor/ops.rs @@ -31,8 +31,14 @@ pub fn get_rep( return Err(DecompositionError::TooLarge(*x, base, n)); } let mut rep = vec![0; n + 1]; - // sign bit, we omit 0 as it is not needed in our representation - rep[0] = if *x < 0 { -1 } else { 1 }; + // sign bit + rep[0] = if *x < 0 { + -1 + } else if *x > 0 { + 1 + } else { + 0 + }; let mut x = x.abs(); //