From ca2adfeefadee34c831ba28b6a94397062f92b0e Mon Sep 17 00:00:00 2001 From: dante <45801863+alexander-camuto@users.noreply.github.com> Date: Mon, 28 Oct 2024 14:41:35 -0400 Subject: [PATCH] chore: unify leakyrelu and relu --- Cargo.toml | 6 +-- ...matmul_relu.rs => accum_matmul_sigmoid.rs} | 4 +- ...ow.rs => accum_matmul_sigmoid_overflow.rs} | 4 +- benches/relu_lookupless.rs | 9 +++- benches/{relu.rs => sigmoid.rs} | 4 +- examples/conv2d_mnist/main.rs | 17 ++----- examples/mlp_4d_einsum.rs | 23 ++++------ src/circuit/ops/errors.rs | 3 ++ src/circuit/ops/layouts.rs | 44 ++++++++++++++++--- src/circuit/ops/lookup.rs | 8 ---- src/circuit/ops/poly.rs | 22 ++++++++-- src/circuit/tests.rs | 37 +++++++++++----- src/graph/utilities.rs | 8 +++- 13 files changed, 122 insertions(+), 67 deletions(-) rename benches/{accum_matmul_relu.rs => accum_matmul_sigmoid.rs} (97%) rename benches/{accum_matmul_relu_overflow.rs => accum_matmul_sigmoid_overflow.rs} (97%) rename benches/{relu.rs => sigmoid.rs} (96%) diff --git a/Cargo.toml b/Cargo.toml index 8e55321bb..f54c88976 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -169,7 +169,7 @@ harness = false [[bench]] -name = "relu" +name = "sigmoid" harness = false [[bench]] @@ -177,12 +177,12 @@ name = "relu_lookupless" harness = false [[bench]] -name = "accum_matmul_relu" +name = "accum_matmul_sigmoid" harness = false [[bench]] -name = "accum_matmul_relu_overflow" +name = "accum_matmul_sigmoid_overflow" harness = false [[bin]] diff --git a/benches/accum_matmul_relu.rs b/benches/accum_matmul_sigmoid.rs similarity index 97% rename from benches/accum_matmul_relu.rs rename to benches/accum_matmul_sigmoid.rs index 32155e2e1..dba9e93aa 100644 --- a/benches/accum_matmul_relu.rs +++ b/benches/accum_matmul_sigmoid.rs @@ -64,7 +64,7 @@ impl Circuit for MyCircuit { &a, BITS, K, - &LookupOp::LeakyReLU { slope: 0.0.into() }, + &LookupOp::Sigmoid { scale: 1.0.into() }, ) .unwrap(); @@ -93,7 +93,7 @@ impl Circuit for MyCircuit { .layout( &mut region, &[output.unwrap()], - Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }), + Box::new(LookupOp::Sigmoid { scale: 1.0.into() }), ) .unwrap(); Ok(()) diff --git a/benches/accum_matmul_relu_overflow.rs b/benches/accum_matmul_sigmoid_overflow.rs similarity index 97% rename from benches/accum_matmul_relu_overflow.rs rename to benches/accum_matmul_sigmoid_overflow.rs index db55e5f86..a8ffe4c54 100644 --- a/benches/accum_matmul_relu_overflow.rs +++ b/benches/accum_matmul_sigmoid_overflow.rs @@ -65,7 +65,7 @@ impl Circuit for MyCircuit { &a, BITS, k, - &LookupOp::LeakyReLU { slope: 0.0.into() }, + &LookupOp::Sigmoid { scale: 1.0.into() }, ) .unwrap(); @@ -94,7 +94,7 @@ impl Circuit for MyCircuit { .layout( &mut region, &[output.unwrap()], - Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }), + Box::new(LookupOp::Sigmoid { scale: 1.0.into() }), ) .unwrap(); Ok(()) diff --git a/benches/relu_lookupless.rs b/benches/relu_lookupless.rs index f9610bbea..693ed9b5e 100644 --- a/benches/relu_lookupless.rs +++ b/benches/relu_lookupless.rs @@ -68,7 +68,14 @@ impl Circuit for NLCircuit { |region| { let mut region = RegionCtx::new(region, 0, 1, 1024, 2); config - .layout(&mut region, &[self.input.clone()], Box::new(PolyOp::ReLU)) + .layout( + &mut region, + &[self.input.clone()], + Box::new(PolyOp::LeakyReLU { + slope: 0.0.into(), + scale: 1, + }), + ) .unwrap(); Ok(()) }, diff --git a/benches/relu.rs b/benches/sigmoid.rs similarity index 96% rename from benches/relu.rs rename to benches/sigmoid.rs index 4a3c2dadd..bd13a06a5 100644 --- a/benches/relu.rs +++ b/benches/sigmoid.rs @@ -42,7 +42,7 @@ impl Circuit for NLCircuit { .map(|_| VarTensor::new_advice(cs, K, 1, LEN)) .collect::>(); - let nl = LookupOp::LeakyReLU { slope: 0.0.into() }; + let nl = LookupOp::Sigmoid { scale: 1.0.into() }; let mut config = Config::default(); @@ -68,7 +68,7 @@ impl Circuit for NLCircuit { .layout( &mut region, &[self.input.clone()], - Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }), + Box::new(LookupOp::Sigmoid { scale: 1.0.into() }), ) .unwrap(); Ok(()) diff --git a/examples/conv2d_mnist/main.rs b/examples/conv2d_mnist/main.rs index dd56a3c98..a9506a3fb 100644 --- a/examples/conv2d_mnist/main.rs +++ b/examples/conv2d_mnist/main.rs @@ -155,18 +155,6 @@ where CheckMode::SAFE, ); - layer_config - .configure_lookup( - cs, - &input, - &output, - ¶ms, - (LOOKUP_MIN, LOOKUP_MAX), - K, - &LookupOp::LeakyReLU { slope: 0.0.into() }, - ) - .unwrap(); - layer_config .configure_lookup( cs, @@ -224,7 +212,10 @@ where .layout( &mut region, &[x.unwrap()], - Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }), + Box::new(PolyOp::LeakyReLU { + scale: 1, + slope: 0.0.into(), + }), ) .unwrap(); diff --git a/examples/mlp_4d_einsum.rs b/examples/mlp_4d_einsum.rs index 5dadfe16b..bc3092942 100644 --- a/examples/mlp_4d_einsum.rs +++ b/examples/mlp_4d_einsum.rs @@ -60,19 +60,6 @@ impl( ) -> Result, CircuitError> { let mut decomp = decompose(config, region, values, ®ion.base(), ®ion.legs())?; // get every n elements now, which correspond to the sign bit - decomp.get_every_n(region.legs() + 1)?; decomp.reshape(values[0].dims())?; @@ -4322,10 +4321,12 @@ pub(crate) fn abs( pairwise(config, region, &[values[0].clone(), sign], BaseOp::Mult) } -pub(crate) fn relu( +pub(crate) fn leaky_relu( config: &BaseConfig, region: &mut RegionCtx, values: &[ValTensor; 1], + alpha: &utils::F32, + input_scale: &i32, ) -> Result, CircuitError> { let sign = sign(config, region, values)?; @@ -4334,12 +4335,45 @@ pub(crate) fn relu( let relu_mask = equals(config, region, &[sign, unit])?; - pairwise( + let positive = pairwise( config, region, - &[values[0].clone(), relu_mask], + &[values[0].clone(), relu_mask.clone()], BaseOp::Mult, - ) + )?; + + if alpha.0 == 0. { + return Ok(positive); + } + + if input_scale < &0 { + return Err(CircuitError::NegativeScale("leaky_relu".to_string())); + } + + let scale_constant = create_constant_tensor(F::from(2_i32.pow(*input_scale as u32) as u64), 1); + + let rescaled_positive = pairwise(config, region, &[positive, scale_constant], BaseOp::Mult)?; + + let neg_mask = not(config, region, &[relu_mask])?; + + let quantized_alpha = quantize_tensor( + Tensor::from([alpha.0; 1].into_iter()), + *input_scale, + &crate::graph::Visibility::Fixed, + )?; + + let alpha_tensor = create_constant_tensor(quantized_alpha[0], 1); + + let scaled_neg_mask = pairwise(config, region, &[neg_mask, alpha_tensor], BaseOp::Mult)?; + + let neg_part = pairwise( + config, + region, + &[values[0].clone(), scaled_neg_mask], + BaseOp::Mult, + )?; + + pairwise(config, region, &[rescaled_positive, neg_part], BaseOp::Add) } fn multi_dim_axes_op( diff --git a/src/circuit/ops/lookup.rs b/src/circuit/ops/lookup.rs index f6c30d1b3..d13540a7b 100644 --- a/src/circuit/ops/lookup.rs +++ b/src/circuit/ops/lookup.rs @@ -43,9 +43,6 @@ pub enum LookupOp { input_scale: utils::F32, output_scale: utils::F32, }, - LeakyReLU { - slope: utils::F32, - }, Sigmoid { scale: utils::F32, }, @@ -127,7 +124,6 @@ impl LookupOp { input_scale, output_scale, } => format!("recip_{}_{}", input_scale, output_scale), - LookupOp::LeakyReLU { slope: a } => format!("leaky_relu_{}", a), LookupOp::Sigmoid { scale } => format!("sigmoid_{}", scale), LookupOp::Sqrt { scale } => format!("sqrt_{}", scale), LookupOp::Rsqrt { scale } => format!("rsqrt_{}", scale), @@ -190,9 +186,6 @@ impl LookupOp { input_scale.into(), output_scale.into(), )), - LookupOp::LeakyReLU { slope: a } => { - Ok::<_, TensorError>(tensor::ops::nonlinearities::leakyrelu(&x, a.0.into())) - } LookupOp::Sigmoid { scale } => { Ok::<_, TensorError>(tensor::ops::nonlinearities::sigmoid(&x, scale.into())) } @@ -283,7 +276,6 @@ impl Op for Lookup LookupOp::Div { denom, .. } => format!("DIV(denom={})", denom), LookupOp::Cast { scale } => format!("CAST(scale={})", scale), LookupOp::Ln { scale } => format!("LN(scale={})", scale), - LookupOp::LeakyReLU { slope: a } => format!("L_RELU(slope={})", a), LookupOp::Sigmoid { scale } => format!("SIGMOID(scale={})", scale), LookupOp::Sqrt { scale } => format!("SQRT(scale={})", scale), LookupOp::Erf { scale } => format!("ERF(scale={})", scale), diff --git a/src/circuit/ops/poly.rs b/src/circuit/ops/poly.rs index 37baa9101..cda837238 100644 --- a/src/circuit/ops/poly.rs +++ b/src/circuit/ops/poly.rs @@ -1,5 +1,8 @@ use crate::{ - circuit::layouts, + circuit::{ + layouts, + utils::{self, F32}, + }, tensor::{self, Tensor, TensorError}, }; @@ -9,9 +12,12 @@ use super::{base::BaseOp, *}; /// An enum representing the operations that can be expressed as arithmetic (non lookup) operations. #[derive(Clone, Debug, Serialize, Deserialize)] pub enum PolyOp { - ReLU, Abs, Sign, + LeakyReLU { + slope: utils::F32, + scale: i32, + }, GatherElements { dim: usize, constant_idx: Option>, @@ -112,9 +118,9 @@ impl< fn as_string(&self) -> String { match &self { + PolyOp::LeakyReLU { slope: a, .. } => format!("LEAKYRELU (slope={})", a), PolyOp::Abs => "ABS".to_string(), PolyOp::Sign => "SIGN".to_string(), - PolyOp::ReLU => "RELU".to_string(), PolyOp::GatherElements { dim, constant_idx } => format!( "GATHERELEMENTS (dim={}, constant_idx{})", dim, @@ -198,7 +204,9 @@ impl< Ok(Some(match self { PolyOp::Abs => layouts::abs(config, region, values[..].try_into()?)?, PolyOp::Sign => layouts::sign(config, region, values[..].try_into()?)?, - PolyOp::ReLU => layouts::relu(config, region, values[..].try_into()?)?, + PolyOp::LeakyReLU { slope, scale } => { + layouts::leaky_relu(config, region, values[..].try_into()?, slope, scale)? + } PolyOp::MultiBroadcastTo { shape } => { layouts::expand(config, region, values[..].try_into()?, shape)? } @@ -329,6 +337,12 @@ impl< fn out_scale(&self, in_scales: Vec) -> Result { let scale = match self { + // this corresponds to the relu operation + PolyOp::LeakyReLU { + slope: F32(0.0), .. + } => in_scales[0], + // this corresponds to the leaky relu operation with a slope which induces a change in scale + PolyOp::LeakyReLU { scale, .. } => in_scales[0] + *scale, PolyOp::MeanOfSquares { .. } => 2 * in_scales[0], PolyOp::Xor | PolyOp::Or | PolyOp::And | PolyOp::Not => 0, PolyOp::Iff => in_scales[1], diff --git a/src/circuit/tests.rs b/src/circuit/tests.rs index 3ba499544..674fb5623 100644 --- a/src/circuit/tests.rs +++ b/src/circuit/tests.rs @@ -1379,7 +1379,10 @@ mod conv_relu_col_ultra_overflow { .layout( &mut region, &[output.unwrap().unwrap()], - Box::new(PolyOp::ReLU), + Box::new(PolyOp::LeakyReLU { + slope: 0.0.into(), + scale: 1, + }), ) .unwrap(); Ok(()) @@ -2347,7 +2350,14 @@ mod matmul_relu { .unwrap(); let _output = config .base_config - .layout(&mut region, &[output.unwrap()], Box::new(PolyOp::ReLU)) + .layout( + &mut region, + &[output.unwrap()], + Box::new(PolyOp::LeakyReLU { + slope: 0.0.into(), + scale: 1, + }), + ) .unwrap(); Ok(()) }, @@ -2439,7 +2449,14 @@ mod relu { |region| { let mut region = RegionCtx::new(region, 0, 1, 2, 2); Ok(config - .layout(&mut region, &[self.input.clone()], Box::new(PolyOp::ReLU)) + .layout( + &mut region, + &[self.input.clone()], + Box::new(PolyOp::LeakyReLU { + slope: 0.0.into(), + scale: 1, + }), + ) .unwrap()) }, ) @@ -2482,11 +2499,11 @@ mod lookup_ultra_overflow { use snark_verifier::system::halo2::transcript::evm::EvmTranscript; #[derive(Clone)] - struct ReLUCircuit { + struct SigmoidCircuit { pub input: ValTensor, } - impl Circuit for ReLUCircuit { + impl Circuit for SigmoidCircuit { type Config = BaseConfig; type FloorPlanner = SimpleFloorPlanner; type Params = TestParams; @@ -2500,7 +2517,7 @@ mod lookup_ultra_overflow { .map(|_| VarTensor::new_advice(cs, 4, 1, 3)) .collect::>(); - let nl = LookupOp::LeakyReLU { slope: 0.0.into() }; + let nl = LookupOp::Sigmoid { scale: 1.0.into() }; let mut config = BaseConfig::default(); @@ -2533,7 +2550,7 @@ mod lookup_ultra_overflow { .layout( &mut region, &[self.input.clone()], - Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }), + Box::new(LookupOp::Sigmoid { scale: 1.0.into() }), ) .map_err(|_| Error::Synthesis) }, @@ -2546,13 +2563,13 @@ mod lookup_ultra_overflow { #[test] #[ignore] - fn relucircuit() { + fn sigmoidcircuit() { // get some logs fam crate::logger::init_logger(); // parameters let a = Tensor::from((0..4).map(|i| Value::known(F::from(i + 1)))); - let circuit = ReLUCircuit:: { + let circuit = SigmoidCircuit:: { input: ValTensor::from(a), }; @@ -2562,7 +2579,7 @@ mod lookup_ultra_overflow { let pk = crate::pfsys::create_keys::< halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme, - ReLUCircuit, + SigmoidCircuit, >(&circuit, ¶ms, true) .unwrap(); diff --git a/src/graph/utilities.rs b/src/graph/utilities.rs index 94b08de1c..8aa9de6ab 100644 --- a/src/graph/utilities.rs +++ b/src/graph/utilities.rs @@ -781,7 +781,10 @@ pub fn new_op_from_onnx( node.decrement_use(); deleted_indices.push(const_idx); } - SupportedOp::Linear(PolyOp::ReLU) + SupportedOp::Linear(PolyOp::LeakyReLU { + slope: 0.0.into(), + scale: 1, + }) } else { SupportedOp::Hybrid(HybridOp::Max) } @@ -821,8 +824,9 @@ pub fn new_op_from_onnx( } }; - SupportedOp::Nonlinear(LookupOp::LeakyReLU { + SupportedOp::Linear(PolyOp::LeakyReLU { slope: crate::circuit::utils::F32(leaky_op.alpha), + scale: scales.params, }) } "Scan" => {