Skip to content

Commit

Permalink
clone variant
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Oct 16, 2023
1 parent 96a39f4 commit f724c4f
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions data/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -563,22 +563,30 @@ impl Tensor {
) -> anyhow::Result<Tensor> {
ensure!(self.rank() == 1);
ensure!(shape[axis] == self.len());
if !self.datum_type().is_copy() {
let mut vec_shape = vec![1; shape.len()];
vec_shape[axis] = self.len();
return self.clone().into_shape(&vec_shape)?.broadcast_to_shape(shape);
}
unsafe {
let mut output = Tensor::uninitialized_dt(self.datum_type(), shape)?;
if output.len() == 0 {
return Ok(output);
}
let inner_len = shape[axis + 1..].iter().product::<usize>();

unsafe fn splat<T: Datum>(input: &Tensor, output: &mut Tensor, inner_len: usize) {
unsafe fn splat<T>(input: &Tensor, output: &mut Tensor, inner_len: usize)
where
T: Datum + Copy,
{
for ix in 0..input.len() {
let value: &T = &input.as_slice_unchecked()[ix];
let value: T = input.as_slice_unchecked()[ix];
output.as_slice_mut_unchecked::<T>()[ix * inner_len..(ix + 1) * inner_len]
.iter_mut()
.for_each(|item| *item = value.clone());
.for_each(|item| *item = value);
}
}
dispatch_datum!(splat(self.datum_type())(&self, &mut output, inner_len));
dispatch_copy_by_size!(splat(self.datum_type())(&self, &mut output, inner_len));

let outer_len = shape[0..axis].iter().product::<usize>();
let repeat_bytes_len = inner_len * self.as_bytes().len();
Expand Down

0 comments on commit f724c4f

Please sign in to comment.