Skip to content

Commit

Permalink
refactor deconv
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Oct 10, 2023
1 parent 484bb81 commit 06b007f
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 70 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ inherits="release"
lto=false

[profile.release]
debug = true
lto = true

[profile.bench]
Expand Down
33 changes: 7 additions & 26 deletions core/src/ops/array/broadcast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,6 @@ pub struct MultiBroadcastTo {
pub shape: ShapeFact,
}



impl MultiBroadcastTo {
pub fn eval_t<T: Datum>(input: &Tensor, shape: &[usize]) -> TractResult<TVec<TValue>> {
unsafe {
let view = input.to_array_view_unchecked::<T>();
let mut output = view
.broadcast(shape)
.with_context(|| format!("Broadcasting {view:?} to {shape:?}"))?
.into_owned()
.into_tensor();
output.set_datum_type(input.datum_type());
Ok(tvec![output.into_tvalue()])
}
}
}

impl Op for MultiBroadcastTo {
fn name(&self) -> Cow<str> {
"MultiBroadcastTo".into()
Expand All @@ -39,7 +22,8 @@ impl EvalOp for MultiBroadcastTo {
let input = args_1!(inputs);
let dims: Vec<usize> =
self.shape.iter().map(|d| d.to_usize()).collect::<TractResult<_>>()?;
dispatch_datum!(Self::eval_t(input.datum_type())(&*input, &*dims))
let output = input.broadcast_to_shape(&dims)?;
Ok(tvec!(output.into_tvalue()))
}

fn state(
Expand All @@ -64,9 +48,7 @@ impl OpState for MultiBroadcastToState {
) -> TractResult<TVec<TValue>> {
let op = op.downcast_ref::<MultiBroadcastTo>().context("Wrong op")?;
let shape = op.shape.eval_to_usize(&session.resolved_symbols)?;
dispatch_datum_by_size!(MultiBroadcastTo::eval_t(inputs[0].datum_type())(
&inputs[0], &*shape
))
Ok(tvec!(inputs[0].broadcast_to_shape(&*shape)?.into_tvalue()))
}
}

Expand All @@ -92,10 +74,10 @@ impl TypedOp for MultiBroadcastTo {
}

fn declutter(
&self,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
&self,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
let input_fact = model.outlet_fact(node.inputs[0])?;
if input_fact.shape == self.shape {
TypedModelPatch::shunt_one_op(model, node)
Expand All @@ -106,4 +88,3 @@ impl TypedOp for MultiBroadcastTo {

as_op!();
}

82 changes: 38 additions & 44 deletions core/src/ops/cnn/deconv/deconv_sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,7 @@ impl EvalOp for DeconvSum {
}

fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
dispatch_floatlike!(Self::eval_with_values(inputs[0].datum_type())(
self,
inputs,
&SymbolValues::default()
))
self.eval_with_values(inputs, &Default::default())
}

fn state(
Expand All @@ -65,23 +61,19 @@ impl OpState for DeconvSum {
_op: &dyn Op,
inputs: TVec<TValue>,
) -> TractResult<TVec<TValue>> {
dispatch_floatlike!(Self::eval_with_values(inputs[0].datum_type())(
self,
inputs,
&session.resolved_symbols
))
self.eval_with_values(inputs, &session.resolved_symbols)
}
}
trivial_op_state_freeeze!(DeconvSum);

impl DeconvSum {
fn eval_with_values<T: Datum + Float + Copy + AddAssign<T>>(
fn eval_with_values(
&self,
inputs: TVec<TValue>,
values: &SymbolValues,
) -> TractResult<TVec<TValue>> {
let gemm = args_1!(inputs).into_tensor();
debug_assert_eq!(gemm.datum_type(), T::datum_type());
let dt = gemm.datum_type();
let input_shape = self.input_shape.eval_to_usize(values)?.into_owned();
let input_shape = self.pool_spec.data_format.shape(input_shape)?;
let output_shape =
Expand All @@ -95,33 +87,13 @@ impl DeconvSum {
&self.adjustments,
)?;
let mut tensor = if let Some(b) = &self.bias {
if output_shape.shape[0..output_shape.c_axis()].iter().all(|d| *d == 1) {
unsafe {
let mut tensor = Tensor::uninitialized::<T>(&output_shape.shape)?;
let values = b.as_ptr::<T>()?;
let slice = tensor.as_ptr_mut::<T>()?;
let stride = *output_shape.c_stride();
for ix in 0..b.len() {
let v = *values.add(ix);
for p in 0..stride {
*slice.add(stride * ix + p) = v;
}
}
tensor
}
} else {
let mut tensor = Tensor::zero::<T>(&output_shape.shape)?;
let mut output = tensor.to_array_view_mut::<T>()?;
let mut bias_shape = tvec!(1; output_shape.rank());
bias_shape[output_shape.c_axis()] = b.len();
let b = b.clone().into_tensor().into_shape(&bias_shape)?;
output += &b.to_array_view::<T>()?;
tensor
}
let mut bias_shape = tvec!(1; output_shape.rank());
bias_shape[output_shape.c_axis()] = b.len();
let b = b.clone().into_tensor().into_shape(&bias_shape)?;
b.broadcast_to_shape(&output_shape.shape)?
} else {
Tensor::zero::<T>(&output_shape.shape)?
Tensor::zero_dt(dt, &output_shape.shape)?
};
let mut output = tensor.to_array_view_mut::<T>()?;
let hw = *gemm.shape().last().unwrap();
let n = *output_shape.n().unwrap_or(&1);
let n_o_hkwk_hw = gemm.into_shape(&[
Expand All @@ -130,41 +102,63 @@ impl DeconvSum {
self.pool_spec.kernel_shape.iter().product(),
hw,
])?;
let n_o_hkwk_hw: ArrayView4<T> = n_o_hkwk_hw.to_array_view::<T>()?.into_dimensionality()?;
if !self.pool_spec.data_format.has_n() {
output = output.insert_axis(Axis(0));
tensor.insert_axis(0)?;
}
dispatch_floatlike!(Self::eval_t(dt)(
self,
&input_shape,
&output_shape,
&spatial_output_details,
&n_o_hkwk_hw,
&mut tensor
))?;
if !self.pool_spec.data_format.has_n() {
tensor.remove_axis(0)?;
}
Ok(tvec!(tensor.into_tvalue()))
}

fn eval_t<T: Datum + Float + Copy + AddAssign<T>>(
&self,
input_shape: &DataShape,
output_shape: &DataShape,
spatial_output_details: &[ComputedPaddedDim<usize>],
n_o_hkwk_hw: &Tensor,
output: &mut Tensor,
) -> TractResult<()> {
let output = output.to_array_view_mut::<T>()?;
let n_o_hkwk_hw: ArrayView4<T> = n_o_hkwk_hw.to_array_view::<T>()?.into_dimensionality()?;
match input_shape.hw_rank() {
1 => self.main_loop_1d(
&input_shape,
&output_shape,
&spatial_output_details,
&n_o_hkwk_hw,
&mut output.into_dimensionality().unwrap(),
)?,
),
2 => self.main_loop_2d(
&input_shape,
&output_shape,
&spatial_output_details,
&n_o_hkwk_hw,
&mut output.into_dimensionality().unwrap(),
)?,
),
3 => self.main_loop_3d(
&input_shape,
&output_shape,
&spatial_output_details,
&n_o_hkwk_hw,
&mut output.into_dimensionality().unwrap(),
)?,
),
_ => self.main_loop(
&input_shape,
&output_shape,
&spatial_output_details,
&n_o_hkwk_hw,
&mut output.into_dimensionality().unwrap(),
)?,
),
}
Ok(tvec!(tensor.into_tvalue()))
}

pub fn main_loop_1d<T: Datum + Float + AddAssign>(
Expand Down
18 changes: 18 additions & 0 deletions data/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
use crate::datum::{round_ties_to_even, scale_by, Blob, ClampCast, Datum, DatumType, QParams};
use crate::dim::TDim;
use crate::TVec;
use anyhow::Context;
use half::f16;
use itertools::Itertools;
use ndarray::prelude::*;
Expand Down Expand Up @@ -538,6 +539,23 @@ impl Tensor {
}
}

fn broadcast_to_shape_t<T: Datum>(&self, shape: &[usize]) -> anyhow::Result<Tensor> {
unsafe {
let view = self.to_array_view_unchecked::<T>();
let mut output = view
.broadcast(shape)
.with_context(|| format!("Broadcasting {view:?} to {shape:?}"))?
.into_owned()
.into_tensor();
output.set_datum_type(self.datum_type());
Ok(output)
}
}

pub fn broadcast_to_shape(&self, shape: &[usize]) -> anyhow::Result<Tensor> {
dispatch_datum!(Self::broadcast_to_shape_t(self.dt)(self, shape))
}

fn clip_range_bounds(
&self,
axis: usize,
Expand Down
16 changes: 16 additions & 0 deletions test-rt/suite-unit/src/deconv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,22 @@ pub fn suite() -> TractResult<TestSuite> {
},
);

suite.add(
"bias_1",
DeconvProblem {
data_format: HWC,
kernel_format: OIHW,
padding: PaddingSpec::Valid,
input: arr2(&[[0.0], [0.0]]).into_dyn(),
kernel: arr3(&[[[0.0]]]).into_dyn(),
bias: Some(arr1(&[1.0f32]).into_dyn()),
strides: tvec!(1),
dilations: tvec!(1),
adjustments: tvec!(0),
group: 1,
},
);

suite.add(
"rank_5_with_group",
DeconvProblem {
Expand Down

0 comments on commit 06b007f

Please sign in to comment.