Skip to content

Commit

Permalink
add flatten operation to tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
PaytonWebber committed Dec 18, 2024
1 parent 4e2db76 commit 60ce1bc
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 0 deletions.
5 changes: 5 additions & 0 deletions tensor/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ impl Tensor {
self.reshape(&new_shape)
}

pub fn flatten(&mut self) {
self.shape = vec![1];
self.strides = vec![1];
}

pub fn shape(&self) -> &Vec<usize> {
&self.shape
}
Expand Down
12 changes: 12 additions & 0 deletions tensor/tests/tensor_core_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,18 @@ fn permute_tensor_invalid_order() {
}
}

#[test]
fn flatten_tensor() {
let length: usize = 42;
let expected_data: Vec<f32> = vec![1.0_f32; length];
let mut a: Tensor = Tensor::ones(&[7, 6]);

a.flatten();
assert_eq!(vec![1], *a.shape());
assert_eq!(vec![1], *a.strides());
assert_eq!(expected_data, *a.data());
}

#[test]
fn get_element_with_index() {
let length: usize = 24;
Expand Down

0 comments on commit 60ce1bc

Please sign in to comment.