From 06b007f13a6119ef62ba3185c7414eb86dc6f6b5 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Tue, 10 Oct 2023 15:40:28 +0200 Subject: [PATCH] refactor deconv --- Cargo.toml | 1 + core/src/ops/array/broadcast.rs | 33 +++-------- core/src/ops/cnn/deconv/deconv_sum.rs | 82 +++++++++++++-------------- data/src/tensor.rs | 18 ++++++ test-rt/suite-unit/src/deconv.rs | 16 ++++++ 5 files changed, 80 insertions(+), 70 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 867d4bc423..2c5e1a94af 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -126,6 +126,7 @@ inherits="release" lto=false [profile.release] +debug = true lto = true [profile.bench] diff --git a/core/src/ops/array/broadcast.rs b/core/src/ops/array/broadcast.rs index d35a1e4ed8..f4c862c6c2 100644 --- a/core/src/ops/array/broadcast.rs +++ b/core/src/ops/array/broadcast.rs @@ -5,23 +5,6 @@ pub struct MultiBroadcastTo { pub shape: ShapeFact, } - - -impl MultiBroadcastTo { - pub fn eval_t(input: &Tensor, shape: &[usize]) -> TractResult> { - unsafe { - let view = input.to_array_view_unchecked::(); - let mut output = view - .broadcast(shape) - .with_context(|| format!("Broadcasting {view:?} to {shape:?}"))? - .into_owned() - .into_tensor(); - output.set_datum_type(input.datum_type()); - Ok(tvec![output.into_tvalue()]) - } - } -} - impl Op for MultiBroadcastTo { fn name(&self) -> Cow { "MultiBroadcastTo".into() @@ -39,7 +22,8 @@ impl EvalOp for MultiBroadcastTo { let input = args_1!(inputs); let dims: Vec = self.shape.iter().map(|d| d.to_usize()).collect::>()?; - dispatch_datum!(Self::eval_t(input.datum_type())(&*input, &*dims)) + let output = input.broadcast_to_shape(&dims)?; + Ok(tvec!(output.into_tvalue())) } fn state( @@ -64,9 +48,7 @@ impl OpState for MultiBroadcastToState { ) -> TractResult> { let op = op.downcast_ref::().context("Wrong op")?; let shape = op.shape.eval_to_usize(&session.resolved_symbols)?; - dispatch_datum_by_size!(MultiBroadcastTo::eval_t(inputs[0].datum_type())( - &inputs[0], &*shape - )) + Ok(tvec!(inputs[0].broadcast_to_shape(&*shape)?.into_tvalue())) } } @@ -92,10 +74,10 @@ impl TypedOp for MultiBroadcastTo { } fn declutter( - &self, - model: &TypedModel, - node: &TypedNode, - ) -> TractResult> { + &self, + model: &TypedModel, + node: &TypedNode, + ) -> TractResult> { let input_fact = model.outlet_fact(node.inputs[0])?; if input_fact.shape == self.shape { TypedModelPatch::shunt_one_op(model, node) @@ -106,4 +88,3 @@ impl TypedOp for MultiBroadcastTo { as_op!(); } - diff --git a/core/src/ops/cnn/deconv/deconv_sum.rs b/core/src/ops/cnn/deconv/deconv_sum.rs index 21633792d9..7a3e0b416f 100644 --- a/core/src/ops/cnn/deconv/deconv_sum.rs +++ b/core/src/ops/cnn/deconv/deconv_sum.rs @@ -42,11 +42,7 @@ impl EvalOp for DeconvSum { } fn eval(&self, inputs: TVec) -> TractResult> { - dispatch_floatlike!(Self::eval_with_values(inputs[0].datum_type())( - self, - inputs, - &SymbolValues::default() - )) + self.eval_with_values(inputs, &Default::default()) } fn state( @@ -65,23 +61,19 @@ impl OpState for DeconvSum { _op: &dyn Op, inputs: TVec, ) -> TractResult> { - dispatch_floatlike!(Self::eval_with_values(inputs[0].datum_type())( - self, - inputs, - &session.resolved_symbols - )) + self.eval_with_values(inputs, &session.resolved_symbols) } } trivial_op_state_freeeze!(DeconvSum); impl DeconvSum { - fn eval_with_values>( + fn eval_with_values( &self, inputs: TVec, values: &SymbolValues, ) -> TractResult> { let gemm = args_1!(inputs).into_tensor(); - debug_assert_eq!(gemm.datum_type(), T::datum_type()); + let dt = gemm.datum_type(); let input_shape = self.input_shape.eval_to_usize(values)?.into_owned(); let input_shape = self.pool_spec.data_format.shape(input_shape)?; let output_shape = @@ -95,33 +87,13 @@ impl DeconvSum { &self.adjustments, )?; let mut tensor = if let Some(b) = &self.bias { - if output_shape.shape[0..output_shape.c_axis()].iter().all(|d| *d == 1) { - unsafe { - let mut tensor = Tensor::uninitialized::(&output_shape.shape)?; - let values = b.as_ptr::()?; - let slice = tensor.as_ptr_mut::()?; - let stride = *output_shape.c_stride(); - for ix in 0..b.len() { - let v = *values.add(ix); - for p in 0..stride { - *slice.add(stride * ix + p) = v; - } - } - tensor - } - } else { - let mut tensor = Tensor::zero::(&output_shape.shape)?; - let mut output = tensor.to_array_view_mut::()?; - let mut bias_shape = tvec!(1; output_shape.rank()); - bias_shape[output_shape.c_axis()] = b.len(); - let b = b.clone().into_tensor().into_shape(&bias_shape)?; - output += &b.to_array_view::()?; - tensor - } + let mut bias_shape = tvec!(1; output_shape.rank()); + bias_shape[output_shape.c_axis()] = b.len(); + let b = b.clone().into_tensor().into_shape(&bias_shape)?; + b.broadcast_to_shape(&output_shape.shape)? } else { - Tensor::zero::(&output_shape.shape)? + Tensor::zero_dt(dt, &output_shape.shape)? }; - let mut output = tensor.to_array_view_mut::()?; let hw = *gemm.shape().last().unwrap(); let n = *output_shape.n().unwrap_or(&1); let n_o_hkwk_hw = gemm.into_shape(&[ @@ -130,10 +102,33 @@ impl DeconvSum { self.pool_spec.kernel_shape.iter().product(), hw, ])?; - let n_o_hkwk_hw: ArrayView4 = n_o_hkwk_hw.to_array_view::()?.into_dimensionality()?; if !self.pool_spec.data_format.has_n() { - output = output.insert_axis(Axis(0)); + tensor.insert_axis(0)?; + } + dispatch_floatlike!(Self::eval_t(dt)( + self, + &input_shape, + &output_shape, + &spatial_output_details, + &n_o_hkwk_hw, + &mut tensor + ))?; + if !self.pool_spec.data_format.has_n() { + tensor.remove_axis(0)?; } + Ok(tvec!(tensor.into_tvalue())) + } + + fn eval_t>( + &self, + input_shape: &DataShape, + output_shape: &DataShape, + spatial_output_details: &[ComputedPaddedDim], + n_o_hkwk_hw: &Tensor, + output: &mut Tensor, + ) -> TractResult<()> { + let output = output.to_array_view_mut::()?; + let n_o_hkwk_hw: ArrayView4 = n_o_hkwk_hw.to_array_view::()?.into_dimensionality()?; match input_shape.hw_rank() { 1 => self.main_loop_1d( &input_shape, @@ -141,30 +136,29 @@ impl DeconvSum { &spatial_output_details, &n_o_hkwk_hw, &mut output.into_dimensionality().unwrap(), - )?, + ), 2 => self.main_loop_2d( &input_shape, &output_shape, &spatial_output_details, &n_o_hkwk_hw, &mut output.into_dimensionality().unwrap(), - )?, + ), 3 => self.main_loop_3d( &input_shape, &output_shape, &spatial_output_details, &n_o_hkwk_hw, &mut output.into_dimensionality().unwrap(), - )?, + ), _ => self.main_loop( &input_shape, &output_shape, &spatial_output_details, &n_o_hkwk_hw, &mut output.into_dimensionality().unwrap(), - )?, + ), } - Ok(tvec!(tensor.into_tvalue())) } pub fn main_loop_1d( diff --git a/data/src/tensor.rs b/data/src/tensor.rs index 1c486ef9e8..e48263b6f1 100644 --- a/data/src/tensor.rs +++ b/data/src/tensor.rs @@ -2,6 +2,7 @@ use crate::datum::{round_ties_to_even, scale_by, Blob, ClampCast, Datum, DatumType, QParams}; use crate::dim::TDim; use crate::TVec; +use anyhow::Context; use half::f16; use itertools::Itertools; use ndarray::prelude::*; @@ -538,6 +539,23 @@ impl Tensor { } } + fn broadcast_to_shape_t(&self, shape: &[usize]) -> anyhow::Result { + unsafe { + let view = self.to_array_view_unchecked::(); + let mut output = view + .broadcast(shape) + .with_context(|| format!("Broadcasting {view:?} to {shape:?}"))? + .into_owned() + .into_tensor(); + output.set_datum_type(self.datum_type()); + Ok(output) + } + } + + pub fn broadcast_to_shape(&self, shape: &[usize]) -> anyhow::Result { + dispatch_datum!(Self::broadcast_to_shape_t(self.dt)(self, shape)) + } + fn clip_range_bounds( &self, axis: usize, diff --git a/test-rt/suite-unit/src/deconv.rs b/test-rt/suite-unit/src/deconv.rs index 2d4cf96bde..01759ad18d 100644 --- a/test-rt/suite-unit/src/deconv.rs +++ b/test-rt/suite-unit/src/deconv.rs @@ -525,6 +525,22 @@ pub fn suite() -> TractResult { }, ); + suite.add( + "bias_1", + DeconvProblem { + data_format: HWC, + kernel_format: OIHW, + padding: PaddingSpec::Valid, + input: arr2(&[[0.0], [0.0]]).into_dyn(), + kernel: arr3(&[[[0.0]]]).into_dyn(), + bias: Some(arr1(&[1.0f32]).into_dyn()), + strides: tvec!(1), + dilations: tvec!(1), + adjustments: tvec!(0), + group: 1, + }, + ); + suite.add( "rank_5_with_group", DeconvProblem {