Skip to content

Commit

Permalink
Fix macro codegen for stepped loop
Browse files Browse the repository at this point in the history
  • Loading branch information
wingertge committed Aug 15, 2024
1 parent 4f918bc commit c9fd1fd
Showing 1 changed file with 36 additions and 6 deletions.
42 changes: 36 additions & 6 deletions crates/cubecl-macros/src/codegen_function/branch.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use proc_macro2::TokenStream;
use syn::Ident;

use crate::{
codegen_function::{base::CodegenKind, expr::codegen_expr},
Expand Down Expand Up @@ -48,7 +47,7 @@ pub(crate) fn codegen_for_loop(
}
};

if &func_name.to_string() == "range" || &func_name.to_string() == "range_stepped" {
if &func_name.to_string() == "range" {
let mut args = call.args.clone();

let unroll = codegen_expr(
Expand All @@ -68,17 +67,48 @@ pub(crate) fn codegen_for_loop(
);

let block = codegen_block(&for_loop.body, loop_level + 1, variable_tracker);
let expand = Ident::new(
&format!("{}_expand", func_name.to_string()),
func_name.span(),

quote::quote! {
{
let _start = #start;
let _end = #end;
let _unroll = #unroll;
cubecl::frontend::branch::range_expand(context, _start, _end, _unroll, |context, #i| #block);
}
}
} else if &func_name.to_string() == "range_stepped" {
let mut args = call.args.clone();

let unroll = codegen_expr(
&args.pop().unwrap().into_value(),
loop_level,
variable_tracker,
);
let step = codegen_expr(
&args.pop().unwrap().into_value(),
loop_level,
variable_tracker,
);
let end = codegen_expr(
&args.pop().unwrap().into_value(),
loop_level,
variable_tracker,
);
let start = codegen_expr(
&args.pop().unwrap().into_value(),
loop_level,
variable_tracker,
);

let block = codegen_block(&for_loop.body, loop_level + 1, variable_tracker);

quote::quote! {
{
let _start = #start;
let _end = #end;
let _step = #step;
let _unroll = #unroll;
cubecl::frontend::branch::#expand(context, _start, _end, _unroll, |context, #i| #block);
cubecl::frontend::branch::range_stepped_expand(context, _start, _end, _step, _unroll, |context, #i| #block);
}
}
} else {
Expand Down

0 comments on commit c9fd1fd

Please sign in to comment.