Skip to content

Commit

Permalink
chore: simplify the code for impl Op<Fp> for SupportedOp (#659)
Browse files Browse the repository at this point in the history
  • Loading branch information
field-worker authored Dec 15, 2023
1 parent faf4db2 commit 2ff4470
Showing 1 changed file with 23 additions and 90 deletions.
113 changes: 23 additions & 90 deletions src/graph/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,20 @@ impl SupportedOp {
let op = self.clone_dyn();
super::homogenize_input_scales(op, in_scales, inputs_to_scale)
}

/// Since each associated value of `SupportedOp` implements `Op`, let's define a helper method to retrieve it.
fn as_op(&self) -> &dyn Op<Fp> {
match self {
SupportedOp::Linear(op) => op,
SupportedOp::Nonlinear(op) => op,
SupportedOp::Hybrid(op) => op,
SupportedOp::Input(op) => op,
SupportedOp::Constant(op) => op,
SupportedOp::Unknown(op) => op,
SupportedOp::Rescaled(op) => op,
SupportedOp::RebaseScale(op) => op,
}
}
}

impl From<Box<dyn Op<Fp>>> for SupportedOp {
Expand Down Expand Up @@ -383,16 +397,7 @@ impl Op<Fp> for SupportedOp {
&self,
inputs: &[Tensor<Fp>],
) -> Result<crate::circuit::ForwardResult<Fp>, crate::tensor::TensorError> {
match self {
SupportedOp::Linear(op) => op.f(inputs),
SupportedOp::Nonlinear(op) => op.f(inputs),
SupportedOp::Hybrid(op) => op.f(inputs),
SupportedOp::Input(op) => op.f(inputs),
SupportedOp::Constant(op) => op.f(inputs),
SupportedOp::Unknown(op) => op.f(inputs),
SupportedOp::Rescaled(op) => op.f(inputs),
SupportedOp::RebaseScale(op) => op.f(inputs),
}
self.as_op().f(inputs)
}

fn layout(
Expand All @@ -401,111 +406,39 @@ impl Op<Fp> for SupportedOp {
region: &mut crate::circuit::region::RegionCtx<Fp>,
values: &[crate::tensor::ValTensor<Fp>],
) -> Result<Option<crate::tensor::ValTensor<Fp>>, Box<dyn Error>> {
match self {
SupportedOp::Linear(op) => op.layout(config, region, values),
SupportedOp::Nonlinear(op) => op.layout(config, region, values),
SupportedOp::Hybrid(op) => op.layout(config, region, values),
SupportedOp::Input(op) => op.layout(config, region, values),
SupportedOp::Constant(op) => op.layout(config, region, values),
SupportedOp::Unknown(op) => op.layout(config, region, values),
SupportedOp::Rescaled(op) => op.layout(config, region, values),
SupportedOp::RebaseScale(op) => op.layout(config, region, values),
}
self.as_op().layout(config, region, values)
}

fn is_input(&self) -> bool {
match self {
SupportedOp::Linear(op) => Op::<Fp>::is_input(op),
SupportedOp::Nonlinear(op) => Op::<Fp>::is_input(op),
SupportedOp::Hybrid(op) => Op::<Fp>::is_input(op),
SupportedOp::Input(op) => Op::<Fp>::is_input(op),
SupportedOp::Constant(op) => Op::<Fp>::is_input(op),
SupportedOp::Unknown(op) => Op::<Fp>::is_input(op),
SupportedOp::Rescaled(op) => Op::<Fp>::is_input(op),
SupportedOp::RebaseScale(op) => Op::<Fp>::is_input(op),
}
self.as_op().is_input()
}

fn is_constant(&self) -> bool {
match self {
SupportedOp::Linear(op) => Op::<Fp>::is_constant(op),
SupportedOp::Nonlinear(op) => Op::<Fp>::is_constant(op),
SupportedOp::Hybrid(op) => Op::<Fp>::is_constant(op),
SupportedOp::Input(op) => Op::<Fp>::is_constant(op),
SupportedOp::Constant(op) => Op::<Fp>::is_constant(op),
SupportedOp::Unknown(op) => Op::<Fp>::is_constant(op),
SupportedOp::Rescaled(op) => Op::<Fp>::is_constant(op),
SupportedOp::RebaseScale(op) => Op::<Fp>::is_constant(op),
}
self.as_op().is_constant()
}

fn requires_homogenous_input_scales(&self) -> Vec<usize> {
match self {
SupportedOp::Linear(op) => Op::<Fp>::requires_homogenous_input_scales(op),
SupportedOp::Nonlinear(op) => Op::<Fp>::requires_homogenous_input_scales(op),
SupportedOp::Hybrid(op) => Op::<Fp>::requires_homogenous_input_scales(op),
SupportedOp::Input(op) => Op::<Fp>::requires_homogenous_input_scales(op),
SupportedOp::Constant(op) => Op::<Fp>::requires_homogenous_input_scales(op),
SupportedOp::Unknown(op) => Op::<Fp>::requires_homogenous_input_scales(op),
SupportedOp::Rescaled(op) => Op::<Fp>::requires_homogenous_input_scales(op),
SupportedOp::RebaseScale(op) => Op::<Fp>::requires_homogenous_input_scales(op),
}
self.as_op().requires_homogenous_input_scales()
}

fn clone_dyn(&self) -> Box<dyn Op<Fp>> {
match self {
SupportedOp::Linear(op) => Box::new(op.clone()),
SupportedOp::Nonlinear(op) => Box::new(op.clone()),
SupportedOp::Hybrid(op) => Box::new(op.clone()),
SupportedOp::Input(op) => Box::new(op.clone()),
SupportedOp::Constant(op) => Box::new(op.clone()),
SupportedOp::Unknown(op) => Box::new(op.clone()),
SupportedOp::Rescaled(op) => Box::new(op.clone()),
SupportedOp::RebaseScale(op) => Box::new(op.clone()),
}
self.as_op().clone_dyn()
}

fn as_string(&self) -> String {
match self {
SupportedOp::Linear(op) => Op::<Fp>::as_string(op),
SupportedOp::Nonlinear(op) => Op::<Fp>::as_string(op),
SupportedOp::Hybrid(op) => Op::<Fp>::as_string(op),
SupportedOp::Input(op) => Op::<Fp>::as_string(op),
SupportedOp::Constant(op) => Op::<Fp>::as_string(op),
SupportedOp::Unknown(op) => Op::<Fp>::as_string(op),
SupportedOp::Rescaled(op) => Op::<Fp>::as_string(op),
SupportedOp::RebaseScale(op) => Op::<Fp>::as_string(op),
}
self.as_op().as_string()
}

fn as_any(&self) -> &dyn std::any::Any {
self
}

fn required_lookups(&self) -> Vec<LookupOp> {
match self {
SupportedOp::Linear(op) => Op::<Fp>::required_lookups(op),
SupportedOp::Nonlinear(op) => Op::<Fp>::required_lookups(op),
SupportedOp::Hybrid(op) => Op::<Fp>::required_lookups(op),
SupportedOp::Input(op) => Op::<Fp>::required_lookups(op),
SupportedOp::Constant(op) => Op::<Fp>::required_lookups(op),
SupportedOp::Unknown(op) => Op::<Fp>::required_lookups(op),
SupportedOp::Rescaled(op) => Op::<Fp>::required_lookups(op),
SupportedOp::RebaseScale(op) => Op::<Fp>::required_lookups(op),
}
self.as_op().required_lookups()
}

fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
match self {
SupportedOp::Linear(op) => Op::<Fp>::out_scale(op, in_scales),
SupportedOp::Nonlinear(op) => Op::<Fp>::out_scale(op, in_scales),
SupportedOp::Hybrid(op) => Op::<Fp>::out_scale(op, in_scales),
SupportedOp::Input(op) => Op::<Fp>::out_scale(op, in_scales),
SupportedOp::Constant(op) => Op::<Fp>::out_scale(op, in_scales),
SupportedOp::Unknown(op) => Op::<Fp>::out_scale(op, in_scales),
SupportedOp::Rescaled(op) => Op::<Fp>::out_scale(op, in_scales),
SupportedOp::RebaseScale(op) => Op::<Fp>::out_scale(op, in_scales),
}
self.as_op().out_scale(in_scales)
}
}

Expand Down

0 comments on commit 2ff4470

Please sign in to comment.