Skip to content

Commit

Permalink
revert while cond (#76)
Browse files Browse the repository at this point in the history
  • Loading branch information
louisfd authored Aug 20, 2024
1 parent 7308828 commit f8f4c6c
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 9 deletions.
44 changes: 38 additions & 6 deletions crates/cubecl-core/tests/frontend/loop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pub fn while_not<I: Int>(lhs: I) {
#[cube]
pub fn manual_loop_break<I: Int>(lhs: I) {
loop {
if lhs != I::from_int(0) {
if lhs == I::from_int(0) {
break;
}
let _ = lhs % I::from_int(1);
Expand All @@ -21,7 +21,7 @@ pub fn manual_loop_break<I: Int>(lhs: I) {
#[cube]
pub fn loop_with_return<I: Int>(lhs: I) {
loop {
if lhs != I::from_int(0) {
if lhs == I::from_int(0) {
return;
}
let _ = lhs % I::from_int(1);
Expand All @@ -46,7 +46,7 @@ mod tests {
while_not::__expand::<ElemType>(&mut context, lhs.into());
let scope = context.into_scope();

assert_eq!(format!("{:?}", scope.operations), inline_macro_ref(false));
assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_while());
}

#[test]
Expand All @@ -58,7 +58,10 @@ mod tests {
manual_loop_break::__expand::<ElemType>(&mut context, lhs.into());
let scope = context.into_scope();

assert_eq!(format!("{:?}", scope.operations), inline_macro_ref(false));
assert_eq!(
format!("{:?}", scope.operations),
inline_macro_ref_loop(false)
);
}

#[test]
Expand All @@ -70,10 +73,13 @@ mod tests {
loop_with_return::__expand::<ElemType>(&mut context, lhs.into());
let scope = context.into_scope();

assert_eq!(format!("{:?}", scope.operations), inline_macro_ref(true));
assert_eq!(
format!("{:?}", scope.operations),
inline_macro_ref_loop(true)
);
}

fn inline_macro_ref(is_return: bool) -> String {
fn inline_macro_ref_while() -> String {
let mut context = CubeContext::root();
let item = Item::new(ElemType::as_elem());
let lhs = context.create_local(item);
Expand All @@ -87,6 +93,32 @@ mod tests {
&mut scope,
loop(|scope| {
cpa!(scope, cond = lhs != 0);
cpa!(scope, cond = !cond);
cpa!(scope, if(cond).then(|scope|{
scope.register(Branch::Break)
}));

cpa!(scope, rhs = lhs % 1i32);
})
);

format!("{:?}", scope.operations)
}

fn inline_macro_ref_loop(is_return: bool) -> String {
let mut context = CubeContext::root();
let item = Item::new(ElemType::as_elem());
let lhs = context.create_local(item);

let mut scope = context.into_scope();
let cond = scope.create_local(Item::new(Elem::Bool));
let lhs: Variable = lhs.into();
let rhs = scope.create_local(item);

cpa!(
&mut scope,
loop(|scope| {
cpa!(scope, cond = lhs == 0);
cpa!(scope, if(cond).then(|scope|{
match is_return {
true => scope.register(Branch::Return),
Expand Down
2 changes: 2 additions & 0 deletions crates/cubecl-core/tests/frontend/reuse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ mod tests {
&mut scope,
loop(|scope| {
cpa!(scope, cond = x < 10);
cpa!(scope, cond = !cond);
cpa!(scope, if(cond).then(|scope|{
scope.register(Branch::Break);
}));
Expand All @@ -90,6 +91,7 @@ mod tests {
&mut scope,
loop(|scope| {
cpa!(scope, cond = x < 10);
cpa!(scope, cond = !cond);
cpa!(scope, if(cond).then(|scope|{
scope.register(Branch::Break);
}));
Expand Down
2 changes: 2 additions & 0 deletions crates/cubecl-macros/src/codegen_function/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,13 @@ impl Codegen {
pub fn tokens(self) -> TokenStream {
self.into_token_stream()
}

pub fn pop_array_indexing(&mut self) -> Option<ArrayIndexing> {
let mut result = None;
core::mem::swap(&mut result, &mut self.array_indexing);
result
}

pub fn set_array_indexing(&mut self, array_indexing: Option<ArrayIndexing>) {
self.array_indexing = array_indexing;
}
Expand Down
13 changes: 10 additions & 3 deletions crates/cubecl-macros/src/codegen_function/branch.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use proc_macro2::TokenStream;
use syn::{Expr, ExprUnary, UnOp};

use crate::{
codegen_function::{base::CodegenKind, expr::codegen_expr},
Expand All @@ -8,7 +9,7 @@ use crate::{
use super::{
base::{codegen_block, Codegen},
function::codegen_call,
operation::codegen_binary,
operation::{codegen_binary, codegen_unary},
variable::{codegen_lit, codegen_path_var},
};

Expand Down Expand Up @@ -139,6 +140,7 @@ pub(crate) fn codegen_cond(
variable_tracker: &mut VariableTracker,
) -> Codegen {
match cond {
syn::Expr::Unary(expr) => codegen_unary(expr, loop_level, variable_tracker),
syn::Expr::Binary(expr) => codegen_binary(expr, loop_level, variable_tracker),
syn::Expr::Lit(expr) => Codegen::new(codegen_lit(expr), CodegenKind::Literal),
syn::Expr::Path(expr) => codegen_path_var(expr, loop_level, variable_tracker),
Expand Down Expand Up @@ -228,8 +230,13 @@ pub(crate) fn codegen_while_loop(
loop_level: usize,
variable_tracker: &mut VariableTracker,
) -> TokenStream {
let (cond, kind, _) =
codegen_cond(&while_loop.cond, loop_level + 1, variable_tracker).process();
let inverted_cond = Expr::Unary(ExprUnary {
attrs: vec![],
op: UnOp::Not(Default::default()),
expr: Box::new(*while_loop.cond.clone()),
});

let (cond, kind, _) = codegen_cond(&inverted_cond, loop_level + 1, variable_tracker).process();

if let CodegenKind::Comptime = kind {
return syn::Error::new_spanned(while_loop.while_token, "Comptime not supported for while")
Expand Down

0 comments on commit f8f4c6c

Please sign in to comment.