From a8e33936b9c933158894b5a655f217f3d581f147 Mon Sep 17 00:00:00 2001 From: Payton Webber Date: Tue, 17 Dec 2024 19:14:53 -0800 Subject: [PATCH] add fmt::Display trait to tensor --- tensor/src/lib.rs | 108 +++++++++++++++++++++++++++++++ tensor/tests/tensor_core_test.rs | 20 ++++++ 2 files changed, 128 insertions(+) diff --git a/tensor/src/lib.rs b/tensor/src/lib.rs index 96ce13f..481b9f3 100644 --- a/tensor/src/lib.rs +++ b/tensor/src/lib.rs @@ -1,3 +1,4 @@ +use core::fmt; use std::ops::{Add, Index}; #[derive(Debug, Clone)] @@ -129,6 +130,113 @@ impl Tensor { } } +fn calculate_data_index(indices: &[usize], strides: &[usize]) -> usize { + indices + .iter() + .enumerate() + .map(|(i, &idx)| idx * strides[i]) + .sum() +} + +fn print_tensor_recursive( + f: &mut fmt::Formatter<'_>, + data: &[f32], + shape: &[usize], + strides: &[usize], + current_index: &mut [usize], + dim: usize, + ndims: usize, +) -> fmt::Result { + if ndims == 0 { + // 0-D tensor (scalar) + if let Some(value) = data.first() { + return write!(f, "{:.4}", value); + } else { + return write!(f, ""); + } + } + + if dim == ndims - 1 { + // Last dimension: print elements in a row + write!(f, "[")?; + for i in 0..shape[dim] { + current_index[dim] = i; + let idx = calculate_data_index(current_index, strides); + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{:.4}", data[idx])?; + } + write!(f, "]")?; + } else { + // Not the last dimension + write!(f, "[")?; + for i in 0..shape[dim] { + current_index[dim] = i; + + if i > 0 { + // Subsequent slices/rows + if dim == 0 { + // Top-level dimension + if ndims >= 3 { + // For 3D or more: blank line between top-level slices + write!(f, "\n\n")?; + // 7 spaces indentation + for _ in 0..7 { + write!(f, " ")?; + } + } else if ndims == 2 { + // For 2D: no blank line, just newline + 8 spaces + writeln!(f)?; + for _ in 0..8 { + write!(f, " ")?; + } + } + } else { + // Inner dimension (dim > 0) + // newline + 8 spaces + writeln!(f)?; + for _ in 0..8 { + write!(f, " ")?; + } + } + } + + print_tensor_recursive(f, data, shape, strides, current_index, dim + 1, ndims)?; + } + write!(f, "]")?; + } + + Ok(()) +} + +impl fmt::Display for Tensor { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if self.shape.is_empty() { + // 0-D tensor + if let Some(value) = self.data.first() { + return write!(f, "tensor({:.4})", value); + } else { + return write!(f, "tensor()"); + } + } + + write!(f, "tensor(")?; + let mut current_index = vec![0; self.shape.len()]; + print_tensor_recursive( + f, + &self.data, + &self.shape, + &self.strides, + &mut current_index, + 0, + self.shape.len(), + )?; + write!(f, ")")?; + Ok(()) + } +} + impl Index<&[usize]> for Tensor { type Output = f32; fn index(&self, indices: &[usize]) -> &Self::Output { diff --git a/tensor/tests/tensor_core_test.rs b/tensor/tests/tensor_core_test.rs index d789cf4..7f1e44f 100644 --- a/tensor/tests/tensor_core_test.rs +++ b/tensor/tests/tensor_core_test.rs @@ -146,3 +146,23 @@ fn tensor_addition_operator() { let expected_data = vec![2.0_f32; 8]; assert_eq!(expected_data, *result.data()); } + +#[test] +fn test_display_1d() { + let a = Tensor::new(&[3], vec![0.0, 1.0, 2.0]).unwrap(); + assert_eq!(format!("{}", a), "tensor([0.0000, 1.0000, 2.0000])"); +} + +#[test] +fn test_display_2d() { + let t = Tensor::new(&[2, 3], vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0]).unwrap(); + let expected = "tensor([[0.0000, 1.0000, 2.0000]\n [3.0000, 4.0000, 5.0000]])"; + assert_eq!(format!("{}", t), expected); +} + +#[test] +fn test_display_3d() { + let t = Tensor::new(&[2, 2, 2], vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]).unwrap(); + let expected = "tensor([[[0.0000, 1.0000]\n [2.0000, 3.0000]]\n\n [[4.0000, 5.0000]\n [6.0000, 7.0000]]])"; + assert_eq!(format!("{}", t), expected); +}