Skip to content

Commit

Permalink
Add tuple expression in cube derive macro (#49)
Browse files Browse the repository at this point in the history
  • Loading branch information
cBournhonesque authored Aug 9, 2024
1 parent 32feabc commit cbbf866
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 5 deletions.
37 changes: 36 additions & 1 deletion crates/cubecl-core/src/frontend/element/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use std::marker::PhantomData;
pub trait CubeType {
type ExpandType: Clone + Init;

/// Wrapper around the init method, necesary to type inference.
/// Wrapper around the init method, necessary to type inference.
fn init(context: &mut CubeContext, expand: Self::ExpandType) -> Self::ExpandType {
expand.init(context)
}
Expand Down Expand Up @@ -153,6 +153,41 @@ from_const!(f32, F32);
from_const!(bool, Bool);
from_const!(val UInt, I32, I64, F32, F64);

macro_rules! tuple_cube_type {
($($P:ident),*) => {
impl<$($P: CubeType),*> CubeType for ($($P,)*) {
type ExpandType = ($($P::ExpandType,)*);
}
}
}
macro_rules! tuple_init {
($($P:ident),*) => {
impl<$($P: Init),*> Init for ($($P,)*) {
#[allow(non_snake_case)]
fn init(self, context: &mut CubeContext) -> Self {
let ($($P,)*) = self;
($(
$P.init(context),
)*)
}
}
}
}

tuple_cube_type!(P1);
tuple_cube_type!(P1, P2);
tuple_cube_type!(P1, P2, P3);
tuple_cube_type!(P1, P2, P3, P4);
tuple_cube_type!(P1, P2, P3, P4, P5);
tuple_cube_type!(P1, P2, P3, P4, P5, P6);

tuple_init!(P1);
tuple_init!(P1, P2);
tuple_init!(P1, P2, P3);
tuple_init!(P1, P2, P3, P4);
tuple_init!(P1, P2, P3, P4, P5);
tuple_init!(P1, P2, P3, P4, P5, P6);

pub trait ExpandElementBaseInit: CubeType {
fn init_elem(context: &mut CubeContext, elem: ExpandElement) -> ExpandElement;
}
Expand Down
1 change: 0 additions & 1 deletion crates/cubecl-core/src/frontend/element/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ mod slice;
mod tensor;
mod uint;
mod vectorized;

pub use array::*;
pub use base::*;
pub use bool::*;
Expand Down
2 changes: 2 additions & 0 deletions crates/cubecl-core/tests/frontend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,6 @@ mod r#struct;
mod tensor;
mod topology;
mod r#trait;

mod tuple;
mod vectorization;
43 changes: 43 additions & 0 deletions crates/cubecl-core/tests/frontend/tuple.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
use cubecl_core as cubecl;
use cubecl_core::prelude::*;

#[cube]
pub fn tuple_const() -> (UInt, UInt) {
let x = UInt::new(0);
let y = UInt::new(1);
(x, y)
}

mod tests {
use super::*;
use cubecl_core::{
cpa,
ir::{Elem, Item, Operation, Variable},
};

#[test]
fn cube_tuple_const_test() {
let mut context = CubeContext::root();

tuple_const::__expand(&mut context);
let scope = context.into_scope();

assert_eq!(scope.operations, inline_macro_ref_tuple_const());
}

fn inline_macro_ref_tuple_const() -> Vec<Operation> {
let context = CubeContext::root();

let mut scope = context.into_scope();
let x = scope.create_local(Item::new(Elem::UInt));
let y = scope.create_local(Item::new(Elem::UInt));

let zero: Variable = 0u32.into();
let one: Variable = 1u32.into();

cpa!(scope, x = zero);
cpa!(scope, y = one);

scope.operations
}
}
31 changes: 29 additions & 2 deletions crates/cubecl-macros/src/codegen_function/expr.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::tracker::VariableTracker;
use proc_macro2::TokenStream;
use proc_macro2::{Ident, Span, TokenStream};

use super::{
base::{codegen_block, Codegen, CodegenKind},
Expand Down Expand Up @@ -74,8 +74,9 @@ pub(crate) fn codegen_expr(
"Range is not supported, use [range](cubecl::prelude::range) instead.",
)
.to_compile_error(),
syn::Expr::Tuple(tuple) => codegen_tuple(tuple, loop_level, variable_tracker),
_ => {
syn::Error::new_spanned(expr, "Expression is not supported").to_compile_error()
syn::Error::new_spanned(expr, "Expression Is not supported").to_compile_error()
}
};

Expand All @@ -86,6 +87,32 @@ pub(crate) fn codegen_expr(
}
}

/// Codegen for tuple expressions
pub(crate) fn codegen_tuple(
unary: &syn::ExprTuple,
loop_level: usize,
variable_tracker: &mut VariableTracker,
) -> TokenStream {
let mut res = quote::quote! {};
let mut vars = Vec::new();
for (i, expr) in unary.elems.iter().enumerate() {
let expr_codegen = codegen_expr(expr, loop_level, variable_tracker);
let expr_tokens = expr_codegen.tokens();
let var = Ident::new(&format!("_tuple_{}", i), Span::call_site());
res = quote::quote! {
#res
let #var = #expr_tokens;
};
vars.push(var);
}
quote::quote! {
{
#res
( #(#vars),* )
}
}
}

/// Codegen for an expression containing a block
pub(crate) fn codegen_expr_block(
block: &syn::ExprBlock,
Expand Down
1 change: 0 additions & 1 deletion crates/cubecl-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,6 @@ fn codegen_cube(
#signature {
#body
}

}
})
}

0 comments on commit cbbf866

Please sign in to comment.