From 02bc447253c4c8efadc811a55789db835dfa2938 Mon Sep 17 00:00:00 2001 From: Genna Wingert Date: Sun, 8 Sep 2024 22:25:05 +0200 Subject: [PATCH] Fix remaining tests --- .../cubecl-core/src/frontend/element/base.rs | 20 +++++ .../src/frontend/element/numeric.rs | 4 +- .../src/frontend/operation/base.rs | 21 +++-- crates/cubecl-core/tests/frontend/mod.rs | 16 ++-- crates/cubecl-core/tests/frontend/reuse.rs | 6 +- .../tests/frontend/shared_memory.rs | 6 +- crates/cubecl-core/tests/frontend/struct.rs | 10 +-- crates/cubecl-core/tests/frontend/tensor.rs | 10 +-- crates/cubecl-core/tests/frontend/topology.rs | 6 +- crates/cubecl-core/tests/frontend/trait.rs | 12 +-- crates/cubecl-core/tests/frontend/tuple.rs | 22 ++--- .../tests/frontend/vectorization.rs | 24 ++--- crates/cubecl-macros/src/expression.rs | 12 +++ .../cubecl-macros/src/generate/expression.rs | 21 +++-- .../cubecl-macros/src/generate/statement.rs | 46 ++-------- crates/cubecl-macros/src/parse/expression.rs | 15 +++- crates/cubecl-macros/src/statement.rs | 87 +++++++++++++++++-- 17 files changed, 214 insertions(+), 124 deletions(-) diff --git a/crates/cubecl-core/src/frontend/element/base.rs b/crates/cubecl-core/src/frontend/element/base.rs index 340f3670b..99ec4c277 100644 --- a/crates/cubecl-core/src/frontend/element/base.rs +++ b/crates/cubecl-core/src/frontend/element/base.rs @@ -186,6 +186,19 @@ macro_rules! tuple_init { } } } +macro_rules! tuple_runtime { + ($($P:ident),*) => { + impl<$($P: IntoRuntime),*> IntoRuntime for ($($P,)*) { + #[allow(non_snake_case)] + fn __expand_runtime_method(self, context: &mut CubeContext) -> Self::ExpandType { + let ($($P,)*) = self; + ($( + $P.__expand_runtime_method(context), + )*) + } + } + } +} tuple_cube_type!(P1); tuple_cube_type!(P1, P2); @@ -201,6 +214,13 @@ tuple_init!(P1, P2, P3, P4); tuple_init!(P1, P2, P3, P4, P5); tuple_init!(P1, P2, P3, P4, P5, P6); +tuple_runtime!(P1); +tuple_runtime!(P1, P2); +tuple_runtime!(P1, P2, P3); +tuple_runtime!(P1, P2, P3, P4); +tuple_runtime!(P1, P2, P3, P4, P5); +tuple_runtime!(P1, P2, P3, P4, P5, P6); + pub trait ExpandElementBaseInit: CubeType { fn init_elem(context: &mut CubeContext, elem: ExpandElement) -> ExpandElement; } diff --git a/crates/cubecl-core/src/frontend/element/numeric.rs b/crates/cubecl-core/src/frontend/element/numeric.rs index 71c2005bc..dbd402846 100644 --- a/crates/cubecl-core/src/frontend/element/numeric.rs +++ b/crates/cubecl-core/src/frontend/element/numeric.rs @@ -82,7 +82,7 @@ pub trait Numeric: fn __expand_from_vec( context: &mut CubeContext, - vec: [ExpandElementTyped; D], + vec: [u32; D], ) -> ::ExpandType { let new_var = context.create_local(Item::vectorized( Self::as_elem(), @@ -91,7 +91,7 @@ pub trait Numeric: let elem = Self::as_elem(); for (i, element) in vec.iter().enumerate() { - let var: Variable = elem.constant_from_i64(element.constant().unwrap().as_i64()); + let var: Variable = elem.constant_from_i64(*element as i64); let expand = ExpandElement::Plain(var); index_assign::expand::( diff --git a/crates/cubecl-core/src/frontend/operation/base.rs b/crates/cubecl-core/src/frontend/operation/base.rs index dea6d8a4b..4e9c30601 100644 --- a/crates/cubecl-core/src/frontend/operation/base.rs +++ b/crates/cubecl-core/src/frontend/operation/base.rs @@ -1,5 +1,3 @@ -use std::num::NonZero; - use crate::ir::{BinaryOperator, Elem, Item, Operator, UnaryOperator, Variable, Vectorization}; use crate::prelude::{CubeType, ExpandElementTyped}; use crate::{ @@ -23,6 +21,7 @@ where let item_rhs = rhs.item(); let vectorization = find_vectorization(item_lhs.vectorization, item_rhs.vectorization); + let item = Item::vectorized(item_lhs.elem, vectorization); // We can only reuse rhs. @@ -196,17 +195,21 @@ where } fn find_vectorization(lhs: Vectorization, rhs: Vectorization) -> Vectorization { + if lhs == rhs { + return lhs; + } match (lhs, rhs) { (None, None) => None, (None, Some(rhs)) => Some(rhs), (Some(lhs), None) => Some(lhs), - (Some(lhs), Some(rhs)) => { - let min = lhs.get().min(rhs.get()); - let common = (0..=min) - .rev() - .find(|i| lhs.get() % i == 0 && rhs.get() % i == 0) - .unwrap_or(1); - NonZero::new(common) + (Some(_), Some(_)) => { + panic!("Auto-matching fixed vectorization currently unsupported"); + // let min = lhs.get().min(rhs.get()); + // let common = (0..=min) + // .rev() + // .find(|i| lhs.get() % i == 0 && rhs.get() % i == 0) + // .unwrap_or(1); + // NonZero::new(common) } } } diff --git a/crates/cubecl-core/tests/frontend/mod.rs b/crates/cubecl-core/tests/frontend/mod.rs index 0bc4abd7d..64cebc692 100644 --- a/crates/cubecl-core/tests/frontend/mod.rs +++ b/crates/cubecl-core/tests/frontend/mod.rs @@ -14,12 +14,12 @@ mod module_import; mod ops; mod parenthesis; mod redeclare; -// mod reuse; -// mod shared_memory; -// mod r#struct; -// mod tensor; -// mod topology; -// mod r#trait; +mod reuse; +mod shared_memory; +mod r#struct; +mod tensor; +mod topology; +mod r#trait; -// mod tuple; -// mod vectorization; +mod tuple; +mod vectorization; diff --git a/crates/cubecl-core/tests/frontend/reuse.rs b/crates/cubecl-core/tests/frontend/reuse.rs index 8ccd69888..c66a1284d 100644 --- a/crates/cubecl-core/tests/frontend/reuse.rs +++ b/crates/cubecl-core/tests/frontend/reuse.rs @@ -26,14 +26,14 @@ mod tests { ir::{Branch, Elem, Item, Variable}, }; - type ElemType = I32; + type ElemType = i32; #[test] fn cube_reuse_assign_test() { let mut context = CubeContext::root(); let x = context.create_local(Item::new(ElemType::as_elem())); - reuse::__expand::(&mut context, x.into()); + reuse::expand::(&mut context, x.into()); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_assign()); @@ -45,7 +45,7 @@ mod tests { let x = context.create_local(Item::new(ElemType::as_elem())); - reuse_incr::__expand::(&mut context, x.into()); + reuse_incr::expand::(&mut context, x.into()); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_incr()); diff --git a/crates/cubecl-core/tests/frontend/shared_memory.rs b/crates/cubecl-core/tests/frontend/shared_memory.rs index 603551fde..b41dbe0b2 100644 --- a/crates/cubecl-core/tests/frontend/shared_memory.rs +++ b/crates/cubecl-core/tests/frontend/shared_memory.rs @@ -2,7 +2,7 @@ use cubecl_core as cubecl; use cubecl_core::prelude::*; #[cube] -pub fn shared_memory_read_write(sm_size: Comptime) { +pub fn shared_memory_read_write(#[comptime] sm_size: u32) { let mut shared = SharedMemory::::new(sm_size); shared[0] = T::from_int(3); let _ = shared[0]; @@ -15,13 +15,13 @@ mod tests { ir::{Item, Variable}, }; - type ElemType = F32; + type ElemType = f32; #[test] fn cube_support_shared_memory() { let mut context = CubeContext::root(); - shared_memory_read_write::__expand::(&mut context, 512); + shared_memory_read_write::expand::(&mut context, 512); assert_eq!( format!("{:?}", context.into_scope().operations), inline_macro_ref() diff --git a/crates/cubecl-core/tests/frontend/struct.rs b/crates/cubecl-core/tests/frontend/struct.rs index e0deee8a9..4eb21572f 100644 --- a/crates/cubecl-core/tests/frontend/struct.rs +++ b/crates/cubecl-core/tests/frontend/struct.rs @@ -40,7 +40,7 @@ mod tests { ir::{Item, Variable}, }; - type ElemType = F32; + type ElemType = f32; #[test] fn cube_new_struct_test() { @@ -49,7 +49,7 @@ mod tests { let x = context.create_local(Item::new(ElemType::as_elem())); let y = context.create_local(Item::new(ElemType::as_elem())); - creator::__expand::(&mut context, x.into(), y.into()); + creator::expand::(&mut context, x.into(), y.into()); let scope = context.into_scope(); assert_eq!( @@ -69,7 +69,7 @@ mod tests { first: x.into(), second: y.into(), }; - state_receiver_with_reuse::__expand::(&mut context, expanded_state); + state_receiver_with_reuse::expand::(&mut context, expanded_state); let scope = context.into_scope(); assert_eq!( @@ -89,7 +89,7 @@ mod tests { first: x.into(), second: y.into(), }; - attribute_modifier_reuse_field::__expand::(&mut context, expanded_state); + attribute_modifier_reuse_field::expand::(&mut context, expanded_state); let scope = context.into_scope(); assert_eq!( @@ -109,7 +109,7 @@ mod tests { first: x.into(), second: y.into(), }; - attribute_modifier_reuse_struct::__expand::(&mut context, expanded_state); + attribute_modifier_reuse_struct::expand::(&mut context, expanded_state); let scope = context.into_scope(); assert_eq!( diff --git a/crates/cubecl-core/tests/frontend/tensor.rs b/crates/cubecl-core/tests/frontend/tensor.rs index d7d905bdb..231e30555 100644 --- a/crates/cubecl-core/tests/frontend/tensor.rs +++ b/crates/cubecl-core/tests/frontend/tensor.rs @@ -15,14 +15,14 @@ mod tests { ir::{Item, Operation, Variable}, }; - type ElemType = F32; + type ElemType = f32; #[test] fn cube_support_tensor_metadata() { let mut context = CubeContext::root(); let input = context.input(0, Item::new(ElemType::as_elem())); - kernel::__expand::(&mut context, input.into()); + kernel::expand::(&mut context, input.into()); assert_eq!(context.into_scope().operations, inline_macro_ref()); } @@ -33,9 +33,9 @@ mod tests { let mut scope = context.into_scope(); let input: Variable = input.into(); - let x = scope.create_local(Item::new(UInt::as_elem())); - let y = scope.create_local(Item::new(UInt::as_elem())); - let z = scope.create_local(Item::new(UInt::as_elem())); + let x = scope.create_local(Item::new(u32::as_elem())); + let y = scope.create_local(Item::new(u32::as_elem())); + let z = scope.create_local(Item::new(u32::as_elem())); cpa!(&mut scope, x = shape(input, 1u32)); cpa!(&mut scope, y = stride(input, 1u32)); diff --git a/crates/cubecl-core/tests/frontend/topology.rs b/crates/cubecl-core/tests/frontend/topology.rs index 816ce5cd7..1cd7263dd 100644 --- a/crates/cubecl-core/tests/frontend/topology.rs +++ b/crates/cubecl-core/tests/frontend/topology.rs @@ -3,7 +3,7 @@ use cubecl_core::prelude::*; #[cube] pub fn topology_kernel(input: Tensor) { - let x = ABSOLUTE_POS + UInt::new(4); + let x = ABSOLUTE_POS + 4; let _ = input[x]; } @@ -14,14 +14,14 @@ mod tests { ir::{Elem, Item, Variable}, }; - type ElemType = F32; + type ElemType = f32; #[test] fn cube_support_topology() { let mut context = CubeContext::root(); let input = context.input(0, Item::new(ElemType::as_elem())); - topology_kernel::__expand::(&mut context, input.into()); + topology_kernel::expand::(&mut context, input.into()); assert_eq!( format!("{:?}", context.into_scope().operations), inline_macro_ref() diff --git a/crates/cubecl-core/tests/frontend/trait.rs b/crates/cubecl-core/tests/frontend/trait.rs index 8d75f27be..bade72acc 100644 --- a/crates/cubecl-core/tests/frontend/trait.rs +++ b/crates/cubecl-core/tests/frontend/trait.rs @@ -64,7 +64,7 @@ impl MethodTypedStrategy for AddStrategy { input_1: ::ExpandType, input_2: ::ExpandType, ) -> ::ExpandType { - add_strategy_operation::__expand::(context, input_1, input_2) + add_strategy_operation::expand::(context, input_1, input_2) } } @@ -80,7 +80,7 @@ mod tests { ir::{Item, Variable}, }; - type ElemType = F32; + type ElemType = f32; #[test] fn cube_strategy_trait_add_test() { let mut context = CubeContext::root(); @@ -88,7 +88,7 @@ mod tests { let x = context.create_local(Item::new(ElemType::as_elem())); let y = context.create_local(Item::new(ElemType::as_elem())); - with_strategy_trait::__expand::(&mut context, x.into(), y.into()); + with_strategy_trait::expand::(&mut context, x.into(), y.into()); let scope = context.into_scope(); assert_eq!( @@ -104,7 +104,7 @@ mod tests { let x = context.create_local(Item::new(ElemType::as_elem())); let y = context.create_local(Item::new(ElemType::as_elem())); - with_strategy_trait::__expand::(&mut context, x.into(), y.into()); + with_strategy_trait::expand::(&mut context, x.into(), y.into()); let scope = context.into_scope(); assert_eq!( @@ -120,7 +120,7 @@ mod tests { let x = context.create_local(Item::new(ElemType::as_elem())); let y = context.create_local(Item::new(ElemType::as_elem())); - two_strategy_traits::__expand::( + two_strategy_traits::expand::( &mut context, x.into(), y.into(), @@ -137,7 +137,7 @@ mod tests { let x = context.create_local(Item::new(ElemType::as_elem())); let y = context.create_local(Item::new(ElemType::as_elem())); - with_trait_generic_method::__expand::( + with_trait_generic_method::expand::( &mut context, x.into(), y.into(), diff --git a/crates/cubecl-core/tests/frontend/tuple.rs b/crates/cubecl-core/tests/frontend/tuple.rs index 84936f48e..bc37cc56d 100644 --- a/crates/cubecl-core/tests/frontend/tuple.rs +++ b/crates/cubecl-core/tests/frontend/tuple.rs @@ -2,17 +2,17 @@ use cubecl_core as cubecl; use cubecl_core::prelude::*; #[cube] -pub fn tuple_const() -> (UInt, UInt) { - let x = UInt::new(0); - let y = UInt::new(1); +pub fn tuple_const() -> (u32, u32) { + let x = 0u32; + let y = 1u32; (x, y) } #[cube] -pub fn tuple_destructuring() -> (UInt, UInt) { - let x = (UInt::new(0), UInt::new(1)); +pub fn tuple_destructuring() -> (u32, u32) { + let x = (0u32, 1u32); let (a, b) = x; - (a + UInt::new(1), b) + (a + 1, b) } mod tests { @@ -21,12 +21,14 @@ mod tests { cpa, ir::{Elem, Item, Operation, Variable}, }; + use pretty_assertions::assert_eq; #[test] + #[ignore = "Empty body because of constant collapsing"] fn cube_tuple_const_test() { let mut context = CubeContext::root(); - tuple_const::__expand(&mut context); + tuple_const::expand(&mut context); let scope = context.into_scope(); assert_eq!(scope.operations, inline_macro_ref_tuple_const()); @@ -52,7 +54,7 @@ mod tests { fn cube_tuple_destructuring() { let mut context = CubeContext::root(); - tuple_destructuring::__expand(&mut context); + tuple_destructuring::expand(&mut context); let scope = context.into_scope(); assert_eq!(scope.operations, inline_macro_ref_tuple_destructuring()); @@ -65,12 +67,10 @@ mod tests { let a = scope.create_local(Item::new(Elem::UInt)); let b = scope.create_local(Item::new(Elem::UInt)); - let zero: Variable = 0u32.into(); let one: Variable = 1u32.into(); - cpa!(scope, a = zero); + cpa!(scope, a = one); cpa!(scope, b = one); - cpa!(scope, a = a + 1u32); scope.operations } diff --git a/crates/cubecl-core/tests/frontend/vectorization.rs b/crates/cubecl-core/tests/frontend/vectorization.rs index 938750d04..6a95921ee 100644 --- a/crates/cubecl-core/tests/frontend/vectorization.rs +++ b/crates/cubecl-core/tests/frontend/vectorization.rs @@ -12,18 +12,20 @@ pub fn vectorization_cmp(rhs: T) { } mod tests { + use std::num::NonZero; + use super::*; use cubecl_core::ir::Item; - type ElemType = F32; + type ElemType = f32; #[test] fn cube_vectorization_binary_op_with_same_scheme_does_not_fail() { let mut context = CubeContext::root(); - let lhs = context.create_local(Item::vectorized(ElemType::as_elem(), 2)); + let lhs = context.create_local(Item::vectorized(ElemType::as_elem(), NonZero::new(2))); - vectorization_binary::__expand::(&mut context, lhs.into()); + vectorization_binary::expand::(&mut context, lhs.into()); } #[test] @@ -31,18 +33,18 @@ mod tests { fn cube_vectorization_binary_op_with_different_scheme_fails() { let mut context = CubeContext::root(); - let lhs = context.create_local(Item::vectorized(ElemType::as_elem(), 4)); + let lhs = context.create_local(Item::vectorized(ElemType::as_elem(), NonZero::new(4))); - vectorization_binary::__expand::(&mut context, lhs.into()); + vectorization_binary::expand::(&mut context, lhs.into()); } #[test] fn cube_vectorization_cmp_op_with_same_scheme_does_not_fail() { let mut context = CubeContext::root(); - let lhs = context.create_local(Item::vectorized(ElemType::as_elem(), 2)); + let lhs = context.create_local(Item::vectorized(ElemType::as_elem(), NonZero::new(2))); - vectorization_cmp::__expand::(&mut context, lhs.into()); + vectorization_cmp::expand::(&mut context, lhs.into()); } #[test] @@ -50,17 +52,17 @@ mod tests { fn cube_vectorization_cmp_op_with_different_scheme_fails() { let mut context = CubeContext::root(); - let lhs = context.create_local(Item::vectorized(ElemType::as_elem(), 4)); + let lhs = context.create_local(Item::vectorized(ElemType::as_elem(), NonZero::new(4))); - vectorization_cmp::__expand::(&mut context, lhs.into()); + vectorization_cmp::expand::(&mut context, lhs.into()); } #[test] fn cube_vectorization_can_be_broadcasted() { let mut context = CubeContext::root(); - let lhs = context.create_local(Item::vectorized(ElemType::as_elem(), 1)); + let lhs = context.create_local(Item::vectorized(ElemType::as_elem(), None)); - vectorization_cmp::__expand::(&mut context, lhs.into()); + vectorization_cmp::expand::(&mut context, lhs.into()); } } diff --git a/crates/cubecl-macros/src/expression.rs b/crates/cubecl-macros/src/expression.rs index 0727a75a4..4a0b0d6ad 100644 --- a/crates/cubecl-macros/src/expression.rs +++ b/crates/cubecl-macros/src/expression.rs @@ -1,3 +1,5 @@ +use std::{rc::Rc, sync::atomic::AtomicUsize}; + use proc_macro2::{Span, TokenStream}; use quote::quote; use syn::{ @@ -24,10 +26,12 @@ pub enum Expression { Variable { name: Ident, is_mut: bool, + use_count: Rc, ty: Option, }, ConstVariable { name: Ident, + use_count: Rc, ty: Option, }, FieldAccess { @@ -207,6 +211,7 @@ impl Expression { Expression::FieldAccess { base, .. } => base.is_const(), Expression::Reference { inner } => inner.is_const(), Expression::Array { elements, .. } => elements.iter().all(|it| it.is_const()), + Expression::Tuple { elements, .. } => elements.iter().all(|it| it.is_const()), Expression::MethodCall { method, args, .. } => { method == "vectorization_factor" && args.is_empty() } @@ -228,6 +233,13 @@ impl Expression { .collect::>>()?; Some(quote![[#(#elements),*]]) } + Expression::Tuple { elements, .. } => { + let elements = elements + .iter() + .map(|it| it.as_const(context)) + .collect::>>()?; + Some(quote![(#(#elements),*)]) + } Expression::FieldAccess { base, field, .. } => { base.as_const(context).map(|base| quote![#base.#field]) } diff --git a/crates/cubecl-macros/src/generate/expression.rs b/crates/cubecl-macros/src/generate/expression.rs index a61608dcb..f7d8ddb91 100644 --- a/crates/cubecl-macros/src/generate/expression.rs +++ b/crates/cubecl-macros/src/generate/expression.rs @@ -366,12 +366,12 @@ impl Expression { .to_compile_error() } } - Expression::Tuple { span, .. } => { + Expression::Tuple { elements, .. } => { if let Some(constant) = self.as_const(context) { constant } else { - syn::Error::new(*span, "Tuple expressions can't be used at runtime") - .to_compile_error() + let elements = elements.iter().map(|it| it.to_tokens(context)); + quote![(#(#elements),*)] } } @@ -426,11 +426,16 @@ impl Expression { impl Block { pub fn to_tokens(&self, context: &mut Context) -> TokenStream { let inner: Vec<_> = self.inner.iter().map(|it| it.to_tokens(context)).collect(); - let ret = self - .ret - .as_ref() - .map(|ret| ret.to_tokens(context)) - .unwrap_or_else(|| quote![()]); + let ret = if let Some(ret) = self.ret.as_ref() { + let as_const = ret.as_const(context); + if let Some(as_const) = as_const { + quote![#as_const.__expand_runtime_method(context)] + } else { + ret.to_tokens(context) + } + } else { + quote![()] + }; quote! { { diff --git a/crates/cubecl-macros/src/generate/statement.rs b/crates/cubecl-macros/src/generate/statement.rs index af1f7d899..3d4c01870 100644 --- a/crates/cubecl-macros/src/generate/statement.rs +++ b/crates/cubecl-macros/src/generate/statement.rs @@ -1,13 +1,8 @@ use proc_macro2::TokenStream; use quote::{quote, quote_spanned}; -use syn::{spanned::Spanned, Pat, Token}; +use syn::{spanned::Spanned, Token}; -use crate::{ - expression::Expression, - paths::frontend_type, - scope::Context, - statement::{parse_pat, Statement}, -}; +use crate::{expression::Expression, paths::frontend_type, scope::Context, statement::Statement}; impl Statement { pub fn to_tokens(&self, context: &mut Context) -> TokenStream { @@ -64,11 +59,10 @@ impl Statement { quote![let #mutable #name #ty;] } } - Statement::Destructure { fields } => { - let fields = generate_struct_destructure(fields, context); - match fields { - Ok(fields) => fields, - Err(e) => e.to_compile_error(), + Statement::Group { statements } => { + let statements = statements.iter().map(|it| it.to_tokens(context)); + quote! { + #(#statements)* } } Statement::Expression { @@ -89,34 +83,6 @@ impl Statement { } } -fn generate_struct_destructure( - fields: &[(Pat, Expression)], - context: &mut Context, -) -> syn::Result { - let fields = fields - .iter() - .map(|(pat, init)| { - let (ident, ty, mutable) = parse_pat(pat.clone())?; - let statement = Statement::Local { - left: Box::new(Expression::Variable { - name: ident, - ty: None, - is_mut: mutable, - }), - init: Some(Box::new(init.clone())), - mutable, - ty, - }; - let statement = statement.to_tokens(context); - Ok(quote![#statement]) - }) - .collect::>>()?; - - Ok(quote! {span=> - #(#fields)* - }) -} - fn is_mut_init(expr: Option<&Expression>) -> bool { fn is_mut(expr: &Expression) -> bool { match expr { diff --git a/crates/cubecl-macros/src/parse/expression.rs b/crates/cubecl-macros/src/parse/expression.rs index 687db2291..544495d5a 100644 --- a/crates/cubecl-macros/src/parse/expression.rs +++ b/crates/cubecl-macros/src/parse/expression.rs @@ -68,15 +68,24 @@ impl Expression { is_const, is_mut, is_keyword, - .. + use_count, }) = variable { if is_const { - Expression::ConstVariable { name, ty } + Expression::ConstVariable { + name, + ty, + use_count, + } } else if is_keyword { Expression::Keyword { name } } else { - Expression::Variable { name, ty, is_mut } + Expression::Variable { + name, + ty, + is_mut, + use_count, + } } } else { // If it's not in the scope, it's not a managed local variable. Treat it as an diff --git a/crates/cubecl-macros/src/statement.rs b/crates/cubecl-macros/src/statement.rs index 86ab7d6b5..22e0c42c0 100644 --- a/crates/cubecl-macros/src/statement.rs +++ b/crates/cubecl-macros/src/statement.rs @@ -1,7 +1,14 @@ +use std::{ + rc::Rc, + sync::atomic::{AtomicUsize, Ordering}, +}; + use crate::{expression::Expression, scope::Context}; use proc_macro2::Span; use quote::format_ident; -use syn::{spanned::Spanned, Ident, Pat, PatStruct, Stmt, Type}; +use syn::{ + spanned::Spanned, Ident, Index, Member, Pat, PatStruct, PatTuple, PatTupleStruct, Stmt, Type, +}; #[derive(Clone, Debug)] pub enum Statement { @@ -11,8 +18,9 @@ pub enum Statement { mutable: bool, ty: Option, }, - Destructure { - fields: Vec<(Pat, Expression)>, + /// Group of statements generated by desugaring + Group { + statements: Vec, }, Expression { expression: Box, @@ -33,7 +41,11 @@ impl Statement { .map(Box::new); let (ident, ty, mutable) = match local.pat { Pat::Struct(pat) => { - return parse_struct_destructure(pat, *init.unwrap(), context); + return desugar_struct_local(pat, *init.unwrap(), context); + } + Pat::Tuple(PatTuple { elems, .. }) + | Pat::TupleStruct(PatTupleStruct { elems, .. }) => { + return desugar_tuple_local(elems, *init.unwrap(), context) } pat => parse_pat(pat)?, }; @@ -42,6 +54,7 @@ impl Statement { name: ident.clone(), is_mut: mutable, ty: ty.clone(), + use_count: Rc::new(AtomicUsize::new(0)), }); context.push_variable(ident, ty.clone(), is_const && !mutable, mutable); @@ -85,7 +98,7 @@ pub fn parse_pat(pat: Pat) -> syn::Result<(Ident, Option, bool)> { Ok(res) } -fn parse_struct_destructure( +fn desugar_struct_local( pat: PatStruct, init: Expression, context: &mut Context, @@ -102,9 +115,69 @@ fn parse_struct_destructure( }; let (ident, ty, mutable) = parse_pat(*field.pat.clone())?; context.push_variable(ident.clone(), ty.clone(), init.is_const(), mutable); - Ok((*field.pat, access)) + let statement = Statement::Local { + left: Box::new(Expression::Variable { + name: ident, + ty: ty.clone(), + is_mut: mutable, + use_count: AtomicUsize::new(0).into(), + }), + init: Some(Box::new(access)), + mutable, + ty, + }; + Ok(statement) + }) + .collect::>>()?; + + match init { + Expression::Variable { use_count, .. } | Expression::ConstVariable { use_count, .. } => { + use_count.fetch_add(fields.len() - 1, Ordering::AcqRel); + } + _ => {} + } + + Ok(Statement::Group { statements: fields }) +} + +fn desugar_tuple_local( + elems: impl IntoIterator, + init: Expression, + context: &mut Context, +) -> syn::Result { + let fields = elems + .into_iter() + .enumerate() + .map(|(i, pat)| { + let span = pat.span(); + let access = Expression::FieldAccess { + base: Box::new(init.clone()), + field: Member::Unnamed(Index::from(i)), + span, + }; + let (ident, ty, mutable) = parse_pat(pat.clone())?; + context.push_variable(ident.clone(), ty.clone(), init.is_const(), mutable); + let statement = Statement::Local { + left: Box::new(Expression::Variable { + name: ident, + ty: ty.clone(), + is_mut: mutable, + use_count: AtomicUsize::new(0).into(), + }), + init: Some(Box::new(access)), + mutable, + ty, + }; + Ok(statement) }) .collect::>>()?; - Ok(Statement::Destructure { fields }) + match init { + Expression::Variable { use_count, .. } | Expression::ConstVariable { use_count, .. } => { + use_count.fetch_add(fields.len() - 1, Ordering::AcqRel); + } + _ => {} + } + + Ok(Statement::Group { statements: fields }) }