Skip to content

Commit

Permalink
add subtraction operation to tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
PaytonWebber committed Jan 18, 2025
1 parent fabe761 commit c6c5df6
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 1 deletion.
58 changes: 57 additions & 1 deletion tensor/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use core::{f32, fmt};
use std::ops::{Add, Mul, Div, Index};
use std::ops::{Add, Sub, Mul, Div, Index};

#[derive(Debug, Clone)]
pub struct Tensor {
Expand Down Expand Up @@ -156,6 +156,51 @@ impl Tensor {
Tensor::new(bc_shape, result_data)
}

pub fn sub(&self, other: &Tensor) -> Result<Tensor, &'static str> {
let self_shape = self.shape();
let other_shape = other.shape();
if !is_broadcastable(self_shape, other_shape) {
return Err("The tensor shapes are not compatible for subtraction.");
}

if self_shape == other_shape {
let result_data: Vec<f32> = self
.data
.iter()
.zip(other.data.iter())
.map(|(a, b)| a - b)
.collect();
return Tensor::new(self_shape.clone(), result_data);
}

let (bc_shape, self_bc_strides, other_bc_strides) =
compute_broadcast_shape_and_strides(self_shape, other_shape);

let self_data = self.data();
let other_data = other.data();

let result_size: usize = bc_shape.iter().product();
let mut result_data: Vec<f32> = Vec::with_capacity(result_size);

for i in 0..result_size {
let multi_idx = unravel_index(i, &bc_shape);

let mut self_offset = 0;
for (dim_i, &stride) in self_bc_strides.iter().enumerate() {
self_offset += multi_idx[dim_i] * stride;
}

let mut other_offset = 0;
for (dim_i, &stride) in other_bc_strides.iter().enumerate() {
other_offset += multi_idx[dim_i] * stride;
}

let val = self_data[self_offset] - other_data[other_offset];
result_data.push(val);
}
Tensor::new(bc_shape, result_data)
}

pub fn mul(&self, other: &Tensor) -> Result<Tensor, &'static str> {
let self_shape = self.shape();
let other_shape = other.shape();
Expand Down Expand Up @@ -494,6 +539,17 @@ impl Add<Tensor> for Tensor {
}
}

impl Sub<Tensor> for Tensor {
type Output = Tensor;

fn sub(self, rhs: Tensor) -> Self::Output {
match Tensor::sub(&self, &rhs) {
Ok(result) => result,
Err(_) => panic!("Shapes of the tensors do not match for subtraction."),
}
}
}

impl Mul<Tensor> for Tensor {
type Output = Tensor;

Expand Down
64 changes: 64 additions & 0 deletions tensor/tests/tensor_core_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,70 @@ fn tensor_broadcasted_addition_operator() {
assert_eq!(expected_data, *c.data());
}

#[test]
fn tensor_subtraction_method() {
let shape = vec![4, 2];
let a = Tensor::ones(shape.clone());
let b = Tensor::ones(shape);
let result = a.sub(&b).unwrap();

let expected_data = vec![0.0_f32; 8];
assert_eq!(expected_data, *result.data());
}

#[test]
fn tensor_broadcasted_subtraction_method() {
let a_shape = vec![4, 3];
let b_shape = vec![3];
let a_data = vec![
0_f32, 0_f32, 0_f32, 10_f32, 10_f32, 10_f32, 20_f32, 20_f32, 20_f32, 30_f32, 30_f32, 30_f32,
];
let b_data = vec![1_f32, 2_f32, 3_f32];

let a_tensor = Tensor::new(a_shape, a_data).unwrap();
let b_tensor = Tensor::new(b_shape, b_data).unwrap();

let c = a_tensor.sub(&b_tensor).unwrap();
let expected_data = vec![
-1_f32, -2_f32, -3_f32, 9_f32, 8_f32, 7_f32, 19_f32, 18_f32, 17_f32, 29_f32, 28_f32, 27_f32,
];

assert_eq!(vec![4, 3], *c.shape());
assert_eq!(expected_data, *c.data());
}

#[test]
fn tensor_subtraction_operator() {
let shape = vec![4, 2];
let a = Tensor::ones(shape.clone());
let b = Tensor::ones(shape);
let result = a - b;

let expected_data = vec![0.0_f32; 8];
assert_eq!(expected_data, *result.data());
}

#[test]
fn tensor_broadcasted_subtraction_operator() {
let a_shape = vec![4, 3];
let b_shape = vec![3];
let a_data = vec![
0_f32, 0_f32, 0_f32, 10_f32, 10_f32, 10_f32, 20_f32, 20_f32, 20_f32, 30_f32, 30_f32, 30_f32,
];
let b_data = vec![1_f32, 2_f32, 3_f32];

let a_tensor = Tensor::new(a_shape, a_data).unwrap();
let b_tensor = Tensor::new(b_shape, b_data).unwrap();

let c = a_tensor - b_tensor;
let expected_data = vec![
-1_f32, -2_f32, -3_f32, 9_f32, 8_f32, 7_f32, 19_f32, 18_f32, 17_f32, 29_f32, 28_f32, 27_f32,
];

assert_eq!(vec![4, 3], *c.shape());
assert_eq!(expected_data, *c.data());
}

#[test]
fn tensor_mul_method() {
let a_shape = vec![1, 3];
Expand Down

0 comments on commit c6c5df6

Please sign in to comment.