Skip to content

Commit

Permalink
add transpose operation to tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
PaytonWebber committed Dec 19, 2024
1 parent 1f78023 commit 5d7bba3
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 8 deletions.
20 changes: 20 additions & 0 deletions tensor/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,26 @@ impl Tensor {
self.strides = vec![1];
}

pub fn transpose(&mut self) -> Result<(), &'static str> {
if self.shape.len() != 2 {
return Err("transpose only supports 2D tensors currently.");
}

let (m, n) = (self.shape[0], self.shape[1]);
let mut new_data = vec![0.0_f32; self.data.len()];
for i in 0..m {
for j in 0..n {
let old_idx = i * n + j;
let new_idx = j * m + i;
new_data[new_idx] = self.data[old_idx];
}
}
self.data = new_data;
self.shape = vec![n, m];
self.strides = vec![m, 1];
Ok(())
}

/* BINARY OPS */

pub fn add(&self, other: &Tensor) -> Result<Tensor, &'static str> {
Expand Down
49 changes: 41 additions & 8 deletions tensor/tests/tensor_core_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,20 @@ fn create_ones_tensor() {
assert_eq!(expected_data, *a.data());
}

#[test]
fn get_element_with_index() {
let shape = vec![2, 3];
let data = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let a = Tensor::new(&shape, data).unwrap();

assert_eq!(a[&[0, 0]], 1.0);
assert_eq!(a[&[0, 1]], 2.0);
assert_eq!(a[&[0, 2]], 3.0);
assert_eq!(a[&[1, 0]], 4.0);
assert_eq!(a[&[1, 1]], 5.0);
assert_eq!(a[&[1, 2]], 6.0);
}

#[test]
fn reshape_tensor_valid_shape() {
let original_shape = vec![4, 2];
Expand Down Expand Up @@ -115,14 +129,33 @@ fn flatten_tensor() {
}

#[test]
fn get_element_with_index() {
let length: usize = 24;
let shape = vec![3, 2, 4];
let data: Vec<f32> = (0..length).map(|v| v as f32 + 10.0).collect();
let a = Tensor::new(&shape, data).unwrap();

let elem: f32 = a[&[1, 0, 3]];
assert_eq!(elem, 21.0_f32);
fn test_transpose_2d() {
// Create a 2D tensor:
// A = [ [1, 2, 3],
// [4, 5, 6] ]
let shape = vec![2, 3];
let data = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let mut a = Tensor::new(&shape, data).unwrap();

// Transpose A:
// A^T should be:
// [ [1, 4],
// [2, 5],
// [3, 6] ]
a.transpose().unwrap();
assert_eq!(*a.shape(), vec![3, 2]);
assert_eq!(*a.strides(), vec![2, 1]);

// Check values:
// A^T[0, 0] = 1, A^T[0, 1] = 4
// A^T[1, 0] = 2, A^T[1, 1] = 5
// A^T[2, 0] = 3, A^T[2, 1] = 6
assert_eq!(a[&[0, 0]], 1.0);
assert_eq!(a[&[1, 0]], 2.0);
assert_eq!(a[&[2, 0]], 3.0);
assert_eq!(a[&[0, 1]], 4.0);
assert_eq!(a[&[1, 1]], 5.0);
assert_eq!(a[&[2, 1]], 6.0);
}

#[test]
Expand Down

0 comments on commit 5d7bba3

Please sign in to comment.