diff --git a/crates/cubecl-core/src/ir/procedure/base.rs b/crates/cubecl-core/src/ir/procedure/base.rs index 9fec84ad..a7ca3627 100644 --- a/crates/cubecl-core/src/ir/procedure/base.rs +++ b/crates/cubecl-core/src/ir/procedure/base.rs @@ -1,6 +1,6 @@ use super::{ - CheckedIndex, CheckedIndexAssign, ConditionalAssign, IndexOffsetGlobalWithLayout, ReadGlobal, - ReadGlobalWithLayout, WriteGlobal, + CheckedIndex, CheckedIndexAssign, ConditionalAssign, EarlyReturn, IndexOffsetGlobalWithLayout, + ReadGlobal, ReadGlobalWithLayout, WriteGlobal, }; use crate::ir::Vectorization; use serde::{Deserialize, Serialize}; @@ -17,6 +17,7 @@ pub enum Procedure { CheckedIndex(CheckedIndex), CheckedIndexAssign(CheckedIndexAssign), ConditionalAssign(ConditionalAssign), + EarlyReturn(EarlyReturn), } impl Procedure { @@ -37,6 +38,7 @@ impl Procedure { Procedure::ConditionalAssign(proc) => { Procedure::ConditionalAssign(proc.vectorize(vectorization)) } + Procedure::EarlyReturn(proc) => Procedure::EarlyReturn(proc.vectorize(vectorization)), } } } diff --git a/crates/cubecl-core/src/ir/procedure/early_return.rs b/crates/cubecl-core/src/ir/procedure/early_return.rs new file mode 100644 index 00000000..8b43ef26 --- /dev/null +++ b/crates/cubecl-core/src/ir/procedure/early_return.rs @@ -0,0 +1,35 @@ +use crate::ir::{macros::cpa, Branch, Elem, Item, Scope, Variable, Vectorization}; +use serde::{Deserialize, Serialize}; + +/// Perform a check bound on the index (lhs) of value (rhs) +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[allow(missing_docs)] +pub struct EarlyReturn { + pub global: Variable, + pub position: Variable, +} + +impl EarlyReturn { + #[allow(missing_docs)] + pub fn expand(self, scope: &mut Scope) { + let variable = self.global; + let index = self.position; + + let array_len = scope.create_local(Item::new(Elem::UInt)); + let outside_bound = scope.create_local(Item::new(Elem::Bool)); + + cpa!(scope, array_len = len(variable)); + cpa!(scope, outside_bound = index >= array_len); + + cpa!(scope, if(outside_bound).then(|scope| { + scope.register(Branch::Return); + })); + } + + pub(crate) fn vectorize(&self, vectorization: Vectorization) -> Self { + Self { + global: self.global.vectorize(vectorization), + position: self.position.vectorize(vectorization), + } + } +} diff --git a/crates/cubecl-core/src/ir/procedure/mod.rs b/crates/cubecl-core/src/ir/procedure/mod.rs index a537fc04..40aba002 100644 --- a/crates/cubecl-core/src/ir/procedure/mod.rs +++ b/crates/cubecl-core/src/ir/procedure/mod.rs @@ -1,11 +1,13 @@ mod assign; mod base; +mod early_return; mod index; mod read; mod write; pub use assign::*; pub use base::*; +pub use early_return::*; pub use index::*; pub use read::*; pub use write::*; diff --git a/crates/cubecl-core/src/ir/scope.rs b/crates/cubecl-core/src/ir/scope.rs index e97efba1..1e6a6442 100644 --- a/crates/cubecl-core/src/ir/scope.rs +++ b/crates/cubecl-core/src/ir/scope.rs @@ -315,6 +315,17 @@ impl Scope { let mut operations = Vec::new(); + if let Some((_input, global, position)) = self.writes_global.first() { + if self.depth == 0 { + operations.push(Operation::Procedure(Procedure::EarlyReturn( + super::EarlyReturn { + global: *global, + position: *position, + }, + ))) + } + } + for (input, strategy, local, position) in self.reads_global.drain(..) { match strategy { ReadingStrategy::OutputLayout => { diff --git a/crates/cubecl-cuda/src/compiler/base.rs b/crates/cubecl-cuda/src/compiler/base.rs index c63cd863..5966a1ba 100644 --- a/crates/cubecl-cuda/src/compiler/base.rs +++ b/crates/cubecl-cuda/src/compiler/base.rs @@ -314,6 +314,10 @@ impl CudaCompiler { proc.expand(scope); compile(scope); } + gpu::Procedure::EarlyReturn(proc) => { + proc.expand(scope); + compile(scope); + } } } diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs b/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs index 19ea2a83..fc2f10cc 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs @@ -434,6 +434,10 @@ impl WgslCompiler { proc.expand(scope); compile(scope); } + cube::Procedure::EarlyReturn(proc) => { + proc.expand(scope); + compile(scope); + } } }