diff --git a/data/src/tensor.rs b/data/src/tensor.rs index b077dfccfe..0950f992d0 100644 --- a/data/src/tensor.rs +++ b/data/src/tensor.rs @@ -563,6 +563,11 @@ impl Tensor { ) -> anyhow::Result { ensure!(self.rank() == 1); ensure!(shape[axis] == self.len()); + if !self.datum_type().is_copy() { + let mut vec_shape = vec![1; shape.len()]; + vec_shape[axis] = self.len(); + return self.clone().into_shape(&vec_shape)?.broadcast_to_shape(shape); + } unsafe { let mut output = Tensor::uninitialized_dt(self.datum_type(), shape)?; if output.len() == 0 { @@ -570,15 +575,18 @@ impl Tensor { } let inner_len = shape[axis + 1..].iter().product::(); - unsafe fn splat(input: &Tensor, output: &mut Tensor, inner_len: usize) { + unsafe fn splat(input: &Tensor, output: &mut Tensor, inner_len: usize) + where + T: Datum + Copy, + { for ix in 0..input.len() { - let value: &T = &input.as_slice_unchecked()[ix]; + let value: T = input.as_slice_unchecked()[ix]; output.as_slice_mut_unchecked::()[ix * inner_len..(ix + 1) * inner_len] .iter_mut() - .for_each(|item| *item = value.clone()); + .for_each(|item| *item = value); } } - dispatch_datum!(splat(self.datum_type())(&self, &mut output, inner_len)); + dispatch_copy_by_size!(splat(self.datum_type())(&self, &mut output, inner_len)); let outer_len = shape[0..axis].iter().product::(); let repeat_bytes_len = inner_len * self.as_bytes().len();