Skip to content

Commit

Permalink
add fmt::Display trait to tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
PaytonWebber committed Dec 18, 2024
1 parent 5292d83 commit a8e3393
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 0 deletions.
108 changes: 108 additions & 0 deletions tensor/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use core::fmt;
use std::ops::{Add, Index};

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -129,6 +130,113 @@ impl Tensor {
}
}

fn calculate_data_index(indices: &[usize], strides: &[usize]) -> usize {
indices
.iter()
.enumerate()
.map(|(i, &idx)| idx * strides[i])
.sum()
}

fn print_tensor_recursive(
f: &mut fmt::Formatter<'_>,
data: &[f32],
shape: &[usize],
strides: &[usize],
current_index: &mut [usize],
dim: usize,
ndims: usize,
) -> fmt::Result {
if ndims == 0 {
// 0-D tensor (scalar)
if let Some(value) = data.first() {
return write!(f, "{:.4}", value);
} else {
return write!(f, "");
}
}

if dim == ndims - 1 {
// Last dimension: print elements in a row
write!(f, "[")?;
for i in 0..shape[dim] {
current_index[dim] = i;
let idx = calculate_data_index(current_index, strides);
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{:.4}", data[idx])?;
}
write!(f, "]")?;
} else {
// Not the last dimension
write!(f, "[")?;
for i in 0..shape[dim] {
current_index[dim] = i;

if i > 0 {
// Subsequent slices/rows
if dim == 0 {
// Top-level dimension
if ndims >= 3 {
// For 3D or more: blank line between top-level slices
write!(f, "\n\n")?;
// 7 spaces indentation
for _ in 0..7 {
write!(f, " ")?;
}
} else if ndims == 2 {
// For 2D: no blank line, just newline + 8 spaces
writeln!(f)?;
for _ in 0..8 {
write!(f, " ")?;
}
}
} else {
// Inner dimension (dim > 0)
// newline + 8 spaces
writeln!(f)?;
for _ in 0..8 {
write!(f, " ")?;
}
}
}

print_tensor_recursive(f, data, shape, strides, current_index, dim + 1, ndims)?;
}
write!(f, "]")?;
}

Ok(())
}

impl fmt::Display for Tensor {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.shape.is_empty() {
// 0-D tensor
if let Some(value) = self.data.first() {
return write!(f, "tensor({:.4})", value);
} else {
return write!(f, "tensor()");
}
}

write!(f, "tensor(")?;
let mut current_index = vec![0; self.shape.len()];
print_tensor_recursive(
f,
&self.data,
&self.shape,
&self.strides,
&mut current_index,
0,
self.shape.len(),
)?;
write!(f, ")")?;
Ok(())
}
}

impl Index<&[usize]> for Tensor {
type Output = f32;
fn index(&self, indices: &[usize]) -> &Self::Output {
Expand Down
20 changes: 20 additions & 0 deletions tensor/tests/tensor_core_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,23 @@ fn tensor_addition_operator() {
let expected_data = vec![2.0_f32; 8];
assert_eq!(expected_data, *result.data());
}

#[test]
fn test_display_1d() {
let a = Tensor::new(&[3], vec![0.0, 1.0, 2.0]).unwrap();
assert_eq!(format!("{}", a), "tensor([0.0000, 1.0000, 2.0000])");
}

#[test]
fn test_display_2d() {
let t = Tensor::new(&[2, 3], vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
let expected = "tensor([[0.0000, 1.0000, 2.0000]\n [3.0000, 4.0000, 5.0000]])";
assert_eq!(format!("{}", t), expected);
}

#[test]
fn test_display_3d() {
let t = Tensor::new(&[2, 2, 2], vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]).unwrap();
let expected = "tensor([[[0.0000, 1.0000]\n [2.0000, 3.0000]]\n\n [[4.0000, 5.0000]\n [6.0000, 7.0000]]])";
assert_eq!(format!("{}", t), expected);
}

0 comments on commit a8e3393

Please sign in to comment.