Skip to content

Commit

Permalink
Update layouts.rs
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-camuto committed Nov 8, 2024
1 parent 5022f5a commit 576bda8
Showing 1 changed file with 23 additions and 6 deletions.
29 changes: 23 additions & 6 deletions src/circuit/ops/layouts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,20 +75,37 @@ pub fn is_closest_to<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
let l1_distance_1 = l1_distance(config, region, &[values[1].clone(), reference[0].clone()])?;
let l1_distance_2 = l1_distance(config, region, &[values[2].clone(), reference[0].clone()])?;

let is_closest_to_0 = less(config, region, &[l1_distance_0.clone(), l1_distance_1])?;
let is_closest_to_1 = less(config, region, &[l1_distance_0, l1_distance_2])?;
// we need to account for rounding up or down so it is a <= and not a < comparison but only one of the distances can be less than or equal to the reference
// else we could have two distances that are equal to the reference,
// which could result in soundness issues if we're looking for a unique solution in monotically increasing distance functions
let is_equal_0 = equals(
config,
region,
&[l1_distance_0.clone(), l1_distance_1.clone()],
)?;
let is_closest_to_0 = less_equal(config, region, &[l1_distance_0.clone(), l1_distance_1])?;
let is_equal_1 = equals(
config,
region,
&[l1_distance_0.clone(), l1_distance_2.clone()],
)?;
let is_closest_to_1 = less_equal(config, region, &[l1_distance_0, l1_distance_2])?;

let is_closest = and(config, region, &[is_closest_to_0, is_closest_to_1])?;
let both_equal = and(config, region, &[is_equal_0, is_equal_1])?;
let not_both_equal = not(config, region, &[both_equal])?;

let is_closest_to = and(config, region, &[is_closest_to_0, is_closest_to_1])?;
let is_closest = and(config, region, &[is_closest, not_both_equal])?;

let mut comparison_unit = create_constant_tensor(integer_rep_to_felt(1), is_closest_to.len());
let mut comparison_unit = create_constant_tensor(integer_rep_to_felt(1), is_closest.len());

comparison_unit.reshape(is_closest_to.dims())?;
comparison_unit.reshape(is_closest.dims())?;
// assigned unit
let assigned_unit = region.assign(&config.custom_gates.inputs[1], &comparison_unit)?;
region.increment(assigned_unit.len());

// assert that the result is 1
enforce_equality(config, region, &[is_closest_to, assigned_unit])?;
enforce_equality(config, region, &[is_closest, assigned_unit])?;

Ok(())
}
Expand Down

0 comments on commit 576bda8

Please sign in to comment.