Skip to content

Commit

Permalink
fix sync (#82)
Browse files Browse the repository at this point in the history
  • Loading branch information
AntiAnimeGeneral authored Sep 9, 2024
1 parent f086ffa commit c8e256e
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 2 deletions.
13 changes: 13 additions & 0 deletions crates/cubecl-core/src/frontend/synchronization.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
use crate::frontend::CubeContext;
use crate::ir::Synchronization;
// Among all backends, the memory order guarantee of WebGPU is the weakest
// So Cubecl's memory order cannot be stronger than that of WebGPU

/// # Coordinates the following among all invocations in the current cube:
///
/// * Memory writes to variables in cube address space(shared memory) complete,
/// e.g. writes that were initiated actually land in the cube address space memory.
///
/// * Then all the invocations in the cube wait for each other to arrive at the barrier, i.e. this step.
///
/// * Then all the invocations int he cube begin executing after the barrier, and any writes to cube address space that were made before the barrier are now visible to any invocation in this cube.
pub fn sync_units() {}

pub mod sync_units {
Expand All @@ -11,6 +21,9 @@ pub mod sync_units {
}
}

/// * Sync_storage is the same but change "cube address space(shared memory)" to "storage address space(input args)". But the set of invocations that are collaborating is still only the invocations in the same cube.
///
/// * There is no guarantee about using barriers alone to make the writes to storage buffer in one cube become visible to invocations in a different cube.
pub fn sync_storage() {}

pub mod sync_storage {
Expand Down
2 changes: 1 addition & 1 deletion crates/cubecl-cuda/src/compiler/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ impl CudaCompiler {
gpu::Operation::Branch(val) => self.compile_branch(instructions, val),
gpu::Operation::Synchronization(val) => match val {
gpu::Synchronization::SyncUnits => instructions.push(Instruction::SyncThreads),
gpu::Synchronization::SyncStorage => instructions.push(Instruction::ThreadFence),
gpu::Synchronization::SyncStorage => instructions.push(Instruction::SyncThreads),
},
gpu::Operation::Subcube(op) => {
self.wrap_size_checked = true;
Expand Down
7 changes: 6 additions & 1 deletion crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ pub struct WgslCompiler {
local_invocation_id: bool,
global_invocation_id: bool,
workgroup_id: bool,
subgroup_size: bool,
rank: bool,
id: bool,
stride: bool,
Expand Down Expand Up @@ -90,6 +91,7 @@ impl WgslCompiler {
|| self.num_workgroup_no_axis
|| self.workgroup_id_no_axis,
workgroup_id: self.workgroup_id || self.workgroup_id_no_axis,
subgroup_size: self.subgroup_size,
body,
extensions,
num_workgroups_no_axis: self.num_workgroup_no_axis,
Expand Down Expand Up @@ -256,7 +258,10 @@ impl WgslCompiler {
self.num_workgroup_no_axis = true;
wgsl::Variable::NumWorkgroups
}
cube::Variable::SubcubeDim => wgsl::Variable::SubgroupSize,
cube::Variable::SubcubeDim => {
self.subgroup_size = true;
wgsl::Variable::SubgroupSize
}
cube::Variable::Matrix { .. } => {
panic!("Cooperative matrix-multiply and accumulate not supported.")
}
Expand Down
4 changes: 4 additions & 0 deletions crates/cubecl-wgpu/src/compiler/wgsl/shader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ pub struct ComputeShader {
pub local_invocation_id: bool,
pub num_workgroups: bool,
pub workgroup_id: bool,
pub subgroup_size: bool,
pub num_workgroups_no_axis: bool,
pub workgroup_id_no_axis: bool,
pub workgroup_size_no_axis: bool,
Expand Down Expand Up @@ -136,6 +137,9 @@ fn main(
if self.workgroup_id {
f.write_str(" @builtin(workgroup_id) workgroup_id: vec3<u32>,\n")?;
}
if self.subgroup_size {
f.write_str(" @builtin(subgroup_size) subgroup_size: u32,\n")?;
}

// Open body
f.write_fmt(format_args!(") {{"))?;
Expand Down

0 comments on commit c8e256e

Please sign in to comment.