Skip to content

Commit

Permalink
Update mod.rs
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-camuto committed Sep 11, 2023
1 parent e5c52ae commit 3efbcb9
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions src/tensor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -714,13 +714,18 @@ impl<T: Clone + TensorType> Tensor<T> {
/// assert_eq!(a.dims(), &[9, 3]);
/// ```
pub fn reshape(&mut self, new_dims: &[usize]) {
let product = if new_dims != &[0] && !new_dims.is_empty() {
new_dims.iter().product::<usize>()
// in onnx parlance this corresponds to converting a tensor to a single element
if new_dims.is_empty() {
assert!(self.len() == 1 || self.len() == 0);
} else {
0
};
assert!(self.len() == product);
self.dims = Vec::from(new_dims);
let product = if new_dims != &[0] {
new_dims.iter().product::<usize>()
} else {
0
};
assert!(self.len() == product);
self.dims = Vec::from(new_dims);
}
}

/// Move axis of the tensor
Expand Down

0 comments on commit 3efbcb9

Please sign in to comment.