Skip to content

Commit

Permalink
feat: bounded lookup round half to even
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-camuto committed Oct 31, 2024
1 parent a3c131d commit 6dd4338
Show file tree
Hide file tree
Showing 7 changed files with 239 additions and 51 deletions.
11 changes: 11 additions & 0 deletions src/circuit/ops/hybrid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ use serde::{Deserialize, Serialize};
/// An enum representing the operations that consist of both lookups and arithmetic operations.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum HybridOp {
RoundHalfToEven {
scale: utils::F32,
legs: usize,
},
Ceil {
scale: utils::F32,
legs: usize,
Expand Down Expand Up @@ -108,9 +112,13 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid

fn as_string(&self) -> String {
match self {
HybridOp::RoundHalfToEven { scale, legs } => {
format!("ROUND_HALF_TO_EVEN(scale={}, legs={})", scale, legs)
}
HybridOp::Ceil { scale, legs } => format!("CEIL(scale={}, legs={})", scale, legs),
HybridOp::Floor { scale, legs } => format!("FLOOR(scale={}, legs={})", scale, legs),
HybridOp::Round { scale, legs } => format!("ROUND(scale={}, legs={})", scale, legs),

HybridOp::Max => format!("MAX"),
HybridOp::Min => format!("MIN"),
HybridOp::Recip {
Expand Down Expand Up @@ -181,6 +189,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
values: &[ValTensor<F>],
) -> Result<Option<ValTensor<F>>, CircuitError> {
Ok(Some(match self {
HybridOp::RoundHalfToEven { scale, legs } => {
layouts::round_half_to_even(config, region, values[..].try_into()?, *scale, *legs)?
}
HybridOp::Ceil { scale, legs } => {
layouts::ceil(config, region, values[..].try_into()?, *scale, *legs)?
}
Expand Down
149 changes: 149 additions & 0 deletions src/circuit/ops/layouts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4654,6 +4654,155 @@ pub fn round<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
)
}

/// round half to even layout
/// # Arguments
/// * `config` - BaseConfig
/// * `region` - RegionCtx
/// * `values` - &[ValTensor<F>; 1]
/// * `scale` - utils::F32
/// * `legs` - usize
/// # Returns
/// * ValTensor<F>
/// # Example
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
/// use ezkl::circuit::ops::layouts::round;
/// use ezkl::tensor::val::ValTensor;
/// use halo2curves::bn256::Fr as Fp;
/// use ezkl::circuit::region::RegionCtx;
/// use ezkl::circuit::region::RegionSettings;
/// use ezkl::circuit::BaseConfig;
/// let dummy_config = BaseConfig::dummy(12, 2);
/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2));
/// let x = ValTensor::from_integer_rep_tensor(Tensor::<IntegerRep>::new(
/// Some(&[3, -2, 3, 1]),
/// &[1, 1, 2, 2],
/// ).unwrap());
/// let result = round::<Fp>(&dummy_config, &mut dummy_region, &[x], 4.0.into(), 2).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[4, -4, 4, 0]), &[1, 1, 2, 2]).unwrap();
/// assert_eq!(result.int_evals().unwrap(), expected);
/// ```
///
pub fn round_half_to_even<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
scale: utils::F32,
legs: usize,
) -> Result<ValTensor<F>, CircuitError> {
// decompose with base scale and then set the last element to zero
let decomposition = decompose(config, region, values, &(scale.0 as usize), &legs)?;
// set the last element to zero and then recompose, we don't actually need to assign here
// as this will automatically be assigned in the recompose function and uses the constant caching of RegionCtx
let zero = ValType::Constant(F::ZERO);

// if scale is not exactly divisible by 2 we warn
if scale.0 % 2.0 != 0.0 {
log::warn!("Scale is not exactly divisible by 2.0, rounding may not be accurate");
}

let midway_point: ValTensor<F> = create_constant_tensor(
integer_rep_to_felt((scale.0 / 2.0).round() as IntegerRep),
1,
);
let assigned_midway_point = region.assign(&config.custom_gates.inputs[1], &midway_point)?;

region.increment(1);

let dims = decomposition.dims().to_vec();
let first_dims = decomposition.dims().to_vec()[..decomposition.dims().len() - 1].to_vec();

let mut incremented_tensor = Tensor::new(None, &first_dims)?;

let cartesian_coord = first_dims
.iter()
.map(|x| 0..*x)
.multi_cartesian_product()
.collect::<Vec<_>>();

let inner_loop_function =
|i: usize, region: &mut RegionCtx<F>| -> Result<Tensor<ValType<F>>, CircuitError> {
let coord = cartesian_coord[i].clone();
let slice = coord.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
let mut sliced_input = decomposition.get_slice(&slice)?;
sliced_input.flatten();
let last_elem = sliced_input.last()?;

let penultimate_elem =
sliced_input.get_slice(&[sliced_input.len() - 2..sliced_input.len() - 1])?;

let is_equal_to_midway = equals(
config,
region,
&[last_elem.clone(), assigned_midway_point.clone()],
)?;
// penultimate_elem is equal to midway point and even, do nothing
let is_odd = nonlinearity(
config,
region,
&[penultimate_elem.clone()],
&LookupOp::IsOdd,
)?;

let is_odd_and_equal_to_midway = and(
config,
region,
&[is_odd.clone(), is_equal_to_midway.clone()],
)?;

let is_greater_than_midway = greater(
config,
region,
&[last_elem.clone(), assigned_midway_point.clone()],
)?;

// if the number is equal to midway point and odd increment, or if it is is_greater_than_midway
let is_odd_and_equal_to_midway_or_greater_than_midway = or(
config,
region,
&[
is_odd_and_equal_to_midway.clone(),
is_greater_than_midway.clone(),
],
)?;

// increment the penultimate element
let incremented_elem = pairwise(
config,
region,
&[
sliced_input.get_slice(&[sliced_input.len() - 2..sliced_input.len() - 1])?,
is_odd_and_equal_to_midway_or_greater_than_midway.clone(),
],
BaseOp::Add,
)?;

let mut inner_tensor = sliced_input.get_inner_tensor()?.clone();
inner_tensor[sliced_input.len() - 2] =
incremented_elem.get_inner_tensor()?.clone()[0].clone();

// set the last elem to zero
inner_tensor[sliced_input.len() - 1] = zero.clone();

Ok(inner_tensor.clone())
};

region.update_max_min_lookup_inputs_force(0, scale.0 as IntegerRep)?;

region.apply_in_loop(&mut incremented_tensor, inner_loop_function)?;

let mut incremented_tensor = incremented_tensor.combine()?;
incremented_tensor.reshape(&dims)?;

recompose(
config,
region,
&[incremented_tensor.into()],
&(scale.0 as usize),
)
}

pub(crate) fn recompose<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
Expand Down
21 changes: 4 additions & 17 deletions src/circuit/ops/lookup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use serde::{Deserialize, Serialize};
use crate::{
circuit::{layouts, table::Range, utils},
fieldutils::{felt_to_integer_rep, integer_rep_to_felt, IntegerRep},
graph::multiplier_to_scale,
tensor::{self, Tensor, TensorError, TensorType},
};

Expand All @@ -16,8 +15,7 @@ use halo2curves::ff::PrimeField;
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Deserialize, Serialize)]
pub enum LookupOp {
Div { denom: utils::F32 },
Cast { scale: utils::F32 },
RoundHalfToEven { scale: utils::F32 },
IsOdd,
Sqrt { scale: utils::F32 },
Rsqrt { scale: utils::F32 },
Sigmoid { scale: utils::F32 },
Expand Down Expand Up @@ -51,10 +49,9 @@ impl LookupOp {
/// as path
pub fn as_path(&self) -> String {
match self {
LookupOp::RoundHalfToEven { scale } => format!("round_half_to_even_{}", scale),
LookupOp::Pow { scale, a } => format!("pow_{}_{}", scale, a),
LookupOp::IsOdd => "is_odd".to_string(),
LookupOp::Div { denom } => format!("div_{}", denom),
LookupOp::Cast { scale } => format!("cast_{}", scale),
LookupOp::Sigmoid { scale } => format!("sigmoid_{}", scale),
LookupOp::Sqrt { scale } => format!("sqrt_{}", scale),
LookupOp::Rsqrt { scale } => format!("rsqrt_{}", scale),
Expand Down Expand Up @@ -85,18 +82,13 @@ impl LookupOp {
let x = x[0].clone().map(|x| felt_to_integer_rep(x));
let res =
match &self {
LookupOp::RoundHalfToEven { scale } => Ok::<_, TensorError>(
tensor::ops::nonlinearities::round_half_to_even(&x, scale.into()),
),
LookupOp::IsOdd => Ok::<_, TensorError>(tensor::ops::nonlinearities::is_odd(&x)),
LookupOp::Pow { scale, a } => Ok::<_, TensorError>(
tensor::ops::nonlinearities::pow(&x, scale.0.into(), a.0.into()),
),
LookupOp::Div { denom } => Ok::<_, TensorError>(
tensor::ops::nonlinearities::const_div(&x, f32::from(*denom).into()),
),
LookupOp::Cast { scale } => Ok::<_, TensorError>(
tensor::ops::nonlinearities::const_div(&x, f32::from(*scale).into()),
),
LookupOp::Sigmoid { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::sigmoid(&x, scale.into()))
}
Expand Down Expand Up @@ -171,10 +163,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Lookup
/// Returns the name of the operation
fn as_string(&self) -> String {
match self {
LookupOp::RoundHalfToEven { scale } => format!("ROUND_HALF_TO_EVEN(scale={})", scale),
LookupOp::IsOdd => "IS_ODD".to_string(),
LookupOp::Pow { a, scale } => format!("POW(scale={}, exponent={})", scale, a),
LookupOp::Div { denom, .. } => format!("DIV(denom={})", denom),
LookupOp::Cast { scale } => format!("CAST(scale={})", scale),
LookupOp::Ln { scale } => format!("LN(scale={})", scale),
LookupOp::Sigmoid { scale } => format!("SIGMOID(scale={})", scale),
LookupOp::Sqrt { scale } => format!("SQRT(scale={})", scale),
Expand Down Expand Up @@ -214,10 +205,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Lookup
/// Returns the scale of the output of the operation.
fn out_scale(&self, inputs_scale: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
let scale = match self {
LookupOp::Cast { scale } => {
let in_scale = inputs_scale[0];
in_scale + multiplier_to_scale(1. / scale.0 as f64)
}
_ => inputs_scale[0],
};
Ok(scale)
Expand Down
11 changes: 11 additions & 0 deletions src/circuit/ops/region.rs
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,17 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
Ok(())
}

/// Update the max and min forcefully
pub fn update_max_min_lookup_inputs_force(
&mut self,
min: IntegerRep,
max: IntegerRep,
) -> Result<(), CircuitError> {
self.statistics.max_lookup_inputs = self.statistics.max_lookup_inputs.max(max);
self.statistics.min_lookup_inputs = self.statistics.min_lookup_inputs.min(min);
Ok(())
}

/// Update the max and min from inputs
pub fn update_max_min_lookup_range(&mut self, range: Range) -> Result<(), CircuitError> {
if range.0 > range.1 {
Expand Down
9 changes: 7 additions & 2 deletions src/circuit/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,16 @@ pub fn num_cols_required(range_len: IntegerRep, col_size: usize) -> usize {
}

impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
/// get largest element represented by the range
pub fn largest(&self) -> IntegerRep {
self.range.0 + (self.col_size * self.table_inputs.len() - 1) as IntegerRep
}
fn name(&self) -> String {
format!(
"{}_{}_{}",
self.nonlinearity.as_path(),
self.range.0,
self.range.1
self.largest()
)
}
/// Configures the table.
Expand Down Expand Up @@ -222,7 +226,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
}

let smallest = self.range.0;
let largest = self.range.1;
let largest = self.largest();

let gen_table = || -> Result<(Tensor<F>, Tensor<F>), crate::tensor::TensorError> {
let inputs = Tensor::from(smallest..=largest)
Expand Down Expand Up @@ -291,6 +295,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {

row_offset += chunk_idx * self.col_size;
let (x, y) = self.cartesian_coord(row_offset);

if !preassigned_input {
table.assign_cell(
|| format!("nl_i_col row {}", row_offset),
Expand Down
Loading

0 comments on commit 6dd4338

Please sign in to comment.