diff --git a/src/circuit/ops/layouts.rs b/src/circuit/ops/layouts.rs index ceda6f3b9..2461243c9 100644 --- a/src/circuit/ops/layouts.rs +++ b/src/circuit/ops/layouts.rs @@ -66,39 +66,40 @@ pub fn l1_distance( /// Determines if from a set of 3 tensors the 1st is closest to a reference tensor. /// should only be used in the context of a monotonic function like the product used in the division, recipe, and sqrt arguments; -/// or the increasing powers of 2 in the ln argument. -fn is_closest_to( +/// or the increasing powers of 2 in the ln argument. Which is used to construct a convex error function. +fn optimum_convex_function( config: &BaseConfig, region: &mut RegionCtx, - values: &[ValTensor; 3], - reference: &[ValTensor; 1], + x: &ValTensor, + f: impl Fn(&BaseConfig, &mut RegionCtx, &ValTensor) -> Result, CircuitError>, ) -> Result<(), CircuitError> { - let l1_distance_0 = l1_distance(config, region, &[values[0].clone(), reference[0].clone()])?; - let l1_distance_1 = l1_distance(config, region, &[values[1].clone(), reference[0].clone()])?; - let l1_distance_2 = l1_distance(config, region, &[values[2].clone(), reference[0].clone()])?; + let two = create_constant_tensor(F::from(2), 1); + let two = region.assign(&config.custom_gates.inputs[1], &two)?; + region.increment(two.len()); - // one might expect this to be unsound as if both l1_distance_0 and l1_distance_1 AND l1_distance_2 are the same then one could expect the solution to not be unique. - // however if l1_distance_0 and l1_distance_1 are the same then l1_distance_2 must be different for a monotonic function like the product used in the division algorithm. - let is_closest_to_0 = less_equal(config, region, &[l1_distance_0.clone(), l1_distance_1])?; - let is_closest_to_1 = less_equal(config, region, &[l1_distance_0, l1_distance_2])?; + let f_x = f(config, region, x)?; - // if we wanted to be more explicit about this condition we would: - // let is_equal_0 = equals(config, region, &[l1_distance_0.clone(), l1_distance_1])?; - // let is_equal_1 = equals(config, region, &[l1_distance_0, l1_distance_2])?; - // let both_equal = and(config, region, &[is_equal_0, is_equal_1])?; - // enforce_equality(config, region, &[both_equal, F::ZERO])?; + let x_plus_2 = pairwise(config, region, &[x.clone(), two.clone()], BaseOp::Add)?; + let f_x_plus_2 = f(config, region, &x_plus_2)?; - let is_closest = and(config, region, &[is_closest_to_0, is_closest_to_1])?; + let x_minus_2 = pairwise(config, region, &[x.clone(), two.clone()], BaseOp::Sub)?; + let f_x_minus_2 = f(config, region, &x_minus_2)?; - let mut comparison_unit = create_constant_tensor(integer_rep_to_felt(1), is_closest.len()); + // because the function is convex, we the result should be the minimum of the three + // not that we offset the x by 2 to get the other two points that due to the convexity of the function and symmetry of convex function, there can be 2 + let f_x_is_opt_rhs = less(config, region, &[f_x.clone(), f_x_plus_2])?; + let f_x_is_opt_lhs = less(config, region, &[f_x.clone(), f_x_minus_2])?; - comparison_unit.reshape(is_closest.dims())?; + let is_opt = and(config, region, &[f_x_is_opt_lhs, f_x_is_opt_rhs])?; + + let mut comparison_unit = create_constant_tensor(integer_rep_to_felt(1), is_opt.len()); + comparison_unit.reshape(is_opt.dims())?; // assigned unit let assigned_unit = region.assign(&config.custom_gates.inputs[1], &comparison_unit)?; region.increment(assigned_unit.len()); // assert that the result is 1 - enforce_equality(config, region, &[is_closest, assigned_unit])?; + enforce_equality(config, region, &[is_opt, assigned_unit])?; Ok(()) } @@ -145,55 +146,16 @@ pub(crate) fn div( region.assign(&config.custom_gates.output, &claimed_output)?; region.increment(claimed_output.len()); - let product = pairwise( - config, - region, - &[claimed_output.clone(), divisor.clone()], - BaseOp::Mult, - )?; - - // take the claimed output and subtract 1 - let one = create_constant_tensor(F::ONE, 1); - let one = region.assign(&config.custom_gates.inputs[1], &one)?; - - let claimed_output_minus_one = pairwise( - config, - region, - &[claimed_output.clone(), one.clone()], - BaseOp::Sub, - )?; - - let claimed_output_minus_one_product = pairwise( - config, - region, - &[claimed_output_minus_one.clone(), divisor.clone()], - BaseOp::Mult, - )?; - - let claimed_output_plus_one = pairwise( - config, - region, - &[claimed_output.clone(), one.clone()], - BaseOp::Add, - )?; - - let claimed_output_plus_one_product = pairwise( - config, - region, - &[claimed_output_plus_one.clone(), divisor.clone()], - BaseOp::Mult, - )?; + let err_func = |config: &BaseConfig, + region: &mut RegionCtx, + x: &ValTensor| + -> Result, CircuitError> { + let product = pairwise(config, region, &[x.clone(), divisor.clone()], BaseOp::Mult)?; + let distance = l1_distance(config, region, &[product, input.clone()])?; + Ok(distance) + }; - is_closest_to( - config, - region, - &[ - product, - claimed_output_minus_one_product, - claimed_output_plus_one_product, - ], - &[input.clone()], - )?; + optimum_convex_function(config, region, &claimed_output, err_func)?; Ok(claimed_output) } @@ -209,9 +171,6 @@ pub(crate) fn recip( let input = value[0].clone(); let input_dims = input.dims(); - let one = create_constant_tensor(F::ONE, 1); - let one = region.assign(&config.custom_gates.inputs[0], &one)?; - let unit_scale = create_constant_tensor(output_scale * input_scale, 1); let unit_scale = region.assign(&config.custom_gates.inputs[1], &unit_scale)?; region.increment(1); @@ -256,89 +215,16 @@ pub(crate) fn recip( &[equal_zero_mask.clone(), equal_inverse_mask], )?; - // this is now of scale 2 * scale - let product = pairwise( - config, - region, - &[claimed_output.clone(), input.clone()], - BaseOp::Mult, - )?; - - let claimed_output_minus_one = pairwise( - config, - region, - &[claimed_output.clone(), one.clone()], - BaseOp::Sub, - )?; - - let claimed_output_minus_one_product = pairwise( - config, - region, - &[claimed_output_minus_one.clone(), input.clone()], - BaseOp::Mult, - )?; - - let claimed_output_plus_one = pairwise( - config, - region, - &[claimed_output.clone(), one.clone()], - BaseOp::Add, - )?; - - let claimed_output_plus_one_product = pairwise( - config, - region, - &[claimed_output_plus_one.clone(), input.clone()], - BaseOp::Mult, - )?; - - let scaled_equal_zero_mask = pairwise( - config, - region, - &[equal_zero_mask.clone(), unit_scale.clone()], - BaseOp::Mult, - )?; - - // add 1 where the mask is 0 - let product_masked = pairwise( - config, - region, - &[product.clone(), scaled_equal_zero_mask.clone()], - BaseOp::Add, - )?; - - // add 1 where the mask is 0 - let claimed_output_minus_one_product_masked = pairwise( - config, - region, - &[ - claimed_output_minus_one_product.clone(), - scaled_equal_zero_mask.clone(), - ], - BaseOp::Add, - )?; - - // add 1 where the mask is 0 - let claimed_output_plus_one_product_masked = pairwise( - config, - region, - &[ - claimed_output_plus_one_product.clone(), - scaled_equal_zero_mask.clone(), - ], - BaseOp::Add, - )?; + let err_func = |config: &BaseConfig, + region: &mut RegionCtx, + x: &ValTensor| + -> Result, CircuitError> { + let product = pairwise(config, region, &[x.clone(), input.clone()], BaseOp::Mult)?; + let distance = l1_distance(config, region, &[product.clone(), unit_scale.clone()])?; + Ok(distance) + }; - is_closest_to( - config, - region, - &[ - product_masked, - claimed_output_minus_one_product_masked, - claimed_output_plus_one_product_masked, - ], - &[unit_scale], - )?; + optimum_convex_function(config, region, &claimed_output, err_func)?; Ok(claimed_output) } @@ -374,9 +260,6 @@ pub fn sqrt( let input = value[0].clone(); let input_dims = input.dims(); - let one = create_constant_tensor(F::ONE, 1); - let one = region.assign(&config.custom_gates.inputs[0], &one)?; - let unit_scale = create_constant_tensor(integer_rep_to_felt(input_scale.0 as IntegerRep), 1); let unit_scale = region.assign(&config.custom_gates.inputs[1], &unit_scale)?; region.increment(1); @@ -411,61 +294,19 @@ pub fn sqrt( // assert the sign is positive enforce_equality(config, region, &[sign, ones.clone()])?; - // this is now of scale 2 * scale - let product = pairwise( - config, - region, - &[claimed_output.clone(), claimed_output.clone()], - BaseOp::Mult, - )?; - - let claimed_output_minus_one = pairwise( - config, - region, - &[claimed_output.clone(), one.clone()], - BaseOp::Sub, - )?; - - let claimed_output_minus_one_product = pairwise( - config, - region, - &[ - claimed_output_minus_one.clone(), - claimed_output_minus_one.clone(), - ], - BaseOp::Mult, - )?; - - let claimed_output_plus_one = pairwise( - config, - region, - &[claimed_output.clone(), one.clone()], - BaseOp::Add, - )?; - - let claimed_output_plus_one_product = pairwise( - config, - region, - &[ - claimed_output_plus_one.clone(), - claimed_output_plus_one.clone(), - ], - BaseOp::Mult, - )?; - // rescaled input let rescaled_input = pairwise(config, region, &[input.clone(), unit_scale], BaseOp::Mult)?; - is_closest_to( - config, - region, - &[ - product, - claimed_output_minus_one_product, - claimed_output_plus_one_product, - ], - &[rescaled_input], - )?; + let err_func = |config: &BaseConfig, + region: &mut RegionCtx, + x: &ValTensor| + -> Result, CircuitError> { + let product = pairwise(config, region, &[x.clone(), x.clone()], BaseOp::Mult)?; + let distance = l1_distance(config, region, &[product.clone(), rescaled_input.clone()])?; + Ok(distance) + }; + + optimum_convex_function(config, region, &claimed_output, err_func)?; Ok(claimed_output) } @@ -4925,16 +4766,33 @@ pub fn ln( BaseOp::Sub, )?; - is_closest_to( + let abs_distance_to_claimed = abs(config, region, &[distance_to_claimed.clone()])?; + + let abs_distance_to_next_pow2 = + l1_distance(config, region, &[input.clone(), next_pow2.clone()])?; + + let abs_distance_to_prior_pow2 = + l1_distance(config, region, &[input.clone(), prior_pow2.clone()])?; + + // because we round up this can be equal + let is_closest_to_0: ValTensor = less_equal( config, region, - &[ - pow2_of_claimed_output.clone(), - prior_pow2.clone(), - next_pow2.clone(), - ], - &[input.clone()], + &[abs_distance_to_claimed.clone(), abs_distance_to_next_pow2], )?; + let is_closest_to_1 = less( + config, + region, + &[abs_distance_to_claimed.clone(), abs_distance_to_prior_pow2], + )?; + + let is_closest = and(config, region, &[is_closest_to_0, is_closest_to_1])?; + + let mut comparison_unit = create_constant_tensor(integer_rep_to_felt(1), is_closest.len()); + comparison_unit.reshape(is_closest.dims())?; + let assigned_unit = region.assign(&config.custom_gates.inputs[1], &comparison_unit)?; + + enforce_equality(config, region, &[is_closest, assigned_unit])?; // get a linear interpolation now