diff --git a/crates/cubecl-core/src/frontend/topology.rs b/crates/cubecl-core/src/frontend/topology.rs index 10c6da024..a33c242a9 100644 --- a/crates/cubecl-core/src/frontend/topology.rs +++ b/crates/cubecl-core/src/frontend/topology.rs @@ -30,6 +30,14 @@ The total amount of working units in a plane. " ); +constant!( + UNIT_POS_PLANE, + crate::ir::Builtin::UnitPosPlane, + r" +The relative position of the working unit inside the plane, without regards to cube dimensions. +" +); + constant!( UNIT_POS, crate::ir::Builtin::UnitPos, diff --git a/crates/cubecl-core/src/ir/variable.rs b/crates/cubecl-core/src/ir/variable.rs index 7fb757e3d..e3b1f5b65 100644 --- a/crates/cubecl-core/src/ir/variable.rs +++ b/crates/cubecl-core/src/ir/variable.rs @@ -71,6 +71,7 @@ pub enum Builtin { CubeCountY, CubeCountZ, PlaneDim, + UnitPosPlane, AbsolutePos, AbsolutePosX, AbsolutePosY, diff --git a/crates/cubecl-cpp/src/shared/base.rs b/crates/cubecl-cpp/src/shared/base.rs index 7e53611ba..706913afc 100644 --- a/crates/cubecl-cpp/src/shared/base.rs +++ b/crates/cubecl-cpp/src/shared/base.rs @@ -854,6 +854,7 @@ impl CppCompiler { Variable::GridDimGlobal } gpu::Builtin::PlaneDim => Variable::WarpSize, + gpu::Builtin::UnitPosPlane => Variable::ThreadIdxWarp, }, gpu::VariableKind::LocalArray { id, depth, length } => { let item = self.compile_item(item); diff --git a/crates/cubecl-cpp/src/shared/element.rs b/crates/cubecl-cpp/src/shared/element.rs index 88f93968f..5d774cbe7 100644 --- a/crates/cubecl-cpp/src/shared/element.rs +++ b/crates/cubecl-cpp/src/shared/element.rs @@ -143,6 +143,7 @@ impl Component for Variable { Variable::GridDimZ => Item::scalar(Elem::U32), Variable::LocalArray(_, e, _, _) => *e, Variable::WarpSize => Item::scalar(Elem::U32), + Variable::ThreadIdxWarp => Item::scalar(Elem::U32), Variable::WmmaFragment { id: _, frag, @@ -163,6 +164,7 @@ impl Component for Variable { #[derive(Debug, Clone, Copy, PartialEq)] pub enum Variable { WarpSize, + ThreadIdxWarp, GlobalInputArray(u16, Item), GlobalOutputArray(u16, Item), GlobalScalar(u16, Elem, gpu::Elem), @@ -285,6 +287,7 @@ impl Display for Variable { write!(f, "l_arr_{}_{}", id, depth) } Variable::WarpSize => f.write_str("warpSize"), + Variable::ThreadIdxWarp => f.write_str("threadIdxGlobal % warpSize"), Variable::WmmaFragment { id: index, frag: _, @@ -416,6 +419,7 @@ impl Variable { Variable::GridDimZ => true, Variable::LocalArray(_, _, _, _) => false, Variable::WarpSize => true, + Variable::ThreadIdxWarp => true, Variable::WmmaFragment { .. } => false, Variable::BlockIdxGlobal => true, Variable::BlockDimGlobal => true, diff --git a/crates/cubecl-hip/Cargo.toml b/crates/cubecl-hip/Cargo.toml index f2249a5ca..e3d89ec32 100644 --- a/crates/cubecl-hip/Cargo.toml +++ b/crates/cubecl-hip/Cargo.toml @@ -43,4 +43,5 @@ cubecl-core = { path = "../cubecl-core", version = "0.4.0", features = [ cubecl-linalg = { path = "../cubecl-linalg", version = "0.4.0", features = [ "export_tests", ] } +cubecl-std = { path = "../cubecl-std", version = "0.4.0" } pretty_assertions = { workspace = true } diff --git a/crates/cubecl-hip/src/lib.rs b/crates/cubecl-hip/src/lib.rs index 4acc4434f..59c382615 100644 --- a/crates/cubecl-hip/src/lib.rs +++ b/crates/cubecl-hip/src/lib.rs @@ -27,4 +27,5 @@ mod tests { cubecl_core::testgen_all!(); cubecl_linalg::testgen_cmma_matmul!(); + cubecl_std::testgen_reduce!(); } diff --git a/crates/cubecl-macros/src/scope.rs b/crates/cubecl-macros/src/scope.rs index 3f9a0696d..95cfa17a9 100644 --- a/crates/cubecl-macros/src/scope.rs +++ b/crates/cubecl-macros/src/scope.rs @@ -9,7 +9,7 @@ use syn::{parse_quote, Ident, Type}; use crate::parse::kernel::KernelParam; -pub const KEYWORDS: [&str; 21] = [ +pub const KEYWORDS: [&str; 22] = [ "ABSOLUTE_POS", "ABSOLUTE_POS_X", "ABSOLUTE_POS_Y", @@ -31,6 +31,7 @@ pub const KEYWORDS: [&str; 21] = [ "CUBE_COUNT_Y", "CUBE_COUNT_Z", "PLANE_DIM", + "UNIT_POS_PLANE", ]; pub type Scope = usize; diff --git a/crates/cubecl-spirv/src/globals.rs b/crates/cubecl-spirv/src/globals.rs index 6178c1e0d..633a018bf 100644 --- a/crates/cubecl-spirv/src/globals.rs +++ b/crates/cubecl-spirv/src/globals.rs @@ -113,6 +113,17 @@ impl SpirvCompiler { }); Variable::SubgroupSize(id) } + Builtin::UnitPosPlane => { + let id = self.get_or_insert_global(Globals::SubgroupInvocationId, |b| { + let id = b.load_builtin( + BuiltIn::SubgroupLocalInvocationId, + Item::Scalar(Elem::Int(32, false)), + ); + b.debug_name(id, "UNIT_POS_PLANE"); + id + }); + Variable::SubgroupSize(id) + } Builtin::CubePos => { let id = self.get_or_insert_global(Globals::WorkgroupIndex, |b| { let x = b.compile_variable(built_var(Builtin::CubePosX)).id(b); diff --git a/crates/cubecl-spirv/src/variable.rs b/crates/cubecl-spirv/src/variable.rs index 46a06dc75..f748f109f 100644 --- a/crates/cubecl-spirv/src/variable.rs +++ b/crates/cubecl-spirv/src/variable.rs @@ -340,6 +340,7 @@ pub enum Globals { NumWorkgroupsY, NumWorkgroupsZ, SubgroupSize, + SubgroupInvocationId, Metadata(u32), } diff --git a/crates/cubecl-std/src/reduce/sum.rs b/crates/cubecl-std/src/reduce/sum.rs index ed4677490..4b964d863 100644 --- a/crates/cubecl-std/src/reduce/sum.rs +++ b/crates/cubecl-std/src/reduce/sum.rs @@ -4,8 +4,7 @@ use cubecl_core::prelude::*; #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub struct ReduceConfig { pub line_size: u32, - pub plane_size: u32, - pub num_planes: u32, + pub max_num_planes: u32, } /// Compute the sum of all elements of `input` and write it to the first element of `output`. @@ -45,28 +44,27 @@ pub fn reduce_sum_vector( output: &mut SliceMut>, #[comptime] config: ReduceConfig, ) { - // How many lines accounted in each iteration. - let block_size = config.plane_size * config.num_planes; + let plane_id = UNIT_POS / PLANE_DIM; + let num_planes = div_ceil(CUBE_DIM, PLANE_DIM); - // This is an integer division rounded up. - let num_blocks = input.len() / block_size + (input.len() % block_size > 0) as u32; + // Compute the number of required iterations to reduce all lines when reducing CUBE_DIM lines per iteration. + let num_iterations = div_ceil(input.len(), CUBE_DIM); - let mut memory = SharedMemory::new_lined(config.num_planes, config.line_size); + let mut memory = SharedMemory::new_lined(config.max_num_planes, input[0].size()); + memory[plane_id] = Line::empty(config.line_size).fill(N::from_int(0)); - memory[UNIT_POS_Y] = Line::empty(config.line_size).fill(N::from_int(0)); - - // For each block, we reduce each group of plane_size lines to a single line. Then, we accumulate the results + // For each iteration, each plane reduces PLANE_DIM lines into a single line. Then, we accumulate the results // into the memory. Thus, after the loop, the reduction of the memory yields the expected output. - for i in 0..num_blocks { - let index = i * block_size + UNIT_POS_Y * config.plane_size + UNIT_POS_X; + for i in 0..num_iterations { + let index = i * CUBE_DIM + plane_id * PLANE_DIM + UNIT_POS_PLANE; let value = select( index < input.len(), input[index], Line::empty(config.line_size).fill(N::from_int(0)), ); let sum = plane_sum(value); - if UNIT_POS_X == 0 { - memory[UNIT_POS_Y] += sum; + if UNIT_POS_PLANE == 0 { + memory[plane_id] += sum; } } @@ -75,8 +73,8 @@ pub fn reduce_sum_vector( // Sum each elements in memory let sum = plane_sum(select( - UNIT_POS_X < config.num_planes, - memory[UNIT_POS_X], + UNIT_POS_PLANE < num_planes, + memory[UNIT_POS_PLANE], Line::empty(config.line_size).fill(N::from_int(0)), )); if UNIT_POS == 0 { @@ -104,3 +102,9 @@ pub fn reduce_sum_lines( output[UNIT_POS] = sum; } } + +// Integer division rounded up. +#[cube] +fn div_ceil(a: u32, b: u32) -> u32 { + a / b + ((a % b) > 0) as u32 +} diff --git a/crates/cubecl-std/src/reduce/test.rs b/crates/cubecl-std/src/reduce/test.rs index 1790eb7f7..0186c683c 100644 --- a/crates/cubecl-std/src/reduce/test.rs +++ b/crates/cubecl-std/src/reduce/test.rs @@ -28,7 +28,7 @@ macro_rules! testgen_reduce { #[test] pub fn reduce_sum_vector_single_plane_line_size_four() { - let mut test = TestCase::new( + let test = TestCase::new( // input TestTensorParts::new_vector((0..32).collect()).with_line_size(4), // output @@ -76,7 +76,7 @@ macro_rules! testgen_reduce { // expected vec![8128], ); - test.cube_dim = CubeDim::new(32, 4, 1); + test.cube_dim = CubeDim::new(128, 1, 1); impl_reduce_sum_test::(&Default::default(), test); } @@ -103,7 +103,7 @@ macro_rules! testgen_reduce { // expected vec![4950], ); - test.cube_dim = CubeDim::new(32, 4, 1); + test.cube_dim = CubeDim::new(128, 1, 1); impl_reduce_sum_test::(&Default::default(), test); } @@ -117,7 +117,7 @@ macro_rules! testgen_reduce { // expected vec![4950], ); - test.cube_dim = CubeDim::new(32, 4, 1); + test.cube_dim = CubeDim::new(128, 1, 1); test.reduce_lines = true; impl_reduce_sum_test::(&Default::default(), test); } @@ -133,7 +133,7 @@ macro_rules! testgen_reduce { vec![523776.0], ); test.tolerance = Some(1e-9); - test.cube_dim = CubeDim::new(32, 8, 1); + test.cube_dim = CubeDim::new(256, 1, 1); impl_reduce_sum_test::(&Default::default(), test); } @@ -148,7 +148,7 @@ macro_rules! testgen_reduce { vec![8128.0], ); test.tolerance = Some(1e-9); - test.cube_dim = CubeDim::new(32, 8, 1); + test.cube_dim = CubeDim::new(256, 1, 1); impl_reduce_sum_test::(&Default::default(), test); } }; @@ -221,8 +221,8 @@ pub fn impl_reduce_sum_test true, Variable::NumWorkgroups => true, Variable::SubgroupSize => true, + Variable::SubgroupInvocationId => true, } } pub fn index(&self, index: usize) -> IndexedVariable { @@ -175,6 +177,7 @@ impl Variable { Self::NumWorkgroupsY => Item::Scalar(Elem::U32), Self::NumWorkgroupsZ => Item::Scalar(Elem::U32), Self::SubgroupSize => Item::Scalar(Elem::U32), + Self::SubgroupInvocationId => Item::Scalar(Elem::U32), } } pub fn elem(&self) -> Elem { @@ -340,6 +343,7 @@ impl Display for Variable { Variable::WorkgroupSize => f.write_str("workgroup_size_no_axis"), Variable::NumWorkgroups => f.write_str("num_workgroups_no_axis"), Variable::SubgroupSize => f.write_str("subgroup_size"), + Variable::SubgroupInvocationId => f.write_str("subgroup_invocation_id"), } } } diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs b/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs index 41f49a6fc..a7355ec75 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs @@ -34,6 +34,7 @@ pub struct WgslCompiler { global_invocation_id: bool, workgroup_id: bool, subgroup_size: bool, + subgroup_invocation_id: bool, id: bool, num_workgroups: bool, workgroup_id_no_axis: bool, @@ -277,6 +278,7 @@ impl WgslCompiler { || self.workgroup_id_no_axis, workgroup_id: self.workgroup_id || self.workgroup_id_no_axis, subgroup_size: self.subgroup_size, + subgroup_invocation_id: self.subgroup_invocation_id, body, extensions, num_workgroups_no_axis: self.num_workgroup_no_axis, @@ -462,6 +464,10 @@ impl WgslCompiler { self.subgroup_size = true; wgsl::Variable::SubgroupSize } + cube::Builtin::UnitPosPlane => { + self.subgroup_invocation_id = true; + wgsl::Variable::SubgroupInvocationId + } }, cube::VariableKind::Matrix { .. } => { panic!("Cooperative matrix-multiply and accumulate not supported.") diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/shader.rs b/crates/cubecl-wgpu/src/compiler/wgsl/shader.rs index 9dc0c089b..1cd0333df 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/shader.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/shader.rs @@ -83,6 +83,7 @@ pub struct ComputeShader { pub num_workgroups: bool, pub workgroup_id: bool, pub subgroup_size: bool, + pub subgroup_invocation_id: bool, pub num_workgroups_no_axis: bool, pub workgroup_id_no_axis: bool, pub workgroup_size_no_axis: bool, @@ -171,6 +172,9 @@ fn {}( if self.subgroup_size { f.write_str(" @builtin(subgroup_size) subgroup_size: u32,\n")?; } + if self.subgroup_invocation_id { + f.write_str(" @builtin(subgroup_invocation_id) subgroup_invocation_id: u32,\n")?; + } // Open body f.write_str(") {\n")?;