From 10b0db6df65c974713af9ed999ee0a70ef611e77 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Wed, 13 Dec 2023 14:28:02 +0100 Subject: [PATCH] DeconvUnary->Deconv --- core/src/ops/cnn/deconv/{unary.rs => deconv.rs} | 12 ++++++------ core/src/ops/cnn/deconv/mod.rs | 5 +++-- core/src/ops/cnn/mod.rs | 4 ++-- harness/core-proptest-pulse/src/deconv.rs | 4 ++-- nnef/src/ops/nnef/deser.rs | 4 ++-- nnef/src/ops/nnef/ser.rs | 8 ++++---- onnx/src/ops/nn/conv_transpose.rs | 4 ++-- pulse/src/ops/cnn/deconv.rs | 10 +++++----- test-rt/suite-unit/src/deconv.rs | 4 ++-- 9 files changed, 28 insertions(+), 27 deletions(-) rename core/src/ops/cnn/deconv/{unary.rs => deconv.rs} (97%) diff --git a/core/src/ops/cnn/deconv/unary.rs b/core/src/ops/cnn/deconv/deconv.rs similarity index 97% rename from core/src/ops/cnn/deconv/unary.rs rename to core/src/ops/cnn/deconv/deconv.rs index 7620063686..fe68816514 100644 --- a/core/src/ops/cnn/deconv/unary.rs +++ b/core/src/ops/cnn/deconv/deconv.rs @@ -6,14 +6,14 @@ use crate::ops::cnn::wire_reshape_bias_for_bin; use crate::ops::einsum::EinSum; #[derive(Clone, Debug, new, Hash)] -pub struct DeconvUnary { +pub struct Deconv { pub pool_spec: PoolSpec, pub kernel_format: KernelFormat, pub adjustments: TVec, pub group: usize, } -impl DeconvUnary { +impl Deconv { fn wire_with_deconv_sum( &self, name: &str, @@ -123,9 +123,9 @@ impl DeconvUnary { } } -impl Op for DeconvUnary { +impl Op for Deconv { fn name(&self) -> Cow { - "DeconvUnary".into() + "Deconv".into() } fn info(&self) -> TractResult> { @@ -135,7 +135,7 @@ impl Op for DeconvUnary { op_as_typed_op!(); } -impl EvalOp for DeconvUnary { +impl EvalOp for Deconv { fn is_stateless(&self) -> bool { true } @@ -154,7 +154,7 @@ impl EvalOp for DeconvUnary { } } -impl TypedOp for DeconvUnary { +impl TypedOp for Deconv { fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { ensure!(inputs.len() == 3); let x_fact = inputs[0]; diff --git a/core/src/ops/cnn/deconv/mod.rs b/core/src/ops/cnn/deconv/mod.rs index ed56fd3f06..c2d8bf1c61 100644 --- a/core/src/ops/cnn/deconv/mod.rs +++ b/core/src/ops/cnn/deconv/mod.rs @@ -1,10 +1,11 @@ use crate::internal::*; use crate::ops::cnn::{PaddingSpec, PoolSpec}; +#[allow(clippy::module_inception)] +mod deconv; mod deconv_sum; -mod unary; -pub use unary::DeconvUnary; +pub use deconv::Deconv; pub fn output_shape( pool_spec: &PoolSpec, diff --git a/core/src/ops/cnn/mod.rs b/core/src/ops/cnn/mod.rs index ef6a0ebca7..0870027605 100644 --- a/core/src/ops/cnn/mod.rs +++ b/core/src/ops/cnn/mod.rs @@ -10,7 +10,7 @@ pub mod pools; mod sumpool; pub use self::conv::{Conv, KernelFormat}; -pub use self::deconv::DeconvUnary; +pub use self::deconv::Deconv; pub use self::maxpool::MaxPool; pub use self::padding::PaddingSpec; pub use self::patch_axis::PatchAxis; @@ -95,7 +95,7 @@ pub fn rewrite_deconv_with_n_axis( model: &TypedModel, node: &TypedNode, name: &str, - deconv: &DeconvUnary, + deconv: &Deconv, ) -> TractResult> { if !deconv.pool_spec.data_format.has_n() { let mut new = deconv.clone(); diff --git a/harness/core-proptest-pulse/src/deconv.rs b/harness/core-proptest-pulse/src/deconv.rs index 90f726501f..9c629b0944 100644 --- a/harness/core-proptest-pulse/src/deconv.rs +++ b/harness/core-proptest-pulse/src/deconv.rs @@ -14,7 +14,7 @@ struct DeconvOp { impl DeconvOp { fn chain(&self, name: &str, model: &mut TypedModel, after: OutletId) -> OutletId { - let deconv = tract_core::ops::cnn::DeconvUnary { + let deconv = tract_core::ops::cnn::Deconv { pool_spec: PoolSpec { data_format: DataFormat::NCHW, kernel_shape: tvec!(self.ker.shape()[2]), @@ -262,7 +262,7 @@ fn deconv2d() { let a = model.add_source("a", f32::fact(dims!(1, 2, s, 8))).unwrap(); let mut kernel = Tensor::zero::(&[2, 2, 1, 3]).unwrap(); kernel.as_slice_mut::().unwrap().iter_mut().enumerate().for_each(|(ix, x)| *x = ix as f32); - let deconv = tract_core::ops::cnn::DeconvUnary { + let deconv = tract_core::ops::cnn::Deconv { pool_spec: PoolSpec { data_format: DataFormat::NCHW, kernel_shape: tvec!(1, 3), diff --git a/nnef/src/ops/nnef/deser.rs b/nnef/src/ops/nnef/deser.rs index 321305d50f..c90f29db0b 100644 --- a/nnef/src/ops/nnef/deser.rs +++ b/nnef/src/ops/nnef/deser.rs @@ -1,6 +1,6 @@ use crate::ast::*; use crate::deser::Value; -use ops::cnn::deconv::DeconvUnary; +use ops::cnn::deconv::Deconv; use ops::cnn::{Conv, KernelFormat}; use tract_core::internal::*; use tract_core::ops::array::PadMode; @@ -367,7 +367,7 @@ pub fn conv_or_deconv( } else { tvec!(0; pool_spec.rank()) }; - Box::new(DeconvUnary::new(pool_spec, KernelFormat::OIHW, adjustments, group)) + Box::new(Deconv::new(pool_spec, KernelFormat::OIHW, adjustments, group)) } else { if let Some(odt) = &output_dt { for dt in &[&input_fact.datum_type, &kernel_fact.datum_type, odt] { diff --git a/nnef/src/ops/nnef/ser.rs b/nnef/src/ops/nnef/ser.rs index cc8c1208c0..17b980e36f 100644 --- a/nnef/src/ops/nnef/ser.rs +++ b/nnef/src/ops/nnef/ser.rs @@ -6,7 +6,7 @@ use tract_core::num_traits::Zero; use tract_core::ops; use tract_core::ops::cast::cast; use tract_core::ops::cnn::Conv; -use tract_core::ops::cnn::DeconvUnary; +use tract_core::ops::cnn::Deconv; use tract_core::ops::cnn::KernelFormat; use tract_core::ops::cnn::PoolSpec; use tract_core::ops::einsum::BasicMatMul; @@ -255,7 +255,7 @@ pub fn conv( pub fn deconv( ast: &mut IntoAst, node: &TypedNode, - op: &ops::cnn::deconv::DeconvUnary, + op: &ops::cnn::deconv::Deconv, ) -> TractResult>> { conv_or_deconv(ast, node, &op.pool_spec, op.group, true, Some(&op.adjustments)) } @@ -516,7 +516,7 @@ pub fn rewrite_kernel_deconv_in_oihw( model: &TypedModel, node: &TypedNode, name: &str, - conv: &DeconvUnary, + conv: &Deconv, ) -> TractResult> { rewrite_kernel_in_oihw( model, @@ -524,7 +524,7 @@ pub fn rewrite_kernel_deconv_in_oihw( name, conv.kernel_format, conv.group, - Box::new(DeconvUnary { kernel_format: KernelFormat::OIHW, ..conv.clone() }), + Box::new(Deconv { kernel_format: KernelFormat::OIHW, ..conv.clone() }), ) } diff --git a/onnx/src/ops/nn/conv_transpose.rs b/onnx/src/ops/nn/conv_transpose.rs index def3a2d6b0..823327fa93 100644 --- a/onnx/src/ops/nn/conv_transpose.rs +++ b/onnx/src/ops/nn/conv_transpose.rs @@ -168,14 +168,14 @@ impl Expansion for ConvTranspose { &x_shape.as_concrete().context("expects concrete dim for deconv")?[2..], output_shape, )?; - tract_core::ops::cnn::DeconvUnary::new( + tract_core::ops::cnn::Deconv::new( pool_spec, KernelFormat::OIHW, adjustments, self.group, ) } else { - tract_core::ops::cnn::DeconvUnary::new( + tract_core::ops::cnn::Deconv::new( pool_spec, KernelFormat::OIHW, self.adjustments.clone().unwrap_or_else(|| tvec!(0; kernel_shape.len() - 2)), diff --git a/pulse/src/ops/cnn/deconv.rs b/pulse/src/ops/cnn/deconv.rs index 8f9b9bb60b..c51cfb1df7 100644 --- a/pulse/src/ops/cnn/deconv.rs +++ b/pulse/src/ops/cnn/deconv.rs @@ -1,14 +1,14 @@ use crate::internal::*; use tract_core::num_traits::Zero; -use tract_core::ops::cnn::DeconvUnary; +use tract_core::ops::cnn::Deconv; use tract_core::ops::cnn::PaddingSpec; use tract_pulse_opl::ops::DeconvDelay; use tract_pulse_opl::ops::PulseMask; -register_all!(DeconvUnary: pulsify); +register_all!(Deconv: pulsify); fn pulsify( - op: &DeconvUnary, + op: &Deconv, source: &TypedModel, node: &TypedNode, target: &mut PulsedModel, @@ -97,12 +97,12 @@ fn pulsify( Ok(Some(wire)) } -fn overlap(pulse_axis: usize, op: &DeconvUnary) -> usize { +fn overlap(pulse_axis: usize, op: &Deconv) -> usize { let geo_axis = pulse_axis - op.pool_spec.data_format.h_axis(); (op.pool_spec.kernel_shape[geo_axis] - 1) * op.pool_spec.dilation(geo_axis) } -impl PulsedOp for DeconvUnary { +impl PulsedOp for Deconv { fn pulsed_output_facts(&self, inputs: &[&PulsedFact]) -> TractResult> { let mut fact = inputs[0].clone(); let stream = fact.stream.as_mut().unwrap(); diff --git a/test-rt/suite-unit/src/deconv.rs b/test-rt/suite-unit/src/deconv.rs index 47dc5963f8..01287788f3 100644 --- a/test-rt/suite-unit/src/deconv.rs +++ b/test-rt/suite-unit/src/deconv.rs @@ -135,7 +135,7 @@ impl Arbitrary for DeconvProblem { } impl DeconvProblem { - fn as_op(&self) -> TractResult { + fn as_op(&self) -> TractResult { let pool_spec = PoolSpec::new( self.data_format, self.kernel_format.spatial_shape(self.kernel.shape()).into(), @@ -146,7 +146,7 @@ impl DeconvProblem { self.kernel_format.output_channels(self.kernel.shape(), self.group).into_owned(), ); let op = - DeconvUnary::new(pool_spec, self.kernel_format, self.adjustments.clone(), self.group); + Deconv::new(pool_spec, self.kernel_format, self.adjustments.clone(), self.group); Ok(op) }