Skip to content

Commit

Permalink
Implement explicit return
Browse files Browse the repository at this point in the history
  • Loading branch information
wingertge committed Aug 25, 2024
1 parent 3bee11f commit 84a8506
Show file tree
Hide file tree
Showing 9 changed files with 114 additions and 16 deletions.
21 changes: 21 additions & 0 deletions crates/cubecl-core/src/new_ir/branch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -299,3 +299,24 @@ where
None
}
}

#[derive(new)]
pub struct Return<Type: SquareType, Ret: Expr<Output = Type>>(pub Option<Ret>);

impl<Type: SquareType, Ret: Expr<Output = Type>> Expr for Return<Type, Ret> {
type Output = Ret;

fn expression_untyped(&self) -> Expression {
Expression::Return {
expr: self
.0
.as_ref()
.map(|it| it.expression_untyped())
.map(Box::new),
}
}

fn vectorization(&self) -> Option<NonZero<u8>> {
None
}
}
7 changes: 7 additions & 0 deletions crates/cubecl-core/src/new_ir/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ pub enum Expression {
then_block: Box<Expression>,
else_branch: Option<Box<Expression>>,
},
Return {
expr: Option<Box<Expression>>,
},

/// A range used in for loops. Currently doesn't exist at runtime, so can be ignored in codegen.
/// This only exists to pass the range down to the for loop it applies to
__Range(Range),
Expand Down Expand Up @@ -115,6 +119,9 @@ impl Expression {
Expression::WhileLoop { .. } => Elem::Unit,
Expression::Loop { .. } => Elem::Unit,
Expression::If { then_block, .. } => then_block.ir_type(),
Expression::Return { expr } => {
expr.as_ref().map(|it| it.ir_type()).unwrap_or(Elem::Unit)
}
}
}

Expand Down
7 changes: 6 additions & 1 deletion crates/cubecl-macros-2/src/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,11 @@ pub enum Expression {
else_branch: Option<Box<Expression>>,
span: Span,
},

Return {
expr: Option<Box<Expression>>,
ty: Type,
span: Span,
},
Range {
start: Box<Expression>,
end: Box<Expression>,
Expand Down Expand Up @@ -151,6 +155,7 @@ impl Expression {
Expression::WhileLoop { .. } => None,
Expression::Loop { .. } => None,
Expression::If { then_block, .. } => then_block.ty(),
Expression::Return { expr, .. } => expr.as_ref().and_then(|expr| expr.ty()),
}
}

Expand Down
12 changes: 11 additions & 1 deletion crates/cubecl-macros-2/src/generate/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ impl ToTokens for Expression {
let else_branch = else_branch
.as_ref()
.map(|it| quote![Some(#it)])
.unwrap_or_else(|| quote![None]);
.unwrap_or_else(|| quote![None::<()>]);
quote_spanned! {*span=>
#if_ty::new(#condition, #then_block, #else_branch)
}
Expand All @@ -244,6 +244,16 @@ impl ToTokens for Expression {
#range::new(#start, #end, #inclusive)
}
}
Expression::Return { expr, ty, span } => {
let ret_ty = ir_type("Return");
let ret_expr = expr
.as_ref()
.map(|it| quote![Some(#it)])
.unwrap_or_else(|| quote![None]);
quote_spanned! {*span=>
#ret_ty::<#ty, _>::new(#ret_expr)
}
}
};

tokens.extend(out);
Expand Down
4 changes: 3 additions & 1 deletion crates/cubecl-macros-2/src/generate/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ impl ToTokens for Kernel {
let vis = &self.visibility;
let name = &self.name;
let generics = &self.generics;
let global_vars = Context::default().current_scope().generate_vars();
let global_vars = Context::new(self.returns.clone())
.current_scope()
.generate_vars();
let block = &self.block;
let return_type = &self.returns;
let args = transform_args(&self.parameters);
Expand Down
32 changes: 25 additions & 7 deletions crates/cubecl-macros-2/src/parse/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,22 +172,40 @@ impl Expression {
}
Expr::Group(group) => Expression::from_expr(*group.expr, context)?,
Expr::Paren(paren) => Expression::from_expr(*paren.expr, context)?,
Expr::Return(ret) => Expression::Return {
span: ret.span(),
expr: ret
.expr
.map(|expr| Expression::from_expr(*expr, context))
.transpose()?
.map(Box::new),
ty: context.return_type.clone(),
},
Expr::Index(_) => todo!("index"),
Expr::Infer(_) => todo!("infer"),
Expr::Let(_) => todo!("let"),

Expr::Macro(_) => todo!("macro"),
Expr::Match(_) => todo!("match"),
Expr::Reference(_) => todo!("reference"),
Expr::Repeat(_) => todo!("repeat"),
Expr::Return(_) => todo!("return"),
Expr::Struct(_) => todo!("struct"),
Expr::Try(_) => todo!("try"),
Expr::TryBlock(_) => todo!("try_block"),
Expr::Tuple(_) => todo!("tuple"),
Expr::Unsafe(_) => todo!("unsafe"),
Expr::Verbatim(_) => todo!("verbatim"),
_ => Err(syn::Error::new_spanned(expr, "Unsupported expression"))?,
Expr::Unsafe(unsafe_expr) => {
context.with_scope(|context| parse_block(unsafe_expr.block, context))?
}
Expr::Verbatim(verbatim) => Expression::Verbatim { tokens: verbatim },
Expr::Try(_) => Err(syn::Error::new_spanned(
expr,
"? Operator is not supported in kernels",
))?,
Expr::TryBlock(_) => Err(syn::Error::new_spanned(
expr,
"try_blocks is unstable and not supported in kernels",
))?,
e => Err(syn::Error::new_spanned(
expr,
format!("Unsupported expression {e:?}"),
))?,
};
Ok(result)
}
Expand Down
3 changes: 1 addition & 2 deletions crates/cubecl-macros-2/src/parse/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,14 @@ pub struct Kernel {

impl Kernel {
pub fn from_item_fn(function: ItemFn) -> syn::Result<Self> {
let mut context = Context::default();

let name = function.sig.ident;
let vis = function.vis;
let generics = function.sig.generics;
let returns = match function.sig.output {
syn::ReturnType::Default => syn::parse2(quote![()]).unwrap(),
syn::ReturnType::Type(_, ty) => *ty,
};
let mut context = Context::new(returns.clone());
let parameters = function
.sig
.inputs
Expand Down
15 changes: 11 additions & 4 deletions crates/cubecl-macros-2/src/scope.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,14 @@ pub const KEYWORDS: [&str; 21] = [
];

pub struct Context {
pub return_type: Type,
scopes: Vec<Scope>,
// Allows for global variable analysis
scope_history: Vec<Scope>,
}

impl Default for Context {
fn default() -> Self {
impl Context {
pub fn new(return_type: Type) -> Self {
let mut root_scope = Scope::default();
root_scope.variables.extend(KEYWORDS.iter().map(|it| {
let name = format_ident!("{it}");
Expand All @@ -50,13 +51,12 @@ impl Default for Context {
}
}));
Self {
return_type,
scopes: vec![root_scope],
scope_history: Default::default(),
}
}
}

impl Context {
pub fn push_variable(&mut self, name: Ident, ty: Option<Type>, is_const: bool) {
self.scopes
.last_mut()
Expand All @@ -74,6 +74,13 @@ impl Context {
self.scope_history.push(scope);
}

pub fn with_scope<T>(&mut self, with: impl FnOnce(&mut Self) -> T) -> T {
self.push_scope();
let res = with(self);
self.pop_scope();
res
}

pub fn restore_scope(&mut self) {
let scope = self.scope_history.pop();
if let Some(scope) = scope {
Expand Down
29 changes: 29 additions & 0 deletions crates/cubecl-macros-2/tests/branch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -514,3 +514,32 @@ fn chained_if() {

assert_eq!(expanded, expected);
}

#[test]
fn explicit_return() {
#[allow(unused)]
#[cube2]
fn if_returns(cond: bool) -> u32 {
if cond {
return 10;
}
1
}

let expanded = if_returns::expand(Variable::new("cond", None)).expression_untyped();
let expected = block(
vec![expr(Expression::If {
condition: var("cond", Elem::Bool),
then_block: Box::new(block(
vec![expr(Expression::Return {
expr: Some(Box::new(lit(10u32))),
})],
None,
)),
else_branch: None,
})],
Some(lit(1u32)),
);

assert_eq!(expanded, expected);
}

0 comments on commit 84a8506

Please sign in to comment.