Skip to content

Commit

Permalink
fix issue where flatten did not update shape correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
PaytonWebber committed Dec 18, 2024
1 parent 60ce1bc commit 210c13d
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
2 changes: 1 addition & 1 deletion tensor/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ impl Tensor {
}

pub fn flatten(&mut self) {
self.shape = vec![1];
self.shape = vec![self.shape.iter().product()];
self.strides = vec![1];
}

Expand Down
4 changes: 3 additions & 1 deletion tensor/tests/tensor_core_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,11 @@ fn flatten_tensor() {
let mut a: Tensor = Tensor::ones(&[7, 6]);

a.flatten();
assert_eq!(vec![1], *a.shape());
let elem: f32 = a[&[22]];
assert_eq!(vec![length], *a.shape());
assert_eq!(vec![1], *a.strides());
assert_eq!(expected_data, *a.data());
assert_eq!(elem, 1.0_f32);
}

#[test]
Expand Down

0 comments on commit 210c13d

Please sign in to comment.