Skip to content

Commit

Permalink
chore: unify leakyrelu and relu
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-camuto committed Oct 28, 2024
1 parent ebaee9e commit ca2adfe
Show file tree
Hide file tree
Showing 13 changed files with 122 additions and 67 deletions.
6 changes: 3 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -169,20 +169,20 @@ harness = false


[[bench]]
name = "relu"
name = "sigmoid"
harness = false

[[bench]]
name = "relu_lookupless"
harness = false

[[bench]]
name = "accum_matmul_relu"
name = "accum_matmul_sigmoid"
harness = false


[[bench]]
name = "accum_matmul_relu_overflow"
name = "accum_matmul_sigmoid_overflow"
harness = false

[[bin]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ impl Circuit<Fr> for MyCircuit {
&a,
BITS,
K,
&LookupOp::LeakyReLU { slope: 0.0.into() },
&LookupOp::Sigmoid { scale: 1.0.into() },
)
.unwrap();

Expand Down Expand Up @@ -93,7 +93,7 @@ impl Circuit<Fr> for MyCircuit {
.layout(
&mut region,
&[output.unwrap()],
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
)
.unwrap();
Ok(())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ impl Circuit<Fr> for MyCircuit {
&a,
BITS,
k,
&LookupOp::LeakyReLU { slope: 0.0.into() },
&LookupOp::Sigmoid { scale: 1.0.into() },
)
.unwrap();

Expand Down Expand Up @@ -94,7 +94,7 @@ impl Circuit<Fr> for MyCircuit {
.layout(
&mut region,
&[output.unwrap()],
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
)
.unwrap();
Ok(())
Expand Down
9 changes: 8 additions & 1 deletion benches/relu_lookupless.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,14 @@ impl Circuit<Fr> for NLCircuit {
|region| {
let mut region = RegionCtx::new(region, 0, 1, 1024, 2);
config
.layout(&mut region, &[self.input.clone()], Box::new(PolyOp::ReLU))
.layout(
&mut region,
&[self.input.clone()],
Box::new(PolyOp::LeakyReLU {
slope: 0.0.into(),
scale: 1,
}),
)
.unwrap();
Ok(())
},
Expand Down
4 changes: 2 additions & 2 deletions benches/relu.rs → benches/sigmoid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ impl Circuit<Fr> for NLCircuit {
.map(|_| VarTensor::new_advice(cs, K, 1, LEN))
.collect::<Vec<_>>();

let nl = LookupOp::LeakyReLU { slope: 0.0.into() };
let nl = LookupOp::Sigmoid { scale: 1.0.into() };

let mut config = Config::default();

Expand All @@ -68,7 +68,7 @@ impl Circuit<Fr> for NLCircuit {
.layout(
&mut region,
&[self.input.clone()],
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
)
.unwrap();
Ok(())
Expand Down
17 changes: 4 additions & 13 deletions examples/conv2d_mnist/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,18 +155,6 @@ where
CheckMode::SAFE,
);

layer_config
.configure_lookup(
cs,
&input,
&output,
&params,
(LOOKUP_MIN, LOOKUP_MAX),
K,
&LookupOp::LeakyReLU { slope: 0.0.into() },
)
.unwrap();

layer_config
.configure_lookup(
cs,
Expand Down Expand Up @@ -224,7 +212,10 @@ where
.layout(
&mut region,
&[x.unwrap()],
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
Box::new(PolyOp::LeakyReLU {
scale: 1,
slope: 0.0.into(),
}),
)
.unwrap();

Expand Down
23 changes: 8 additions & 15 deletions examples/mlp_4d_einsum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,19 +60,6 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
CheckMode::SAFE,
);

// sets up a new ReLU table and resuses it for l1 and l3 non linearities
layer_config
.configure_lookup(
cs,
&input,
&output,
&params,
(LOOKUP_MIN, LOOKUP_MAX),
K,
&LookupOp::LeakyReLU { slope: 0.0.into() },
)
.unwrap();

// sets up a new ReLU table and resuses it for l1 and l3 non linearities
layer_config
.configure_lookup(
Expand Down Expand Up @@ -144,7 +131,10 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
.layout(
&mut region,
&[x],
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
Box::new(PolyOp::LeakyReLU {
scale: 1,
slope: 0.0.into(),
}),
)
.unwrap()
.unwrap();
Expand Down Expand Up @@ -184,7 +174,10 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
.layout(
&mut region,
&[x],
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
Box::new(PolyOp::LeakyReLU {
scale: 1,
slope: 0.0.into(),
}),
)
.unwrap();
println!("6");
Expand Down
3 changes: 3 additions & 0 deletions src/circuit/ops/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,4 +94,7 @@ pub enum CircuitError {
#[error("[io] {0}")]
/// IO error
IoError(#[from] std::io::Error),
/// Invalid scale
#[error("negative scale for an op that requires positive inputs {0}")]
NegativeScale(String),
}
44 changes: 39 additions & 5 deletions src/circuit/ops/layouts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4305,7 +4305,6 @@ pub(crate) fn sign<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
) -> Result<ValTensor<F>, CircuitError> {
let mut decomp = decompose(config, region, values, &region.base(), &region.legs())?;
// get every n elements now, which correspond to the sign bit

decomp.get_every_n(region.legs() + 1)?;
decomp.reshape(values[0].dims())?;

Expand All @@ -4322,10 +4321,12 @@ pub(crate) fn abs<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
pairwise(config, region, &[values[0].clone(), sign], BaseOp::Mult)
}

pub(crate) fn relu<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
pub(crate) fn leaky_relu<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
alpha: &utils::F32,
input_scale: &i32,
) -> Result<ValTensor<F>, CircuitError> {
let sign = sign(config, region, values)?;

Expand All @@ -4334,12 +4335,45 @@ pub(crate) fn relu<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(

let relu_mask = equals(config, region, &[sign, unit])?;

pairwise(
let positive = pairwise(
config,
region,
&[values[0].clone(), relu_mask],
&[values[0].clone(), relu_mask.clone()],
BaseOp::Mult,
)
)?;

if alpha.0 == 0. {
return Ok(positive);
}

if input_scale < &0 {
return Err(CircuitError::NegativeScale("leaky_relu".to_string()));
}

let scale_constant = create_constant_tensor(F::from(2_i32.pow(*input_scale as u32) as u64), 1);

let rescaled_positive = pairwise(config, region, &[positive, scale_constant], BaseOp::Mult)?;

let neg_mask = not(config, region, &[relu_mask])?;

let quantized_alpha = quantize_tensor(
Tensor::from([alpha.0; 1].into_iter()),
*input_scale,
&crate::graph::Visibility::Fixed,
)?;

let alpha_tensor = create_constant_tensor(quantized_alpha[0], 1);

let scaled_neg_mask = pairwise(config, region, &[neg_mask, alpha_tensor], BaseOp::Mult)?;

let neg_part = pairwise(
config,
region,
&[values[0].clone(), scaled_neg_mask],
BaseOp::Mult,
)?;

pairwise(config, region, &[rescaled_positive, neg_part], BaseOp::Add)
}

fn multi_dim_axes_op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
Expand Down
8 changes: 0 additions & 8 deletions src/circuit/ops/lookup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,6 @@ pub enum LookupOp {
input_scale: utils::F32,
output_scale: utils::F32,
},
LeakyReLU {
slope: utils::F32,
},
Sigmoid {
scale: utils::F32,
},
Expand Down Expand Up @@ -127,7 +124,6 @@ impl LookupOp {
input_scale,
output_scale,
} => format!("recip_{}_{}", input_scale, output_scale),
LookupOp::LeakyReLU { slope: a } => format!("leaky_relu_{}", a),
LookupOp::Sigmoid { scale } => format!("sigmoid_{}", scale),
LookupOp::Sqrt { scale } => format!("sqrt_{}", scale),
LookupOp::Rsqrt { scale } => format!("rsqrt_{}", scale),
Expand Down Expand Up @@ -190,9 +186,6 @@ impl LookupOp {
input_scale.into(),
output_scale.into(),
)),
LookupOp::LeakyReLU { slope: a } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::leakyrelu(&x, a.0.into()))
}
LookupOp::Sigmoid { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::sigmoid(&x, scale.into()))
}
Expand Down Expand Up @@ -283,7 +276,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Lookup
LookupOp::Div { denom, .. } => format!("DIV(denom={})", denom),
LookupOp::Cast { scale } => format!("CAST(scale={})", scale),
LookupOp::Ln { scale } => format!("LN(scale={})", scale),
LookupOp::LeakyReLU { slope: a } => format!("L_RELU(slope={})", a),
LookupOp::Sigmoid { scale } => format!("SIGMOID(scale={})", scale),
LookupOp::Sqrt { scale } => format!("SQRT(scale={})", scale),
LookupOp::Erf { scale } => format!("ERF(scale={})", scale),
Expand Down
22 changes: 18 additions & 4 deletions src/circuit/ops/poly.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use crate::{
circuit::layouts,
circuit::{
layouts,
utils::{self, F32},
},
tensor::{self, Tensor, TensorError},
};

Expand All @@ -9,9 +12,12 @@ use super::{base::BaseOp, *};
/// An enum representing the operations that can be expressed as arithmetic (non lookup) operations.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum PolyOp {
ReLU,
Abs,
Sign,
LeakyReLU {
slope: utils::F32,
scale: i32,
},
GatherElements {
dim: usize,
constant_idx: Option<Tensor<usize>>,
Expand Down Expand Up @@ -112,9 +118,9 @@ impl<

fn as_string(&self) -> String {
match &self {
PolyOp::LeakyReLU { slope: a, .. } => format!("LEAKYRELU (slope={})", a),
PolyOp::Abs => "ABS".to_string(),
PolyOp::Sign => "SIGN".to_string(),
PolyOp::ReLU => "RELU".to_string(),
PolyOp::GatherElements { dim, constant_idx } => format!(
"GATHERELEMENTS (dim={}, constant_idx{})",
dim,
Expand Down Expand Up @@ -198,7 +204,9 @@ impl<
Ok(Some(match self {
PolyOp::Abs => layouts::abs(config, region, values[..].try_into()?)?,
PolyOp::Sign => layouts::sign(config, region, values[..].try_into()?)?,
PolyOp::ReLU => layouts::relu(config, region, values[..].try_into()?)?,
PolyOp::LeakyReLU { slope, scale } => {
layouts::leaky_relu(config, region, values[..].try_into()?, slope, scale)?
}
PolyOp::MultiBroadcastTo { shape } => {
layouts::expand(config, region, values[..].try_into()?, shape)?
}
Expand Down Expand Up @@ -329,6 +337,12 @@ impl<

fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
let scale = match self {
// this corresponds to the relu operation
PolyOp::LeakyReLU {
slope: F32(0.0), ..
} => in_scales[0],
// this corresponds to the leaky relu operation with a slope which induces a change in scale
PolyOp::LeakyReLU { scale, .. } => in_scales[0] + *scale,
PolyOp::MeanOfSquares { .. } => 2 * in_scales[0],
PolyOp::Xor | PolyOp::Or | PolyOp::And | PolyOp::Not => 0,
PolyOp::Iff => in_scales[1],
Expand Down
Loading

0 comments on commit ca2adfe

Please sign in to comment.