Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Import reduce naive from burn #314

Merged
merged 10 commits into from
Nov 28, 2024
5 changes: 5 additions & 0 deletions crates/cubecl-reduce/src/instructions.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pub struct ReduceArgMax;
pub struct ReduceArgMin;
pub struct ReduceMean;
pub struct ReduceSum;
pub struct ReduceProd;
7 changes: 6 additions & 1 deletion crates/cubecl-reduce/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
pub mod sum;
mod instructions;
mod naive;

#[cfg(feature = "export_tests")]
pub mod test;

pub use instructions::*;
pub use naive::*;

213 changes: 213 additions & 0 deletions crates/cubecl-reduce/src/naive.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
use cubecl_core as cubecl;
use cubecl_core::prelude::*;

use crate::{ReduceArgMax, ReduceArgMin, ReduceMean, ReduceProd, ReduceSum};

/// An instruction for the [reduce_naive](reduce_naive) algorithm.
#[cube]
pub trait ReduceNaiveInstruction<EI: Numeric>: Send + Sync + 'static {
/// The reduction accumulator.
/// The implement works on lines. Most likely, the accumulator is `Line<T>`
/// for some CubePrimitive type `T` instead of simply `T`.
type Accumulator: CubeType;

/// Initialize the accumulator with a null value for the reduction.
///
/// This could be called many time during reduction. It is required
/// that reducing the initial accumulator any number of times do not change the outcome
/// of the reduction. For example, adding 0s in a sum do not change the outcome.
fn init_accumulator(line_size: u32) -> Self::Accumulator;

/// Reduce `current_value` into `accumulator`.
fn accumulate(accumulator: &mut Self::Accumulator, current_value: Line<EI>, i: u32);

/// Write the result of the reduction stored in `accumulator` into `output[index]`.
fn write<EO: Numeric>(
output: &mut Tensor<Line<EO>>,
accumulator: Self::Accumulator,
index: u32,
shape_reduce_dim: u32,
);
}

/// A naive implementation of the reduction algorithm.
///
/// Each thread with absolute position P is responsible
/// to compute the reduction corresponding to index P of the `output`.
#[cube]
pub fn reduce_naive<RD: ReduceNaiveInstruction<EI>, EI: Numeric, EO: Numeric>(
input: &Tensor<Line<EI>>,
output: &mut Tensor<Line<EO>>,
dim: u32,
) {
if ABSOLUTE_POS >= output.len() * output.line_size() {
return;
}

// Compute the first index where to start the reduction for the current thread.
// First, compute the coordinate corresponding to the ABSOLUTE_POS element of the output tensor
// Then, use the strides of the input tensor to find the index of the same coordinate
// in the input tensor.
let mut offset_input = 0;
for axis in 0..input.rank() {
let coordinate = (ABSOLUTE_POS / output.stride(axis)) % output.shape(axis);
offset_input += coordinate * input.stride(axis);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How come the if axis != dim was removed? Was it useless?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The shape of the dim axis in output is always 1, so for that particular axis, the coordinate is always 0. Thus, it has a null effect on offset_input.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SHould I add a comment?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I was just making sure nothing was lost from the original

}

// Reduce all the lines along `dim` for the previously computed offset.
let mut accumulator = RD::init_accumulator(input.line_size());
for i in 0..input.shape(dim) {
let index = i * input.stride(dim) + offset_input;
RD::accumulate(
&mut accumulator,
unsafe { *input.index_unchecked(index) },
i,
);
}

// Write the local outcome into output.
RD::write::<EO>(output, accumulator, ABSOLUTE_POS, input.shape(dim));
}

// Implementations for common instructions.

#[cube]
impl<EI: Numeric> ReduceNaiveInstruction<EI> for ReduceSum {
type Accumulator = Line<EI>;

fn init_accumulator(line_size: u32) -> Line<EI> {
Line::empty(line_size).fill(EI::from_int(0))
}

fn accumulate(accumulator: &mut Self::Accumulator, current_value: Line<EI>, _i: u32) {
*accumulator += current_value;
}

fn write<EO: Numeric>(
output: &mut Tensor<Line<EO>>,
accumulator: Self::Accumulator,
index: u32,
_shape_reduce_dim: u32,
) {
output[index] = Line::cast_from(accumulator);
}
}

#[cube]
impl<EI: Numeric> ReduceNaiveInstruction<EI> for ReduceProd {
type Accumulator = Line<EI>;

fn init_accumulator(line_size: u32) -> Line<EI> {
Line::empty(line_size).fill(EI::from_int(1))
}

fn accumulate(accumulator: &mut Self::Accumulator, current_value: Line<EI>, _i: u32) {
*accumulator *= current_value;
}

fn write<EO: Numeric>(
output: &mut Tensor<Line<EO>>,
accumulator: Self::Accumulator,
index: u32,
_shape_reduce_dim: u32,
) {
output[index] = Line::cast_from(accumulator);
}
}

#[cube]
impl<EI: Numeric> ReduceNaiveInstruction<EI> for ReduceMean {
type Accumulator = Line<EI>;

fn init_accumulator(line_size: u32) -> Self::Accumulator {
Line::empty(line_size).fill(EI::from_int(0))
}

fn accumulate(accumulator: &mut Self::Accumulator, current_value: Line<EI>, _i: u32) {
*accumulator += current_value;
}

fn write<EO: Numeric>(
output: &mut Tensor<Line<EO>>,
accumulator: Self::Accumulator,
index: u32,
shape_reduce_dim: u32,
) {
output[index] = Line::cast_from(
accumulator / Line::empty(output.line_size()).fill(EI::cast_from(shape_reduce_dim)),
);
}
}

#[cube]
impl<EI: Numeric> ReduceNaiveInstruction<EI> for ReduceArgMax {
type Accumulator = (Line<EI>, Line<u32>);

fn init_accumulator(line_size: u32) -> Self::Accumulator {
(
// TODO: switch to using f32::NEG_INFINITY when it's supported: https://github.com/tracel-ai/cubecl/issues/68
Line::empty(line_size).fill(EI::MIN),
Line::empty(line_size).fill(0u32),
)
}

fn accumulate(accumulator: &mut Self::Accumulator, current_value: Line<EI>, i: u32) {
// FIX: There is an issue when line_size is 1 on wgpu.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could do if comptime!(current_value.size() == 1) and have a different logic without indexes

let (max, index) = accumulator;
#[unroll]
for k in 0..current_value.size() {
if current_value[k] > max[k] {
max[k] = current_value[k];
index[k] = i;
}
}
}

fn write<EO: Numeric>(
output: &mut Tensor<Line<EO>>,
accumulator: Self::Accumulator,
index: u32,
_shape_reduce_dim: u32,
) {
let (_, position) = accumulator;
output[index] = Line::cast_from(position)
}
}

#[cube]
impl<EI: Numeric> ReduceNaiveInstruction<EI> for ReduceArgMin {
type Accumulator = (Line<EI>, Line<u32>);

fn init_accumulator(line_size: u32) -> Self::Accumulator {
(
// TODO: switch to using f32::INFINITY when it's supported: https://github.com/tracel-ai/cubecl/issues/68
Line::empty(line_size).fill(EI::MAX),
Line::empty(line_size).fill(0u32),
)
}

fn accumulate(accumulator: &mut Self::Accumulator, current_value: Line<EI>, i: u32) {
// FIX: There is an issue when line_size is 1 on wgpu.
let (min, index) = accumulator;
#[unroll]
for k in 0..current_value.size() {
if current_value[k] < min[k] {
min[k] = current_value[k];
index[k] = i;
}
}
}

fn write<EO: Numeric>(
output: &mut Tensor<Line<EO>>,
accumulator: Self::Accumulator,
index: u32,
_shape_reduce_dim: u32,
) {
let (_, position) = accumulator;
#[unroll]
for k in 0..output.line_size() {
output[index][k] = EO::cast_from(position[k]);
}
}
}
110 changes: 0 additions & 110 deletions crates/cubecl-reduce/src/sum.rs

This file was deleted.

Loading
Loading