Skip to content

Commit

Permalink
refactor and add comments to tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
PaytonWebber committed Dec 19, 2024
1 parent 8b65479 commit 1f78023
Showing 1 changed file with 49 additions and 44 deletions.
93 changes: 49 additions & 44 deletions tensor/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
use core::{f32, fmt};
use std::{
ops::{Add, Index},
usize,
};
use std::ops::{Add, Index};

#[derive(Debug, Clone)]
pub struct Tensor {
Expand All @@ -25,16 +22,6 @@ impl Tensor {
})
}

fn calculate_strides(shape: &[usize]) -> Vec<usize> {
let length: usize = shape.len();
let mut strides = vec![1; length];
strides.iter_mut().enumerate().for_each(|(i, stride)| {
// stride[i] = (shape[i+1]*shape[i+2]*...*shape[N-1])
*stride = shape.iter().take(length).skip(i + 1).product();
});
strides
}

pub fn zeros(shape: &[usize]) -> Self {
let num_elements: usize = shape.iter().product();
let strides: Vec<usize> = Self::calculate_strides(shape);
Expand All @@ -55,6 +42,18 @@ impl Tensor {
}
}

fn calculate_strides(shape: &[usize]) -> Vec<usize> {
let length: usize = shape.len();
let mut strides = vec![1; length];
strides.iter_mut().enumerate().for_each(|(i, stride)| {
// stride[i] = (shape[i+1]*shape[i+2]*...*shape[N-1])
*stride = shape.iter().take(length).skip(i + 1).product();
});
strides
}

/* MOVEMENT OPS */

pub fn reshape(&mut self, shape: &[usize]) -> Result<(), &'static str> {
let new_length: usize = shape.iter().product();
let current_length: usize = self.shape.iter().product();
Expand Down Expand Up @@ -90,36 +89,7 @@ impl Tensor {
self.strides = vec![1];
}

pub fn shape(&self) -> &Vec<usize> {
&self.shape
}

pub fn strides(&self) -> &Vec<usize> {
&self.strides
}

pub fn data(&self) -> &Vec<f32> {
&self.data
}

pub fn data_mut(&mut self) -> &mut Vec<f32> {
&mut self.data
}

fn get(&self, indices: &[usize]) -> Option<&f32> {
if indices.len() != self.shape.len() {
return None;
}

let mut idx: usize = 0;
for (i, &dim) in indices.iter().enumerate() {
if dim >= self.shape[i] {
return None;
}
idx += dim * self.strides[i];
}
self.data.get(idx)
}
/* BINARY OPS */

pub fn add(&self, other: &Tensor) -> Result<Tensor, &'static str> {
if self.shape != other.shape {
Expand Down Expand Up @@ -173,6 +143,39 @@ impl Tensor {

Ok(c)
}

/* GETTERS */

fn get(&self, indices: &[usize]) -> Option<&f32> {
if indices.len() != self.shape.len() {
return None;
}

let mut idx: usize = 0;
for (i, &dim) in indices.iter().enumerate() {
if dim >= self.shape[i] {
return None;
}
idx += dim * self.strides[i];
}
self.data.get(idx)
}

pub fn shape(&self) -> &Vec<usize> {
&self.shape
}

pub fn strides(&self) -> &Vec<usize> {
&self.strides
}

pub fn data(&self) -> &Vec<f32> {
&self.data
}

pub fn data_mut(&mut self) -> &mut Vec<f32> {
&mut self.data
}
}

fn calculate_data_index(indices: &[usize], strides: &[usize]) -> usize {
Expand Down Expand Up @@ -282,6 +285,8 @@ impl fmt::Display for Tensor {
}
}

/* TRAIT IMPLEMENTATIONS */

impl Index<&[usize]> for Tensor {
type Output = f32;
fn index(&self, indices: &[usize]) -> &Self::Output {
Expand Down

0 comments on commit 1f78023

Please sign in to comment.