Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
rework uninitialized tensor funciton
Browse files Browse the repository at this point in the history
kali committed Jan 31, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent ff770ae commit f1600fa
Showing 1 changed file with 16 additions and 14 deletions.
30 changes: 16 additions & 14 deletions data/src/tensor.rs
Original file line number Diff line number Diff line change
@@ -8,6 +8,7 @@ use itertools::Itertools;
use ndarray::prelude::*;
#[cfg(feature = "complex")]
use num_complex::Complex;
use num_traits::Zero;
use std::alloc;
use std::borrow::Cow;
use std::fmt;
@@ -194,14 +195,6 @@ impl Tensor {
shape: &[usize],
alignment: usize,
) -> anyhow::Result<Tensor> {
if dt == String::datum_type() {
return Ok(ndarray::ArrayD::<String>::default(shape).into());
} else if dt == Blob::datum_type() {
return Ok(ndarray::ArrayD::<Blob>::default(shape).into());
} else if dt == TDim::datum_type() {
return Ok(ndarray::ArrayD::<TDim>::default(shape).into());
}
assert!(dt.is_copy());
let bytes = shape.iter().cloned().product::<usize>() * dt.size_of();
let layout = alloc::Layout::from_size_align(bytes, alignment)?;
let data = if bytes == 0 {
@@ -213,13 +206,22 @@ impl Tensor {
} as *mut u8;
let mut tensor = Tensor { strides: tvec!(), layout, dt, shape: shape.into(), data, len: 0 };
tensor.update_strides_and_len();
#[cfg(debug_assertions)]
if !data.is_null() {
if dt == DatumType::F32 {
tensor.as_slice_mut_unchecked::<f32>().iter_mut().for_each(|f| *f = std::f32::NAN);
} else {
// safe, non copy types have been dealt with
tensor.as_bytes_mut().iter_mut().for_each(|x| *x = (-1i8) as u8);
if dt == String::datum_type() || dt == Blob::datum_type() {
tensor.as_bytes_mut().iter_mut().for_each(|x| *x = 0);
} else if dt == TDim::datum_type() {
tensor
.as_slice_mut_unchecked::<TDim>()
.iter_mut()
.for_each(|dim| std::ptr::write(dim, TDim::zero()))
} else if cfg!(debug_assertions) {
assert!(dt.is_copy());
if dt == DatumType::F32 {
tensor.fill_t(std::f32::NAN).unwrap();
} else {
// safe, non copy types have been dealt with
tensor.as_bytes_mut().iter_mut().for_each(|x| *x = (-1i8) as u8);
}
}
}
Ok(tensor)

0 comments on commit f1600fa

Please sign in to comment.