From f7de7c444dc74c7d72070179930173d2b46844bb Mon Sep 17 00:00:00 2001 From: maxime Date: Fri, 22 Nov 2024 11:11:41 -0500 Subject: [PATCH 1/8] Introduce a new UnitPosPlane constant --- crates/cubecl-core/src/frontend/topology.rs | 8 ++++++++ crates/cubecl-core/src/ir/variable.rs | 1 + crates/cubecl-cpp/src/shared/base.rs | 1 + crates/cubecl-cpp/src/shared/element.rs | 4 ++++ crates/cubecl-macros/src/scope.rs | 3 ++- crates/cubecl-wgpu/src/compiler/wgsl/base.rs | 6 +++++- crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs | 6 ++++++ crates/cubecl-wgpu/src/compiler/wgsl/shader.rs | 4 ++++ 8 files changed, 31 insertions(+), 2 deletions(-) diff --git a/crates/cubecl-core/src/frontend/topology.rs b/crates/cubecl-core/src/frontend/topology.rs index 10c6da024..a33c242a9 100644 --- a/crates/cubecl-core/src/frontend/topology.rs +++ b/crates/cubecl-core/src/frontend/topology.rs @@ -30,6 +30,14 @@ The total amount of working units in a plane. " ); +constant!( + UNIT_POS_PLANE, + crate::ir::Builtin::UnitPosPlane, + r" +The relative position of the working unit inside the plane, without regards to cube dimensions. +" +); + constant!( UNIT_POS, crate::ir::Builtin::UnitPos, diff --git a/crates/cubecl-core/src/ir/variable.rs b/crates/cubecl-core/src/ir/variable.rs index 7fb757e3d..e3b1f5b65 100644 --- a/crates/cubecl-core/src/ir/variable.rs +++ b/crates/cubecl-core/src/ir/variable.rs @@ -71,6 +71,7 @@ pub enum Builtin { CubeCountY, CubeCountZ, PlaneDim, + UnitPosPlane, AbsolutePos, AbsolutePosX, AbsolutePosY, diff --git a/crates/cubecl-cpp/src/shared/base.rs b/crates/cubecl-cpp/src/shared/base.rs index 7788f5d52..08d0cd685 100644 --- a/crates/cubecl-cpp/src/shared/base.rs +++ b/crates/cubecl-cpp/src/shared/base.rs @@ -837,6 +837,7 @@ impl CppCompiler { Variable::GridDimGlobal } gpu::Builtin::PlaneDim => Variable::WarpSize, + gpu::Builtin::UnitPosPlane => Variable::ThreadIdxWarp, }, gpu::VariableKind::LocalArray { id, depth, length } => { let item = self.compile_item(item); diff --git a/crates/cubecl-cpp/src/shared/element.rs b/crates/cubecl-cpp/src/shared/element.rs index 88f93968f..5d774cbe7 100644 --- a/crates/cubecl-cpp/src/shared/element.rs +++ b/crates/cubecl-cpp/src/shared/element.rs @@ -143,6 +143,7 @@ impl Component for Variable { Variable::GridDimZ => Item::scalar(Elem::U32), Variable::LocalArray(_, e, _, _) => *e, Variable::WarpSize => Item::scalar(Elem::U32), + Variable::ThreadIdxWarp => Item::scalar(Elem::U32), Variable::WmmaFragment { id: _, frag, @@ -163,6 +164,7 @@ impl Component for Variable { #[derive(Debug, Clone, Copy, PartialEq)] pub enum Variable { WarpSize, + ThreadIdxWarp, GlobalInputArray(u16, Item), GlobalOutputArray(u16, Item), GlobalScalar(u16, Elem, gpu::Elem), @@ -285,6 +287,7 @@ impl Display for Variable { write!(f, "l_arr_{}_{}", id, depth) } Variable::WarpSize => f.write_str("warpSize"), + Variable::ThreadIdxWarp => f.write_str("threadIdxGlobal % warpSize"), Variable::WmmaFragment { id: index, frag: _, @@ -416,6 +419,7 @@ impl Variable { Variable::GridDimZ => true, Variable::LocalArray(_, _, _, _) => false, Variable::WarpSize => true, + Variable::ThreadIdxWarp => true, Variable::WmmaFragment { .. } => false, Variable::BlockIdxGlobal => true, Variable::BlockDimGlobal => true, diff --git a/crates/cubecl-macros/src/scope.rs b/crates/cubecl-macros/src/scope.rs index 3f9a0696d..95cfa17a9 100644 --- a/crates/cubecl-macros/src/scope.rs +++ b/crates/cubecl-macros/src/scope.rs @@ -9,7 +9,7 @@ use syn::{parse_quote, Ident, Type}; use crate::parse::kernel::KernelParam; -pub const KEYWORDS: [&str; 21] = [ +pub const KEYWORDS: [&str; 22] = [ "ABSOLUTE_POS", "ABSOLUTE_POS_X", "ABSOLUTE_POS_Y", @@ -31,6 +31,7 @@ pub const KEYWORDS: [&str; 21] = [ "CUBE_COUNT_Y", "CUBE_COUNT_Z", "PLANE_DIM", + "UNIT_POS_PLANE", ]; pub type Scope = usize; diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/base.rs b/crates/cubecl-wgpu/src/compiler/wgsl/base.rs index 990e4149a..0b7d1dfc1 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/base.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/base.rs @@ -3,7 +3,6 @@ use std::fmt::Display; #[derive(Debug, Clone, PartialEq)] pub enum Variable { - SubgroupSize, GlobalInputArray(u16, Item), GlobalOutputArray(u16, Item), GlobalScalar(u16, Elem, cube::Elem), @@ -55,6 +54,8 @@ pub enum Variable { NumWorkgroupsX, NumWorkgroupsY, NumWorkgroupsZ, + SubgroupSize, + SubgroupInvocationId, } #[derive(Debug, Clone, PartialEq, Eq, Copy)] @@ -117,6 +118,7 @@ impl Variable { Variable::WorkgroupSize => true, Variable::NumWorkgroups => true, Variable::SubgroupSize => true, + Variable::SubgroupInvocationId => true, } } pub fn index(&self, index: usize) -> IndexedVariable { @@ -175,6 +177,7 @@ impl Variable { Self::NumWorkgroupsY => Item::Scalar(Elem::U32), Self::NumWorkgroupsZ => Item::Scalar(Elem::U32), Self::SubgroupSize => Item::Scalar(Elem::U32), + Self::SubgroupInvocationId => Item::Scalar(Elem::U32), } } pub fn elem(&self) -> Elem { @@ -340,6 +343,7 @@ impl Display for Variable { Variable::WorkgroupSize => f.write_str("workgroup_size_no_axis"), Variable::NumWorkgroups => f.write_str("num_workgroups_no_axis"), Variable::SubgroupSize => f.write_str("subgroup_size"), + Variable::SubgroupInvocationId => f.write_str("subgroup_invocation_id"), } } } diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs b/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs index 1b1a57c4b..1ca0e1dcf 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs @@ -31,6 +31,7 @@ pub struct WgslCompiler { global_invocation_id: bool, workgroup_id: bool, subgroup_size: bool, + subgroup_invocation_id: bool, id: bool, num_workgroups: bool, workgroup_id_no_axis: bool, @@ -264,6 +265,7 @@ impl WgslCompiler { || self.workgroup_id_no_axis, workgroup_id: self.workgroup_id || self.workgroup_id_no_axis, subgroup_size: self.subgroup_size, + subgroup_invocation_id: self.subgroup_invocation_id, body, extensions, num_workgroups_no_axis: self.num_workgroup_no_axis, @@ -448,6 +450,10 @@ impl WgslCompiler { self.subgroup_size = true; wgsl::Variable::SubgroupSize } + cube::Builtin::UnitPosPlane => { + self.subgroup_invocation_id = true; + wgsl::Variable::SubgroupInvocationId + } }, cube::VariableKind::Matrix { .. } => { panic!("Cooperative matrix-multiply and accumulate not supported.") diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/shader.rs b/crates/cubecl-wgpu/src/compiler/wgsl/shader.rs index a924a7edb..3fdeeb121 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/shader.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/shader.rs @@ -83,6 +83,7 @@ pub struct ComputeShader { pub num_workgroups: bool, pub workgroup_id: bool, pub subgroup_size: bool, + pub subgroup_invocation_id: bool, pub num_workgroups_no_axis: bool, pub workgroup_id_no_axis: bool, pub workgroup_size_no_axis: bool, @@ -170,6 +171,9 @@ fn main( if self.subgroup_size { f.write_str(" @builtin(subgroup_size) subgroup_size: u32,\n")?; } + if self.subgroup_invocation_id { + f.write_str(" @builtin(subgroup_invocation_id) subgroup_invocation_id: u32,\n")?; + } // Open body f.write_str(") {\n")?; From 5a558d6e83fda9573779e0529dc168bfb21636f4 Mon Sep 17 00:00:00 2001 From: maxime Date: Fri, 22 Nov 2024 11:27:35 -0500 Subject: [PATCH 2/8] Update reduce_sum to a plane base implementation --- crates/cubecl-std/src/reduce/sum.rs | 31 ++++++++++++++-------------- crates/cubecl-std/src/reduce/test.rs | 14 ++++++------- 2 files changed, 22 insertions(+), 23 deletions(-) diff --git a/crates/cubecl-std/src/reduce/sum.rs b/crates/cubecl-std/src/reduce/sum.rs index ed4677490..5b577ec67 100644 --- a/crates/cubecl-std/src/reduce/sum.rs +++ b/crates/cubecl-std/src/reduce/sum.rs @@ -4,8 +4,7 @@ use cubecl_core::prelude::*; #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub struct ReduceConfig { pub line_size: u32, - pub plane_size: u32, - pub num_planes: u32, + pub max_num_planes: u32, } /// Compute the sum of all elements of `input` and write it to the first element of `output`. @@ -45,28 +44,28 @@ pub fn reduce_sum_vector( output: &mut SliceMut>, #[comptime] config: ReduceConfig, ) { - // How many lines accounted in each iteration. - let block_size = config.plane_size * config.num_planes; + let plane_id = UNIT_POS / PLANE_DIM; + let num_planes = CUBE_DIM / PLANE_DIM; - // This is an integer division rounded up. - let num_blocks = input.len() / block_size + (input.len() % block_size > 0) as u32; + // This is an integer division rounded up. It computes the number of required iterations + // to reduce all lines when reducing CUBE_DIM lines per iteration. + let num_iterations = input.len() / CUBE_DIM + (input.len() % CUBE_DIM > 0) as u32; - let mut memory = SharedMemory::new_lined(config.num_planes, config.line_size); + let mut memory = SharedMemory::new_lined(config.max_num_planes, input[0].size()); + memory[plane_id] = Line::empty(config.line_size).fill(N::from_int(0)); - memory[UNIT_POS_Y] = Line::empty(config.line_size).fill(N::from_int(0)); - - // For each block, we reduce each group of plane_size lines to a single line. Then, we accumulate the results + // For each iteration, each plane reduces PLANE_DIM lines into a single line. Then, we accumulate the results // into the memory. Thus, after the loop, the reduction of the memory yields the expected output. - for i in 0..num_blocks { - let index = i * block_size + UNIT_POS_Y * config.plane_size + UNIT_POS_X; + for i in 0..num_iterations { + let index = i * CUBE_DIM + plane_id * PLANE_DIM + UNIT_POS_PLANE; let value = select( index < input.len(), input[index], Line::empty(config.line_size).fill(N::from_int(0)), ); let sum = plane_sum(value); - if UNIT_POS_X == 0 { - memory[UNIT_POS_Y] += sum; + if UNIT_POS_PLANE == 0 { + memory[plane_id] += sum; } } @@ -75,8 +74,8 @@ pub fn reduce_sum_vector( // Sum each elements in memory let sum = plane_sum(select( - UNIT_POS_X < config.num_planes, - memory[UNIT_POS_X], + UNIT_POS_PLANE < num_planes, + memory[UNIT_POS_PLANE], Line::empty(config.line_size).fill(N::from_int(0)), )); if UNIT_POS == 0 { diff --git a/crates/cubecl-std/src/reduce/test.rs b/crates/cubecl-std/src/reduce/test.rs index 1790eb7f7..07337e514 100644 --- a/crates/cubecl-std/src/reduce/test.rs +++ b/crates/cubecl-std/src/reduce/test.rs @@ -76,7 +76,7 @@ macro_rules! testgen_reduce { // expected vec![8128], ); - test.cube_dim = CubeDim::new(32, 4, 1); + test.cube_dim = CubeDim::new(128, 1, 1); impl_reduce_sum_test::(&Default::default(), test); } @@ -103,7 +103,7 @@ macro_rules! testgen_reduce { // expected vec![4950], ); - test.cube_dim = CubeDim::new(32, 4, 1); + test.cube_dim = CubeDim::new(128, 1, 1); impl_reduce_sum_test::(&Default::default(), test); } @@ -117,7 +117,7 @@ macro_rules! testgen_reduce { // expected vec![4950], ); - test.cube_dim = CubeDim::new(32, 4, 1); + test.cube_dim = CubeDim::new(128, 1, 1); test.reduce_lines = true; impl_reduce_sum_test::(&Default::default(), test); } @@ -133,7 +133,7 @@ macro_rules! testgen_reduce { vec![523776.0], ); test.tolerance = Some(1e-9); - test.cube_dim = CubeDim::new(32, 8, 1); + test.cube_dim = CubeDim::new(256, 1, 1); impl_reduce_sum_test::(&Default::default(), test); } @@ -148,7 +148,7 @@ macro_rules! testgen_reduce { vec![8128.0], ); test.tolerance = Some(1e-9); - test.cube_dim = CubeDim::new(32, 8, 1); + test.cube_dim = CubeDim::new(256, 1, 1); impl_reduce_sum_test::(&Default::default(), test); } }; @@ -221,8 +221,8 @@ pub fn impl_reduce_sum_test Date: Fri, 22 Nov 2024 11:27:42 -0500 Subject: [PATCH 3/8] add reduce test to HIP --- crates/cubecl-hip/Cargo.toml | 1 + crates/cubecl-hip/src/lib.rs | 1 + 2 files changed, 2 insertions(+) diff --git a/crates/cubecl-hip/Cargo.toml b/crates/cubecl-hip/Cargo.toml index 2944e5b0f..22d10c0b2 100644 --- a/crates/cubecl-hip/Cargo.toml +++ b/crates/cubecl-hip/Cargo.toml @@ -43,4 +43,5 @@ cubecl-core = { path = "../cubecl-core", version = "0.4.0", features = [ cubecl-linalg = { path = "../cubecl-linalg", version = "0.4.0", features = [ "export_tests", ] } +cubecl-std = { path = "../cubecl-linalg", version = "0.4.0" } pretty_assertions = { workspace = true } diff --git a/crates/cubecl-hip/src/lib.rs b/crates/cubecl-hip/src/lib.rs index 5d69d790b..a89547a49 100644 --- a/crates/cubecl-hip/src/lib.rs +++ b/crates/cubecl-hip/src/lib.rs @@ -25,4 +25,5 @@ mod tests { cubecl_core::testgen_all!(); cubecl_linalg::testgen_cmma_matmul!(); + cubecl_std::testgen_reduce!(); } From 30bede64a5cd829b12461463b001772b7352a7fc Mon Sep 17 00:00:00 2001 From: maxime Date: Fri, 22 Nov 2024 11:31:59 -0500 Subject: [PATCH 4/8] Fix cargo toml --- crates/cubecl-hip/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/cubecl-hip/Cargo.toml b/crates/cubecl-hip/Cargo.toml index 22d10c0b2..46b268e18 100644 --- a/crates/cubecl-hip/Cargo.toml +++ b/crates/cubecl-hip/Cargo.toml @@ -43,5 +43,5 @@ cubecl-core = { path = "../cubecl-core", version = "0.4.0", features = [ cubecl-linalg = { path = "../cubecl-linalg", version = "0.4.0", features = [ "export_tests", ] } -cubecl-std = { path = "../cubecl-linalg", version = "0.4.0" } +cubecl-std = { path = "../cubecl-std", version = "0.4.0" } pretty_assertions = { workspace = true } From d95160d0421067fa09359559ed14da981cacca0b Mon Sep 17 00:00:00 2001 From: maxime Date: Fri, 22 Nov 2024 11:48:57 -0500 Subject: [PATCH 5/8] fix ceiling division for num_planes --- crates/cubecl-std/src/reduce/sum.rs | 15 +++++++++++---- crates/cubecl-std/src/reduce/test.rs | 2 +- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/crates/cubecl-std/src/reduce/sum.rs b/crates/cubecl-std/src/reduce/sum.rs index 5b577ec67..53259e61e 100644 --- a/crates/cubecl-std/src/reduce/sum.rs +++ b/crates/cubecl-std/src/reduce/sum.rs @@ -37,6 +37,8 @@ pub fn reduce_sum_lined( reduce_sum_lines(&tmp.to_slice(), &mut output.to_slice_mut(), 1_u32); } + + /// Compute the sum of all elements of `input` and write it to the first element of `output`. #[cube] pub fn reduce_sum_vector( @@ -45,11 +47,10 @@ pub fn reduce_sum_vector( #[comptime] config: ReduceConfig, ) { let plane_id = UNIT_POS / PLANE_DIM; - let num_planes = CUBE_DIM / PLANE_DIM; + let num_planes = div_ceil(CUBE_DIM, PLANE_DIM); - // This is an integer division rounded up. It computes the number of required iterations - // to reduce all lines when reducing CUBE_DIM lines per iteration. - let num_iterations = input.len() / CUBE_DIM + (input.len() % CUBE_DIM > 0) as u32; + // Compute the number of required iterations to reduce all lines when reducing CUBE_DIM lines per iteration. + let num_iterations = div_ceil(input.len(), CUBE_DIM); let mut memory = SharedMemory::new_lined(config.max_num_planes, input[0].size()); memory[plane_id] = Line::empty(config.line_size).fill(N::from_int(0)); @@ -103,3 +104,9 @@ pub fn reduce_sum_lines( output[UNIT_POS] = sum; } } + +// Integer division rounded up. +#[cube] +fn div_ceil(a: u32, b: u32) -> u32 { + a / b + ((a % b) > 0) as u32 +} diff --git a/crates/cubecl-std/src/reduce/test.rs b/crates/cubecl-std/src/reduce/test.rs index 07337e514..0186c683c 100644 --- a/crates/cubecl-std/src/reduce/test.rs +++ b/crates/cubecl-std/src/reduce/test.rs @@ -28,7 +28,7 @@ macro_rules! testgen_reduce { #[test] pub fn reduce_sum_vector_single_plane_line_size_four() { - let mut test = TestCase::new( + let test = TestCase::new( // input TestTensorParts::new_vector((0..32).collect()).with_line_size(4), // output From 0253adcf042df00f29ac254af7f4e5adecdb99d9 Mon Sep 17 00:00:00 2001 From: maxime Date: Fri, 22 Nov 2024 12:15:24 -0500 Subject: [PATCH 6/8] add spirv support --- crates/cubecl-spirv/src/globals.rs | 11 +++++++++++ crates/cubecl-spirv/src/variable.rs | 1 + 2 files changed, 12 insertions(+) diff --git a/crates/cubecl-spirv/src/globals.rs b/crates/cubecl-spirv/src/globals.rs index 6178c1e0d..61a8586ba 100644 --- a/crates/cubecl-spirv/src/globals.rs +++ b/crates/cubecl-spirv/src/globals.rs @@ -113,6 +113,17 @@ impl SpirvCompiler { }); Variable::SubgroupSize(id) } + Builtin::UnitPosPlane => { + let id = self.get_or_insert_global(Globals::SubgroupInvocationId, |b| { + let id = b.load_builtin( + BuiltIn::SubgroupLocalInvocationId, + Item::Scalar(Elem::Int(32, false)), + ); + b.debug_name(id, "PLANE_DIM"); + id + }); + Variable::SubgroupSize(id) + } Builtin::CubePos => { let id = self.get_or_insert_global(Globals::WorkgroupIndex, |b| { let x = b.compile_variable(built_var(Builtin::CubePosX)).id(b); diff --git a/crates/cubecl-spirv/src/variable.rs b/crates/cubecl-spirv/src/variable.rs index 46a06dc75..f748f109f 100644 --- a/crates/cubecl-spirv/src/variable.rs +++ b/crates/cubecl-spirv/src/variable.rs @@ -340,6 +340,7 @@ pub enum Globals { NumWorkgroupsY, NumWorkgroupsZ, SubgroupSize, + SubgroupInvocationId, Metadata(u32), } From e0a1fe8dff914fd89ab7973060adda24c3da77b8 Mon Sep 17 00:00:00 2001 From: maxime Date: Fri, 22 Nov 2024 12:21:07 -0500 Subject: [PATCH 7/8] Change debug name for UNIT_POS_PLANE --- crates/cubecl-spirv/src/globals.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/cubecl-spirv/src/globals.rs b/crates/cubecl-spirv/src/globals.rs index 61a8586ba..633a018bf 100644 --- a/crates/cubecl-spirv/src/globals.rs +++ b/crates/cubecl-spirv/src/globals.rs @@ -119,7 +119,7 @@ impl SpirvCompiler { BuiltIn::SubgroupLocalInvocationId, Item::Scalar(Elem::Int(32, false)), ); - b.debug_name(id, "PLANE_DIM"); + b.debug_name(id, "UNIT_POS_PLANE"); id }); Variable::SubgroupSize(id) From d1778affd04b7879009380c462ed05c1885889d0 Mon Sep 17 00:00:00 2001 From: maxime Date: Fri, 22 Nov 2024 12:21:13 -0500 Subject: [PATCH 8/8] RUN cargo fmt --- crates/cubecl-std/src/reduce/sum.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/crates/cubecl-std/src/reduce/sum.rs b/crates/cubecl-std/src/reduce/sum.rs index 53259e61e..4b964d863 100644 --- a/crates/cubecl-std/src/reduce/sum.rs +++ b/crates/cubecl-std/src/reduce/sum.rs @@ -37,8 +37,6 @@ pub fn reduce_sum_lined( reduce_sum_lines(&tmp.to_slice(), &mut output.to_slice_mut(), 1_u32); } - - /// Compute the sum of all elements of `input` and write it to the first element of `output`. #[cube] pub fn reduce_sum_vector( @@ -105,7 +103,7 @@ pub fn reduce_sum_lines( } } -// Integer division rounded up. +// Integer division rounded up. #[cube] fn div_ceil(a: u32, b: u32) -> u32 { a / b + ((a % b) > 0) as u32