Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Query num planes #294

Merged
merged 8 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions crates/cubecl-core/src/frontend/topology.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@ The total amount of working units in a plane.
"
);

constant!(
UNIT_POS_PLANE,
crate::ir::Builtin::UnitPosPlane,
r"
The relative position of the working unit inside the plane, without regards to cube dimensions.
"
);

constant!(
UNIT_POS,
crate::ir::Builtin::UnitPos,
Expand Down
1 change: 1 addition & 0 deletions crates/cubecl-core/src/ir/variable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ pub enum Builtin {
CubeCountY,
CubeCountZ,
PlaneDim,
UnitPosPlane,
AbsolutePos,
AbsolutePosX,
AbsolutePosY,
Expand Down
1 change: 1 addition & 0 deletions crates/cubecl-cpp/src/shared/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -837,6 +837,7 @@ impl<D: Dialect> CppCompiler<D> {
Variable::GridDimGlobal
}
gpu::Builtin::PlaneDim => Variable::WarpSize,
gpu::Builtin::UnitPosPlane => Variable::ThreadIdxWarp,
},
gpu::VariableKind::LocalArray { id, depth, length } => {
let item = self.compile_item(item);
Expand Down
4 changes: 4 additions & 0 deletions crates/cubecl-cpp/src/shared/element.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ impl<D: Dialect> Component<D> for Variable<D> {
Variable::GridDimZ => Item::scalar(Elem::U32),
Variable::LocalArray(_, e, _, _) => *e,
Variable::WarpSize => Item::scalar(Elem::U32),
Variable::ThreadIdxWarp => Item::scalar(Elem::U32),
Variable::WmmaFragment {
id: _,
frag,
Expand All @@ -163,6 +164,7 @@ impl<D: Dialect> Component<D> for Variable<D> {
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum Variable<D: Dialect> {
WarpSize,
ThreadIdxWarp,
GlobalInputArray(u16, Item<D>),
GlobalOutputArray(u16, Item<D>),
GlobalScalar(u16, Elem<D>, gpu::Elem),
Expand Down Expand Up @@ -285,6 +287,7 @@ impl<D: Dialect> Display for Variable<D> {
write!(f, "l_arr_{}_{}", id, depth)
}
Variable::WarpSize => f.write_str("warpSize"),
Variable::ThreadIdxWarp => f.write_str("threadIdxGlobal % warpSize"),
Variable::WmmaFragment {
id: index,
frag: _,
Expand Down Expand Up @@ -416,6 +419,7 @@ impl<D: Dialect> Variable<D> {
Variable::GridDimZ => true,
Variable::LocalArray(_, _, _, _) => false,
Variable::WarpSize => true,
Variable::ThreadIdxWarp => true,
Variable::WmmaFragment { .. } => false,
Variable::BlockIdxGlobal => true,
Variable::BlockDimGlobal => true,
Expand Down
1 change: 1 addition & 0 deletions crates/cubecl-hip/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,5 @@ cubecl-core = { path = "../cubecl-core", version = "0.4.0", features = [
cubecl-linalg = { path = "../cubecl-linalg", version = "0.4.0", features = [
"export_tests",
] }
cubecl-std = { path = "../cubecl-std", version = "0.4.0" }
pretty_assertions = { workspace = true }
1 change: 1 addition & 0 deletions crates/cubecl-hip/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,5 @@ mod tests {

cubecl_core::testgen_all!();
cubecl_linalg::testgen_cmma_matmul!();
cubecl_std::testgen_reduce!();
}
3 changes: 2 additions & 1 deletion crates/cubecl-macros/src/scope.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use syn::{parse_quote, Ident, Type};

use crate::parse::kernel::KernelParam;

pub const KEYWORDS: [&str; 21] = [
pub const KEYWORDS: [&str; 22] = [
"ABSOLUTE_POS",
"ABSOLUTE_POS_X",
"ABSOLUTE_POS_Y",
Expand All @@ -31,6 +31,7 @@ pub const KEYWORDS: [&str; 21] = [
"CUBE_COUNT_Y",
"CUBE_COUNT_Z",
"PLANE_DIM",
"UNIT_POS_PLANE",
];

pub type Scope = usize;
Expand Down
11 changes: 11 additions & 0 deletions crates/cubecl-spirv/src/globals.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,17 @@ impl<T: SpirvTarget> SpirvCompiler<T> {
});
Variable::SubgroupSize(id)
}
Builtin::UnitPosPlane => {
let id = self.get_or_insert_global(Globals::SubgroupInvocationId, |b| {
let id = b.load_builtin(
BuiltIn::SubgroupLocalInvocationId,
Item::Scalar(Elem::Int(32, false)),
);
b.debug_name(id, "UNIT_POS_PLANE");
id
});
Variable::SubgroupSize(id)
}
Builtin::CubePos => {
let id = self.get_or_insert_global(Globals::WorkgroupIndex, |b| {
let x = b.compile_variable(built_var(Builtin::CubePosX)).id(b);
Expand Down
1 change: 1 addition & 0 deletions crates/cubecl-spirv/src/variable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@ pub enum Globals {
NumWorkgroupsY,
NumWorkgroupsZ,
SubgroupSize,
SubgroupInvocationId,

Metadata(u32),
}
Expand Down
36 changes: 20 additions & 16 deletions crates/cubecl-std/src/reduce/sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@ use cubecl_core::prelude::*;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct ReduceConfig {
pub line_size: u32,
pub plane_size: u32,
pub num_planes: u32,
pub max_num_planes: u32,
}

/// Compute the sum of all elements of `input` and write it to the first element of `output`.
Expand Down Expand Up @@ -45,28 +44,27 @@ pub fn reduce_sum_vector<N: Numeric>(
output: &mut SliceMut<Line<N>>,
#[comptime] config: ReduceConfig,
) {
// How many lines accounted in each iteration.
let block_size = config.plane_size * config.num_planes;
let plane_id = UNIT_POS / PLANE_DIM;
let num_planes = div_ceil(CUBE_DIM, PLANE_DIM);

// This is an integer division rounded up.
let num_blocks = input.len() / block_size + (input.len() % block_size > 0) as u32;
// Compute the number of required iterations to reduce all lines when reducing CUBE_DIM lines per iteration.
let num_iterations = div_ceil(input.len(), CUBE_DIM);

let mut memory = SharedMemory::new_lined(config.num_planes, config.line_size);
let mut memory = SharedMemory::new_lined(config.max_num_planes, input[0].size());
memory[plane_id] = Line::empty(config.line_size).fill(N::from_int(0));

memory[UNIT_POS_Y] = Line::empty(config.line_size).fill(N::from_int(0));

// For each block, we reduce each group of plane_size lines to a single line. Then, we accumulate the results
// For each iteration, each plane reduces PLANE_DIM lines into a single line. Then, we accumulate the results
// into the memory. Thus, after the loop, the reduction of the memory yields the expected output.
for i in 0..num_blocks {
let index = i * block_size + UNIT_POS_Y * config.plane_size + UNIT_POS_X;
for i in 0..num_iterations {
let index = i * CUBE_DIM + plane_id * PLANE_DIM + UNIT_POS_PLANE;
let value = select(
index < input.len(),
input[index],
Line::empty(config.line_size).fill(N::from_int(0)),
);
let sum = plane_sum(value);
if UNIT_POS_X == 0 {
memory[UNIT_POS_Y] += sum;
if UNIT_POS_PLANE == 0 {
memory[plane_id] += sum;
}
}

Expand All @@ -75,8 +73,8 @@ pub fn reduce_sum_vector<N: Numeric>(

// Sum each elements in memory
let sum = plane_sum(select(
UNIT_POS_X < config.num_planes,
memory[UNIT_POS_X],
UNIT_POS_PLANE < num_planes,
memory[UNIT_POS_PLANE],
Line::empty(config.line_size).fill(N::from_int(0)),
));
if UNIT_POS == 0 {
Expand Down Expand Up @@ -104,3 +102,9 @@ pub fn reduce_sum_lines<N: Numeric>(
output[UNIT_POS] = sum;
}
}

// Integer division rounded up.
#[cube]
fn div_ceil(a: u32, b: u32) -> u32 {
a / b + ((a % b) > 0) as u32
}
16 changes: 8 additions & 8 deletions crates/cubecl-std/src/reduce/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ macro_rules! testgen_reduce {

#[test]
pub fn reduce_sum_vector_single_plane_line_size_four() {
let mut test = TestCase::new(
let test = TestCase::new(
// input
TestTensorParts::new_vector((0..32).collect()).with_line_size(4),
// output
Expand Down Expand Up @@ -76,7 +76,7 @@ macro_rules! testgen_reduce {
// expected
vec![8128],
);
test.cube_dim = CubeDim::new(32, 4, 1);
test.cube_dim = CubeDim::new(128, 1, 1);
impl_reduce_sum_test::<TestRuntime, u32>(&Default::default(), test);
}

Expand All @@ -103,7 +103,7 @@ macro_rules! testgen_reduce {
// expected
vec![4950],
);
test.cube_dim = CubeDim::new(32, 4, 1);
test.cube_dim = CubeDim::new(128, 1, 1);
impl_reduce_sum_test::<TestRuntime, u32>(&Default::default(), test);
}

Expand All @@ -117,7 +117,7 @@ macro_rules! testgen_reduce {
// expected
vec![4950],
);
test.cube_dim = CubeDim::new(32, 4, 1);
test.cube_dim = CubeDim::new(128, 1, 1);
test.reduce_lines = true;
impl_reduce_sum_test::<TestRuntime, u32>(&Default::default(), test);
}
Expand All @@ -133,7 +133,7 @@ macro_rules! testgen_reduce {
vec![523776.0],
);
test.tolerance = Some(1e-9);
test.cube_dim = CubeDim::new(32, 8, 1);
test.cube_dim = CubeDim::new(256, 1, 1);
impl_reduce_sum_test::<TestRuntime, f32>(&Default::default(), test);
}

Expand All @@ -148,7 +148,7 @@ macro_rules! testgen_reduce {
vec![8128.0],
);
test.tolerance = Some(1e-9);
test.cube_dim = CubeDim::new(32, 8, 1);
test.cube_dim = CubeDim::new(256, 1, 1);
impl_reduce_sum_test::<TestRuntime, f32>(&Default::default(), test);
}
};
Expand Down Expand Up @@ -221,8 +221,8 @@ pub fn impl_reduce_sum_test<R: Runtime, N: Numeric + CubeElement + std::fmt::Dis

let config = ReduceConfig {
line_size: test.input.line_size as u32,
plane_size: test.cube_dim.x,
num_planes: test.cube_dim.y,
max_num_planes: test.cube_dim.num_elems()
/ client.properties().hardware_properties().plane_size_min,
};

unsafe {
Expand Down
6 changes: 5 additions & 1 deletion crates/cubecl-wgpu/src/compiler/wgsl/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use std::fmt::Display;

#[derive(Debug, Clone, PartialEq)]
pub enum Variable {
SubgroupSize,
GlobalInputArray(u16, Item),
GlobalOutputArray(u16, Item),
GlobalScalar(u16, Elem, cube::Elem),
Expand Down Expand Up @@ -55,6 +54,8 @@ pub enum Variable {
NumWorkgroupsX,
NumWorkgroupsY,
NumWorkgroupsZ,
SubgroupSize,
SubgroupInvocationId,
}

#[derive(Debug, Clone, PartialEq, Eq, Copy)]
Expand Down Expand Up @@ -117,6 +118,7 @@ impl Variable {
Variable::WorkgroupSize => true,
Variable::NumWorkgroups => true,
Variable::SubgroupSize => true,
Variable::SubgroupInvocationId => true,
}
}
pub fn index(&self, index: usize) -> IndexedVariable {
Expand Down Expand Up @@ -175,6 +177,7 @@ impl Variable {
Self::NumWorkgroupsY => Item::Scalar(Elem::U32),
Self::NumWorkgroupsZ => Item::Scalar(Elem::U32),
Self::SubgroupSize => Item::Scalar(Elem::U32),
Self::SubgroupInvocationId => Item::Scalar(Elem::U32),
}
}
pub fn elem(&self) -> Elem {
Expand Down Expand Up @@ -340,6 +343,7 @@ impl Display for Variable {
Variable::WorkgroupSize => f.write_str("workgroup_size_no_axis"),
Variable::NumWorkgroups => f.write_str("num_workgroups_no_axis"),
Variable::SubgroupSize => f.write_str("subgroup_size"),
Variable::SubgroupInvocationId => f.write_str("subgroup_invocation_id"),
}
}
}
Expand Down
6 changes: 6 additions & 0 deletions crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ pub struct WgslCompiler {
global_invocation_id: bool,
workgroup_id: bool,
subgroup_size: bool,
subgroup_invocation_id: bool,
id: bool,
num_workgroups: bool,
workgroup_id_no_axis: bool,
Expand Down Expand Up @@ -264,6 +265,7 @@ impl WgslCompiler {
|| self.workgroup_id_no_axis,
workgroup_id: self.workgroup_id || self.workgroup_id_no_axis,
subgroup_size: self.subgroup_size,
subgroup_invocation_id: self.subgroup_invocation_id,
body,
extensions,
num_workgroups_no_axis: self.num_workgroup_no_axis,
Expand Down Expand Up @@ -448,6 +450,10 @@ impl WgslCompiler {
self.subgroup_size = true;
wgsl::Variable::SubgroupSize
}
cube::Builtin::UnitPosPlane => {
self.subgroup_invocation_id = true;
wgsl::Variable::SubgroupInvocationId
}
},
cube::VariableKind::Matrix { .. } => {
panic!("Cooperative matrix-multiply and accumulate not supported.")
Expand Down
4 changes: 4 additions & 0 deletions crates/cubecl-wgpu/src/compiler/wgsl/shader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ pub struct ComputeShader {
pub num_workgroups: bool,
pub workgroup_id: bool,
pub subgroup_size: bool,
pub subgroup_invocation_id: bool,
pub num_workgroups_no_axis: bool,
pub workgroup_id_no_axis: bool,
pub workgroup_size_no_axis: bool,
Expand Down Expand Up @@ -170,6 +171,9 @@ fn main(
if self.subgroup_size {
f.write_str(" @builtin(subgroup_size) subgroup_size: u32,\n")?;
}
if self.subgroup_invocation_id {
f.write_str(" @builtin(subgroup_invocation_id) subgroup_invocation_id: u32,\n")?;
}

// Open body
f.write_str(") {\n")?;
Expand Down
Loading