Skip to content

Commit

Permalink
Query num planes (#294)
Browse files Browse the repository at this point in the history
* Introduce a new UnitPosPlane constant

* Update reduce_sum to a plane base implementation

* add reduce test to HIP

* Fix cargo toml

* fix ceiling division for num_planes

* add spirv support

* Change debug name for UNIT_POS_PLANE

* RUN cargo fmt
  • Loading branch information
maxtremblay authored Nov 22, 2024
1 parent ee66c6b commit 4ce794b
Show file tree
Hide file tree
Showing 14 changed files with 73 additions and 26 deletions.
8 changes: 8 additions & 0 deletions crates/cubecl-core/src/frontend/topology.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@ The total amount of working units in a plane.
"
);

constant!(
UNIT_POS_PLANE,
crate::ir::Builtin::UnitPosPlane,
r"
The relative position of the working unit inside the plane, without regards to cube dimensions.
"
);

constant!(
UNIT_POS,
crate::ir::Builtin::UnitPos,
Expand Down
1 change: 1 addition & 0 deletions crates/cubecl-core/src/ir/variable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ pub enum Builtin {
CubeCountY,
CubeCountZ,
PlaneDim,
UnitPosPlane,
AbsolutePos,
AbsolutePosX,
AbsolutePosY,
Expand Down
1 change: 1 addition & 0 deletions crates/cubecl-cpp/src/shared/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -854,6 +854,7 @@ impl<D: Dialect> CppCompiler<D> {
Variable::GridDimGlobal
}
gpu::Builtin::PlaneDim => Variable::WarpSize,
gpu::Builtin::UnitPosPlane => Variable::ThreadIdxWarp,
},
gpu::VariableKind::LocalArray { id, depth, length } => {
let item = self.compile_item(item);
Expand Down
4 changes: 4 additions & 0 deletions crates/cubecl-cpp/src/shared/element.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ impl<D: Dialect> Component<D> for Variable<D> {
Variable::GridDimZ => Item::scalar(Elem::U32),
Variable::LocalArray(_, e, _, _) => *e,
Variable::WarpSize => Item::scalar(Elem::U32),
Variable::ThreadIdxWarp => Item::scalar(Elem::U32),
Variable::WmmaFragment {
id: _,
frag,
Expand All @@ -163,6 +164,7 @@ impl<D: Dialect> Component<D> for Variable<D> {
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum Variable<D: Dialect> {
WarpSize,
ThreadIdxWarp,
GlobalInputArray(u16, Item<D>),
GlobalOutputArray(u16, Item<D>),
GlobalScalar(u16, Elem<D>, gpu::Elem),
Expand Down Expand Up @@ -285,6 +287,7 @@ impl<D: Dialect> Display for Variable<D> {
write!(f, "l_arr_{}_{}", id, depth)
}
Variable::WarpSize => f.write_str("warpSize"),
Variable::ThreadIdxWarp => f.write_str("threadIdxGlobal % warpSize"),
Variable::WmmaFragment {
id: index,
frag: _,
Expand Down Expand Up @@ -416,6 +419,7 @@ impl<D: Dialect> Variable<D> {
Variable::GridDimZ => true,
Variable::LocalArray(_, _, _, _) => false,
Variable::WarpSize => true,
Variable::ThreadIdxWarp => true,
Variable::WmmaFragment { .. } => false,
Variable::BlockIdxGlobal => true,
Variable::BlockDimGlobal => true,
Expand Down
1 change: 1 addition & 0 deletions crates/cubecl-hip/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,5 @@ cubecl-core = { path = "../cubecl-core", version = "0.4.0", features = [
cubecl-linalg = { path = "../cubecl-linalg", version = "0.4.0", features = [
"export_tests",
] }
cubecl-std = { path = "../cubecl-std", version = "0.4.0" }
pretty_assertions = { workspace = true }
1 change: 1 addition & 0 deletions crates/cubecl-hip/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,5 @@ mod tests {

cubecl_core::testgen_all!();
cubecl_linalg::testgen_cmma_matmul!();
cubecl_std::testgen_reduce!();
}
3 changes: 2 additions & 1 deletion crates/cubecl-macros/src/scope.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use syn::{parse_quote, Ident, Type};

use crate::parse::kernel::KernelParam;

pub const KEYWORDS: [&str; 21] = [
pub const KEYWORDS: [&str; 22] = [
"ABSOLUTE_POS",
"ABSOLUTE_POS_X",
"ABSOLUTE_POS_Y",
Expand All @@ -31,6 +31,7 @@ pub const KEYWORDS: [&str; 21] = [
"CUBE_COUNT_Y",
"CUBE_COUNT_Z",
"PLANE_DIM",
"UNIT_POS_PLANE",
];

pub type Scope = usize;
Expand Down
11 changes: 11 additions & 0 deletions crates/cubecl-spirv/src/globals.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,17 @@ impl<T: SpirvTarget> SpirvCompiler<T> {
});
Variable::SubgroupSize(id)
}
Builtin::UnitPosPlane => {
let id = self.get_or_insert_global(Globals::SubgroupInvocationId, |b| {
let id = b.load_builtin(
BuiltIn::SubgroupLocalInvocationId,
Item::Scalar(Elem::Int(32, false)),
);
b.debug_name(id, "UNIT_POS_PLANE");
id
});
Variable::SubgroupSize(id)
}
Builtin::CubePos => {
let id = self.get_or_insert_global(Globals::WorkgroupIndex, |b| {
let x = b.compile_variable(built_var(Builtin::CubePosX)).id(b);
Expand Down
1 change: 1 addition & 0 deletions crates/cubecl-spirv/src/variable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@ pub enum Globals {
NumWorkgroupsY,
NumWorkgroupsZ,
SubgroupSize,
SubgroupInvocationId,

Metadata(u32),
}
Expand Down
36 changes: 20 additions & 16 deletions crates/cubecl-std/src/reduce/sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@ use cubecl_core::prelude::*;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct ReduceConfig {
pub line_size: u32,
pub plane_size: u32,
pub num_planes: u32,
pub max_num_planes: u32,
}

/// Compute the sum of all elements of `input` and write it to the first element of `output`.
Expand Down Expand Up @@ -45,28 +44,27 @@ pub fn reduce_sum_vector<N: Numeric>(
output: &mut SliceMut<Line<N>>,
#[comptime] config: ReduceConfig,
) {
// How many lines accounted in each iteration.
let block_size = config.plane_size * config.num_planes;
let plane_id = UNIT_POS / PLANE_DIM;
let num_planes = div_ceil(CUBE_DIM, PLANE_DIM);

// This is an integer division rounded up.
let num_blocks = input.len() / block_size + (input.len() % block_size > 0) as u32;
// Compute the number of required iterations to reduce all lines when reducing CUBE_DIM lines per iteration.
let num_iterations = div_ceil(input.len(), CUBE_DIM);

let mut memory = SharedMemory::new_lined(config.num_planes, config.line_size);
let mut memory = SharedMemory::new_lined(config.max_num_planes, input[0].size());
memory[plane_id] = Line::empty(config.line_size).fill(N::from_int(0));

memory[UNIT_POS_Y] = Line::empty(config.line_size).fill(N::from_int(0));

// For each block, we reduce each group of plane_size lines to a single line. Then, we accumulate the results
// For each iteration, each plane reduces PLANE_DIM lines into a single line. Then, we accumulate the results
// into the memory. Thus, after the loop, the reduction of the memory yields the expected output.
for i in 0..num_blocks {
let index = i * block_size + UNIT_POS_Y * config.plane_size + UNIT_POS_X;
for i in 0..num_iterations {
let index = i * CUBE_DIM + plane_id * PLANE_DIM + UNIT_POS_PLANE;
let value = select(
index < input.len(),
input[index],
Line::empty(config.line_size).fill(N::from_int(0)),
);
let sum = plane_sum(value);
if UNIT_POS_X == 0 {
memory[UNIT_POS_Y] += sum;
if UNIT_POS_PLANE == 0 {
memory[plane_id] += sum;
}
}

Expand All @@ -75,8 +73,8 @@ pub fn reduce_sum_vector<N: Numeric>(

// Sum each elements in memory
let sum = plane_sum(select(
UNIT_POS_X < config.num_planes,
memory[UNIT_POS_X],
UNIT_POS_PLANE < num_planes,
memory[UNIT_POS_PLANE],
Line::empty(config.line_size).fill(N::from_int(0)),
));
if UNIT_POS == 0 {
Expand Down Expand Up @@ -104,3 +102,9 @@ pub fn reduce_sum_lines<N: Numeric>(
output[UNIT_POS] = sum;
}
}

// Integer division rounded up.
#[cube]
fn div_ceil(a: u32, b: u32) -> u32 {
a / b + ((a % b) > 0) as u32
}
16 changes: 8 additions & 8 deletions crates/cubecl-std/src/reduce/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ macro_rules! testgen_reduce {

#[test]
pub fn reduce_sum_vector_single_plane_line_size_four() {
let mut test = TestCase::new(
let test = TestCase::new(
// input
TestTensorParts::new_vector((0..32).collect()).with_line_size(4),
// output
Expand Down Expand Up @@ -76,7 +76,7 @@ macro_rules! testgen_reduce {
// expected
vec![8128],
);
test.cube_dim = CubeDim::new(32, 4, 1);
test.cube_dim = CubeDim::new(128, 1, 1);
impl_reduce_sum_test::<TestRuntime, u32>(&Default::default(), test);
}

Expand All @@ -103,7 +103,7 @@ macro_rules! testgen_reduce {
// expected
vec![4950],
);
test.cube_dim = CubeDim::new(32, 4, 1);
test.cube_dim = CubeDim::new(128, 1, 1);
impl_reduce_sum_test::<TestRuntime, u32>(&Default::default(), test);
}

Expand All @@ -117,7 +117,7 @@ macro_rules! testgen_reduce {
// expected
vec![4950],
);
test.cube_dim = CubeDim::new(32, 4, 1);
test.cube_dim = CubeDim::new(128, 1, 1);
test.reduce_lines = true;
impl_reduce_sum_test::<TestRuntime, u32>(&Default::default(), test);
}
Expand All @@ -133,7 +133,7 @@ macro_rules! testgen_reduce {
vec![523776.0],
);
test.tolerance = Some(1e-9);
test.cube_dim = CubeDim::new(32, 8, 1);
test.cube_dim = CubeDim::new(256, 1, 1);
impl_reduce_sum_test::<TestRuntime, f32>(&Default::default(), test);
}

Expand All @@ -148,7 +148,7 @@ macro_rules! testgen_reduce {
vec![8128.0],
);
test.tolerance = Some(1e-9);
test.cube_dim = CubeDim::new(32, 8, 1);
test.cube_dim = CubeDim::new(256, 1, 1);
impl_reduce_sum_test::<TestRuntime, f32>(&Default::default(), test);
}
};
Expand Down Expand Up @@ -221,8 +221,8 @@ pub fn impl_reduce_sum_test<R: Runtime, N: Numeric + CubeElement + std::fmt::Dis

let config = ReduceConfig {
line_size: test.input.line_size as u32,
plane_size: test.cube_dim.x,
num_planes: test.cube_dim.y,
max_num_planes: test.cube_dim.num_elems()
/ client.properties().hardware_properties().plane_size_min,
};

unsafe {
Expand Down
6 changes: 5 additions & 1 deletion crates/cubecl-wgpu/src/compiler/wgsl/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use std::fmt::Display;

#[derive(Debug, Clone, PartialEq)]
pub enum Variable {
SubgroupSize,
GlobalInputArray(u16, Item),
GlobalOutputArray(u16, Item),
GlobalScalar(u16, Elem, cube::Elem),
Expand Down Expand Up @@ -55,6 +54,8 @@ pub enum Variable {
NumWorkgroupsX,
NumWorkgroupsY,
NumWorkgroupsZ,
SubgroupSize,
SubgroupInvocationId,
}

#[derive(Debug, Clone, PartialEq, Eq, Copy)]
Expand Down Expand Up @@ -117,6 +118,7 @@ impl Variable {
Variable::WorkgroupSize => true,
Variable::NumWorkgroups => true,
Variable::SubgroupSize => true,
Variable::SubgroupInvocationId => true,
}
}
pub fn index(&self, index: usize) -> IndexedVariable {
Expand Down Expand Up @@ -175,6 +177,7 @@ impl Variable {
Self::NumWorkgroupsY => Item::Scalar(Elem::U32),
Self::NumWorkgroupsZ => Item::Scalar(Elem::U32),
Self::SubgroupSize => Item::Scalar(Elem::U32),
Self::SubgroupInvocationId => Item::Scalar(Elem::U32),
}
}
pub fn elem(&self) -> Elem {
Expand Down Expand Up @@ -340,6 +343,7 @@ impl Display for Variable {
Variable::WorkgroupSize => f.write_str("workgroup_size_no_axis"),
Variable::NumWorkgroups => f.write_str("num_workgroups_no_axis"),
Variable::SubgroupSize => f.write_str("subgroup_size"),
Variable::SubgroupInvocationId => f.write_str("subgroup_invocation_id"),
}
}
}
Expand Down
6 changes: 6 additions & 0 deletions crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ pub struct WgslCompiler {
global_invocation_id: bool,
workgroup_id: bool,
subgroup_size: bool,
subgroup_invocation_id: bool,
id: bool,
num_workgroups: bool,
workgroup_id_no_axis: bool,
Expand Down Expand Up @@ -277,6 +278,7 @@ impl WgslCompiler {
|| self.workgroup_id_no_axis,
workgroup_id: self.workgroup_id || self.workgroup_id_no_axis,
subgroup_size: self.subgroup_size,
subgroup_invocation_id: self.subgroup_invocation_id,
body,
extensions,
num_workgroups_no_axis: self.num_workgroup_no_axis,
Expand Down Expand Up @@ -462,6 +464,10 @@ impl WgslCompiler {
self.subgroup_size = true;
wgsl::Variable::SubgroupSize
}
cube::Builtin::UnitPosPlane => {
self.subgroup_invocation_id = true;
wgsl::Variable::SubgroupInvocationId
}
},
cube::VariableKind::Matrix { .. } => {
panic!("Cooperative matrix-multiply and accumulate not supported.")
Expand Down
4 changes: 4 additions & 0 deletions crates/cubecl-wgpu/src/compiler/wgsl/shader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ pub struct ComputeShader {
pub num_workgroups: bool,
pub workgroup_id: bool,
pub subgroup_size: bool,
pub subgroup_invocation_id: bool,
pub num_workgroups_no_axis: bool,
pub workgroup_id_no_axis: bool,
pub workgroup_size_no_axis: bool,
Expand Down Expand Up @@ -171,6 +172,9 @@ fn {}(
if self.subgroup_size {
f.write_str(" @builtin(subgroup_size) subgroup_size: u32,\n")?;
}
if self.subgroup_invocation_id {
f.write_str(" @builtin(subgroup_invocation_id) subgroup_invocation_id: u32,\n")?;
}

// Open body
f.write_str(") {\n")?;
Expand Down

0 comments on commit 4ce794b

Please sign in to comment.