From b99e7533820f127b3ecb9f118456742a00311732 Mon Sep 17 00:00:00 2001 From: AntiAnimeGeneral <119280876+AntiAnimeGeneral@users.noreply.github.com> Date: Tue, 20 Aug 2024 02:32:49 +0800 Subject: [PATCH] Impl SyncStorage (#72) --- crates/cubecl-core/src/frontend/synchronization.rs | 10 ++++++++++ crates/cubecl-core/src/ir/synchronization.rs | 1 + crates/cubecl-cuda/src/compiler/base.rs | 1 + crates/cubecl-cuda/src/compiler/instruction.rs | 2 ++ crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs | 3 +++ crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs | 4 +++- 6 files changed, 20 insertions(+), 1 deletion(-) diff --git a/crates/cubecl-core/src/frontend/synchronization.rs b/crates/cubecl-core/src/frontend/synchronization.rs index a47967e4..c4f64cd5 100644 --- a/crates/cubecl-core/src/frontend/synchronization.rs +++ b/crates/cubecl-core/src/frontend/synchronization.rs @@ -10,3 +10,13 @@ pub mod sync_units { context.register(Synchronization::SyncUnits) } } + +pub fn sync_storage() {} + +pub mod sync_storage { + use super::*; + + pub fn __expand(context: &mut CubeContext) { + context.register(Synchronization::SyncStorage) + } +} diff --git a/crates/cubecl-core/src/ir/synchronization.rs b/crates/cubecl-core/src/ir/synchronization.rs index 1db20c9b..9822b9ff 100644 --- a/crates/cubecl-core/src/ir/synchronization.rs +++ b/crates/cubecl-core/src/ir/synchronization.rs @@ -6,4 +6,5 @@ use serde::{Deserialize, Serialize}; pub enum Synchronization { // Synchronizize units in a cube. SyncUnits, + SyncStorage } diff --git a/crates/cubecl-cuda/src/compiler/base.rs b/crates/cubecl-cuda/src/compiler/base.rs index 6a991dd7..0fc97760 100644 --- a/crates/cubecl-cuda/src/compiler/base.rs +++ b/crates/cubecl-cuda/src/compiler/base.rs @@ -135,6 +135,7 @@ impl CudaCompiler { gpu::Operation::Branch(val) => self.compile_branch(instructions, val), gpu::Operation::Synchronization(val) => match val { gpu::Synchronization::SyncUnits => instructions.push(Instruction::SyncThreads), + gpu::Synchronization::SyncStorage => instructions.push(Instruction::ThreadFence), }, gpu::Operation::Subcube(op) => { self.wrap_size_checked = true; diff --git a/crates/cubecl-cuda/src/compiler/instruction.rs b/crates/cubecl-cuda/src/compiler/instruction.rs index 97b35bba..e733f7fa 100644 --- a/crates/cubecl-cuda/src/compiler/instruction.rs +++ b/crates/cubecl-cuda/src/compiler/instruction.rs @@ -114,6 +114,7 @@ pub enum Instruction { out: Variable, }, SyncThreads, + ThreadFence, Ceil(UnaryInstruction), Floor(UnaryInstruction), Wrap(WarpInstruction), @@ -264,6 +265,7 @@ for (uint {i} = {start}; {i} < {end}; {increment}) {{ out, } => Clamp::format(f, input, min_value, max_value, out), Instruction::SyncThreads => f.write_str("__syncthreads();\n"), + Instruction::ThreadFence => f.write_str("__threadfence();\n"), Instruction::Ceil(it) => Ceil::format(f, &it.input, &it.out), Instruction::Floor(it) => Floor::format(f, &it.input, &it.out), Instruction::SliceLength { input, out } => { diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs b/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs index d1a6f6da..345bd336 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs @@ -398,6 +398,9 @@ impl WgslCompiler { match synchronization { cube::Synchronization::SyncUnits => { instructions.push(wgsl::Instruction::WorkgroupBarrier) + }, + cube::Synchronization::SyncStorage=> { + instructions.push(wgsl::Instruction::StorageBarrier) } }; } diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs b/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs index 87819b55..39a47cce 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs @@ -44,6 +44,7 @@ pub enum Instruction { Return, Break, WorkgroupBarrier, + StorageBarrier, // Index handles casting to correct local variable. Index { lhs: Variable, @@ -578,6 +579,7 @@ for (var {i}: u32 = {start}; {i} < {end}; {increment}) {{ Instruction::Return => f.write_str("return;\n"), Instruction::Break => f.write_str("break;\n"), Instruction::WorkgroupBarrier => f.write_str("workgroupBarrier();\n"), + Instruction::StorageBarrier => f.write_str("storageBarrier();\n"), Instruction::Length { var, out } => match var { Variable::Slice { .. } => f.write_fmt(format_args!("{out} = {var}_length;\n")), _ => f.write_fmt(format_args!("{out} = arrayLength(&{var});\n")), @@ -615,7 +617,7 @@ for (var {i}: u32 = {start}; {i} < {end}; {increment}) {{ f.write_fmt(format_args!("{out} = atomicLoad({input});\n")) } Instruction::AtomicStore { input, out } => { - f.write_fmt(format_args!("{out} = atomicStore({input});\n")) + f.write_fmt(format_args!("atomicStore({out},{input});\n")) } Instruction::AtomicSwap { lhs, rhs, out } => { f.write_fmt(format_args!("{out} = atomicExchange({lhs}, {rhs});"))