Skip to content

Commit

Permalink
Weekly chores (#332)
Browse files Browse the repository at this point in the history
  • Loading branch information
maxtremblay authored Dec 2, 2024
1 parent 9a0df74 commit 67e845f
Show file tree
Hide file tree
Showing 6 changed files with 172 additions and 143 deletions.
21 changes: 21 additions & 0 deletions crates/cubecl-core/src/ir/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,27 @@ pub struct CubeDim {
}

impl CubeDim {
/// Create a new cube dim with x = y = z = 1.
pub fn new_single() -> Self {
Self { x: 1, y: 1, z: 1 }
}

/// Create a new cube dim with the given x, and y = z = 1.
pub fn new_1d(x: u32) -> Self {
Self { x, y: 1, z: 1 }
}

/// Create a new cube dim with the given x and y, and z = 1.
pub fn new_2d(x: u32, y: u32) -> Self {
Self { x, y, z: 1 }
}

/// Create a new cube dim with the given x, y and z.
/// This is equivalent to the [new](CubeDim::new) function.
pub fn new_3d(x: u32, y: u32, z: u32) -> Self {
Self { x, y, z }
}

pub fn num_elems(&self) -> u32 {
self.x * self.y * self.z
}
Expand Down
14 changes: 7 additions & 7 deletions crates/cubecl-reduce/src/instructions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ use cubecl_core as cubecl;
use cubecl_core::prelude::*;

/// Compute the coordinate of the maximum item returning the smallest coordinate in case of equality.
pub struct ReduceArgMax;
pub struct ArgMax;

#[cube]
impl ReduceArgMax {
impl ArgMax {
/// Compare two pairs of items and coordinates and return a new pair
/// where each element in the lines is the maximal item with its coordinate.
/// In case of equality, the lowest coordinate is selected.
Expand All @@ -27,10 +27,10 @@ impl ReduceArgMax {
}

/// Compute the coordinate of the minimum item returning the smallest coordinate in case of equality.
pub struct ReduceArgMin;
pub struct ArgMin;

#[cube]
impl ReduceArgMin {
impl ArgMin {
/// Compare two pairs of items and coordinates and return a new pair
/// where each element in the lines is the minimal item with its coordinate.
/// In case of equality, the lowest coordinate is selected.
Expand All @@ -51,6 +51,6 @@ impl ReduceArgMin {
}
}

pub struct ReduceMean;
pub struct ReduceSum;
pub struct ReduceProd;
pub struct Mean;
pub struct Sum;
pub struct Prod;
94 changes: 40 additions & 54 deletions crates/cubecl-reduce/src/naive.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use cubecl_core as cubecl;
use cubecl_core::prelude::*;

use crate::{ReduceArgMax, ReduceArgMin, ReduceMean, ReduceProd, ReduceSum};
use crate::{ArgMax, ArgMin, Mean, Prod, Sum};

/// An instruction for the [reduce_naive](reduce_naive) algorithm.
#[cube]
Expand All @@ -16,10 +16,10 @@ pub trait ReduceNaiveInstruction<EI: Numeric>: Send + Sync + 'static {
/// This could be called many time during reduction. It is required
/// that reducing the initial accumulator any number of times do not change the outcome
/// of the reduction. For example, adding 0s in a sum do not change the outcome.
fn init_accumulator(line_size: u32) -> Self::Accumulator;
fn init_accumulator(#[comptime] line_size: u32) -> Self::Accumulator;

/// Reduce `current_value` into `accumulator`.
fn accumulate(accumulator: &mut Self::Accumulator, current_value: Line<EI>, i: u32);
fn accumulate(accumulator: &mut Self::Accumulator, item: Line<EI>, coordinate: u32);

/// Write the result of the reduction stored in `accumulator` into `output[index]`.
fn write<EO: Numeric>(
Expand Down Expand Up @@ -56,12 +56,12 @@ pub fn reduce_naive<RD: ReduceNaiveInstruction<EI>, EI: Numeric, EO: Numeric>(

// Reduce all the lines along `dim` for the previously computed offset.
let mut accumulator = RD::init_accumulator(input.line_size());
for i in 0..input.shape(dim) {
let index = i * input.stride(dim) + offset_input;
for coordinate in 0..input.shape(dim) {
let index = coordinate * input.stride(dim) + offset_input;
RD::accumulate(
&mut accumulator,
unsafe { *input.index_unchecked(index) },
i,
coordinate,
);
}

Expand All @@ -72,10 +72,10 @@ pub fn reduce_naive<RD: ReduceNaiveInstruction<EI>, EI: Numeric, EO: Numeric>(
// Implementations for common instructions.

#[cube]
impl<EI: Numeric> ReduceNaiveInstruction<EI> for ReduceSum {
impl<EI: Numeric> ReduceNaiveInstruction<EI> for Sum {
type Accumulator = Line<EI>;

fn init_accumulator(line_size: u32) -> Line<EI> {
fn init_accumulator(#[comptime] line_size: u32) -> Line<EI> {
Line::empty(line_size).fill(EI::from_int(0))
}

Expand All @@ -94,15 +94,15 @@ impl<EI: Numeric> ReduceNaiveInstruction<EI> for ReduceSum {
}

#[cube]
impl<EI: Numeric> ReduceNaiveInstruction<EI> for ReduceProd {
impl<EI: Numeric> ReduceNaiveInstruction<EI> for Prod {
type Accumulator = Line<EI>;

fn init_accumulator(line_size: u32) -> Line<EI> {
fn init_accumulator(#[comptime] line_size: u32) -> Line<EI> {
Line::empty(line_size).fill(EI::from_int(1))
}

fn accumulate(accumulator: &mut Self::Accumulator, current_value: Line<EI>, _i: u32) {
*accumulator *= current_value;
fn accumulate(accumulator: &mut Self::Accumulator, item: Line<EI>, _coordinate: u32) {
*accumulator *= item;
}

fn write<EO: Numeric>(
Expand All @@ -116,15 +116,15 @@ impl<EI: Numeric> ReduceNaiveInstruction<EI> for ReduceProd {
}

#[cube]
impl<EI: Numeric> ReduceNaiveInstruction<EI> for ReduceMean {
impl<EI: Numeric> ReduceNaiveInstruction<EI> for Mean {
type Accumulator = Line<EI>;

fn init_accumulator(line_size: u32) -> Self::Accumulator {
fn init_accumulator(#[comptime] line_size: u32) -> Self::Accumulator {
Line::empty(line_size).fill(EI::from_int(0))
}

fn accumulate(accumulator: &mut Self::Accumulator, current_value: Line<EI>, _i: u32) {
*accumulator += current_value;
fn accumulate(accumulator: &mut Self::Accumulator, item: Line<EI>, _coordinate: u32) {
*accumulator += item;
}

fn write<EO: Numeric>(
Expand All @@ -140,34 +140,27 @@ impl<EI: Numeric> ReduceNaiveInstruction<EI> for ReduceMean {
}

#[cube]
impl<EI: Numeric> ReduceNaiveInstruction<EI> for ReduceArgMax {
impl<EI: Numeric> ReduceNaiveInstruction<EI> for ArgMax {
type Accumulator = (Line<EI>, Line<u32>);

fn init_accumulator(line_size: u32) -> Self::Accumulator {
fn init_accumulator(#[comptime] line_size: u32) -> Self::Accumulator {
(
// TODO: switch to using f32::NEG_INFINITY when it's supported: https://github.com/tracel-ai/cubecl/issues/68
Line::empty(line_size).fill(EI::MIN),
Line::empty(line_size).fill(0u32),
)
}

fn accumulate(accumulator: &mut Self::Accumulator, current_value: Line<EI>, i: u32) {
let (max, index) = accumulator;
#[allow(clippy::collapsible_else_if)]
if comptime!(current_value.size() > 1) {
#[unroll]
for k in 0..current_value.size() {
if current_value[k] > max[k] {
max[k] = current_value[k];
index[k] = i;
}
}
} else {
if current_value > *max {
*max = current_value;
*index = Line::new(i);
}
}
fn accumulate(accumulator: &mut Self::Accumulator, item: Line<EI>, coordinate: u32) {
let (acc_item, acc_coordinate) = accumulator;
let (new_item, new_coordinate) = Self::choose_argmax(
*acc_item,
*acc_coordinate,
item,
Line::empty(item.size()).fill(coordinate),
);
accumulator.0 = new_item;
accumulator.1 = new_coordinate;
}

fn write<EO: Numeric>(
Expand All @@ -182,34 +175,27 @@ impl<EI: Numeric> ReduceNaiveInstruction<EI> for ReduceArgMax {
}

#[cube]
impl<EI: Numeric> ReduceNaiveInstruction<EI> for ReduceArgMin {
impl<EI: Numeric> ReduceNaiveInstruction<EI> for ArgMin {
type Accumulator = (Line<EI>, Line<u32>);

fn init_accumulator(line_size: u32) -> Self::Accumulator {
fn init_accumulator(#[comptime] line_size: u32) -> Self::Accumulator {
(
// TODO: switch to using f32::INFINITY when it's supported: https://github.com/tracel-ai/cubecl/issues/68
Line::empty(line_size).fill(EI::MAX),
Line::empty(line_size).fill(0u32),
)
}

fn accumulate(accumulator: &mut Self::Accumulator, current_value: Line<EI>, i: u32) {
let (min, index) = accumulator;
#[allow(clippy::collapsible_else_if)]
if comptime!(current_value.size() > 1) {
#[unroll]
for k in 0..current_value.size() {
if current_value[k] < min[k] {
min[k] = current_value[k];
index[k] = i;
}
}
} else {
if current_value < *min {
*min = current_value;
*index = Line::new(i);
}
}
fn accumulate(accumulator: &mut Self::Accumulator, item: Line<EI>, coordinate: u32) {
let (acc_item, acc_coordinate) = accumulator;
let (new_item, new_coordinate) = Self::choose_argmin(
*acc_item,
*acc_coordinate,
item,
Line::empty(item.size()).fill(coordinate),
);
accumulator.0 = new_item;
accumulator.1 = new_coordinate;
}

fn write<EO: Numeric>(
Expand Down
12 changes: 6 additions & 6 deletions crates/cubecl-reduce/src/shared.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use cubecl_core as cubecl;
use cubecl_core::prelude::*;

use crate::{ReduceArgMax, ReduceArgMin, ReduceMean, ReduceProd, ReduceSum};
use crate::{ArgMax, ArgMin, Mean, Prod, Sum};

/// An instruction for the [reduce_shared](reduce_shared) algorithm.
#[cube]
Expand Down Expand Up @@ -139,7 +139,7 @@ fn div_ceil(a: u32, b: u32) -> u32 {
// Implementations for common instructions.

#[cube]
impl<EI: Numeric> ReduceSharedInstruction<EI> for ReduceSum {
impl<EI: Numeric> ReduceSharedInstruction<EI> for Sum {
type Accumulator = SharedMemory<Line<EI>>;

fn create_accumulator(
Expand Down Expand Up @@ -183,7 +183,7 @@ impl<EI: Numeric> ReduceSharedInstruction<EI> for ReduceSum {
}

#[cube]
impl<EI: Numeric> ReduceSharedInstruction<EI> for ReduceProd {
impl<EI: Numeric> ReduceSharedInstruction<EI> for Prod {
type Accumulator = SharedMemory<Line<EI>>;

fn create_accumulator(
Expand Down Expand Up @@ -227,7 +227,7 @@ impl<EI: Numeric> ReduceSharedInstruction<EI> for ReduceProd {
}

#[cube]
impl<EI: Numeric> ReduceSharedInstruction<EI> for ReduceMean {
impl<EI: Numeric> ReduceSharedInstruction<EI> for Mean {
type Accumulator = SharedMemory<Line<EI>>;

fn create_accumulator(
Expand Down Expand Up @@ -272,7 +272,7 @@ impl<EI: Numeric> ReduceSharedInstruction<EI> for ReduceMean {
}

#[cube]
impl<EI: Numeric> ReduceSharedInstruction<EI> for ReduceArgMax {
impl<EI: Numeric> ReduceSharedInstruction<EI> for ArgMax {
type Accumulator = (SharedMemory<Line<EI>>, SharedMemory<Line<u32>>);

fn create_accumulator(
Expand Down Expand Up @@ -340,7 +340,7 @@ impl<EI: Numeric> ReduceSharedInstruction<EI> for ReduceArgMax {
}

#[cube]
impl<EI: Numeric> ReduceSharedInstruction<EI> for ReduceArgMin {
impl<EI: Numeric> ReduceSharedInstruction<EI> for ArgMin {
type Accumulator = (SharedMemory<Line<EI>>, SharedMemory<Line<u32>>);

fn create_accumulator(
Expand Down
Loading

0 comments on commit 67e845f

Please sign in to comment.