From 1f780238378d7a17a520b4efc932a78e9becffa3 Mon Sep 17 00:00:00 2001 From: Payton Webber Date: Thu, 19 Dec 2024 01:42:42 -0800 Subject: [PATCH] refactor and add comments to tensor --- tensor/src/lib.rs | 93 +++++++++++++++++++++++++---------------------- 1 file changed, 49 insertions(+), 44 deletions(-) diff --git a/tensor/src/lib.rs b/tensor/src/lib.rs index 89e5453..e78119d 100644 --- a/tensor/src/lib.rs +++ b/tensor/src/lib.rs @@ -1,8 +1,5 @@ use core::{f32, fmt}; -use std::{ - ops::{Add, Index}, - usize, -}; +use std::ops::{Add, Index}; #[derive(Debug, Clone)] pub struct Tensor { @@ -25,16 +22,6 @@ impl Tensor { }) } - fn calculate_strides(shape: &[usize]) -> Vec { - let length: usize = shape.len(); - let mut strides = vec![1; length]; - strides.iter_mut().enumerate().for_each(|(i, stride)| { - // stride[i] = (shape[i+1]*shape[i+2]*...*shape[N-1]) - *stride = shape.iter().take(length).skip(i + 1).product(); - }); - strides - } - pub fn zeros(shape: &[usize]) -> Self { let num_elements: usize = shape.iter().product(); let strides: Vec = Self::calculate_strides(shape); @@ -55,6 +42,18 @@ impl Tensor { } } + fn calculate_strides(shape: &[usize]) -> Vec { + let length: usize = shape.len(); + let mut strides = vec![1; length]; + strides.iter_mut().enumerate().for_each(|(i, stride)| { + // stride[i] = (shape[i+1]*shape[i+2]*...*shape[N-1]) + *stride = shape.iter().take(length).skip(i + 1).product(); + }); + strides + } + + /* MOVEMENT OPS */ + pub fn reshape(&mut self, shape: &[usize]) -> Result<(), &'static str> { let new_length: usize = shape.iter().product(); let current_length: usize = self.shape.iter().product(); @@ -90,36 +89,7 @@ impl Tensor { self.strides = vec![1]; } - pub fn shape(&self) -> &Vec { - &self.shape - } - - pub fn strides(&self) -> &Vec { - &self.strides - } - - pub fn data(&self) -> &Vec { - &self.data - } - - pub fn data_mut(&mut self) -> &mut Vec { - &mut self.data - } - - fn get(&self, indices: &[usize]) -> Option<&f32> { - if indices.len() != self.shape.len() { - return None; - } - - let mut idx: usize = 0; - for (i, &dim) in indices.iter().enumerate() { - if dim >= self.shape[i] { - return None; - } - idx += dim * self.strides[i]; - } - self.data.get(idx) - } + /* BINARY OPS */ pub fn add(&self, other: &Tensor) -> Result { if self.shape != other.shape { @@ -173,6 +143,39 @@ impl Tensor { Ok(c) } + + /* GETTERS */ + + fn get(&self, indices: &[usize]) -> Option<&f32> { + if indices.len() != self.shape.len() { + return None; + } + + let mut idx: usize = 0; + for (i, &dim) in indices.iter().enumerate() { + if dim >= self.shape[i] { + return None; + } + idx += dim * self.strides[i]; + } + self.data.get(idx) + } + + pub fn shape(&self) -> &Vec { + &self.shape + } + + pub fn strides(&self) -> &Vec { + &self.strides + } + + pub fn data(&self) -> &Vec { + &self.data + } + + pub fn data_mut(&mut self) -> &mut Vec { + &mut self.data + } } fn calculate_data_index(indices: &[usize], strides: &[usize]) -> usize { @@ -282,6 +285,8 @@ impl fmt::Display for Tensor { } } +/* TRAIT IMPLEMENTATIONS */ + impl Index<&[usize]> for Tensor { type Output = f32; fn index(&self, indices: &[usize]) -> &Self::Output {