From f04cb5c4283254478fbcd90ee468f4d8627d8220 Mon Sep 17 00:00:00 2001 From: maxime Date: Tue, 26 Nov 2024 10:27:04 -0500 Subject: [PATCH 1/9] import reduce from burn --- crates/cubecl-reduce/src/base.rs | 83 ++++++++++++ crates/cubecl-reduce/src/lib.rs | 15 ++- crates/cubecl-reduce/src/naive/argmax.rs | 32 +++++ crates/cubecl-reduce/src/naive/argmin.rs | 32 +++++ crates/cubecl-reduce/src/naive/base.rs | 21 +++ crates/cubecl-reduce/src/naive/kernel.rs | 61 +++++++++ crates/cubecl-reduce/src/naive/mean_dim.rs | 23 ++++ crates/cubecl-reduce/src/naive/mod.rs | 7 + crates/cubecl-reduce/src/naive/prod_dim.rs | 22 +++ crates/cubecl-reduce/src/naive/sum_dim.rs | 22 +++ crates/cubecl-reduce/src/prod.rs | 15 +++ crates/cubecl-reduce/src/shared/argmax.rs | 63 +++++++++ crates/cubecl-reduce/src/shared/argmin.rs | 64 +++++++++ crates/cubecl-reduce/src/shared/base.rs | 35 +++++ crates/cubecl-reduce/src/shared/kernel.rs | 117 ++++++++++++++++ crates/cubecl-reduce/src/shared/mean_dim.rs | 44 ++++++ crates/cubecl-reduce/src/shared/mod.rs | 7 + crates/cubecl-reduce/src/shared/prod_dim.rs | 43 ++++++ crates/cubecl-reduce/src/shared/sum_dim.rs | 43 ++++++ crates/cubecl-reduce/src/subcube/argmax.rs | 54 ++++++++ crates/cubecl-reduce/src/subcube/argmin.rs | 54 ++++++++ crates/cubecl-reduce/src/subcube/base.rs | 17 +++ crates/cubecl-reduce/src/subcube/kernel.rs | 134 +++++++++++++++++++ crates/cubecl-reduce/src/subcube/mean_dim.rs | 45 +++++++ crates/cubecl-reduce/src/subcube/mod.rs | 7 + crates/cubecl-reduce/src/subcube/prod_dim.rs | 44 ++++++ crates/cubecl-reduce/src/subcube/sum_dim.rs | 44 ++++++ crates/cubecl-reduce/src/sum.rs | 123 ++--------------- crates/cubecl-reduce/src/tune/base.rs | 94 +++++++++++++ crates/cubecl-reduce/src/tune/key.rs | 39 ++++++ crates/cubecl-reduce/src/tune/mod.rs | 7 + 31 files changed, 1300 insertions(+), 111 deletions(-) create mode 100644 crates/cubecl-reduce/src/base.rs create mode 100644 crates/cubecl-reduce/src/naive/argmax.rs create mode 100644 crates/cubecl-reduce/src/naive/argmin.rs create mode 100644 crates/cubecl-reduce/src/naive/base.rs create mode 100644 crates/cubecl-reduce/src/naive/kernel.rs create mode 100644 crates/cubecl-reduce/src/naive/mean_dim.rs create mode 100644 crates/cubecl-reduce/src/naive/mod.rs create mode 100644 crates/cubecl-reduce/src/naive/prod_dim.rs create mode 100644 crates/cubecl-reduce/src/naive/sum_dim.rs create mode 100644 crates/cubecl-reduce/src/prod.rs create mode 100644 crates/cubecl-reduce/src/shared/argmax.rs create mode 100644 crates/cubecl-reduce/src/shared/argmin.rs create mode 100644 crates/cubecl-reduce/src/shared/base.rs create mode 100644 crates/cubecl-reduce/src/shared/kernel.rs create mode 100644 crates/cubecl-reduce/src/shared/mean_dim.rs create mode 100644 crates/cubecl-reduce/src/shared/mod.rs create mode 100644 crates/cubecl-reduce/src/shared/prod_dim.rs create mode 100644 crates/cubecl-reduce/src/shared/sum_dim.rs create mode 100644 crates/cubecl-reduce/src/subcube/argmax.rs create mode 100644 crates/cubecl-reduce/src/subcube/argmin.rs create mode 100644 crates/cubecl-reduce/src/subcube/base.rs create mode 100644 crates/cubecl-reduce/src/subcube/kernel.rs create mode 100644 crates/cubecl-reduce/src/subcube/mean_dim.rs create mode 100644 crates/cubecl-reduce/src/subcube/mod.rs create mode 100644 crates/cubecl-reduce/src/subcube/prod_dim.rs create mode 100644 crates/cubecl-reduce/src/subcube/sum_dim.rs create mode 100644 crates/cubecl-reduce/src/tune/base.rs create mode 100644 crates/cubecl-reduce/src/tune/key.rs create mode 100644 crates/cubecl-reduce/src/tune/mod.rs diff --git a/crates/cubecl-reduce/src/base.rs b/crates/cubecl-reduce/src/base.rs new file mode 100644 index 00000000..5cb9ebfb --- /dev/null +++ b/crates/cubecl-reduce/src/base.rs @@ -0,0 +1,83 @@ +use cubecl::prelude::Numeric; + +#[cfg(feature = "autotune")] +use crate::kernel::reduce::reduce_dim_autotune; +use crate::{element::JitElement, ops::numeric::empty_device, tensor::JitTensor, JitRuntime}; + +use super::{ + naive::{base::ReduceDimNaive, kernel::reduce_dim_naive}, + shared::{base::ReduceDimShared, kernel::reduce_dim_shared}, + subcube::{base::ReduceDimSubcube, kernel::reduce_dim_subcube}, +}; + +#[allow(dead_code)] +pub(crate) trait ReduceDimAlgorithm: + core::fmt::Debug + ReduceDimNaive + ReduceDimShared + ReduceDimSubcube +{ +} + +/// Creates an empty output tensor with reduce output shape +pub fn init_reduce_output( + input: &JitTensor, + reduce_dim: usize, +) -> JitTensor { + let mut shape_out = input.shape.clone(); + shape_out.dims[reduce_dim] = 1; + + empty_device::(input.client.clone(), input.device.clone(), shape_out) +} + +#[derive(Copy, Clone, Debug)] +#[allow(missing_docs)] +pub enum ReduceStrategy { + /// Naive + Naive, + /// Use shared memory as an accumulator + SharedMemory, + /// Use subcube functions + Subcube, + #[cfg(feature = "autotune")] + Autotune, +} + +impl Default for ReduceStrategy { + fn default() -> Self { + // if autotune is enabled, default to autotune + #[cfg(feature = "autotune")] + return ReduceStrategy::Autotune; + + #[cfg(not(feature = "autotune"))] + ReduceStrategy::Naive + } +} + +macro_rules! reduce_operation { + ($name:ident, $ops:ident) => { + #[derive(Debug)] + pub(crate) struct $ops; + + impl ReduceDimAlgorithm for $ops {} + + /// Executes the reduce operation with the given strategy. + pub fn $name( + tensor: JitTensor, + dim: usize, + strategy: ReduceStrategy, + ) -> JitTensor { + match strategy { + ReduceStrategy::Naive => reduce_dim_naive::<$ops, R, EI, EO>(tensor, dim), + ReduceStrategy::SharedMemory => reduce_dim_shared::<$ops, R, EI, EO>(tensor, dim), + ReduceStrategy::Subcube => reduce_dim_subcube::<$ops, R, EI, EO>(tensor, dim), + #[cfg(feature = "autotune")] + ReduceStrategy::Autotune => reduce_dim_autotune::<$ops, R, EI, EO>(tensor, dim), + } + } + }; +} + +// Autotunable reduce operation variants +reduce_operation!(sum_dim, SumDim); +reduce_operation!(mean_dim, MeanDim); +reduce_operation!(prod_dim, ProdDim); +reduce_operation!(argmin, Argmin); +reduce_operation!(argmax, Argmax); diff --git a/crates/cubecl-reduce/src/lib.rs b/crates/cubecl-reduce/src/lib.rs index 5f157485..ebc9a5b0 100644 --- a/crates/cubecl-reduce/src/lib.rs +++ b/crates/cubecl-reduce/src/lib.rs @@ -1,4 +1,15 @@ -pub mod sum; +mod base; +mod naive; +mod prod; +mod shared; +mod subcube; +mod sum; +mod tune; -#[cfg(feature = "export_tests")] +pub use base::*; +pub use prod::*; +pub use sum::*; +pub use tune::*; + +#[cfg(export_tests)] pub mod test; diff --git a/crates/cubecl-reduce/src/naive/argmax.rs b/crates/cubecl-reduce/src/naive/argmax.rs new file mode 100644 index 00000000..ba51ec72 --- /dev/null +++ b/crates/cubecl-reduce/src/naive/argmax.rs @@ -0,0 +1,32 @@ +use cubecl::prelude::*; + +use crate::{kernel::reduce::Argmax, JitElement}; + +use super::base::ReduceDimNaive; + +#[cube] +impl ReduceDimNaive for Argmax { + type Accumulator = (EI, u32); + + fn initialize_naive() -> Self::Accumulator { + // TODO: switch to using f32::NEG_INFINITY when it's supported: https://github.com/tracel-ai/cubecl/issues/68 + (comptime![EI::minimum_value()].runtime(), 0u32) + } + + fn inner_loop_naive(accumulator: &mut Self::Accumulator, current_value: EI, i: u32) { + let (max, index) = accumulator; + if current_value > *max { + *max = current_value; + *index = i; + } + } + + fn assign_naive( + output: &mut Tensor, + accumulator: Self::Accumulator, + _shape_reduce_dim: u32, + ) { + let (_, index) = accumulator; + output[ABSOLUTE_POS] = EO::cast_from(index); + } +} diff --git a/crates/cubecl-reduce/src/naive/argmin.rs b/crates/cubecl-reduce/src/naive/argmin.rs new file mode 100644 index 00000000..0418d20a --- /dev/null +++ b/crates/cubecl-reduce/src/naive/argmin.rs @@ -0,0 +1,32 @@ +use cubecl::prelude::*; + +use crate::{kernel::reduce::Argmin, JitElement}; + +use super::base::ReduceDimNaive; + +#[cube] +impl ReduceDimNaive for Argmin { + type Accumulator = (EI, u32); + + fn initialize_naive() -> Self::Accumulator { + // TODO: switch to using f32::INFINITY when it's supported: https://github.com/tracel-ai/cubecl/issues/68 + (comptime![EI::maximum_value()].runtime(), 0u32) + } + + fn inner_loop_naive(accumulator: &mut Self::Accumulator, current_value: EI, i: u32) { + let (min, index) = accumulator; + if current_value < *min { + *min = current_value; + *index = i; + } + } + + fn assign_naive( + output: &mut Tensor, + accumulator: Self::Accumulator, + _shape_reduce_dim: u32, + ) { + let (_, index) = accumulator; + output[ABSOLUTE_POS] = EO::cast_from(index); + } +} diff --git a/crates/cubecl-reduce/src/naive/base.rs b/crates/cubecl-reduce/src/naive/base.rs new file mode 100644 index 00000000..82194ea3 --- /dev/null +++ b/crates/cubecl-reduce/src/naive/base.rs @@ -0,0 +1,21 @@ +use cubecl::prelude::*; + +/// Specifies the reduce dim algorithm in use +#[cube] +pub trait ReduceDimNaive: Send + Sync + 'static { + /// The reduction accumulator + type Accumulator: CubeType; + + /// Initialization for naive algorithm + fn initialize_naive() -> Self::Accumulator; + + /// Inner loop for naive algorithm + fn inner_loop_naive(accumulator: &mut Self::Accumulator, current_value: EI, i: u32); + + /// Assignation for naive algorithm + fn assign_naive( + output: &mut Tensor, + accumulator: Self::Accumulator, + shape_reduce_dim: u32, + ); +} diff --git a/crates/cubecl-reduce/src/naive/kernel.rs b/crates/cubecl-reduce/src/naive/kernel.rs new file mode 100644 index 00000000..c001edca --- /dev/null +++ b/crates/cubecl-reduce/src/naive/kernel.rs @@ -0,0 +1,61 @@ +use crate::{ + element::JitElement, kernel::reduce::init_reduce_output, tensor::JitTensor, JitRuntime, +}; +use cubecl::calculate_cube_count_elemwise; +use cubecl::prelude::*; + +use super::base::ReduceDimNaive; + +#[cube(launch_unchecked)] +pub(crate) fn naive_reduce_dim_kernel, EI: Numeric, EO: Numeric>( + input: &Tensor, + output: &mut Tensor, + dim: u32, +) { + if ABSOLUTE_POS >= output.len() { + return; + } + + let mut offset_input = 0; + + for i in 0..input.rank() { + let mut offset_local = ABSOLUTE_POS / output.stride(i); + offset_local %= output.shape(i); + if i != dim { + offset_input += offset_local * input.stride(i); + } + } + + let mut accumulator = RD::initialize_naive(); + + for i in 0..input.shape(dim) { + let index = i * input.stride(dim) + offset_input; + RD::inner_loop_naive(&mut accumulator, input[index], i); + } + + RD::assign_naive::(output, accumulator, input.shape(dim)); +} + +/// Executes the naive kernel for reduce dim +pub fn reduce_dim_naive, R: JitRuntime, EI: JitElement, EO: JitElement>( + input: JitTensor, + dim: usize, +) -> JitTensor { + let output = init_reduce_output::(&input, dim); + + let cube_dim = CubeDim::default(); + let cube_count = calculate_cube_count_elemwise(output.shape.num_elements(), cube_dim); + + unsafe { + naive_reduce_dim_kernel::launch_unchecked::( + &input.client, + cube_count, + cube_dim, + input.as_tensor_arg::(1), + output.as_tensor_arg::(1), + ScalarArg::new(dim as u32), + ); + } + + output +} diff --git a/crates/cubecl-reduce/src/naive/mean_dim.rs b/crates/cubecl-reduce/src/naive/mean_dim.rs new file mode 100644 index 00000000..c3210faf --- /dev/null +++ b/crates/cubecl-reduce/src/naive/mean_dim.rs @@ -0,0 +1,23 @@ +use cubecl::prelude::*; + +use crate::kernel::reduce::MeanDim; + +use super::base::ReduceDimNaive; + +#[cube] +impl ReduceDimNaive for MeanDim { + type Accumulator = EI; + + fn initialize_naive() -> EI { + EI::from_int(0) + } + + fn inner_loop_naive(accumulator: &mut EI, current_value: EI, _i: u32) { + *accumulator += current_value; + } + + fn assign_naive(output: &mut Tensor, accumulator: EI, shape_reduce_dim: u32) { + let mean = accumulator / EI::cast_from(shape_reduce_dim); + output[ABSOLUTE_POS] = EO::cast_from(mean); + } +} diff --git a/crates/cubecl-reduce/src/naive/mod.rs b/crates/cubecl-reduce/src/naive/mod.rs new file mode 100644 index 00000000..b11ee5e2 --- /dev/null +++ b/crates/cubecl-reduce/src/naive/mod.rs @@ -0,0 +1,7 @@ +pub(crate) mod argmax; +pub(crate) mod argmin; +pub(crate) mod base; +pub(crate) mod kernel; +pub(crate) mod mean_dim; +pub(crate) mod prod_dim; +pub(crate) mod sum_dim; diff --git a/crates/cubecl-reduce/src/naive/prod_dim.rs b/crates/cubecl-reduce/src/naive/prod_dim.rs new file mode 100644 index 00000000..5dfedf73 --- /dev/null +++ b/crates/cubecl-reduce/src/naive/prod_dim.rs @@ -0,0 +1,22 @@ +use cubecl::prelude::*; + +use crate::kernel::reduce::ProdDim; + +use super::base::ReduceDimNaive; + +#[cube] +impl ReduceDimNaive for ProdDim { + type Accumulator = EI; + + fn initialize_naive() -> EI { + EI::from_int(1) + } + + fn inner_loop_naive(accumulator: &mut EI, current_value: EI, _i: u32) { + *accumulator *= current_value; + } + + fn assign_naive(output: &mut Tensor, accumulator: EI, _shape_reduce_dim: u32) { + output[ABSOLUTE_POS] = EO::cast_from(accumulator); + } +} diff --git a/crates/cubecl-reduce/src/naive/sum_dim.rs b/crates/cubecl-reduce/src/naive/sum_dim.rs new file mode 100644 index 00000000..6d16669e --- /dev/null +++ b/crates/cubecl-reduce/src/naive/sum_dim.rs @@ -0,0 +1,22 @@ +use cubecl::prelude::*; + +use crate::kernel::reduce::SumDim; + +use super::base::ReduceDimNaive; + +#[cube] +impl ReduceDimNaive for SumDim { + type Accumulator = EI; + + fn initialize_naive() -> EI { + EI::from_int(0) + } + + fn inner_loop_naive(accumulator: &mut EI, current_value: EI, _i: u32) { + *accumulator += current_value; + } + + fn assign_naive(output: &mut Tensor, accumulator: EI, _shape_reduce_dim: u32) { + output[ABSOLUTE_POS] = EO::cast_from(accumulator); + } +} diff --git a/crates/cubecl-reduce/src/prod.rs b/crates/cubecl-reduce/src/prod.rs new file mode 100644 index 00000000..77227bae --- /dev/null +++ b/crates/cubecl-reduce/src/prod.rs @@ -0,0 +1,15 @@ +use crate::{element::JitElement, tensor::JitTensor, JitRuntime}; +use burn_tensor::Shape; + +use super::{prod_dim, ReduceStrategy}; + +/// Multiply all elements in the input buffer. +pub fn prod( + input: JitTensor, + strategy: ReduceStrategy, +) -> JitTensor { + let shape = Shape::new([input.shape.num_elements()]); + let input: JitTensor = + JitTensor::new_contiguous(input.client, input.device, shape, input.handle, input.dtype); + prod_dim::(input, 0, strategy) +} diff --git a/crates/cubecl-reduce/src/shared/argmax.rs b/crates/cubecl-reduce/src/shared/argmax.rs new file mode 100644 index 00000000..1685a200 --- /dev/null +++ b/crates/cubecl-reduce/src/shared/argmax.rs @@ -0,0 +1,63 @@ +use crate::{kernel::reduce::Argmax, JitElement}; +use cubecl::prelude::*; + +use super::base::ReduceDimShared; + +#[cube] +impl ReduceDimShared for Argmax { + /// The reduction accumulator + type Accumulator = (SharedMemory, SharedMemory); + type Value = (EIn, u32); + + /// Initialization for shared algorithm + fn initialize_shared( + shared_memory_size: u32, + write_position: u32, + ) -> (SharedMemory, SharedMemory) { + let mut value_shared = SharedMemory::new(shared_memory_size); + let mut index_shared = SharedMemory::new(shared_memory_size); + value_shared[write_position] = comptime![EIn::minimum_value()].runtime(); + index_shared[write_position] = 0; + (value_shared, index_shared) + } + + /// How to write to shared memory + fn write_to_shared( + shared_memory: &mut (SharedMemory, SharedMemory), + write_position: u32, + value: (EIn, u32), + ) { + let (values, indices) = shared_memory; + let (value, index) = value; + + if value > values[write_position] { + values[write_position] = value; + indices[write_position] = index; + } + } + + /// How to read from input in shared algorithm + fn read_from_input(input: &Tensor, read_position: u32, i: u32) -> (EIn, u32) { + (input[read_position], i) + } + + /// How to read from shared memory + fn read_from_shared( + shared_memory: &(SharedMemory, SharedMemory), + read_position: u32, + ) -> (EIn, u32) { + let (values, indices) = shared_memory; + (values[read_position], indices[read_position]) + } + + /// How to assign from shared memory + fn assign_shared( + shared_memory: &(SharedMemory, SharedMemory), + output: &mut Tensor, + write_position: u32, + _shape_reduce_dim: u32, + ) { + let (_, indices) = shared_memory; + output[write_position] = EOut::cast_from(indices[0]); + } +} diff --git a/crates/cubecl-reduce/src/shared/argmin.rs b/crates/cubecl-reduce/src/shared/argmin.rs new file mode 100644 index 00000000..ff7826b1 --- /dev/null +++ b/crates/cubecl-reduce/src/shared/argmin.rs @@ -0,0 +1,64 @@ +use cubecl::prelude::*; + +use crate::{kernel::reduce::Argmin, JitElement}; + +use super::base::ReduceDimShared; + +#[cube] +impl ReduceDimShared for Argmin { + /// The reduction accumulator + type Accumulator = (SharedMemory, SharedMemory); + type Value = (EIn, u32); + + /// Initialization for shared algorithm + fn initialize_shared( + shared_memory_size: u32, + write_position: u32, + ) -> (SharedMemory, SharedMemory) { + let mut value_shared = SharedMemory::new(shared_memory_size); + let mut index_shared = SharedMemory::new(shared_memory_size); + value_shared[write_position] = comptime![EIn::maximum_value()].runtime(); + index_shared[write_position] = 0; + (value_shared, index_shared) + } + + /// How to write to shared memory + fn write_to_shared( + shared_memory: &mut (SharedMemory, SharedMemory), + write_position: u32, + value: (EIn, u32), + ) { + let (values, indices) = shared_memory; + let (value, index) = value; + + if value < values[write_position] { + values[write_position] = value; + indices[write_position] = index; + } + } + + /// How to read from input in shared algorithm + fn read_from_input(input: &Tensor, read_position: u32, i: u32) -> (EIn, u32) { + (input[read_position], i) + } + + /// How to read from shared memory + fn read_from_shared( + shared_memory: &(SharedMemory, SharedMemory), + read_position: u32, + ) -> (EIn, u32) { + let (values, indices) = shared_memory; + (values[read_position], indices[read_position]) + } + + /// How to assign from shared memory + fn assign_shared( + shared_memory: &(SharedMemory, SharedMemory), + output: &mut Tensor, + write_position: u32, + _shape_reduce_dim: u32, + ) { + let (_, indices) = shared_memory; + output[write_position] = EOut::cast_from(indices[0]); + } +} diff --git a/crates/cubecl-reduce/src/shared/base.rs b/crates/cubecl-reduce/src/shared/base.rs new file mode 100644 index 00000000..bdb70e6a --- /dev/null +++ b/crates/cubecl-reduce/src/shared/base.rs @@ -0,0 +1,35 @@ +use cubecl::prelude::*; + +use crate::JitElement; + +/// Specifies the reduce dim algorithm in use +#[cube] +pub trait ReduceDimShared: Send + Sync + 'static { + /// The reduction accumulator + type Accumulator: CubeType; + type Value: CubeType; + + /// Initialization for shared algorithm + fn initialize_shared(shared_memory_size: u32, write_position: u32) -> Self::Accumulator; + + /// How to write to shared memory + fn write_to_shared( + shared_memory: &mut Self::Accumulator, + write_position: u32, + value: Self::Value, + ); + + /// How to read from input in shared algorithm + fn read_from_input(input: &Tensor, read_position: u32, i: u32) -> Self::Value; + + /// How to read from shared memory + fn read_from_shared(shared_memory: &Self::Accumulator, read_position: u32) -> Self::Value; + + /// How to assign from shared memory + fn assign_shared( + shared_memory: &Self::Accumulator, + output: &mut Tensor, + write_position: u32, + shape_reduce_dim: u32, + ); +} diff --git a/crates/cubecl-reduce/src/shared/kernel.rs b/crates/cubecl-reduce/src/shared/kernel.rs new file mode 100644 index 00000000..1b2dcb35 --- /dev/null +++ b/crates/cubecl-reduce/src/shared/kernel.rs @@ -0,0 +1,117 @@ +use cubecl::prelude::*; + +use crate::{kernel::reduce::init_reduce_output, tensor::JitTensor, JitElement, JitRuntime}; + +use super::base::ReduceDimShared; + +#[cube(launch)] +pub fn reduce_dim_shared_kernel< + RD: ReduceDimShared, + EIn: JitElement, + EOut: JitElement, +>( + input: &Tensor, + output: &mut Tensor, + #[comptime] dim: u32, + #[comptime] smem_size: u32, + #[comptime] elems_per_thread: u32, + #[comptime] divisible_shape: bool, +) { + let reduce_group_id = CUBE_POS; + + let stride_reduce_dim_input = input.stride(dim); + let shape_reduce_dim_input = input.shape(dim); + + let mut shared_memory = RD::initialize_shared(smem_size, UNIT_POS); + + let mut index_offset = 0; + + for i in 0..input.rank() { + let num_block = reduce_group_id / output.stride(i) % output.shape(i); + index_offset += num_block * input.stride(i); + } + + for i in 0..elems_per_thread { + let nth = i * CUBE_DIM + UNIT_POS; + + #[allow(clippy::collapsible_else_if)] + if divisible_shape { + let current_pos = nth * stride_reduce_dim_input + index_offset; + + let new_value = RD::read_from_input(input, current_pos, nth); + RD::write_to_shared(&mut shared_memory, UNIT_POS, new_value); + } else { + if nth < shape_reduce_dim_input { + let current_pos = nth * stride_reduce_dim_input + index_offset; + + let new_value = RD::read_from_input(input, current_pos, nth); + RD::write_to_shared(&mut shared_memory, UNIT_POS, new_value); + } + } + } + + sync_units(); + + let mut n_threads = CUBE_DIM; + + while n_threads > 1 { + n_threads /= 2; + + if UNIT_POS < n_threads { + let read_pos = n_threads + UNIT_POS; + let read_value = RD::read_from_shared(&shared_memory, read_pos); + RD::write_to_shared(&mut shared_memory, UNIT_POS, read_value); + } + + sync_units(); + } + + if UNIT_POS == 0 { + RD::assign_shared( + &shared_memory, + output, + reduce_group_id, + shape_reduce_dim_input, + ); + } +} + +/// Executes the shared memory kernel for reduce dim +pub fn reduce_dim_shared< + RD: ReduceDimShared, + R: JitRuntime, + EI: JitElement, + EO: JitElement, +>( + input: JitTensor, + dim: usize, +) -> JitTensor { + let output = init_reduce_output::(&input, dim); + + let num_elems_output = output.shape.num_elements(); + let cube_dim = CubeDim::default(); + let cube_count_x = f32::ceil(f32::sqrt(num_elems_output as f32)); + let cube_count_y = f32::ceil(num_elems_output as f32 / cube_count_x); + let cube_count = CubeCount::Static(cube_count_x as u32, cube_count_y as u32, 1); + + let reduce_group_size = input.shape.dims[dim]; + let n_invocation_per_cube = cube_dim.num_elems(); + let elems_per_thread = + f32::ceil(reduce_group_size as f32 / n_invocation_per_cube as f32) as u32; + + let divisible_shape = n_invocation_per_cube * elems_per_thread == reduce_group_size as u32; + + reduce_dim_shared_kernel::launch::( + &input.client, + cube_count, + cube_dim, + input.as_tensor_arg::(1), + output.as_tensor_arg::(1), + dim as u32, + cube_dim.num_elems(), + elems_per_thread, + divisible_shape, + ); + + output +} diff --git a/crates/cubecl-reduce/src/shared/mean_dim.rs b/crates/cubecl-reduce/src/shared/mean_dim.rs new file mode 100644 index 00000000..0b09d917 --- /dev/null +++ b/crates/cubecl-reduce/src/shared/mean_dim.rs @@ -0,0 +1,44 @@ +use crate::{kernel::reduce::MeanDim, JitElement}; +use cubecl::prelude::*; + +use super::base::ReduceDimShared; + +#[cube] +impl ReduceDimShared for MeanDim { + /// The reduction accumulator + type Accumulator = SharedMemory; + type Value = EIn; + + /// Initialization for shared algorithm + fn initialize_shared(shared_memory_size: u32, write_position: u32) -> SharedMemory { + let mut value_shared = SharedMemory::new(shared_memory_size); + value_shared[write_position] = EIn::from_int(0); + value_shared + } + + /// How to write to shared memory + fn write_to_shared(shared_memory: &mut SharedMemory, write_position: u32, value: EIn) { + shared_memory[write_position] += value; + } + + /// How to read from input in shared algorithm + fn read_from_input(input: &Tensor, read_position: u32, _i: u32) -> EIn { + input[read_position] + } + + /// How to read from shared memory + fn read_from_shared(shared_memory: &SharedMemory, read_position: u32) -> EIn { + shared_memory[read_position] + } + + /// How to assign from shared memory + fn assign_shared( + shared_memory: &SharedMemory, + output: &mut Tensor, + write_position: u32, + shape_reduce_dim: u32, + ) { + let mean = shared_memory[0] / EIn::cast_from(shape_reduce_dim); + output[write_position] = EOut::cast_from(mean); + } +} diff --git a/crates/cubecl-reduce/src/shared/mod.rs b/crates/cubecl-reduce/src/shared/mod.rs new file mode 100644 index 00000000..b11ee5e2 --- /dev/null +++ b/crates/cubecl-reduce/src/shared/mod.rs @@ -0,0 +1,7 @@ +pub(crate) mod argmax; +pub(crate) mod argmin; +pub(crate) mod base; +pub(crate) mod kernel; +pub(crate) mod mean_dim; +pub(crate) mod prod_dim; +pub(crate) mod sum_dim; diff --git a/crates/cubecl-reduce/src/shared/prod_dim.rs b/crates/cubecl-reduce/src/shared/prod_dim.rs new file mode 100644 index 00000000..8041cc68 --- /dev/null +++ b/crates/cubecl-reduce/src/shared/prod_dim.rs @@ -0,0 +1,43 @@ +use crate::{kernel::reduce::ProdDim, JitElement}; +use cubecl::prelude::*; + +use super::base::ReduceDimShared; + +#[cube] +impl ReduceDimShared for ProdDim { + /// The reduction accumulator + type Accumulator = SharedMemory; + type Value = EIn; + + /// Initialization for shared algorithm + fn initialize_shared(shared_memory_size: u32, write_position: u32) -> SharedMemory { + let mut value_shared = SharedMemory::new(shared_memory_size); + value_shared[write_position] = EIn::from_int(1); + value_shared + } + + /// How to write to shared memory + fn write_to_shared(shared_memory: &mut SharedMemory, write_position: u32, value: EIn) { + shared_memory[write_position] *= value; + } + + /// How to read from input in shared algorithm + fn read_from_input(input: &Tensor, read_position: u32, _i: u32) -> EIn { + input[read_position] + } + + /// How to read from shared memory + fn read_from_shared(shared_memory: &SharedMemory, read_position: u32) -> EIn { + shared_memory[read_position] + } + + /// How to assign from shared memory + fn assign_shared( + shared_memory: &SharedMemory, + output: &mut Tensor, + write_position: u32, + _shape_reduce_dim: u32, + ) { + output[write_position] = EOut::cast_from(shared_memory[0]); + } +} diff --git a/crates/cubecl-reduce/src/shared/sum_dim.rs b/crates/cubecl-reduce/src/shared/sum_dim.rs new file mode 100644 index 00000000..da2b7337 --- /dev/null +++ b/crates/cubecl-reduce/src/shared/sum_dim.rs @@ -0,0 +1,43 @@ +use crate::{kernel::reduce::SumDim, JitElement}; +use cubecl::prelude::*; + +use super::base::ReduceDimShared; + +#[cube] +impl ReduceDimShared for SumDim { + /// The reduction accumulator + type Accumulator = SharedMemory; + type Value = EIn; + + /// Initialization for shared algorithm + fn initialize_shared(shared_memory_size: u32, write_position: u32) -> SharedMemory { + let mut value_shared = SharedMemory::new(shared_memory_size); + value_shared[write_position] = EIn::from_int(0); + value_shared + } + + /// How to write to shared memory + fn write_to_shared(shared_memory: &mut SharedMemory, write_position: u32, value: EIn) { + shared_memory[write_position] += value; + } + + /// How to read from input in shared algorithm + fn read_from_input(input: &Tensor, read_position: u32, _i: u32) -> EIn { + input[read_position] + } + + /// How to read from shared memory + fn read_from_shared(shared_memory: &SharedMemory, read_position: u32) -> EIn { + shared_memory[read_position] + } + + /// How to assign from shared memory + fn assign_shared( + shared_memory: &SharedMemory, + output: &mut Tensor, + write_position: u32, + _shape_reduce_dim: u32, + ) { + output[write_position] = EOut::cast_from(shared_memory[0]); + } +} diff --git a/crates/cubecl-reduce/src/subcube/argmax.rs b/crates/cubecl-reduce/src/subcube/argmax.rs new file mode 100644 index 00000000..428bd712 --- /dev/null +++ b/crates/cubecl-reduce/src/subcube/argmax.rs @@ -0,0 +1,54 @@ +use cubecl::{cube, prelude::*}; + +use crate::{kernel::reduce::Argmax, JitElement}; + +use super::base::ReduceDimSubcube; + +#[cube] +impl ReduceDimSubcube for Argmax { + /// The reduction accumulator + type Accumulator = (SharedMemory, SharedMemory); + type Value = (EIn, u32); + + fn init_shared(#[comptime] size: u32) -> Self::Accumulator { + let value_shared = SharedMemory::new(size); + let index_shared = SharedMemory::new(size); + (value_shared, index_shared) + } + + fn init_value() -> Self::Value { + (comptime![EIn::minimum_value()], 0u32) + } + + fn read_value(input: &Tensor, pos: u32, i: u32) -> Self::Value { + (input[pos], i) + } + + fn read_from_shared(acc: &Self::Accumulator, pos: u32) -> Self::Value { + let (values, indices) = acc; + (values[pos], indices[pos]) + } + + fn update_value(current: &mut Self::Value, new: Self::Value) { + let (current_val, current_idx) = current; + let (new_val, new_idx) = new; + *current_val = Max::max(*current_val, new_val); + *current_idx = select(*current_val == new_val, new_idx, *current_idx); + } + + fn reduce_subcube(acc: &mut Self::Accumulator, write_position: u32, value: Self::Value) { + let (val, index) = value; + let (val_smem, index_smem) = acc; + let max = plane_max(val); + + if max == val { + val_smem[write_position] = val; + index_smem[write_position] = index; + } + } + + fn store(acc: &Self::Accumulator, out: &mut Tensor, pos: u32, _layout: u32) { + let (_, indices) = acc; + out[pos] = EOut::cast_from(indices[0]); + } +} diff --git a/crates/cubecl-reduce/src/subcube/argmin.rs b/crates/cubecl-reduce/src/subcube/argmin.rs new file mode 100644 index 00000000..6a002a5d --- /dev/null +++ b/crates/cubecl-reduce/src/subcube/argmin.rs @@ -0,0 +1,54 @@ +use cubecl::{cube, prelude::*}; + +use crate::{kernel::reduce::Argmin, JitElement}; + +use super::base::ReduceDimSubcube; + +#[cube] +impl ReduceDimSubcube for Argmin { + /// The reduction accumulator + type Accumulator = (SharedMemory, SharedMemory); + type Value = (EIn, u32); + + fn init_shared(#[comptime] size: u32) -> Self::Accumulator { + let value_shared = SharedMemory::new(size); + let index_shared = SharedMemory::new(size); + (value_shared, index_shared) + } + + fn init_value() -> Self::Value { + (comptime![EIn::maximum_value()], 0u32) + } + + fn read_value(input: &Tensor, pos: u32, i: u32) -> Self::Value { + (input[pos], i) + } + + fn read_from_shared(acc: &Self::Accumulator, pos: u32) -> Self::Value { + let (values, indices) = acc; + (values[pos], indices[pos]) + } + + fn update_value(current: &mut Self::Value, new: Self::Value) { + let (current_val, current_idx) = current; + let (new_val, new_idx) = new; + *current_val = Min::min(*current_val, new_val); + *current_idx = select(*current_val == new_val, new_idx, *current_idx); + } + + fn reduce_subcube(acc: &mut Self::Accumulator, write_position: u32, value: Self::Value) { + let (val, index) = value; + let (val_smem, index_smem) = acc; + let min = plane_min(val); + + if min == val { + val_smem[write_position] = val; + index_smem[write_position] = index; + } + } + + fn store(acc: &Self::Accumulator, out: &mut Tensor, pos: u32, _layout: u32) { + let (_, indices) = acc; + out[pos] = EOut::cast_from(indices[0]); + } +} diff --git a/crates/cubecl-reduce/src/subcube/base.rs b/crates/cubecl-reduce/src/subcube/base.rs new file mode 100644 index 00000000..a700bf84 --- /dev/null +++ b/crates/cubecl-reduce/src/subcube/base.rs @@ -0,0 +1,17 @@ +use cubecl::prelude::*; + +use crate::JitElement; + +#[cube] +pub trait ReduceDimSubcube: Send + Sync + 'static { + type Accumulator: CubeType; + type Value: CubeType; + + fn init_shared(#[comptime] size: u32) -> Self::Accumulator; + fn init_value() -> Self::Value; + fn read_value(input: &Tensor, pos: u32, i: u32) -> Self::Value; + fn read_from_shared(acc: &Self::Accumulator, pos: u32) -> Self::Value; + fn update_value(current: &mut Self::Value, new: Self::Value); + fn reduce_subcube(acc: &mut Self::Accumulator, pos: u32, value: Self::Value); + fn store(acc: &Self::Accumulator, out: &mut Tensor, pos: u32, dim_len: u32); +} diff --git a/crates/cubecl-reduce/src/subcube/kernel.rs b/crates/cubecl-reduce/src/subcube/kernel.rs new file mode 100644 index 00000000..4e783e74 --- /dev/null +++ b/crates/cubecl-reduce/src/subcube/kernel.rs @@ -0,0 +1,134 @@ +use cubecl::{prelude::*, CubeCount, CubeDim, Feature}; + +use crate::{ + kernel::reduce::{init_reduce_output, shared::kernel::reduce_dim_shared, ReduceDimAlgorithm}, + tensor::JitTensor, + JitElement, JitRuntime, +}; + +use super::base::ReduceDimSubcube; + +#[cube(launch)] +pub fn reduce_dim_subcube_kernel< + RD: ReduceDimSubcube, + EIn: JitElement, + EOut: JitElement, +>( + input: &Tensor, + output: &mut Tensor, + #[comptime] dim: u32, + #[comptime] subcube_size: u32, + #[comptime] elems_per_thread: u32, + #[comptime] divisible_shape: bool, +) { + let reduce_group_id = CUBE_POS; + + let stride_reduce_dim_input = input.stride(dim); + let shape_reduce_dim_input = input.shape(dim); + + let should_unroll = elems_per_thread <= 8; + + let warp_id = UNIT_POS / PLANE_DIM; + + let mut shared_memory = RD::init_shared(subcube_size); + + let mut index_offset = 0; + + for i in 0..input.rank() { + let num_block = reduce_group_id / output.stride(i) % output.shape(i); + index_offset += num_block * input.stride(i); + } + + let mut value = RD::init_value(); + + #[unroll(should_unroll)] + for i in 0..elems_per_thread { + let nth = i * CUBE_DIM + UNIT_POS; + let current_pos = nth * stride_reduce_dim_input + index_offset; + + #[allow(clippy::collapsible_else_if)] + if divisible_shape { + let next = RD::read_value(input, current_pos, nth); + RD::update_value(&mut value, next); + } else { + if nth < shape_reduce_dim_input { + let next = RD::read_value(input, current_pos, nth); + RD::update_value(&mut value, next); + } + } + } + + RD::reduce_subcube(&mut shared_memory, warp_id, value); + + sync_units(); + + if UNIT_POS >= PLANE_DIM { + return; + } + + let value = RD::read_from_shared(&shared_memory, UNIT_POS); + RD::reduce_subcube(&mut shared_memory, 0, value); + + if UNIT_POS == 0 { + RD::store( + &shared_memory, + output, + reduce_group_id, + shape_reduce_dim_input, + ); + } +} + +/// Executes the shared memory kernel for reduce dim +pub fn reduce_dim_subcube< + RD: ReduceDimAlgorithm, + R: JitRuntime, + EI: JitElement, + EO: JitElement, +>( + input: JitTensor, + dim: usize, +) -> JitTensor { + let topology = input.client.properties().hardware_properties(); + + if !input.client.properties().feature_enabled(Feature::Plane) + || topology.plane_size_min != topology.plane_size_max + { + return reduce_dim_shared::(input, dim); + } + + let subcube_size = topology.plane_size_min; + + let output = init_reduce_output::(&input, dim); + + let num_elems_output = output.shape.num_elements(); + let cube_dim = CubeDim { + x: subcube_size, + y: subcube_size, + z: 1, + }; + let cube_count_x = f32::ceil(f32::sqrt(num_elems_output as f32)); + let cube_count_y = f32::ceil(num_elems_output as f32 / cube_count_x); + let cube_count = CubeCount::Static(cube_count_x as u32, cube_count_y as u32, 1); + + let reduce_group_size = input.shape.dims[dim]; + let n_invocation_per_cube = cube_dim.num_elems(); + let elems_per_thread = + f32::ceil(reduce_group_size as f32 / n_invocation_per_cube as f32) as u32; + + let divisible_shape = n_invocation_per_cube * elems_per_thread == reduce_group_size as u32; + + reduce_dim_subcube_kernel::launch::( + &input.client, + cube_count, + cube_dim, + input.as_tensor_arg::(1), + output.as_tensor_arg::(1), + dim as u32, + subcube_size, + elems_per_thread, + divisible_shape, + ); + + output +} diff --git a/crates/cubecl-reduce/src/subcube/mean_dim.rs b/crates/cubecl-reduce/src/subcube/mean_dim.rs new file mode 100644 index 00000000..63e14de4 --- /dev/null +++ b/crates/cubecl-reduce/src/subcube/mean_dim.rs @@ -0,0 +1,45 @@ +use cubecl::{cube, prelude::*}; + +use crate::{kernel::reduce::MeanDim, JitElement}; + +use super::base::ReduceDimSubcube; + +#[cube] +impl ReduceDimSubcube for MeanDim { + /// The reduction accumulator + type Accumulator = SharedMemory; + type Value = EIn; + + fn init_shared(#[comptime] size: u32) -> Self::Accumulator { + SharedMemory::new(size) + } + + fn init_value() -> Self::Value { + comptime![EIn::default()].runtime() + } + + fn read_value(input: &Tensor, pos: u32, _i: u32) -> Self::Value { + input[pos] + } + + fn read_from_shared(acc: &Self::Accumulator, pos: u32) -> Self::Value { + acc[pos] + } + + fn update_value(current: &mut Self::Value, new: Self::Value) { + *current += new; + } + + fn reduce_subcube(acc: &mut Self::Accumulator, write_position: u32, value: Self::Value) { + let sum = plane_sum(value); + + if UNIT_POS % PLANE_DIM == 0 { + acc[write_position] = sum; + } + } + + fn store(acc: &Self::Accumulator, out: &mut Tensor, pos: u32, dim_length: u32) { + let denom = EIn::cast_from(dim_length); + out[pos] = EOut::cast_from(acc[0] / denom); + } +} diff --git a/crates/cubecl-reduce/src/subcube/mod.rs b/crates/cubecl-reduce/src/subcube/mod.rs new file mode 100644 index 00000000..183c1e2d --- /dev/null +++ b/crates/cubecl-reduce/src/subcube/mod.rs @@ -0,0 +1,7 @@ +pub mod argmax; +pub mod argmin; +pub mod base; +pub mod kernel; +pub mod mean_dim; +pub mod prod_dim; +pub mod sum_dim; diff --git a/crates/cubecl-reduce/src/subcube/prod_dim.rs b/crates/cubecl-reduce/src/subcube/prod_dim.rs new file mode 100644 index 00000000..4c0b71d9 --- /dev/null +++ b/crates/cubecl-reduce/src/subcube/prod_dim.rs @@ -0,0 +1,44 @@ +use cubecl::{cube, prelude::*}; + +use crate::{kernel::reduce::ProdDim, JitElement}; + +use super::base::ReduceDimSubcube; + +#[cube] +impl ReduceDimSubcube for ProdDim { + /// The reduction accumulator + type Accumulator = SharedMemory; + type Value = EIn; + + fn init_shared(#[comptime] size: u32) -> Self::Accumulator { + SharedMemory::new(size) + } + + fn init_value() -> Self::Value { + comptime![EIn::from_int(1)].runtime() + } + + fn read_value(input: &Tensor, pos: u32, _i: u32) -> Self::Value { + input[pos] + } + + fn read_from_shared(acc: &Self::Accumulator, pos: u32) -> Self::Value { + acc[pos] + } + + fn update_value(current: &mut Self::Value, new: Self::Value) { + *current *= new; + } + + fn reduce_subcube(acc: &mut Self::Accumulator, write_position: u32, value: Self::Value) { + let prod = plane_prod(value); + + if UNIT_POS % PLANE_DIM == 0 { + acc[write_position] = prod; + } + } + + fn store(acc: &Self::Accumulator, out: &mut Tensor, pos: u32, _layout: u32) { + out[pos] = EOut::cast_from(acc[0]); + } +} diff --git a/crates/cubecl-reduce/src/subcube/sum_dim.rs b/crates/cubecl-reduce/src/subcube/sum_dim.rs new file mode 100644 index 00000000..3aac1a3c --- /dev/null +++ b/crates/cubecl-reduce/src/subcube/sum_dim.rs @@ -0,0 +1,44 @@ +use cubecl::{cube, prelude::*}; + +use crate::{kernel::reduce::SumDim, JitElement}; + +use super::base::ReduceDimSubcube; + +#[cube] +impl ReduceDimSubcube for SumDim { + /// The reduction accumulator + type Accumulator = SharedMemory; + type Value = EIn; + + fn init_shared(#[comptime] size: u32) -> Self::Accumulator { + SharedMemory::new(size) + } + + fn init_value() -> Self::Value { + comptime![EIn::default()].runtime() + } + + fn read_value(input: &Tensor, pos: u32, _i: u32) -> Self::Value { + input[pos] + } + + fn read_from_shared(acc: &Self::Accumulator, pos: u32) -> Self::Value { + acc[pos] + } + + fn update_value(current: &mut Self::Value, new: Self::Value) { + *current += new; + } + + fn reduce_subcube(acc: &mut Self::Accumulator, write_position: u32, value: Self::Value) { + let sum = plane_sum(value); + + if UNIT_POS % PLANE_DIM == 0 { + acc[write_position] = sum; + } + } + + fn store(acc: &Self::Accumulator, out: &mut Tensor, pos: u32, _layout: u32) { + out[pos] = EOut::cast_from(acc[0]); + } +} diff --git a/crates/cubecl-reduce/src/sum.rs b/crates/cubecl-reduce/src/sum.rs index 4b964d86..fea80bcc 100644 --- a/crates/cubecl-reduce/src/sum.rs +++ b/crates/cubecl-reduce/src/sum.rs @@ -1,110 +1,15 @@ -use cubecl_core as cubecl; -use cubecl_core::prelude::*; - -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] -pub struct ReduceConfig { - pub line_size: u32, - pub max_num_planes: u32, -} - -/// Compute the sum of all elements of `input` and write it to the first element of `output`. -/// -/// This doesn't reduce values across lines. For a version that does, use [reduce_sum_lined]. -/// -/// This is a work in progress toward a more general multi-dimensional reduce kernel. -#[cube(launch_unchecked)] -pub fn reduce_sum( - input: &Tensor>, - output: &mut Tensor>, - #[comptime] config: ReduceConfig, -) { - reduce_sum_vector(&input.to_slice(), &mut output.to_slice_mut(), config); -} - -/// Compute the sum of all elements of `input` and write it to the first element of `output`. -/// -/// This reduces values across lines. For a version that doesn't, use [reduce_sum]. -/// -/// This is a work in progress toward a more general multi-dimensional reduce kernel. -#[cube(launch_unchecked)] -pub fn reduce_sum_lined( - input: &Tensor>, - output: &mut Tensor, - #[comptime] config: ReduceConfig, -) { - let mut tmp = SharedMemory::new_lined(1, config.line_size); - reduce_sum_vector(&input.to_slice(), &mut tmp.to_slice_mut(), config); - reduce_sum_lines(&tmp.to_slice(), &mut output.to_slice_mut(), 1_u32); -} - -/// Compute the sum of all elements of `input` and write it to the first element of `output`. -#[cube] -pub fn reduce_sum_vector( - input: &Slice>, - output: &mut SliceMut>, - #[comptime] config: ReduceConfig, -) { - let plane_id = UNIT_POS / PLANE_DIM; - let num_planes = div_ceil(CUBE_DIM, PLANE_DIM); - - // Compute the number of required iterations to reduce all lines when reducing CUBE_DIM lines per iteration. - let num_iterations = div_ceil(input.len(), CUBE_DIM); - - let mut memory = SharedMemory::new_lined(config.max_num_planes, input[0].size()); - memory[plane_id] = Line::empty(config.line_size).fill(N::from_int(0)); - - // For each iteration, each plane reduces PLANE_DIM lines into a single line. Then, we accumulate the results - // into the memory. Thus, after the loop, the reduction of the memory yields the expected output. - for i in 0..num_iterations { - let index = i * CUBE_DIM + plane_id * PLANE_DIM + UNIT_POS_PLANE; - let value = select( - index < input.len(), - input[index], - Line::empty(config.line_size).fill(N::from_int(0)), - ); - let sum = plane_sum(value); - if UNIT_POS_PLANE == 0 { - memory[plane_id] += sum; - } - } - - // Make sure that each local sum is completed and written to memory. - sync_units(); - - // Sum each elements in memory - let sum = plane_sum(select( - UNIT_POS_PLANE < num_planes, - memory[UNIT_POS_PLANE], - Line::empty(config.line_size).fill(N::from_int(0)), - )); - if UNIT_POS == 0 { - output[0] = sum; - } -} - -/// For each line, sum all elements and write the result into the corresponding element of output. -#[cube] -pub fn reduce_sum_lines( - input: &Slice>, - output: &mut SliceMut, - #[comptime] length: u32, -) { - if UNIT_POS < length { - let line = input[UNIT_POS]; - - let mut sum = N::from_int(0); - - #[unroll] - for k in 0..line.size() { - sum += line[k]; - } - - output[UNIT_POS] = sum; - } -} - -// Integer division rounded up. -#[cube] -fn div_ceil(a: u32, b: u32) -> u32 { - a / b + ((a % b) > 0) as u32 +use crate::{element::JitElement, tensor::JitTensor, JitRuntime}; +use burn_tensor::Shape; + +use super::{sum_dim, ReduceStrategy}; + +/// Sum all elements in the input buffer. +pub fn sum( + input: JitTensor, + strategy: ReduceStrategy, +) -> JitTensor { + let shape = Shape::new([input.shape.num_elements()]); + let input: JitTensor = + JitTensor::new_contiguous(input.client, input.device, shape, input.handle, input.dtype); + sum_dim::(input, 0, strategy) } diff --git a/crates/cubecl-reduce/src/tune/base.rs b/crates/cubecl-reduce/src/tune/base.rs new file mode 100644 index 00000000..f52bfd7c --- /dev/null +++ b/crates/cubecl-reduce/src/tune/base.rs @@ -0,0 +1,94 @@ +use burn_tensor::{Element, ElementConversion}; +use cubecl::tune::{local_tuner, tune_with, LocalTuner}; +use cubecl::{tune, Feature}; + +use crate::{ + element::JitElement, + kernel::{ + prng::random_like_uniform, + reduce::{ + naive::kernel::reduce_dim_naive, shared::kernel::reduce_dim_shared, + subcube::kernel::reduce_dim_subcube, ReduceDimAlgorithm, + }, + }, + tensor::JitTensor, + tune_key::JitAutotuneKey, + JitRuntime, JitTuneId, +}; + +use super::create_key; + +/// Set of reduce_dim implementations available for autotune +/// Autotune key is given by concatenating the closest upper power of 2 of +/// dim to reduce, and product of others +#[tune( + operations(reduce_dim_naive, reduce_dim_shared, reduce_dim_subcube), + create_key = create_key::, + should_run = should_run +)] +pub fn reduce_dim_operations< + RD: ReduceDimAlgorithm, + R: JitRuntime, + EI: JitElement + Element, + EO: JitElement + Element, +>( + key: JitAutotuneKey, + input: JitTensor, + reduce_dim: usize, +) -> JitTensor { + let random_bounds: (EI, EI) = ((-10.0).elem::(), (10.0).elem::()); + let input = random_like_uniform(input, random_bounds.0, random_bounds.1); + + tune_with!(input, reduce_dim) +} + +/// Executes autotune on reduce_dim operation +pub(crate) fn reduce_dim_autotune< + RD: ReduceDimAlgorithm, + R: JitRuntime, + EI: JitElement + Element, + EO: JitElement + Element, +>( + input: JitTensor, + reduce_dim: usize, +) -> JitTensor { + let client = input.client.clone(); + + let id = JitTuneId::new::(&input.device); + + let operation_set = Box::new(ReduceDimOperations::::new(input, reduce_dim)); + + static TUNER: LocalTuner = local_tuner!(); + + TUNER.execute(&id, &client, operation_set) +} + +fn should_run< + RD: ReduceDimAlgorithm, + R: JitRuntime, + EI: JitElement + Element, + EO: JitElement + Element, +>( + op: &ReduceDimOperations, + key: &JitAutotuneKey, + index: usize, +) -> bool { + let JitAutotuneKey::ReduceDim(key) = key else { + unreachable!() + }; + + match index { + // Naive + 0 => key.reduce_dim_length <= 8192, + // Shared + 1 => key.reduce_dim_length >= 16, + // Subcube + 2 => { + let props = op.input.client.properties(); + let hardware = props.hardware_properties(); + props.feature_enabled(Feature::Plane) + && hardware.plane_size_min == hardware.plane_size_max + } + _ => true, + } +} diff --git a/crates/cubecl-reduce/src/tune/key.rs b/crates/cubecl-reduce/src/tune/key.rs new file mode 100644 index 00000000..3634022b --- /dev/null +++ b/crates/cubecl-reduce/src/tune/key.rs @@ -0,0 +1,39 @@ +use cubecl::AutotuneKey; +use serde::{Deserialize, Serialize}; + +use burn_tensor::DType; + +use crate::{tensor::JitTensor, JitAutotuneKey, JitElement, JitRuntime}; + +/// Autotune key representative of reduce versions +#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)] +pub struct ReduceAutotuneKey { + #[autotune(anchor)] + pub(crate) reduce_dim_length: usize, + #[autotune(anchor)] + pub(crate) reduce_dim_stride: usize, + #[autotune(anchor)] + pub(crate) others_product: usize, + dtype: DType, +} + +pub(crate) fn create_key( + input: &JitTensor, + reduce_dim: &usize, +) -> JitAutotuneKey { + let dims = &input.shape.dims; + let reduce_dim = *reduce_dim; + + let mut others_product = 1; + for (d, len) in dims.iter().enumerate() { + if d != reduce_dim { + others_product *= len + } + } + JitAutotuneKey::ReduceDim(ReduceAutotuneKey::new( + dims[reduce_dim], + input.strides[reduce_dim], + others_product, + EI::dtype(), + )) +} diff --git a/crates/cubecl-reduce/src/tune/mod.rs b/crates/cubecl-reduce/src/tune/mod.rs new file mode 100644 index 00000000..aee5569b --- /dev/null +++ b/crates/cubecl-reduce/src/tune/mod.rs @@ -0,0 +1,7 @@ +#[cfg(feature = "autotune")] +mod base; +mod key; + +#[cfg(feature = "autotune")] +pub(crate) use base::*; +pub use key::*; From cdd3eaa3fa4a0ea449a2123cea06f84e383251f1 Mon Sep 17 00:00:00 2001 From: maxime Date: Tue, 26 Nov 2024 15:09:50 -0500 Subject: [PATCH 2/9] remove autotune for now --- crates/cubecl-reduce/src/base.rs | 88 ++------------------------------ 1 file changed, 5 insertions(+), 83 deletions(-) diff --git a/crates/cubecl-reduce/src/base.rs b/crates/cubecl-reduce/src/base.rs index 5cb9ebfb..5dbea36c 100644 --- a/crates/cubecl-reduce/src/base.rs +++ b/crates/cubecl-reduce/src/base.rs @@ -1,83 +1,5 @@ -use cubecl::prelude::Numeric; - -#[cfg(feature = "autotune")] -use crate::kernel::reduce::reduce_dim_autotune; -use crate::{element::JitElement, ops::numeric::empty_device, tensor::JitTensor, JitRuntime}; - -use super::{ - naive::{base::ReduceDimNaive, kernel::reduce_dim_naive}, - shared::{base::ReduceDimShared, kernel::reduce_dim_shared}, - subcube::{base::ReduceDimSubcube, kernel::reduce_dim_subcube}, -}; - -#[allow(dead_code)] -pub(crate) trait ReduceDimAlgorithm: - core::fmt::Debug + ReduceDimNaive + ReduceDimShared + ReduceDimSubcube -{ -} - -/// Creates an empty output tensor with reduce output shape -pub fn init_reduce_output( - input: &JitTensor, - reduce_dim: usize, -) -> JitTensor { - let mut shape_out = input.shape.clone(); - shape_out.dims[reduce_dim] = 1; - - empty_device::(input.client.clone(), input.device.clone(), shape_out) -} - -#[derive(Copy, Clone, Debug)] -#[allow(missing_docs)] -pub enum ReduceStrategy { - /// Naive - Naive, - /// Use shared memory as an accumulator - SharedMemory, - /// Use subcube functions - Subcube, - #[cfg(feature = "autotune")] - Autotune, -} - -impl Default for ReduceStrategy { - fn default() -> Self { - // if autotune is enabled, default to autotune - #[cfg(feature = "autotune")] - return ReduceStrategy::Autotune; - - #[cfg(not(feature = "autotune"))] - ReduceStrategy::Naive - } -} - -macro_rules! reduce_operation { - ($name:ident, $ops:ident) => { - #[derive(Debug)] - pub(crate) struct $ops; - - impl ReduceDimAlgorithm for $ops {} - - /// Executes the reduce operation with the given strategy. - pub fn $name( - tensor: JitTensor, - dim: usize, - strategy: ReduceStrategy, - ) -> JitTensor { - match strategy { - ReduceStrategy::Naive => reduce_dim_naive::<$ops, R, EI, EO>(tensor, dim), - ReduceStrategy::SharedMemory => reduce_dim_shared::<$ops, R, EI, EO>(tensor, dim), - ReduceStrategy::Subcube => reduce_dim_subcube::<$ops, R, EI, EO>(tensor, dim), - #[cfg(feature = "autotune")] - ReduceStrategy::Autotune => reduce_dim_autotune::<$ops, R, EI, EO>(tensor, dim), - } - } - }; -} - -// Autotunable reduce operation variants -reduce_operation!(sum_dim, SumDim); -reduce_operation!(mean_dim, MeanDim); -reduce_operation!(prod_dim, ProdDim); -reduce_operation!(argmin, Argmin); -reduce_operation!(argmax, Argmax); +pub struct ArgMax; +pub struct ArgMin; +pub struct MeanDim; +pub struct SumDim; +pub struct ProdDim; From 2095415947522e147544171a9843dbc452554e32 Mon Sep 17 00:00:00 2001 From: maxime Date: Tue, 26 Nov 2024 15:14:14 -0500 Subject: [PATCH 3/9] impl complete test for naive reduce --- crates/cubecl-reduce/src/instructions.rs | 5 + crates/cubecl-reduce/src/lib.rs | 22 +- crates/cubecl-reduce/src/naive/argmax.rs | 12 +- crates/cubecl-reduce/src/naive/argmin.rs | 10 +- crates/cubecl-reduce/src/naive/base.rs | 37 +- crates/cubecl-reduce/src/naive/kernel.rs | 61 --- crates/cubecl-reduce/src/naive/mean_dim.rs | 6 +- crates/cubecl-reduce/src/naive/mod.rs | 13 +- crates/cubecl-reduce/src/naive/prod_dim.rs | 5 +- crates/cubecl-reduce/src/naive/sum_dim.rs | 5 +- crates/cubecl-reduce/src/test.rs | 514 ++++++++++++++------- 11 files changed, 419 insertions(+), 271 deletions(-) create mode 100644 crates/cubecl-reduce/src/instructions.rs delete mode 100644 crates/cubecl-reduce/src/naive/kernel.rs diff --git a/crates/cubecl-reduce/src/instructions.rs b/crates/cubecl-reduce/src/instructions.rs new file mode 100644 index 00000000..5dbea36c --- /dev/null +++ b/crates/cubecl-reduce/src/instructions.rs @@ -0,0 +1,5 @@ +pub struct ArgMax; +pub struct ArgMin; +pub struct MeanDim; +pub struct SumDim; +pub struct ProdDim; diff --git a/crates/cubecl-reduce/src/lib.rs b/crates/cubecl-reduce/src/lib.rs index ebc9a5b0..3f194e5f 100644 --- a/crates/cubecl-reduce/src/lib.rs +++ b/crates/cubecl-reduce/src/lib.rs @@ -1,15 +1,15 @@ -mod base; +mod instructions; mod naive; -mod prod; -mod shared; -mod subcube; -mod sum; -mod tune; +// mod prod; +// mod shared; +// mod subcube; +// mod sum; +// mod tune; -pub use base::*; -pub use prod::*; -pub use sum::*; -pub use tune::*; +pub use instructions::*; +// pub use prod::*; +// pub use sum::*; +// pub use tune::*; -#[cfg(export_tests)] +#[cfg(feature = "export_tests")] pub mod test; diff --git a/crates/cubecl-reduce/src/naive/argmax.rs b/crates/cubecl-reduce/src/naive/argmax.rs index ba51ec72..10e03654 100644 --- a/crates/cubecl-reduce/src/naive/argmax.rs +++ b/crates/cubecl-reduce/src/naive/argmax.rs @@ -1,16 +1,16 @@ -use cubecl::prelude::*; - -use crate::{kernel::reduce::Argmax, JitElement}; +use cubecl_core as cubecl; +use cubecl_core::prelude::*; use super::base::ReduceDimNaive; +use crate::ArgMax; #[cube] -impl ReduceDimNaive for Argmax { +impl ReduceDimNaive for ArgMax { type Accumulator = (EI, u32); fn initialize_naive() -> Self::Accumulator { // TODO: switch to using f32::NEG_INFINITY when it's supported: https://github.com/tracel-ai/cubecl/issues/68 - (comptime![EI::minimum_value()].runtime(), 0u32) + (comptime![EI::MIN].runtime(), 0u32) } fn inner_loop_naive(accumulator: &mut Self::Accumulator, current_value: EI, i: u32) { @@ -29,4 +29,6 @@ impl ReduceDimNaive for Argmax { let (_, index) = accumulator; output[ABSOLUTE_POS] = EO::cast_from(index); } + + } diff --git a/crates/cubecl-reduce/src/naive/argmin.rs b/crates/cubecl-reduce/src/naive/argmin.rs index 0418d20a..32a0a7e1 100644 --- a/crates/cubecl-reduce/src/naive/argmin.rs +++ b/crates/cubecl-reduce/src/naive/argmin.rs @@ -1,16 +1,16 @@ -use cubecl::prelude::*; - -use crate::{kernel::reduce::Argmin, JitElement}; +use cubecl_core as cubecl; +use cubecl_core::prelude::*; use super::base::ReduceDimNaive; +use crate::ArgMin; #[cube] -impl ReduceDimNaive for Argmin { +impl ReduceDimNaive for ArgMin { type Accumulator = (EI, u32); fn initialize_naive() -> Self::Accumulator { // TODO: switch to using f32::INFINITY when it's supported: https://github.com/tracel-ai/cubecl/issues/68 - (comptime![EI::maximum_value()].runtime(), 0u32) + (comptime![EI::MAX].runtime(), 0u32) } fn inner_loop_naive(accumulator: &mut Self::Accumulator, current_value: EI, i: u32) { diff --git a/crates/cubecl-reduce/src/naive/base.rs b/crates/cubecl-reduce/src/naive/base.rs index 82194ea3..9cf756ab 100644 --- a/crates/cubecl-reduce/src/naive/base.rs +++ b/crates/cubecl-reduce/src/naive/base.rs @@ -1,4 +1,5 @@ -use cubecl::prelude::*; +use cubecl_core as cubecl; +use cubecl_core::prelude::*; /// Specifies the reduce dim algorithm in use #[cube] @@ -19,3 +20,37 @@ pub trait ReduceDimNaive: Send + Sync + 'static { shape_reduce_dim: u32, ); } + +#[cube] +pub fn reduce_dim_naive, EI: Numeric, EO: Numeric>( + input: &Tensor, + output: &mut Tensor, + dim: u32, +) { + if ABSOLUTE_POS >= output.len() { + return; + }; + + let mut offset_input = 0; + + for i in 0..input.rank() { + let mut offset_local = ABSOLUTE_POS / output.stride(i); + offset_local %= output.shape(i); + if i != dim { + offset_input += offset_local * input.stride(i); + } + } + + let mut accumulator = RD::initialize_naive(); + + for i in 0..input.shape(dim) { + let index = i * input.stride(dim) + offset_input; + RD::inner_loop_naive( + &mut accumulator, + unsafe { *input.index_unchecked(index) }, + i, + ); + } + + RD::assign_naive::(output, accumulator, input.shape(dim)); +} diff --git a/crates/cubecl-reduce/src/naive/kernel.rs b/crates/cubecl-reduce/src/naive/kernel.rs deleted file mode 100644 index c001edca..00000000 --- a/crates/cubecl-reduce/src/naive/kernel.rs +++ /dev/null @@ -1,61 +0,0 @@ -use crate::{ - element::JitElement, kernel::reduce::init_reduce_output, tensor::JitTensor, JitRuntime, -}; -use cubecl::calculate_cube_count_elemwise; -use cubecl::prelude::*; - -use super::base::ReduceDimNaive; - -#[cube(launch_unchecked)] -pub(crate) fn naive_reduce_dim_kernel, EI: Numeric, EO: Numeric>( - input: &Tensor, - output: &mut Tensor, - dim: u32, -) { - if ABSOLUTE_POS >= output.len() { - return; - } - - let mut offset_input = 0; - - for i in 0..input.rank() { - let mut offset_local = ABSOLUTE_POS / output.stride(i); - offset_local %= output.shape(i); - if i != dim { - offset_input += offset_local * input.stride(i); - } - } - - let mut accumulator = RD::initialize_naive(); - - for i in 0..input.shape(dim) { - let index = i * input.stride(dim) + offset_input; - RD::inner_loop_naive(&mut accumulator, input[index], i); - } - - RD::assign_naive::(output, accumulator, input.shape(dim)); -} - -/// Executes the naive kernel for reduce dim -pub fn reduce_dim_naive, R: JitRuntime, EI: JitElement, EO: JitElement>( - input: JitTensor, - dim: usize, -) -> JitTensor { - let output = init_reduce_output::(&input, dim); - - let cube_dim = CubeDim::default(); - let cube_count = calculate_cube_count_elemwise(output.shape.num_elements(), cube_dim); - - unsafe { - naive_reduce_dim_kernel::launch_unchecked::( - &input.client, - cube_count, - cube_dim, - input.as_tensor_arg::(1), - output.as_tensor_arg::(1), - ScalarArg::new(dim as u32), - ); - } - - output -} diff --git a/crates/cubecl-reduce/src/naive/mean_dim.rs b/crates/cubecl-reduce/src/naive/mean_dim.rs index c3210faf..98e65727 100644 --- a/crates/cubecl-reduce/src/naive/mean_dim.rs +++ b/crates/cubecl-reduce/src/naive/mean_dim.rs @@ -1,8 +1,8 @@ -use cubecl::prelude::*; - -use crate::kernel::reduce::MeanDim; +use cubecl_core as cubecl; +use cubecl_core::prelude::*; use super::base::ReduceDimNaive; +use crate::MeanDim; #[cube] impl ReduceDimNaive for MeanDim { diff --git a/crates/cubecl-reduce/src/naive/mod.rs b/crates/cubecl-reduce/src/naive/mod.rs index b11ee5e2..91a3f1ba 100644 --- a/crates/cubecl-reduce/src/naive/mod.rs +++ b/crates/cubecl-reduce/src/naive/mod.rs @@ -1,7 +1,6 @@ -pub(crate) mod argmax; -pub(crate) mod argmin; -pub(crate) mod base; -pub(crate) mod kernel; -pub(crate) mod mean_dim; -pub(crate) mod prod_dim; -pub(crate) mod sum_dim; +pub mod argmax; +pub mod argmin; +pub mod base; +pub mod mean_dim; +pub mod prod_dim; +pub mod sum_dim; diff --git a/crates/cubecl-reduce/src/naive/prod_dim.rs b/crates/cubecl-reduce/src/naive/prod_dim.rs index 5dfedf73..4e73e062 100644 --- a/crates/cubecl-reduce/src/naive/prod_dim.rs +++ b/crates/cubecl-reduce/src/naive/prod_dim.rs @@ -1,6 +1,7 @@ -use cubecl::prelude::*; +use cubecl_core::prelude::*; +use cubecl_core as cubecl; -use crate::kernel::reduce::ProdDim; +use crate::ProdDim; use super::base::ReduceDimNaive; diff --git a/crates/cubecl-reduce/src/naive/sum_dim.rs b/crates/cubecl-reduce/src/naive/sum_dim.rs index 6d16669e..e21e23d6 100644 --- a/crates/cubecl-reduce/src/naive/sum_dim.rs +++ b/crates/cubecl-reduce/src/naive/sum_dim.rs @@ -1,6 +1,7 @@ -use cubecl::prelude::*; +use cubecl_core as cubecl; +use cubecl_core::prelude::*; -use crate::kernel::reduce::SumDim; +use crate::SumDim; use super::base::ReduceDimNaive; diff --git a/crates/cubecl-reduce/src/test.rs b/crates/cubecl-reduce/src/test.rs index 345c9d7b..548e86de 100644 --- a/crates/cubecl-reduce/src/test.rs +++ b/crates/cubecl-reduce/src/test.rs @@ -1,26 +1,23 @@ #![allow(missing_docs)] -use cubecl_core::{prelude::*, Feature}; +use cubecl_core as cubecl; +use cubecl_core::prelude::*; -use crate::sum::{reduce_sum, reduce_sum_lined, ReduceConfig}; +use crate::naive::base::{reduce_dim_naive, ReduceDimNaive}; +use crate::{ArgMax, ArgMin, MeanDim, ProdDim, SumDim}; -#[macro_export] -macro_rules! impl_test_reduce_sum_vector { - ($float:ident, [$(($num_values:expr, $cube_size:expr, $line_size:expr)),*]) => { - ::paste::paste! { - $( - #[test] - pub fn []() { - TestCase::<$float>::sum_vector(32, 32, 1).run::(&Default::default()); - } - )* - } - }; +#[cube(launch_unchecked)] +pub fn naive_reduce_dim_kernel>( + input: &Tensor, + output: &mut Tensor, + dim: u32, +) { + reduce_dim_naive::(input, output, dim) } #[macro_export] macro_rules! testgen_reduce { - ([$($float:ident),*]) => { + ([$($float:ident), *]) => { mod test_reduce { use super::*; ::paste::paste! { @@ -34,209 +31,378 @@ macro_rules! testgen_reduce { }; ($float:ident) => { - use super::*; - use cubecl_core::as_type; - use cubecl_core::prelude::Float; - use cubecl_core::CubeCount; use cubecl_reduce::test::TestCase; + use cubecl_core::prelude::CubeCount; - $crate::impl_test_reduce_sum_vector!( + $crate::impl_test_reduce!( $float, [ - (32, 32, 1), - (64, 32, 1), - (100, 32, 1), - (1000, 32, 1), - (2048, 32, 1), - (32, 64, 1), - (64, 64, 1), - (100, 64, 1), - (1000, 64, 1), - (2048, 64, 1), - (32, 1024, 1), - (64, 1024, 1), - (100, 1024, 1), - (1000, 1024, 1), - (2048, 1024, 1), - (32, 32, 2), - (64, 32, 2), - (100, 32, 2), - (1000, 32, 2), - (2048, 32, 2), - (32, 64, 2), - (64, 64, 2), - (100, 64, 2), - (1000, 64, 2), - (2048, 64, 2), - (32, 1024, 2), - (64, 1024, 2), - (100, 1024, 2), - (1000, 1024, 2), - (2048, 1024, 2), - (32, 32, 4), - (64, 32, 4), - (100, 32, 4), - (1000, 32, 4), - (2048, 32, 4), - (32, 64, 4), - (64, 64, 4), - (100, 64, 4), - (1000, 64, 4), - (2048, 64, 4), - (32, 1024, 4), - (64, 1024, 4), - (100, 1024, 4), - (1000, 1024, 4), - (2048, 1024, 4) + { + id: "reduce_columns_small_matrix_row_major", + shape: [4, 8], + stride: [8, 1], + reduce_dim: 0, + cube_count: CubeCount::Static(1, 1, 1), + cube_dim: CubeDim::new(4, 8, 1), + line_size: 1, + }, + { + id: "reduce_columns_large_matrix_row_major", + shape: [8, 256], + stride: [256, 1], + reduce_dim: 1, + cube_count: CubeCount::Static(8, 1, 1), + cube_dim: CubeDim::new(16, 16, 1), + line_size: 1, + }, + { + id: "reduce_rows_large_matrix_row_major", + shape: [8, 256], + stride: [256, 1], + reduce_dim: 0, + cube_count: CubeCount::Static(8, 1, 1), + cube_dim: CubeDim::new(16, 16, 1), + line_size: 1, + }, + { + id: "rank_three_tensor", + shape: [16, 16, 16], + stride: [1, 256, 16], + reduce_dim: 2, + cube_count: CubeCount::Static(4, 1, 1), + cube_dim: CubeDim::new(16, 16, 1), + line_size: 1, + }, + { + id: "rank_three_tensor_unexact_shape", + shape: [11, 12, 13], + stride: [156, 13, 1], + reduce_dim: 1, + cube_count: CubeCount::Static(4, 1, 1), + cube_dim: CubeDim::new(16, 16, 1), + line_size: 1, + }, + { + id: "reduce_rows_large_matrix_row_major_line_size_four", + shape: [8, 256], + stride: [256, 1], + reduce_dim: 0, + cube_count: CubeCount::Static(8, 1, 1), + cube_dim: CubeDim::new(16, 16, 1), + line_size: 4, + } ] ); }; } -#[derive(Debug)] -pub struct TestTensorParts { - pub values: Vec, - pub stride: Vec, - pub shape: Vec, - pub line_size: u8, -} +#[macro_export] +macro_rules! impl_test_reduce { + ( + $float:ident, + [ + $( + { + id: $id:literal, + shape: $shape:expr, + stride: $stride:expr, + reduce_dim: $reduce_dim:expr, + cube_count: $cube_count:expr, + cube_dim: $cube_dim:expr, + line_size: $line_size:expr, + } + ),* + ]) => { + ::paste::paste! { + $( + #[test] + pub fn [< reduce_sum_dim_naive_ $id >]() { + let test = TestCase { + shape: $shape.into(), + stride: $stride.into(), + reduce_dim: $reduce_dim, + cube_count: $cube_count, + cube_dim: $cube_dim, + line_size:$line_size + }; + test.test_sum_dim_naive::<$float, TestRuntime>(&Default::default()); + } -impl TestTensorParts { - pub fn new_vector(values: Vec) -> Self { - let shape = vec![values.len()]; - Self { - values, - stride: vec![1], - shape, - line_size: 1, - } - } + #[test] + pub fn [< reduce_prod_dim_naive_ $id >]() { + let test = TestCase { + shape: $shape.into(), + stride: $stride.into(), + reduce_dim: $reduce_dim, + cube_count: $cube_count, + cube_dim: $cube_dim, + line_size:$line_size + }; + test.test_prod_dim_naive::<$float, TestRuntime>(&Default::default()); + } - pub fn range_vector(stop: usize) -> Self { - let values = (0..stop).map(|x| N::new(x as f32)).collect(); - Self::new_vector(values) - } + #[test] + pub fn [< reduce_mean_dim_naive_ $id >]() { + let test = TestCase { + shape: $shape.into(), + stride: $stride.into(), + reduce_dim: $reduce_dim, + cube_count: $cube_count, + cube_dim: $cube_dim, + line_size:$line_size + }; + test.test_mean_dim_naive::<$float, TestRuntime>(&Default::default()); + } - pub fn zero_vector(size: usize) -> Self { - let values = vec![N::new(0.0); size]; - Self::new_vector(values) - } + #[test] + pub fn [< reduce_argmax_dim_naive_ $id >]() { + let test = TestCase { + shape: $shape.into(), + stride: $stride.into(), + reduce_dim: $reduce_dim, + cube_count: $cube_count, + cube_dim: $cube_dim, + line_size:$line_size + }; + test.test_argmax_dim_naive::<$float, TestRuntime>(&Default::default()); + } - pub fn with_line_size(mut self, line_size: u8) -> Self { - self.line_size = line_size; - self - } + #[test] + pub fn [< reduce_argmin_dim_naive_ $id >]() { + let test = TestCase { + shape: $shape.into(), + stride: $stride.into(), + reduce_dim: $reduce_dim, + cube_count: $cube_count, + cube_dim: $cube_dim, + line_size:$line_size + }; + test.test_argmin_dim_naive::<$float, TestRuntime>(&Default::default()); + } + )* + } + }; } #[derive(Debug)] -pub struct TestCase { - pub input: TestTensorParts, - pub output: TestTensorParts, - pub expected: Vec, +pub struct TestCase { + pub shape: Vec, + pub stride: Vec, + pub reduce_dim: u32, + pub line_size: u8, pub cube_count: CubeCount, pub cube_dim: CubeDim, - pub sum_dim: u32, - pub reduce_lines: bool, } -impl TestCase { - pub fn new(input: TestTensorParts, output: TestTensorParts, expected: Vec) -> Self { - Self { - input, - output, - expected, - cube_count: CubeCount::Static(1, 1, 1), - cube_dim: CubeDim::new(32, 1, 1), - sum_dim: 0, - reduce_lines: false, - } +impl TestCase { + pub fn test_sum_dim_naive(&self, device: &R::Device) + where + F: Float + CubeElement + std::fmt::Display, + R: Runtime, + { + let input_values: Vec = self.random_input_values(); + let expected_values = self.cpu_sum_dim(&input_values); + self.run_test::(device, input_values, expected_values) } - /// ASSUMPTION: line_size divide num_values exactly - pub fn sum_vector(num_values: usize, cube_size: u32, line_size: usize) -> Self + pub fn test_prod_dim_naive(&self, device: &R::Device) where - F: Float, + F: Float + CubeElement + std::fmt::Display, + R: Runtime, { - // Compute the sums on the cpu. - let values_per_sum = num_values / line_size; - let partial_sum = values_per_sum * (values_per_sum - 1) / 2; - let mut sums = vec![0; line_size]; - for k in 0..line_size { - sums[k] = partial_sum + values_per_sum * k; - } - let sums = sums.into_iter().map(|s| F::new(s as f32)).collect(); - - let mut test = TestCase::new( - // input - TestTensorParts::range_vector(num_values), - // output - TestTensorParts::zero_vector(line_size), - // expected - sums, - ); - test.cube_dim = CubeDim::new(cube_size, 1, 1); - test + let input_values: Vec = self.random_input_values(); + let expected_values = self.cpu_prod_dim(&input_values); + self.run_test::(device, input_values, expected_values) } - pub fn run(self, device: &R::Device) + pub fn test_mean_dim_naive(&self, device: &R::Device) where F: Float + CubeElement + std::fmt::Display, + R: Runtime, + { + let input_values: Vec = self.random_input_values(); + let expected_values = self.cpu_mean_dim(&input_values); + self.run_test::(device, input_values, expected_values) + } + + pub fn test_argmax_dim_naive(&self, device: &R::Device) + where + F: Float + CubeElement + std::fmt::Display, + R: Runtime, + { + let input_values: Vec = self.random_input_values(); + let expected_values = self.cpu_argmax_dim(&input_values); + self.run_test::(device, input_values, expected_values) + } + + pub fn test_argmin_dim_naive(&self, device: &R::Device) + where + F: Float + CubeElement + std::fmt::Display, + R: Runtime, + { + let input_values: Vec = self.random_input_values(); + let expected_values = self.cpu_argmin_dim(&input_values); + self.run_test::(device, input_values, expected_values) + } + + pub fn run_test( + &self, + device: &R::Device, + input_values: Vec, + expected_values: Vec, + ) where + I: Numeric + CubeElement + std::fmt::Display, + O: Numeric + CubeElement + std::fmt::Display, + R: Runtime, + K: ReduceDimNaive, { let client = R::client(device); - if !client.properties().feature_enabled(Feature::Plane) { - // Can't execute the test. - return; - } - let input_handle = client.create(F::as_bytes(&self.input.values)); - let output_handle = client.create(F::as_bytes(&self.output.values)); + let input_handle = client.create(I::as_bytes(&input_values)); - let config = ReduceConfig { - line_size: self.input.line_size as u32, - max_num_planes: self.cube_dim.num_elems() - / client.properties().hardware_properties().plane_size_min, - }; + let output_handle = + client.create(O::as_bytes(&vec![O::from_int(0); expected_values.len()])); + let mut output_shape = self.shape.clone(); + output_shape[self.reduce_dim as usize] = 1; + let output_stride = self.output_stride(); unsafe { - let input_tensor = TensorArg::from_raw_parts::( + let input_tensor = TensorArg::from_raw_parts::( &input_handle, - &self.input.stride, - &self.input.shape, - self.input.line_size, + &self.stride, + &self.shape, + self.line_size, ); - let output_tensor = TensorArg::from_raw_parts::( + let output_tensor = TensorArg::from_raw_parts::( &output_handle, - &self.output.stride, - &self.output.shape, - self.output.line_size, + &output_stride, + &output_shape, + self.line_size, ); - if self.reduce_lines { - reduce_sum_lined::launch_unchecked::( - &client, - self.cube_count, - self.cube_dim, - input_tensor, - output_tensor, - config, - ); - } else { - reduce_sum::launch_unchecked::( - &client, - self.cube_count, - self.cube_dim, - input_tensor, - output_tensor, - config, - ); - } + naive_reduce_dim_kernel::launch_unchecked::( + &client, + self.cube_count.clone(), + self.cube_dim.clone(), + input_tensor, + output_tensor, + ScalarArg::new(self.reduce_dim.clone()), + ); } let binding = output_handle.binding(); let bytes = client.read_one(binding); - let output_values = F::from_bytes(&bytes); + let output_values = O::from_bytes(&bytes); + + assert_approx_equal_abs(output_values, &expected_values, 1e-9); + } + + fn cpu_sum_dim(&self, values: &[F]) -> Vec { + let mut expected = vec![F::new(0.0); self.num_output_values()]; + for input_index in 0..values.len() { + let output_index = self.to_output_index(input_index); + expected[output_index] += values[input_index]; + } + expected + } + + fn cpu_prod_dim(&self, values: &[F]) -> Vec { + let mut expected = vec![F::new(1.0); self.num_output_values()]; + for input_index in 0..values.len() { + let output_index = self.to_output_index(input_index); + expected[output_index] *= values[input_index]; + } + expected + } + + fn cpu_mean_dim(&self, values: &[F]) -> Vec { + self.cpu_sum_dim(values) + .into_iter() + .map(|sum| sum / F::new(self.shape[self.reduce_dim as usize] as f32)) + .collect() + } + + fn cpu_argmax_dim(&self, values: &[F]) -> Vec { + let mut expected = vec![(F::MIN, 0_u32); self.num_output_values()]; + for input_index in 0..values.len() { + let output_index = self.to_output_index(input_index); + let (best, _) = expected[output_index]; + let candidate = values[input_index]; + if candidate > best { + let coordinate = self.to_input_coordinate(input_index); + expected[output_index] = (candidate, coordinate[self.reduce_dim as usize] as u32); + } + } + expected.into_iter().map(|(_, i)| i).collect() + } + + fn cpu_argmin_dim(&self, values: &[F]) -> Vec { + let mut expected = vec![(F::MAX, 0_u32); self.num_output_values()]; + for input_index in 0..values.len() { + let output_index = self.to_output_index(input_index); + let (best, _) = expected[output_index]; + let candidate = values[input_index]; + if candidate < best { + let coordinate = self.to_input_coordinate(input_index); + expected[output_index] = (candidate, coordinate[self.reduce_dim as usize] as u32); + } + } + expected.into_iter().map(|(_, i)| i).collect() + } + + fn num_output_values(&self) -> usize { + self.shape.iter().product::() / self.shape[self.reduce_dim as usize] + } + + fn to_output_index(&self, input_index: usize) -> usize { + let mut coordinate = self.to_input_coordinate(input_index); + coordinate[self.reduce_dim as usize] = 0; + self.from_output_coordinate(coordinate) + } + + fn to_input_coordinate(&self, index: usize) -> Vec { + self.stride + .iter() + .zip(self.shape.iter()) + .map(|(stride, shape)| (index / stride) % shape) + .collect() + } + + fn from_output_coordinate(&self, coordinate: Vec) -> usize { + coordinate + .into_iter() + .zip(self.output_stride().iter()) + .map(|(c, s)| c * s) + .sum() + } + + fn output_stride(&self) -> Vec { + let dim_stride = self.stride[self.reduce_dim as usize]; + let dim_shape = self.shape[self.reduce_dim as usize]; + self.stride + .iter() + .map(|s| match s.cmp(&dim_stride) { + std::cmp::Ordering::Equal => 1, + std::cmp::Ordering::Greater => s / dim_shape, + std::cmp::Ordering::Less => *s, + }) + .collect() + } + + fn random_input_values(&self) -> Vec { + let size = self.shape.iter().product::() * self.line_size as usize; + + fn lcg(seed: &mut u64) -> f32 { + const A: u64 = 1664525; + const C: u64 = 1013904223; + const M: f64 = 2u64.pow(32) as f64; + + *seed = (A.wrapping_mul(*seed).wrapping_add(C)) % (1u64 << 32); + (*seed as f64 / M * 2.0 - 1.0) as f32 + } - assert_approx_equal_abs(output_values, &self.expected, 1e-9); + let mut seed = 123456789; // Not really important for testing. + (0..size).map(|_| F::new(lcg(&mut seed))).collect() } } From 573be8be823ba09dc4910e3a2f8cc0f26781b16e Mon Sep 17 00:00:00 2001 From: maxime Date: Wed, 27 Nov 2024 14:40:56 -0500 Subject: [PATCH 4/9] Add line support to naive reduction and test --- crates/cubecl-reduce/src/naive/base.rs | 14 +++++------ crates/cubecl-reduce/src/naive/mean_dim.rs | 17 +++++++------ crates/cubecl-reduce/src/naive/prod_dim.rs | 18 ++++++++------ crates/cubecl-reduce/src/naive/sum_dim.rs | 18 ++++++++------ crates/cubecl-reduce/src/test.rs | 29 ++++++++++++++-------- 5 files changed, 57 insertions(+), 39 deletions(-) diff --git a/crates/cubecl-reduce/src/naive/base.rs b/crates/cubecl-reduce/src/naive/base.rs index 9cf756ab..3e08d6cc 100644 --- a/crates/cubecl-reduce/src/naive/base.rs +++ b/crates/cubecl-reduce/src/naive/base.rs @@ -8,14 +8,14 @@ pub trait ReduceDimNaive: Send + Sync + 'static { type Accumulator: CubeType; /// Initialization for naive algorithm - fn initialize_naive() -> Self::Accumulator; + fn initialize_naive(line_size: u32) -> Self::Accumulator; /// Inner loop for naive algorithm - fn inner_loop_naive(accumulator: &mut Self::Accumulator, current_value: EI, i: u32); + fn inner_loop_naive(accumulator: &mut Self::Accumulator, current_value: Line, i: u32); /// Assignation for naive algorithm fn assign_naive( - output: &mut Tensor, + output: &mut Tensor>, accumulator: Self::Accumulator, shape_reduce_dim: u32, ); @@ -23,11 +23,11 @@ pub trait ReduceDimNaive: Send + Sync + 'static { #[cube] pub fn reduce_dim_naive, EI: Numeric, EO: Numeric>( - input: &Tensor, - output: &mut Tensor, + input: &Tensor>, + output: &mut Tensor>, dim: u32, ) { - if ABSOLUTE_POS >= output.len() { + if ABSOLUTE_POS >= output.len() * output.line_size() { return; }; @@ -41,7 +41,7 @@ pub fn reduce_dim_naive, EI: Numeric, EO: Numeric>( } } - let mut accumulator = RD::initialize_naive(); + let mut accumulator = RD::initialize_naive(input.line_size()); for i in 0..input.shape(dim) { let index = i * input.stride(dim) + offset_input; diff --git a/crates/cubecl-reduce/src/naive/mean_dim.rs b/crates/cubecl-reduce/src/naive/mean_dim.rs index 98e65727..e070c6ef 100644 --- a/crates/cubecl-reduce/src/naive/mean_dim.rs +++ b/crates/cubecl-reduce/src/naive/mean_dim.rs @@ -6,18 +6,21 @@ use crate::MeanDim; #[cube] impl ReduceDimNaive for MeanDim { - type Accumulator = EI; + type Accumulator = Line; - fn initialize_naive() -> EI { - EI::from_int(0) + fn initialize_naive(line_size: u32) -> Self::Accumulator { + Line::empty(line_size).fill(EI::from_int(0)) } - fn inner_loop_naive(accumulator: &mut EI, current_value: EI, _i: u32) { + fn inner_loop_naive(accumulator: &mut Self::Accumulator, current_value: Line, _i: u32) { *accumulator += current_value; } - fn assign_naive(output: &mut Tensor, accumulator: EI, shape_reduce_dim: u32) { - let mean = accumulator / EI::cast_from(shape_reduce_dim); - output[ABSOLUTE_POS] = EO::cast_from(mean); + fn assign_naive( + output: &mut Tensor>, + accumulator: Self::Accumulator, + shape_reduce_dim: u32, + ) { + output[ABSOLUTE_POS] = Line::cast_from(accumulator / Line::empty(output.line_size()).fill(EI::cast_from(shape_reduce_dim))); } } diff --git a/crates/cubecl-reduce/src/naive/prod_dim.rs b/crates/cubecl-reduce/src/naive/prod_dim.rs index 4e73e062..8a2f7eef 100644 --- a/crates/cubecl-reduce/src/naive/prod_dim.rs +++ b/crates/cubecl-reduce/src/naive/prod_dim.rs @@ -1,5 +1,5 @@ -use cubecl_core::prelude::*; use cubecl_core as cubecl; +use cubecl_core::prelude::*; use crate::ProdDim; @@ -7,17 +7,21 @@ use super::base::ReduceDimNaive; #[cube] impl ReduceDimNaive for ProdDim { - type Accumulator = EI; + type Accumulator = Line; - fn initialize_naive() -> EI { - EI::from_int(1) + fn initialize_naive(line_size: u32) -> Line { + Line::empty(line_size).fill(EI::from_int(1)) } - fn inner_loop_naive(accumulator: &mut EI, current_value: EI, _i: u32) { + fn inner_loop_naive(accumulator: &mut Self::Accumulator, current_value: Line, _i: u32) { *accumulator *= current_value; } - fn assign_naive(output: &mut Tensor, accumulator: EI, _shape_reduce_dim: u32) { - output[ABSOLUTE_POS] = EO::cast_from(accumulator); + fn assign_naive( + output: &mut Tensor>, + accumulator: Self::Accumulator, + _shape_reduce_dim: u32, + ) { + output[ABSOLUTE_POS] = Line::cast_from(accumulator); } } diff --git a/crates/cubecl-reduce/src/naive/sum_dim.rs b/crates/cubecl-reduce/src/naive/sum_dim.rs index e21e23d6..2577bfc8 100644 --- a/crates/cubecl-reduce/src/naive/sum_dim.rs +++ b/crates/cubecl-reduce/src/naive/sum_dim.rs @@ -7,17 +7,21 @@ use super::base::ReduceDimNaive; #[cube] impl ReduceDimNaive for SumDim { - type Accumulator = EI; + type Accumulator = Line; - fn initialize_naive() -> EI { - EI::from_int(0) + fn initialize_naive(line_size: u32) -> Line { + Line::empty(line_size).fill(EI::from_int(0)) } - fn inner_loop_naive(accumulator: &mut EI, current_value: EI, _i: u32) { + fn inner_loop_naive(accumulator: &mut Self::Accumulator, current_value: Line, _i: u32) { *accumulator += current_value; } - fn assign_naive(output: &mut Tensor, accumulator: EI, _shape_reduce_dim: u32) { - output[ABSOLUTE_POS] = EO::cast_from(accumulator); - } + fn assign_naive( + output: &mut Tensor>, + accumulator: Self::Accumulator, + _shape_reduce_dim: u32, + ) { + output[ABSOLUTE_POS] = Line::cast_from(accumulator); + } } diff --git a/crates/cubecl-reduce/src/test.rs b/crates/cubecl-reduce/src/test.rs index 548e86de..b24facdd 100644 --- a/crates/cubecl-reduce/src/test.rs +++ b/crates/cubecl-reduce/src/test.rs @@ -8,8 +8,8 @@ use crate::{ArgMax, ArgMin, MeanDim, ProdDim, SumDim}; #[cube(launch_unchecked)] pub fn naive_reduce_dim_kernel>( - input: &Tensor, - output: &mut Tensor, + input: &Tensor>, + output: &mut Tensor>, dim: u32, ) { reduce_dim_naive::(input, output, dim) @@ -84,8 +84,8 @@ macro_rules! testgen_reduce { }, { id: "reduce_rows_large_matrix_row_major_line_size_four", - shape: [8, 256], - stride: [256, 1], + shape: [32, 64], + stride: [64, 1], reduce_dim: 0, cube_count: CubeCount::Static(8, 1, 1), cube_dim: CubeDim::new(16, 16, 1), @@ -154,6 +154,8 @@ macro_rules! impl_test_reduce { test.test_mean_dim_naive::<$float, TestRuntime>(&Default::default()); } + // Fix the line issue in argmax before running the test. + #[ignore] #[test] pub fn [< reduce_argmax_dim_naive_ $id >]() { let test = TestCase { @@ -167,6 +169,8 @@ macro_rules! impl_test_reduce { test.test_argmax_dim_naive::<$float, TestRuntime>(&Default::default()); } + // Fix the line issue in argmin before running the test. + #[ignore] #[test] pub fn [< reduce_argmin_dim_naive_ $id >]() { let test = TestCase { @@ -294,7 +298,7 @@ impl TestCase { let bytes = client.read_one(binding); let output_values = O::from_bytes(&bytes); - assert_approx_equal_abs(output_values, &expected_values, 1e-9); + assert_approx_equal_abs(output_values, &expected_values, 1e-7); } fn cpu_sum_dim(&self, values: &[F]) -> Vec { @@ -308,9 +312,9 @@ impl TestCase { fn cpu_prod_dim(&self, values: &[F]) -> Vec { let mut expected = vec![F::new(1.0); self.num_output_values()]; - for input_index in 0..values.len() { - let output_index = self.to_output_index(input_index); - expected[output_index] *= values[input_index]; + for value_index in 0..values.len() { + let output_index = self.to_output_index(value_index); + expected[output_index] *= values[value_index]; } expected } @@ -351,13 +355,15 @@ impl TestCase { } fn num_output_values(&self) -> usize { - self.shape.iter().product::() / self.shape[self.reduce_dim as usize] + self.line_size as usize * self.shape.iter().product::() + / self.shape[self.reduce_dim as usize] } fn to_output_index(&self, input_index: usize) -> usize { - let mut coordinate = self.to_input_coordinate(input_index); + let line_size = self.line_size as usize; + let mut coordinate = self.to_input_coordinate(input_index / line_size); coordinate[self.reduce_dim as usize] = 0; - self.from_output_coordinate(coordinate) + self.from_output_coordinate(coordinate) * line_size + input_index % line_size } fn to_input_coordinate(&self, index: usize) -> Vec { @@ -403,6 +409,7 @@ impl TestCase { let mut seed = 123456789; // Not really important for testing. (0..size).map(|_| F::new(lcg(&mut seed))).collect() + // (0..size).map(|x| F::new(x as f32)).collect() } } From 59a3d2e87972b525bc36e16a2c480b29ae837cce Mon Sep 17 00:00:00 2001 From: maxime Date: Wed, 27 Nov 2024 15:20:49 -0500 Subject: [PATCH 5/9] clean and reorganize code and add doc --- crates/cubecl-reduce/src/base.rs | 5 - crates/cubecl-reduce/src/instructions.rs | 10 +- crates/cubecl-reduce/src/lib.rs | 14 +- crates/cubecl-reduce/src/naive.rs | 213 +++++++++++++++++++ crates/cubecl-reduce/src/naive/argmax.rs | 34 --- crates/cubecl-reduce/src/naive/argmin.rs | 32 --- crates/cubecl-reduce/src/naive/base.rs | 56 ----- crates/cubecl-reduce/src/naive/mean_dim.rs | 26 --- crates/cubecl-reduce/src/naive/mod.rs | 6 - crates/cubecl-reduce/src/naive/prod_dim.rs | 27 --- crates/cubecl-reduce/src/naive/sum_dim.rs | 27 --- crates/cubecl-reduce/src/prod.rs | 15 -- crates/cubecl-reduce/src/shared/argmax.rs | 63 ------ crates/cubecl-reduce/src/shared/argmin.rs | 64 ------ crates/cubecl-reduce/src/shared/base.rs | 35 --- crates/cubecl-reduce/src/shared/kernel.rs | 117 ---------- crates/cubecl-reduce/src/shared/mean_dim.rs | 44 ---- crates/cubecl-reduce/src/shared/mod.rs | 7 - crates/cubecl-reduce/src/shared/prod_dim.rs | 43 ---- crates/cubecl-reduce/src/shared/sum_dim.rs | 43 ---- crates/cubecl-reduce/src/subcube/argmax.rs | 54 ----- crates/cubecl-reduce/src/subcube/argmin.rs | 54 ----- crates/cubecl-reduce/src/subcube/base.rs | 17 -- crates/cubecl-reduce/src/subcube/kernel.rs | 134 ------------ crates/cubecl-reduce/src/subcube/mean_dim.rs | 45 ---- crates/cubecl-reduce/src/subcube/mod.rs | 7 - crates/cubecl-reduce/src/subcube/prod_dim.rs | 44 ---- crates/cubecl-reduce/src/subcube/sum_dim.rs | 44 ---- crates/cubecl-reduce/src/sum.rs | 15 -- crates/cubecl-reduce/src/test.rs | 21 +- crates/cubecl-reduce/src/tune/base.rs | 94 -------- crates/cubecl-reduce/src/tune/key.rs | 39 ---- crates/cubecl-reduce/src/tune/mod.rs | 7 - 33 files changed, 233 insertions(+), 1223 deletions(-) delete mode 100644 crates/cubecl-reduce/src/base.rs create mode 100644 crates/cubecl-reduce/src/naive.rs delete mode 100644 crates/cubecl-reduce/src/naive/argmax.rs delete mode 100644 crates/cubecl-reduce/src/naive/argmin.rs delete mode 100644 crates/cubecl-reduce/src/naive/base.rs delete mode 100644 crates/cubecl-reduce/src/naive/mean_dim.rs delete mode 100644 crates/cubecl-reduce/src/naive/mod.rs delete mode 100644 crates/cubecl-reduce/src/naive/prod_dim.rs delete mode 100644 crates/cubecl-reduce/src/naive/sum_dim.rs delete mode 100644 crates/cubecl-reduce/src/prod.rs delete mode 100644 crates/cubecl-reduce/src/shared/argmax.rs delete mode 100644 crates/cubecl-reduce/src/shared/argmin.rs delete mode 100644 crates/cubecl-reduce/src/shared/base.rs delete mode 100644 crates/cubecl-reduce/src/shared/kernel.rs delete mode 100644 crates/cubecl-reduce/src/shared/mean_dim.rs delete mode 100644 crates/cubecl-reduce/src/shared/mod.rs delete mode 100644 crates/cubecl-reduce/src/shared/prod_dim.rs delete mode 100644 crates/cubecl-reduce/src/shared/sum_dim.rs delete mode 100644 crates/cubecl-reduce/src/subcube/argmax.rs delete mode 100644 crates/cubecl-reduce/src/subcube/argmin.rs delete mode 100644 crates/cubecl-reduce/src/subcube/base.rs delete mode 100644 crates/cubecl-reduce/src/subcube/kernel.rs delete mode 100644 crates/cubecl-reduce/src/subcube/mean_dim.rs delete mode 100644 crates/cubecl-reduce/src/subcube/mod.rs delete mode 100644 crates/cubecl-reduce/src/subcube/prod_dim.rs delete mode 100644 crates/cubecl-reduce/src/subcube/sum_dim.rs delete mode 100644 crates/cubecl-reduce/src/sum.rs delete mode 100644 crates/cubecl-reduce/src/tune/base.rs delete mode 100644 crates/cubecl-reduce/src/tune/key.rs delete mode 100644 crates/cubecl-reduce/src/tune/mod.rs diff --git a/crates/cubecl-reduce/src/base.rs b/crates/cubecl-reduce/src/base.rs deleted file mode 100644 index 5dbea36c..00000000 --- a/crates/cubecl-reduce/src/base.rs +++ /dev/null @@ -1,5 +0,0 @@ -pub struct ArgMax; -pub struct ArgMin; -pub struct MeanDim; -pub struct SumDim; -pub struct ProdDim; diff --git a/crates/cubecl-reduce/src/instructions.rs b/crates/cubecl-reduce/src/instructions.rs index 5dbea36c..ffc0eef0 100644 --- a/crates/cubecl-reduce/src/instructions.rs +++ b/crates/cubecl-reduce/src/instructions.rs @@ -1,5 +1,5 @@ -pub struct ArgMax; -pub struct ArgMin; -pub struct MeanDim; -pub struct SumDim; -pub struct ProdDim; +pub struct ReduceArgMax; +pub struct ReduceArgMin; +pub struct ReduceMean; +pub struct ReduceSum; +pub struct ReduceProd; diff --git a/crates/cubecl-reduce/src/lib.rs b/crates/cubecl-reduce/src/lib.rs index 3f194e5f..c10badbf 100644 --- a/crates/cubecl-reduce/src/lib.rs +++ b/crates/cubecl-reduce/src/lib.rs @@ -1,15 +1,9 @@ mod instructions; mod naive; -// mod prod; -// mod shared; -// mod subcube; -// mod sum; -// mod tune; - -pub use instructions::*; -// pub use prod::*; -// pub use sum::*; -// pub use tune::*; #[cfg(feature = "export_tests")] pub mod test; + +pub use instructions::*; +pub use naive::*; + diff --git a/crates/cubecl-reduce/src/naive.rs b/crates/cubecl-reduce/src/naive.rs new file mode 100644 index 00000000..0954376f --- /dev/null +++ b/crates/cubecl-reduce/src/naive.rs @@ -0,0 +1,213 @@ +use cubecl_core as cubecl; +use cubecl_core::prelude::*; + +use crate::{ReduceArgMax, ReduceArgMin, ReduceMean, ReduceProd, ReduceSum}; + +/// An instruction for the [reduce_naive](reduce_naive) algorithm. +#[cube] +pub trait ReduceNaiveInstruction: Send + Sync + 'static { + /// The reduction accumulator. + /// The implement works on lines. Most likely, the accumulator is `Line` + /// for some CubePrimitive type `T` instead of simply `T`. + type Accumulator: CubeType; + + /// Initialize the accumulator with a null value for the reduction. + /// + /// This could be called many time during reduction. It is required + /// that reducing the initial accumulator any number of times do not change the outcome + /// of the reduction. For example, adding 0s in a sum do not change the outcome. + fn init_accumulator(line_size: u32) -> Self::Accumulator; + + /// Reduce `current_value` into `accumulator`. + fn accumulate(accumulator: &mut Self::Accumulator, current_value: Line, i: u32); + + /// Write the result of the reduction stored in `accumulator` into `output[index]`. + fn write( + output: &mut Tensor>, + accumulator: Self::Accumulator, + index: u32, + shape_reduce_dim: u32, + ); +} + +/// A naive implementation of the reduction algorithm. +/// +/// Each thread with absolute position P is responsible +/// to compute the reduction corresponding to index P of the `output`. +#[cube] +pub fn reduce_naive, EI: Numeric, EO: Numeric>( + input: &Tensor>, + output: &mut Tensor>, + dim: u32, +) { + if ABSOLUTE_POS >= output.len() * output.line_size() { + return; + } + + // Compute the first index where to start the reduction for the current thread. + // First, compute the coordinate corresponding to the ABSOLUTE_POS element of the output tensor + // Then, use the strides of the input tensor to find the index of the same coordinate + // in the input tensor. + let mut offset_input = 0; + for axis in 0..input.rank() { + let coordinate = (ABSOLUTE_POS / output.stride(axis)) % output.shape(axis); + offset_input += coordinate * input.stride(axis); + } + + // Reduce all the lines along `dim` for the previously computed offset. + let mut accumulator = RD::init_accumulator(input.line_size()); + for i in 0..input.shape(dim) { + let index = i * input.stride(dim) + offset_input; + RD::accumulate( + &mut accumulator, + unsafe { *input.index_unchecked(index) }, + i, + ); + } + + // Write the local outcome into output. + RD::write::(output, accumulator, ABSOLUTE_POS, input.shape(dim)); +} + +// Implementations for common instructions. + +#[cube] +impl ReduceNaiveInstruction for ReduceSum { + type Accumulator = Line; + + fn init_accumulator(line_size: u32) -> Line { + Line::empty(line_size).fill(EI::from_int(0)) + } + + fn accumulate(accumulator: &mut Self::Accumulator, current_value: Line, _i: u32) { + *accumulator += current_value; + } + + fn write( + output: &mut Tensor>, + accumulator: Self::Accumulator, + index: u32, + _shape_reduce_dim: u32, + ) { + output[index] = Line::cast_from(accumulator); + } +} + +#[cube] +impl ReduceNaiveInstruction for ReduceProd { + type Accumulator = Line; + + fn init_accumulator(line_size: u32) -> Line { + Line::empty(line_size).fill(EI::from_int(1)) + } + + fn accumulate(accumulator: &mut Self::Accumulator, current_value: Line, _i: u32) { + *accumulator *= current_value; + } + + fn write( + output: &mut Tensor>, + accumulator: Self::Accumulator, + index: u32, + _shape_reduce_dim: u32, + ) { + output[index] = Line::cast_from(accumulator); + } +} + +#[cube] +impl ReduceNaiveInstruction for ReduceMean { + type Accumulator = Line; + + fn init_accumulator(line_size: u32) -> Self::Accumulator { + Line::empty(line_size).fill(EI::from_int(0)) + } + + fn accumulate(accumulator: &mut Self::Accumulator, current_value: Line, _i: u32) { + *accumulator += current_value; + } + + fn write( + output: &mut Tensor>, + accumulator: Self::Accumulator, + index: u32, + shape_reduce_dim: u32, + ) { + output[index] = Line::cast_from( + accumulator / Line::empty(output.line_size()).fill(EI::cast_from(shape_reduce_dim)), + ); + } +} + +#[cube] +impl ReduceNaiveInstruction for ReduceArgMax { + type Accumulator = (Line, Line); + + fn init_accumulator(line_size: u32) -> Self::Accumulator { + ( + // TODO: switch to using f32::NEG_INFINITY when it's supported: https://github.com/tracel-ai/cubecl/issues/68 + Line::empty(line_size).fill(EI::MIN), + Line::empty(line_size).fill(0u32), + ) + } + + fn accumulate(accumulator: &mut Self::Accumulator, current_value: Line, i: u32) { + // FIX: There is an issue when line_size is 1 on wgpu. + let (max, index) = accumulator; + #[unroll] + for k in 0..current_value.size() { + if current_value[k] > max[k] { + max[k] = current_value[k]; + index[k] = i; + } + } + } + + fn write( + output: &mut Tensor>, + accumulator: Self::Accumulator, + index: u32, + _shape_reduce_dim: u32, + ) { + let (_, position) = accumulator; + output[index] = Line::cast_from(position) + } +} + +#[cube] +impl ReduceNaiveInstruction for ReduceArgMin { + type Accumulator = (Line, Line); + + fn init_accumulator(line_size: u32) -> Self::Accumulator { + ( + // TODO: switch to using f32::INFINITY when it's supported: https://github.com/tracel-ai/cubecl/issues/68 + Line::empty(line_size).fill(EI::MAX), + Line::empty(line_size).fill(0u32), + ) + } + + fn accumulate(accumulator: &mut Self::Accumulator, current_value: Line, i: u32) { + // FIX: There is an issue when line_size is 1 on wgpu. + let (min, index) = accumulator; + #[unroll] + for k in 0..current_value.size() { + if current_value[k] < min[k] { + min[k] = current_value[k]; + index[k] = i; + } + } + } + + fn write( + output: &mut Tensor>, + accumulator: Self::Accumulator, + index: u32, + _shape_reduce_dim: u32, + ) { + let (_, position) = accumulator; + #[unroll] + for k in 0..output.line_size() { + output[index][k] = EO::cast_from(position[k]); + } + } +} diff --git a/crates/cubecl-reduce/src/naive/argmax.rs b/crates/cubecl-reduce/src/naive/argmax.rs deleted file mode 100644 index 10e03654..00000000 --- a/crates/cubecl-reduce/src/naive/argmax.rs +++ /dev/null @@ -1,34 +0,0 @@ -use cubecl_core as cubecl; -use cubecl_core::prelude::*; - -use super::base::ReduceDimNaive; -use crate::ArgMax; - -#[cube] -impl ReduceDimNaive for ArgMax { - type Accumulator = (EI, u32); - - fn initialize_naive() -> Self::Accumulator { - // TODO: switch to using f32::NEG_INFINITY when it's supported: https://github.com/tracel-ai/cubecl/issues/68 - (comptime![EI::MIN].runtime(), 0u32) - } - - fn inner_loop_naive(accumulator: &mut Self::Accumulator, current_value: EI, i: u32) { - let (max, index) = accumulator; - if current_value > *max { - *max = current_value; - *index = i; - } - } - - fn assign_naive( - output: &mut Tensor, - accumulator: Self::Accumulator, - _shape_reduce_dim: u32, - ) { - let (_, index) = accumulator; - output[ABSOLUTE_POS] = EO::cast_from(index); - } - - -} diff --git a/crates/cubecl-reduce/src/naive/argmin.rs b/crates/cubecl-reduce/src/naive/argmin.rs deleted file mode 100644 index 32a0a7e1..00000000 --- a/crates/cubecl-reduce/src/naive/argmin.rs +++ /dev/null @@ -1,32 +0,0 @@ -use cubecl_core as cubecl; -use cubecl_core::prelude::*; - -use super::base::ReduceDimNaive; -use crate::ArgMin; - -#[cube] -impl ReduceDimNaive for ArgMin { - type Accumulator = (EI, u32); - - fn initialize_naive() -> Self::Accumulator { - // TODO: switch to using f32::INFINITY when it's supported: https://github.com/tracel-ai/cubecl/issues/68 - (comptime![EI::MAX].runtime(), 0u32) - } - - fn inner_loop_naive(accumulator: &mut Self::Accumulator, current_value: EI, i: u32) { - let (min, index) = accumulator; - if current_value < *min { - *min = current_value; - *index = i; - } - } - - fn assign_naive( - output: &mut Tensor, - accumulator: Self::Accumulator, - _shape_reduce_dim: u32, - ) { - let (_, index) = accumulator; - output[ABSOLUTE_POS] = EO::cast_from(index); - } -} diff --git a/crates/cubecl-reduce/src/naive/base.rs b/crates/cubecl-reduce/src/naive/base.rs deleted file mode 100644 index 3e08d6cc..00000000 --- a/crates/cubecl-reduce/src/naive/base.rs +++ /dev/null @@ -1,56 +0,0 @@ -use cubecl_core as cubecl; -use cubecl_core::prelude::*; - -/// Specifies the reduce dim algorithm in use -#[cube] -pub trait ReduceDimNaive: Send + Sync + 'static { - /// The reduction accumulator - type Accumulator: CubeType; - - /// Initialization for naive algorithm - fn initialize_naive(line_size: u32) -> Self::Accumulator; - - /// Inner loop for naive algorithm - fn inner_loop_naive(accumulator: &mut Self::Accumulator, current_value: Line, i: u32); - - /// Assignation for naive algorithm - fn assign_naive( - output: &mut Tensor>, - accumulator: Self::Accumulator, - shape_reduce_dim: u32, - ); -} - -#[cube] -pub fn reduce_dim_naive, EI: Numeric, EO: Numeric>( - input: &Tensor>, - output: &mut Tensor>, - dim: u32, -) { - if ABSOLUTE_POS >= output.len() * output.line_size() { - return; - }; - - let mut offset_input = 0; - - for i in 0..input.rank() { - let mut offset_local = ABSOLUTE_POS / output.stride(i); - offset_local %= output.shape(i); - if i != dim { - offset_input += offset_local * input.stride(i); - } - } - - let mut accumulator = RD::initialize_naive(input.line_size()); - - for i in 0..input.shape(dim) { - let index = i * input.stride(dim) + offset_input; - RD::inner_loop_naive( - &mut accumulator, - unsafe { *input.index_unchecked(index) }, - i, - ); - } - - RD::assign_naive::(output, accumulator, input.shape(dim)); -} diff --git a/crates/cubecl-reduce/src/naive/mean_dim.rs b/crates/cubecl-reduce/src/naive/mean_dim.rs deleted file mode 100644 index e070c6ef..00000000 --- a/crates/cubecl-reduce/src/naive/mean_dim.rs +++ /dev/null @@ -1,26 +0,0 @@ -use cubecl_core as cubecl; -use cubecl_core::prelude::*; - -use super::base::ReduceDimNaive; -use crate::MeanDim; - -#[cube] -impl ReduceDimNaive for MeanDim { - type Accumulator = Line; - - fn initialize_naive(line_size: u32) -> Self::Accumulator { - Line::empty(line_size).fill(EI::from_int(0)) - } - - fn inner_loop_naive(accumulator: &mut Self::Accumulator, current_value: Line, _i: u32) { - *accumulator += current_value; - } - - fn assign_naive( - output: &mut Tensor>, - accumulator: Self::Accumulator, - shape_reduce_dim: u32, - ) { - output[ABSOLUTE_POS] = Line::cast_from(accumulator / Line::empty(output.line_size()).fill(EI::cast_from(shape_reduce_dim))); - } -} diff --git a/crates/cubecl-reduce/src/naive/mod.rs b/crates/cubecl-reduce/src/naive/mod.rs deleted file mode 100644 index 91a3f1ba..00000000 --- a/crates/cubecl-reduce/src/naive/mod.rs +++ /dev/null @@ -1,6 +0,0 @@ -pub mod argmax; -pub mod argmin; -pub mod base; -pub mod mean_dim; -pub mod prod_dim; -pub mod sum_dim; diff --git a/crates/cubecl-reduce/src/naive/prod_dim.rs b/crates/cubecl-reduce/src/naive/prod_dim.rs deleted file mode 100644 index 8a2f7eef..00000000 --- a/crates/cubecl-reduce/src/naive/prod_dim.rs +++ /dev/null @@ -1,27 +0,0 @@ -use cubecl_core as cubecl; -use cubecl_core::prelude::*; - -use crate::ProdDim; - -use super::base::ReduceDimNaive; - -#[cube] -impl ReduceDimNaive for ProdDim { - type Accumulator = Line; - - fn initialize_naive(line_size: u32) -> Line { - Line::empty(line_size).fill(EI::from_int(1)) - } - - fn inner_loop_naive(accumulator: &mut Self::Accumulator, current_value: Line, _i: u32) { - *accumulator *= current_value; - } - - fn assign_naive( - output: &mut Tensor>, - accumulator: Self::Accumulator, - _shape_reduce_dim: u32, - ) { - output[ABSOLUTE_POS] = Line::cast_from(accumulator); - } -} diff --git a/crates/cubecl-reduce/src/naive/sum_dim.rs b/crates/cubecl-reduce/src/naive/sum_dim.rs deleted file mode 100644 index 2577bfc8..00000000 --- a/crates/cubecl-reduce/src/naive/sum_dim.rs +++ /dev/null @@ -1,27 +0,0 @@ -use cubecl_core as cubecl; -use cubecl_core::prelude::*; - -use crate::SumDim; - -use super::base::ReduceDimNaive; - -#[cube] -impl ReduceDimNaive for SumDim { - type Accumulator = Line; - - fn initialize_naive(line_size: u32) -> Line { - Line::empty(line_size).fill(EI::from_int(0)) - } - - fn inner_loop_naive(accumulator: &mut Self::Accumulator, current_value: Line, _i: u32) { - *accumulator += current_value; - } - - fn assign_naive( - output: &mut Tensor>, - accumulator: Self::Accumulator, - _shape_reduce_dim: u32, - ) { - output[ABSOLUTE_POS] = Line::cast_from(accumulator); - } -} diff --git a/crates/cubecl-reduce/src/prod.rs b/crates/cubecl-reduce/src/prod.rs deleted file mode 100644 index 77227bae..00000000 --- a/crates/cubecl-reduce/src/prod.rs +++ /dev/null @@ -1,15 +0,0 @@ -use crate::{element::JitElement, tensor::JitTensor, JitRuntime}; -use burn_tensor::Shape; - -use super::{prod_dim, ReduceStrategy}; - -/// Multiply all elements in the input buffer. -pub fn prod( - input: JitTensor, - strategy: ReduceStrategy, -) -> JitTensor { - let shape = Shape::new([input.shape.num_elements()]); - let input: JitTensor = - JitTensor::new_contiguous(input.client, input.device, shape, input.handle, input.dtype); - prod_dim::(input, 0, strategy) -} diff --git a/crates/cubecl-reduce/src/shared/argmax.rs b/crates/cubecl-reduce/src/shared/argmax.rs deleted file mode 100644 index 1685a200..00000000 --- a/crates/cubecl-reduce/src/shared/argmax.rs +++ /dev/null @@ -1,63 +0,0 @@ -use crate::{kernel::reduce::Argmax, JitElement}; -use cubecl::prelude::*; - -use super::base::ReduceDimShared; - -#[cube] -impl ReduceDimShared for Argmax { - /// The reduction accumulator - type Accumulator = (SharedMemory, SharedMemory); - type Value = (EIn, u32); - - /// Initialization for shared algorithm - fn initialize_shared( - shared_memory_size: u32, - write_position: u32, - ) -> (SharedMemory, SharedMemory) { - let mut value_shared = SharedMemory::new(shared_memory_size); - let mut index_shared = SharedMemory::new(shared_memory_size); - value_shared[write_position] = comptime![EIn::minimum_value()].runtime(); - index_shared[write_position] = 0; - (value_shared, index_shared) - } - - /// How to write to shared memory - fn write_to_shared( - shared_memory: &mut (SharedMemory, SharedMemory), - write_position: u32, - value: (EIn, u32), - ) { - let (values, indices) = shared_memory; - let (value, index) = value; - - if value > values[write_position] { - values[write_position] = value; - indices[write_position] = index; - } - } - - /// How to read from input in shared algorithm - fn read_from_input(input: &Tensor, read_position: u32, i: u32) -> (EIn, u32) { - (input[read_position], i) - } - - /// How to read from shared memory - fn read_from_shared( - shared_memory: &(SharedMemory, SharedMemory), - read_position: u32, - ) -> (EIn, u32) { - let (values, indices) = shared_memory; - (values[read_position], indices[read_position]) - } - - /// How to assign from shared memory - fn assign_shared( - shared_memory: &(SharedMemory, SharedMemory), - output: &mut Tensor, - write_position: u32, - _shape_reduce_dim: u32, - ) { - let (_, indices) = shared_memory; - output[write_position] = EOut::cast_from(indices[0]); - } -} diff --git a/crates/cubecl-reduce/src/shared/argmin.rs b/crates/cubecl-reduce/src/shared/argmin.rs deleted file mode 100644 index ff7826b1..00000000 --- a/crates/cubecl-reduce/src/shared/argmin.rs +++ /dev/null @@ -1,64 +0,0 @@ -use cubecl::prelude::*; - -use crate::{kernel::reduce::Argmin, JitElement}; - -use super::base::ReduceDimShared; - -#[cube] -impl ReduceDimShared for Argmin { - /// The reduction accumulator - type Accumulator = (SharedMemory, SharedMemory); - type Value = (EIn, u32); - - /// Initialization for shared algorithm - fn initialize_shared( - shared_memory_size: u32, - write_position: u32, - ) -> (SharedMemory, SharedMemory) { - let mut value_shared = SharedMemory::new(shared_memory_size); - let mut index_shared = SharedMemory::new(shared_memory_size); - value_shared[write_position] = comptime![EIn::maximum_value()].runtime(); - index_shared[write_position] = 0; - (value_shared, index_shared) - } - - /// How to write to shared memory - fn write_to_shared( - shared_memory: &mut (SharedMemory, SharedMemory), - write_position: u32, - value: (EIn, u32), - ) { - let (values, indices) = shared_memory; - let (value, index) = value; - - if value < values[write_position] { - values[write_position] = value; - indices[write_position] = index; - } - } - - /// How to read from input in shared algorithm - fn read_from_input(input: &Tensor, read_position: u32, i: u32) -> (EIn, u32) { - (input[read_position], i) - } - - /// How to read from shared memory - fn read_from_shared( - shared_memory: &(SharedMemory, SharedMemory), - read_position: u32, - ) -> (EIn, u32) { - let (values, indices) = shared_memory; - (values[read_position], indices[read_position]) - } - - /// How to assign from shared memory - fn assign_shared( - shared_memory: &(SharedMemory, SharedMemory), - output: &mut Tensor, - write_position: u32, - _shape_reduce_dim: u32, - ) { - let (_, indices) = shared_memory; - output[write_position] = EOut::cast_from(indices[0]); - } -} diff --git a/crates/cubecl-reduce/src/shared/base.rs b/crates/cubecl-reduce/src/shared/base.rs deleted file mode 100644 index bdb70e6a..00000000 --- a/crates/cubecl-reduce/src/shared/base.rs +++ /dev/null @@ -1,35 +0,0 @@ -use cubecl::prelude::*; - -use crate::JitElement; - -/// Specifies the reduce dim algorithm in use -#[cube] -pub trait ReduceDimShared: Send + Sync + 'static { - /// The reduction accumulator - type Accumulator: CubeType; - type Value: CubeType; - - /// Initialization for shared algorithm - fn initialize_shared(shared_memory_size: u32, write_position: u32) -> Self::Accumulator; - - /// How to write to shared memory - fn write_to_shared( - shared_memory: &mut Self::Accumulator, - write_position: u32, - value: Self::Value, - ); - - /// How to read from input in shared algorithm - fn read_from_input(input: &Tensor, read_position: u32, i: u32) -> Self::Value; - - /// How to read from shared memory - fn read_from_shared(shared_memory: &Self::Accumulator, read_position: u32) -> Self::Value; - - /// How to assign from shared memory - fn assign_shared( - shared_memory: &Self::Accumulator, - output: &mut Tensor, - write_position: u32, - shape_reduce_dim: u32, - ); -} diff --git a/crates/cubecl-reduce/src/shared/kernel.rs b/crates/cubecl-reduce/src/shared/kernel.rs deleted file mode 100644 index 1b2dcb35..00000000 --- a/crates/cubecl-reduce/src/shared/kernel.rs +++ /dev/null @@ -1,117 +0,0 @@ -use cubecl::prelude::*; - -use crate::{kernel::reduce::init_reduce_output, tensor::JitTensor, JitElement, JitRuntime}; - -use super::base::ReduceDimShared; - -#[cube(launch)] -pub fn reduce_dim_shared_kernel< - RD: ReduceDimShared, - EIn: JitElement, - EOut: JitElement, ->( - input: &Tensor, - output: &mut Tensor, - #[comptime] dim: u32, - #[comptime] smem_size: u32, - #[comptime] elems_per_thread: u32, - #[comptime] divisible_shape: bool, -) { - let reduce_group_id = CUBE_POS; - - let stride_reduce_dim_input = input.stride(dim); - let shape_reduce_dim_input = input.shape(dim); - - let mut shared_memory = RD::initialize_shared(smem_size, UNIT_POS); - - let mut index_offset = 0; - - for i in 0..input.rank() { - let num_block = reduce_group_id / output.stride(i) % output.shape(i); - index_offset += num_block * input.stride(i); - } - - for i in 0..elems_per_thread { - let nth = i * CUBE_DIM + UNIT_POS; - - #[allow(clippy::collapsible_else_if)] - if divisible_shape { - let current_pos = nth * stride_reduce_dim_input + index_offset; - - let new_value = RD::read_from_input(input, current_pos, nth); - RD::write_to_shared(&mut shared_memory, UNIT_POS, new_value); - } else { - if nth < shape_reduce_dim_input { - let current_pos = nth * stride_reduce_dim_input + index_offset; - - let new_value = RD::read_from_input(input, current_pos, nth); - RD::write_to_shared(&mut shared_memory, UNIT_POS, new_value); - } - } - } - - sync_units(); - - let mut n_threads = CUBE_DIM; - - while n_threads > 1 { - n_threads /= 2; - - if UNIT_POS < n_threads { - let read_pos = n_threads + UNIT_POS; - let read_value = RD::read_from_shared(&shared_memory, read_pos); - RD::write_to_shared(&mut shared_memory, UNIT_POS, read_value); - } - - sync_units(); - } - - if UNIT_POS == 0 { - RD::assign_shared( - &shared_memory, - output, - reduce_group_id, - shape_reduce_dim_input, - ); - } -} - -/// Executes the shared memory kernel for reduce dim -pub fn reduce_dim_shared< - RD: ReduceDimShared, - R: JitRuntime, - EI: JitElement, - EO: JitElement, ->( - input: JitTensor, - dim: usize, -) -> JitTensor { - let output = init_reduce_output::(&input, dim); - - let num_elems_output = output.shape.num_elements(); - let cube_dim = CubeDim::default(); - let cube_count_x = f32::ceil(f32::sqrt(num_elems_output as f32)); - let cube_count_y = f32::ceil(num_elems_output as f32 / cube_count_x); - let cube_count = CubeCount::Static(cube_count_x as u32, cube_count_y as u32, 1); - - let reduce_group_size = input.shape.dims[dim]; - let n_invocation_per_cube = cube_dim.num_elems(); - let elems_per_thread = - f32::ceil(reduce_group_size as f32 / n_invocation_per_cube as f32) as u32; - - let divisible_shape = n_invocation_per_cube * elems_per_thread == reduce_group_size as u32; - - reduce_dim_shared_kernel::launch::( - &input.client, - cube_count, - cube_dim, - input.as_tensor_arg::(1), - output.as_tensor_arg::(1), - dim as u32, - cube_dim.num_elems(), - elems_per_thread, - divisible_shape, - ); - - output -} diff --git a/crates/cubecl-reduce/src/shared/mean_dim.rs b/crates/cubecl-reduce/src/shared/mean_dim.rs deleted file mode 100644 index 0b09d917..00000000 --- a/crates/cubecl-reduce/src/shared/mean_dim.rs +++ /dev/null @@ -1,44 +0,0 @@ -use crate::{kernel::reduce::MeanDim, JitElement}; -use cubecl::prelude::*; - -use super::base::ReduceDimShared; - -#[cube] -impl ReduceDimShared for MeanDim { - /// The reduction accumulator - type Accumulator = SharedMemory; - type Value = EIn; - - /// Initialization for shared algorithm - fn initialize_shared(shared_memory_size: u32, write_position: u32) -> SharedMemory { - let mut value_shared = SharedMemory::new(shared_memory_size); - value_shared[write_position] = EIn::from_int(0); - value_shared - } - - /// How to write to shared memory - fn write_to_shared(shared_memory: &mut SharedMemory, write_position: u32, value: EIn) { - shared_memory[write_position] += value; - } - - /// How to read from input in shared algorithm - fn read_from_input(input: &Tensor, read_position: u32, _i: u32) -> EIn { - input[read_position] - } - - /// How to read from shared memory - fn read_from_shared(shared_memory: &SharedMemory, read_position: u32) -> EIn { - shared_memory[read_position] - } - - /// How to assign from shared memory - fn assign_shared( - shared_memory: &SharedMemory, - output: &mut Tensor, - write_position: u32, - shape_reduce_dim: u32, - ) { - let mean = shared_memory[0] / EIn::cast_from(shape_reduce_dim); - output[write_position] = EOut::cast_from(mean); - } -} diff --git a/crates/cubecl-reduce/src/shared/mod.rs b/crates/cubecl-reduce/src/shared/mod.rs deleted file mode 100644 index b11ee5e2..00000000 --- a/crates/cubecl-reduce/src/shared/mod.rs +++ /dev/null @@ -1,7 +0,0 @@ -pub(crate) mod argmax; -pub(crate) mod argmin; -pub(crate) mod base; -pub(crate) mod kernel; -pub(crate) mod mean_dim; -pub(crate) mod prod_dim; -pub(crate) mod sum_dim; diff --git a/crates/cubecl-reduce/src/shared/prod_dim.rs b/crates/cubecl-reduce/src/shared/prod_dim.rs deleted file mode 100644 index 8041cc68..00000000 --- a/crates/cubecl-reduce/src/shared/prod_dim.rs +++ /dev/null @@ -1,43 +0,0 @@ -use crate::{kernel::reduce::ProdDim, JitElement}; -use cubecl::prelude::*; - -use super::base::ReduceDimShared; - -#[cube] -impl ReduceDimShared for ProdDim { - /// The reduction accumulator - type Accumulator = SharedMemory; - type Value = EIn; - - /// Initialization for shared algorithm - fn initialize_shared(shared_memory_size: u32, write_position: u32) -> SharedMemory { - let mut value_shared = SharedMemory::new(shared_memory_size); - value_shared[write_position] = EIn::from_int(1); - value_shared - } - - /// How to write to shared memory - fn write_to_shared(shared_memory: &mut SharedMemory, write_position: u32, value: EIn) { - shared_memory[write_position] *= value; - } - - /// How to read from input in shared algorithm - fn read_from_input(input: &Tensor, read_position: u32, _i: u32) -> EIn { - input[read_position] - } - - /// How to read from shared memory - fn read_from_shared(shared_memory: &SharedMemory, read_position: u32) -> EIn { - shared_memory[read_position] - } - - /// How to assign from shared memory - fn assign_shared( - shared_memory: &SharedMemory, - output: &mut Tensor, - write_position: u32, - _shape_reduce_dim: u32, - ) { - output[write_position] = EOut::cast_from(shared_memory[0]); - } -} diff --git a/crates/cubecl-reduce/src/shared/sum_dim.rs b/crates/cubecl-reduce/src/shared/sum_dim.rs deleted file mode 100644 index da2b7337..00000000 --- a/crates/cubecl-reduce/src/shared/sum_dim.rs +++ /dev/null @@ -1,43 +0,0 @@ -use crate::{kernel::reduce::SumDim, JitElement}; -use cubecl::prelude::*; - -use super::base::ReduceDimShared; - -#[cube] -impl ReduceDimShared for SumDim { - /// The reduction accumulator - type Accumulator = SharedMemory; - type Value = EIn; - - /// Initialization for shared algorithm - fn initialize_shared(shared_memory_size: u32, write_position: u32) -> SharedMemory { - let mut value_shared = SharedMemory::new(shared_memory_size); - value_shared[write_position] = EIn::from_int(0); - value_shared - } - - /// How to write to shared memory - fn write_to_shared(shared_memory: &mut SharedMemory, write_position: u32, value: EIn) { - shared_memory[write_position] += value; - } - - /// How to read from input in shared algorithm - fn read_from_input(input: &Tensor, read_position: u32, _i: u32) -> EIn { - input[read_position] - } - - /// How to read from shared memory - fn read_from_shared(shared_memory: &SharedMemory, read_position: u32) -> EIn { - shared_memory[read_position] - } - - /// How to assign from shared memory - fn assign_shared( - shared_memory: &SharedMemory, - output: &mut Tensor, - write_position: u32, - _shape_reduce_dim: u32, - ) { - output[write_position] = EOut::cast_from(shared_memory[0]); - } -} diff --git a/crates/cubecl-reduce/src/subcube/argmax.rs b/crates/cubecl-reduce/src/subcube/argmax.rs deleted file mode 100644 index 428bd712..00000000 --- a/crates/cubecl-reduce/src/subcube/argmax.rs +++ /dev/null @@ -1,54 +0,0 @@ -use cubecl::{cube, prelude::*}; - -use crate::{kernel::reduce::Argmax, JitElement}; - -use super::base::ReduceDimSubcube; - -#[cube] -impl ReduceDimSubcube for Argmax { - /// The reduction accumulator - type Accumulator = (SharedMemory, SharedMemory); - type Value = (EIn, u32); - - fn init_shared(#[comptime] size: u32) -> Self::Accumulator { - let value_shared = SharedMemory::new(size); - let index_shared = SharedMemory::new(size); - (value_shared, index_shared) - } - - fn init_value() -> Self::Value { - (comptime![EIn::minimum_value()], 0u32) - } - - fn read_value(input: &Tensor, pos: u32, i: u32) -> Self::Value { - (input[pos], i) - } - - fn read_from_shared(acc: &Self::Accumulator, pos: u32) -> Self::Value { - let (values, indices) = acc; - (values[pos], indices[pos]) - } - - fn update_value(current: &mut Self::Value, new: Self::Value) { - let (current_val, current_idx) = current; - let (new_val, new_idx) = new; - *current_val = Max::max(*current_val, new_val); - *current_idx = select(*current_val == new_val, new_idx, *current_idx); - } - - fn reduce_subcube(acc: &mut Self::Accumulator, write_position: u32, value: Self::Value) { - let (val, index) = value; - let (val_smem, index_smem) = acc; - let max = plane_max(val); - - if max == val { - val_smem[write_position] = val; - index_smem[write_position] = index; - } - } - - fn store(acc: &Self::Accumulator, out: &mut Tensor, pos: u32, _layout: u32) { - let (_, indices) = acc; - out[pos] = EOut::cast_from(indices[0]); - } -} diff --git a/crates/cubecl-reduce/src/subcube/argmin.rs b/crates/cubecl-reduce/src/subcube/argmin.rs deleted file mode 100644 index 6a002a5d..00000000 --- a/crates/cubecl-reduce/src/subcube/argmin.rs +++ /dev/null @@ -1,54 +0,0 @@ -use cubecl::{cube, prelude::*}; - -use crate::{kernel::reduce::Argmin, JitElement}; - -use super::base::ReduceDimSubcube; - -#[cube] -impl ReduceDimSubcube for Argmin { - /// The reduction accumulator - type Accumulator = (SharedMemory, SharedMemory); - type Value = (EIn, u32); - - fn init_shared(#[comptime] size: u32) -> Self::Accumulator { - let value_shared = SharedMemory::new(size); - let index_shared = SharedMemory::new(size); - (value_shared, index_shared) - } - - fn init_value() -> Self::Value { - (comptime![EIn::maximum_value()], 0u32) - } - - fn read_value(input: &Tensor, pos: u32, i: u32) -> Self::Value { - (input[pos], i) - } - - fn read_from_shared(acc: &Self::Accumulator, pos: u32) -> Self::Value { - let (values, indices) = acc; - (values[pos], indices[pos]) - } - - fn update_value(current: &mut Self::Value, new: Self::Value) { - let (current_val, current_idx) = current; - let (new_val, new_idx) = new; - *current_val = Min::min(*current_val, new_val); - *current_idx = select(*current_val == new_val, new_idx, *current_idx); - } - - fn reduce_subcube(acc: &mut Self::Accumulator, write_position: u32, value: Self::Value) { - let (val, index) = value; - let (val_smem, index_smem) = acc; - let min = plane_min(val); - - if min == val { - val_smem[write_position] = val; - index_smem[write_position] = index; - } - } - - fn store(acc: &Self::Accumulator, out: &mut Tensor, pos: u32, _layout: u32) { - let (_, indices) = acc; - out[pos] = EOut::cast_from(indices[0]); - } -} diff --git a/crates/cubecl-reduce/src/subcube/base.rs b/crates/cubecl-reduce/src/subcube/base.rs deleted file mode 100644 index a700bf84..00000000 --- a/crates/cubecl-reduce/src/subcube/base.rs +++ /dev/null @@ -1,17 +0,0 @@ -use cubecl::prelude::*; - -use crate::JitElement; - -#[cube] -pub trait ReduceDimSubcube: Send + Sync + 'static { - type Accumulator: CubeType; - type Value: CubeType; - - fn init_shared(#[comptime] size: u32) -> Self::Accumulator; - fn init_value() -> Self::Value; - fn read_value(input: &Tensor, pos: u32, i: u32) -> Self::Value; - fn read_from_shared(acc: &Self::Accumulator, pos: u32) -> Self::Value; - fn update_value(current: &mut Self::Value, new: Self::Value); - fn reduce_subcube(acc: &mut Self::Accumulator, pos: u32, value: Self::Value); - fn store(acc: &Self::Accumulator, out: &mut Tensor, pos: u32, dim_len: u32); -} diff --git a/crates/cubecl-reduce/src/subcube/kernel.rs b/crates/cubecl-reduce/src/subcube/kernel.rs deleted file mode 100644 index 4e783e74..00000000 --- a/crates/cubecl-reduce/src/subcube/kernel.rs +++ /dev/null @@ -1,134 +0,0 @@ -use cubecl::{prelude::*, CubeCount, CubeDim, Feature}; - -use crate::{ - kernel::reduce::{init_reduce_output, shared::kernel::reduce_dim_shared, ReduceDimAlgorithm}, - tensor::JitTensor, - JitElement, JitRuntime, -}; - -use super::base::ReduceDimSubcube; - -#[cube(launch)] -pub fn reduce_dim_subcube_kernel< - RD: ReduceDimSubcube, - EIn: JitElement, - EOut: JitElement, ->( - input: &Tensor, - output: &mut Tensor, - #[comptime] dim: u32, - #[comptime] subcube_size: u32, - #[comptime] elems_per_thread: u32, - #[comptime] divisible_shape: bool, -) { - let reduce_group_id = CUBE_POS; - - let stride_reduce_dim_input = input.stride(dim); - let shape_reduce_dim_input = input.shape(dim); - - let should_unroll = elems_per_thread <= 8; - - let warp_id = UNIT_POS / PLANE_DIM; - - let mut shared_memory = RD::init_shared(subcube_size); - - let mut index_offset = 0; - - for i in 0..input.rank() { - let num_block = reduce_group_id / output.stride(i) % output.shape(i); - index_offset += num_block * input.stride(i); - } - - let mut value = RD::init_value(); - - #[unroll(should_unroll)] - for i in 0..elems_per_thread { - let nth = i * CUBE_DIM + UNIT_POS; - let current_pos = nth * stride_reduce_dim_input + index_offset; - - #[allow(clippy::collapsible_else_if)] - if divisible_shape { - let next = RD::read_value(input, current_pos, nth); - RD::update_value(&mut value, next); - } else { - if nth < shape_reduce_dim_input { - let next = RD::read_value(input, current_pos, nth); - RD::update_value(&mut value, next); - } - } - } - - RD::reduce_subcube(&mut shared_memory, warp_id, value); - - sync_units(); - - if UNIT_POS >= PLANE_DIM { - return; - } - - let value = RD::read_from_shared(&shared_memory, UNIT_POS); - RD::reduce_subcube(&mut shared_memory, 0, value); - - if UNIT_POS == 0 { - RD::store( - &shared_memory, - output, - reduce_group_id, - shape_reduce_dim_input, - ); - } -} - -/// Executes the shared memory kernel for reduce dim -pub fn reduce_dim_subcube< - RD: ReduceDimAlgorithm, - R: JitRuntime, - EI: JitElement, - EO: JitElement, ->( - input: JitTensor, - dim: usize, -) -> JitTensor { - let topology = input.client.properties().hardware_properties(); - - if !input.client.properties().feature_enabled(Feature::Plane) - || topology.plane_size_min != topology.plane_size_max - { - return reduce_dim_shared::(input, dim); - } - - let subcube_size = topology.plane_size_min; - - let output = init_reduce_output::(&input, dim); - - let num_elems_output = output.shape.num_elements(); - let cube_dim = CubeDim { - x: subcube_size, - y: subcube_size, - z: 1, - }; - let cube_count_x = f32::ceil(f32::sqrt(num_elems_output as f32)); - let cube_count_y = f32::ceil(num_elems_output as f32 / cube_count_x); - let cube_count = CubeCount::Static(cube_count_x as u32, cube_count_y as u32, 1); - - let reduce_group_size = input.shape.dims[dim]; - let n_invocation_per_cube = cube_dim.num_elems(); - let elems_per_thread = - f32::ceil(reduce_group_size as f32 / n_invocation_per_cube as f32) as u32; - - let divisible_shape = n_invocation_per_cube * elems_per_thread == reduce_group_size as u32; - - reduce_dim_subcube_kernel::launch::( - &input.client, - cube_count, - cube_dim, - input.as_tensor_arg::(1), - output.as_tensor_arg::(1), - dim as u32, - subcube_size, - elems_per_thread, - divisible_shape, - ); - - output -} diff --git a/crates/cubecl-reduce/src/subcube/mean_dim.rs b/crates/cubecl-reduce/src/subcube/mean_dim.rs deleted file mode 100644 index 63e14de4..00000000 --- a/crates/cubecl-reduce/src/subcube/mean_dim.rs +++ /dev/null @@ -1,45 +0,0 @@ -use cubecl::{cube, prelude::*}; - -use crate::{kernel::reduce::MeanDim, JitElement}; - -use super::base::ReduceDimSubcube; - -#[cube] -impl ReduceDimSubcube for MeanDim { - /// The reduction accumulator - type Accumulator = SharedMemory; - type Value = EIn; - - fn init_shared(#[comptime] size: u32) -> Self::Accumulator { - SharedMemory::new(size) - } - - fn init_value() -> Self::Value { - comptime![EIn::default()].runtime() - } - - fn read_value(input: &Tensor, pos: u32, _i: u32) -> Self::Value { - input[pos] - } - - fn read_from_shared(acc: &Self::Accumulator, pos: u32) -> Self::Value { - acc[pos] - } - - fn update_value(current: &mut Self::Value, new: Self::Value) { - *current += new; - } - - fn reduce_subcube(acc: &mut Self::Accumulator, write_position: u32, value: Self::Value) { - let sum = plane_sum(value); - - if UNIT_POS % PLANE_DIM == 0 { - acc[write_position] = sum; - } - } - - fn store(acc: &Self::Accumulator, out: &mut Tensor, pos: u32, dim_length: u32) { - let denom = EIn::cast_from(dim_length); - out[pos] = EOut::cast_from(acc[0] / denom); - } -} diff --git a/crates/cubecl-reduce/src/subcube/mod.rs b/crates/cubecl-reduce/src/subcube/mod.rs deleted file mode 100644 index 183c1e2d..00000000 --- a/crates/cubecl-reduce/src/subcube/mod.rs +++ /dev/null @@ -1,7 +0,0 @@ -pub mod argmax; -pub mod argmin; -pub mod base; -pub mod kernel; -pub mod mean_dim; -pub mod prod_dim; -pub mod sum_dim; diff --git a/crates/cubecl-reduce/src/subcube/prod_dim.rs b/crates/cubecl-reduce/src/subcube/prod_dim.rs deleted file mode 100644 index 4c0b71d9..00000000 --- a/crates/cubecl-reduce/src/subcube/prod_dim.rs +++ /dev/null @@ -1,44 +0,0 @@ -use cubecl::{cube, prelude::*}; - -use crate::{kernel::reduce::ProdDim, JitElement}; - -use super::base::ReduceDimSubcube; - -#[cube] -impl ReduceDimSubcube for ProdDim { - /// The reduction accumulator - type Accumulator = SharedMemory; - type Value = EIn; - - fn init_shared(#[comptime] size: u32) -> Self::Accumulator { - SharedMemory::new(size) - } - - fn init_value() -> Self::Value { - comptime![EIn::from_int(1)].runtime() - } - - fn read_value(input: &Tensor, pos: u32, _i: u32) -> Self::Value { - input[pos] - } - - fn read_from_shared(acc: &Self::Accumulator, pos: u32) -> Self::Value { - acc[pos] - } - - fn update_value(current: &mut Self::Value, new: Self::Value) { - *current *= new; - } - - fn reduce_subcube(acc: &mut Self::Accumulator, write_position: u32, value: Self::Value) { - let prod = plane_prod(value); - - if UNIT_POS % PLANE_DIM == 0 { - acc[write_position] = prod; - } - } - - fn store(acc: &Self::Accumulator, out: &mut Tensor, pos: u32, _layout: u32) { - out[pos] = EOut::cast_from(acc[0]); - } -} diff --git a/crates/cubecl-reduce/src/subcube/sum_dim.rs b/crates/cubecl-reduce/src/subcube/sum_dim.rs deleted file mode 100644 index 3aac1a3c..00000000 --- a/crates/cubecl-reduce/src/subcube/sum_dim.rs +++ /dev/null @@ -1,44 +0,0 @@ -use cubecl::{cube, prelude::*}; - -use crate::{kernel::reduce::SumDim, JitElement}; - -use super::base::ReduceDimSubcube; - -#[cube] -impl ReduceDimSubcube for SumDim { - /// The reduction accumulator - type Accumulator = SharedMemory; - type Value = EIn; - - fn init_shared(#[comptime] size: u32) -> Self::Accumulator { - SharedMemory::new(size) - } - - fn init_value() -> Self::Value { - comptime![EIn::default()].runtime() - } - - fn read_value(input: &Tensor, pos: u32, _i: u32) -> Self::Value { - input[pos] - } - - fn read_from_shared(acc: &Self::Accumulator, pos: u32) -> Self::Value { - acc[pos] - } - - fn update_value(current: &mut Self::Value, new: Self::Value) { - *current += new; - } - - fn reduce_subcube(acc: &mut Self::Accumulator, write_position: u32, value: Self::Value) { - let sum = plane_sum(value); - - if UNIT_POS % PLANE_DIM == 0 { - acc[write_position] = sum; - } - } - - fn store(acc: &Self::Accumulator, out: &mut Tensor, pos: u32, _layout: u32) { - out[pos] = EOut::cast_from(acc[0]); - } -} diff --git a/crates/cubecl-reduce/src/sum.rs b/crates/cubecl-reduce/src/sum.rs deleted file mode 100644 index fea80bcc..00000000 --- a/crates/cubecl-reduce/src/sum.rs +++ /dev/null @@ -1,15 +0,0 @@ -use crate::{element::JitElement, tensor::JitTensor, JitRuntime}; -use burn_tensor::Shape; - -use super::{sum_dim, ReduceStrategy}; - -/// Sum all elements in the input buffer. -pub fn sum( - input: JitTensor, - strategy: ReduceStrategy, -) -> JitTensor { - let shape = Shape::new([input.shape.num_elements()]); - let input: JitTensor = - JitTensor::new_contiguous(input.client, input.device, shape, input.handle, input.dtype); - sum_dim::(input, 0, strategy) -} diff --git a/crates/cubecl-reduce/src/test.rs b/crates/cubecl-reduce/src/test.rs index b24facdd..35c7cb89 100644 --- a/crates/cubecl-reduce/src/test.rs +++ b/crates/cubecl-reduce/src/test.rs @@ -3,16 +3,17 @@ use cubecl_core as cubecl; use cubecl_core::prelude::*; -use crate::naive::base::{reduce_dim_naive, ReduceDimNaive}; -use crate::{ArgMax, ArgMin, MeanDim, ProdDim, SumDim}; +use crate::{ + reduce_naive, ReduceArgMax, ReduceArgMin, ReduceNaiveInstruction, ReduceMean, ReduceProd, ReduceSum, +}; #[cube(launch_unchecked)] -pub fn naive_reduce_dim_kernel>( +pub fn naive_reduce_dim_kernel>( input: &Tensor>, output: &mut Tensor>, dim: u32, ) { - reduce_dim_naive::(input, output, dim) + reduce_naive::(input, output, dim) } #[macro_export] @@ -206,7 +207,7 @@ impl TestCase { { let input_values: Vec = self.random_input_values(); let expected_values = self.cpu_sum_dim(&input_values); - self.run_test::(device, input_values, expected_values) + self.run_test::(device, input_values, expected_values) } pub fn test_prod_dim_naive(&self, device: &R::Device) @@ -216,7 +217,7 @@ impl TestCase { { let input_values: Vec = self.random_input_values(); let expected_values = self.cpu_prod_dim(&input_values); - self.run_test::(device, input_values, expected_values) + self.run_test::(device, input_values, expected_values) } pub fn test_mean_dim_naive(&self, device: &R::Device) @@ -226,7 +227,7 @@ impl TestCase { { let input_values: Vec = self.random_input_values(); let expected_values = self.cpu_mean_dim(&input_values); - self.run_test::(device, input_values, expected_values) + self.run_test::(device, input_values, expected_values) } pub fn test_argmax_dim_naive(&self, device: &R::Device) @@ -236,7 +237,7 @@ impl TestCase { { let input_values: Vec = self.random_input_values(); let expected_values = self.cpu_argmax_dim(&input_values); - self.run_test::(device, input_values, expected_values) + self.run_test::(device, input_values, expected_values) } pub fn test_argmin_dim_naive(&self, device: &R::Device) @@ -246,7 +247,7 @@ impl TestCase { { let input_values: Vec = self.random_input_values(); let expected_values = self.cpu_argmin_dim(&input_values); - self.run_test::(device, input_values, expected_values) + self.run_test::(device, input_values, expected_values) } pub fn run_test( @@ -258,7 +259,7 @@ impl TestCase { I: Numeric + CubeElement + std::fmt::Display, O: Numeric + CubeElement + std::fmt::Display, R: Runtime, - K: ReduceDimNaive, + K: ReduceNaiveInstruction, { let client = R::client(device); diff --git a/crates/cubecl-reduce/src/tune/base.rs b/crates/cubecl-reduce/src/tune/base.rs deleted file mode 100644 index f52bfd7c..00000000 --- a/crates/cubecl-reduce/src/tune/base.rs +++ /dev/null @@ -1,94 +0,0 @@ -use burn_tensor::{Element, ElementConversion}; -use cubecl::tune::{local_tuner, tune_with, LocalTuner}; -use cubecl::{tune, Feature}; - -use crate::{ - element::JitElement, - kernel::{ - prng::random_like_uniform, - reduce::{ - naive::kernel::reduce_dim_naive, shared::kernel::reduce_dim_shared, - subcube::kernel::reduce_dim_subcube, ReduceDimAlgorithm, - }, - }, - tensor::JitTensor, - tune_key::JitAutotuneKey, - JitRuntime, JitTuneId, -}; - -use super::create_key; - -/// Set of reduce_dim implementations available for autotune -/// Autotune key is given by concatenating the closest upper power of 2 of -/// dim to reduce, and product of others -#[tune( - operations(reduce_dim_naive, reduce_dim_shared, reduce_dim_subcube), - create_key = create_key::, - should_run = should_run -)] -pub fn reduce_dim_operations< - RD: ReduceDimAlgorithm, - R: JitRuntime, - EI: JitElement + Element, - EO: JitElement + Element, ->( - key: JitAutotuneKey, - input: JitTensor, - reduce_dim: usize, -) -> JitTensor { - let random_bounds: (EI, EI) = ((-10.0).elem::(), (10.0).elem::()); - let input = random_like_uniform(input, random_bounds.0, random_bounds.1); - - tune_with!(input, reduce_dim) -} - -/// Executes autotune on reduce_dim operation -pub(crate) fn reduce_dim_autotune< - RD: ReduceDimAlgorithm, - R: JitRuntime, - EI: JitElement + Element, - EO: JitElement + Element, ->( - input: JitTensor, - reduce_dim: usize, -) -> JitTensor { - let client = input.client.clone(); - - let id = JitTuneId::new::(&input.device); - - let operation_set = Box::new(ReduceDimOperations::::new(input, reduce_dim)); - - static TUNER: LocalTuner = local_tuner!(); - - TUNER.execute(&id, &client, operation_set) -} - -fn should_run< - RD: ReduceDimAlgorithm, - R: JitRuntime, - EI: JitElement + Element, - EO: JitElement + Element, ->( - op: &ReduceDimOperations, - key: &JitAutotuneKey, - index: usize, -) -> bool { - let JitAutotuneKey::ReduceDim(key) = key else { - unreachable!() - }; - - match index { - // Naive - 0 => key.reduce_dim_length <= 8192, - // Shared - 1 => key.reduce_dim_length >= 16, - // Subcube - 2 => { - let props = op.input.client.properties(); - let hardware = props.hardware_properties(); - props.feature_enabled(Feature::Plane) - && hardware.plane_size_min == hardware.plane_size_max - } - _ => true, - } -} diff --git a/crates/cubecl-reduce/src/tune/key.rs b/crates/cubecl-reduce/src/tune/key.rs deleted file mode 100644 index 3634022b..00000000 --- a/crates/cubecl-reduce/src/tune/key.rs +++ /dev/null @@ -1,39 +0,0 @@ -use cubecl::AutotuneKey; -use serde::{Deserialize, Serialize}; - -use burn_tensor::DType; - -use crate::{tensor::JitTensor, JitAutotuneKey, JitElement, JitRuntime}; - -/// Autotune key representative of reduce versions -#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)] -pub struct ReduceAutotuneKey { - #[autotune(anchor)] - pub(crate) reduce_dim_length: usize, - #[autotune(anchor)] - pub(crate) reduce_dim_stride: usize, - #[autotune(anchor)] - pub(crate) others_product: usize, - dtype: DType, -} - -pub(crate) fn create_key( - input: &JitTensor, - reduce_dim: &usize, -) -> JitAutotuneKey { - let dims = &input.shape.dims; - let reduce_dim = *reduce_dim; - - let mut others_product = 1; - for (d, len) in dims.iter().enumerate() { - if d != reduce_dim { - others_product *= len - } - } - JitAutotuneKey::ReduceDim(ReduceAutotuneKey::new( - dims[reduce_dim], - input.strides[reduce_dim], - others_product, - EI::dtype(), - )) -} diff --git a/crates/cubecl-reduce/src/tune/mod.rs b/crates/cubecl-reduce/src/tune/mod.rs deleted file mode 100644 index aee5569b..00000000 --- a/crates/cubecl-reduce/src/tune/mod.rs +++ /dev/null @@ -1,7 +0,0 @@ -#[cfg(feature = "autotune")] -mod base; -mod key; - -#[cfg(feature = "autotune")] -pub(crate) use base::*; -pub use key::*; From 2ecb00e8198b78f110d1de8e6cf4e2717992a27a Mon Sep 17 00:00:00 2001 From: maxime Date: Wed, 27 Nov 2024 15:26:11 -0500 Subject: [PATCH 6/9] Add comments to test --- crates/cubecl-reduce/src/test.rs | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/crates/cubecl-reduce/src/test.rs b/crates/cubecl-reduce/src/test.rs index 35c7cb89..f48ae3d4 100644 --- a/crates/cubecl-reduce/src/test.rs +++ b/crates/cubecl-reduce/src/test.rs @@ -4,9 +4,11 @@ use cubecl_core as cubecl; use cubecl_core::prelude::*; use crate::{ - reduce_naive, ReduceArgMax, ReduceArgMin, ReduceNaiveInstruction, ReduceMean, ReduceProd, ReduceSum, + reduce_naive, ReduceArgMax, ReduceArgMin, ReduceMean, ReduceNaiveInstruction, ReduceProd, + ReduceSum, }; +// Simple kernel to launch tests. #[cube(launch_unchecked)] pub fn naive_reduce_dim_kernel>( input: &Tensor>, @@ -16,8 +18,10 @@ pub fn naive_reduce_dim_kernel(input, output, dim) } +// This macro generate all the tests. #[macro_export] macro_rules! testgen_reduce { + // Generate all the tests for a list of types. ([$($float:ident), *]) => { mod test_reduce { use super::*; @@ -31,6 +35,7 @@ macro_rules! testgen_reduce { } }; + // Generate all the tests for a specific float type. ($float:ident) => { use cubecl_reduce::test::TestCase; use cubecl_core::prelude::CubeCount; @@ -97,6 +102,10 @@ macro_rules! testgen_reduce { }; } +// For a given tensor description and cube settings +// run the tests for `ReduceSum`, `ReduceProd`, `ReduceMean`, `ReduceArgMax` and `ReduceArgMin` +// for all implementations. +// For each test, a reference reduction is computed on the CPU to compare the outcome of the kernel. #[macro_export] macro_rules! impl_test_reduce { ( @@ -265,6 +274,8 @@ impl TestCase { let input_handle = client.create(I::as_bytes(&input_values)); + // Zero initialize a tensor with the same shape as input + // except for the `self.reduce_dim` axis where the shape is 1. let output_handle = client.create(O::as_bytes(&vec![O::from_int(0); expected_values.len()])); let mut output_shape = self.shape.clone(); @@ -410,7 +421,6 @@ impl TestCase { let mut seed = 123456789; // Not really important for testing. (0..size).map(|_| F::new(lcg(&mut seed))).collect() - // (0..size).map(|x| F::new(x as f32)).collect() } } From 24e7d9111a3943542729e9a5ec4f7d12007752a1 Mon Sep 17 00:00:00 2001 From: maxime Date: Wed, 27 Nov 2024 15:34:10 -0500 Subject: [PATCH 7/9] run cargo fmt --- crates/cubecl-reduce/src/lib.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/crates/cubecl-reduce/src/lib.rs b/crates/cubecl-reduce/src/lib.rs index c10badbf..dcd0be2f 100644 --- a/crates/cubecl-reduce/src/lib.rs +++ b/crates/cubecl-reduce/src/lib.rs @@ -6,4 +6,3 @@ pub mod test; pub use instructions::*; pub use naive::*; - From 7feb03461640a49651b47986b6f91944a68db984 Mon Sep 17 00:00:00 2001 From: maxime Date: Wed, 27 Nov 2024 16:27:51 -0500 Subject: [PATCH 8/9] Fix ArgMin and ArgMax and unlock tests --- crates/cubecl-reduce/src/naive.rs | 41 +++++++++++++++++++------------ crates/cubecl-reduce/src/test.rs | 8 ++---- 2 files changed, 27 insertions(+), 22 deletions(-) diff --git a/crates/cubecl-reduce/src/naive.rs b/crates/cubecl-reduce/src/naive.rs index 0954376f..0bd81f42 100644 --- a/crates/cubecl-reduce/src/naive.rs +++ b/crates/cubecl-reduce/src/naive.rs @@ -152,13 +152,19 @@ impl ReduceNaiveInstruction for ReduceArgMax { } fn accumulate(accumulator: &mut Self::Accumulator, current_value: Line, i: u32) { - // FIX: There is an issue when line_size is 1 on wgpu. let (max, index) = accumulator; - #[unroll] - for k in 0..current_value.size() { - if current_value[k] > max[k] { - max[k] = current_value[k]; - index[k] = i; + if comptime!(current_value.size() > 1) { + #[unroll] + for k in 0..current_value.size() { + if current_value[k] > max[k] { + max[k] = current_value[k]; + index[k] = i; + } + } + } else { + if current_value > *max { + *max = current_value; + *index = Line::new(i); } } } @@ -187,13 +193,19 @@ impl ReduceNaiveInstruction for ReduceArgMin { } fn accumulate(accumulator: &mut Self::Accumulator, current_value: Line, i: u32) { - // FIX: There is an issue when line_size is 1 on wgpu. let (min, index) = accumulator; - #[unroll] - for k in 0..current_value.size() { - if current_value[k] < min[k] { - min[k] = current_value[k]; - index[k] = i; + if comptime!(current_value.size() > 1) { + #[unroll] + for k in 0..current_value.size() { + if current_value[k] < min[k] { + min[k] = current_value[k]; + index[k] = i; + } + } + } else { + if current_value < *min { + *min = current_value; + *index = Line::new(i); } } } @@ -205,9 +217,6 @@ impl ReduceNaiveInstruction for ReduceArgMin { _shape_reduce_dim: u32, ) { let (_, position) = accumulator; - #[unroll] - for k in 0..output.line_size() { - output[index][k] = EO::cast_from(position[k]); - } + output[index] = Line::cast_from(position) } } diff --git a/crates/cubecl-reduce/src/test.rs b/crates/cubecl-reduce/src/test.rs index f48ae3d4..1a5f75ac 100644 --- a/crates/cubecl-reduce/src/test.rs +++ b/crates/cubecl-reduce/src/test.rs @@ -164,8 +164,6 @@ macro_rules! impl_test_reduce { test.test_mean_dim_naive::<$float, TestRuntime>(&Default::default()); } - // Fix the line issue in argmax before running the test. - #[ignore] #[test] pub fn [< reduce_argmax_dim_naive_ $id >]() { let test = TestCase { @@ -179,8 +177,6 @@ macro_rules! impl_test_reduce { test.test_argmax_dim_naive::<$float, TestRuntime>(&Default::default()); } - // Fix the line issue in argmin before running the test. - #[ignore] #[test] pub fn [< reduce_argmin_dim_naive_ $id >]() { let test = TestCase { @@ -345,7 +341,7 @@ impl TestCase { let (best, _) = expected[output_index]; let candidate = values[input_index]; if candidate > best { - let coordinate = self.to_input_coordinate(input_index); + let coordinate = self.to_input_coordinate(input_index / self.line_size as usize); expected[output_index] = (candidate, coordinate[self.reduce_dim as usize] as u32); } } @@ -359,7 +355,7 @@ impl TestCase { let (best, _) = expected[output_index]; let candidate = values[input_index]; if candidate < best { - let coordinate = self.to_input_coordinate(input_index); + let coordinate = self.to_input_coordinate(input_index / self.line_size as usize); expected[output_index] = (candidate, coordinate[self.reduce_dim as usize] as u32); } } From 3c8b090765d004c53a90abbbd5083d533d61db72 Mon Sep 17 00:00:00 2001 From: maxime Date: Wed, 27 Nov 2024 20:02:57 -0500 Subject: [PATCH 9/9] add clippy exception for comptime if --- crates/cubecl-reduce/src/naive.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/crates/cubecl-reduce/src/naive.rs b/crates/cubecl-reduce/src/naive.rs index 0bd81f42..d655a847 100644 --- a/crates/cubecl-reduce/src/naive.rs +++ b/crates/cubecl-reduce/src/naive.rs @@ -153,6 +153,7 @@ impl ReduceNaiveInstruction for ReduceArgMax { fn accumulate(accumulator: &mut Self::Accumulator, current_value: Line, i: u32) { let (max, index) = accumulator; + #[allow(clippy::collapsible_else_if)] if comptime!(current_value.size() > 1) { #[unroll] for k in 0..current_value.size() { @@ -194,6 +195,7 @@ impl ReduceNaiveInstruction for ReduceArgMin { fn accumulate(accumulator: &mut Self::Accumulator, current_value: Line, i: u32) { let (min, index) = accumulator; + #[allow(clippy::collapsible_else_if)] if comptime!(current_value.size() > 1) { #[unroll] for k in 0..current_value.size() {