Skip to content

Commit

Permalink
Revert variable reuse optimization in loops (#121)
Browse files Browse the repository at this point in the history
  • Loading branch information
wingertge authored Sep 15, 2024
1 parent 230ab97 commit a2a2add
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 10 deletions.
3 changes: 2 additions & 1 deletion crates/cubecl-core/src/frontend/branch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,8 @@ pub fn return_expand(context: &mut CubeContext) {
context.register(Branch::Return);
}

pub fn loop_expand(context: &mut CubeContext, block: impl FnOnce(&mut CubeContext)) {
// Don't make this `FnOnce`, it must be executable multiple times
pub fn loop_expand(context: &mut CubeContext, mut block: impl FnMut(&mut CubeContext)) {
let mut inside_loop = context.child();

block(&mut inside_loop);
Expand Down
10 changes: 6 additions & 4 deletions crates/cubecl-core/tests/frontend/loop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ mod tests {

let mut scope = context.into_scope();
let cond = scope.create_local(Item::new(Elem::Bool));
let y = scope.create_local(item);
let lhs: Variable = lhs.into();

cpa!(
Expand All @@ -97,8 +98,8 @@ mod tests {
cpa!(scope, if(cond).then(|scope|{
scope.register(Branch::Break)
}));

cpa!(scope, lhs = lhs % 1i32);
// Must not mutate `lhs` because it is used in every iteration
cpa!(scope, y = lhs % 1i32);
})
);

Expand All @@ -112,6 +113,7 @@ mod tests {

let mut scope = context.into_scope();
let cond = scope.create_local(Item::new(Elem::Bool));
let y = scope.create_local(item);
let lhs: Variable = lhs.into();

cpa!(
Expand All @@ -124,8 +126,8 @@ mod tests {
false => scope.register(Branch::Break)
}
}));

cpa!(scope, lhs = lhs % 1i32);
// Must not mutate `lhs` because it is used in every iteration
cpa!(scope, y = lhs % 1i32);
})
);

Expand Down
5 changes: 4 additions & 1 deletion crates/cubecl-macros/src/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,10 @@ pub enum Expression {
block: Block,
scope: Scope,
},
Loop(Block),
Loop {
block: Block,
scope: Scope,
},
If {
condition: Box<Expression>,
then_block: Block,
Expand Down
4 changes: 2 additions & 2 deletions crates/cubecl-macros/src/generate/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,9 +261,9 @@ impl Expression {
}
}
}
Expression::Loop(block) => {
Expression::Loop { block, scope } => {
let loop_ty = frontend_type("branch");
let block = block.to_tokens(context);
let block = context.in_fn_mut(scope, |ctx| block.to_tokens(ctx));

quote![#loop_ty::loop_expand(context, |context| #block);]
}
Expand Down
4 changes: 2 additions & 2 deletions crates/cubecl-macros/src/parse/branch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ fn expand_for_in_loop(
}

pub fn expand_loop(loop_expr: ExprLoop, context: &mut Context) -> syn::Result<Expression> {
let (block, _) = context.in_scope(|ctx| Block::from_block(loop_expr.body, ctx))?;
Ok(Expression::Loop(block))
let (block, scope) = context.in_scope(|ctx| Block::from_block(loop_expr.body, ctx))?;
Ok(Expression::Loop { block, scope })
}

pub fn expand_if(if_expr: ExprIf, context: &mut Context) -> syn::Result<Expression> {
Expand Down

0 comments on commit a2a2add

Please sign in to comment.