From 56164c2475edc9c01581c4a818220a482337e658 Mon Sep 17 00:00:00 2001 From: cBournhonesque Date: Mon, 12 Aug 2024 16:03:47 -0400 Subject: [PATCH] support destructuring tuples in cube macro --- crates/cubecl-core/tests/frontend/tuple.rs | 34 ++++++++++++++ crates/cubecl-macros/src/analyzer.rs | 46 +++++++++++++------ .../src/codegen_function/variable.rs | 7 +++ 3 files changed, 73 insertions(+), 14 deletions(-) diff --git a/crates/cubecl-core/tests/frontend/tuple.rs b/crates/cubecl-core/tests/frontend/tuple.rs index c5d370236..84936f48e 100644 --- a/crates/cubecl-core/tests/frontend/tuple.rs +++ b/crates/cubecl-core/tests/frontend/tuple.rs @@ -8,6 +8,13 @@ pub fn tuple_const() -> (UInt, UInt) { (x, y) } +#[cube] +pub fn tuple_destructuring() -> (UInt, UInt) { + let x = (UInt::new(0), UInt::new(1)); + let (a, b) = x; + (a + UInt::new(1), b) +} + mod tests { use super::*; use cubecl_core::{ @@ -40,4 +47,31 @@ mod tests { scope.operations } + + #[test] + fn cube_tuple_destructuring() { + let mut context = CubeContext::root(); + + tuple_destructuring::__expand(&mut context); + let scope = context.into_scope(); + + assert_eq!(scope.operations, inline_macro_ref_tuple_destructuring()); + } + + fn inline_macro_ref_tuple_destructuring() -> Vec { + let context = CubeContext::root(); + + let mut scope = context.into_scope(); + let a = scope.create_local(Item::new(Elem::UInt)); + let b = scope.create_local(Item::new(Elem::UInt)); + + let zero: Variable = 0u32.into(); + let one: Variable = 1u32.into(); + + cpa!(scope, a = zero); + cpa!(scope, b = one); + cpa!(scope, a = a + 1u32); + + scope.operations + } } diff --git a/crates/cubecl-macros/src/analyzer.rs b/crates/cubecl-macros/src/analyzer.rs index 5dc659bdc..59037273b 100644 --- a/crates/cubecl-macros/src/analyzer.rs +++ b/crates/cubecl-macros/src/analyzer.rs @@ -75,22 +75,23 @@ impl VariableAnalyzer { match stmt { // Declaration syn::Stmt::Local(local) => { - let mut is_comptime = false; - let id = match &local.pat { - syn::Pat::Ident(pat_ident) => Some(&pat_ident.ident), - syn::Pat::Type(pat_type) => { - is_comptime = is_ty_comptime(&pat_type.ty); - match &*pat_type.pat { - syn::Pat::Ident(pat_ident) => Some(&pat_ident.ident), - _ => todo!("Analysis: unsupported typed path {:?}", pat_type.pat), + match &local.pat { + syn::Pat::Tuple(pat_tuple) => { + for pat in pat_tuple.elems.iter() { + let (id, is_comptime) = find_local_declaration_ident(pat); + if let Some(id) = id { + self.variable_tracker + .analyze_declare(id.to_string(), depth, is_comptime); + } + } + } + _ => { + let (id, is_comptime) = find_local_declaration_ident(&local.pat); + if let Some(id) = id { + self.variable_tracker + .analyze_declare(id.to_string(), depth, is_comptime); } } - syn::Pat::Wild(_) => None, - _ => todo!("Analysis: unsupported path {:?}", local.pat), - }; - if let Some(id) = id { - self.variable_tracker - .analyze_declare(id.to_string(), depth, is_comptime); } if let Some(local_init) = &local.init { self.find_occurrences_in_expr(&local_init.expr, depth) @@ -268,6 +269,23 @@ impl VariableAnalyzer { } } +fn find_local_declaration_ident(pat: &syn::Pat) -> (Option<&syn::Ident>, bool) { + let mut is_comptime = false; + let id = match &pat { + syn::Pat::Ident(pat_ident) => Some(&pat_ident.ident), + syn::Pat::Type(pat_type) => { + is_comptime = is_ty_comptime(&pat_type.ty); + match &*pat_type.pat { + syn::Pat::Ident(pat_ident) => Some(&pat_ident.ident), + _ => todo!("Analysis: unsupported typed path {:?}", pat_type.pat), + } + } + syn::Pat::Wild(_) => None, + _ => todo!("Analysis: unsupported path {:?}", pat), + }; + (id, is_comptime) +} + fn is_ty_comptime(ty: &syn::Type) -> bool { if let syn::Type::Path(path) = ty { for segment in path.path.segments.iter() { diff --git a/crates/cubecl-macros/src/codegen_function/variable.rs b/crates/cubecl-macros/src/codegen_function/variable.rs index 97642d399..a8bdc6edf 100644 --- a/crates/cubecl-macros/src/codegen_function/variable.rs +++ b/crates/cubecl-macros/src/codegen_function/variable.rs @@ -42,6 +42,7 @@ pub(crate) fn codegen_array_lit(array: &syn::ExprArray) -> TokenStream { /// let x = ... /// let x: T = ... /// let _ = ... +/// let (a, b) = ... /// let mut _ = ... pub(crate) fn codegen_local( local: &syn::Local, @@ -57,6 +58,12 @@ pub(crate) fn codegen_local( _ => todo!("Codegen: Unsupported typed path {:?}", pat_type.pat), }, syn::Pat::Wild(wild) => wild.underscore_token.to_token_stream(), + syn::Pat::Tuple(_) => { + // destructuring pattern; we can just return it as is + return quote::quote! { + #local + }; + } _ => todo!("Codegen: Declaration {:?} is unsupported.", local.pat), };