Skip to content

Commit

Permalink
fix issue where permute did not update strides correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
PaytonWebber committed Dec 18, 2024
1 parent 210c13d commit 5292d83
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 56 deletions.
8 changes: 6 additions & 2 deletions tensor/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,12 @@ impl Tensor {
return Err("Index out of range for shape.");
}

let new_shape: Vec<usize> = order.iter().map(|i| self.shape[*i]).collect();
self.reshape(&new_shape)
let new_shape: Vec<usize> = order.iter().map(|&i| self.shape[i]).collect();
let new_strides: Vec<usize> = order.iter().map(|&i| self.strides[i]).collect();

self.shape = new_shape;
self.strides = new_strides;
Ok(())
}

pub fn flatten(&mut self) {
Expand Down
106 changes: 52 additions & 54 deletions tensor/tests/tensor_core_test.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
use std::usize;

use tensor::Tensor;

#[test]
fn create_tensor_from_data() {
let shape: Vec<usize> = vec![3, 4, 3];
let strides: Vec<usize> = vec![12, 3, 1];
let shape = vec![3, 4, 3];
let strides = vec![12, 3, 1];
let length: usize = shape.iter().product();
let data: Vec<f32> = (0..length).map(|v| v as f32 + 10.0).collect();
let expected_data: Vec<f32> = data.to_vec().clone();
let a: Tensor = Tensor::new(&shape, data).unwrap();
let expected_data: Vec<f32> = data.to_vec();
let a = Tensor::new(&shape, data).unwrap();

assert_eq!(shape, *a.shape());
assert_eq!(strides, *a.strides());
Expand All @@ -18,11 +16,11 @@ fn create_tensor_from_data() {

#[test]
fn create_zeros_tensor() {
let shape: Vec<usize> = vec![4, 2];
let strides: Vec<usize> = vec![2, 1];
let shape = vec![4, 2];
let strides = vec![2, 1];
let length: usize = shape.iter().product();
let expected_data: Vec<f32> = vec![0.0; length];
let a: Tensor = Tensor::zeros(&shape);
let expected_data = vec![0.0; length];
let a = Tensor::zeros(&shape);

assert_eq!(shape, *a.shape());
assert_eq!(strides, *a.strides());
Expand All @@ -31,11 +29,11 @@ fn create_zeros_tensor() {

#[test]
fn create_ones_tensor() {
let shape: Vec<usize> = vec![1, 9, 2, 5];
let strides: Vec<usize> = vec![90, 10, 5, 1];
let shape = vec![1, 9, 2, 5];
let strides = vec![90, 10, 5, 1];
let length: usize = shape.iter().product();
let expected_data: Vec<f32> = vec![1.0; length];
let a: Tensor = Tensor::ones(&shape);
let expected_data = vec![1.0; length];
let a = Tensor::ones(&shape);

assert_eq!(shape, *a.shape());
assert_eq!(strides, *a.strides());
Expand All @@ -44,69 +42,69 @@ fn create_ones_tensor() {

#[test]
fn reshape_tensor_valid_shape() {
let original_shape: Vec<usize> = vec![4, 2];
let mut a: Tensor = Tensor::ones(&original_shape);
let original_shape = vec![4, 2];
let mut a = Tensor::ones(&original_shape);

let new_shape: Vec<usize> = vec![2, 2, 2];
let new_strides: Vec<usize> = vec![4, 2, 1];
let new_shape = vec![2, 2, 2];
let new_strides = vec![4, 2, 1];
a.reshape(&new_shape).unwrap();

assert_eq!(new_strides, *a.strides());
assert_eq!(new_shape, *a.shape());
assert_eq!(new_strides, *a.strides());
}

#[test]
fn reshape_tensor_invalid_shape() {
let original_shape: Vec<usize> = vec![4, 2];
let mut a: Tensor = Tensor::ones(&original_shape);
let original_shape = vec![4, 2];
let mut a = Tensor::ones(&original_shape);

let new_shape: Vec<usize> = vec![7, 6];
if let Ok(()) = a.reshape(&new_shape) {
panic!("The new shape should've been invalid.")
let new_shape = vec![7, 6];
if a.reshape(&new_shape).is_ok() {
panic!("The new shape should've been invalid.");
}
}

#[test]
fn permute_tensor_valid_order() {
let original_shape: Vec<usize> = vec![4, 2, 2];
let mut a: Tensor = Tensor::ones(&original_shape);
let original_shape = vec![1, 4, 2];
let mut a = Tensor::ones(&original_shape);

let permutation: Vec<usize> = vec![1, 0, 2];
let new_strides: Vec<usize> = vec![8, 2, 1];
let new_shape: Vec<usize> = vec![2, 4, 2];
let permutation = vec![1, 2, 0];
let new_strides = vec![2, 1, 8];
let new_shape = vec![4, 2, 1];
a.permute(&permutation).unwrap();

assert_eq!(new_strides, *a.strides());
assert_eq!(new_shape, *a.shape());
assert_eq!(new_strides, *a.strides());
}

#[test]
fn permute_tensor_index_out_of_range() {
let original_shape: Vec<usize> = vec![4, 2];
let mut a: Tensor = Tensor::ones(&original_shape);
let original_shape = vec![4, 2];
let mut a = Tensor::ones(&original_shape);

let permutation: Vec<usize> = vec![4, 2, 1];
if let Ok(()) = a.permute(&permutation) {
panic!("The permutation should've been invalid.")
let permutation = vec![4, 2, 1]; // invalid since original_shape.len() = 2
if a.permute(&permutation).is_ok() {
panic!("The permutation should've been invalid.");
}
}

#[test]
fn permute_tensor_invalid_order() {
let original_shape: Vec<usize> = vec![4, 2];
let mut a: Tensor = Tensor::ones(&original_shape);
let original_shape = vec![4, 2];
let mut a = Tensor::ones(&original_shape);

let permutation: Vec<usize> = vec![2, 0, 1];
if let Ok(()) = a.permute(&permutation) {
panic!("The permutation should've been invalid.")
let permutation = vec![2, 0, 1]; // not a proper permutation of [0, 1]
if a.permute(&permutation).is_ok() {
panic!("The permutation should've been invalid.");
}
}

#[test]
fn flatten_tensor() {
let length: usize = 42;
let expected_data: Vec<f32> = vec![1.0_f32; length];
let mut a: Tensor = Tensor::ones(&[7, 6]);
let expected_data = vec![1.0_f32; length];
let mut a = Tensor::ones(&[7, 6]);

a.flatten();
let elem: f32 = a[&[22]];
Expand All @@ -119,32 +117,32 @@ fn flatten_tensor() {
#[test]
fn get_element_with_index() {
let length: usize = 24;
let shape: Vec<usize> = vec![3, 2, 4];
let shape = vec![3, 2, 4];
let data: Vec<f32> = (0..length).map(|v| v as f32 + 10.0).collect();
let a: Tensor = Tensor::new(&shape, data).unwrap();
let a = Tensor::new(&shape, data).unwrap();

let elem: f32 = a[&[1, 0, 3]];
assert_eq!(elem, 21.0_f32);
}

#[test]
fn tensor_addition_method() {
let shape: Vec<usize> = vec![4, 2];
let a: Tensor = Tensor::ones(&shape);
let b: Tensor = Tensor::ones(&shape);
let result: Tensor = a.add(&b).unwrap();
let shape = vec![4, 2];
let a = Tensor::ones(&shape);
let b = Tensor::ones(&shape);
let result = a.add(&b).unwrap();

let expected_data: Vec<f32> = vec![2.0_f32; 8];
let expected_data = vec![2.0_f32; 8];
assert_eq!(expected_data, *result.data());
}

#[test]
fn tensor_addition_operator() {
let shape: Vec<usize> = vec![4, 2];
let a: Tensor = Tensor::ones(&shape);
let b: Tensor = Tensor::ones(&shape);
let result: Tensor = a + b;
let shape = vec![4, 2];
let a = Tensor::ones(&shape);
let b = Tensor::ones(&shape);
let result = a + b;

let expected_data: Vec<f32> = vec![2.0_f32; 8];
let expected_data = vec![2.0_f32; 8];
assert_eq!(expected_data, *result.data());
}

0 comments on commit 5292d83

Please sign in to comment.