diff --git a/src/tensor/mod.rs b/src/tensor/mod.rs index 7f62d7600..c0d218390 100644 --- a/src/tensor/mod.rs +++ b/src/tensor/mod.rs @@ -714,13 +714,18 @@ impl Tensor { /// assert_eq!(a.dims(), &[9, 3]); /// ``` pub fn reshape(&mut self, new_dims: &[usize]) { - let product = if new_dims != &[0] && !new_dims.is_empty() { - new_dims.iter().product::() + // in onnx parlance this corresponds to converting a tensor to a single element + if new_dims.is_empty() { + assert!(self.len() == 1 || self.len() == 0); } else { - 0 - }; - assert!(self.len() == product); - self.dims = Vec::from(new_dims); + let product = if new_dims != &[0] { + new_dims.iter().product::() + } else { + 0 + }; + assert!(self.len() == product); + self.dims = Vec::from(new_dims); + } } /// Move axis of the tensor