diff --git a/crates/cubecl-macros/src/codegen_function/branch.rs b/crates/cubecl-macros/src/codegen_function/branch.rs index 011022e0f..230393f8d 100644 --- a/crates/cubecl-macros/src/codegen_function/branch.rs +++ b/crates/cubecl-macros/src/codegen_function/branch.rs @@ -1,5 +1,4 @@ use proc_macro2::TokenStream; -use syn::Ident; use crate::{ codegen_function::{base::CodegenKind, expr::codegen_expr}, @@ -48,7 +47,7 @@ pub(crate) fn codegen_for_loop( } }; - if &func_name.to_string() == "range" || &func_name.to_string() == "range_stepped" { + if &func_name.to_string() == "range" { let mut args = call.args.clone(); let unroll = codegen_expr( @@ -68,17 +67,48 @@ pub(crate) fn codegen_for_loop( ); let block = codegen_block(&for_loop.body, loop_level + 1, variable_tracker); - let expand = Ident::new( - &format!("{}_expand", func_name.to_string()), - func_name.span(), + + quote::quote! { + { + let _start = #start; + let _end = #end; + let _unroll = #unroll; + cubecl::frontend::branch::range_expand(context, _start, _end, _unroll, |context, #i| #block); + } + } + } else if &func_name.to_string() == "range_stepped" { + let mut args = call.args.clone(); + + let unroll = codegen_expr( + &args.pop().unwrap().into_value(), + loop_level, + variable_tracker, + ); + let step = codegen_expr( + &args.pop().unwrap().into_value(), + loop_level, + variable_tracker, + ); + let end = codegen_expr( + &args.pop().unwrap().into_value(), + loop_level, + variable_tracker, ); + let start = codegen_expr( + &args.pop().unwrap().into_value(), + loop_level, + variable_tracker, + ); + + let block = codegen_block(&for_loop.body, loop_level + 1, variable_tracker); quote::quote! { { let _start = #start; let _end = #end; + let _step = #step; let _unroll = #unroll; - cubecl::frontend::branch::#expand(context, _start, _end, _unroll, |context, #i| #block); + cubecl::frontend::branch::range_stepped_expand(context, _start, _end, _step, _unroll, |context, #i| #block); } } } else {