Skip to content

Commit

Permalink
broadcast bias out of eval for deconv_sum
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Dec 7, 2023
1 parent 000bc16 commit 21a70a5
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 15 deletions.
5 changes: 2 additions & 3 deletions core/src/ops/cnn/deconv/deconv_sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ impl DeconvSum {
&self.pool_spec.strides(),
&self.adjustments,
)?;
let mut tensor = bias.broadcast_to_shape(&output_shape.shape)?;
let mut tensor = bias.into_tensor();
let hw = *gemm.shape().last().unwrap();
let n = *output_shape.n().unwrap_or(&1);
let n_o_hkwk_hw = gemm.into_tensor().into_shape(&[
Expand Down Expand Up @@ -117,8 +117,7 @@ impl TypedOp for DeconvSum {
fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
ensure!(inputs.len() == 2);
let shape = super::output_shape(&self.pool_spec, &self.input_shape, &self.adjustments)?;
ensure!(inputs[1].rank() == shape.len());
ensure!(inputs[1].shape.iter().zip(shape.iter()).all(|(b, o)| b.is_one() || b == o.to_dim()));
ensure!(*inputs[1].shape == *shape);
Ok(tvec!(inputs[0].datum_type.fact(shape)))
}

Expand Down
21 changes: 9 additions & 12 deletions core/src/ops/cnn/deconv/unary.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::internal::*;
use crate::ops::array::MultiBroadcastTo;
use crate::ops::cnn::wire_reshape_bias;
use crate::ops::cnn::KernelFormat;
use crate::ops::cnn::PoolSpec;
Expand Down Expand Up @@ -91,14 +92,20 @@ impl DeconvUnary {
&[kernel[0], input[0]],
)?;

let bias = wire_reshape_bias(
let mut bias = wire_reshape_bias(
target,
format!("{name}.reshape_bias"),
inputs[2],
shape.rank(),
shape.c_axis(),
self.pool_spec.output_channels,
)?[0];
let output_shape = super::output_shape(&self.pool_spec, &shape.shape, &self.adjustments)?;
bias = target.wire_node(
&format!("{name}.broadcast_bias"),
MultiBroadcastTo { shape: output_shape.into() },
&[bias],
)?[0];

// einsum must be (N_)CHkWk_HW
let deconv_sum = target.wire_node(
Expand Down Expand Up @@ -203,17 +210,7 @@ impl TypedOp for DeconvUnary {
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
let mut patch = TypedModelPatch::default();
let mut inputs = patch.taps(model, &node.inputs)?;
let x_shape = patch.outlet_fact(inputs[0])?;
let x_shape = self.pool_spec.data_format.shape(x_shape.shape.to_tvec())?;
inputs[2] = wire_reshape_bias(
&mut patch,
&node.name,
inputs[2],
x_shape.rank(),
x_shape.c_axis(),
self.pool_spec.output_channels,
)?[0];
let inputs = patch.taps(model, &node.inputs)?;
let output = self
.wire_with_deconv_sum(&node.name, &mut patch, &inputs)
.context("In wire_with_deconv_sum")?;
Expand Down

0 comments on commit 21a70a5

Please sign in to comment.