Skip to content

Commit

Permalink
Update layouts.rs
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-camuto committed Nov 8, 2024
1 parent 203fdc1 commit 7e3c3ff
Showing 1 changed file with 73 additions and 215 deletions.
288 changes: 73 additions & 215 deletions src/circuit/ops/layouts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,39 +66,40 @@ pub fn l1_distance<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(

/// Determines if from a set of 3 tensors the 1st is closest to a reference tensor.
/// should only be used in the context of a monotonic function like the product used in the division, recipe, and sqrt arguments;
/// or the increasing powers of 2 in the ln argument.
fn is_closest_to<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
/// or the increasing powers of 2 in the ln argument. Which is used to construct a convex error function.
fn optimum_convex_function<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 3],
reference: &[ValTensor<F>; 1],
x: &ValTensor<F>,
f: impl Fn(&BaseConfig<F>, &mut RegionCtx<F>, &ValTensor<F>) -> Result<ValTensor<F>, CircuitError>,
) -> Result<(), CircuitError> {
let l1_distance_0 = l1_distance(config, region, &[values[0].clone(), reference[0].clone()])?;
let l1_distance_1 = l1_distance(config, region, &[values[1].clone(), reference[0].clone()])?;
let l1_distance_2 = l1_distance(config, region, &[values[2].clone(), reference[0].clone()])?;
let two = create_constant_tensor(F::from(2), 1);
let two = region.assign(&config.custom_gates.inputs[1], &two)?;
region.increment(two.len());

// one might expect this to be unsound as if both l1_distance_0 and l1_distance_1 AND l1_distance_2 are the same then one could expect the solution to not be unique.
// however if l1_distance_0 and l1_distance_1 are the same then l1_distance_2 must be different for a monotonic function like the product used in the division algorithm.
let is_closest_to_0 = less_equal(config, region, &[l1_distance_0.clone(), l1_distance_1])?;
let is_closest_to_1 = less_equal(config, region, &[l1_distance_0, l1_distance_2])?;
let f_x = f(config, region, x)?;

// if we wanted to be more explicit about this condition we would:
// let is_equal_0 = equals(config, region, &[l1_distance_0.clone(), l1_distance_1])?;
// let is_equal_1 = equals(config, region, &[l1_distance_0, l1_distance_2])?;
// let both_equal = and(config, region, &[is_equal_0, is_equal_1])?;
// enforce_equality(config, region, &[both_equal, F::ZERO])?;
let x_plus_2 = pairwise(config, region, &[x.clone(), two.clone()], BaseOp::Add)?;
let f_x_plus_2 = f(config, region, &x_plus_2)?;

let is_closest = and(config, region, &[is_closest_to_0, is_closest_to_1])?;
let x_minus_2 = pairwise(config, region, &[x.clone(), two.clone()], BaseOp::Sub)?;
let f_x_minus_2 = f(config, region, &x_minus_2)?;

let mut comparison_unit = create_constant_tensor(integer_rep_to_felt(1), is_closest.len());
// because the function is convex, we the result should be the minimum of the three
// not that we offset the x by 2 to get the other two points that due to the convexity of the function and symmetry of convex function, there can be 2
let f_x_is_opt_rhs = less(config, region, &[f_x.clone(), f_x_plus_2])?;
let f_x_is_opt_lhs = less(config, region, &[f_x.clone(), f_x_minus_2])?;

comparison_unit.reshape(is_closest.dims())?;
let is_opt = and(config, region, &[f_x_is_opt_lhs, f_x_is_opt_rhs])?;

let mut comparison_unit = create_constant_tensor(integer_rep_to_felt(1), is_opt.len());
comparison_unit.reshape(is_opt.dims())?;
// assigned unit
let assigned_unit = region.assign(&config.custom_gates.inputs[1], &comparison_unit)?;
region.increment(assigned_unit.len());

// assert that the result is 1
enforce_equality(config, region, &[is_closest, assigned_unit])?;
enforce_equality(config, region, &[is_opt, assigned_unit])?;

Ok(())
}
Expand Down Expand Up @@ -145,55 +146,16 @@ pub(crate) fn div<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
region.assign(&config.custom_gates.output, &claimed_output)?;
region.increment(claimed_output.len());

let product = pairwise(
config,
region,
&[claimed_output.clone(), divisor.clone()],
BaseOp::Mult,
)?;

// take the claimed output and subtract 1
let one = create_constant_tensor(F::ONE, 1);
let one = region.assign(&config.custom_gates.inputs[1], &one)?;

let claimed_output_minus_one = pairwise(
config,
region,
&[claimed_output.clone(), one.clone()],
BaseOp::Sub,
)?;

let claimed_output_minus_one_product = pairwise(
config,
region,
&[claimed_output_minus_one.clone(), divisor.clone()],
BaseOp::Mult,
)?;

let claimed_output_plus_one = pairwise(
config,
region,
&[claimed_output.clone(), one.clone()],
BaseOp::Add,
)?;

let claimed_output_plus_one_product = pairwise(
config,
region,
&[claimed_output_plus_one.clone(), divisor.clone()],
BaseOp::Mult,
)?;
let err_func = |config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
x: &ValTensor<F>|
-> Result<ValTensor<F>, CircuitError> {
let product = pairwise(config, region, &[x.clone(), divisor.clone()], BaseOp::Mult)?;
let distance = l1_distance(config, region, &[product, input.clone()])?;
Ok(distance)
};

is_closest_to(
config,
region,
&[
product,
claimed_output_minus_one_product,
claimed_output_plus_one_product,
],
&[input.clone()],
)?;
optimum_convex_function(config, region, &claimed_output, err_func)?;

Ok(claimed_output)
}
Expand All @@ -209,9 +171,6 @@ pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
let input = value[0].clone();
let input_dims = input.dims();

let one = create_constant_tensor(F::ONE, 1);
let one = region.assign(&config.custom_gates.inputs[0], &one)?;

let unit_scale = create_constant_tensor(output_scale * input_scale, 1);
let unit_scale = region.assign(&config.custom_gates.inputs[1], &unit_scale)?;
region.increment(1);
Expand Down Expand Up @@ -256,89 +215,16 @@ pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
&[equal_zero_mask.clone(), equal_inverse_mask],
)?;

// this is now of scale 2 * scale
let product = pairwise(
config,
region,
&[claimed_output.clone(), input.clone()],
BaseOp::Mult,
)?;

let claimed_output_minus_one = pairwise(
config,
region,
&[claimed_output.clone(), one.clone()],
BaseOp::Sub,
)?;

let claimed_output_minus_one_product = pairwise(
config,
region,
&[claimed_output_minus_one.clone(), input.clone()],
BaseOp::Mult,
)?;

let claimed_output_plus_one = pairwise(
config,
region,
&[claimed_output.clone(), one.clone()],
BaseOp::Add,
)?;

let claimed_output_plus_one_product = pairwise(
config,
region,
&[claimed_output_plus_one.clone(), input.clone()],
BaseOp::Mult,
)?;

let scaled_equal_zero_mask = pairwise(
config,
region,
&[equal_zero_mask.clone(), unit_scale.clone()],
BaseOp::Mult,
)?;

// add 1 where the mask is 0
let product_masked = pairwise(
config,
region,
&[product.clone(), scaled_equal_zero_mask.clone()],
BaseOp::Add,
)?;

// add 1 where the mask is 0
let claimed_output_minus_one_product_masked = pairwise(
config,
region,
&[
claimed_output_minus_one_product.clone(),
scaled_equal_zero_mask.clone(),
],
BaseOp::Add,
)?;

// add 1 where the mask is 0
let claimed_output_plus_one_product_masked = pairwise(
config,
region,
&[
claimed_output_plus_one_product.clone(),
scaled_equal_zero_mask.clone(),
],
BaseOp::Add,
)?;
let err_func = |config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
x: &ValTensor<F>|
-> Result<ValTensor<F>, CircuitError> {
let product = pairwise(config, region, &[x.clone(), input.clone()], BaseOp::Mult)?;
let distance = l1_distance(config, region, &[product.clone(), unit_scale.clone()])?;
Ok(distance)
};

is_closest_to(
config,
region,
&[
product_masked,
claimed_output_minus_one_product_masked,
claimed_output_plus_one_product_masked,
],
&[unit_scale],
)?;
optimum_convex_function(config, region, &claimed_output, err_func)?;

Ok(claimed_output)
}
Expand Down Expand Up @@ -374,9 +260,6 @@ pub fn sqrt<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
let input = value[0].clone();
let input_dims = input.dims();

let one = create_constant_tensor(F::ONE, 1);
let one = region.assign(&config.custom_gates.inputs[0], &one)?;

let unit_scale = create_constant_tensor(integer_rep_to_felt(input_scale.0 as IntegerRep), 1);
let unit_scale = region.assign(&config.custom_gates.inputs[1], &unit_scale)?;
region.increment(1);
Expand Down Expand Up @@ -411,61 +294,19 @@ pub fn sqrt<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
// assert the sign is positive
enforce_equality(config, region, &[sign, ones.clone()])?;

// this is now of scale 2 * scale
let product = pairwise(
config,
region,
&[claimed_output.clone(), claimed_output.clone()],
BaseOp::Mult,
)?;

let claimed_output_minus_one = pairwise(
config,
region,
&[claimed_output.clone(), one.clone()],
BaseOp::Sub,
)?;

let claimed_output_minus_one_product = pairwise(
config,
region,
&[
claimed_output_minus_one.clone(),
claimed_output_minus_one.clone(),
],
BaseOp::Mult,
)?;

let claimed_output_plus_one = pairwise(
config,
region,
&[claimed_output.clone(), one.clone()],
BaseOp::Add,
)?;

let claimed_output_plus_one_product = pairwise(
config,
region,
&[
claimed_output_plus_one.clone(),
claimed_output_plus_one.clone(),
],
BaseOp::Mult,
)?;

// rescaled input
let rescaled_input = pairwise(config, region, &[input.clone(), unit_scale], BaseOp::Mult)?;

is_closest_to(
config,
region,
&[
product,
claimed_output_minus_one_product,
claimed_output_plus_one_product,
],
&[rescaled_input],
)?;
let err_func = |config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
x: &ValTensor<F>|
-> Result<ValTensor<F>, CircuitError> {
let product = pairwise(config, region, &[x.clone(), x.clone()], BaseOp::Mult)?;
let distance = l1_distance(config, region, &[product.clone(), rescaled_input.clone()])?;
Ok(distance)
};

optimum_convex_function(config, region, &claimed_output, err_func)?;

Ok(claimed_output)
}
Expand Down Expand Up @@ -4925,16 +4766,33 @@ pub fn ln<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
BaseOp::Sub,
)?;

is_closest_to(
let abs_distance_to_claimed = abs(config, region, &[distance_to_claimed.clone()])?;

let abs_distance_to_next_pow2 =
l1_distance(config, region, &[input.clone(), next_pow2.clone()])?;

let abs_distance_to_prior_pow2 =
l1_distance(config, region, &[input.clone(), prior_pow2.clone()])?;

// because we round up this can be equal
let is_closest_to_0: ValTensor<F> = less_equal(
config,
region,
&[
pow2_of_claimed_output.clone(),
prior_pow2.clone(),
next_pow2.clone(),
],
&[input.clone()],
&[abs_distance_to_claimed.clone(), abs_distance_to_next_pow2],
)?;
let is_closest_to_1 = less(
config,
region,
&[abs_distance_to_claimed.clone(), abs_distance_to_prior_pow2],
)?;

let is_closest = and(config, region, &[is_closest_to_0, is_closest_to_1])?;

let mut comparison_unit = create_constant_tensor(integer_rep_to_felt(1), is_closest.len());
comparison_unit.reshape(is_closest.dims())?;
let assigned_unit = region.assign(&config.custom_gates.inputs[1], &comparison_unit)?;

enforce_equality(config, region, &[is_closest, assigned_unit])?;

// get a linear interpolation now

Expand Down

0 comments on commit 7e3c3ff

Please sign in to comment.