From 2ff44701f9e9d19abf877ea2ee6601cbc73da17e Mon Sep 17 00:00:00 2001 From: field-worker <151173028+field-worker@users.noreply.github.com> Date: Fri, 15 Dec 2023 12:27:24 +0100 Subject: [PATCH] chore: simplify the code for `impl Op for SupportedOp` (#659) --- src/graph/node.rs | 113 ++++++++++------------------------------------ 1 file changed, 23 insertions(+), 90 deletions(-) diff --git a/src/graph/node.rs b/src/graph/node.rs index fb1bb3fbb..d10e92754 100644 --- a/src/graph/node.rs +++ b/src/graph/node.rs @@ -338,6 +338,20 @@ impl SupportedOp { let op = self.clone_dyn(); super::homogenize_input_scales(op, in_scales, inputs_to_scale) } + + /// Since each associated value of `SupportedOp` implements `Op`, let's define a helper method to retrieve it. + fn as_op(&self) -> &dyn Op { + match self { + SupportedOp::Linear(op) => op, + SupportedOp::Nonlinear(op) => op, + SupportedOp::Hybrid(op) => op, + SupportedOp::Input(op) => op, + SupportedOp::Constant(op) => op, + SupportedOp::Unknown(op) => op, + SupportedOp::Rescaled(op) => op, + SupportedOp::RebaseScale(op) => op, + } + } } impl From>> for SupportedOp { @@ -383,16 +397,7 @@ impl Op for SupportedOp { &self, inputs: &[Tensor], ) -> Result, crate::tensor::TensorError> { - match self { - SupportedOp::Linear(op) => op.f(inputs), - SupportedOp::Nonlinear(op) => op.f(inputs), - SupportedOp::Hybrid(op) => op.f(inputs), - SupportedOp::Input(op) => op.f(inputs), - SupportedOp::Constant(op) => op.f(inputs), - SupportedOp::Unknown(op) => op.f(inputs), - SupportedOp::Rescaled(op) => op.f(inputs), - SupportedOp::RebaseScale(op) => op.f(inputs), - } + self.as_op().f(inputs) } fn layout( @@ -401,81 +406,27 @@ impl Op for SupportedOp { region: &mut crate::circuit::region::RegionCtx, values: &[crate::tensor::ValTensor], ) -> Result>, Box> { - match self { - SupportedOp::Linear(op) => op.layout(config, region, values), - SupportedOp::Nonlinear(op) => op.layout(config, region, values), - SupportedOp::Hybrid(op) => op.layout(config, region, values), - SupportedOp::Input(op) => op.layout(config, region, values), - SupportedOp::Constant(op) => op.layout(config, region, values), - SupportedOp::Unknown(op) => op.layout(config, region, values), - SupportedOp::Rescaled(op) => op.layout(config, region, values), - SupportedOp::RebaseScale(op) => op.layout(config, region, values), - } + self.as_op().layout(config, region, values) } fn is_input(&self) -> bool { - match self { - SupportedOp::Linear(op) => Op::::is_input(op), - SupportedOp::Nonlinear(op) => Op::::is_input(op), - SupportedOp::Hybrid(op) => Op::::is_input(op), - SupportedOp::Input(op) => Op::::is_input(op), - SupportedOp::Constant(op) => Op::::is_input(op), - SupportedOp::Unknown(op) => Op::::is_input(op), - SupportedOp::Rescaled(op) => Op::::is_input(op), - SupportedOp::RebaseScale(op) => Op::::is_input(op), - } + self.as_op().is_input() } fn is_constant(&self) -> bool { - match self { - SupportedOp::Linear(op) => Op::::is_constant(op), - SupportedOp::Nonlinear(op) => Op::::is_constant(op), - SupportedOp::Hybrid(op) => Op::::is_constant(op), - SupportedOp::Input(op) => Op::::is_constant(op), - SupportedOp::Constant(op) => Op::::is_constant(op), - SupportedOp::Unknown(op) => Op::::is_constant(op), - SupportedOp::Rescaled(op) => Op::::is_constant(op), - SupportedOp::RebaseScale(op) => Op::::is_constant(op), - } + self.as_op().is_constant() } fn requires_homogenous_input_scales(&self) -> Vec { - match self { - SupportedOp::Linear(op) => Op::::requires_homogenous_input_scales(op), - SupportedOp::Nonlinear(op) => Op::::requires_homogenous_input_scales(op), - SupportedOp::Hybrid(op) => Op::::requires_homogenous_input_scales(op), - SupportedOp::Input(op) => Op::::requires_homogenous_input_scales(op), - SupportedOp::Constant(op) => Op::::requires_homogenous_input_scales(op), - SupportedOp::Unknown(op) => Op::::requires_homogenous_input_scales(op), - SupportedOp::Rescaled(op) => Op::::requires_homogenous_input_scales(op), - SupportedOp::RebaseScale(op) => Op::::requires_homogenous_input_scales(op), - } + self.as_op().requires_homogenous_input_scales() } fn clone_dyn(&self) -> Box> { - match self { - SupportedOp::Linear(op) => Box::new(op.clone()), - SupportedOp::Nonlinear(op) => Box::new(op.clone()), - SupportedOp::Hybrid(op) => Box::new(op.clone()), - SupportedOp::Input(op) => Box::new(op.clone()), - SupportedOp::Constant(op) => Box::new(op.clone()), - SupportedOp::Unknown(op) => Box::new(op.clone()), - SupportedOp::Rescaled(op) => Box::new(op.clone()), - SupportedOp::RebaseScale(op) => Box::new(op.clone()), - } + self.as_op().clone_dyn() } fn as_string(&self) -> String { - match self { - SupportedOp::Linear(op) => Op::::as_string(op), - SupportedOp::Nonlinear(op) => Op::::as_string(op), - SupportedOp::Hybrid(op) => Op::::as_string(op), - SupportedOp::Input(op) => Op::::as_string(op), - SupportedOp::Constant(op) => Op::::as_string(op), - SupportedOp::Unknown(op) => Op::::as_string(op), - SupportedOp::Rescaled(op) => Op::::as_string(op), - SupportedOp::RebaseScale(op) => Op::::as_string(op), - } + self.as_op().as_string() } fn as_any(&self) -> &dyn std::any::Any { @@ -483,29 +434,11 @@ impl Op for SupportedOp { } fn required_lookups(&self) -> Vec { - match self { - SupportedOp::Linear(op) => Op::::required_lookups(op), - SupportedOp::Nonlinear(op) => Op::::required_lookups(op), - SupportedOp::Hybrid(op) => Op::::required_lookups(op), - SupportedOp::Input(op) => Op::::required_lookups(op), - SupportedOp::Constant(op) => Op::::required_lookups(op), - SupportedOp::Unknown(op) => Op::::required_lookups(op), - SupportedOp::Rescaled(op) => Op::::required_lookups(op), - SupportedOp::RebaseScale(op) => Op::::required_lookups(op), - } + self.as_op().required_lookups() } fn out_scale(&self, in_scales: Vec) -> Result> { - match self { - SupportedOp::Linear(op) => Op::::out_scale(op, in_scales), - SupportedOp::Nonlinear(op) => Op::::out_scale(op, in_scales), - SupportedOp::Hybrid(op) => Op::::out_scale(op, in_scales), - SupportedOp::Input(op) => Op::::out_scale(op, in_scales), - SupportedOp::Constant(op) => Op::::out_scale(op, in_scales), - SupportedOp::Unknown(op) => Op::::out_scale(op, in_scales), - SupportedOp::Rescaled(op) => Op::::out_scale(op, in_scales), - SupportedOp::RebaseScale(op) => Op::::out_scale(op, in_scales), - } + self.as_op().out_scale(in_scales) } }