From ccd9d63f2ef7a593f250d0eded1285aa2d1fce3c Mon Sep 17 00:00:00 2001 From: Genna Wingert Date: Sun, 15 Sep 2024 20:09:22 +0200 Subject: [PATCH] Revert variable reuse optimization in loops --- crates/cubecl-core/src/frontend/branch.rs | 3 ++- crates/cubecl-core/tests/frontend/loop.rs | 10 ++++++---- crates/cubecl-macros/src/expression.rs | 5 ++++- crates/cubecl-macros/src/generate/expression.rs | 4 ++-- crates/cubecl-macros/src/parse/branch.rs | 4 ++-- 5 files changed, 16 insertions(+), 10 deletions(-) diff --git a/crates/cubecl-core/src/frontend/branch.rs b/crates/cubecl-core/src/frontend/branch.rs index 626aee841..4e4b16610 100644 --- a/crates/cubecl-core/src/frontend/branch.rs +++ b/crates/cubecl-core/src/frontend/branch.rs @@ -415,7 +415,8 @@ pub fn return_expand(context: &mut CubeContext) { context.register(Branch::Return); } -pub fn loop_expand(context: &mut CubeContext, block: impl FnOnce(&mut CubeContext)) { +// Don't make this `FnOnce`, it must be executable multiple times +pub fn loop_expand(context: &mut CubeContext, mut block: impl FnMut(&mut CubeContext)) { let mut inside_loop = context.child(); block(&mut inside_loop); diff --git a/crates/cubecl-core/tests/frontend/loop.rs b/crates/cubecl-core/tests/frontend/loop.rs index 3eeb85c90..ac76ae251 100644 --- a/crates/cubecl-core/tests/frontend/loop.rs +++ b/crates/cubecl-core/tests/frontend/loop.rs @@ -87,6 +87,7 @@ mod tests { let mut scope = context.into_scope(); let cond = scope.create_local(Item::new(Elem::Bool)); + let y = scope.create_local(item); let lhs: Variable = lhs.into(); cpa!( @@ -97,8 +98,8 @@ mod tests { cpa!(scope, if(cond).then(|scope|{ scope.register(Branch::Break) })); - - cpa!(scope, lhs = lhs % 1i32); + // Must not mutate `lhs` because it is used in every iteration + cpa!(scope, y = lhs % 1i32); }) ); @@ -112,6 +113,7 @@ mod tests { let mut scope = context.into_scope(); let cond = scope.create_local(Item::new(Elem::Bool)); + let y = scope.create_local(item); let lhs: Variable = lhs.into(); cpa!( @@ -124,8 +126,8 @@ mod tests { false => scope.register(Branch::Break) } })); - - cpa!(scope, lhs = lhs % 1i32); + // Must not mutate `lhs` because it is used in every iteration + cpa!(scope, y = lhs % 1i32); }) ); diff --git a/crates/cubecl-macros/src/expression.rs b/crates/cubecl-macros/src/expression.rs index 023cbd380..bc267de17 100644 --- a/crates/cubecl-macros/src/expression.rs +++ b/crates/cubecl-macros/src/expression.rs @@ -82,7 +82,10 @@ pub enum Expression { block: Block, scope: Scope, }, - Loop(Block), + Loop { + block: Block, + scope: Scope, + }, If { condition: Box, then_block: Block, diff --git a/crates/cubecl-macros/src/generate/expression.rs b/crates/cubecl-macros/src/generate/expression.rs index a849c504e..19c3c0301 100644 --- a/crates/cubecl-macros/src/generate/expression.rs +++ b/crates/cubecl-macros/src/generate/expression.rs @@ -261,9 +261,9 @@ impl Expression { } } } - Expression::Loop(block) => { + Expression::Loop { block, scope } => { let loop_ty = frontend_type("branch"); - let block = block.to_tokens(context); + let block = context.in_fn_mut(scope, |ctx| block.to_tokens(ctx)); quote![#loop_ty::loop_expand(context, |context| #block);] } diff --git a/crates/cubecl-macros/src/parse/branch.rs b/crates/cubecl-macros/src/parse/branch.rs index a8524cb30..4c6ecae8b 100644 --- a/crates/cubecl-macros/src/parse/branch.rs +++ b/crates/cubecl-macros/src/parse/branch.rs @@ -67,8 +67,8 @@ fn expand_for_in_loop( } pub fn expand_loop(loop_expr: ExprLoop, context: &mut Context) -> syn::Result { - let (block, _) = context.in_scope(|ctx| Block::from_block(loop_expr.body, ctx))?; - Ok(Expression::Loop(block)) + let (block, scope) = context.in_scope(|ctx| Block::from_block(loop_expr.body, ctx))?; + Ok(Expression::Loop { block, scope }) } pub fn expand_if(if_expr: ExprIf, context: &mut Context) -> syn::Result {