Skip to content

Commit

Permalink
support destructuring tuples in cube macro
Browse files Browse the repository at this point in the history
  • Loading branch information
cBournhonesque committed Aug 12, 2024
1 parent c3ef475 commit 56164c2
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 14 deletions.
34 changes: 34 additions & 0 deletions crates/cubecl-core/tests/frontend/tuple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@ pub fn tuple_const() -> (UInt, UInt) {
(x, y)
}

#[cube]
pub fn tuple_destructuring() -> (UInt, UInt) {
let x = (UInt::new(0), UInt::new(1));
let (a, b) = x;
(a + UInt::new(1), b)
}

mod tests {
use super::*;
use cubecl_core::{
Expand Down Expand Up @@ -40,4 +47,31 @@ mod tests {

scope.operations
}

#[test]
fn cube_tuple_destructuring() {
let mut context = CubeContext::root();

tuple_destructuring::__expand(&mut context);
let scope = context.into_scope();

assert_eq!(scope.operations, inline_macro_ref_tuple_destructuring());
}

fn inline_macro_ref_tuple_destructuring() -> Vec<Operation> {
let context = CubeContext::root();

let mut scope = context.into_scope();
let a = scope.create_local(Item::new(Elem::UInt));
let b = scope.create_local(Item::new(Elem::UInt));

let zero: Variable = 0u32.into();
let one: Variable = 1u32.into();

cpa!(scope, a = zero);
cpa!(scope, b = one);
cpa!(scope, a = a + 1u32);

scope.operations
}
}
46 changes: 32 additions & 14 deletions crates/cubecl-macros/src/analyzer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,22 +75,23 @@ impl VariableAnalyzer {
match stmt {
// Declaration
syn::Stmt::Local(local) => {
let mut is_comptime = false;
let id = match &local.pat {
syn::Pat::Ident(pat_ident) => Some(&pat_ident.ident),
syn::Pat::Type(pat_type) => {
is_comptime = is_ty_comptime(&pat_type.ty);
match &*pat_type.pat {
syn::Pat::Ident(pat_ident) => Some(&pat_ident.ident),
_ => todo!("Analysis: unsupported typed path {:?}", pat_type.pat),
match &local.pat {
syn::Pat::Tuple(pat_tuple) => {
for pat in pat_tuple.elems.iter() {
let (id, is_comptime) = find_local_declaration_ident(pat);
if let Some(id) = id {
self.variable_tracker
.analyze_declare(id.to_string(), depth, is_comptime);
}
}
}
_ => {
let (id, is_comptime) = find_local_declaration_ident(&local.pat);
if let Some(id) = id {
self.variable_tracker
.analyze_declare(id.to_string(), depth, is_comptime);
}
}
syn::Pat::Wild(_) => None,
_ => todo!("Analysis: unsupported path {:?}", local.pat),
};
if let Some(id) = id {
self.variable_tracker
.analyze_declare(id.to_string(), depth, is_comptime);
}
if let Some(local_init) = &local.init {
self.find_occurrences_in_expr(&local_init.expr, depth)
Expand Down Expand Up @@ -268,6 +269,23 @@ impl VariableAnalyzer {
}
}

fn find_local_declaration_ident(pat: &syn::Pat) -> (Option<&syn::Ident>, bool) {
let mut is_comptime = false;
let id = match &pat {
syn::Pat::Ident(pat_ident) => Some(&pat_ident.ident),
syn::Pat::Type(pat_type) => {
is_comptime = is_ty_comptime(&pat_type.ty);
match &*pat_type.pat {
syn::Pat::Ident(pat_ident) => Some(&pat_ident.ident),
_ => todo!("Analysis: unsupported typed path {:?}", pat_type.pat),
}
}
syn::Pat::Wild(_) => None,
_ => todo!("Analysis: unsupported path {:?}", pat),
};
(id, is_comptime)
}

fn is_ty_comptime(ty: &syn::Type) -> bool {
if let syn::Type::Path(path) = ty {
for segment in path.path.segments.iter() {
Expand Down
7 changes: 7 additions & 0 deletions crates/cubecl-macros/src/codegen_function/variable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ pub(crate) fn codegen_array_lit(array: &syn::ExprArray) -> TokenStream {
/// let x = ...
/// let x: T = ...
/// let _ = ...
/// let (a, b) = ...
/// let mut _ = ...
pub(crate) fn codegen_local(
local: &syn::Local,
Expand All @@ -57,6 +58,12 @@ pub(crate) fn codegen_local(
_ => todo!("Codegen: Unsupported typed path {:?}", pat_type.pat),
},
syn::Pat::Wild(wild) => wild.underscore_token.to_token_stream(),
syn::Pat::Tuple(_) => {
// destructuring pattern; we can just return it as is
return quote::quote! {
#local
};
}
_ => todo!("Codegen: Declaration {:?} is unsupported.", local.pat),
};

Expand Down

0 comments on commit 56164c2

Please sign in to comment.