Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-camuto committed Nov 9, 2024
1 parent d5b5bf8 commit 0f91823
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 56 deletions.
63 changes: 33 additions & 30 deletions src/circuit/ops/layouts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,31 +71,29 @@ fn optimum_convex_function<F: PrimeField + TensorType + PartialOrd + std::hash::
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
x: &ValTensor<F>,
ignore_mask: &Option<ValTensor<F>>,
f: impl Fn(&BaseConfig<F>, &mut RegionCtx<F>, &ValTensor<F>) -> Result<ValTensor<F>, CircuitError>,
) -> Result<(), CircuitError> {
let two = create_constant_tensor(F::from(2), 1);
let two = region.assign(&config.custom_gates.inputs[1], &two)?;
region.increment(two.len());
let one = create_constant_tensor(F::from(1), 1);
let one = region.assign(&config.custom_gates.inputs[1], &one)?;
region.increment(one.len());

let f_x = f(config, region, x)?;

let x_plus_2 = pairwise(config, region, &[x.clone(), two.clone()], BaseOp::Add)?;
let f_x_plus_2 = f(config, region, &x_plus_2)?;
let x_plus_1 = pairwise(config, region, &[x.clone(), one.clone()], BaseOp::Add)?;
let f_x_plus_1 = f(config, region, &x_plus_1)?;

let x_minus_2 = pairwise(config, region, &[x.clone(), two.clone()], BaseOp::Sub)?;
let f_x_minus_2 = f(config, region, &x_minus_2)?;
let x_minus_1 = pairwise(config, region, &[x.clone(), one.clone()], BaseOp::Sub)?;
let f_x_minus_1 = f(config, region, &x_minus_1)?;

// because the function is convex, we the result should be the minimum of the three
// not that we offset the x by 2 to get the other two points that due to the convexity of the function and symmetry of convex function, there can be 2
let f_x_is_opt_rhs = less(config, region, &[f_x.clone(), f_x_plus_2])?;
let f_x_is_opt_lhs = less(config, region, &[f_x.clone(), f_x_minus_2])?;
// because the function is convex, the result should be the minimum of the three
// not that we offset the x by 1 to get the next value
// f(x) <= f(x+1) and f(x) <= f(x-1)
// the result is 1 if the function is optimal solely because of the convexity of the function
// the distances can be equal but this is only possible if f(x) and f(x+1) are both optimal (or f(x) and f(x-1)).
let f_x_is_opt_rhs = less_equal(config, region, &[f_x.clone(), f_x_plus_1])?;
let f_x_is_opt_lhs = less_equal(config, region, &[f_x.clone(), f_x_minus_1])?;

let mut is_opt = and(config, region, &[f_x_is_opt_lhs, f_x_is_opt_rhs])?;

if let Some(ignore_mask) = ignore_mask {
is_opt = or(config, region, &[is_opt.clone(), ignore_mask.clone()])?;
}
let is_opt = and(config, region, &[f_x_is_opt_lhs, f_x_is_opt_rhs])?;

let mut comparison_unit = create_constant_tensor(integer_rep_to_felt(1), is_opt.len());
comparison_unit.reshape(is_opt.dims())?;
Expand Down Expand Up @@ -160,7 +158,7 @@ pub(crate) fn div<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
Ok(distance)
};

optimum_convex_function(config, region, &claimed_output, &None, err_func)?;
optimum_convex_function(config, region, &claimed_output, err_func)?;

Ok(claimed_output)
}
Expand Down Expand Up @@ -241,13 +239,7 @@ pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
Ok(distance)
};

optimum_convex_function(
config,
region,
&claimed_output,
&Some(equal_zero_mask),
err_func,
)?;
optimum_convex_function(config, region, &claimed_output, err_func)?;

Ok(claimed_output)
}
Expand Down Expand Up @@ -329,7 +321,7 @@ pub fn sqrt<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
Ok(distance)
};

optimum_convex_function(config, region, &claimed_output, &None, err_func)?;
optimum_convex_function(config, region, &claimed_output, err_func)?;

Ok(claimed_output)
}
Expand Down Expand Up @@ -4798,18 +4790,29 @@ pub fn ln<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
l1_distance(config, region, &[input.clone(), prior_pow2.clone()])?;

// because we round up this can be equal
let is_closest_to_0: ValTensor<F> = less_equal(
let is_closest_to_0: ValTensor<F> = less(
config,
region,
&[abs_distance_to_claimed.clone(), abs_distance_to_next_pow2],
&[
abs_distance_to_claimed.clone(),
abs_distance_to_next_pow2.clone(),
],
)?;

let is_closest_to_1 = less(
config,
region,
&[abs_distance_to_claimed.clone(), abs_distance_to_prior_pow2],
&[
abs_distance_to_claimed.clone(),
abs_distance_to_prior_pow2.clone(),
],
)?;

let is_closest = and(config, region, &[is_closest_to_0, is_closest_to_1])?;
let is_closest = and(
config,
region,
&[is_closest_to_0.clone(), is_closest_to_1.clone()],
)?;

let mut comparison_unit = create_constant_tensor(integer_rep_to_felt(1), is_closest.len());
comparison_unit.reshape(is_closest.dims())?;
Expand Down
14 changes: 11 additions & 3 deletions src/tensor/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1546,9 +1546,17 @@ pub mod nonlinearities {
pub fn ilog2(a: &Tensor<IntegerRep>, scale_input: f64) -> Tensor<IntegerRep> {
a.par_enum_map(|_, a_i| {
let kix = (a_i as f64) / scale_input;
let kix = (kix).log2();
let rounded = kix.round();
Ok::<_, TensorError>(rounded as IntegerRep)
let log = (kix).log2();
let floor = log.floor();
let ceil = log.ceil();
let floor_dist = ((2.0_f64).powf(floor) - kix).abs();
let ceil_dist = (kix - (2.0_f64).powf(ceil)).abs();

if floor_dist < ceil_dist {
Ok::<_, TensorError>(floor as IntegerRep)
} else {
Ok::<_, TensorError>(ceil as IntegerRep)
}
})
.unwrap()
}
Expand Down
Loading

0 comments on commit 0f91823

Please sign in to comment.