From deee548513907894364f675392bb4752e95c6920 Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Wed, 6 Dec 2023 09:34:29 +0100 Subject: [PATCH] wip --- core/src/floats.rs | 15 -- core/src/ops/change_axes.rs | 9 +- core/src/ops/cnn/conv/unary.rs | 188 +++++++++--------- core/src/ops/cnn/deconv/deconv_sum.rs | 19 +- core/src/ops/cnn/deconv/unary.rs | 91 ++++++--- core/src/ops/cnn/mod.rs | 40 ++++ core/src/ops/downsample/conv.rs | 4 +- core/src/plan.rs | 4 +- .../core-proptest-pulse/src/conv_plus_conv.rs | 7 +- harness/core-proptest-pulse/src/deconv.rs | 12 +- nnef/src/ops/core/qconv.rs | 63 ++---- nnef/src/ops/nnef/deser.rs | 20 +- nnef/src/ops/nnef/ser.rs | 33 +-- onnx/src/ops/nn/conv_transpose.rs | 132 ++++++------ pulse/src/ops/cnn/conv.rs | 7 +- pulse/src/ops/cnn/deconv.rs | 7 +- test-rt/suite-unit/src/deconv.rs | 52 ++++- 17 files changed, 356 insertions(+), 347 deletions(-) diff --git a/core/src/floats.rs b/core/src/floats.rs index 9a7be7d958..31550b4840 100644 --- a/core/src/floats.rs +++ b/core/src/floats.rs @@ -3,7 +3,6 @@ use tract_num_traits::Float; use crate::internal::translator::Translate; use crate::internal::*; use crate::ops::array::{Pad, PadMode}; -use crate::ops::cnn::{ConvUnary, DeconvUnary}; use crate::ops::einsum::EinSum; use crate::ops::konst::Const; use crate::ops::scan::Scan; @@ -27,14 +26,6 @@ impl Box::new(TypedSource::new(fact_float_precision_conversion::(&source.fact))) } else if let Some(konst) = node.op_as::() { Box::new(Const(tensor_float_precision_conversion::(&konst.0))) - /* - } else if let Some(op) = node.op_as::() { - Box::new(ConvUnary { - kernel: tensor_float_precision_conversion::(&op.kernel), - bias: op.bias.as_ref().map(tensor_float_precision_conversion::), - ..op.clone() - }) - */ } else if let Some(op) = node.op_as::() { let body = FloatPrecisionTranslator::::default().translate_model(&op.body)?; Box::new(Scan { body, ..op.clone() }) @@ -43,12 +34,6 @@ impl operating_dt: dt_float_precision_conversion::(op.operating_dt), ..op.clone() }) - } else if let Some(op) = node.op_as::() { - Box::new(DeconvUnary { - kernel: tensor_float_precision_conversion::(&op.kernel), - bias: op.bias.as_ref().map(tensor_float_precision_conversion::), - ..op.clone() - }) } else if let Some(op) = node.op_as::() { if let PadMode::Constant(t) = &op.mode { Box::new(Pad { diff --git a/core/src/ops/change_axes.rs b/core/src/ops/change_axes.rs index 5c2144e343..abcdac1ae5 100644 --- a/core/src/ops/change_axes.rs +++ b/core/src/ops/change_axes.rs @@ -267,6 +267,7 @@ impl AxisOp { shape.insert(*to, axis); } Reshape(at, from, to) => { + ensure!(from.iter().product::() == to.iter().product::()); if shape.len() >= from.len() + *at && tract_itertools::izip!(shape.iter().skip(*at), from) .all(|(shape, spec)| shape.to_dim() == *spec) @@ -435,7 +436,7 @@ impl AxisOp { pub fn wire_split_axis( model: &mut TypedModel, - name: &str, + name: impl ToString, outlet: OutletId, axis: usize, outer_dim: usize, @@ -444,12 +445,12 @@ impl AxisOp { let dim: TDim = fact.shape[axis].clone(); let inner_dim = dim.clone() / outer_dim; let op = Self::Reshape(axis, tvec!(dim.clone()), tvec!(outer_dim.to_dim(), inner_dim)); - model.wire_node(name, op, &[outlet]) + model.wire_node(name.to_string(), op, &[outlet]) } pub fn wire_collapse_axis( model: &mut TypedModel, - name: &str, + name: impl ToString, outlet: OutletId, axis: usize, ) -> TractResult> { @@ -457,7 +458,7 @@ impl AxisOp { let dim: TDim = fact.shape[axis].clone(); let next_dim: TDim = fact.shape[axis + 1].clone(); let op = Self::Reshape(axis, tvec!(dim.clone(), next_dim.clone()), tvec!(dim * next_dim)); - model.wire_node(name, op, &[outlet]) + model.wire_node(name.to_string(), op, &[outlet]) } } diff --git a/core/src/ops/cnn/conv/unary.rs b/core/src/ops/cnn/conv/unary.rs index d44337cccd..9a1a87a201 100644 --- a/core/src/ops/cnn/conv/unary.rs +++ b/core/src/ops/cnn/conv/unary.rs @@ -5,17 +5,15 @@ use tract_num_traits::Zero; use crate::internal::*; use crate::model::*; use crate::ops; -use crate::ops::array::MultiBroadcastTo; use crate::ops::array::Pad; use crate::ops::array::PadMode; use crate::ops::binary::TypedBinOp; use crate::ops::cast::cast; +use crate::ops::cnn::wire_reshape_bias; use crate::ops::cnn::PaddingSpec::*; use crate::ops::einsum::EinSum; -use crate::ops::math::Add; -use crate::ops::math::Div; -use crate::ops::math::Mul; -use crate::ops::math::Sub; +use crate::ops::math::{add, div, mul, sub}; +use crate::ops::math::{Add, Div, Mul, Sub}; use crate::ops::matmul::lir_unary::AddMatMulGeometry; use crate::ops::matmul::lir_unary::MapOutputAxisToInput; use crate::ops::matmul::mir_quant::wire_offset_u8_as_i8; @@ -159,7 +157,8 @@ impl ConvUnary { Reduce::new(tvec!(2), ops::nn::Reducer::Sum), &g_o_ihw_as_i32, )?; - let sum_ker_a_g_c = model.wire_node(format!("{name}.rm_k"), AxisOp::Rm(2), &sum_ker_g_c_k)?; + let sum_ker_a_g_c = + model.wire_node(format!("{name}.rm_k"), AxisOp::Rm(2), &sum_ker_g_c_k)?; // align sum_A from G,C to "C" shape: N,HW,G,C (or N,G,C,HW) let sum_ker_n_g_c = model.wire_node(format!("{name}.sum_ker_n_g_c"), AxisOp::Add(0), &sum_ker_a_g_c)?; @@ -185,7 +184,8 @@ impl ConvUnary { let x_dt = model.outlet_fact(x)?.datum_type; let (mmm_output_shape, c_axis, h_axis) = self.mmm_output_shape(&output_shape)?; let b_storage = unsafe { mmm.b_packed(x_dt.size_of(), k) }; - let bias = model.wire_node(format!("{name}.cast_bias"), cast(mmm.internal_type()), &[bias])?[0]; + let bias = + model.wire_node(format!("{name}.cast_bias"), cast(mmm.internal_type()), &[bias])?[0]; let wire = self.wire_mm_weights_bias( model, name, @@ -503,20 +503,15 @@ impl ConvUnary { name: &str, wire: &[OutletId], ) -> TractResult { - let &[x, kernel, bias] = wire else { bail!("Wrong number of inputs") }; + let &[x, kernel, mut bias] = wire else { bail!("Wrong number of inputs") }; let x_fact = model.outlet_fact(x)?.clone(); let x_shape = x_fact.shape.as_concrete().unwrap(); let ConcretePoolGeometry { input_shape, patch, output_shape } = self.pool_spec.compute_geo(&x_fact.shape)?.to_concrete(x_shape)?.into_owned(); let kernel = self.wire_kernel_as_g_o_ihw(model, name, kernel)?; - let bias_fact = model.outlet_fact(bias)?; - let need_bias_shape = tvec!(self.output_channels().to_dim()); - let bias_reshape: Box = if bias_fact.shape.volume().is_one() { - Box::new(MultiBroadcastTo { shape: need_bias_shape.into() }) - } else { - Box::new(AxisOp::Reshape(0, bias_fact.shape.to_tvec(), need_bias_shape)) - }; - let bias = model.wire_node(format!("{name}.bias"), bias_reshape, &[bias])?[0]; + let c_axis = self.pool_spec.data_format.shape(x_shape)?.c_axis(); + bias = + wire_reshape_bias(model, name, bias, x_fact.rank(), c_axis, self.output_channels())?[0]; let op = DepthWise::new(patch, input_shape, output_shape); Ok(model.wire_node(name, op, &[x, kernel[0], bias])?[0]) } @@ -666,80 +661,91 @@ impl ConvUnary { Ok(Some(patch)) } - /* - fn declutter_channel_arithmetic_succ( - &self, - model: &TypedModel, - node: &TypedNode, - ) -> TractResult> { - if self.q_params.is_some() || self.group != 1 { - return Ok(None); - } - let &[succ] = &*node.outputs[0].successors else { return Ok(None) }; - let Some(bin) = model.node(succ.node).op_as::() else { return Ok(None) }; - let other_input = model.node(succ.node).inputs[1 - succ.slot]; - let other_fact = &model.outlet_fact(other_input)?; - let Some(konst) = &other_fact.konst else { return Ok(None) }; - let axes_mapping = model.node_axes_mapping(succ.node)?; - let input_shape = - self.pool_spec.data_format.shape(&model.outlet_fact(node.inputs[0])?.shape)?; - let conv_c_axis = input_shape.c_axis(); - let &[konst_c_axis] = - &*axes_mapping.axis((InOut::In(succ.slot), conv_c_axis))?.inputs[1 - succ.slot] - else { - return Ok(None); - }; - let Ok(co) = node.outputs[0].fact.shape[conv_c_axis].to_usize() else { return Ok(None) }; - let operand_for_bias = if konst.shape()[konst_c_axis] == co && konst.len() == co { - konst.clone().into_tensor().into_shape(&[co])? - } else if konst.len() == 1 { - konst.clone().to_scalar_tensor()?.broadcast_scalar_to_shape(&[co])? - } else { - return Ok(None); - }; - let mut bias = if let Some(b) = &self.bias { - b.clone() - } else { - Tensor::zero_dt(other_fact.datum_type, &[co])?.into_arc_tensor() - }; - let mut kernel = self.kernel.clone(); - let mut operand_shape_for_kernel = tvec!(1; 2 + input_shape.hw_rank()); - let o_axis = if self.kernel_fmt == KernelFormat::OIHW { 0 } else { self.kernel.rank() - 1 }; - operand_shape_for_kernel[o_axis] = co; - let operand_for_kernel = operand_for_bias.clone().into_shape(&operand_shape_for_kernel)?; - if bin.0.is::() && succ.slot == 0 { - bias = (bias.into_tensor().into_array::()? - - operand_for_bias.to_array_view::()?) - .into_arc_tensor() - } else if bin.0.is::
() && succ.slot == 0 { - bias = (bias.into_tensor().into_array::()? - / operand_for_bias.to_array_view::()?) - .into_arc_tensor(); - kernel = (kernel.into_tensor().into_array::()? - / operand_for_kernel.to_array_view::()?) - .into_arc_tensor(); - } else if bin.0.is::() { - bias = (bias.into_tensor().into_array::()? - + operand_for_bias.to_array_view::()?) - .into_arc_tensor(); - } else if bin.0.is::() { - bias = (bias.into_tensor().into_array::()? - * operand_for_bias.to_array_view::()?) - .into_arc_tensor(); - kernel = (kernel.into_tensor().into_array::()? - * operand_for_kernel.to_array_view::()?) - .into_arc_tensor(); - } else { - return Ok(None); - }; - let new_op = ConvUnary { bias: Some(bias), kernel, ..self.clone() }; - let mut patch = TypedModelPatch::default(); - let wire = patch.tap_model(model, node.inputs[0])?; - let wire = patch.wire_node(&node.name, new_op, &[wire])?[0]; - patch.shunt_outside(model, succ.node.into(), wire)?; + fn declutter_channel_arithmetic_succ( + &self, + model: &TypedModel, + node: &TypedNode, + ) -> TractResult> { + if self.q_params.is_some() || self.group != 1 { + return Ok(None); + } + let &[succ_outlet] = &*node.outputs[0].successors else { return Ok(None) }; + let succ = model.node(succ_outlet.node); + let Some(bin) = succ.op_as::() else { return Ok(None) }; + let other_input = succ.inputs[1 - succ_outlet.slot]; + let axes_mapping = model.node_axes_mapping(succ.id)?; + let input_shape = + self.pool_spec.data_format.shape(&model.outlet_fact(node.inputs[0])?.shape)?; + let conv_c_axis = input_shape.c_axis(); + if axes_mapping.axis((InOut::In(succ_outlet.slot), conv_c_axis))?.inputs + [1 - succ_outlet.slot] + .len() + != 1 + { + return Ok(None); + }; + + let mut patch = TypedModelPatch::default(); + let [input, mut kernel, mut bias] = &*patch.taps(model, &node.inputs)? else { + panic!("Expect three inputs"); + }; + let name = &node.name; + let succ_name = &succ.name; + + let operand = patch.tap_model(model, other_input)?; + + let renamed = format!("{name}.{succ_name}"); + bias = wire_reshape_bias( + &mut patch, + format!("{renamed}.reshape_bias"), + bias, + 1, + 0, + self.output_channels(), + )?[0]; + + let operand = wire_reshape_bias( + &mut patch, + format!("{renamed}.reshape_operand"), + operand, + 1, + 0, + self.output_channels(), + )?[0]; + + let operand_fact = patch.outlet_fact(operand)?.shape.to_tvec(); + let kernel_fact = patch.outlet_fact(kernel)?; + let mut operand_shape_for_kernel = tvec!(1.to_dim(); 2 + input_shape.hw_rank()); + operand_shape_for_kernel[self.kernel_fmt.o_axis(&kernel_fact.shape)] = + self.output_channels().to_dim(); + let operand_for_kernel = patch.wire_node( + format!("{renamed}.reshape_operand_for_kernel"), + AxisOp::Reshape(0, operand_fact, operand_shape_for_kernel), + &[operand], + )?[0]; + + if bin.0.is::() && succ_outlet.slot == 0 { + bias = patch.wire_node(&renamed, sub(), &[bias, operand])?[0]; + } else if bin.0.is::() { + bias = patch.wire_node(&renamed, sub(), &[operand, bias])?[0]; + } else if bin.0.is::
() && succ_outlet.slot == 0 { + bias = patch.wire_node(&renamed, div(), &[bias, operand])?[0]; + kernel = patch.wire_node(&renamed, div(), &[kernel, operand_for_kernel])?[0]; + } else if bin.0.is::
() { + bias = patch.wire_node(&renamed, div(), &[operand, bias])?[0]; + kernel = patch.wire_node(&renamed, div(), &[operand_for_kernel, kernel])?[0]; + } else if bin.0.is::() { + bias = patch.wire_node(&renamed, add(), &[bias, operand])?[0]; + } else if bin.0.is::() { + bias = patch.wire_node(&renamed, mul(), &[bias, operand])?[0]; + kernel = patch.wire_node(&renamed, mul(), &[kernel, operand_for_kernel])?[0]; + } else { + return Ok(None); + }; + let wire = patch.wire_node(&node.name, self.clone(), &[*input, kernel, bias])?[0]; + patch.shunt_outside(model, succ_outlet.node.into(), wire)?; Ok(Some(patch)) } - */ } impl Op for ConvUnary { @@ -894,7 +900,7 @@ impl TypedOp for ConvUnary { } pass!(declutter_stride_slice_to_downsample); pass!(declutter_as_einsum); - // pass!(declutter_channel_arithmetic_succ); + pass!(declutter_channel_arithmetic_succ); pass!(declutter_precursor_padding); Ok(None) } @@ -1150,8 +1156,8 @@ mod test { )?; model.set_output_outlets(&wire)?; model.declutter()?; - assert_eq!(model.nodes().len(), 2); // source + conv - let cv = model.nodes()[1].op_as::().unwrap(); + assert_eq!(model.nodes().len(), 4); // source + conv + kernel + bias + let cv = model.nodes()[3].op_as::().unwrap(); assert_eq!(cv.pool_spec.padding, Explicit(tvec![1], tvec![0])); // source + conv Ok(()) } diff --git a/core/src/ops/cnn/deconv/deconv_sum.rs b/core/src/ops/cnn/deconv/deconv_sum.rs index 65aed7f7b3..d97ce8369b 100644 --- a/core/src/ops/cnn/deconv/deconv_sum.rs +++ b/core/src/ops/cnn/deconv/deconv_sum.rs @@ -26,7 +26,6 @@ pub struct DeconvSum { /// shape of the deconvolution input pub input_shape: ShapeFact, pub adjustments: TVec, - pub bias: Option>, pub group: usize, } @@ -74,8 +73,7 @@ impl DeconvSum { inputs: TVec, values: &SymbolValues, ) -> TractResult> { - let gemm = args_1!(inputs).into_tensor(); - let dt = gemm.datum_type(); + let (gemm, bias) = args_2!(inputs); let input_shape = self.input_shape.eval_to_usize(values)?.into_owned(); let input_shape = self.pool_spec.data_format.shape(input_shape)?; let output_shape = @@ -88,14 +86,10 @@ impl DeconvSum { &self.pool_spec.strides(), &self.adjustments, )?; - let mut tensor = if let Some(b) = &self.bias { - b.broadcast_vector_to_shape(&output_shape.shape, output_shape.c_axis())? - } else { - Tensor::zero_dt(dt, &output_shape.shape)? - }; + let mut tensor = bias.broadcast_to_shape(&output_shape.shape)?; let hw = *gemm.shape().last().unwrap(); let n = *output_shape.n().unwrap_or(&1); - let n_o_hkwk_hw = gemm.into_shape(&[ + let n_o_hkwk_hw = gemm.into_tensor().into_shape(&[ n, *output_shape.c(), self.pool_spec.kernel_shape.iter().product(), @@ -121,8 +115,11 @@ impl DeconvSum { impl TypedOp for DeconvSum { fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { + ensure!(inputs.len() == 2); let shape = super::output_shape(&self.pool_spec, &self.input_shape, &self.adjustments)?; - Ok(tvec!(inputs[0].datum_type.fact(&*shape))) + ensure!(inputs[1].rank() == shape.len()); + ensure!(inputs[1].shape.iter().zip(shape.iter()).all(|(b, o)| b.is_one() || b == o.to_dim())); + Ok(tvec!(inputs[0].datum_type.fact(shape))) } fn concretize_dims( @@ -136,7 +133,7 @@ impl TypedOp for DeconvSum { target.wire_node( &node.name, Self { input_shape: self.input_shape.eval(values)?.into_owned(), ..self.clone() }, - &[mapping[&node.inputs[0]]], + &[mapping[&node.inputs[0]], mapping[&node.inputs[1]]], ) } diff --git a/core/src/ops/cnn/deconv/unary.rs b/core/src/ops/cnn/deconv/unary.rs index 7af3d4e25e..20ae0298e1 100644 --- a/core/src/ops/cnn/deconv/unary.rs +++ b/core/src/ops/cnn/deconv/unary.rs @@ -1,4 +1,5 @@ use crate::internal::*; +use crate::ops::cnn::wire_reshape_bias; use crate::ops::cnn::KernelFormat; use crate::ops::cnn::PoolSpec; use crate::ops::einsum::EinSum; @@ -7,9 +8,6 @@ use crate::ops::einsum::EinSum; pub struct DeconvUnary { pub pool_spec: PoolSpec, pub kernel_format: KernelFormat, - pub kernel: Arc, - pub bias: Option>, - pub adjustments: TVec, pub group: usize, } @@ -19,9 +17,9 @@ impl DeconvUnary { &self, name: &str, target: &mut TypedModel, - input: OutletId, + inputs: &[OutletId], ) -> TractResult> { - let input_shape = target.outlet_fact(input)?.shape.clone(); + let input_shape = target.outlet_fact(inputs[0])?.shape.clone(); let shape = self.pool_spec.data_format.shape(input_shape.to_tvec())?; let geo_dim = shape.hw_dims().iter().product(); @@ -29,7 +27,7 @@ impl DeconvUnary { let mut input = target.wire_node( format!("{name}.reshaped_input"), AxisOp::Reshape(shape.h_axis(), shape.hw_dims().into(), tvec!(geo_dim)), - &[input], + &[inputs[0]], )?; // rework input to (N) (G) I/G HW or (N) (G) HW I/G @@ -58,15 +56,24 @@ impl DeconvUnary { )?; } } - let kernel_as_group_o_i_hw = - self.kernel_format.kernel_as_group_o_i_hw(&self.kernel, self.group)?; - let mut kernel_as_optg_ohw_i = - kernel_as_group_o_i_hw.move_axis(2, 3)?.collapse_axis_with_next(1); + + let mut kernel = tvec!(inputs[1]); + let kernel_fact = target.outlet_fact(kernel[0])?.clone(); + for (ix, op) in self + .kernel_format + .kernel_as_group_o_i_hw_ops(&kernel_fact.shape, self.group) + .into_iter() + .enumerate() + { + kernel = target.wire_node(format!("{name}.kernel.{ix}"), op, &kernel)?; + } + + kernel = target.wire_node(format!("{name}.kernel.mv_i"), AxisOp::Move(2, 3), &kernel)?; + kernel = + AxisOp::wire_collapse_axis(target, format!("{name}.kernel.col_ohw"), kernel[0], 1)?; if self.group == 1 { - kernel_as_optg_ohw_i.remove_axis(0)?; + kernel = target.wire_node(format!("{name}.kernel.rm_g"), AxisOp::Rm(0), &kernel)?; } - let kernel = - target.add_const(format!("{}.kernel", name), kernel_as_optg_ohw_i.into_arc_tensor())?; let mut expr = if self.pool_spec.data_format.c_is_last() { "gmk,Ngnk->Ngmn".to_string() } else { @@ -80,10 +87,19 @@ impl DeconvUnary { } let einsum = target.wire_node( format!("{name}.einsum"), - EinSum { axes: expr.parse()?, operating_dt: self.kernel.datum_type(), q_params: None }, - &[kernel, input[0]], + EinSum { axes: expr.parse()?, operating_dt: kernel_fact.datum_type, q_params: None }, + &[kernel[0], input[0]], )?; + let bias = wire_reshape_bias( + target, + format!("{name}.reshape_bias"), + inputs[2], + shape.rank(), + shape.c_axis(), + self.pool_spec.output_channels, + )?[0]; + // einsum must be (N_)CHkWk_HW let deconv_sum = target.wire_node( format!("{name}.deconv_sum"), @@ -92,10 +108,9 @@ impl DeconvUnary { self.kernel_format, input_shape, self.adjustments.clone(), - self.bias.clone(), self.group, ), - &einsum, + &[einsum[0], bias], )?; Ok(deconv_sum) } @@ -119,25 +134,31 @@ impl EvalOp for DeconvUnary { } fn eval(&self, inputs: TVec) -> TractResult> { - let input = args_1!(inputs); + ensure!(inputs.len() == 3); let mut model = TypedModel::default(); - let source = model.add_source("source", input.datum_type().fact(input.shape()))?; - let output = self.wire_with_deconv_sum("adhoc", &mut model, source)?; + let inputs = inputs + .into_iter() + .enumerate() + .map(|(ix, input)| model.add_const(format!("s{ix}"), input.into_tensor())) + .collect::>>()?; + let output = self.wire_with_deconv_sum("adhoc", &mut model, &inputs)?; model.set_output_outlets(&output)?; - model.into_runnable()?.run(tvec!(input)).context("In adhoc deconvolution eval") + model.into_runnable()?.run(tvec![]).context("In adhoc deconvolution eval") } } impl TypedOp for DeconvUnary { fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { + ensure!(inputs.len() == 3); let x_fact = inputs[0]; + let k_fact = inputs[1]; ensure!( &self.pool_spec.input_channels.to_dim() == self.pool_spec.data_format.shape(&inputs[0].shape)?.c() ); ensure!( - self.pool_spec.input_channels - == *self.kernel_format.input_channels(&self.kernel.shape(), self.group) + self.pool_spec.input_channels.to_dim() + == *self.kernel_format.input_channels(&k_fact.shape, self.group) ); let output_shape = super::output_shape(&self.pool_spec, &x_fact.shape, &self.adjustments)?; Ok(tvec!(x_fact.datum_type.fact(&output_shape))) @@ -149,6 +170,7 @@ impl TypedOp for DeconvUnary { outputs: &[&TypedFact], ) -> TractResult { let fact = &inputs[0]; + let k_fact = &inputs[1]; let shape = self.pool_spec.data_format.shape(fact.shape.iter().collect::>())?; let mut axes = AxesMapping::disconnected(inputs, outputs)? .renaming((InOut::In(0), shape.c_axis()), 'I')? @@ -160,10 +182,9 @@ impl TypedOp for DeconvUnary { } let h_axis = shape.h_axis(); let geo = "HWXYZ".chars().chain('a'..); - let kernel_spatial_shape = - &self.kernel.shape()[self.kernel_format.h_axis()..][..shape.hw_rank()]; - for ((ix, &dim), repr) in kernel_spatial_shape.iter().enumerate().zip(geo) { - if dim == 1 + let kernel_spatial_shape = self.kernel_format.spatial_shape(&k_fact.shape); + for ((ix, &ref dim), repr) in kernel_spatial_shape.iter().enumerate().zip(geo) { + if dim.is_one() && self.pool_spec.stride(ix) == 1 && self.pool_spec.padding.valid_dim(ix, true) && self.adjustments[ix] == 0 @@ -182,11 +203,21 @@ impl TypedOp for DeconvUnary { node: &TypedNode, ) -> TractResult> { let mut patch = TypedModelPatch::default(); - let input = patch.tap_model(model, node.inputs[0])?; + let mut inputs = patch.taps(model, &node.inputs)?; + let x_shape = patch.outlet_fact(inputs[0])?; + let x_shape = self.pool_spec.data_format.shape(x_shape.shape.to_tvec())?; + inputs[2] = wire_reshape_bias( + &mut patch, + &node.name, + inputs[2], + x_shape.rank(), + x_shape.c_axis(), + self.pool_spec.output_channels, + )?[0]; let output = self - .wire_with_deconv_sum(&node.name, &mut patch, input) + .wire_with_deconv_sum(&node.name, &mut patch, &inputs) .context("In wire_with_deconv_sum")?; - patch.shunt_outside(model, (node.id, 0).into(), output[0])?; + patch.shunt_outside(model, node.id.into(), output[0])?; Ok(Some(patch)) } diff --git a/core/src/ops/cnn/mod.rs b/core/src/ops/cnn/mod.rs index a1eca3b853..21f6a53699 100644 --- a/core/src/ops/cnn/mod.rs +++ b/core/src/ops/cnn/mod.rs @@ -17,3 +17,43 @@ pub use self::patch_axis::PatchAxis; pub use self::patches::{Patch, PatchSpec}; pub use self::pools::PoolSpec; pub use self::sumpool::SumPool; + +use super::array::MultiBroadcastTo; + +fn wire_reshape_bias( + model: &mut TypedModel, + name: impl AsRef, + outlet: OutletId, + rank: usize, + c_axis: usize, + output_channels: usize, +) -> TractResult> { + let name = name.as_ref(); + let mut bias = tvec!(outlet); + let fact = model.outlet_fact(outlet)?.clone(); + if fact.shape.volume().is_one() && fact.rank() > 0 { + bias = model.wire_node( + format!("{name}.bias.broadcast_as_scalar"), + AxisOp::Reshape(0, fact.shape.to_tvec(), tvec![]), + &bias, + )?; + } + if model.outlet_fact(bias[0])?.rank() == 0 { + bias = model.wire_node( + format!("{name}.bias.broadcast"), + MultiBroadcastTo { shape: tvec!(output_channels).into() }, + &bias, + )?; + } + let fact = model.outlet_fact(bias[0])?.clone(); + let mut bias_final_shape = tvec![1.to_dim(); rank]; + bias_final_shape[c_axis] = output_channels.to_dim(); + if &*bias_final_shape != &*fact.shape { + bias = model.wire_node( + format!("{name}.bias"), + AxisOp::Reshape(0, fact.shape.to_tvec(), bias_final_shape), + &bias, + )?; + } + Ok(bias) +} diff --git a/core/src/ops/downsample/conv.rs b/core/src/ops/downsample/conv.rs index c4ebf34030..7512c9f4d5 100644 --- a/core/src/ops/downsample/conv.rs +++ b/core/src/ops/downsample/conv.rs @@ -30,8 +30,8 @@ pub fn fuse_downsample_into_conv( new_conv.pool_spec.strides.as_mut().unwrap()[geo_axis] *= down_op.stride as usize; let mut patch = TypedModelPatch::default(); - let tap = patch.tap_model(model, conv_node.inputs[0])?; - let new_output = patch.wire_node(&*conv_node.name, new_conv, [tap].as_ref())?[0]; + let taps = patch.taps(model, &conv_node.inputs)?; + let new_output = patch.wire_node(&*conv_node.name, new_conv, &taps)?[0]; patch.shunt_outside(model, OutletId::new(down_node.id, 0), new_output)?; Ok(Some(patch)) } diff --git a/core/src/plan.rs b/core/src/plan.rs index 663d84b1d2..0930541b65 100644 --- a/core/src/plan.rs +++ b/core/src/plan.rs @@ -554,13 +554,13 @@ where F: Fact + Clone + 'static, O: Debug + Display + AsRef + AsMut + Clone + 'static, { - eprint!("{node} {input:?}"); + // eprint!("{node} {input:?}"); let r = match state { Some(ref mut state) => state.eval(session_state, node.op(), input), None => node.op().eval(input), } .with_context(|| format!("Evaluating {node}")); - eprintln!(" ==> {}", r.as_ref().unwrap()[0].dump(true)?); + // eprintln!(" ==> {}", r.as_ref().unwrap()[0].dump(true)?); r } diff --git a/harness/core-proptest-pulse/src/conv_plus_conv.rs b/harness/core-proptest-pulse/src/conv_plus_conv.rs index ab92c24ef3..0bba6ade2e 100644 --- a/harness/core-proptest-pulse/src/conv_plus_conv.rs +++ b/harness/core-proptest-pulse/src/conv_plus_conv.rs @@ -88,7 +88,12 @@ impl Arbitrary for ConvPlusConvProblem { impl ConvPlusConvProblem { pub fn min_input_size(ops: &[ConvOp]) -> usize { let model = Self::model(ops); - let dims: Vec<&TDim> = model.nodes.iter().map(|n| &n.outputs[0].fact.shape[2]).collect(); + let dims: Vec<&TDim> = model + .nodes + .iter() + .filter(|node| !node.outputs[0].fact.shape.is_concrete()) + .map(|n| &n.outputs[0].fact.shape[2]) + .collect(); for s in 0usize.. { let symbols = SymbolValues::default().with(&model.symbol_table.get("S").unwrap(), s as _); diff --git a/harness/core-proptest-pulse/src/deconv.rs b/harness/core-proptest-pulse/src/deconv.rs index 87c28821b3..90f726501f 100644 --- a/harness/core-proptest-pulse/src/deconv.rs +++ b/harness/core-proptest-pulse/src/deconv.rs @@ -25,12 +25,12 @@ impl DeconvOp { output_channels: self.ker.shape()[0], }, kernel_format: tract_core::ops::cnn::KernelFormat::OIHW, - kernel: self.ker.clone().into_arc_tensor(), - bias: None, adjustments: tvec!(self.adj), group: 1, }; - model.wire_node(name, deconv, &[after]).unwrap()[0] + let kernel = model.add_const("kernel", self.ker.clone()).unwrap(); + let bias = model.add_const("bias", rctensor0(0f32)).unwrap(); + model.wire_node(name, deconv, &[after, kernel, bias]).unwrap()[0] } } @@ -273,12 +273,12 @@ fn deconv2d() { output_channels: 2, }, kernel_format: tract_core::ops::cnn::KernelFormat::OIHW, - kernel: kernel.into_arc_tensor(), - bias: None, adjustments: tvec!(0, 0), group: 1, }; - let deconv = model.wire_node("deconv", deconv, &[a]).unwrap(); + let kernel = model.add_const("kernel", kernel).unwrap(); + let bias = model.add_const("bias", rctensor0(0f32)).unwrap(); + let deconv = model.wire_node("deconv", deconv, &[a, kernel, bias]).unwrap(); model.set_output_outlets(&deconv).unwrap(); model.declutter().unwrap(); diff --git a/nnef/src/ops/core/qconv.rs b/nnef/src/ops/core/qconv.rs index 76b1e893b4..22b0802430 100644 --- a/nnef/src/ops/core/qconv.rs +++ b/nnef/src/ops/core/qconv.rs @@ -2,7 +2,6 @@ use crate::deser::Value; use crate::internal::*; use crate::ops::nnef::deser::read_conv_parameters; use crate::ops::nnef::ser::make_conv_named_args; -use crate::ops::nnef::ser::ser_axis_op; use crate::ser::*; use tract_core::ops::cnn::ConvUnary; use tract_core::ops::cnn::KernelFormat; @@ -46,46 +45,17 @@ fn qconv_unary_dump( if op.q_params.is_none() || node.outputs[0].fact.datum_type.is_quantized() { return Ok(None); } - let name = &node.name; let mut named_args = make_conv_named_args(node, &op.pool_spec, op.group, false, None)?; for (ix, name) in ["a0", "a_scale", "b0", "b_scale", "c0", "c_scale"].iter().enumerate() { - named_args.push((name, (*ast.mapping[&node.inputs[1 + ix]]).clone())); + named_args.push((name, (*ast.mapping[&node.inputs[3 + ix]]).clone())); } - let mut wire = ast.mapping[&node.inputs[0]].clone(); + let wire = ast.mapping[&node.inputs[0]].clone(); ensure!(op.kernel_fmt == KernelFormat::OIHW); - let mut weights = ast.mapping[&node.inputs[1]].clone(); - let mut bias = ast.mapping[&node.inputs[2]].clone(); - /* - let mut weights = ast.konst_variable(format!("{name}_weights"), &op.kernel)?; - let mut rank = op.kernel.rank(); - for fix in op.kernel_fmt.kernel_as_group_o_i_h_w_ops(op.kernel.shape(), op.group) { - weights = ser_axis_op(&fix, weights, rank); - match fix { - AxisOp::Add(_) => rank += 1, - AxisOp::Rm(_) => rank -= 1, - AxisOp::Move(_, _) => (), - AxisOp::Reshape(_, before, after) => rank = rank + after.len() - before.len(), - } - } - weights = ser_axis_op( - &AxisOp::Reshape( - 0, - tvec!(op.group.to_dim(), op.pool_spec.output_channels.to_dim() / op.group), - tvec!(op.pool_spec.output_channels.to_dim()), - ), - weights, - rank, - ); - wire = ast.force_variable(format!("{name}_input"), &wire); - - if let Some(bias) = op.bias.as_ref() { - let bias = ast.konst(format!("{name}_bias"), bias)?; - inputs.push(bias) - } - */ - let mut inputs = tvec![wire, weights, bias]; + let weights = ast.mapping[&node.inputs[1]].clone(); + let bias = ast.mapping[&node.inputs[2]].clone(); + let inputs = tvec![wire, weights, bias]; Ok(Some(invocation("tract_core_qconv", &inputs, &named_args))) } @@ -95,18 +65,17 @@ fn qconv_load(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> Tr inputs.push(invocation.named_arg_as(builder, "filter")?); inputs.push(invocation.named_arg_as(builder, "bias")?); - /* - if input_fact.rank() != kernel.rank() { + let input_fact = builder.model.outlet_fact(inputs[0])?.clone(); + let kernel_fact = builder.model.outlet_fact(inputs[1])?.clone(); + + if input_fact.rank() != kernel_fact.rank() { bail!( "Convolution input expected as NCHW, filter as OIHW. Got {:?} and {:?}.", input_fact, - kernel + kernel_fact ); } - */ - let input_fact = builder.model.outlet_fact(inputs[0])?.clone(); - let kernel_fact = builder.model.outlet_fact(inputs[1])?.clone(); let (group, pool_spec) = read_conv_parameters( builder, invocation, @@ -116,10 +85,6 @@ fn qconv_load(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> Tr let qparams = qparams_as_outlets(builder, invocation).context("Loading qparams")?; inputs.extend(qparams.iter().cloned()); - let bias: Arc = invocation.named_arg_as(builder, "bias")?; - - let bias: Option> = - if bias.is_uniform() && bias.cast_to_scalar::()? == 0.0 { None } else { Some(bias) }; let Some(c0) = &builder.model.outlet_fact(qparams[4])?.konst else { bail!("For quantized convolution, output quantization must be static"); @@ -132,12 +97,8 @@ fn qconv_load(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> Tr scale: c_scale.cast_to_scalar()?, }); - let op: Box = Box::new(ConvUnary::new( - pool_spec, - KernelFormat::OIHW, - group, - Some(output_dt), - )); + let op: Box = + Box::new(ConvUnary::new(pool_spec, KernelFormat::OIHW, group, Some(output_dt))); builder.wire(op, &inputs) } diff --git a/nnef/src/ops/nnef/deser.rs b/nnef/src/ops/nnef/deser.rs index 224a13632a..bece8c7d00 100644 --- a/nnef/src/ops/nnef/deser.rs +++ b/nnef/src/ops/nnef/deser.rs @@ -323,16 +323,6 @@ pub fn conv_or_deconv( let bias: OutletId = invocation.named_arg_as(builder, "bias")?; let input_fact = builder.model.outlet_fact(input)?.clone(); let kernel_fact = builder.model.outlet_fact(kernel)?.clone(); - /* - let bias_fact = builder.model.outlet_fact(bias)?; - if input_fact.rank() != kernel.rank() { - bail!( - "Convolution input expected as NCHW, filter as OIHW. Got {:?} and {:?}.", - input_fact, - kernel - ); - } - */ let mut inputs = tvec!(input, kernel, bias); let (group, pool_spec) = read_conv_parameters( @@ -367,15 +357,7 @@ pub fn conv_or_deconv( } else { tvec!(0; pool_spec.rank()) }; - todo!() - /* - Box::new(DeconvUnary::new( - pool_spec, - KernelFormat::OIHW, - adjustments, - group, - )) - */ + Box::new(DeconvUnary::new(pool_spec, KernelFormat::OIHW, adjustments, group)) } else { Box::new(ConvUnary::new( pool_spec, diff --git a/nnef/src/ops/nnef/ser.rs b/nnef/src/ops/nnef/ser.rs index 7181265dcc..ba5c8e32ab 100644 --- a/nnef/src/ops/nnef/ser.rs +++ b/nnef/src/ops/nnef/ser.rs @@ -7,7 +7,6 @@ use tract_core::ops; use tract_core::ops::cnn::KernelFormat; use tract_core::ops::cnn::PoolSpec; use tract_core::ops::nn::DataFormat; -use tract_core::tract_data::itertools::Itertools; pub fn source( ast: &mut IntoAst, @@ -186,7 +185,6 @@ pub fn conv_or_deconv( ast: &mut IntoAst, node: &TypedNode, pool_spec: &PoolSpec, - kernel_format: KernelFormat, group: usize, deconv: bool, adjustments: Option<&[usize]>, @@ -206,22 +204,6 @@ pub fn conv_or_deconv( wire = ast.force_variable(format!("{}_input", node.name), &wire); let inputs = tvec![wire, kernel, bias]; - /* - // nnef: O I/g H W - let mut kernel_go_i_h_w = - kernel_format.kernel_as_group_o_i_hw(kernel, group)?.collapse_axis_with_next(0); - // split hw... as h_w_... - for (ix, dim) in kernel_format.hw(kernel.shape()).iter().dropping_back(1).enumerate() { - kernel_go_i_h_w = kernel_go_i_h_w.split_axis(ix + 2, *dim)?; - } - inputs.push( - ast.konst_variable(format!("{}_weigths", node.name), &kernel_go_i_h_w.into_arc_tensor())?, - ); - if let Some(bias) = bias.as_ref() { - inputs.push(ast.konst(format!("{}_bias", node.name), bias)?); - } - */ - let named_args = make_conv_named_args(node, pool_spec, group, deconv, adjustments)?; let name = if deconv { "deconv" } else { "conv" }; @@ -254,7 +236,7 @@ pub fn conv( if op.q_params.is_some() && !node.outputs[0].fact.datum_type.is_quantized() { return Ok(None); } - conv_or_deconv(ast, node, &op.pool_spec, op.kernel_fmt, op.group, false, None) + conv_or_deconv(ast, node, &op.pool_spec, op.group, false, None) } pub fn deconv( @@ -262,18 +244,7 @@ pub fn deconv( node: &TypedNode, op: &ops::cnn::deconv::DeconvUnary, ) -> TractResult>> { - todo!(); - /* - conv_or_deconv( - ast, - node, - &op.pool_spec, - op.kernel_format, - op.group, - true, - Some(&op.adjustments), - ) - */ + conv_or_deconv(ast, node, &op.pool_spec, op.group, true, Some(&op.adjustments)) } fn cnn_pool_fragment( diff --git a/onnx/src/ops/nn/conv_transpose.rs b/onnx/src/ops/nn/conv_transpose.rs index 51b2921458..f49acaca9a 100644 --- a/onnx/src/ops/nn/conv_transpose.rs +++ b/onnx/src/ops/nn/conv_transpose.rs @@ -115,75 +115,73 @@ impl Expansion for ConvTranspose { target: &mut TypedModel, inputs: &[OutletId], ) -> TractResult> { - if let Some(k) = target.outlet_fact(inputs[1])?.konst.clone() { - let zeros = tvec!(0; k.rank() - 2); - // ONNX deconv kernels are stored as gi_o_h_w (convolution are go_i_hw) - let kernel = k - .into_tensor() - .split_axis(0, self.group)? - .move_axis(1, 2)? - .collapse_axis_with_next(0); + // ONNX deconv kernels are stored as gi_o_h_w (convolution are go_i_hw) + /* + let kernel = + k.into_tensor().split_axis(0, self.group)?.move_axis(1, 2)?.collapse_axis_with_next(0); + */ + let mut kernel = AxisOp::wire_split_axis( + target, + format!("{prefix}.kernel_split_group"), + inputs[1], + 0, + self.group, + )?; + kernel = + target.wire_node(format!("{prefix}.kernel_reorder"), AxisOp::Move(1, 2), &kernel)?; + kernel = AxisOp::wire_collapse_axis( + target, + format!("{prefix}.kernel_merge_group"), + kernel[0], + 0, + )?; - let bias = if self.have_bias { - Some( - target - .outlet_fact(inputs[2])? - .konst - .clone() - .context("bias must be a constant")?, - ) - } else { - None - }; + let bias = if self.have_bias { + inputs[2] + } else { + target.add_const( + format!("{prefix}.bias"), + Tensor::zero_scalar_dt(target.outlet_fact(inputs[0])?.datum_type)?, + )? + }; - let ci = KernelFormat::OIHW.input_channels(kernel.shape(), self.group).into_owned(); - let co = KernelFormat::OIHW.output_channels(kernel.shape(), self.group).into_owned(); - let op = if let Some(output_shape) = &self.output_shape { - let x_shape = &target.outlet_fact(inputs[0])?.shape; - let pool_spec = PoolSpec::new( - DataFormat::NCHW, - kernel.shape()[2..].into(), - self.padding_spec.clone(), - self.dilations.clone(), - self.strides.clone(), - ci, - co, - ); - let adjustments = adjustments( - &pool_spec, - &x_shape.as_concrete().context("expects concrete dim for deconv")?[2..], - output_shape, - )?; - tract_core::ops::cnn::DeconvUnary::new( - pool_spec, - KernelFormat::OIHW, - kernel.into_arc_tensor(), - bias, - adjustments, - self.group, - ) - } else { - let pool_spec = PoolSpec::new( - DataFormat::NCHW, - kernel.shape()[2..].into(), - self.padding_spec.clone(), - self.dilations.clone(), - self.strides.clone(), - ci, - co, - ); - tract_core::ops::cnn::DeconvUnary::new( - pool_spec, - KernelFormat::OIHW, - kernel.into_arc_tensor(), - bias, - self.adjustments.clone().unwrap_or(zeros), - self.group, - ) - }; - target.wire_node(prefix, op, &[inputs[0]]) + let kernel_shape = target + .outlet_fact(kernel[0])? + .shape + .as_concrete() + .context("Expects concrete kernel shape")?; + let ci = KernelFormat::OIHW.input_channels(&kernel_shape, self.group).into_owned(); + let co = KernelFormat::OIHW.output_channels(&kernel_shape, self.group).into_owned(); + let pool_spec = PoolSpec::new( + DataFormat::NCHW, + kernel_shape[2..].into(), + self.padding_spec.clone(), + self.dilations.clone(), + self.strides.clone(), + ci, + co, + ); + let op = if let Some(output_shape) = &self.output_shape { + let x_shape = &target.outlet_fact(inputs[0])?.shape; + let adjustments = adjustments( + &pool_spec, + &x_shape.as_concrete().context("expects concrete dim for deconv")?[2..], + output_shape, + )?; + tract_core::ops::cnn::DeconvUnary::new( + pool_spec, + KernelFormat::OIHW, + adjustments, + self.group, + ) } else { - bail!("Kernel values are expected to be constant.") - } + tract_core::ops::cnn::DeconvUnary::new( + pool_spec, + KernelFormat::OIHW, + self.adjustments.clone().unwrap_or_else(|| tvec!(0; kernel_shape.len() - 2)), + self.group, + ) + }; + target.wire_node(prefix, op, &[inputs[0], kernel[0], bias]) } } diff --git a/pulse/src/ops/cnn/conv.rs b/pulse/src/ops/cnn/conv.rs index b0a92edee1..58a8ffe51d 100644 --- a/pulse/src/ops/cnn/conv.rs +++ b/pulse/src/ops/cnn/conv.rs @@ -13,15 +13,12 @@ fn pulsify( _symbol: &Symbol, _pulse: &TDim, ) -> TractResult>> { - fn zero() -> Tensor { - tensor0(D::default()) - } let fact = target.outlet_fact(mapping[&node.inputs[0]])?; - let zero = dispatch_numbers!(zero(fact.datum_type)()); + let zero = Tensor::zero_scalar_dt(fact.datum_type)?; if let Some((wire, pool_spec)) = pulsify_pooled_input(&op.pool_spec, source, node, target, mapping, Some(zero))? { - let mut wires:TVec<_> = node.inputs.iter().map(|i| mapping[i]).collect(); + let mut wires: TVec<_> = node.inputs.iter().map(|i| mapping[i]).collect(); wires[0] = wire; Ok(Some(target.wire_node(&node.name, ConvUnary { pool_spec, ..op.clone() }, &wires)?)) } else { diff --git a/pulse/src/ops/cnn/deconv.rs b/pulse/src/ops/cnn/deconv.rs index 0d7be5d8bb..8f9b9bb60b 100644 --- a/pulse/src/ops/cnn/deconv.rs +++ b/pulse/src/ops/cnn/deconv.rs @@ -46,6 +46,8 @@ fn pulsify( value: Tensor::zero_scalar_dt(fact.datum_type)?, }; wire = target.wire_node(format!("{}.mask", node.name), mask, &wire)?; + wire.push(mapping[&node.inputs[1]]); + wire.push(mapping[&node.inputs[2]]); wire = target.wire_node(format!("{}.deconv", node.name), pulse_op, &wire)?; let overlap = overlap(stream.axis, op); let deconv_input_dim = (stream.dim.clone() - 1) * stride + 1; @@ -54,7 +56,7 @@ fn pulsify( &fact.streaming_shape(), &op.adjustments, )?; - let kernel_spatial_shape = op.kernel_format.hw(op.kernel.shape()); + let kernel_spatial_shape = &op.pool_spec.kernel_shape; let shape = op.pool_spec.data_format.shape(fact.streaming_shape())?; let paddings = op.pool_spec.padding.compute_for_deconv( shape.hw_dims(), @@ -97,8 +99,7 @@ fn pulsify( fn overlap(pulse_axis: usize, op: &DeconvUnary) -> usize { let geo_axis = pulse_axis - op.pool_spec.data_format.h_axis(); - let axis_in_kernel = op.kernel_format.h_axis() + geo_axis; - (op.kernel.shape()[axis_in_kernel] - 1) * op.pool_spec.dilation(geo_axis) + (op.pool_spec.kernel_shape[geo_axis] - 1) * op.pool_spec.dilation(geo_axis) } impl PulsedOp for DeconvUnary { diff --git a/test-rt/suite-unit/src/deconv.rs b/test-rt/suite-unit/src/deconv.rs index 9a3617ace2..47dc5963f8 100644 --- a/test-rt/suite-unit/src/deconv.rs +++ b/test-rt/suite-unit/src/deconv.rs @@ -145,21 +145,23 @@ impl DeconvProblem { self.kernel_format.input_channels(self.kernel.shape(), self.group).into_owned(), self.kernel_format.output_channels(self.kernel.shape(), self.group).into_owned(), ); - let op = DeconvUnary::new( - pool_spec, - self.kernel_format, - self.kernel.clone().into_arc_tensor(), - self.bias.as_ref().map(|b| b.clone().into_arc_tensor()), - self.adjustments.clone(), - self.group, - ); + let op = + DeconvUnary::new(pool_spec, self.kernel_format, self.adjustments.clone(), self.group); Ok(op) } fn tract(&self) -> TractResult { let mut model = TypedModel::default(); let src = model.add_source("src", f32::fact(self.input.shape()))?; - let output = model.wire_node("deconv", self.as_op().context("Generating op")?, &[src])?; + let kernel = model.add_const("kernel", self.kernel.clone().into_tensor())?; + let bias = + self.bias.as_ref().map(|b| b.clone().into_tensor()).unwrap_or_else(|| tensor0(0f32)); + let bias = model.add_const("bias", bias)?; + let output = model.wire_node( + "deconv", + self.as_op().context("Generating op")?, + &[src, kernel, bias], + )?; model.set_output_outlets(&output)?; Ok(model) } @@ -556,6 +558,38 @@ pub fn suite() -> TractResult { }, ); + suite.add( + "bias_2", + DeconvProblem { + data_format: CHW, + kernel_format: OIHW, + padding: PaddingSpec::Valid, + input: arr2(&[[0f32, 1.]]).into_dyn(), + kernel: arr3(&[[[1f32]], [[0.]]]).into_dyn(), + bias: Some(arr1(&[0f32, 0.]).into_dyn()), + strides: tvec!(1), + dilations: tvec!(1), + adjustments: tvec!(0), + group: 1, + }, + ); + + suite.add( + "bias_group_0", + DeconvProblem { + data_format: CHW, + kernel_format: OIHW, + padding: PaddingSpec::Valid, + input: arr2(&[[0f32], [1.]]).into_dyn(), + kernel: arr3(&[[[1f32]], [[0.]]]).into_dyn(), + bias: Some(arr1(&[0f32, 0.]).into_dyn()), + strides: tvec!(1), + dilations: tvec!(1), + adjustments: tvec!(0), + group: 2, + }, + ); + suite.add( "rank_5_with_group", DeconvProblem {