Skip to content

Commit

Permalink
optimise broadcast vec to shape
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Oct 13, 2023
1 parent 06b007f commit 96a39f4
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 5 deletions.
5 changes: 1 addition & 4 deletions core/src/ops/cnn/deconv/deconv_sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,7 @@ impl DeconvSum {
&self.adjustments,
)?;
let mut tensor = if let Some(b) = &self.bias {
let mut bias_shape = tvec!(1; output_shape.rank());
bias_shape[output_shape.c_axis()] = b.len();
let b = b.clone().into_tensor().into_shape(&bias_shape)?;
b.broadcast_to_shape(&output_shape.shape)?
b.broadcast_vector_to_shape(&output_shape.shape, output_shape.c_axis())?
} else {
Tensor::zero_dt(dt, &output_shape.shape)?
};
Expand Down
90 changes: 89 additions & 1 deletion data/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
use crate::datum::{round_ties_to_even, scale_by, Blob, ClampCast, Datum, DatumType, QParams};
use crate::dim::TDim;
use crate::TVec;
use anyhow::Context;
use anyhow::{ensure, Context};
use half::f16;
use itertools::Itertools;
use ndarray::prelude::*;
Expand Down Expand Up @@ -556,6 +556,41 @@ impl Tensor {
dispatch_datum!(Self::broadcast_to_shape_t(self.dt)(self, shape))
}

pub fn broadcast_vector_to_shape(
&self,
shape: &[usize],
axis: usize,
) -> anyhow::Result<Tensor> {
ensure!(self.rank() == 1);
ensure!(shape[axis] == self.len());
unsafe {
let mut output = Tensor::uninitialized_dt(self.datum_type(), shape)?;
if output.len() == 0 {
return Ok(output);
}
let inner_len = shape[axis + 1..].iter().product::<usize>();

unsafe fn splat<T: Datum>(input: &Tensor, output: &mut Tensor, inner_len: usize) {
for ix in 0..input.len() {
let value: &T = &input.as_slice_unchecked()[ix];
output.as_slice_mut_unchecked::<T>()[ix * inner_len..(ix + 1) * inner_len]
.iter_mut()
.for_each(|item| *item = value.clone());
}
}
dispatch_datum!(splat(self.datum_type())(&self, &mut output, inner_len));

let outer_len = shape[0..axis].iter().product::<usize>();
let repeat_bytes_len = inner_len * self.as_bytes().len();
let bytes = output.as_bytes_mut();
for ix in 1..outer_len {
bytes.copy_within(0..repeat_bytes_len, ix * repeat_bytes_len);
}

Ok(output)
}
}

fn clip_range_bounds(
&self,
axis: usize,
Expand Down Expand Up @@ -1512,7 +1547,10 @@ impl IntoArcTensor for Arc<Tensor> {

#[cfg(test)]
mod tests {
use crate::prelude::tensor1;

use super::*;
use proptest::collection::vec;
use proptest::prelude::*;

#[derive(Debug)]
Expand Down Expand Up @@ -1579,6 +1617,56 @@ mod tests {
PermuteAxisProblem { shape: vec![2, 2], permutation: vec![1, 0] }.check().unwrap();
}

#[derive(Debug)]
struct BroadcastVecToShape {
vec: Vec<f32>,
axis: usize,
shape: TVec<usize>,
}

impl BroadcastVecToShape {
fn check(&self) -> proptest::test_runner::TestCaseResult {
let input = tensor1(&self.vec);
let mut intermediate = tvec![1usize; self.shape.len()];
intermediate[self.axis] = self.vec.len();
let reference = input
.clone()
.into_shape(&intermediate)
.unwrap()
.broadcast_to_shape(&self.shape)
.unwrap();
prop_assert_eq!(
reference,
input.broadcast_vector_to_shape(&self.shape, self.axis).unwrap()
);
Ok(())
}
}

impl Arbitrary for BroadcastVecToShape {
type Strategy = BoxedStrategy<BroadcastVecToShape>;
type Parameters = ();

fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
vec(0usize..5, 0usize..4)
.prop_flat_map(|shape| {
(vec(-10f32..10f32, 0usize..5), Just(shape.clone()), 0..shape.len() + 1)
})
.prop_map(|(vec, mut shape, axis)| {
shape.insert(axis, vec.len());
BroadcastVecToShape { vec, shape: shape.into(), axis }
})
.boxed()
}
}

proptest::proptest! {
#[test]
fn broadcast_vector_to_shape_prop(pb: BroadcastVecToShape) {
pb.check().unwrap()
}
}

#[test]
#[cfg(feature = "complex")]
fn test_reinterpret_inner_dim_as_complex() -> anyhow::Result<()> {
Expand Down

0 comments on commit 96a39f4

Please sign in to comment.