Skip to content

Commit

Permalink
fix ceiling division for num_planes
Browse files Browse the repository at this point in the history
  • Loading branch information
maxtremblay committed Nov 22, 2024
1 parent 30bede6 commit d95160d
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 5 deletions.
15 changes: 11 additions & 4 deletions crates/cubecl-std/src/reduce/sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ pub fn reduce_sum_lined<N: Numeric>(
reduce_sum_lines(&tmp.to_slice(), &mut output.to_slice_mut(), 1_u32);
}



/// Compute the sum of all elements of `input` and write it to the first element of `output`.
#[cube]
pub fn reduce_sum_vector<N: Numeric>(
Expand All @@ -45,11 +47,10 @@ pub fn reduce_sum_vector<N: Numeric>(
#[comptime] config: ReduceConfig,
) {
let plane_id = UNIT_POS / PLANE_DIM;
let num_planes = CUBE_DIM / PLANE_DIM;
let num_planes = div_ceil(CUBE_DIM, PLANE_DIM);

// This is an integer division rounded up. It computes the number of required iterations
// to reduce all lines when reducing CUBE_DIM lines per iteration.
let num_iterations = input.len() / CUBE_DIM + (input.len() % CUBE_DIM > 0) as u32;
// Compute the number of required iterations to reduce all lines when reducing CUBE_DIM lines per iteration.
let num_iterations = div_ceil(input.len(), CUBE_DIM);

let mut memory = SharedMemory::new_lined(config.max_num_planes, input[0].size());
memory[plane_id] = Line::empty(config.line_size).fill(N::from_int(0));
Expand Down Expand Up @@ -103,3 +104,9 @@ pub fn reduce_sum_lines<N: Numeric>(
output[UNIT_POS] = sum;
}
}

// Integer division rounded up.
#[cube]
fn div_ceil(a: u32, b: u32) -> u32 {
a / b + ((a % b) > 0) as u32
}
2 changes: 1 addition & 1 deletion crates/cubecl-std/src/reduce/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ macro_rules! testgen_reduce {

#[test]
pub fn reduce_sum_vector_single_plane_line_size_four() {
let mut test = TestCase::new(
let test = TestCase::new(
// input
TestTensorParts::new_vector((0..32).collect()).with_line_size(4),
// output
Expand Down

0 comments on commit d95160d

Please sign in to comment.