From 6dd4338dd4c4a0d1dc2608034ccb2e765de40f27 Mon Sep 17 00:00:00 2001 From: dante <45801863+alexander-camuto@users.noreply.github.com> Date: Thu, 31 Oct 2024 17:26:09 -0400 Subject: [PATCH] feat: bounded lookup round half to even --- src/circuit/ops/hybrid.rs | 11 +++ src/circuit/ops/layouts.rs | 149 +++++++++++++++++++++++++++++++++++++ src/circuit/ops/lookup.rs | 21 +----- src/circuit/ops/region.rs | 11 +++ src/circuit/table.rs | 9 ++- src/graph/utilities.rs | 63 ++++++++-------- src/tensor/ops.rs | 26 +++++++ 7 files changed, 239 insertions(+), 51 deletions(-) diff --git a/src/circuit/ops/hybrid.rs b/src/circuit/ops/hybrid.rs index 5f58031a0..0ea1327c7 100644 --- a/src/circuit/ops/hybrid.rs +++ b/src/circuit/ops/hybrid.rs @@ -13,6 +13,10 @@ use serde::{Deserialize, Serialize}; /// An enum representing the operations that consist of both lookups and arithmetic operations. #[derive(Clone, Debug, Serialize, Deserialize)] pub enum HybridOp { + RoundHalfToEven { + scale: utils::F32, + legs: usize, + }, Ceil { scale: utils::F32, legs: usize, @@ -108,9 +112,13 @@ impl Op for Hybrid fn as_string(&self) -> String { match self { + HybridOp::RoundHalfToEven { scale, legs } => { + format!("ROUND_HALF_TO_EVEN(scale={}, legs={})", scale, legs) + } HybridOp::Ceil { scale, legs } => format!("CEIL(scale={}, legs={})", scale, legs), HybridOp::Floor { scale, legs } => format!("FLOOR(scale={}, legs={})", scale, legs), HybridOp::Round { scale, legs } => format!("ROUND(scale={}, legs={})", scale, legs), + HybridOp::Max => format!("MAX"), HybridOp::Min => format!("MIN"), HybridOp::Recip { @@ -181,6 +189,9 @@ impl Op for Hybrid values: &[ValTensor], ) -> Result>, CircuitError> { Ok(Some(match self { + HybridOp::RoundHalfToEven { scale, legs } => { + layouts::round_half_to_even(config, region, values[..].try_into()?, *scale, *legs)? + } HybridOp::Ceil { scale, legs } => { layouts::ceil(config, region, values[..].try_into()?, *scale, *legs)? } diff --git a/src/circuit/ops/layouts.rs b/src/circuit/ops/layouts.rs index b3cc48fb1..2f3d9a31f 100644 --- a/src/circuit/ops/layouts.rs +++ b/src/circuit/ops/layouts.rs @@ -4654,6 +4654,155 @@ pub fn round( ) } +/// round half to even layout +/// # Arguments +/// * `config` - BaseConfig +/// * `region` - RegionCtx +/// * `values` - &[ValTensor; 1] +/// * `scale` - utils::F32 +/// * `legs` - usize +/// # Returns +/// * ValTensor +/// # Example +/// ``` +/// use ezkl::tensor::Tensor; +/// use ezkl::fieldutils::IntegerRep; +/// use ezkl::circuit::ops::layouts::round; +/// use ezkl::tensor::val::ValTensor; +/// use halo2curves::bn256::Fr as Fp; +/// use ezkl::circuit::region::RegionCtx; +/// use ezkl::circuit::region::RegionSettings; +/// use ezkl::circuit::BaseConfig; +/// let dummy_config = BaseConfig::dummy(12, 2); +/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2)); +/// let x = ValTensor::from_integer_rep_tensor(Tensor::::new( +/// Some(&[3, -2, 3, 1]), +/// &[1, 1, 2, 2], +/// ).unwrap()); +/// let result = round::(&dummy_config, &mut dummy_region, &[x], 4.0.into(), 2).unwrap(); +/// let expected = Tensor::::new(Some(&[4, -4, 4, 0]), &[1, 1, 2, 2]).unwrap(); +/// assert_eq!(result.int_evals().unwrap(), expected); +/// ``` +/// +pub fn round_half_to_even( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 1], + scale: utils::F32, + legs: usize, +) -> Result, CircuitError> { + // decompose with base scale and then set the last element to zero + let decomposition = decompose(config, region, values, &(scale.0 as usize), &legs)?; + // set the last element to zero and then recompose, we don't actually need to assign here + // as this will automatically be assigned in the recompose function and uses the constant caching of RegionCtx + let zero = ValType::Constant(F::ZERO); + + // if scale is not exactly divisible by 2 we warn + if scale.0 % 2.0 != 0.0 { + log::warn!("Scale is not exactly divisible by 2.0, rounding may not be accurate"); + } + + let midway_point: ValTensor = create_constant_tensor( + integer_rep_to_felt((scale.0 / 2.0).round() as IntegerRep), + 1, + ); + let assigned_midway_point = region.assign(&config.custom_gates.inputs[1], &midway_point)?; + + region.increment(1); + + let dims = decomposition.dims().to_vec(); + let first_dims = decomposition.dims().to_vec()[..decomposition.dims().len() - 1].to_vec(); + + let mut incremented_tensor = Tensor::new(None, &first_dims)?; + + let cartesian_coord = first_dims + .iter() + .map(|x| 0..*x) + .multi_cartesian_product() + .collect::>(); + + let inner_loop_function = + |i: usize, region: &mut RegionCtx| -> Result>, CircuitError> { + let coord = cartesian_coord[i].clone(); + let slice = coord.iter().map(|x| *x..*x + 1).collect::>(); + let mut sliced_input = decomposition.get_slice(&slice)?; + sliced_input.flatten(); + let last_elem = sliced_input.last()?; + + let penultimate_elem = + sliced_input.get_slice(&[sliced_input.len() - 2..sliced_input.len() - 1])?; + + let is_equal_to_midway = equals( + config, + region, + &[last_elem.clone(), assigned_midway_point.clone()], + )?; + // penultimate_elem is equal to midway point and even, do nothing + let is_odd = nonlinearity( + config, + region, + &[penultimate_elem.clone()], + &LookupOp::IsOdd, + )?; + + let is_odd_and_equal_to_midway = and( + config, + region, + &[is_odd.clone(), is_equal_to_midway.clone()], + )?; + + let is_greater_than_midway = greater( + config, + region, + &[last_elem.clone(), assigned_midway_point.clone()], + )?; + + // if the number is equal to midway point and odd increment, or if it is is_greater_than_midway + let is_odd_and_equal_to_midway_or_greater_than_midway = or( + config, + region, + &[ + is_odd_and_equal_to_midway.clone(), + is_greater_than_midway.clone(), + ], + )?; + + // increment the penultimate element + let incremented_elem = pairwise( + config, + region, + &[ + sliced_input.get_slice(&[sliced_input.len() - 2..sliced_input.len() - 1])?, + is_odd_and_equal_to_midway_or_greater_than_midway.clone(), + ], + BaseOp::Add, + )?; + + let mut inner_tensor = sliced_input.get_inner_tensor()?.clone(); + inner_tensor[sliced_input.len() - 2] = + incremented_elem.get_inner_tensor()?.clone()[0].clone(); + + // set the last elem to zero + inner_tensor[sliced_input.len() - 1] = zero.clone(); + + Ok(inner_tensor.clone()) + }; + + region.update_max_min_lookup_inputs_force(0, scale.0 as IntegerRep)?; + + region.apply_in_loop(&mut incremented_tensor, inner_loop_function)?; + + let mut incremented_tensor = incremented_tensor.combine()?; + incremented_tensor.reshape(&dims)?; + + recompose( + config, + region, + &[incremented_tensor.into()], + &(scale.0 as usize), + ) +} + pub(crate) fn recompose( config: &BaseConfig, region: &mut RegionCtx, diff --git a/src/circuit/ops/lookup.rs b/src/circuit/ops/lookup.rs index 9558b5464..0f7ca0852 100644 --- a/src/circuit/ops/lookup.rs +++ b/src/circuit/ops/lookup.rs @@ -4,7 +4,6 @@ use serde::{Deserialize, Serialize}; use crate::{ circuit::{layouts, table::Range, utils}, fieldutils::{felt_to_integer_rep, integer_rep_to_felt, IntegerRep}, - graph::multiplier_to_scale, tensor::{self, Tensor, TensorError, TensorType}, }; @@ -16,8 +15,7 @@ use halo2curves::ff::PrimeField; #[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Deserialize, Serialize)] pub enum LookupOp { Div { denom: utils::F32 }, - Cast { scale: utils::F32 }, - RoundHalfToEven { scale: utils::F32 }, + IsOdd, Sqrt { scale: utils::F32 }, Rsqrt { scale: utils::F32 }, Sigmoid { scale: utils::F32 }, @@ -51,10 +49,9 @@ impl LookupOp { /// as path pub fn as_path(&self) -> String { match self { - LookupOp::RoundHalfToEven { scale } => format!("round_half_to_even_{}", scale), LookupOp::Pow { scale, a } => format!("pow_{}_{}", scale, a), + LookupOp::IsOdd => "is_odd".to_string(), LookupOp::Div { denom } => format!("div_{}", denom), - LookupOp::Cast { scale } => format!("cast_{}", scale), LookupOp::Sigmoid { scale } => format!("sigmoid_{}", scale), LookupOp::Sqrt { scale } => format!("sqrt_{}", scale), LookupOp::Rsqrt { scale } => format!("rsqrt_{}", scale), @@ -85,18 +82,13 @@ impl LookupOp { let x = x[0].clone().map(|x| felt_to_integer_rep(x)); let res = match &self { - LookupOp::RoundHalfToEven { scale } => Ok::<_, TensorError>( - tensor::ops::nonlinearities::round_half_to_even(&x, scale.into()), - ), + LookupOp::IsOdd => Ok::<_, TensorError>(tensor::ops::nonlinearities::is_odd(&x)), LookupOp::Pow { scale, a } => Ok::<_, TensorError>( tensor::ops::nonlinearities::pow(&x, scale.0.into(), a.0.into()), ), LookupOp::Div { denom } => Ok::<_, TensorError>( tensor::ops::nonlinearities::const_div(&x, f32::from(*denom).into()), ), - LookupOp::Cast { scale } => Ok::<_, TensorError>( - tensor::ops::nonlinearities::const_div(&x, f32::from(*scale).into()), - ), LookupOp::Sigmoid { scale } => { Ok::<_, TensorError>(tensor::ops::nonlinearities::sigmoid(&x, scale.into())) } @@ -171,10 +163,9 @@ impl Op for Lookup /// Returns the name of the operation fn as_string(&self) -> String { match self { - LookupOp::RoundHalfToEven { scale } => format!("ROUND_HALF_TO_EVEN(scale={})", scale), + LookupOp::IsOdd => "IS_ODD".to_string(), LookupOp::Pow { a, scale } => format!("POW(scale={}, exponent={})", scale, a), LookupOp::Div { denom, .. } => format!("DIV(denom={})", denom), - LookupOp::Cast { scale } => format!("CAST(scale={})", scale), LookupOp::Ln { scale } => format!("LN(scale={})", scale), LookupOp::Sigmoid { scale } => format!("SIGMOID(scale={})", scale), LookupOp::Sqrt { scale } => format!("SQRT(scale={})", scale), @@ -214,10 +205,6 @@ impl Op for Lookup /// Returns the scale of the output of the operation. fn out_scale(&self, inputs_scale: Vec) -> Result { let scale = match self { - LookupOp::Cast { scale } => { - let in_scale = inputs_scale[0]; - in_scale + multiplier_to_scale(1. / scale.0 as f64) - } _ => inputs_scale[0], }; Ok(scale) diff --git a/src/circuit/ops/region.rs b/src/circuit/ops/region.rs index aa66df1ac..9c92c8901 100644 --- a/src/circuit/ops/region.rs +++ b/src/circuit/ops/region.rs @@ -474,6 +474,17 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a Ok(()) } + /// Update the max and min forcefully + pub fn update_max_min_lookup_inputs_force( + &mut self, + min: IntegerRep, + max: IntegerRep, + ) -> Result<(), CircuitError> { + self.statistics.max_lookup_inputs = self.statistics.max_lookup_inputs.max(max); + self.statistics.min_lookup_inputs = self.statistics.min_lookup_inputs.min(min); + Ok(()) + } + /// Update the max and min from inputs pub fn update_max_min_lookup_range(&mut self, range: Range) -> Result<(), CircuitError> { if range.0 > range.1 { diff --git a/src/circuit/table.rs b/src/circuit/table.rs index 4be1fa6df..5c82cf454 100644 --- a/src/circuit/table.rs +++ b/src/circuit/table.rs @@ -150,12 +150,16 @@ pub fn num_cols_required(range_len: IntegerRep, col_size: usize) -> usize { } impl Table { + /// get largest element represented by the range + pub fn largest(&self) -> IntegerRep { + self.range.0 + (self.col_size * self.table_inputs.len() - 1) as IntegerRep + } fn name(&self) -> String { format!( "{}_{}_{}", self.nonlinearity.as_path(), self.range.0, - self.range.1 + self.largest() ) } /// Configures the table. @@ -222,7 +226,7 @@ impl Table { } let smallest = self.range.0; - let largest = self.range.1; + let largest = self.largest(); let gen_table = || -> Result<(Tensor, Tensor), crate::tensor::TensorError> { let inputs = Tensor::from(smallest..=largest) @@ -291,6 +295,7 @@ impl Table { row_offset += chunk_idx * self.col_size; let (x, y) = self.cartesian_coord(row_offset); + if !preassigned_input { table.assign_cell( || format!("nl_i_col row {}", row_offset), diff --git a/src/graph/utilities.rs b/src/graph/utilities.rs index be4b42841..ae89df0a7 100644 --- a/src/graph/utilities.rs +++ b/src/graph/utilities.rs @@ -803,7 +803,7 @@ pub fn new_op_from_onnx( } } "Recip" => { - let in_scale = inputs[0].out_scales()[0]; + let in_scale = input_scales[0]; let max_scale = std::cmp::max(scales.get_max(), in_scale); // If the input scale is larger than the params scale SupportedOp::Hybrid(HybridOp::Recip { @@ -837,61 +837,61 @@ pub fn new_op_from_onnx( "Abs" => SupportedOp::Linear(PolyOp::Abs), "Neg" => SupportedOp::Linear(PolyOp::Neg), "HardSwish" => SupportedOp::Nonlinear(LookupOp::HardSwish { - scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(), + scale: scale_to_multiplier(input_scales[0]).into(), }), "Sigmoid" => SupportedOp::Nonlinear(LookupOp::Sigmoid { - scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(), + scale: scale_to_multiplier(input_scales[0]).into(), }), "Sqrt" => SupportedOp::Nonlinear(LookupOp::Sqrt { - scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(), + scale: scale_to_multiplier(input_scales[0]).into(), }), "Rsqrt" => SupportedOp::Nonlinear(LookupOp::Rsqrt { - scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(), + scale: scale_to_multiplier(input_scales[0]).into(), }), "Exp" => SupportedOp::Nonlinear(LookupOp::Exp { - scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(), + scale: scale_to_multiplier(input_scales[0]).into(), }), "Ln" => SupportedOp::Nonlinear(LookupOp::Ln { - scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(), + scale: scale_to_multiplier(input_scales[0]).into(), }), "Sin" => SupportedOp::Nonlinear(LookupOp::Sin { - scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(), + scale: scale_to_multiplier(input_scales[0]).into(), }), "Cos" => SupportedOp::Nonlinear(LookupOp::Cos { - scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(), + scale: scale_to_multiplier(input_scales[0]).into(), }), "Tan" => SupportedOp::Nonlinear(LookupOp::Tan { - scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(), + scale: scale_to_multiplier(input_scales[0]).into(), }), "Asin" => SupportedOp::Nonlinear(LookupOp::ASin { - scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(), + scale: scale_to_multiplier(input_scales[0]).into(), }), "Acos" => SupportedOp::Nonlinear(LookupOp::ACos { - scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(), + scale: scale_to_multiplier(input_scales[0]).into(), }), "Atan" => SupportedOp::Nonlinear(LookupOp::ATan { - scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(), + scale: scale_to_multiplier(input_scales[0]).into(), }), "Sinh" => SupportedOp::Nonlinear(LookupOp::Sinh { - scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(), + scale: scale_to_multiplier(input_scales[0]).into(), }), "Cosh" => SupportedOp::Nonlinear(LookupOp::Cosh { - scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(), + scale: scale_to_multiplier(input_scales[0]).into(), }), "Tanh" => SupportedOp::Nonlinear(LookupOp::Tanh { - scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(), + scale: scale_to_multiplier(input_scales[0]).into(), }), "Asinh" => SupportedOp::Nonlinear(LookupOp::ASinh { - scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(), + scale: scale_to_multiplier(input_scales[0]).into(), }), "Acosh" => SupportedOp::Nonlinear(LookupOp::ACosh { - scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(), + scale: scale_to_multiplier(input_scales[0]).into(), }), "Atanh" => SupportedOp::Nonlinear(LookupOp::ATanh { - scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(), + scale: scale_to_multiplier(input_scales[0]).into(), }), "Erf" => SupportedOp::Nonlinear(LookupOp::Erf { - scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(), + scale: scale_to_multiplier(input_scales[0]).into(), }), "Source" => { let dt = node.outputs[0].fact.datum_type; @@ -935,11 +935,9 @@ pub fn new_op_from_onnx( replace_const( 0, 0, - SupportedOp::Nonlinear(LookupOp::Cast { - scale: crate::circuit::utils::F32(scale_to_multiplier( - input_scales[0], - ) - as f32), + SupportedOp::Hybrid(HybridOp::Floor { + scale: scale_to_multiplier(input_scales[0]).into(), + legs: run_args.decomp_legs, }), )? } else { @@ -1045,7 +1043,7 @@ pub fn new_op_from_onnx( } }; - let in_scale = inputs[0].out_scales()[0]; + let in_scale = input_scales[0]; let max_scale = std::cmp::max(scales.get_max(), in_scale); SupportedOp::Hybrid(HybridOp::Softmax { @@ -1084,19 +1082,20 @@ pub fn new_op_from_onnx( }) } "Ceil" => SupportedOp::Hybrid(HybridOp::Ceil { - scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(), + scale: scale_to_multiplier(input_scales[0]).into(), legs: run_args.decomp_legs, }), "Floor" => SupportedOp::Hybrid(HybridOp::Floor { - scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(), + scale: scale_to_multiplier(input_scales[0]).into(), legs: run_args.decomp_legs, }), "Round" => SupportedOp::Hybrid(HybridOp::Round { - scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(), + scale: scale_to_multiplier(input_scales[0]).into(), legs: run_args.decomp_legs, }), - "RoundHalfToEven" => SupportedOp::Nonlinear(LookupOp::RoundHalfToEven { - scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(), + "RoundHalfToEven" => SupportedOp::Hybrid(HybridOp::RoundHalfToEven { + scale: scale_to_multiplier(input_scales[0]).into(), + legs: run_args.decomp_legs, }), "Sign" => SupportedOp::Linear(PolyOp::Sign), "Pow" => { @@ -1116,7 +1115,7 @@ pub fn new_op_from_onnx( SupportedOp::Linear(PolyOp::Pow(exponent as u32)) } else { SupportedOp::Nonlinear(LookupOp::Pow { - scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(), + scale: scale_to_multiplier(input_scales[0]).into(), a: crate::circuit::utils::F32(exponent), }) } diff --git a/src/tensor/ops.rs b/src/tensor/ops.rs index 2666191ad..bdd79f651 100644 --- a/src/tensor/ops.rs +++ b/src/tensor/ops.rs @@ -1474,6 +1474,32 @@ pub mod nonlinearities { .unwrap() } + /// Checks if a tensor's elements are odd + /// # Arguments + /// * `a` - Tensor + /// * `scale` - Single value + /// # Examples + /// ``` + /// use ezkl::tensor::Tensor; + /// use ezkl::fieldutils::IntegerRep; + /// use ezkl::tensor::ops::nonlinearities::is_odd; + /// let x = Tensor::::new( + /// Some(&[2, 15, 2, 1, 1, 0]), + /// &[2, 3], + /// ).unwrap(); + /// + /// let result = is_odd(&x); + /// let expected = Tensor::::new(Some(&[0, 1, 0, 1, 1, 0]), &[2, 3]).unwrap(); + /// assert_eq!(result, expected); + /// ``` + pub fn is_odd(a: &Tensor) -> Tensor { + a.par_enum_map(|_, a_i| { + let rounded = if a_i % 2 == 0 { 0 } else { 1 }; + Ok::<_, TensorError>(rounded) + }) + .unwrap() + } + /// Elementwise applies sigmoid to a tensor of integers. /// # Arguments ///