diff --git a/tensor/src/lib.rs b/tensor/src/lib.rs index 554ff9d..17bfb37 100644 --- a/tensor/src/lib.rs +++ b/tensor/src/lib.rs @@ -78,7 +78,7 @@ impl Tensor { } pub fn flatten(&mut self) { - self.shape = vec![1]; + self.shape = vec![self.shape.iter().product()]; self.strides = vec![1]; } diff --git a/tensor/tests/tensor_core_test.rs b/tensor/tests/tensor_core_test.rs index facd5f0..544758f 100644 --- a/tensor/tests/tensor_core_test.rs +++ b/tensor/tests/tensor_core_test.rs @@ -109,9 +109,11 @@ fn flatten_tensor() { let mut a: Tensor = Tensor::ones(&[7, 6]); a.flatten(); - assert_eq!(vec![1], *a.shape()); + let elem: f32 = a[&[22]]; + assert_eq!(vec![length], *a.shape()); assert_eq!(vec![1], *a.strides()); assert_eq!(expected_data, *a.data()); + assert_eq!(elem, 1.0_f32); } #[test]