From 5d7bba385c5502425926c7926602c45b79ddcc1b Mon Sep 17 00:00:00 2001 From: Payton Webber Date: Thu, 19 Dec 2024 02:34:16 -0800 Subject: [PATCH] add transpose operation to tensor --- tensor/src/lib.rs | 20 +++++++++++++ tensor/tests/tensor_core_test.rs | 49 ++++++++++++++++++++++++++------ 2 files changed, 61 insertions(+), 8 deletions(-) diff --git a/tensor/src/lib.rs b/tensor/src/lib.rs index e78119d..4033d3d 100644 --- a/tensor/src/lib.rs +++ b/tensor/src/lib.rs @@ -89,6 +89,26 @@ impl Tensor { self.strides = vec![1]; } + pub fn transpose(&mut self) -> Result<(), &'static str> { + if self.shape.len() != 2 { + return Err("transpose only supports 2D tensors currently."); + } + + let (m, n) = (self.shape[0], self.shape[1]); + let mut new_data = vec![0.0_f32; self.data.len()]; + for i in 0..m { + for j in 0..n { + let old_idx = i * n + j; + let new_idx = j * m + i; + new_data[new_idx] = self.data[old_idx]; + } + } + self.data = new_data; + self.shape = vec![n, m]; + self.strides = vec![m, 1]; + Ok(()) + } + /* BINARY OPS */ pub fn add(&self, other: &Tensor) -> Result { diff --git a/tensor/tests/tensor_core_test.rs b/tensor/tests/tensor_core_test.rs index 3ea99c3..c3e6893 100644 --- a/tensor/tests/tensor_core_test.rs +++ b/tensor/tests/tensor_core_test.rs @@ -40,6 +40,20 @@ fn create_ones_tensor() { assert_eq!(expected_data, *a.data()); } +#[test] +fn get_element_with_index() { + let shape = vec![2, 3]; + let data = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + let a = Tensor::new(&shape, data).unwrap(); + + assert_eq!(a[&[0, 0]], 1.0); + assert_eq!(a[&[0, 1]], 2.0); + assert_eq!(a[&[0, 2]], 3.0); + assert_eq!(a[&[1, 0]], 4.0); + assert_eq!(a[&[1, 1]], 5.0); + assert_eq!(a[&[1, 2]], 6.0); +} + #[test] fn reshape_tensor_valid_shape() { let original_shape = vec![4, 2]; @@ -115,14 +129,33 @@ fn flatten_tensor() { } #[test] -fn get_element_with_index() { - let length: usize = 24; - let shape = vec![3, 2, 4]; - let data: Vec = (0..length).map(|v| v as f32 + 10.0).collect(); - let a = Tensor::new(&shape, data).unwrap(); - - let elem: f32 = a[&[1, 0, 3]]; - assert_eq!(elem, 21.0_f32); +fn test_transpose_2d() { + // Create a 2D tensor: + // A = [ [1, 2, 3], + // [4, 5, 6] ] + let shape = vec![2, 3]; + let data = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + let mut a = Tensor::new(&shape, data).unwrap(); + + // Transpose A: + // A^T should be: + // [ [1, 4], + // [2, 5], + // [3, 6] ] + a.transpose().unwrap(); + assert_eq!(*a.shape(), vec![3, 2]); + assert_eq!(*a.strides(), vec![2, 1]); + + // Check values: + // A^T[0, 0] = 1, A^T[0, 1] = 4 + // A^T[1, 0] = 2, A^T[1, 1] = 5 + // A^T[2, 0] = 3, A^T[2, 1] = 6 + assert_eq!(a[&[0, 0]], 1.0); + assert_eq!(a[&[1, 0]], 2.0); + assert_eq!(a[&[2, 0]], 3.0); + assert_eq!(a[&[0, 1]], 4.0); + assert_eq!(a[&[1, 1]], 5.0); + assert_eq!(a[&[2, 1]], 6.0); } #[test]