diff --git a/examples/onnx/rounding_ops/gen.py b/examples/onnx/rounding_ops/gen.py index 5731a1157..30b26a6e1 100644 --- a/examples/onnx/rounding_ops/gen.py +++ b/examples/onnx/rounding_ops/gen.py @@ -21,9 +21,9 @@ def main(): torch_model = Circuit() # Input to the model shape = [3, 2, 3] - w = 0.1*torch.rand(1, *shape, requires_grad=True) - x = 0.1*torch.rand(1, *shape, requires_grad=True) - y = 0.1*torch.rand(1, *shape, requires_grad=True) + w = 2 * torch.rand(1, *shape, requires_grad=True) - 1 + x = 2 * torch.rand(1, *shape, requires_grad=True) - 1 + y = 2 * torch.rand(1, *shape, requires_grad=True) - 1 torch_out = torch_model(w, x, y) # Export the model torch.onnx.export(torch_model, # model being run diff --git a/examples/onnx/rounding_ops/input.json b/examples/onnx/rounding_ops/input.json index 4de30e437..8288460ec 100644 --- a/examples/onnx/rounding_ops/input.json +++ b/examples/onnx/rounding_ops/input.json @@ -1 +1,148 @@ -{"input_shapes": [[3, 2, 3], [3, 2, 3], [3, 2, 3], [3, 2, 3]], "input_data": [[0.0025284828152507544, 0.04976580664515495, 0.025840921327471733, 0.0829394981265068, 0.09595223516225815, 0.08764562010765076, 0.06308566778898239, 0.062386948615312576, 0.08090643584728241, 0.09267748892307281, 0.07428313046693802, 0.08987367898225784, 0.005716216750442982, 0.0666426345705986, 0.012837404385209084, 0.05769496038556099, 0.05761152133345604, 0.08006472885608673], [0.007834953255951405, 0.011380612850189209, 0.08560049533843994, 0.022283583879470825, 0.07879520952701569, 0.04422441124916077, 0.030812596902251244, 0.006081616971641779, 0.011045408435165882, 0.08776585012674332, 0.044985152781009674, 0.015603715553879738, 0.07923348993062973, 0.04872611165046692, 0.0036642670165747404, 0.05142095685005188, 0.0963878259062767, 0.03225792199373245], [0.09952805936336517, 0.002214533044025302, 0.011696457862854004, 0.022422820329666138, 0.04151459410786629, 0.027647346258163452, 0.011919880285859108, 0.006539052817970514, 0.06569185107946396, 0.034328874200582504, 0.0032284557819366455, 0.004105025436729193, 0.022395813837647438, 0.07135921716690063, 0.07882415503263474, 0.09764843434095383, 0.05335796996951103, 0.0525360181927681]], "output_data": [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]]} \ No newline at end of file +{ + "input_shapes": [ + [ + 3, + 2, + 3 + ], + [ + 3, + 2, + 3 + ], + [ + 3, + 2, + 3 + ], + [ + 3, + 2, + 3 + ] + ], + "input_data": [ + [ + 0.6261028051376343, + 0.49872446060180664, + -0.04514765739440918, + 0.5936200618743896, + 0.9271858930587769, + 0.6688600778579712, + -0.20331168174743652, + -0.7016235589981079, + 0.025863051414489746, + -0.19426143169403076, + 0.9827852249145508, + 0.4897397756576538, + 0.2992602586746216, + 0.7011144161224365, + 0.9278832674026489, + 0.5943725109100342, + -0.573331356048584, + 0.3675816059112549 + ], + [ + 0.7803324460983276, + -0.9616303443908691, + 0.6070173978805542, + -0.028337717056274414, + -0.5080242156982422, + -0.9280107021331787, + 0.6150380373001099, + 0.3865993022918701, + -0.43668973445892334, + 0.17152702808380127, + 0.5144252777099609, + -0.28881049156188965, + 0.8932310342788696, + 0.059034109115600586, + 0.6865451335906982, + 0.009820222854614258, + 0.23011493682861328, + -0.9492779970169067 + ], + [ + -0.21352827548980713, + -0.16015326976776123, + -0.38964390754699707, + 0.13464701175689697, + -0.8814496994018555, + 0.5037975311279297, + -0.804405927658081, + 0.9858957529067993, + 0.19567716121673584, + 0.9777265787124634, + 0.6151977777481079, + 0.568595290184021, + 0.10584986209869385, + -0.8975653648376465, + 0.6235959529876709, + -0.547879695892334, + 0.9289869070053101, + 0.7567293643951416 + ] + ], + "output_data": [ + [ + 1.0, + 0.0, + -0.0, + 1.0, + 1.0, + 1.0, + -0.0, + -1.0, + 0.0, + -0.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + 1.0, + -1.0, + 0.0 + ], + [ + 0.0, + -1.0, + 0.0, + -1.0, + -1.0, + -1.0, + 0.0, + 0.0, + -1.0, + 0.0, + 0.0, + -1.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + -1.0 + ], + [ + -0.0, + -0.0, + -0.0, + 1.0, + -0.0, + 1.0, + -0.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + -0.0, + 1.0, + -0.0, + 1.0, + 1.0 + ] + ] +} \ No newline at end of file diff --git a/examples/onnx/rounding_ops/network.onnx b/examples/onnx/rounding_ops/network.onnx index 60f43b01c..7f319e031 100644 --- a/examples/onnx/rounding_ops/network.onnx +++ b/examples/onnx/rounding_ops/network.onnx @@ -1,10 +1,11 @@ -pytorch2.0.1:â +pytorch2.2.2:ã  woutput_w/Round"Round  xoutput_x/Floor"Floor  -youtput_y/Ceil"Ceil torch_jitZ% +youtput_y/Ceil"Ceil +main_graphZ% w   diff --git a/src/circuit/ops/hybrid.rs b/src/circuit/ops/hybrid.rs index a145e5997..5f58031a0 100644 --- a/src/circuit/ops/hybrid.rs +++ b/src/circuit/ops/hybrid.rs @@ -13,6 +13,18 @@ use serde::{Deserialize, Serialize}; /// An enum representing the operations that consist of both lookups and arithmetic operations. #[derive(Clone, Debug, Serialize, Deserialize)] pub enum HybridOp { + Ceil { + scale: utils::F32, + legs: usize, + }, + Floor { + scale: utils::F32, + legs: usize, + }, + Round { + scale: utils::F32, + legs: usize, + }, Recip { input_scale: utils::F32, output_scale: utils::F32, @@ -96,6 +108,9 @@ impl Op for Hybrid fn as_string(&self) -> String { match self { + HybridOp::Ceil { scale, legs } => format!("CEIL(scale={}, legs={})", scale, legs), + HybridOp::Floor { scale, legs } => format!("FLOOR(scale={}, legs={})", scale, legs), + HybridOp::Round { scale, legs } => format!("ROUND(scale={}, legs={})", scale, legs), HybridOp::Max => format!("MAX"), HybridOp::Min => format!("MIN"), HybridOp::Recip { @@ -166,6 +181,15 @@ impl Op for Hybrid values: &[ValTensor], ) -> Result>, CircuitError> { Ok(Some(match self { + HybridOp::Ceil { scale, legs } => { + layouts::ceil(config, region, values[..].try_into()?, *scale, *legs)? + } + HybridOp::Floor { scale, legs } => { + layouts::floor(config, region, values[..].try_into()?, *scale, *legs)? + } + HybridOp::Round { scale, legs } => { + layouts::round(config, region, values[..].try_into()?, *scale, *legs)? + } HybridOp::Max => layouts::max_comp(config, region, values[..].try_into()?)?, HybridOp::Min => layouts::min_comp(config, region, values[..].try_into()?)?, HybridOp::SumPool { diff --git a/src/circuit/ops/layouts.rs b/src/circuit/ops/layouts.rs index 685ae1df9..fd8f4f31f 100644 --- a/src/circuit/ops/layouts.rs +++ b/src/circuit/ops/layouts.rs @@ -4155,8 +4155,40 @@ pub(crate) fn argmin( Ok(assigned_argmin) } -/// max layout -pub(crate) fn max_comp( +/// Max layout +/// # Arguments +/// * `config` - BaseConfig +/// * `region` - RegionCtx +/// * `values` - &[ValTensor; 2] +/// # Returns +/// * ValTensor +/// # Example +/// ``` +/// use ezkl::tensor::Tensor; +/// use ezkl::fieldutils::IntegerRep; +/// use ezkl::circuit::ops::layouts::max_comp; +/// use ezkl::tensor::val::ValTensor; +/// use halo2curves::bn256::Fr as Fp; +/// use ezkl::circuit::region::RegionCtx; +/// use ezkl::circuit::region::RegionSettings; +/// use ezkl::circuit::BaseConfig; +/// let dummy_config = BaseConfig::dummy(12, 2); +/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2)); +/// let x = ValTensor::from_integer_rep_tensor(Tensor::::new( +/// Some(&[5, 2, 3, 0]), +/// &[1, 1, 2, 2], +/// ).unwrap()); +/// let y = ValTensor::from_integer_rep_tensor(Tensor::::new( +/// Some(&[5, 1, 1, 1]), +/// &[1, 1, 2, 2], +/// ).unwrap()); +/// +/// let result = max_comp::(&dummy_config, &mut dummy_region, &[x, y]).unwrap(); +/// let expected = Tensor::::new(Some(&[5, 2, 3, 1]), &[1, 1, 2, 2]).unwrap(); +/// assert_eq!(result.int_evals().unwrap(), expected); +/// ``` +/// +pub fn max_comp( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor; 2], @@ -4176,8 +4208,38 @@ pub(crate) fn max_comp( +/// Min comp layout +/// # Arguments +/// * `config` - BaseConfig +/// * `region` - RegionCtx +/// * `values` - &[ValTensor; 2] +/// # Returns +/// * ValTensor +/// # Example +/// ``` +/// use ezkl::tensor::Tensor; +/// use ezkl::fieldutils::IntegerRep; +/// use ezkl::circuit::ops::layouts::min_comp; +/// use ezkl::tensor::val::ValTensor; +/// use halo2curves::bn256::Fr as Fp; +/// use ezkl::circuit::region::RegionCtx; +/// use ezkl::circuit::region::RegionSettings; +/// use ezkl::circuit::BaseConfig; +/// let dummy_config = BaseConfig::dummy(12, 2); +/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2)); +/// let x = ValTensor::from_integer_rep_tensor(Tensor::::new( +/// Some(&[5, 2, 3, 0]), +/// &[1, 1, 2, 2], +/// ).unwrap()); +/// let y = ValTensor::from_integer_rep_tensor(Tensor::::new( +/// Some(&[5, 1, 1, 1]), +/// &[1, 1, 2, 2], +/// ).unwrap()); +/// let result = min_comp::(&dummy_config, &mut dummy_region, &[x, y]).unwrap(); +/// let expected = Tensor::::new(Some(&[5, 1, 1, 0]), &[1, 1, 2, 2]).unwrap(); +/// assert_eq!(result.int_evals().unwrap(), expected); +/// ``` +pub fn min_comp( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor; 2], @@ -4220,6 +4282,435 @@ pub(crate) fn min( .map_err(|e| e.into()) } +/// floor layout +/// # Arguments +/// * `config` - BaseConfig +/// * `region` - RegionCtx +/// * `values` - &[ValTensor; 1] +/// * `scale` - utils::F32 +/// * `legs` - usize +/// # Returns +/// * ValTensor +/// # Example +/// ``` +/// use ezkl::tensor::Tensor; +/// use ezkl::fieldutils::IntegerRep; +/// use ezkl::circuit::ops::layouts::floor; +/// use ezkl::tensor::val::ValTensor; +/// use halo2curves::bn256::Fr as Fp; +/// use ezkl::circuit::region::RegionCtx; +/// use ezkl::circuit::region::RegionSettings; +/// use ezkl::circuit::BaseConfig; +/// let dummy_config = BaseConfig::dummy(12, 2); +/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2)); +/// let x = ValTensor::from_integer_rep_tensor(Tensor::::new( +/// Some(&[3, -2, -3, 1]), +/// &[1, 1, 2, 2], +/// ).unwrap()); +/// let result = floor::(&dummy_config, &mut dummy_region, &[x], 2.0.into(), 2).unwrap(); +/// let expected = Tensor::::new(Some(&[2, -2, -4, 0]), &[1, 1, 2, 2]).unwrap(); +/// assert_eq!(result.int_evals().unwrap(), expected); +/// ``` +pub fn floor( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 1], + scale: utils::F32, + legs: usize, +) -> Result, CircuitError> { + // decompose with base scale and then set the last element to zero + let decomposition = decompose(config, region, values, &(scale.0 as usize), &legs)?; + // set the last element to zero and then recompose + let zero = create_constant_tensor(F::ZERO, 1); + + let assigned_zero = region.assign(&config.custom_gates.inputs[1], &zero)?; + let assigned_zero = assigned_zero.get_inner_tensor()?[0].clone(); + let negative_one = create_constant_tensor(integer_rep_to_felt(-1), 1); + let assigned_negative_one = region.assign(&config.custom_gates.inputs[1], &negative_one)?; + + let dims = decomposition.dims().to_vec(); + let first_dims = decomposition.dims().to_vec()[..decomposition.dims().len() - 1].to_vec(); + + let mut incremented_tensor = Tensor::new(None, &first_dims)?; + + let cartesian_coord = first_dims + .iter() + .map(|x| 0..*x) + .multi_cartesian_product() + .collect::>(); + + let inner_loop_function = + |i: usize, region: &mut RegionCtx| -> Result>, CircuitError> { + let coord = cartesian_coord[i].clone(); + let slice = coord.iter().map(|x| *x..*x + 1).collect::>(); + let mut sliced_input = decomposition.get_slice(&slice)?; + sliced_input.flatten(); + let last_elem = sliced_input.last()?; + + let last_elem_is_zero = equals_zero(config, region, &[last_elem.clone()])?; + let last_elem_is_not_zero = not(config, region, &[last_elem_is_zero.clone()])?; + + let sign = sliced_input.first()?; + let is_negative = equals(config, region, &[sign, assigned_negative_one.clone()])?; + + let is_negative_and_not_zero = and( + config, + region, + &[last_elem_is_not_zero.clone(), is_negative.clone()], + )?; + + // increment the penultimate element + let incremented_elem = pairwise( + config, + region, + &[ + sliced_input.get_slice(&[sliced_input.len() - 2..sliced_input.len() - 1])?, + is_negative_and_not_zero.clone(), + ], + BaseOp::Add, + )?; + + let mut inner_tensor = sliced_input.get_inner_tensor()?.clone(); + inner_tensor[sliced_input.len() - 2] = + incremented_elem.get_inner_tensor()?.clone()[0].clone(); + + // set the last elem to zero + inner_tensor[sliced_input.len() - 1] = assigned_zero.clone(); + + Ok(inner_tensor.clone()) + }; + + region.apply_in_loop(&mut incremented_tensor, inner_loop_function)?; + + let mut incremented_tensor = incremented_tensor.combine()?; + incremented_tensor.reshape(&dims)?; + + recompose( + config, + region, + &[incremented_tensor.into()], + &(scale.0 as usize), + ) +} + +/// ceil layout +/// # Arguments +/// * `config` - BaseConfig +/// * `region` - RegionCtx +/// * `values` - &[ValTensor; 1] +/// * `scale` - utils::F32 +/// * `legs` - usize +/// # Returns +/// * ValTensor +/// # Example +/// ``` +/// use ezkl::tensor::Tensor; +/// use ezkl::fieldutils::IntegerRep; +/// use ezkl::circuit::ops::layouts::ceil; +/// use ezkl::tensor::val::ValTensor; +/// use halo2curves::bn256::Fr as Fp; +/// use ezkl::circuit::region::RegionCtx; +/// use ezkl::circuit::region::RegionSettings; +/// use ezkl::circuit::BaseConfig; +/// let dummy_config = BaseConfig::dummy(12, 2); +/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2)); +/// let x = ValTensor::from_integer_rep_tensor(Tensor::::new( +/// Some(&[3, -2, 3, 1]), +/// &[1, 1, 2, 2], +/// ).unwrap()); +/// let result = ceil::(&dummy_config, &mut dummy_region, &[x], 2.0.into(), 2).unwrap(); +/// let expected = Tensor::::new(Some(&[4, -2, 4, 2]), &[1, 1, 2, 2]).unwrap(); +/// assert_eq!(result.int_evals().unwrap(), expected); +/// ``` +/// +pub fn ceil( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 1], + scale: utils::F32, + legs: usize, +) -> Result, CircuitError> { + // decompose with base scale and then set the last element to zero + let decomposition = decompose(config, region, values, &(scale.0 as usize), &legs)?; + // set the last element to zero and then recompose + let zero = create_constant_tensor(F::ZERO, 1); + + let assigned_zero = region.assign(&config.custom_gates.inputs[1], &zero)?; + let assigned_zero = assigned_zero.get_inner_tensor()?[0].clone(); + let one = create_constant_tensor(integer_rep_to_felt(1), 1); + let assigned_one = region.assign(&config.custom_gates.inputs[1], &one)?; + + let dims = decomposition.dims().to_vec(); + let first_dims = decomposition.dims().to_vec()[..decomposition.dims().len() - 1].to_vec(); + + let mut incremented_tensor = Tensor::new(None, &first_dims)?; + + let cartesian_coord = first_dims + .iter() + .map(|x| 0..*x) + .multi_cartesian_product() + .collect::>(); + + let inner_loop_function = + |i: usize, region: &mut RegionCtx| -> Result>, CircuitError> { + let coord = cartesian_coord[i].clone(); + let slice = coord.iter().map(|x| *x..*x + 1).collect::>(); + let mut sliced_input = decomposition.get_slice(&slice)?; + sliced_input.flatten(); + let last_elem = sliced_input.last()?; + + let last_elem_is_zero = equals_zero(config, region, &[last_elem.clone()])?; + let last_elem_is_not_zero = not(config, region, &[last_elem_is_zero.clone()])?; + + let sign = sliced_input.first()?; + let is_positive = equals(config, region, &[sign, assigned_one.clone()])?; + + let is_positive_and_not_zero = and( + config, + region, + &[last_elem_is_not_zero.clone(), is_positive.clone()], + )?; + + // increment the penultimate element + let incremented_elem = pairwise( + config, + region, + &[ + sliced_input.get_slice(&[sliced_input.len() - 2..sliced_input.len() - 1])?, + is_positive_and_not_zero.clone(), + ], + BaseOp::Add, + )?; + + let mut inner_tensor = sliced_input.get_inner_tensor()?.clone(); + inner_tensor[sliced_input.len() - 2] = + incremented_elem.get_inner_tensor()?.clone()[0].clone(); + + // set the last elem to zero + inner_tensor[sliced_input.len() - 1] = assigned_zero.clone(); + + Ok(inner_tensor.clone()) + }; + + region.apply_in_loop(&mut incremented_tensor, inner_loop_function)?; + + let mut incremented_tensor = incremented_tensor.combine()?; + incremented_tensor.reshape(&dims)?; + + recompose( + config, + region, + &[incremented_tensor.into()], + &(scale.0 as usize), + ) +} + +/// round layout +/// # Arguments +/// * `config` - BaseConfig +/// * `region` - RegionCtx +/// * `values` - &[ValTensor; 1] +/// * `scale` - utils::F32 +/// * `legs` - usize +/// # Returns +/// * ValTensor +/// # Example +/// ``` +/// use ezkl::tensor::Tensor; +/// use ezkl::fieldutils::IntegerRep; +/// use ezkl::circuit::ops::layouts::round; +/// use ezkl::tensor::val::ValTensor; +/// use halo2curves::bn256::Fr as Fp; +/// use ezkl::circuit::region::RegionCtx; +/// use ezkl::circuit::region::RegionSettings; +/// use ezkl::circuit::BaseConfig; +/// let dummy_config = BaseConfig::dummy(12, 2); +/// let mut dummy_region = RegionCtx::new_dummy(0,2,RegionSettings::all_true(128,2)); +/// let x = ValTensor::from_integer_rep_tensor(Tensor::::new( +/// Some(&[3, -2, 3, 1]), +/// &[1, 1, 2, 2], +/// ).unwrap()); +/// let result = round::(&dummy_config, &mut dummy_region, &[x], 4.0.into(), 2).unwrap(); +/// let expected = Tensor::::new(Some(&[4, -4, 4, 0]), &[1, 1, 2, 2]).unwrap(); +/// assert_eq!(result.int_evals().unwrap(), expected); +/// ``` +/// +pub fn round( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 1], + scale: utils::F32, + legs: usize, +) -> Result, CircuitError> { + // decompose with base scale and then set the last element to zero + let decomposition = decompose(config, region, values, &(scale.0 as usize), &legs)?; + // set the last element to zero and then recompose + let zero = create_constant_tensor(F::ZERO, 1); + + let assigned_zero = region.assign(&config.custom_gates.inputs[1], &zero)?; + let assigned_zero = assigned_zero.get_inner_tensor()?[0].clone(); + let one = create_constant_tensor(integer_rep_to_felt(1), 1); + let assigned_one = region.assign(&config.custom_gates.inputs[1], &one)?; + let negative_one = create_constant_tensor(integer_rep_to_felt(-1), 1); + let assigned_negative_one = region.assign(&config.custom_gates.inputs[1], &negative_one)?; + + // if scale is not exactly divisible by 2 we warn + if scale.0 % 2.0 != 0.0 { + log::warn!("Scale is not exactly divisible by 2.0, rounding may not be accurate"); + } + + let midway_point: ValTensor = create_constant_tensor( + integer_rep_to_felt((scale.0 / 2.0).round() as IntegerRep), + 1, + ); + let assigned_midway_point = region.assign(&config.custom_gates.inputs[1], &midway_point)?; + + let dims = decomposition.dims().to_vec(); + let first_dims = decomposition.dims().to_vec()[..decomposition.dims().len() - 1].to_vec(); + + let mut incremented_tensor = Tensor::new(None, &first_dims)?; + + let cartesian_coord = first_dims + .iter() + .map(|x| 0..*x) + .multi_cartesian_product() + .collect::>(); + + let inner_loop_function = + |i: usize, region: &mut RegionCtx| -> Result>, CircuitError> { + let coord = cartesian_coord[i].clone(); + let slice = coord.iter().map(|x| *x..*x + 1).collect::>(); + let mut sliced_input = decomposition.get_slice(&slice)?; + sliced_input.flatten(); + let last_elem = sliced_input.last()?; + + let sign = sliced_input.first()?; + let is_positive = equals(config, region, &[sign.clone(), assigned_one.clone()])?; + let is_negative = equals(config, region, &[sign, assigned_negative_one.clone()])?; + + let is_greater_than_midway = greater_equal( + config, + region, + &[last_elem.clone(), assigned_midway_point.clone()], + )?; + + // if greater than midway point and positive, increment + let is_positive_and_more_than_midway = and( + config, + region, + &[is_positive.clone(), is_greater_than_midway.clone()], + )?; + + // is less than midway point and negative, decrement + let is_negative_and_more_than_midway = and( + config, + region, + &[is_negative.clone(), is_greater_than_midway], + )?; + + let conditions_for_increment = or( + config, + region, + &[ + is_positive_and_more_than_midway.clone(), + is_negative_and_more_than_midway.clone(), + ], + )?; + + // increment the penultimate element + let incremented_elem = pairwise( + config, + region, + &[ + sliced_input.get_slice(&[sliced_input.len() - 2..sliced_input.len() - 1])?, + conditions_for_increment.clone(), + ], + BaseOp::Add, + )?; + + let mut inner_tensor = sliced_input.get_inner_tensor()?.clone(); + inner_tensor[sliced_input.len() - 2] = + incremented_elem.get_inner_tensor()?.clone()[0].clone(); + + // set the last elem to zero + inner_tensor[sliced_input.len() - 1] = assigned_zero.clone(); + + Ok(inner_tensor.clone()) + }; + + region.apply_in_loop(&mut incremented_tensor, inner_loop_function)?; + + let mut incremented_tensor = incremented_tensor.combine()?; + incremented_tensor.reshape(&dims)?; + + recompose( + config, + region, + &[incremented_tensor.into()], + &(scale.0 as usize), + ) +} + +pub(crate) fn recompose( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 1], + base: &usize, +) -> Result, CircuitError> { + let input = values[0].clone(); + + let first_dims = input.dims().to_vec()[..input.dims().len() - 1].to_vec(); + let n = input.dims().last().unwrap() - 1; + + let is_assigned = !input.all_prev_assigned(); + + let bases: ValTensor = Tensor::from( + (0..n) + .rev() + .map(|x| ValType::Constant(integer_rep_to_felt(base.pow(x as u32) as IntegerRep))), + ) + .into(); + + // multiply and sum the values + let mut output: Tensor>> = Tensor::new(None, &first_dims)?; + + let cartesian_coord = first_dims + .iter() + .map(|x| 0..*x) + .multi_cartesian_product() + .collect::>(); + + let inner_loop_function = + |i: usize, region: &mut RegionCtx| -> Result>, CircuitError> { + let coord = cartesian_coord[i].clone(); + let slice = coord.iter().map(|x| *x..*x + 1).collect::>(); + let mut sliced_input = input.get_slice(&slice)?; + sliced_input.flatten(); + + if !is_assigned { + sliced_input = region.assign(&config.custom_gates.inputs[0], &sliced_input)?; + } + + // get the sign bit and make sure it is valid + let sign = sliced_input.first()?; + let rest = sliced_input.get_slice(&[1..sliced_input.len()])?; + + let prod_decomp = dot(config, region, &[rest, bases.clone()])?; + + let signed_decomp = pairwise(config, region, &[prod_decomp, sign], BaseOp::Mult)?; + + Ok(signed_decomp.get_inner_tensor()?.clone()) + }; + + region.apply_in_loop(&mut output, inner_loop_function)?; + + let mut combined_output = output.combine()?; + + combined_output.reshape(&first_dims)?; + + Ok(combined_output.into()) +} + pub(crate) fn decompose( config: &BaseConfig, region: &mut RegionCtx, diff --git a/src/circuit/ops/lookup.rs b/src/circuit/ops/lookup.rs index 825eb0806..9558b5464 100644 --- a/src/circuit/ops/lookup.rs +++ b/src/circuit/ops/lookup.rs @@ -17,9 +17,6 @@ use halo2curves::ff::PrimeField; pub enum LookupOp { Div { denom: utils::F32 }, Cast { scale: utils::F32 }, - Ceil { scale: utils::F32 }, - Floor { scale: utils::F32 }, - Round { scale: utils::F32 }, RoundHalfToEven { scale: utils::F32 }, Sqrt { scale: utils::F32 }, Rsqrt { scale: utils::F32 }, @@ -54,9 +51,6 @@ impl LookupOp { /// as path pub fn as_path(&self) -> String { match self { - LookupOp::Ceil { scale } => format!("ceil_{}", scale), - LookupOp::Floor { scale } => format!("floor_{}", scale), - LookupOp::Round { scale } => format!("round_{}", scale), LookupOp::RoundHalfToEven { scale } => format!("round_half_to_even_{}", scale), LookupOp::Pow { scale, a } => format!("pow_{}_{}", scale, a), LookupOp::Div { denom } => format!("div_{}", denom), @@ -91,15 +85,6 @@ impl LookupOp { let x = x[0].clone().map(|x| felt_to_integer_rep(x)); let res = match &self { - LookupOp::Ceil { scale } => { - Ok::<_, TensorError>(tensor::ops::nonlinearities::ceil(&x, scale.into())) - } - LookupOp::Floor { scale } => { - Ok::<_, TensorError>(tensor::ops::nonlinearities::floor(&x, scale.into())) - } - LookupOp::Round { scale } => { - Ok::<_, TensorError>(tensor::ops::nonlinearities::round(&x, scale.into())) - } LookupOp::RoundHalfToEven { scale } => Ok::<_, TensorError>( tensor::ops::nonlinearities::round_half_to_even(&x, scale.into()), ), @@ -186,9 +171,6 @@ impl Op for Lookup /// Returns the name of the operation fn as_string(&self) -> String { match self { - LookupOp::Ceil { scale } => format!("CEIL(scale={})", scale), - LookupOp::Floor { scale } => format!("FLOOR(scale={})", scale), - LookupOp::Round { scale } => format!("ROUND(scale={})", scale), LookupOp::RoundHalfToEven { scale } => format!("ROUND_HALF_TO_EVEN(scale={})", scale), LookupOp::Pow { a, scale } => format!("POW(scale={}, exponent={})", scale, a), LookupOp::Div { denom, .. } => format!("DIV(denom={})", denom), diff --git a/src/graph/utilities.rs b/src/graph/utilities.rs index aa6b73f8e..be4b42841 100644 --- a/src/graph/utilities.rs +++ b/src/graph/utilities.rs @@ -1083,14 +1083,17 @@ pub fn new_op_from_onnx( pool_dims: kernel_shape.to_vec(), }) } - "Ceil" => SupportedOp::Nonlinear(LookupOp::Ceil { + "Ceil" => SupportedOp::Hybrid(HybridOp::Ceil { scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(), + legs: run_args.decomp_legs, }), - "Floor" => SupportedOp::Nonlinear(LookupOp::Floor { + "Floor" => SupportedOp::Hybrid(HybridOp::Floor { scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(), + legs: run_args.decomp_legs, }), - "Round" => SupportedOp::Nonlinear(LookupOp::Round { + "Round" => SupportedOp::Hybrid(HybridOp::Round { scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(), + legs: run_args.decomp_legs, }), "RoundHalfToEven" => SupportedOp::Nonlinear(LookupOp::RoundHalfToEven { scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(), diff --git a/src/tensor/ops.rs b/src/tensor/ops.rs index 292a28046..2666191ad 100644 --- a/src/tensor/ops.rs +++ b/src/tensor/ops.rs @@ -27,7 +27,7 @@ pub fn get_rep( n: usize, ) -> Result, DecompositionError> { // check if x is too large - if x.abs() > (base.pow(n as u32) as IntegerRep) { + if x.abs() > (base.pow(n as u32) as IntegerRep) - 1 { return Err(DecompositionError::TooLarge(*x, base, n)); } let mut rep = vec![0; n + 1]; @@ -1421,85 +1421,6 @@ pub fn slice( pub mod nonlinearities { use super::*; - /// Ceiling operator. - /// # Arguments - /// * `a` - Tensor - /// * `scale` - Single value - /// # Examples - /// ``` - /// use ezkl::tensor::Tensor; - /// use ezkl::fieldutils::IntegerRep; - /// - /// use ezkl::tensor::ops::nonlinearities::ceil; - /// let x = Tensor::::new( - /// Some(&[1, 2, 3, 4, 5, 6]), - /// &[3, 2], - /// ).unwrap(); - /// let result = ceil(&x, 2.0); - /// let expected = Tensor::::new(Some(&[2, 2, 4, 4, 6, 6]), &[3, 2]).unwrap(); - /// assert_eq!(result, expected); - /// ``` - pub fn ceil(a: &Tensor, scale: f64) -> Tensor { - a.par_enum_map(|_, a_i| { - let kix = (a_i as f64) / scale; - let rounded = kix.ceil() * scale; - Ok::<_, TensorError>(rounded as IntegerRep) - }) - .unwrap() - } - - /// Floor operator. - /// # Arguments - /// * `a` - Tensor - /// * `scale` - Single value - /// # Examples - /// ``` - /// use ezkl::tensor::Tensor; - /// use ezkl::fieldutils::IntegerRep; - /// use ezkl::tensor::ops::nonlinearities::floor; - /// let x = Tensor::::new( - /// Some(&[1, 2, 3, 4, 5, 6]), - /// &[3, 2], - /// ).unwrap(); - /// let result = floor(&x, 2.0); - /// let expected = Tensor::::new(Some(&[0, 2, 2, 4, 4, 6]), &[3, 2]).unwrap(); - /// assert_eq!(result, expected); - /// ``` - pub fn floor(a: &Tensor, scale: f64) -> Tensor { - a.par_enum_map(|_, a_i| { - let kix = (a_i as f64) / scale; - let rounded = kix.floor() * scale; - Ok::<_, TensorError>(rounded as IntegerRep) - }) - .unwrap() - } - - /// Round operator. - /// # Arguments - /// * `a` - Tensor - /// * `scale` - Single value - /// # Examples - /// ``` - /// use ezkl::tensor::Tensor; - /// use ezkl::fieldutils::IntegerRep; - /// use ezkl::tensor::ops::nonlinearities::round; - /// let x = Tensor::::new( - /// Some(&[1, 2, 3, 4, 5, 6]), - /// &[3, 2], - /// ).unwrap(); - /// let result = round(&x, 2.0); - /// let expected = Tensor::::new(Some(&[2, 2, 4, 4, 6, 6]), &[3, 2]).unwrap(); - /// assert_eq!(result, expected); - /// ``` - pub fn round(a: &Tensor, scale: f64) -> Tensor { - a.par_enum_map(|_, a_i| { - let kix = (a_i as f64) / scale; - let rounded = kix.round() * scale; - Ok::<_, TensorError>(rounded as IntegerRep) - }) - .unwrap() - } - /// Round half to even operator. /// # Arguments /// * `a` - Tensor @@ -1721,27 +1642,6 @@ pub mod nonlinearities { .unwrap() } - /// Elementwise applies sign to a tensor of integers. - /// # Arguments - /// * `a` - Tensor - /// # Examples - /// ``` - /// use ezkl::tensor::Tensor; - /// use ezkl::fieldutils::IntegerRep; - /// use ezkl::tensor::ops::nonlinearities::sign; - /// let x = Tensor::::new( - /// Some(&[-2, 15, 2, 1, 1, 0]), - /// &[2, 3], - /// ).unwrap(); - /// let result = sign(&x); - /// let expected = Tensor::::new(Some(&[-1, 1, 1, 1, 1, 0]), &[2, 3]).unwrap(); - /// assert_eq!(result, expected); - /// ``` - pub fn sign(a: &Tensor) -> Tensor { - a.par_enum_map(|_, a_i| Ok::<_, TensorError>(a_i.signum())) - .unwrap() - } - /// Elementwise applies square root to a tensor of integers. /// # Arguments /// @@ -2225,101 +2125,6 @@ pub mod nonlinearities { .unwrap() } - /// Elementwise applies leaky relu to a tensor of integers. - /// # Arguments - /// - /// * `a` - Tensor - /// * `scale` - Single value - /// * `slope` - Single value - /// # Examples - /// ``` - /// use ezkl::tensor::Tensor; - /// use ezkl::fieldutils::IntegerRep; - /// use ezkl::tensor::ops::nonlinearities::leakyrelu; - /// let x = Tensor::::new( - /// Some(&[2, 15, 2, 1, 1, -5]), - /// &[2, 3], - /// ).unwrap(); - /// let result = leakyrelu(&x, 0.1); - /// let expected = Tensor::::new(Some(&[2, 15, 2, 1, 1, -1]), &[2, 3]).unwrap(); - /// assert_eq!(result, expected); - /// ``` - pub fn leakyrelu(a: &Tensor, slope: f64) -> Tensor { - a.par_enum_map(|_, a_i| { - let rounded = if a_i < 0 { - let d_inv_x = (slope) * (a_i as f64); - d_inv_x.round() as IntegerRep - } else { - let d_inv_x = a_i as f64; - d_inv_x.round() as IntegerRep - }; - Ok::<_, TensorError>(rounded) - }) - .unwrap() - } - - /// Elementwise applies max to a tensor of integers. - /// # Arguments - /// * `a` - Tensor - /// * `b` - scalar - /// # Examples - /// ``` - /// use ezkl::tensor::Tensor; - /// use ezkl::fieldutils::IntegerRep; - /// use ezkl::tensor::ops::nonlinearities::max; - /// let x = Tensor::::new( - /// Some(&[2, 15, 2, 1, 1, -5]), - /// &[2, 3], - /// ).unwrap(); - /// let result = max(&x, 1.0, 1.0); - /// let expected = Tensor::::new(Some(&[2, 15, 2, 1, 1, 1]), &[2, 3]).unwrap(); - /// assert_eq!(result, expected); - /// ``` - pub fn max(a: &Tensor, scale_input: f64, threshold: f64) -> Tensor { - // calculate value of output - a.par_enum_map(|_, a_i| { - let d_inv_x = (a_i as f64) / scale_input; - let rounded = if d_inv_x <= threshold { - (threshold * scale_input).round() as IntegerRep - } else { - (d_inv_x * scale_input).round() as IntegerRep - }; - Ok::<_, TensorError>(rounded) - }) - .unwrap() - } - - /// Elementwise applies min to a tensor of integers. - /// # Arguments - /// * `a` - Tensor - /// * `b` - scalar - /// # Examples - /// ``` - /// use ezkl::tensor::Tensor; - /// use ezkl::fieldutils::IntegerRep; - /// use ezkl::tensor::ops::nonlinearities::min; - /// let x = Tensor::::new( - /// Some(&[2, 15, 2, 1, 1, -5]), - /// &[2, 3], - /// ).unwrap(); - /// let result = min(&x, 1.0, 2.0); - /// let expected = Tensor::::new(Some(&[2, 2, 2, 1, 1, -5]), &[2, 3]).unwrap(); - /// assert_eq!(result, expected); - /// ``` - pub fn min(a: &Tensor, scale_input: f64, threshold: f64) -> Tensor { - // calculate value of output - a.par_enum_map(|_, a_i| { - let d_inv_x = (a_i as f64) / scale_input; - let rounded = if d_inv_x >= threshold { - (threshold * scale_input).round() as IntegerRep - } else { - (d_inv_x * scale_input).round() as IntegerRep - }; - Ok::<_, TensorError>(rounded) - }) - .unwrap() - } - /// Elementwise divides a tensor with a const integer element. /// # Arguments /// @@ -2400,104 +2205,6 @@ pub mod nonlinearities { }) .unwrap() } - - /// Elementwise greater than - /// # Arguments - /// - /// * `a` - Tensor - /// * `b` - Single value - /// # Examples - /// ``` - /// use ezkl::tensor::Tensor; - /// use ezkl::fieldutils::IntegerRep; - /// use ezkl::tensor::ops::nonlinearities::greater_than; - /// let x = Tensor::::new( - /// Some(&[2, 1, 2, 7, 1, 1]), - /// &[2, 3], - /// ).unwrap(); - /// let k = 2.0; - /// let result = greater_than(&x, k); - /// let expected = Tensor::::new(Some(&[0, 0, 0, 1, 0, 0]), &[2, 3]).unwrap(); - /// assert_eq!(result, expected); - /// ``` - pub fn greater_than(a: &Tensor, b: f64) -> Tensor { - a.par_enum_map(|_, a_i| Ok::<_, TensorError>(IntegerRep::from((a_i as f64 - b) > 0_f64))) - .unwrap() - } - - /// Elementwise greater than - /// # Arguments - /// - /// * `a` - Tensor - /// * `b` - Single value - /// # Examples - /// ``` - /// use ezkl::tensor::Tensor; - /// use ezkl::fieldutils::IntegerRep; - /// use ezkl::tensor::ops::nonlinearities::greater_than_equal; - /// let x = Tensor::::new( - /// Some(&[2, 1, 2, 7, 1, 1]), - /// &[2, 3], - /// ).unwrap(); - /// let k = 2.0; - /// let result = greater_than_equal(&x, k); - /// let expected = Tensor::::new(Some(&[1, 0, 1, 1, 0, 0]), &[2, 3]).unwrap(); - /// assert_eq!(result, expected); - /// ``` - pub fn greater_than_equal(a: &Tensor, b: f64) -> Tensor { - a.par_enum_map(|_, a_i| Ok::<_, TensorError>(IntegerRep::from((a_i as f64 - b) >= 0_f64))) - .unwrap() - } - - /// Elementwise less than - /// # Arguments - /// * `a` - Tensor - /// * `b` - Single value - /// # Examples - /// ``` - /// use ezkl::tensor::Tensor; - /// use ezkl::fieldutils::IntegerRep; - /// use ezkl::tensor::ops::nonlinearities::less_than; - /// - /// let x = Tensor::::new( - /// Some(&[2, 1, 2, 7, 1, 1]), - /// &[2, 3], - /// ).unwrap(); - /// let k = 2.0; - /// - /// let result = less_than(&x, k); - /// let expected = Tensor::::new(Some(&[0, 1, 0, 0, 1, 1]), &[2, 3]).unwrap(); - /// assert_eq!(result, expected); - /// ``` - pub fn less_than(a: &Tensor, b: f64) -> Tensor { - a.par_enum_map(|_, a_i| Ok::<_, TensorError>(IntegerRep::from((a_i as f64 - b) < 0_f64))) - .unwrap() - } - - /// Elementwise less than - /// # Arguments - /// * `a` - Tensor - /// * `b` - Single value - /// # Examples - /// ``` - /// use ezkl::tensor::Tensor; - /// use ezkl::fieldutils::IntegerRep; - /// use ezkl::tensor::ops::nonlinearities::less_than_equal; - /// - /// let x = Tensor::::new( - /// Some(&[2, 1, 2, 7, 1, 1]), - /// &[2, 3], - /// ).unwrap(); - /// let k = 2.0; - /// - /// let result = less_than_equal(&x, k); - /// let expected = Tensor::::new(Some(&[1, 1, 1, 0, 1, 1]), &[2, 3]).unwrap(); - /// assert_eq!(result, expected); - /// ``` - pub fn less_than_equal(a: &Tensor, b: f64) -> Tensor { - a.par_enum_map(|_, a_i| Ok::<_, TensorError>(IntegerRep::from((a_i as f64 - b) <= 0_f64))) - .unwrap() - } } /// Ops that return the transcript i.e intermediate calcs of an op