diff --git a/tensor/src/lib.rs b/tensor/src/lib.rs index 68f56bb..e5299d9 100644 --- a/tensor/src/lib.rs +++ b/tensor/src/lib.rs @@ -66,6 +66,13 @@ impl Tensor { if order.len() != self.shape.len() { return Err("The permutation does not align with the current shape."); } + + let mut sorted_order: Vec = order.to_vec(); + sorted_order.sort(); + if sorted_order != (0..self.shape.len()).collect::>() { + return Err("Index out of range for shape."); + } + let new_shape: Vec = order.iter().map(|i| self.shape[*i]).collect(); self.reshape(&new_shape) } diff --git a/tensor/tests/tensor_core_test.rs b/tensor/tests/tensor_core_test.rs index 6fad4cf..7af5854 100644 --- a/tensor/tests/tensor_core_test.rs +++ b/tensor/tests/tensor_core_test.rs @@ -80,6 +80,17 @@ fn permute_tensor_valid_order() { assert_eq!(new_shape, *a.shape()); } +#[test] +fn permute_tensor_index_out_of_range() { + let original_shape: Vec = vec![4, 2]; + let mut a: Tensor = Tensor::ones(&original_shape); + + let permutation: Vec = vec![4, 2, 1]; + if let Ok(()) = a.permute(&permutation) { + panic!("The permutation should've been invalid.") + } +} + #[test] fn permute_tensor_invalid_order() { let original_shape: Vec = vec![4, 2];