Skip to content

Commit

Permalink
Clean up macro and optimize branch operations (#118)
Browse files Browse the repository at this point in the history
  • Loading branch information
wingertge authored Sep 11, 2024
1 parent 7a86f9a commit f1164b0
Show file tree
Hide file tree
Showing 16 changed files with 377 additions and 527 deletions.
77 changes: 41 additions & 36 deletions crates/cubecl-core/src/frontend/branch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,33 +246,57 @@ pub fn if_expand(
}
}

pub enum IfElseExpand {
ComptimeThen,
ComptimeElse,
Runtime {
runtime_cond: ExpandElement,
then_child: CubeContext,
},
}

impl IfElseExpand {
pub fn or_else(self, context: &mut CubeContext, else_block: impl FnOnce(&mut CubeContext)) {
match self {
Self::Runtime {
runtime_cond,
then_child,
} => {
let mut else_child = context.child();
else_block(&mut else_child);

context.register(Branch::IfElse(IfElse {
cond: *runtime_cond,
scope_if: then_child.into_scope(),
scope_else: else_child.into_scope(),
}));
}
Self::ComptimeElse => else_block(context),
Self::ComptimeThen => (),
}
}
}

pub fn if_else_expand(
context: &mut CubeContext,
runtime_cond: ExpandElement,
then_block: impl FnOnce(&mut CubeContext),
else_block: impl FnOnce(&mut CubeContext),
) {
) -> IfElseExpand {
let comptime_cond = runtime_cond.as_const().map(|it| it.as_bool());
match comptime_cond {
Some(cond) => {
if cond {
then_block(context);
} else {
else_block(context);
}
Some(true) => {
then_block(context);
IfElseExpand::ComptimeThen
}
Some(false) => IfElseExpand::ComptimeElse,
None => {
let mut then_child = context.child();
then_block(&mut then_child);

let mut else_child = context.child();
else_block(&mut else_child);

context.register(Branch::IfElse(IfElse {
cond: *runtime_cond,
scope_if: then_child.into_scope(),
scope_else: else_child.into_scope(),
}));
IfElseExpand::Runtime {
runtime_cond,
then_child,
}
}
}
}
Expand All @@ -285,28 +309,9 @@ pub fn return_expand(context: &mut CubeContext) {
context.register(Branch::Return);
}

pub fn loop_expand<FB>(context: &mut CubeContext, mut block: FB)
where
FB: FnMut(&mut CubeContext),
{
let mut inside_loop = context.child();

block(&mut inside_loop);
context.register(Branch::Loop(Loop {
scope: inside_loop.into_scope(),
}));
}

pub fn while_loop_expand(
context: &mut CubeContext,
mut cond_fn: impl FnMut(&mut CubeContext) -> ExpandElementTyped<bool>,
block: impl FnOnce(&mut CubeContext),
) {
pub fn loop_expand(context: &mut CubeContext, block: impl FnOnce(&mut CubeContext)) {
let mut inside_loop = context.child();

let cond: ExpandElement = cond_fn(&mut inside_loop).into();
if_expand(&mut inside_loop, cond, break_expand);

block(&mut inside_loop);
context.register(Branch::Loop(Loop {
scope: inside_loop.into_scope(),
Expand Down
4 changes: 2 additions & 2 deletions crates/cubecl-core/tests/frontend/comptime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ mod tests {
cpa!(&mut scope, if(runtime_cond).then(|scope| {
cpa!(scope, y = x + 5.0f32);
}).else(|scope| {
cpa!(scope, y = x - 6.0f32);
cpa!(scope, x = x - 6.0f32);
}));
};

Expand All @@ -368,7 +368,7 @@ mod tests {
if comptime_cond {
cpa!(scope, y = x + 5.0f32);
} else {
cpa!(scope, y = x - 6.0f32);
cpa!(scope, x = x - 6.0f32);
}
}));

Expand Down
8 changes: 4 additions & 4 deletions crates/cubecl-core/tests/frontend/if.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ mod tests {
frontend::{CubeContext, CubePrimitive},
ir::{Elem, Item, Variable},
};
use pretty_assertions::assert_eq;

use super::*;

Expand Down Expand Up @@ -94,11 +95,10 @@ mod tests {
let mut scope = context.into_scope();
let cond = scope.create_local(Item::new(Elem::Bool));
let lhs: Variable = lhs.into();
let y = scope.create_local(item);

cpa!(scope, cond = lhs > 0f32);
cpa!(&mut scope, if(cond).then(|scope| {
cpa!(scope, y = lhs + 4.0f32);
cpa!(scope, lhs = lhs + 4.0f32);
}));

format!("{:?}", scope.operations)
Expand All @@ -118,7 +118,7 @@ mod tests {
cpa!(&mut scope, if(cond).then(|scope| {
cpa!(scope, y = lhs + 4.0f32);
}).else(|scope|{
cpa!(scope, y = lhs - 5.0f32);
cpa!(scope, lhs = lhs - 5.0f32);
}));

format!("{:?}", scope.operations)
Expand All @@ -143,7 +143,7 @@ mod tests {
cpa!(&mut scope, if(cond2).then(|scope| {
cpa!(scope, y = lhs + 1.0f32);
}).else(|scope|{
cpa!(scope, y = lhs + 0.0f32);
cpa!(scope, lhs = lhs + 0.0f32);
}));
}));

Expand Down
7 changes: 3 additions & 4 deletions crates/cubecl-core/tests/frontend/loop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ mod tests {
cpa,
ir::{Branch, Elem, Item, Variable},
};
use pretty_assertions::assert_eq;

type ElemType = i32;

Expand Down Expand Up @@ -87,7 +88,6 @@ mod tests {
let mut scope = context.into_scope();
let cond = scope.create_local(Item::new(Elem::Bool));
let lhs: Variable = lhs.into();
let rhs = scope.create_local(item);

cpa!(
&mut scope,
Expand All @@ -98,7 +98,7 @@ mod tests {
scope.register(Branch::Break)
}));

cpa!(scope, rhs = lhs % 1i32);
cpa!(scope, lhs = lhs % 1i32);
})
);

Expand All @@ -113,7 +113,6 @@ mod tests {
let mut scope = context.into_scope();
let cond = scope.create_local(Item::new(Elem::Bool));
let lhs: Variable = lhs.into();
let rhs = scope.create_local(item);

cpa!(
&mut scope,
Expand All @@ -126,7 +125,7 @@ mod tests {
}
}));

cpa!(scope, rhs = lhs % 1i32);
cpa!(scope, lhs = lhs % 1i32);
})
);

Expand Down
42 changes: 16 additions & 26 deletions crates/cubecl-macros/src/expression.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
use std::{rc::Rc, sync::atomic::AtomicUsize};

use proc_macro2::{Span, TokenStream};
use quote::{quote, ToTokens};
use syn::{
AngleBracketedGenericArguments, Ident, Lit, Member, Pat, Path, PathArguments, PathSegment, Type,
};

use crate::{operator::Operator, scope::Context, statement::Statement};
use crate::{
operator::Operator,
scope::{Context, ManagedVar, Scope},
statement::Statement,
};

#[derive(Clone, Debug)]
pub enum Expression {
Expand All @@ -21,18 +23,7 @@ pub enum Expression {
operator: Operator,
ty: Option<Type>,
},
Variable {
name: Ident,
is_ref: bool,
is_mut: bool,
use_count: Rc<AtomicUsize>,
ty: Option<Type>,
},
ConstVariable {
name: Ident,
use_count: Rc<AtomicUsize>,
ty: Option<Type>,
},
Variable(ManagedVar),
FieldAccess {
base: Box<Expression>,
field: Member,
Expand Down Expand Up @@ -68,6 +59,7 @@ pub enum Expression {
Closure {
params: Vec<Pat>,
body: Box<Expression>,
scope: Scope,
},
Cast {
from: Box<Expression>,
Expand All @@ -88,10 +80,7 @@ pub enum Expression {
var_name: syn::Ident,
var_ty: Option<syn::Type>,
block: Block,
},
WhileLoop {
condition: Box<Expression>,
block: Block,
scope: Scope,
},
Loop(Block),
If {
Expand Down Expand Up @@ -142,7 +131,7 @@ pub enum Expression {
},
}

#[derive(Clone, Debug)]
#[derive(Clone, Debug, Default)]
pub struct Block {
pub inner: Vec<Statement>,
pub ret: Option<Box<Expression>>,
Expand All @@ -154,8 +143,7 @@ impl Expression {
match self {
Expression::Binary { ty, .. } => ty.clone(),
Expression::Unary { ty, .. } => ty.clone(),
Expression::Variable { ty, .. } => ty.clone(),
Expression::ConstVariable { ty, .. } => ty.clone(),
Expression::Variable(var) => var.ty.clone(),
Expression::Literal { ty, .. } => Some(ty.clone()),
Expression::Assignment { ty, .. } => ty.clone(),
Expression::Verbatim { .. } => None,
Expand All @@ -169,7 +157,6 @@ impl Expression {
Expression::MethodCall { .. } => None,
Expression::Path { .. } => None,
Expression::Range { start, .. } => start.ty(),
Expression::WhileLoop { .. } => None,
Expression::Loop { .. } => None,
Expression::If { then_block, .. } => then_block.ty.clone(),
Expression::Return { expr, .. } => expr.as_ref().and_then(|expr| expr.ty()),
Expand All @@ -193,7 +180,7 @@ impl Expression {
Expression::Path { .. } => true,
Expression::Verbatim { .. } => true,
Expression::VerbatimTerminated { .. } => true,
Expression::ConstVariable { .. } => true,
Expression::Variable(var) => var.is_const,
Expression::FieldAccess { base, .. } => base.is_const(),
Expression::Reference { inner } => inner.is_const(),
Expression::Array { elements, .. } => elements.iter().all(|it| it.is_const()),
Expand All @@ -211,7 +198,11 @@ impl Expression {
Expression::Literal { value, .. } => Some(quote![#value]),
Expression::Verbatim { tokens, .. } => Some(tokens.clone()),
Expression::VerbatimTerminated { tokens, .. } => Some(tokens.clone()),
Expression::ConstVariable { name, .. } => Some(quote![#name.clone()]),
Expression::Variable(ManagedVar {
name,
is_const: true,
..
}) => Some(quote![#name.clone()]),
Expression::Path { path, .. } => Some(quote![#path]),
Expression::Array { elements, .. } => {
let elements = elements
Expand Down Expand Up @@ -248,7 +239,6 @@ impl Expression {
Expression::If { then_block, .. } => then_block.ret.is_some(),
Expression::Block(block) => block.ret.is_some(),
Expression::ForLoop { .. } => false,
Expression::WhileLoop { .. } => false,
Expression::Loop { .. } => false,
Expression::VerbatimTerminated { .. } => false,
_ => true,
Expand Down
Loading

0 comments on commit f1164b0

Please sign in to comment.