From 96a39f40482b00835513990c6958ea2a4f2a60af Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Fri, 13 Oct 2023 17:29:14 +0200 Subject: [PATCH] optimise broadcast vec to shape --- core/src/ops/cnn/deconv/deconv_sum.rs | 5 +- data/src/tensor.rs | 90 ++++++++++++++++++++++++++- 2 files changed, 90 insertions(+), 5 deletions(-) diff --git a/core/src/ops/cnn/deconv/deconv_sum.rs b/core/src/ops/cnn/deconv/deconv_sum.rs index 7a3e0b416f..1c604886e6 100644 --- a/core/src/ops/cnn/deconv/deconv_sum.rs +++ b/core/src/ops/cnn/deconv/deconv_sum.rs @@ -87,10 +87,7 @@ impl DeconvSum { &self.adjustments, )?; let mut tensor = if let Some(b) = &self.bias { - let mut bias_shape = tvec!(1; output_shape.rank()); - bias_shape[output_shape.c_axis()] = b.len(); - let b = b.clone().into_tensor().into_shape(&bias_shape)?; - b.broadcast_to_shape(&output_shape.shape)? + b.broadcast_vector_to_shape(&output_shape.shape, output_shape.c_axis())? } else { Tensor::zero_dt(dt, &output_shape.shape)? }; diff --git a/data/src/tensor.rs b/data/src/tensor.rs index e48263b6f1..b077dfccfe 100644 --- a/data/src/tensor.rs +++ b/data/src/tensor.rs @@ -2,7 +2,7 @@ use crate::datum::{round_ties_to_even, scale_by, Blob, ClampCast, Datum, DatumType, QParams}; use crate::dim::TDim; use crate::TVec; -use anyhow::Context; +use anyhow::{ensure, Context}; use half::f16; use itertools::Itertools; use ndarray::prelude::*; @@ -556,6 +556,41 @@ impl Tensor { dispatch_datum!(Self::broadcast_to_shape_t(self.dt)(self, shape)) } + pub fn broadcast_vector_to_shape( + &self, + shape: &[usize], + axis: usize, + ) -> anyhow::Result { + ensure!(self.rank() == 1); + ensure!(shape[axis] == self.len()); + unsafe { + let mut output = Tensor::uninitialized_dt(self.datum_type(), shape)?; + if output.len() == 0 { + return Ok(output); + } + let inner_len = shape[axis + 1..].iter().product::(); + + unsafe fn splat(input: &Tensor, output: &mut Tensor, inner_len: usize) { + for ix in 0..input.len() { + let value: &T = &input.as_slice_unchecked()[ix]; + output.as_slice_mut_unchecked::()[ix * inner_len..(ix + 1) * inner_len] + .iter_mut() + .for_each(|item| *item = value.clone()); + } + } + dispatch_datum!(splat(self.datum_type())(&self, &mut output, inner_len)); + + let outer_len = shape[0..axis].iter().product::(); + let repeat_bytes_len = inner_len * self.as_bytes().len(); + let bytes = output.as_bytes_mut(); + for ix in 1..outer_len { + bytes.copy_within(0..repeat_bytes_len, ix * repeat_bytes_len); + } + + Ok(output) + } + } + fn clip_range_bounds( &self, axis: usize, @@ -1512,7 +1547,10 @@ impl IntoArcTensor for Arc { #[cfg(test)] mod tests { + use crate::prelude::tensor1; + use super::*; + use proptest::collection::vec; use proptest::prelude::*; #[derive(Debug)] @@ -1579,6 +1617,56 @@ mod tests { PermuteAxisProblem { shape: vec![2, 2], permutation: vec![1, 0] }.check().unwrap(); } + #[derive(Debug)] + struct BroadcastVecToShape { + vec: Vec, + axis: usize, + shape: TVec, + } + + impl BroadcastVecToShape { + fn check(&self) -> proptest::test_runner::TestCaseResult { + let input = tensor1(&self.vec); + let mut intermediate = tvec![1usize; self.shape.len()]; + intermediate[self.axis] = self.vec.len(); + let reference = input + .clone() + .into_shape(&intermediate) + .unwrap() + .broadcast_to_shape(&self.shape) + .unwrap(); + prop_assert_eq!( + reference, + input.broadcast_vector_to_shape(&self.shape, self.axis).unwrap() + ); + Ok(()) + } + } + + impl Arbitrary for BroadcastVecToShape { + type Strategy = BoxedStrategy; + type Parameters = (); + + fn arbitrary_with(_: Self::Parameters) -> Self::Strategy { + vec(0usize..5, 0usize..4) + .prop_flat_map(|shape| { + (vec(-10f32..10f32, 0usize..5), Just(shape.clone()), 0..shape.len() + 1) + }) + .prop_map(|(vec, mut shape, axis)| { + shape.insert(axis, vec.len()); + BroadcastVecToShape { vec, shape: shape.into(), axis } + }) + .boxed() + } + } + + proptest::proptest! { + #[test] + fn broadcast_vector_to_shape_prop(pb: BroadcastVecToShape) { + pb.check().unwrap() + } + } + #[test] #[cfg(feature = "complex")] fn test_reinterpret_inner_dim_as_complex() -> anyhow::Result<()> {