From 72d51d683f723f33641547d543e20b85fcda70a7 Mon Sep 17 00:00:00 2001 From: Payton Webber Date: Tue, 17 Dec 2024 16:53:07 -0800 Subject: [PATCH] add permute operation to tensor --- tensor/src/lib.rs | 8 ++++++++ tensor/tests/tensor_core_test.rs | 27 +++++++++++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/tensor/src/lib.rs b/tensor/src/lib.rs index 6664c32..68f56bb 100644 --- a/tensor/src/lib.rs +++ b/tensor/src/lib.rs @@ -62,6 +62,14 @@ impl Tensor { Ok(()) } + pub fn permute(&mut self, order: &[usize]) -> Result<(), &'static str> { + if order.len() != self.shape.len() { + return Err("The permutation does not align with the current shape."); + } + let new_shape: Vec = order.iter().map(|i| self.shape[*i]).collect(); + self.reshape(&new_shape) + } + pub fn shape(&self) -> &Vec { &self.shape } diff --git a/tensor/tests/tensor_core_test.rs b/tensor/tests/tensor_core_test.rs index 35dc27f..6fad4cf 100644 --- a/tensor/tests/tensor_core_test.rs +++ b/tensor/tests/tensor_core_test.rs @@ -1,3 +1,5 @@ +use std::usize; + use tensor::Tensor; #[test] @@ -64,6 +66,31 @@ fn reshape_tensor_invalid_shape() { } } +#[test] +fn permute_tensor_valid_order() { + let original_shape: Vec = vec![4, 2, 2]; + let mut a: Tensor = Tensor::ones(&original_shape); + + let permutation: Vec = vec![1, 0, 2]; + let new_strides: Vec = vec![8, 2, 1]; + let new_shape: Vec = vec![2, 4, 2]; + a.permute(&permutation).unwrap(); + + assert_eq!(new_strides, *a.strides()); + assert_eq!(new_shape, *a.shape()); +} + +#[test] +fn permute_tensor_invalid_order() { + let original_shape: Vec = vec![4, 2]; + let mut a: Tensor = Tensor::ones(&original_shape); + + let permutation: Vec = vec![2, 0, 1]; + if let Ok(()) = a.permute(&permutation) { + panic!("The permutation should've been invalid.") + } +} + #[test] fn get_element_with_index() { let length: usize = 24;