From 3efbcb949b003fb98ef34b8b1ca6fa57541e2869 Mon Sep 17 00:00:00 2001 From: Alexander Camuto Date: Mon, 11 Sep 2023 11:53:01 +0100 Subject: [PATCH] Update mod.rs --- src/tensor/mod.rs | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/tensor/mod.rs b/src/tensor/mod.rs index 7f62d7600..c0d218390 100644 --- a/src/tensor/mod.rs +++ b/src/tensor/mod.rs @@ -714,13 +714,18 @@ impl Tensor { /// assert_eq!(a.dims(), &[9, 3]); /// ``` pub fn reshape(&mut self, new_dims: &[usize]) { - let product = if new_dims != &[0] && !new_dims.is_empty() { - new_dims.iter().product::() + // in onnx parlance this corresponds to converting a tensor to a single element + if new_dims.is_empty() { + assert!(self.len() == 1 || self.len() == 0); } else { - 0 - }; - assert!(self.len() == product); - self.dims = Vec::from(new_dims); + let product = if new_dims != &[0] { + new_dims.iter().product::() + } else { + 0 + }; + assert!(self.len() == product); + self.dims = Vec::from(new_dims); + } } /// Move axis of the tensor