diff --git a/core/src/ops/cnn/conv/unary.rs b/core/src/ops/cnn/conv/unary.rs index 195fe39c69..63b02d0581 100644 --- a/core/src/ops/cnn/conv/unary.rs +++ b/core/src/ops/cnn/conv/unary.rs @@ -120,32 +120,28 @@ impl ConvUnary { // group,bias fn bias_as_non_linear( &self, + model: &mut TypedModel, + name: &str, + bias: OutletId, c_group_axis: usize, - ) -> TractResult> { + ) -> TractResult<(ProtoFusedSpec, OutletId)> { use tract_linalg::mmm::BinOp::Add; - if let Some(bias) = &self.bias { - if let Some(uni) = bias.as_uniform() { - if uni.is_zero()? { - Ok(None) - } else { - Ok(Some((ProtoFusedSpec::BinScalar(2, Add), uni))) - } - } else { - let bias = bias - .clone() - .into_tensor() - .into_shape(&[self.group, bias.len() / self.group])?; - Ok(Some(( - ProtoFusedSpec::BinPerRow( - 2, - Add, - MapOutputAxisToInput(tvec!((c_group_axis, 0))), - ), - bias, - ))) - } + let fact = model.outlet_fact(bias)?; + if fact.uniform.is_some() { + Ok((ProtoFusedSpec::BinScalar(2, Add), bias)) } else { - Ok(None) + let bias = model.wire_node( + format!("{name}.reformat_bias"), + AxisOp::Reshape( + 0, + tvec!(fact.shape.volume()), + tvec!(self.group.to_dim(), fact.shape.volume() / self.group), + ), + &[bias], + )?[0]; + let pfs = + ProtoFusedSpec::BinPerRow(2, Add, MapOutputAxisToInput(tvec!((c_group_axis, 0)))); + Ok((pfs, bias)) } } @@ -502,8 +498,9 @@ impl ConvUnary { let mut wires: TVec = wire.into(); wires.push(kernels); let mut ops: Vec = vec![ProtoFusedSpec::AddMatMul(geo, 1, 0)]; - if let Some((fused, tensor)) = self.bias_as_non_linear(c_m_axis - 1)? { - let bias = model.add_const(format!("{name}.bias"), tensor)?; + if let Some(bias) = &self.bias { + let bias = model.add_const(format!("{name}.bias"), bias.clone())?; + let (fused, bias) = self.bias_as_non_linear(model, name, bias, c_m_axis - 1)?; wires.push(bias); ops.push(fused); }