Skip to content

Commit

Permalink
Impl SyncStorage (#72)
Browse files Browse the repository at this point in the history
  • Loading branch information
AntiAnimeGeneral authored Aug 19, 2024
1 parent ccde038 commit b99e753
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 1 deletion.
10 changes: 10 additions & 0 deletions crates/cubecl-core/src/frontend/synchronization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,13 @@ pub mod sync_units {
context.register(Synchronization::SyncUnits)
}
}

pub fn sync_storage() {}

pub mod sync_storage {
use super::*;

pub fn __expand(context: &mut CubeContext) {
context.register(Synchronization::SyncStorage)
}
}
1 change: 1 addition & 0 deletions crates/cubecl-core/src/ir/synchronization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ use serde::{Deserialize, Serialize};
pub enum Synchronization {
// Synchronizize units in a cube.
SyncUnits,
SyncStorage
}
1 change: 1 addition & 0 deletions crates/cubecl-cuda/src/compiler/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ impl CudaCompiler {
gpu::Operation::Branch(val) => self.compile_branch(instructions, val),
gpu::Operation::Synchronization(val) => match val {
gpu::Synchronization::SyncUnits => instructions.push(Instruction::SyncThreads),
gpu::Synchronization::SyncStorage => instructions.push(Instruction::ThreadFence),
},
gpu::Operation::Subcube(op) => {
self.wrap_size_checked = true;
Expand Down
2 changes: 2 additions & 0 deletions crates/cubecl-cuda/src/compiler/instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ pub enum Instruction {
out: Variable,
},
SyncThreads,
ThreadFence,
Ceil(UnaryInstruction),
Floor(UnaryInstruction),
Wrap(WarpInstruction),
Expand Down Expand Up @@ -264,6 +265,7 @@ for (uint {i} = {start}; {i} < {end}; {increment}) {{
out,
} => Clamp::format(f, input, min_value, max_value, out),
Instruction::SyncThreads => f.write_str("__syncthreads();\n"),
Instruction::ThreadFence => f.write_str("__threadfence();\n"),
Instruction::Ceil(it) => Ceil::format(f, &it.input, &it.out),
Instruction::Floor(it) => Floor::format(f, &it.input, &it.out),
Instruction::SliceLength { input, out } => {
Expand Down
3 changes: 3 additions & 0 deletions crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,9 @@ impl WgslCompiler {
match synchronization {
cube::Synchronization::SyncUnits => {
instructions.push(wgsl::Instruction::WorkgroupBarrier)
},
cube::Synchronization::SyncStorage=> {
instructions.push(wgsl::Instruction::StorageBarrier)
}
};
}
Expand Down
4 changes: 3 additions & 1 deletion crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ pub enum Instruction {
Return,
Break,
WorkgroupBarrier,
StorageBarrier,
// Index handles casting to correct local variable.
Index {
lhs: Variable,
Expand Down Expand Up @@ -578,6 +579,7 @@ for (var {i}: u32 = {start}; {i} < {end}; {increment}) {{
Instruction::Return => f.write_str("return;\n"),
Instruction::Break => f.write_str("break;\n"),
Instruction::WorkgroupBarrier => f.write_str("workgroupBarrier();\n"),
Instruction::StorageBarrier => f.write_str("storageBarrier();\n"),
Instruction::Length { var, out } => match var {
Variable::Slice { .. } => f.write_fmt(format_args!("{out} = {var}_length;\n")),
_ => f.write_fmt(format_args!("{out} = arrayLength(&{var});\n")),
Expand Down Expand Up @@ -615,7 +617,7 @@ for (var {i}: u32 = {start}; {i} < {end}; {increment}) {{
f.write_fmt(format_args!("{out} = atomicLoad({input});\n"))
}
Instruction::AtomicStore { input, out } => {
f.write_fmt(format_args!("{out} = atomicStore({input});\n"))
f.write_fmt(format_args!("atomicStore({out},{input});\n"))
}
Instruction::AtomicSwap { lhs, rhs, out } => {
f.write_fmt(format_args!("{out} = atomicExchange({lhs}, {rhs});"))
Expand Down

0 comments on commit b99e753

Please sign in to comment.