Skip to content

Commit

Permalink
Add if as a value expression (#120)
Browse files Browse the repository at this point in the history
  • Loading branch information
wingertge authored Sep 13, 2024
1 parent 2b9cc83 commit a8cbd14
Show file tree
Hide file tree
Showing 17 changed files with 303 additions and 57 deletions.
122 changes: 114 additions & 8 deletions crates/cubecl-core/src/frontend/branch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use num_traits::NumCast;
use crate::frontend::{CubeContext, ExpandElement};
use crate::ir::{Branch, If, IfElse, Item, Loop, RangeLoop};

use super::{CubeType, ExpandElementTyped, Int, Numeric};
use super::{assign, CubePrimitive, CubeType, ExpandElementTyped, Int, Numeric};

/// Something that can be iterated on by a for loop. Currently only includes `Range`, `StepBy` and
/// `Sequence`.
Expand Down Expand Up @@ -46,7 +46,10 @@ impl<I: Int> RangeExpand<I> {
}
}

pub fn __expand_step_by(self, n: impl Into<ExpandElementTyped<u32>>) -> SteppedRangeExpand<I> {
pub fn __expand_step_by_method(
self,
n: impl Into<ExpandElementTyped<u32>>,
) -> SteppedRangeExpand<I> {
SteppedRangeExpand {
start: self.start,
end: self.end,
Expand Down Expand Up @@ -182,23 +185,40 @@ impl<I: Int + Into<ExpandElement>> Iterable<I> for SteppedRangeExpand<I> {
/// integer range. Equivalent to:
///
/// ```ignore
/// for i in start..end { ... }
/// start..end
/// ```
pub fn range<T: Int>(start: T, end: T) -> impl Iterator<Item = T> {
let start: i64 = start.to_i64().unwrap();
let end: i64 = end.to_i64().unwrap();
(start..end).map(<T as NumCast>::from).map(Option::unwrap)
}

pub mod range {
use crate::prelude::{CubeContext, ExpandElementTyped, Int};

use super::RangeExpand;

pub fn expand<I: Int>(
_context: &mut CubeContext,
start: ExpandElementTyped<I>,
end: ExpandElementTyped<I>,
) -> RangeExpand<I> {
RangeExpand {
start,
end,
inclusive: false,
}
}
}

/// Stepped range. Equivalent to:
///
/// ```ignore
/// for i in (start..end).step_by(step) { ... }
/// (start..end).step_by(step)
/// ```
pub fn range_stepped<I: Int>(start: I, end: I, step: I) -> impl Iterator<Item = I>
where
RangeExpand<I>: Iterator,
{
///
/// Allows using any integer for the step, instead of just usize
pub fn range_stepped<I: Int>(start: I, end: I, step: impl Int) -> impl Iterator<Item = I> {
let start = start.to_i64().unwrap();
let end = end.to_i64().unwrap();
let step = step.to_usize().unwrap();
Expand All @@ -208,6 +228,26 @@ where
.map(Option::unwrap)
}

pub mod range_stepped {
use crate::prelude::{CubeContext, ExpandElementTyped, Int};

use super::SteppedRangeExpand;

pub fn expand<I: Int>(
_context: &mut CubeContext,
start: ExpandElementTyped<I>,
end: ExpandElementTyped<I>,
step: ExpandElementTyped<u32>,
) -> SteppedRangeExpand<I> {
SteppedRangeExpand {
start,
end,
step,
inclusive: false,
}
}
}

pub fn for_expand<I: Numeric>(
context: &mut CubeContext,
range: impl Iterable<I>,
Expand Down Expand Up @@ -301,6 +341,72 @@ pub fn if_else_expand(
}
}

pub enum IfElseExprExpand<C: CubeType> {
ComptimeThen(ExpandElementTyped<C>),
ComptimeElse,
Runtime {
runtime_cond: ExpandElement,
out: ExpandElementTyped<C>,
then_child: CubeContext,
},
}

impl<C: CubePrimitive> IfElseExprExpand<C> {
pub fn or_else(
self,
context: &mut CubeContext,
else_block: impl FnOnce(&mut CubeContext) -> ExpandElementTyped<C>,
) -> ExpandElementTyped<C> {
match self {
Self::Runtime {
runtime_cond,
out,
then_child,
} => {
let mut else_child = context.child();
let ret = else_block(&mut else_child);
assign::expand(&mut else_child, ret, out.clone());

context.register(Branch::IfElse(IfElse {
cond: *runtime_cond,
scope_if: then_child.into_scope(),
scope_else: else_child.into_scope(),
}));
out
}
Self::ComptimeElse => else_block(context),
Self::ComptimeThen(ret) => ret,
}
}
}

pub fn if_else_expr_expand<C: CubePrimitive>(
context: &mut CubeContext,
runtime_cond: ExpandElement,
then_block: impl FnOnce(&mut CubeContext) -> ExpandElementTyped<C>,
) -> IfElseExprExpand<C> {
let comptime_cond = runtime_cond.as_const().map(|it| it.as_bool());
match comptime_cond {
Some(true) => {
let ret = then_block(context);
IfElseExprExpand::ComptimeThen(ret)
}
Some(false) => IfElseExprExpand::ComptimeElse,
None => {
let mut then_child = context.child();
let ret = then_block(&mut then_child);
let out: ExpandElementTyped<C> = context.create_local(ret.expand.item()).into();
assign::expand(&mut then_child, ret, out.clone());

IfElseExprExpand::Runtime {
runtime_cond,
out,
then_child,
}
}
}
}

pub fn break_expand(context: &mut CubeContext) {
context.register(Branch::Break);
}
Expand Down
14 changes: 14 additions & 0 deletions crates/cubecl-core/src/frontend/const_expand.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@ pub trait OptionExt<T: CubeType> {
_context: &mut CubeContext,
other: impl FnOnce(&mut CubeContext) -> T::ExpandType,
) -> T::ExpandType;

fn __expand_unwrap_or_method(
self,
_context: &mut CubeContext,
other: T::ExpandType,
) -> T::ExpandType;
}

impl<T: CubeType + Into<T::ExpandType>> OptionExt<T> for Option<T> {
Expand All @@ -16,4 +22,12 @@ impl<T: CubeType + Into<T::ExpandType>> OptionExt<T> for Option<T> {
) -> <T as CubeType>::ExpandType {
self.map(Into::into).unwrap_or_else(|| other(context))
}

fn __expand_unwrap_or_method(
self,
_context: &mut CubeContext,
other: <T as CubeType>::ExpandType,
) -> <T as CubeType>::ExpandType {
self.map(Into::into).unwrap_or(other)
}
}
9 changes: 3 additions & 6 deletions crates/cubecl-core/src/frontend/element/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,10 @@ pub trait BitCast: CubePrimitive {
unexpanded!()
}

fn __expand_bitcast_from<From>(
fn __expand_bitcast_from<From: CubePrimitive>(
context: &mut CubeContext,
value: From,
) -> <Self as CubeType>::ExpandType
where
From: Into<ExpandElement>,
{
value: ExpandElementTyped<From>,
) -> <Self as CubeType>::ExpandType {
let value: ExpandElement = value.into();
let var: Variable = *value;
let new_var = context.create_local(Item::vectorized(
Expand Down
2 changes: 1 addition & 1 deletion crates/cubecl-core/src/frontend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ mod sequence;
mod subcube;
mod topology;

pub use branch::{RangeExpand, SteppedRangeExpand};
pub use branch::{range, range_stepped, RangeExpand, SteppedRangeExpand};
pub use const_expand::*;
pub use context::*;
pub use element::*;
Expand Down
4 changes: 2 additions & 2 deletions crates/cubecl-core/src/ir/processing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,8 @@ impl ScopeProcessing {
sanitize_constant_scalar_ref_elem(&mut op.cond, Elem::Bool);
}
Branch::RangeLoop(op) => {
sanitize_constant_scalar_ref_elem(&mut op.start, Elem::UInt);
sanitize_constant_scalar_ref_elem(&mut op.end, Elem::UInt);
sanitize_constant_scalar_ref_var(&mut op.end, &op.start);
sanitize_constant_scalar_ref_var(&mut op.i, &op.start);
if let Some(step) = &mut op.step {
sanitize_constant_scalar_ref_elem(step, Elem::UInt);
}
Expand Down
3 changes: 1 addition & 2 deletions crates/cubecl-core/src/ir/scope.rs
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,7 @@ impl Scope {
pub fn process(&mut self) -> ScopeProcessing {
self.undeclared += self.locals.len() as u16;

let mut variables = Vec::new();
core::mem::swap(&mut self.locals, &mut variables);
let mut variables = core::mem::take(&mut self.locals);

for var in self.matrices.drain(..) {
variables.push(var);
Expand Down
58 changes: 58 additions & 0 deletions crates/cubecl-core/tests/frontend/if.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,17 @@ pub fn elsif<F: Float>(lhs: F) {
}
}

#[cube]
pub fn elsif_assign<F: Float>(lhs: F) {
let _ = if lhs < F::new(0.) {
lhs + F::new(2.)
} else if lhs > F::new(0.) {
lhs + F::new(1.)
} else {
lhs + F::new(0.)
};
}

mod tests {
use cubecl_core::{
cpa,
Expand Down Expand Up @@ -87,6 +98,21 @@ mod tests {
assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_elsif());
}

#[test]
fn cube_elsif_assign_test() {
let mut context = CubeContext::root();

let lhs = context.create_local(Item::new(ElemType::as_elem()));

elsif_assign::expand::<ElemType>(&mut context, lhs.into());
let scope = context.into_scope();

assert_eq!(
format!("{:#?}", scope.operations),
inline_macro_ref_elsif_assign()
);
}

fn inline_macro_ref_if() -> String {
let mut context = CubeContext::root();
let item = Item::new(ElemType::as_elem());
Expand Down Expand Up @@ -149,4 +175,36 @@ mod tests {

format!("{:?}", scope.operations)
}

fn inline_macro_ref_elsif_assign() -> String {
let mut context = CubeContext::root();
let item = Item::new(ElemType::as_elem());
let lhs = context.create_local(item);

let mut scope = context.into_scope();
let lhs: Variable = lhs.into();
let cond1 = scope.create_local(Item::new(Elem::Bool));
let y = scope.create_local(item);
let out = scope.create_local(item);
let cond2 = scope.create_local(Item::new(Elem::Bool));
let out2 = scope.create_local(item);

cpa!(scope, cond1 = lhs < 0f32);
cpa!(&mut scope, if(cond1).then(|scope| {
cpa!(scope, y = lhs + 2.0f32);
cpa!(scope, out = y);
}).else(|mut scope|{
cpa!(scope, cond2 = lhs > 0f32);
cpa!(&mut scope, if(cond2).then(|scope| {
cpa!(scope, y = lhs + 1.0f32);
cpa!(scope, out2 = y);
}).else(|scope|{
cpa!(scope, lhs = lhs + 0.0f32);
cpa!(scope, out2 = lhs);
}));
cpa!(scope, out = out2);
}));

format!("{:#?}", scope.operations)
}
}
9 changes: 6 additions & 3 deletions crates/cubecl-cuda/src/compiler/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,10 @@ impl CudaCompiler {
}),
gpu::Operator::Index(op) => {
if let ExecutionMode::Checked = self.strategy {
if has_length(&op.lhs) {
// Since atomics must be declared inline (for `wgpu` compatibility), we need to
// disable runtime checks for them. Otherwise the variable would be declared
// inside the `if` scope.
if has_length(&op.lhs) && !op.lhs.item().elem.is_atomic() {
self.compile_procedure(
instructions,
gpu::Procedure::CheckedIndex(gpu::CheckedIndex {
Expand Down Expand Up @@ -749,11 +752,11 @@ impl CudaCompiler {
gpu::IntKind::I64 => panic!("i64 isn't supported yet"),
},
gpu::Elem::AtomicInt(kind) => match kind {
gpu::IntKind::I32 => super::Elem::I32,
gpu::IntKind::I32 => super::Elem::Atomic(super::AtomicKind::I32),
gpu::IntKind::I64 => panic!("atomic<i64> isn't supported yet"),
},
gpu::Elem::UInt => super::Elem::U32,
gpu::Elem::AtomicUInt => super::Elem::U32,
gpu::Elem::AtomicUInt => super::Elem::Atomic(super::AtomicKind::U32),
gpu::Elem::Bool => super::Elem::Bool,
}
}
Expand Down
4 changes: 3 additions & 1 deletion crates/cubecl-cuda/src/compiler/binary.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{Component, Variable};
use super::{Component, Elem, Variable};
use std::fmt::{Display, Formatter};

pub trait Binary {
Expand Down Expand Up @@ -259,6 +259,8 @@ impl Binary for Index {
f.write_fmt(format_args!("{out} = {}({lhs}[{rhs}]);\n", item_out.elem))?;
}
Ok(())
} else if let Elem::Atomic(inner) = item_out.elem {
f.write_fmt(format_args!("{inner}* {out} = &{lhs}[{rhs}];\n"))
} else {
f.write_fmt(format_args!("{out} = {lhs}[{rhs}];\n"))
}
Expand Down
19 changes: 19 additions & 0 deletions crates/cubecl-cuda/src/compiler/element.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,22 @@ pub enum Elem {
I32,
U32,
Bool,
Atomic(AtomicKind),
}

#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
pub enum AtomicKind {
I32,
U32,
}

impl Display for AtomicKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
AtomicKind::I32 => f.write_str("int"),
AtomicKind::U32 => f.write_str("uint"),
}
}
}

#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
Expand All @@ -33,6 +49,7 @@ impl Display for Elem {
Elem::I32 => f.write_str("int"),
Elem::U32 => f.write_str("uint"),
Elem::Bool => f.write_str("bool"),
Elem::Atomic(inner) => inner.fmt(f),
}
}
}
Expand Down Expand Up @@ -470,6 +487,8 @@ impl Elem {
Self::I32 => core::mem::size_of::<i32>(),
Self::U32 => core::mem::size_of::<u32>(),
Self::Bool => core::mem::size_of::<bool>(),
Self::Atomic(AtomicKind::I32) => core::mem::size_of::<i32>(),
Self::Atomic(AtomicKind::U32) => core::mem::size_of::<u32>(),
}
}
}
Loading

0 comments on commit a8cbd14

Please sign in to comment.