From 67e845f40df7f3883c906d94c6451a7d01d4134e Mon Sep 17 00:00:00 2001 From: Maxime Tremblay Date: Mon, 2 Dec 2024 09:41:26 -0500 Subject: [PATCH] Weekly chores (#332) --- crates/cubecl-core/src/ir/kernel.rs | 21 ++++ crates/cubecl-reduce/src/instructions.rs | 14 +-- crates/cubecl-reduce/src/naive.rs | 94 ++++++-------- crates/cubecl-reduce/src/shared.rs | 12 +- crates/cubecl-reduce/src/test.rs | 152 +++++++++++------------ crates/cubecl-runtime/src/server.rs | 22 ++++ 6 files changed, 172 insertions(+), 143 deletions(-) diff --git a/crates/cubecl-core/src/ir/kernel.rs b/crates/cubecl-core/src/ir/kernel.rs index 53990a2cc..1d5ba011e 100644 --- a/crates/cubecl-core/src/ir/kernel.rs +++ b/crates/cubecl-core/src/ir/kernel.rs @@ -304,6 +304,27 @@ pub struct CubeDim { } impl CubeDim { + /// Create a new cube dim with x = y = z = 1. + pub fn new_single() -> Self { + Self { x: 1, y: 1, z: 1 } + } + + /// Create a new cube dim with the given x, and y = z = 1. + pub fn new_1d(x: u32) -> Self { + Self { x, y: 1, z: 1 } + } + + /// Create a new cube dim with the given x and y, and z = 1. + pub fn new_2d(x: u32, y: u32) -> Self { + Self { x, y, z: 1 } + } + + /// Create a new cube dim with the given x, y and z. + /// This is equivalent to the [new](CubeDim::new) function. + pub fn new_3d(x: u32, y: u32, z: u32) -> Self { + Self { x, y, z } + } + pub fn num_elems(&self) -> u32 { self.x * self.y * self.z } diff --git a/crates/cubecl-reduce/src/instructions.rs b/crates/cubecl-reduce/src/instructions.rs index f08200b51..53f3a7238 100644 --- a/crates/cubecl-reduce/src/instructions.rs +++ b/crates/cubecl-reduce/src/instructions.rs @@ -2,10 +2,10 @@ use cubecl_core as cubecl; use cubecl_core::prelude::*; /// Compute the coordinate of the maximum item returning the smallest coordinate in case of equality. -pub struct ReduceArgMax; +pub struct ArgMax; #[cube] -impl ReduceArgMax { +impl ArgMax { /// Compare two pairs of items and coordinates and return a new pair /// where each element in the lines is the maximal item with its coordinate. /// In case of equality, the lowest coordinate is selected. @@ -27,10 +27,10 @@ impl ReduceArgMax { } /// Compute the coordinate of the minimum item returning the smallest coordinate in case of equality. -pub struct ReduceArgMin; +pub struct ArgMin; #[cube] -impl ReduceArgMin { +impl ArgMin { /// Compare two pairs of items and coordinates and return a new pair /// where each element in the lines is the minimal item with its coordinate. /// In case of equality, the lowest coordinate is selected. @@ -51,6 +51,6 @@ impl ReduceArgMin { } } -pub struct ReduceMean; -pub struct ReduceSum; -pub struct ReduceProd; +pub struct Mean; +pub struct Sum; +pub struct Prod; diff --git a/crates/cubecl-reduce/src/naive.rs b/crates/cubecl-reduce/src/naive.rs index d59e36980..a10de6a22 100644 --- a/crates/cubecl-reduce/src/naive.rs +++ b/crates/cubecl-reduce/src/naive.rs @@ -1,7 +1,7 @@ use cubecl_core as cubecl; use cubecl_core::prelude::*; -use crate::{ReduceArgMax, ReduceArgMin, ReduceMean, ReduceProd, ReduceSum}; +use crate::{ArgMax, ArgMin, Mean, Prod, Sum}; /// An instruction for the [reduce_naive](reduce_naive) algorithm. #[cube] @@ -16,10 +16,10 @@ pub trait ReduceNaiveInstruction: Send + Sync + 'static { /// This could be called many time during reduction. It is required /// that reducing the initial accumulator any number of times do not change the outcome /// of the reduction. For example, adding 0s in a sum do not change the outcome. - fn init_accumulator(line_size: u32) -> Self::Accumulator; + fn init_accumulator(#[comptime] line_size: u32) -> Self::Accumulator; /// Reduce `current_value` into `accumulator`. - fn accumulate(accumulator: &mut Self::Accumulator, current_value: Line, i: u32); + fn accumulate(accumulator: &mut Self::Accumulator, item: Line, coordinate: u32); /// Write the result of the reduction stored in `accumulator` into `output[index]`. fn write( @@ -56,12 +56,12 @@ pub fn reduce_naive, EI: Numeric, EO: Numeric>( // Reduce all the lines along `dim` for the previously computed offset. let mut accumulator = RD::init_accumulator(input.line_size()); - for i in 0..input.shape(dim) { - let index = i * input.stride(dim) + offset_input; + for coordinate in 0..input.shape(dim) { + let index = coordinate * input.stride(dim) + offset_input; RD::accumulate( &mut accumulator, unsafe { *input.index_unchecked(index) }, - i, + coordinate, ); } @@ -72,10 +72,10 @@ pub fn reduce_naive, EI: Numeric, EO: Numeric>( // Implementations for common instructions. #[cube] -impl ReduceNaiveInstruction for ReduceSum { +impl ReduceNaiveInstruction for Sum { type Accumulator = Line; - fn init_accumulator(line_size: u32) -> Line { + fn init_accumulator(#[comptime] line_size: u32) -> Line { Line::empty(line_size).fill(EI::from_int(0)) } @@ -94,15 +94,15 @@ impl ReduceNaiveInstruction for ReduceSum { } #[cube] -impl ReduceNaiveInstruction for ReduceProd { +impl ReduceNaiveInstruction for Prod { type Accumulator = Line; - fn init_accumulator(line_size: u32) -> Line { + fn init_accumulator(#[comptime] line_size: u32) -> Line { Line::empty(line_size).fill(EI::from_int(1)) } - fn accumulate(accumulator: &mut Self::Accumulator, current_value: Line, _i: u32) { - *accumulator *= current_value; + fn accumulate(accumulator: &mut Self::Accumulator, item: Line, _coordinate: u32) { + *accumulator *= item; } fn write( @@ -116,15 +116,15 @@ impl ReduceNaiveInstruction for ReduceProd { } #[cube] -impl ReduceNaiveInstruction for ReduceMean { +impl ReduceNaiveInstruction for Mean { type Accumulator = Line; - fn init_accumulator(line_size: u32) -> Self::Accumulator { + fn init_accumulator(#[comptime] line_size: u32) -> Self::Accumulator { Line::empty(line_size).fill(EI::from_int(0)) } - fn accumulate(accumulator: &mut Self::Accumulator, current_value: Line, _i: u32) { - *accumulator += current_value; + fn accumulate(accumulator: &mut Self::Accumulator, item: Line, _coordinate: u32) { + *accumulator += item; } fn write( @@ -140,10 +140,10 @@ impl ReduceNaiveInstruction for ReduceMean { } #[cube] -impl ReduceNaiveInstruction for ReduceArgMax { +impl ReduceNaiveInstruction for ArgMax { type Accumulator = (Line, Line); - fn init_accumulator(line_size: u32) -> Self::Accumulator { + fn init_accumulator(#[comptime] line_size: u32) -> Self::Accumulator { ( // TODO: switch to using f32::NEG_INFINITY when it's supported: https://github.com/tracel-ai/cubecl/issues/68 Line::empty(line_size).fill(EI::MIN), @@ -151,23 +151,16 @@ impl ReduceNaiveInstruction for ReduceArgMax { ) } - fn accumulate(accumulator: &mut Self::Accumulator, current_value: Line, i: u32) { - let (max, index) = accumulator; - #[allow(clippy::collapsible_else_if)] - if comptime!(current_value.size() > 1) { - #[unroll] - for k in 0..current_value.size() { - if current_value[k] > max[k] { - max[k] = current_value[k]; - index[k] = i; - } - } - } else { - if current_value > *max { - *max = current_value; - *index = Line::new(i); - } - } + fn accumulate(accumulator: &mut Self::Accumulator, item: Line, coordinate: u32) { + let (acc_item, acc_coordinate) = accumulator; + let (new_item, new_coordinate) = Self::choose_argmax( + *acc_item, + *acc_coordinate, + item, + Line::empty(item.size()).fill(coordinate), + ); + accumulator.0 = new_item; + accumulator.1 = new_coordinate; } fn write( @@ -182,10 +175,10 @@ impl ReduceNaiveInstruction for ReduceArgMax { } #[cube] -impl ReduceNaiveInstruction for ReduceArgMin { +impl ReduceNaiveInstruction for ArgMin { type Accumulator = (Line, Line); - fn init_accumulator(line_size: u32) -> Self::Accumulator { + fn init_accumulator(#[comptime] line_size: u32) -> Self::Accumulator { ( // TODO: switch to using f32::INFINITY when it's supported: https://github.com/tracel-ai/cubecl/issues/68 Line::empty(line_size).fill(EI::MAX), @@ -193,23 +186,16 @@ impl ReduceNaiveInstruction for ReduceArgMin { ) } - fn accumulate(accumulator: &mut Self::Accumulator, current_value: Line, i: u32) { - let (min, index) = accumulator; - #[allow(clippy::collapsible_else_if)] - if comptime!(current_value.size() > 1) { - #[unroll] - for k in 0..current_value.size() { - if current_value[k] < min[k] { - min[k] = current_value[k]; - index[k] = i; - } - } - } else { - if current_value < *min { - *min = current_value; - *index = Line::new(i); - } - } + fn accumulate(accumulator: &mut Self::Accumulator, item: Line, coordinate: u32) { + let (acc_item, acc_coordinate) = accumulator; + let (new_item, new_coordinate) = Self::choose_argmin( + *acc_item, + *acc_coordinate, + item, + Line::empty(item.size()).fill(coordinate), + ); + accumulator.0 = new_item; + accumulator.1 = new_coordinate; } fn write( diff --git a/crates/cubecl-reduce/src/shared.rs b/crates/cubecl-reduce/src/shared.rs index 9a08eeed1..df2575eaf 100644 --- a/crates/cubecl-reduce/src/shared.rs +++ b/crates/cubecl-reduce/src/shared.rs @@ -1,7 +1,7 @@ use cubecl_core as cubecl; use cubecl_core::prelude::*; -use crate::{ReduceArgMax, ReduceArgMin, ReduceMean, ReduceProd, ReduceSum}; +use crate::{ArgMax, ArgMin, Mean, Prod, Sum}; /// An instruction for the [reduce_shared](reduce_shared) algorithm. #[cube] @@ -139,7 +139,7 @@ fn div_ceil(a: u32, b: u32) -> u32 { // Implementations for common instructions. #[cube] -impl ReduceSharedInstruction for ReduceSum { +impl ReduceSharedInstruction for Sum { type Accumulator = SharedMemory>; fn create_accumulator( @@ -183,7 +183,7 @@ impl ReduceSharedInstruction for ReduceSum { } #[cube] -impl ReduceSharedInstruction for ReduceProd { +impl ReduceSharedInstruction for Prod { type Accumulator = SharedMemory>; fn create_accumulator( @@ -227,7 +227,7 @@ impl ReduceSharedInstruction for ReduceProd { } #[cube] -impl ReduceSharedInstruction for ReduceMean { +impl ReduceSharedInstruction for Mean { type Accumulator = SharedMemory>; fn create_accumulator( @@ -272,7 +272,7 @@ impl ReduceSharedInstruction for ReduceMean { } #[cube] -impl ReduceSharedInstruction for ReduceArgMax { +impl ReduceSharedInstruction for ArgMax { type Accumulator = (SharedMemory>, SharedMemory>); fn create_accumulator( @@ -340,7 +340,7 @@ impl ReduceSharedInstruction for ReduceArgMax { } #[cube] -impl ReduceSharedInstruction for ReduceArgMin { +impl ReduceSharedInstruction for ArgMin { type Accumulator = (SharedMemory>, SharedMemory>); fn create_accumulator( diff --git a/crates/cubecl-reduce/src/test.rs b/crates/cubecl-reduce/src/test.rs index 5fecb4343..e6c4e00aa 100644 --- a/crates/cubecl-reduce/src/test.rs +++ b/crates/cubecl-reduce/src/test.rs @@ -9,8 +9,8 @@ use rand::{ }; use crate::{ - reduce_naive, reduce_shared, ReduceArgMax, ReduceArgMin, ReduceMean, ReduceNaiveInstruction, - ReduceProd, ReduceSharedInstruction, ReduceSum, + reduce_naive, reduce_shared, ArgMax, ArgMin, Mean, Prod, ReduceNaiveInstruction, + ReduceSharedInstruction, Sum, }; // All random values generated for tests will be in the set @@ -72,8 +72,8 @@ macro_rules! testgen_reduce { shape: [4, 8], stride: [8, 1], reduce_dim: 0, - cube_count: CubeCount::Static(1, 1, 1), - cube_dim: CubeDim::new(4, 8, 1), + cube_count: CubeCount::new_single(), + cube_dim: CubeDim::new_2d(4, 8), line_size: 1, }, { @@ -81,8 +81,8 @@ macro_rules! testgen_reduce { shape: [8, 256], stride: [256, 1], reduce_dim: 1, - cube_count: CubeCount::Static(8, 1, 1), - cube_dim: CubeDim::new(16, 16, 1), + cube_count: CubeCount::new_1d(8), + cube_dim: CubeDim::new_2d(16, 16), line_size: 1, }, { @@ -90,8 +90,8 @@ macro_rules! testgen_reduce { shape: [8, 256], stride: [256, 1], reduce_dim: 0, - cube_count: CubeCount::Static(8, 1, 1), - cube_dim: CubeDim::new(16, 16, 1), + cube_count: CubeCount::new_1d(8), + cube_dim: CubeDim::new_2d(16, 16), line_size: 1, }, { @@ -99,8 +99,8 @@ macro_rules! testgen_reduce { shape: [16, 16, 16], stride: [1, 256, 16], reduce_dim: 2, - cube_count: CubeCount::Static(4, 1, 1), - cube_dim: CubeDim::new(16, 16, 1), + cube_count: CubeCount::new_1d(4), + cube_dim: CubeDim::new_2d(16, 16), line_size: 1, }, { @@ -108,8 +108,8 @@ macro_rules! testgen_reduce { shape: [11, 12, 13], stride: [156, 13, 1], reduce_dim: 1, - cube_count: CubeCount::Static(4, 1, 1), - cube_dim: CubeDim::new(16, 16, 1), + cube_count: CubeCount::new_1d(4), + cube_dim: CubeDim::new_2d(16, 16), line_size: 1, }, { @@ -117,8 +117,8 @@ macro_rules! testgen_reduce { shape: [32, 64], stride: [64, 1], reduce_dim: 0, - cube_count: CubeCount::Static(8, 1, 1), - cube_dim: CubeDim::new(16, 16, 1), + cube_count: CubeCount::new_1d(8), + cube_dim: CubeDim::new_2d(16, 16), line_size: 4, } ] @@ -133,8 +133,8 @@ macro_rules! testgen_reduce { shape: [4, 8], stride: [8, 1], reduce_dim: 0, - cube_count: CubeCount::Static(8, 1, 1), - cube_dim: CubeDim::new(2, 1, 1), + cube_count: CubeCount::new_1d(8), + cube_dim: CubeDim::new_1d(2), line_size: 1, }, { @@ -142,8 +142,8 @@ macro_rules! testgen_reduce { shape: [8, 256], stride: [256, 1], reduce_dim: 1, - cube_count: CubeCount::Static(8, 1, 1), - cube_dim: CubeDim::new(16, 1, 1), + cube_count: CubeCount::new_1d(8), + cube_dim: CubeDim::new_1d(16), line_size: 1, }, { @@ -151,8 +151,8 @@ macro_rules! testgen_reduce { shape: [16, 256], stride: [256, 1], reduce_dim: 0, - cube_count: CubeCount::Static(256, 1, 1), - cube_dim: CubeDim::new(5, 1, 1), + cube_count: CubeCount::new_1d(256), + cube_dim: CubeDim::new_1d(5), line_size: 1, }, { @@ -160,8 +160,8 @@ macro_rules! testgen_reduce { shape: [16, 16, 16], stride: [1, 256, 16], reduce_dim: 2, - cube_count: CubeCount::Static(16, 16, 1), - cube_dim: CubeDim::new(4, 1, 1), + cube_count: CubeCount::new_2d(16, 16), + cube_dim: CubeDim::new_1d(4), line_size: 1, }, { @@ -169,8 +169,8 @@ macro_rules! testgen_reduce { shape: [11, 12, 13], stride: [156, 13, 1], reduce_dim: 1, - cube_count: CubeCount::Static(11, 1, 13), - cube_dim: CubeDim::new(2, 1, 1), + cube_count: CubeCount::new_2d(11, 13), + cube_dim: CubeDim::new_1d(2), line_size: 1, }, { @@ -178,8 +178,8 @@ macro_rules! testgen_reduce { shape: [32, 64], stride: [64, 1], reduce_dim: 0, - cube_count: CubeCount::Static(64, 1, 1), - cube_dim: CubeDim::new(8, 1, 1), + cube_count: CubeCount::new_1d(64), + cube_dim: CubeDim::new_1d(8), line_size: 4, } ] @@ -212,7 +212,7 @@ macro_rules! impl_test_reduce { ::paste::paste! { $( #[test] - pub fn [< reduce_sum_dim_ $kind _ $id >]() { + pub fn [< reduce_sum_ $kind _ $id >]() { let test = TestCase { shape: $shape.into(), stride: $stride.into(), @@ -221,11 +221,11 @@ macro_rules! impl_test_reduce { cube_dim: $cube_dim, line_size:$line_size }; - test.[< test_sum_dim_ $kind >]::<$float, TestRuntime>(&Default::default()); + test.[< test_sum_ $kind >]::<$float, TestRuntime>(&Default::default()); } #[test] - pub fn [< reduce_prod_dim_ $kind _ $id >]() { + pub fn [< reduce_prod_ $kind _ $id >]() { let test = TestCase { shape: $shape.into(), stride: $stride.into(), @@ -234,11 +234,11 @@ macro_rules! impl_test_reduce { cube_dim: $cube_dim, line_size:$line_size }; - test.[< test_prod_dim_ $kind >]::<$float, TestRuntime>(&Default::default()); + test.[< test_prod_ $kind >]::<$float, TestRuntime>(&Default::default()); } #[test] - pub fn [< reduce_mean_dim_ $kind _ $id >]() { + pub fn [< reduce_mean_ $kind _ $id >]() { let test = TestCase { shape: $shape.into(), stride: $stride.into(), @@ -247,11 +247,11 @@ macro_rules! impl_test_reduce { cube_dim: $cube_dim, line_size:$line_size }; - test.[< test_mean_dim_ $kind >]::<$float, TestRuntime>(&Default::default()); + test.[< test_mean_ $kind >]::<$float, TestRuntime>(&Default::default()); } #[test] - pub fn [< reduce_argmax_dim_ $kind _ $id >]() { + pub fn [< reduce_argmax_ $kind _ $id >]() { let test = TestCase { shape: $shape.into(), stride: $stride.into(), @@ -260,11 +260,11 @@ macro_rules! impl_test_reduce { cube_dim: $cube_dim, line_size:$line_size }; - test.[< test_argmax_dim_ $kind >]::<$float, TestRuntime>(&Default::default()); + test.[< test_argmax_ $kind >]::<$float, TestRuntime>(&Default::default()); } #[test] - pub fn [< reduce_argmin_dim_ $kind _ $id >]() { + pub fn [< reduce_argmin_ $kind _ $id >]() { let test = TestCase { shape: $shape.into(), stride: $stride.into(), @@ -273,7 +273,7 @@ macro_rules! impl_test_reduce { cube_dim: $cube_dim, line_size:$line_size }; - test.[< test_argmin_dim_ $kind >]::<$float, TestRuntime>(&Default::default()); + test.[< test_argmin_ $kind >]::<$float, TestRuntime>(&Default::default()); } )* } @@ -291,54 +291,54 @@ pub struct TestCase { } impl TestCase { - pub fn test_sum_dim_naive(&self, device: &R::Device) + pub fn test_sum_naive(&self, device: &R::Device) where F: Float + CubeElement + std::fmt::Display, R: Runtime, { let input_values: Vec = self.random_input_values(); - let expected_values = self.cpu_sum_dim(&input_values); - self.run_test_naive::(device, input_values, expected_values) + let expected_values = self.cpu_sum(&input_values); + self.run_test_naive::(device, input_values, expected_values) } - pub fn test_prod_dim_naive(&self, device: &R::Device) + pub fn test_prod_naive(&self, device: &R::Device) where F: Float + CubeElement + std::fmt::Display, R: Runtime, { let input_values: Vec = self.random_input_values(); - let expected_values = self.cpu_prod_dim(&input_values); - self.run_test_naive::(device, input_values, expected_values) + let expected_values = self.cpu_prod(&input_values); + self.run_test_naive::(device, input_values, expected_values) } - pub fn test_mean_dim_naive(&self, device: &R::Device) + pub fn test_mean_naive(&self, device: &R::Device) where F: Float + CubeElement + std::fmt::Display, R: Runtime, { let input_values: Vec = self.random_input_values(); - let expected_values = self.cpu_mean_dim(&input_values); - self.run_test_naive::(device, input_values, expected_values) + let expected_values = self.cpu_mean(&input_values); + self.run_test_naive::(device, input_values, expected_values) } - pub fn test_argmax_dim_naive(&self, device: &R::Device) + pub fn test_argmax_naive(&self, device: &R::Device) where F: Float + CubeElement + std::fmt::Display, R: Runtime, { let input_values: Vec = self.random_input_values(); - let expected_values = self.cpu_argmax_dim(&input_values); - self.run_test_naive::(device, input_values, expected_values) + let expected_values = self.cpu_argmax(&input_values); + self.run_test_naive::(device, input_values, expected_values) } - pub fn test_argmin_dim_naive(&self, device: &R::Device) + pub fn test_argmin_naive(&self, device: &R::Device) where F: Float + CubeElement + std::fmt::Display, R: Runtime, { let input_values: Vec = self.random_input_values(); - let expected_values = self.cpu_argmin_dim(&input_values); - self.run_test_naive::(device, input_values, expected_values) + let expected_values = self.cpu_argmin(&input_values); + self.run_test_naive::(device, input_values, expected_values) } pub fn run_test_naive( @@ -395,54 +395,54 @@ impl TestCase { assert_approx_equal_abs(output_values, &expected_values, 1.0 / (PRECISION as f32)); } - pub fn test_sum_dim_shared(&self, device: &R::Device) + pub fn test_sum_shared(&self, device: &R::Device) where F: Float + CubeElement + std::fmt::Display, R: Runtime, { let input_values: Vec = self.random_input_values(); - let expected_values = self.cpu_sum_dim(&input_values); - self.run_test_shared::(device, input_values, expected_values) + let expected_values = self.cpu_sum(&input_values); + self.run_test_shared::(device, input_values, expected_values) } - pub fn test_prod_dim_shared(&self, device: &R::Device) + pub fn test_prod_shared(&self, device: &R::Device) where F: Float + CubeElement + std::fmt::Display, R: Runtime, { let input_values: Vec = self.random_input_values(); - let expected_values = self.cpu_prod_dim(&input_values); - self.run_test_shared::(device, input_values, expected_values) + let expected_values = self.cpu_prod(&input_values); + self.run_test_shared::(device, input_values, expected_values) } - pub fn test_mean_dim_shared(&self, device: &R::Device) + pub fn test_mean_shared(&self, device: &R::Device) where F: Float + CubeElement + std::fmt::Display, R: Runtime, { let input_values: Vec = self.random_input_values(); - let expected_values = self.cpu_mean_dim(&input_values); - self.run_test_shared::(device, input_values, expected_values) + let expected_values = self.cpu_mean(&input_values); + self.run_test_shared::(device, input_values, expected_values) } - pub fn test_argmax_dim_shared(&self, device: &R::Device) + pub fn test_argmax_shared(&self, device: &R::Device) where F: Float + CubeElement + std::fmt::Display, R: Runtime, { let input_values: Vec = self.random_input_values(); - let expected_values = self.cpu_argmax_dim(&input_values); - self.run_test_shared::(device, input_values, expected_values) + let expected_values = self.cpu_argmax(&input_values); + self.run_test_shared::(device, input_values, expected_values) } - pub fn test_argmin_dim_shared(&self, device: &R::Device) + pub fn test_argmin_shared(&self, device: &R::Device) where F: Float + CubeElement + std::fmt::Display, R: Runtime, { let input_values: Vec = self.random_input_values(); - let expected_values = self.cpu_argmin_dim(&input_values); - self.run_test_shared::(device, input_values, expected_values) + let expected_values = self.cpu_argmin(&input_values); + self.run_test_shared::(device, input_values, expected_values) } pub fn run_test_shared( @@ -503,7 +503,7 @@ impl TestCase { assert_approx_equal_abs(output_values, &expected_values, 1.0 / (PRECISION as f32)); } - fn cpu_sum_dim(&self, values: &[F]) -> Vec { + fn cpu_sum(&self, values: &[F]) -> Vec { let mut expected = vec![F::new(0.0); self.num_output_values()]; #[allow(clippy::needless_range_loop)] for input_index in 0..values.len() { @@ -513,7 +513,7 @@ impl TestCase { expected } - fn cpu_prod_dim(&self, values: &[F]) -> Vec { + fn cpu_prod(&self, values: &[F]) -> Vec { let mut expected = vec![F::new(1.0); self.num_output_values()]; #[allow(clippy::needless_range_loop)] for value_index in 0..values.len() { @@ -523,14 +523,14 @@ impl TestCase { expected } - fn cpu_mean_dim(&self, values: &[F]) -> Vec { - self.cpu_sum_dim(values) + fn cpu_mean(&self, values: &[F]) -> Vec { + self.cpu_sum(values) .into_iter() .map(|sum| sum / F::new(self.shape[self.reduce_dim as usize] as f32)) .collect() } - fn cpu_argmax_dim(&self, values: &[F]) -> Vec { + fn cpu_argmax(&self, values: &[F]) -> Vec { let mut expected = vec![(F::MIN, 0_u32); self.num_output_values()]; #[allow(clippy::needless_range_loop)] for input_index in 0..values.len() { @@ -545,7 +545,7 @@ impl TestCase { expected.into_iter().map(|(_, i)| i).collect() } - fn cpu_argmin_dim(&self, values: &[F]) -> Vec { + fn cpu_argmin(&self, values: &[F]) -> Vec { let mut expected = vec![(F::MAX, 0_u32); self.num_output_values()]; #[allow(clippy::needless_range_loop)] for input_index in 0..values.len() { @@ -590,13 +590,13 @@ impl TestCase { } fn output_stride(&self) -> Vec { - let dim_stride = self.stride[self.reduce_dim as usize]; - let dim_shape = self.shape[self.reduce_dim as usize]; + let stride = self.stride[self.reduce_dim as usize]; + let shape = self.shape[self.reduce_dim as usize]; self.stride .iter() - .map(|s| match s.cmp(&dim_stride) { + .map(|s| match s.cmp(&stride) { std::cmp::Ordering::Equal => 1, - std::cmp::Ordering::Greater => s / dim_shape, + std::cmp::Ordering::Greater => s / shape, std::cmp::Ordering::Less => *s, }) .collect() diff --git a/crates/cubecl-runtime/src/server.rs b/crates/cubecl-runtime/src/server.rs index f25defff6..ad0cd30da 100644 --- a/crates/cubecl-runtime/src/server.rs +++ b/crates/cubecl-runtime/src/server.rs @@ -178,6 +178,28 @@ pub enum CubeCount { Dynamic(Binding), } +impl CubeCount { + /// Create a new static cube count with the given x = y = z = 1. + pub fn new_single() -> Self { + CubeCount::Static(1, 1, 1) + } + + /// Create a new static cube count with the given x, and y = z = 1. + pub fn new_1d(x: u32) -> Self { + CubeCount::Static(x, 1, 1) + } + + /// Create a new static cube count with the given x and y, and z = 1. + pub fn new_2d(x: u32, y: u32) -> Self { + CubeCount::Static(x, y, 1) + } + + /// Create a new static cube count with the given x, y and z. + pub fn new_3d(x: u32, y: u32) -> Self { + CubeCount::Static(x, y, 1) + } +} + impl Debug for CubeCount { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { match self {