diff --git a/crates/cubecl-std/src/reduce/sum.rs b/crates/cubecl-std/src/reduce/sum.rs index 5b577ec6..53259e61 100644 --- a/crates/cubecl-std/src/reduce/sum.rs +++ b/crates/cubecl-std/src/reduce/sum.rs @@ -37,6 +37,8 @@ pub fn reduce_sum_lined( reduce_sum_lines(&tmp.to_slice(), &mut output.to_slice_mut(), 1_u32); } + + /// Compute the sum of all elements of `input` and write it to the first element of `output`. #[cube] pub fn reduce_sum_vector( @@ -45,11 +47,10 @@ pub fn reduce_sum_vector( #[comptime] config: ReduceConfig, ) { let plane_id = UNIT_POS / PLANE_DIM; - let num_planes = CUBE_DIM / PLANE_DIM; + let num_planes = div_ceil(CUBE_DIM, PLANE_DIM); - // This is an integer division rounded up. It computes the number of required iterations - // to reduce all lines when reducing CUBE_DIM lines per iteration. - let num_iterations = input.len() / CUBE_DIM + (input.len() % CUBE_DIM > 0) as u32; + // Compute the number of required iterations to reduce all lines when reducing CUBE_DIM lines per iteration. + let num_iterations = div_ceil(input.len(), CUBE_DIM); let mut memory = SharedMemory::new_lined(config.max_num_planes, input[0].size()); memory[plane_id] = Line::empty(config.line_size).fill(N::from_int(0)); @@ -103,3 +104,9 @@ pub fn reduce_sum_lines( output[UNIT_POS] = sum; } } + +// Integer division rounded up. +#[cube] +fn div_ceil(a: u32, b: u32) -> u32 { + a / b + ((a % b) > 0) as u32 +} diff --git a/crates/cubecl-std/src/reduce/test.rs b/crates/cubecl-std/src/reduce/test.rs index 07337e51..0186c683 100644 --- a/crates/cubecl-std/src/reduce/test.rs +++ b/crates/cubecl-std/src/reduce/test.rs @@ -28,7 +28,7 @@ macro_rules! testgen_reduce { #[test] pub fn reduce_sum_vector_single_plane_line_size_four() { - let mut test = TestCase::new( + let test = TestCase::new( // input TestTensorParts::new_vector((0..32).collect()).with_line_size(4), // output