diff --git a/.gitignore b/.gitignore index 6985cf1b..17f449fd 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,5 @@ Cargo.lock # MSVC Windows builds of rustc generate these, which store debugging information *.pdb +**/out +.clangd \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 00000000..37441bee --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "files.eol": "\n" +} \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index 0b78e563..b6c2a6a4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,17 +4,13 @@ # https://doc.rust-lang.org/cargo/reference/resolver.html#feature-resolver-version-2 resolver = "2" -members = [ - "crates/*", - "examples/*", - "xtask", -] +members = ["crates/*", "examples/*", "xtask"] [workspace.package] edition = "2021" -version = "0.2.0" license = "MIT OR Apache-2.0" readme = "README.md" +version = "0.2.0" [workspace.dependencies] @@ -29,23 +25,24 @@ serde = { version = "1.0.204", default-features = false, features = [ serde_json = { version = "1.0.119", default-features = false } dashmap = "5.5.3" -spin = { version = "0.9.8", features = ["mutex", "spin_mutex"] } hashbrown = "0.14.5" +spin = { version = "0.9.8", features = ["mutex", "spin_mutex"] } +getrandom = { version = "0.2.15", default-features = false } rand = { version = "0.8.5", default-features = false, features = [ "std_rng", ] } # std_rng is for no_std -getrandom = { version = "0.2.15", default-features = false } -pollster = "0.3" +async-channel = "2.3" dirs = "5.0.1" -web-time = "1.1.0" md5 = "0.7.0" -async-channel = "2.3" +pollster = "0.3" +weak-table = "0.3" +web-time = "1.1.0" # Testing -serial_test = "3.1.1" rstest = "0.19.0" +serial_test = "3.1.1" bytemuck = "1.16.1" half = { version = "2.4.1", features = [ @@ -57,15 +54,20 @@ num-traits = { version = "0.2.19", default-features = false, features = [ "libm", ] } # libm is for no_std +darling = "0.20.10" +ident_case = "1" proc-macro2 = "1.0.86" -syn = { version = "2.0.69", features = ["full", "extra-traits"] } quote = "1.0.36" +syn = { version = "2", features = ["full", "extra-traits", "visit-mut"] } ### For xtask crate ### -strum = {version = "0.26.3", features = ["derive"]} +strum = { version = "0.26.3", features = ["derive"] } tracel-xtask = { version = "~1.0" } -portable-atomic-util = { version = "0.2.2", features = ["alloc"] } # alloc is for no_std +portable-atomic-util = { version = "0.2.2", features = [ + "alloc", +] } # alloc is for no_std +pretty_assertions = "1.4" [profile.dev] opt-level = 2 diff --git a/crates/cubecl-common/Cargo.toml b/crates/cubecl-common/Cargo.toml index 73530d09..04290252 100644 --- a/crates/cubecl-common/Cargo.toml +++ b/crates/cubecl-common/Cargo.toml @@ -1,5 +1,8 @@ [package] -authors = ["Dilshod Tadjibaev (@antimora)", "Nathaniel Simard (@nathanielsimard)"] +authors = [ + "Dilshod Tadjibaev (@antimora)", + "Nathaniel Simard (@nathanielsimard)", +] categories = ["science", "mathematics", "algorithms"] description = "Common crate for CubeCL" edition.workspace = true @@ -20,18 +23,22 @@ web-time = { version = "1.1.0" } [dependencies] # ** Please make sure all dependencies support no_std when std is disabled ** -spin = { workspace = true } # using in place of use std::sync::Mutex; derive-new = { workspace = true } -serde = { workspace = true } -rand = { workspace = true } pollster = { workspace = true, optional = true } +rand = { workspace = true } +serde = { workspace = true } +spin = { workspace = true } # using in place of use std::sync::Mutex; [target.'cfg(target_has_atomic = "ptr")'.dependencies] spin = { workspace = true, features = ["mutex", "spin_mutex"] } [target.'cfg(not(target_has_atomic = "ptr"))'.dependencies] portable-atomic-util = { workspace = true } -spin = { workspace = true, features = ["mutex", "spin_mutex", "portable_atomic"] } +spin = { workspace = true, features = [ + "mutex", + "spin_mutex", + "portable_atomic", +] } [dev-dependencies] dashmap = { workspace = true } diff --git a/crates/cubecl-common/src/lib.rs b/crates/cubecl-common/src/lib.rs index 12dfabe9..48fc6eca 100644 --- a/crates/cubecl-common/src/lib.rs +++ b/crates/cubecl-common/src/lib.rs @@ -22,6 +22,8 @@ pub mod benchmark; /// notation. pub mod reader; +/// Operators used by macro and IR +pub mod operator; /// Synchronization type module, used both by ComputeServer and Backends. pub mod sync_type; diff --git a/crates/cubecl-common/src/operator.rs b/crates/cubecl-common/src/operator.rs new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/crates/cubecl-common/src/operator.rs @@ -0,0 +1 @@ + diff --git a/crates/cubecl-core/Cargo.toml b/crates/cubecl-core/Cargo.toml index d2e4b0f6..e67d2d0a 100644 --- a/crates/cubecl-core/Cargo.toml +++ b/crates/cubecl-core/Cargo.toml @@ -15,21 +15,23 @@ version.workspace = true [features] default = ["cubecl-runtime/default"] +export_tests = [] std = ["cubecl-runtime/std"] template = [] -export_tests = [] [dependencies] cubecl-runtime = { path = "../cubecl-runtime", version = "0.2.0", default-features = false } bytemuck = { workspace = true } -half = { workspace = true, features = ["bytemuck"] } -serde = { workspace = true } +cubecl-common = { path = "../cubecl-common", version = "0.2.0" } cubecl-macros = { path = "../cubecl-macros", version = "0.2.0" } derive-new = { workspace = true } +half = { workspace = true, features = ["bytemuck"] } num-traits = { workspace = true } +serde = { workspace = true } log = { workspace = true } [dev-dependencies] +pretty_assertions = { workspace = true } trybuild = "1" diff --git a/crates/cubecl-core/src/codegen/integrator.rs b/crates/cubecl-core/src/codegen/integrator.rs index 7bf25c23..c43fa5ec 100644 --- a/crates/cubecl-core/src/codegen/integrator.rs +++ b/crates/cubecl-core/src/codegen/integrator.rs @@ -1,3 +1,5 @@ +use std::num::NonZero; + use super::Compiler; use crate::{ ir::{ @@ -95,18 +97,22 @@ impl core::fmt::Display for KernelSettings { } match self.vectorization_global { - Some(vectorization) => f.write_fmt(format_args!("vg{}", vectorization))?, + Some(vectorization) => f.write_fmt(format_args!( + "vg{}", + vectorization.map(NonZero::get).unwrap_or(1) + ))?, None => f.write_str("vn")?, }; for vectorization in self.vectorization_partial.iter() { match vectorization { - VectorizationPartial::Input { pos, vectorization } => { - f.write_fmt(format_args!("v{vectorization}i{pos}"))? - } - VectorizationPartial::Output { pos, vectorization } => { - f.write_fmt(format_args!("v{vectorization}o{pos}"))? - } + VectorizationPartial::Input { pos, vectorization } => f.write_fmt(format_args!( + "v{}i{pos}", + vectorization.map(NonZero::get).unwrap_or(1) + ))?, + VectorizationPartial::Output { pos, vectorization } => f.write_fmt( + format_args!("v{}o{pos}", vectorization.map(NonZero::get).unwrap_or(1)), + )?, }; } @@ -130,7 +136,7 @@ impl KernelSettings { pub fn vectorize_input(mut self, position: usize, vectorization: Vectorization) -> Self { // Not setting the vectorization factor when it's the default value reduces the kernel id // size. - if vectorization == 1 { + if vectorization.is_none() { return self; } @@ -147,7 +153,7 @@ impl KernelSettings { pub fn vectorize_output(mut self, position: usize, vectorization: Vectorization) -> Self { // Not setting the vectorization factor when it's the default value reduces the kernel id // size. - if vectorization == 1 { + if vectorization.is_none() { return self; } @@ -173,7 +179,7 @@ impl KernelSettings { } } - 1 + None } /// Fetch the vectorization for the provided output position. @@ -190,7 +196,7 @@ impl KernelSettings { } } - 1 + None } /// Compile the shader with inplace enabled by the given [mapping](InplaceMapping). diff --git a/crates/cubecl-core/src/frontend/branch.rs b/crates/cubecl-core/src/frontend/branch.rs index b95a6029..92035727 100644 --- a/crates/cubecl-core/src/frontend/branch.rs +++ b/crates/cubecl-core/src/frontend/branch.rs @@ -1,159 +1,232 @@ -use std::ops::Deref; +use num_traits::NumCast; -use crate::frontend::{CubeContext, ExpandElement, UInt}; -use crate::ir::{Branch, Elem, If, IfElse, Item, Loop, RangeLoop, Variable}; +use crate::frontend::{CubeContext, ExpandElement}; +use crate::ir::{Branch, If, IfElse, Item, Loop, RangeLoop}; -use super::comptime::Comptime; -use super::ExpandElementTyped; +use super::{CubeType, ExpandElementTyped, Int, Numeric}; -/// UInt range. Equivalent to: -/// -/// ```ignore -/// for i in start..end { ... } -/// ``` -pub fn range(start: S, end: E, _unroll: Comptime) -> impl Iterator -where - S: Into, - E: Into, -{ - let start: UInt = start.into(); - let end: UInt = end.into(); +/// Something that can be iterated on by a for loop. Currently only includes `Range`, `StepBy` and +/// `Sequence`. +pub trait Iterable: Sized { + /// Expand a runtime loop without unrolling + /// + /// # Arguments + /// * `context` - the expansion context + /// * `body` - the loop body to be executed repeatedly + fn expand( + self, + context: &mut CubeContext, + body: impl FnMut(&mut CubeContext, ::ExpandType), + ); + /// Expand an unrolled loop. The body should be invoced `n` times, where `n` is the number of + /// iterations. + /// + /// # Arguments + /// * `context` - the expansion context + /// * `body` - the loop body to be executed repeatedly + fn expand_unroll( + self, + context: &mut CubeContext, + body: impl FnMut(&mut CubeContext, ::ExpandType), + ); +} - (start.val..end.val).map(UInt::new) +pub struct RangeExpand { + pub start: ExpandElementTyped, + pub end: ExpandElementTyped, + pub inclusive: bool, } -/// Stepped range. Equivalent to: -/// -/// ```ignore -/// for i in (start..end).step_by(step) { ... } -/// ``` -pub fn range_stepped( - start: S, - end: E, - step: Step, - _unroll: Comptime, -) -> impl Iterator -where - S: Into, - E: Into, - Step: Into, -{ - let start: UInt = start.into(); - let end: UInt = end.into(); - let step: UInt = step.into(); +impl RangeExpand { + pub fn new(start: ExpandElementTyped, end: ExpandElementTyped, inclusive: bool) -> Self { + RangeExpand { + start, + end, + inclusive, + } + } - (start.val..end.val) - .step_by(step.val as usize) - .map(UInt::new) + pub fn __expand_step_by(self, n: impl Into>) -> SteppedRangeExpand { + SteppedRangeExpand { + start: self.start, + end: self.end, + step: n.into(), + inclusive: self.inclusive, + } + } } -pub fn range_expand(context: &mut CubeContext, start: S, end: E, unroll: bool, mut func: F) -where - F: FnMut(&mut CubeContext, ExpandElementTyped), - S: Into>, - E: Into>, -{ - let start: ExpandElementTyped = start.into(); - let end: ExpandElementTyped = end.into(); - let start = start.expand; - let end = end.expand; +impl Iterable for RangeExpand { + fn expand_unroll( + self, + context: &mut CubeContext, + mut body: impl FnMut(&mut CubeContext, ::ExpandType), + ) { + let start = self + .start + .expand + .as_const() + .expect("Only constant start can be unrolled.") + .as_i64(); + let end = self + .end + .expand + .as_const() + .expect("Only constant end can be unrolled.") + .as_i64(); - if unroll { - let start = match start.deref() { - Variable::ConstantScalar(value) => value.as_usize(), - _ => panic!("Only constant start can be unrolled."), - }; - let end = match end.deref() { - Variable::ConstantScalar(value) => value.as_usize(), - _ => panic!("Only constant end can be unrolled."), - }; - - for i in start..end { - let var: ExpandElement = i.into(); - func(context, var.into()) + if self.inclusive { + for i in start..=end { + let var = I::from_int(i); + body(context, var.into()) + } + } else { + for i in start..end { + let var = I::from_int(i); + body(context, var.into()) + } } - } else { + } + + fn expand( + self, + context: &mut CubeContext, + mut body: impl FnMut(&mut CubeContext, ::ExpandType), + ) { let mut child = context.child(); - let index_ty = Item::new(Elem::UInt); + let index_ty = Item::new(I::as_elem()); let i = child.scope.borrow_mut().create_local_undeclared(index_ty); let i = ExpandElement::Plain(i); - func(&mut child, i.clone().into()); + body(&mut child, i.clone().into()); context.register(Branch::RangeLoop(RangeLoop { i: *i, - start: *start, - end: *end, + start: *self.start.expand, + end: *self.end.expand, step: None, scope: child.into_scope(), + inclusive: self.inclusive, })); } } -pub fn range_stepped_expand( - context: &mut CubeContext, - start: S, - end: E, - step: Step, - unroll: bool, - mut func: F, -) where - F: FnMut(&mut CubeContext, ExpandElementTyped), - S: Into>, - E: Into>, - Step: Into>, -{ - let start: ExpandElementTyped = start.into(); - let end: ExpandElementTyped = end.into(); - let step: ExpandElementTyped = step.into(); - let start = start.expand; - let end = end.expand; - let step = step.expand; +pub struct SteppedRangeExpand { + start: ExpandElementTyped, + end: ExpandElementTyped, + step: ExpandElementTyped, + inclusive: bool, +} - if unroll { - let start = match start.deref() { - Variable::ConstantScalar(value) => value.as_usize(), - _ => panic!("Only constant start can be unrolled."), - }; - let end = match end.deref() { - Variable::ConstantScalar(value) => value.as_usize(), - _ => panic!("Only constant end can be unrolled."), - }; - let step: usize = match step.deref() { - Variable::ConstantScalar(value) => value.as_usize(), - _ => panic!("Only constant step can be unrolled."), - }; - - for i in (start..end).step_by(step) { - let var: ExpandElement = i.into(); - func(context, var.into()) - } - } else { +impl> Iterable for SteppedRangeExpand { + fn expand( + self, + context: &mut CubeContext, + mut body: impl FnMut(&mut CubeContext, ::ExpandType), + ) { let mut child = context.child(); - let index_ty = Item::new(Elem::UInt); + let index_ty = Item::new(I::as_elem()); let i = child.scope.borrow_mut().create_local_undeclared(index_ty); let i = ExpandElement::Plain(i); - func(&mut child, i.clone().into()); + body(&mut child, i.clone().into()); context.register(Branch::RangeLoop(RangeLoop { i: *i, - start: *start, - end: *end, - step: Some(*step), + start: *self.start.expand, + end: *self.end.expand, + step: Some(*self.step.expand), scope: child.into_scope(), + inclusive: self.inclusive, })); } + + fn expand_unroll( + self, + context: &mut CubeContext, + mut body: impl FnMut(&mut CubeContext, ::ExpandType), + ) { + let start = self + .start + .expand + .as_const() + .expect("Only constant start can be unrolled.") + .as_i64(); + let end = self + .end + .expand + .as_const() + .expect("Only constant end can be unrolled.") + .as_i64(); + let step = self + .step + .expand + .as_const() + .expect("Only constant step can be unrolled.") + .as_usize(); + + if self.inclusive { + for i in (start..=end).step_by(step) { + let var = I::from_int(i); + body(context, var.into()) + } + } else { + for i in (start..end).step_by(step) { + let var = I::from_int(i); + body(context, var.into()) + } + } + } } -pub fn if_expand( +/// integer range. Equivalent to: +/// +/// ```ignore +/// for i in start..end { ... } +/// ``` +pub fn range(start: T, end: T) -> impl Iterator { + let start: i64 = start.to_i64().unwrap(); + let end: i64 = end.to_i64().unwrap(); + (start..end).map(::from).map(Option::unwrap) +} + +/// Stepped range. Equivalent to: +/// +/// ```ignore +/// for i in (start..end).step_by(step) { ... } +/// ``` +pub fn range_stepped(start: I, end: I, step: I) -> impl Iterator +where + RangeExpand: Iterator, +{ + let start = start.to_i64().unwrap(); + let end = end.to_i64().unwrap(); + let step = step.to_usize().unwrap(); + (start..end) + .step_by(step) + .map(::from) + .map(Option::unwrap) +} + +pub fn for_expand( + context: &mut CubeContext, + range: impl Iterable, + unroll: bool, + body: impl FnMut(&mut CubeContext, ExpandElementTyped), +) { + if unroll { + range.expand_unroll(context, body); + } else { + range.expand(context, body); + } +} + +pub fn if_expand( context: &mut CubeContext, - comptime_cond: Option, runtime_cond: ExpandElement, - mut block: IF, -) where - IF: FnMut(&mut CubeContext), -{ + block: impl FnOnce(&mut CubeContext), +) { + let comptime_cond = runtime_cond.as_const().map(|it| it.as_bool()); match comptime_cond { Some(cond) => { if cond { @@ -173,16 +246,13 @@ pub fn if_expand( } } -pub fn if_else_expand( +pub fn if_else_expand( context: &mut CubeContext, - comptime_cond: Option, runtime_cond: ExpandElement, - mut then_block: IF, - mut else_block: EL, -) where - IF: FnMut(&mut CubeContext), - EL: FnMut(&mut CubeContext), -{ + then_block: impl FnOnce(&mut CubeContext), + else_block: impl FnOnce(&mut CubeContext), +) { + let comptime_cond = runtime_cond.as_const().map(|it| it.as_bool()); match comptime_cond { Some(cond) => { if cond { @@ -227,15 +297,15 @@ where })); } -pub fn while_loop_expand(context: &mut CubeContext, mut cond_fn: FC, mut block: FB) -where - FC: FnMut(&mut CubeContext) -> ExpandElementTyped, - FB: FnMut(&mut CubeContext), -{ +pub fn while_loop_expand( + context: &mut CubeContext, + mut cond_fn: impl FnMut(&mut CubeContext) -> ExpandElementTyped, + block: impl FnOnce(&mut CubeContext), +) { let mut inside_loop = context.child(); let cond: ExpandElement = cond_fn(&mut inside_loop).into(); - if_expand(&mut inside_loop, None, cond, break_expand); + if_expand(&mut inside_loop, cond, break_expand); block(&mut inside_loop); context.register(Branch::Loop(Loop { diff --git a/crates/cubecl-core/src/frontend/cmma.rs b/crates/cubecl-core/src/frontend/cmma.rs index f6737a0a..042f0b70 100644 --- a/crates/cubecl-core/src/frontend/cmma.rs +++ b/crates/cubecl-core/src/frontend/cmma.rs @@ -32,15 +32,15 @@ //! cmma::MatrixLayout::Undefined, //! ); //! cmma::fill::(&c, F32::new(0.0)); -//! cmma::load::(&a, lhs.as_slice(), UInt::new(16)); -//! cmma::load::(&b, rhs.as_slice(), UInt::new(16)); +//! cmma::load::(&a, lhs.as_slice(), u32::new(16)); +//! cmma::load::(&b, rhs.as_slice(), u32::new(16)); //! //! cmma::execute::(&a, &b, &c, &c); //! //! cmma::store::( //! out.as_slice_mut(), //! &c, -//! UInt::new(16), +//! u32::new(16), //! cmma::MatrixLayout::RowMajor, //! ); //! } @@ -54,8 +54,8 @@ use crate::{ }; use super::{ - CubeContext, CubePrimitive, CubeType, ExpandElement, ExpandElementTyped, Init, Slice, SliceMut, - UInt, + CubeContext, CubePrimitive, CubeType, ExpandElement, ExpandElementTyped, Init, IntoRuntime, + Slice, SliceMut, }; pub use ir::{MatrixIdent, MatrixLayout}; @@ -79,6 +79,12 @@ impl CubeType for Matrix { type ExpandType = MatrixExpand; } +impl IntoRuntime for Matrix { + fn __expand_runtime_method(self, _context: &mut CubeContext) -> MatrixExpand { + unimplemented!("Matrices can't exist at compile time") + } +} + impl Init for MatrixExpand { fn init(self, _context: &mut CubeContext) -> Self { self @@ -107,9 +113,9 @@ impl Matrix { pub fn __expand_new( context: &mut CubeContext, ident: MatrixIdent, - m: ExpandElementTyped, - n: ExpandElementTyped, - k: ExpandElementTyped, + m: ExpandElementTyped, + n: ExpandElementTyped, + k: ExpandElementTyped, layout: MatrixLayout, ) -> MatrixExpand { let elem = context.create_matrix(ir::Matrix { @@ -135,7 +141,7 @@ pub mod fill { use super::*; /// Expand method of [fill()]. - pub fn __expand( + pub fn expand( context: &mut CubeContext, mat: MatrixExpand, value: ExpandElementTyped, @@ -150,7 +156,7 @@ pub mod fill { /// Load the matrix with the provided array using the stride. #[allow(unused_variables)] -pub fn load(mat: &Matrix, value: &Slice<'_, C>, stride: UInt) { +pub fn load(mat: &Matrix, value: &Slice<'_, C>, stride: u32) { unexpanded!() } @@ -160,11 +166,11 @@ pub mod load { /// Expand method of [load()]. #[allow(unused_variables)] - pub fn __expand( + pub fn expand( context: &mut CubeContext, mat: MatrixExpand, value: ExpandElementTyped>, - stride: ExpandElementTyped, + stride: ExpandElementTyped, ) { let stride: ExpandElement = stride.into(); @@ -181,7 +187,7 @@ pub mod load { pub fn store( output: &mut SliceMut<'_, C>, mat: &Matrix, - stride: UInt, + stride: u32, layout: MatrixLayout, ) { unexpanded!() @@ -193,11 +199,11 @@ pub mod store { /// Expand method of [store()]. #[allow(unused_variables)] - pub fn __expand( + pub fn expand( context: &mut CubeContext, output: ExpandElementTyped>, mat: MatrixExpand, - stride: ExpandElementTyped, + stride: ExpandElementTyped, layout: MatrixLayout, ) { let stride: ExpandElement = stride.into(); @@ -227,7 +233,7 @@ pub mod execute { use super::*; /// Expand method of [execute()]. - pub fn __expand( + pub fn expand( context: &mut CubeContext, mat_a: MatrixExpand, mat_b: MatrixExpand, diff --git a/crates/cubecl-core/src/frontend/comptime.rs b/crates/cubecl-core/src/frontend/comptime.rs deleted file mode 100644 index deec54bf..00000000 --- a/crates/cubecl-core/src/frontend/comptime.rs +++ /dev/null @@ -1,160 +0,0 @@ -use crate::{ - frontend::{CubeContext, CubeType}, - unexpanded, -}; - -use super::{CubePrimitive, ExpandElement, ExpandElementTyped, Init, UInt, Vectorized}; - -#[derive(Clone, Copy)] -/// Encapsulates a value to signify it must be used at compilation time rather than in the kernel -/// -/// Use `Comptime>` to have an alternate runtime behaviour if the compilation time value is not present -pub struct Comptime { - pub(crate) inner: T, -} - -/// Type that can be used within [Comptime]. -pub trait ComptimeType: CubeType + Into { - /// Create the expand type from the normal type. - fn into_expand(self) -> Self::ExpandType; -} - -impl ComptimeType for UInt { - fn into_expand(self) -> Self::ExpandType { - ExpandElementTyped::new(self.into()) - } -} - -impl Comptime { - /// Create a new Comptime. Useful when hardcoding values in - /// Cube kernels. For instance: - /// if Comptime::new(false) {...} never generates the inner code block - pub fn new(inner: T) -> Self { - Self { inner } - } - - /// Get the inner value of a Comptime. For instance: - /// let c = Comptime::new(false); - /// if Comptime::get(c) {...} - pub fn get(_comptime: Self) -> T { - unexpanded!() - } - - /// Executes a closure on the comptime and returns a new comptime containing the value. - pub fn map R>(_comptime: Self, _closure: F) -> Comptime { - unexpanded!() - } - - pub fn __expand_map R>(inner: T, closure: F) -> R { - closure(inner) - } -} - -impl Comptime> { - /// Map a Comptime optional to a Comptime boolean that tell - /// whether the optional contained a value - pub fn is_some(comptime: Self) -> Comptime { - Comptime::new(comptime.inner.is_some()) - } - - /// Return the inner value of the Comptime if it exists, - /// otherwise tell how to compute it at runtime - pub fn unwrap_or_else(_comptime: Self, mut _alt: F) -> T - where - F: FnOnce() -> T, - { - unexpanded!() - } - - /// Expanded version of unwrap_or_else - pub fn __expand_unwrap_or_else( - context: &mut CubeContext, - t: Option, - alt: F, - ) -> ::ExpandType - where - F: FnOnce(&mut CubeContext) -> T::ExpandType, - { - match t { - Some(t) => t.into_expand(), - None => alt(context), - } - } -} - -impl CubeType for Comptime { - type ExpandType = T; -} - -impl Comptime { - pub fn vectorization(_state: &T) -> Comptime { - unexpanded!() - } - - pub fn __expand_vectorization(_context: &mut CubeContext, state: T) -> UInt { - state.vectorization_factor() - } -} - -impl> Comptime { - pub fn runtime(_comptime: Self) -> T { - unexpanded!() - } - - pub fn __expand_runtime(_context: &mut CubeContext, inner: T) -> ExpandElementTyped { - let elem: ExpandElement = inner.into(); - elem.into() - } -} - -impl> core::ops::Add for Comptime { - type Output = Comptime; - - fn add(self, rhs: Self) -> Self::Output { - Comptime::new(self.inner.add(rhs.inner)) - } -} - -impl> core::ops::Sub for Comptime { - type Output = Comptime; - - fn sub(self, rhs: Self) -> Self::Output { - Comptime::new(self.inner.sub(rhs.inner)) - } -} - -impl> core::ops::Div for Comptime { - type Output = Comptime; - - fn div(self, rhs: Self) -> Self::Output { - Comptime::new(self.inner.div(rhs.inner)) - } -} - -impl> core::ops::Mul for Comptime { - type Output = Comptime; - - fn mul(self, rhs: Self) -> Self::Output { - Comptime::new(self.inner.mul(rhs.inner)) - } -} - -impl> core::ops::Rem for Comptime { - type Output = Comptime; - - fn rem(self, rhs: Self) -> Self::Output { - Comptime::new(self.inner.rem(rhs.inner)) - } -} - -impl core::cmp::PartialEq for Comptime { - fn eq(&self, other: &Self) -> bool { - core::cmp::PartialEq::eq(&self.inner, &other.inner) - } -} - -impl core::cmp::PartialOrd for Comptime { - fn partial_cmp(&self, other: &Self) -> Option { - core::cmp::PartialOrd::partial_cmp(&self.inner, &other.inner) - } -} diff --git a/crates/cubecl-core/src/frontend/const_expand.rs b/crates/cubecl-core/src/frontend/const_expand.rs new file mode 100644 index 00000000..fc2a08a1 --- /dev/null +++ b/crates/cubecl-core/src/frontend/const_expand.rs @@ -0,0 +1,19 @@ +use super::{CubeContext, CubeType}; + +pub trait OptionExt { + fn __expand_unwrap_or_else_method( + self, + _context: &mut CubeContext, + other: impl FnOnce(&mut CubeContext) -> T::ExpandType, + ) -> T::ExpandType; +} + +impl> OptionExt for Option { + fn __expand_unwrap_or_else_method( + self, + context: &mut CubeContext, + other: impl FnOnce(&mut CubeContext) -> ::ExpandType, + ) -> ::ExpandType { + self.map(Into::into).unwrap_or_else(|| other(context)) + } +} diff --git a/crates/cubecl-core/src/frontend/element/array.rs b/crates/cubecl-core/src/frontend/element/array.rs index 8e2ba9c5..b0520d6d 100644 --- a/crates/cubecl-core/src/frontend/element/array.rs +++ b/crates/cubecl-core/src/frontend/element/array.rs @@ -1,4 +1,4 @@ -use std::marker::PhantomData; +use std::{marker::PhantomData, num::NonZero}; use crate::{ compute::{KernelBuilder, KernelLauncher}, @@ -8,12 +8,12 @@ use crate::{ }; use crate::{ frontend::{indexation::Index, CubeContext}, - prelude::{assign, index, index_assign, Comptime}, + prelude::{assign, index, index_assign}, }; use super::{ ArgSettings, CubePrimitive, ExpandElement, ExpandElementBaseInit, ExpandElementTyped, - LaunchArg, LaunchArgExpand, TensorHandleRef, UInt, + LaunchArg, LaunchArgExpand, TensorHandleRef, }; /// A contiguous array of elements. @@ -30,19 +30,18 @@ impl Array { Array { _val: PhantomData } } - pub fn vectorized(_size: S, _vectorization_factor: UInt) -> Self { + pub fn vectorized(_size: S, _vectorization_factor: u32) -> Self { Array { _val: PhantomData } } - pub fn __expand_new( + pub fn __expand_new( context: &mut CubeContext, - size: S, + size: ExpandElementTyped, ) -> ::ExpandType { - let size = size.value(); - let size = match size { - crate::ir::Variable::ConstantScalar(value) => value.as_u32(), - _ => panic!("Array need constant initialization value"), - }; + let size = size + .constant() + .expect("Array need constant initialization value") + .as_u32(); context .create_local_array(Item::new(T::as_elem()), size) .into() @@ -51,7 +50,7 @@ impl Array { pub fn __expand_vectorized( context: &mut CubeContext, size: S, - vectorization_factor: UInt, + vectorization_factor: u32, ) -> ::ExpandType { let size = size.value(); let size = match size { @@ -60,13 +59,13 @@ impl Array { }; context .create_local_array( - Item::vectorized(T::as_elem(), vectorization_factor.val as u8), + Item::vectorized(T::as_elem(), NonZero::new(vectorization_factor as u8)), size, ) .into() } - pub fn to_vectorized(self, _vectorization_factor: Comptime) -> T { + pub fn to_vectorized(self, _vectorization_factor: u32) -> T { unexpanded!() } } @@ -75,15 +74,21 @@ impl ExpandElementTyped> { pub fn __expand_to_vectorized_method( self, context: &mut CubeContext, - vectorization_factor: UInt, + vectorization_factor: ExpandElementTyped, ) -> ExpandElementTyped { - let factor = vectorization_factor.val; + let factor = vectorization_factor + .constant() + .expect("Vectorization must be comptime") + .as_u32(); let var = self.expand.clone(); - let new_var = context.create_local(Item::vectorized(var.item().elem(), factor as u8)); + let new_var = context.create_local(Item::vectorized( + var.item().elem(), + NonZero::new(factor as u8), + )); - if vectorization_factor.val == 1 { + if factor == 1 { let element = index::expand(context, self.clone(), ExpandElementTyped::from_lit(0u32)); - assign::expand(context, element, new_var.clone()); + assign::expand(context, element, new_var.clone().into()); } else { for i in 0..factor { let expand: Self = self.expand.clone().into(); @@ -113,7 +118,8 @@ impl ExpandElementBaseInit for Array { impl Array { /// Obtain the array length - pub fn len(&self) -> UInt { + #[allow(clippy::len_without_is_empty)] + pub fn len(&self) -> u32 { unexpanded!() } } @@ -178,7 +184,7 @@ impl<'a, R: Runtime> ArgSettings for ArrayArg<'a, R> { Self::Handle { handle: _, vectorization_factor, - } => settings.vectorize_input(position, *vectorization_factor), + } => settings.vectorize_input(position, NonZero::new(*vectorization_factor)), Self::Alias { input_pos: _ } => { panic!("Not yet supported, only output can be aliased for now."); } @@ -190,7 +196,7 @@ impl<'a, R: Runtime> ArgSettings for ArrayArg<'a, R> { Self::Handle { handle: _, vectorization_factor, - } => settings.vectorize_output(position, *vectorization_factor), + } => settings.vectorize_output(position, NonZero::new(*vectorization_factor)), Self::Alias { input_pos } => { settings.mappings.push(crate::InplaceMapping { pos_input: *input_pos, diff --git a/crates/cubecl-core/src/frontend/element/atomic.rs b/crates/cubecl-core/src/frontend/element/atomic.rs index 5c39a6da..0c722d29 100644 --- a/crates/cubecl-core/src/frontend/element/atomic.rs +++ b/crates/cubecl-core/src/frontend/element/atomic.rs @@ -1,9 +1,9 @@ use super::{ - init_expand_element, ExpandElementBaseInit, ExpandElementTyped, LaunchArgExpand, Numeric, - Vectorized, I32, I64, + init_expand_element, ExpandElementBaseInit, ExpandElementTyped, IntoRuntime, LaunchArgExpand, + Numeric, }; use crate::{ - frontend::{CubeContext, CubePrimitive, CubeType, ExpandElement, UInt}, + frontend::{CubeContext, CubePrimitive, CubeType, ExpandElement}, ir::{ BinaryOperator, CompareAndSwapOperator, Elem, IntKind, Item, Operator, UnaryOperator, Vectorization, @@ -278,13 +278,21 @@ macro_rules! impl_atomic_int { #[derive(Clone, Copy, Hash, PartialEq, Eq)] pub struct $type { pub val: $primitive, - pub vectorization: u8, } impl CubeType for $type { type ExpandType = ExpandElementTyped; } + impl IntoRuntime for $type { + fn __expand_runtime_method( + self, + _context: &mut CubeContext, + ) -> ExpandElementTyped { + unimplemented!("Atomics don't exist at compile time") + } + } + impl CubePrimitive for $type { fn as_elem() -> Elem { Elem::AtomicInt(IntKind::$inner_type) @@ -302,95 +310,68 @@ macro_rules! impl_atomic_int { builder: &mut KernelBuilder, vectorization: Vectorization, ) -> ExpandElementTyped { - assert_eq!(vectorization, 1, "Attempted to vectorize a scalar"); + assert_eq!(vectorization, None, "Attempted to vectorize a scalar"); builder.scalar(Elem::AtomicInt(IntKind::$inner_type)).into() } } - - impl Vectorized for $type { - fn vectorization_factor(&self) -> UInt { - UInt { - val: self.vectorization as u32, - vectorization: 1, - } - } - - fn vectorize(mut self, factor: UInt) -> Self { - self.vectorization = factor.vectorization; - self - } - } }; } impl_atomic_int!(AtomicI32, I32, i32); impl_atomic_int!(AtomicI64, I64, i64); -/// An atomic version of `UInt`. Can only be acted on atomically. +/// An atomic version of `u32`. Can only be acted on atomically. #[allow(clippy::derived_hash_with_manual_eq)] #[derive(Clone, Copy, Hash, PartialEq, Eq)] /// An atomic unsigned int. -pub struct AtomicUInt { +pub struct AtomicU32 { pub val: u32, - pub vectorization: u8, } -impl core::fmt::Debug for AtomicUInt { +impl core::fmt::Debug for AtomicU32 { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - if self.vectorization == 1 { - f.write_fmt(format_args!("{}", self.val)) - } else { - f.write_fmt(format_args!("{}-{}", self.val, self.vectorization)) - } + f.write_fmt(format_args!("{}", self.val)) } } -impl CubeType for AtomicUInt { +impl CubeType for AtomicU32 { type ExpandType = ExpandElementTyped; } -impl CubePrimitive for AtomicUInt { +impl CubePrimitive for AtomicU32 { fn as_elem() -> Elem { Elem::AtomicUInt } } -impl ExpandElementBaseInit for AtomicUInt { +impl IntoRuntime for AtomicU32 { + fn __expand_runtime_method(self, _context: &mut CubeContext) -> ExpandElementTyped { + unimplemented!("Atomics don't exist at compile time") + } +} + +impl ExpandElementBaseInit for AtomicU32 { fn init_elem(context: &mut CubeContext, elem: ExpandElement) -> ExpandElement { init_expand_element(context, elem) } } -impl LaunchArgExpand for AtomicUInt { +impl LaunchArgExpand for AtomicU32 { fn expand( builder: &mut KernelBuilder, vectorization: Vectorization, ) -> ExpandElementTyped { - assert_eq!(vectorization, 1, "Attempted to vectorize a scalar"); + assert_eq!(vectorization, None, "Attempted to vectorize a scalar"); builder.scalar(Elem::AtomicUInt).into() } } impl Atomic for AtomicI32 { - type Primitive = I32; + type Primitive = i32; } impl Atomic for AtomicI64 { - type Primitive = I64; + type Primitive = i64; } -impl Atomic for AtomicUInt { - type Primitive = UInt; -} - -impl Vectorized for AtomicUInt { - fn vectorization_factor(&self) -> UInt { - UInt { - val: self.vectorization as u32, - vectorization: 1, - } - } - - fn vectorize(mut self, factor: UInt) -> Self { - self.vectorization = factor.vectorization; - self - } +impl Atomic for AtomicU32 { + type Primitive = u32; } diff --git a/crates/cubecl-core/src/frontend/element/base.rs b/crates/cubecl-core/src/frontend/element/base.rs index 663d2c10..f6f2b7f0 100644 --- a/crates/cubecl-core/src/frontend/element/base.rs +++ b/crates/cubecl-core/src/frontend/element/base.rs @@ -1,11 +1,12 @@ -use super::{Bool, CubePrimitive, Numeric, UInt, Vectorized, F32, F64, I32, I64}; +use super::{CubePrimitive, Numeric, Vectorized}; use crate::{ - ir::{ConstantScalarValue, Elem, Item, Operator, Variable, Vectorization}, - prelude::{index_assign, init_expand, CubeContext, KernelBuilder, KernelLauncher}, + ir::{ConstantScalarValue, Elem, FloatKind, Item, Operator, Variable, Vectorization}, + prelude::{index_assign, init_expand, CubeContext, CubeIndex, KernelBuilder, KernelLauncher}, KernelSettings, Runtime, }; use alloc::rc::Rc; -use std::marker::PhantomData; +use half::{bf16, f16}; +use std::{marker::PhantomData, num::NonZero}; /// Types used in a cube function must implement this trait /// @@ -28,6 +29,14 @@ pub trait CubeType { } } +pub trait IntoRuntime: CubeType + Sized { + fn runtime(self) -> Self { + self + } + + fn __expand_runtime_method(self, context: &mut CubeContext) -> Self::ExpandType; +} + /// Trait to be implemented by [cube types](CubeType) implementations. pub trait Init: Sized { /// Initialize a type within a [context](CubeContext). @@ -123,8 +132,8 @@ pub struct ExpandElementTyped { } macro_rules! from_const { - ($lit:ty, $ty:ty) => { - impl From<$lit> for ExpandElementTyped<$ty> { + ($lit:ty) => { + impl From<$lit> for ExpandElementTyped<$lit> { fn from(value: $lit) -> Self { let variable: Variable = value.into(); @@ -132,26 +141,30 @@ macro_rules! from_const { } } }; - (val $($lit:ty),*) => { - $( - impl From<$lit> for ExpandElementTyped { - fn from(value: $lit) -> Self { - let variable: Variable = value.val.into(); +} - ExpandElement::Plain(variable).into() - } - } - )* - }; +from_const!(u32); +from_const!(i64); +from_const!(i32); +from_const!(f64); +from_const!(f32); +from_const!(bool); + +impl From for ExpandElementTyped { + fn from(value: f16) -> Self { + let variable = + Variable::ConstantScalar(ConstantScalarValue::Float(value.to_f64(), FloatKind::F16)); + ExpandElement::Plain(variable).into() + } } -from_const!(u32, UInt); -from_const!(i64, I64); -from_const!(i32, I32); -from_const!(f64, F64); -from_const!(f32, F32); -from_const!(bool, Bool); -from_const!(val UInt, I32, I64, F32, F64); +impl From for ExpandElementTyped { + fn from(value: bf16) -> Self { + let variable = + Variable::ConstantScalar(ConstantScalarValue::Float(value.to_f64(), FloatKind::BF16)); + ExpandElement::Plain(variable).into() + } +} macro_rules! tuple_cube_type { ($($P:ident),*) => { @@ -173,6 +186,19 @@ macro_rules! tuple_init { } } } +macro_rules! tuple_runtime { + ($($P:ident),*) => { + impl<$($P: IntoRuntime),*> IntoRuntime for ($($P,)*) { + #[allow(non_snake_case)] + fn __expand_runtime_method(self, context: &mut CubeContext) -> Self::ExpandType { + let ($($P,)*) = self; + ($( + $P.__expand_runtime_method(context), + )*) + } + } + } +} tuple_cube_type!(P1); tuple_cube_type!(P1, P2); @@ -188,6 +214,13 @@ tuple_init!(P1, P2, P3, P4); tuple_init!(P1, P2, P3, P4, P5); tuple_init!(P1, P2, P3, P4, P5, P6); +tuple_runtime!(P1); +tuple_runtime!(P1, P2); +tuple_runtime!(P1, P2, P3); +tuple_runtime!(P1, P2, P3, P4); +tuple_runtime!(P1, P2, P3, P4, P5); +tuple_runtime!(P1, P2, P3, P4, P5, P6); + pub trait ExpandElementBaseInit: CubeType { fn init_elem(context: &mut CubeContext, elem: ExpandElement) -> ExpandElement; } @@ -199,11 +232,29 @@ impl Init for ExpandElementTyped { } impl Vectorized for ExpandElementTyped { - fn vectorization_factor(&self) -> UInt { + fn vectorization_factor(&self) -> u32 { self.expand.vectorization_factor() } - fn vectorize(self, factor: UInt) -> Self { + fn vectorize(self, factor: u32) -> Self { + Self { + expand: self.expand.vectorize(factor), + _type: PhantomData, + } + } +} + +impl ExpandElementTyped { + // Expanded version of rank. + pub fn __expand_vectorization_factor_method(self, _context: &mut CubeContext) -> u32 { + self.expand + .item() + .vectorization + .map(|it| it.get()) + .unwrap_or(1) as u32 + } + + pub fn __expand_vectorize_method(self, _context: &mut CubeContext, factor: u32) -> Self { Self { expand: self.expand.vectorize(factor), _type: PhantomData, @@ -347,22 +398,6 @@ impl Init for ExpandElement { } } -macro_rules! impl_init_for { - ($($t:ty),*) => { - $( - impl Init for $t { - fn init(self, _context: &mut CubeContext) -> Self { - panic!("Shouln't be called, only for comptime.") - } - } - - )* - }; -} - -// Add all types used within comptime -impl_init_for!(u32, bool, UInt); - impl Init for Option { fn init(self, context: &mut CubeContext) -> Self { self.map(|o| Init::init(o, context)) @@ -384,37 +419,40 @@ impl Init for Vec { } /// Create a constant element of the correct type during expansion. -pub(crate) fn __expand_new( +pub(crate) fn __expand_new( _context: &mut CubeContext, - val: ExpandElementTyped, - elem: Elem, -) -> ExpandElementTyped { - ExpandElement::Plain(elem.from_constant(*val.expand)).into() + val: C, +) -> ExpandElementTyped { + let val = Out::from(val).unwrap(); + val.into() } /// Create a vectorized constant element of the correct type during expansion. -pub(crate) fn __expand_vectorized( +pub(crate) fn __expand_vectorized, Out: Numeric>( context: &mut CubeContext, - val: ExpandElementTyped, - vectorization: UInt, + val: C, + vectorization: u32, elem: Elem, -) -> ExpandElementTyped { - if vectorization.val == 1 { - __expand_new(context, val, elem) - } else { - let new_var = context.create_local(Item::vectorized(elem, vectorization.val as u8)); - - for (i, element) in vec![val; vectorization.val as usize].iter().enumerate() { - let element = elem.from_constant(*element.expand); - - index_assign::expand::( - context, - new_var.clone().into(), - ExpandElementTyped::from_lit(i), - ExpandElement::Plain(element).into(), - ); - } +) -> ExpandElementTyped { + let new_var = context.create_local(Item::vectorized(elem, NonZero::new(vectorization as u8))); + let val = Out::from(val).unwrap(); + let val: ExpandElementTyped = val.into(); + + // Allow setting explicit vectorization of 1 without trying to index assign it + if vectorization == 1 { + return val; + } - new_var.into() + for (i, element) in vec![val; vectorization as usize].iter().enumerate() { + let element = elem.from_constant(*element.expand); + + index_assign::expand::( + context, + new_var.clone().into(), + ExpandElementTyped::from_lit(i), + ExpandElement::Plain(element).into(), + ); } + + new_var.into() } diff --git a/crates/cubecl-core/src/frontend/element/bool.rs b/crates/cubecl-core/src/frontend/element/bool.rs index 2f7c0b85..56ef9568 100644 --- a/crates/cubecl-core/src/frontend/element/bool.rs +++ b/crates/cubecl-core/src/frontend/element/bool.rs @@ -1,15 +1,12 @@ use crate::frontend::{CubePrimitive, CubeType}; use crate::ir::Elem; -use crate::prelude::{ComptimeType, CubeContext}; +use crate::prelude::CubeContext; use super::{ - init_expand_element, ExpandElement, ExpandElementBaseInit, ExpandElementTyped, Vectorized, + init_expand_element, ExpandElement, ExpandElementBaseInit, ExpandElementTyped, Init, + IntoRuntime, }; -// To be consistent with other primitive type. -/// Boolean type. -pub type Bool = bool; - /// Extension trait for [bool]. pub trait BoolOps { #[allow(clippy::new_ret_no_self)] @@ -24,36 +21,27 @@ pub trait BoolOps { } } -impl BoolOps for Bool {} - -impl ComptimeType for Bool { - fn into_expand(self) -> Self::ExpandType { - ExpandElementTyped::new(self.into()) - } -} +impl BoolOps for bool {} impl CubeType for bool { type ExpandType = ExpandElementTyped; } -impl CubePrimitive for Bool { +impl CubePrimitive for bool { fn as_elem() -> Elem { Elem::Bool } } -impl ExpandElementBaseInit for bool { - fn init_elem(context: &mut CubeContext, elem: ExpandElement) -> ExpandElement { - init_expand_element(context, elem) +impl IntoRuntime for bool { + fn __expand_runtime_method(self, context: &mut CubeContext) -> ExpandElementTyped { + let expand: ExpandElementTyped = self.into(); + Init::init(expand, context) } } -impl Vectorized for bool { - fn vectorization_factor(&self) -> crate::prelude::UInt { - todo!() - } - - fn vectorize(self, _factor: crate::prelude::UInt) -> Self { - todo!() +impl ExpandElementBaseInit for bool { + fn init_elem(context: &mut CubeContext, elem: ExpandElement) -> ExpandElement { + init_expand_element(context, elem) } } diff --git a/crates/cubecl-core/src/frontend/element/cast.rs b/crates/cubecl-core/src/frontend/element/cast.rs index 68998fae..a4765e4e 100644 --- a/crates/cubecl-core/src/frontend/element/cast.rs +++ b/crates/cubecl-core/src/frontend/element/cast.rs @@ -5,24 +5,21 @@ use crate::{ ir::Operator, }; +use super::ExpandElementTyped; + /// Enable elegant casting from any to any CubeElem pub trait Cast: CubePrimitive { fn cast_from(value: From) -> Self; - fn __expand_cast_from( + fn __expand_cast_from( context: &mut CubeContext, - value: From, - ) -> ::ExpandType - where - From: Into, - { - let value: ExpandElement = value.into(); - let var: Variable = *value; + value: ExpandElementTyped, + ) -> ::ExpandType { let new_var = context.create_local(Item::vectorized( ::as_elem(), - var.item().vectorization, + value.expand.item().vectorization, )); - assign::expand(context, value, new_var.clone()); + assign::expand(context, value, new_var.clone().into()); new_var.into() } } diff --git a/crates/cubecl-core/src/frontend/element/cube_elem.rs b/crates/cubecl-core/src/frontend/element/cube_elem.rs index dbc709fe..34382b94 100644 --- a/crates/cubecl-core/src/frontend/element/cube_elem.rs +++ b/crates/cubecl-core/src/frontend/element/cube_elem.rs @@ -1,15 +1,13 @@ -use crate::frontend::UInt; use crate::frontend::{CubeType, ExpandElement}; use crate::ir::{Elem, Variable}; -use super::{ExpandElementTyped, Vectorized}; +use super::{ExpandElementTyped, IntoRuntime}; /// Form of CubeType that encapsulates all primitive types: /// Numeric, UInt, Bool pub trait CubePrimitive: CubeType> - + Vectorized - + core::cmp::Eq + + IntoRuntime + core::cmp::PartialEq + Send + Sync @@ -41,12 +39,3 @@ impl_into_expand_element!(bool); impl_into_expand_element!(f32); impl_into_expand_element!(i32); impl_into_expand_element!(i64); - -/// Useful for Comptime -impl From for ExpandElement { - fn from(value: UInt) -> Self { - ExpandElement::Plain(crate::ir::Variable::ConstantScalar( - crate::ir::ConstantScalarValue::UInt(value.val as u64), - )) - } -} diff --git a/crates/cubecl-core/src/frontend/element/float.rs b/crates/cubecl-core/src/frontend/element/float.rs index 4201c20f..7de1d7b7 100644 --- a/crates/cubecl-core/src/frontend/element/float.rs +++ b/crates/cubecl-core/src/frontend/element/float.rs @@ -1,20 +1,14 @@ +use std::num::NonZero; + use half::{bf16, f16}; -use crate::frontend::{ - Ceil, Cos, Erf, Exp, Floor, Log, Log1p, Normalize, Powf, Recip, Round, Sin, Sqrt, Tanh, -}; -use crate::frontend::{ - ComptimeType, CubeContext, CubePrimitive, CubeType, ExpandElement, ExpandElementBaseInit, - ExpandElementTyped, Numeric, +use crate::{ + ir::{Elem, FloatKind, Item, Vectorization}, + prelude::*, + unexpanded, }; -use crate::ir::{ConstantScalarValue, Elem, FloatKind, Item, Variable, Vectorization}; -use super::{ - init_expand_element, LaunchArgExpand, ScalarArgSettings, UInt, Vectorized, __expand_new, - __expand_vectorized, -}; -use crate::compute::{KernelBuilder, KernelLauncher}; -use crate::Runtime; +use super::Numeric; /// Floating point numbers. Used as input in float kernels pub trait Float: @@ -33,199 +27,127 @@ pub trait Float: + Erf + Recip + Normalize - + From - + core::ops::Add - + core::ops::Sub - + core::ops::Mul - + core::ops::Div - + std::ops::AddAssign - + std::ops::SubAssign - + std::ops::MulAssign - + std::ops::DivAssign - + std::cmp::PartialOrd - + std::cmp::PartialEq + + Into + + core::ops::Add + + core::ops::Sub + + core::ops::Mul + + core::ops::Div + + std::ops::AddAssign + + std::ops::SubAssign + + std::ops::MulAssign + + std::ops::DivAssign + + std::cmp::PartialOrd + + std::cmp::PartialEq { fn new(val: f32) -> Self; - fn vectorized(val: f32, vectorization: UInt) -> Self; - fn vectorized_empty(vectorization: UInt) -> Self; - fn __expand_new( - context: &mut CubeContext, - val: Self::ExpandType, - ) -> ::ExpandType { - __expand_new(context, val, Self::as_elem()) + fn vectorized(val: f32, vectorization: u32) -> Self; + fn vectorized_empty(vectorization: u32) -> Self; + fn __expand_new(context: &mut CubeContext, val: f32) -> ::ExpandType { + __expand_new(context, val) } fn __expand_vectorized( context: &mut CubeContext, - val: Self::ExpandType, - vectorization: UInt, + val: f32, + vectorization: u32, ) -> ::ExpandType { __expand_vectorized(context, val, vectorization, Self::as_elem()) } fn __expand_vectorized_empty( context: &mut CubeContext, - vectorization: UInt, + vectorization: u32, ) -> ::ExpandType; } macro_rules! impl_float { - ($type:ident, $primitive:ty) => { - #[derive(Clone, Copy)] - pub struct $type { - pub val: f32, - pub vectorization: u8, - } - - impl CubeType for $type { - type ExpandType = ExpandElementTyped<$type>; + (half $primitive:ident, $kind:ident) => { + impl_float!($primitive, $kind, |val| $primitive::from_f32(val)); + }; + ($primitive:ident, $kind:ident) => { + impl_float!($primitive, $kind, |val| val as $primitive); + }; + ($primitive:ident, $kind:ident, $new:expr) => { + impl CubeType for $primitive { + type ExpandType = ExpandElementTyped<$primitive>; } - impl CubePrimitive for $type { + impl CubePrimitive for $primitive { /// Return the element type to use on GPU fn as_elem() -> Elem { - Elem::Float(FloatKind::$type) + Elem::Float(FloatKind::$kind) } } - impl ComptimeType for $type { - fn into_expand(self) -> Self::ExpandType { - let elem = Self::as_elem(); - let value = self.val as f64; - let value = match elem { - Elem::Float(kind) => ConstantScalarValue::Float(value, kind), - _ => panic!("Wrong elem type"), - }; - - ExpandElementTyped::new(ExpandElement::Plain(Variable::ConstantScalar(value))) + impl IntoRuntime for $primitive { + fn __expand_runtime_method( + self, + context: &mut CubeContext, + ) -> ExpandElementTyped { + let expand: ExpandElementTyped = self.into(); + Init::init(expand, context) } } - impl From<$type> for ExpandElement { - fn from(value: $type) -> Self { - let constant = $type::as_elem().from_constant(value.val.into()); - ExpandElement::Plain(constant) - } - } + impl Numeric for $primitive {} - impl Numeric for $type { - type Primitive = $primitive; - } + impl Vectorized for $primitive { + fn vectorization_factor(&self) -> u32 { + 1 + } - impl From for $type { - fn from(val: u32) -> Self { - $type::from_int(val) + fn vectorize(self, _factor: u32) -> Self { + unexpanded!() } } - impl ExpandElementBaseInit for $type { + impl ExpandElementBaseInit for $primitive { fn init_elem(context: &mut CubeContext, elem: ExpandElement) -> ExpandElement { init_expand_element(context, elem) } } - impl Float for $type { + impl Float for $primitive { fn new(val: f32) -> Self { - Self { - val, - vectorization: 1, - } + $new(val) } - fn vectorized(val: f32, vectorization: UInt) -> Self { - if vectorization.val == 1 { - Self::new(val) - } else { - Self { - val, - vectorization: vectorization.val as u8, - } - } + fn vectorized(val: f32, _vectorization: u32) -> Self { + Self::new(val) } - fn vectorized_empty(vectorization: UInt) -> Self { + fn vectorized_empty(vectorization: u32) -> Self { Self::vectorized(0., vectorization) } fn __expand_vectorized_empty( context: &mut CubeContext, - vectorization: UInt, + vectorization: u32, ) -> ::ExpandType { - if vectorization.val == 1 { - Self::__expand_new(context, ExpandElementTyped::from_lit(0.)) - } else { - context - .create_local(Item::vectorized(Self::as_elem(), vectorization.val as u8)) - .into() - } + context + .create_local(Item::vectorized( + Self::as_elem(), + NonZero::new(vectorization as u8), + )) + .into() } } - impl LaunchArgExpand for $type { + impl LaunchArgExpand for $primitive { fn expand( builder: &mut KernelBuilder, vectorization: Vectorization, ) -> ExpandElementTyped { - assert_eq!(vectorization, 1, "Attempted to vectorize a scalar"); - builder.scalar($type::as_elem()).into() - } - } - - impl Vectorized for $type { - fn vectorization_factor(&self) -> UInt { - UInt { - val: self.vectorization as u32, - vectorization: 1, - } - } - - fn vectorize(mut self, factor: UInt) -> Self { - self.vectorization = factor.vectorization; - self + assert_eq!(vectorization, None, "Attempted to vectorize a scalar"); + builder.scalar($primitive::as_elem()).into() } } }; } -impl_float!(F16, f16); -impl_float!(BF16, bf16); -impl_float!(F32, f32); -impl_float!(F64, f64); - -impl From for F32 { - fn from(value: f32) -> Self { - Self { - val: value, - vectorization: 1, - } - } -} - -impl From for BF16 { - fn from(value: f32) -> Self { - Self { - val: value, - vectorization: 1, - } - } -} - -impl From for F16 { - fn from(value: f32) -> Self { - Self { - val: value, - vectorization: 1, - } - } -} - -impl From for F64 { - fn from(value: f32) -> Self { - Self { - val: value, - vectorization: 1, - } - } -} +impl_float!(half f16, F16); +impl_float!(half bf16, BF16); +impl_float!(f32, F32); +impl_float!(f64, F64); impl ScalarArgSettings for f16 { fn register(&self, settings: &mut KernelLauncher) { diff --git a/crates/cubecl-core/src/frontend/element/int.rs b/crates/cubecl-core/src/frontend/element/int.rs index 7579ea79..a498b86d 100644 --- a/crates/cubecl-core/src/frontend/element/int.rs +++ b/crates/cubecl-core/src/frontend/element/int.rs @@ -1,110 +1,93 @@ -use crate::compute::{KernelBuilder, KernelLauncher}; use crate::frontend::{ - ComptimeType, CubeContext, CubePrimitive, CubeType, ExpandElement, ExpandElementBaseInit, - ExpandElementTyped, Numeric, + CubeContext, CubePrimitive, CubeType, ExpandElement, ExpandElementBaseInit, ExpandElementTyped, + Numeric, }; -use crate::ir::{ConstantScalarValue, Elem, IntKind, Variable, Vectorization}; +use crate::ir::{Elem, IntKind, Vectorization}; use crate::Runtime; +use crate::{ + compute::{KernelBuilder, KernelLauncher}, + unexpanded, +}; use super::{ - init_expand_element, LaunchArgExpand, ScalarArgSettings, UInt, Vectorized, __expand_new, - __expand_vectorized, + init_expand_element, Init, IntoRuntime, LaunchArgExpand, ScalarArgSettings, Vectorized, + __expand_new, __expand_vectorized, }; -/// Signed integer. Used as input in int kernels +/// Signed or unsigned integer. Used as input in int kernels pub trait Int: Numeric + std::ops::Rem - + From - + core::ops::Add - + core::ops::Sub - + core::ops::Mul - + core::ops::Div - + std::ops::AddAssign - + std::ops::SubAssign - + std::ops::MulAssign - + std::ops::DivAssign - + std::cmp::PartialOrd - + std::cmp::PartialEq + + core::ops::Add + + core::ops::Sub + + core::ops::Mul + + core::ops::Div + + core::ops::BitOr + + core::ops::BitAnd + + core::ops::BitXor + + core::ops::Shl + + core::ops::Shr + + std::ops::RemAssign + + std::ops::AddAssign + + std::ops::SubAssign + + std::ops::MulAssign + + std::ops::DivAssign + + std::ops::BitOrAssign + + std::ops::BitAndAssign + + std::ops::BitXorAssign + + std::ops::ShlAssign + + std::ops::ShrAssign + + std::cmp::PartialOrd + + std::cmp::PartialEq { fn new(val: i64) -> Self; - fn vectorized(val: i64, vectorization: UInt) -> Self; - fn __expand_new( - context: &mut CubeContext, - val: Self::ExpandType, - ) -> ::ExpandType { - __expand_new(context, val, Self::as_elem()) + fn vectorized(val: i64, vectorization: u32) -> Self; + fn __expand_new(context: &mut CubeContext, val: i64) -> ::ExpandType { + __expand_new(context, val) } fn __expand_vectorized( context: &mut CubeContext, - val: Self::ExpandType, - vectorization: UInt, + val: i64, + vectorization: u32, ) -> ::ExpandType { __expand_vectorized(context, val, vectorization, Self::as_elem()) } } macro_rules! impl_int { - ($type:ident, $primitive:ty) => { - #[allow(clippy::derived_hash_with_manual_eq)] - #[derive(Clone, Copy, Hash)] - pub struct $type { - pub val: $primitive, - pub vectorization: u8, - } - + ($type:ident, $kind:ident) => { impl CubeType for $type { type ExpandType = ExpandElementTyped; } impl CubePrimitive for $type { fn as_elem() -> Elem { - Elem::Int(IntKind::$type) + Elem::Int(IntKind::$kind) } } - impl From for $type { - fn from(val: u32) -> Self { - Self { - val: val as $primitive, - vectorization: 1, - } - } - } - - impl From for $type { - fn from(val: i32) -> Self { - Self { - val: val as $primitive, - vectorization: 1, - } + impl IntoRuntime for $type { + fn __expand_runtime_method( + self, + context: &mut CubeContext, + ) -> ExpandElementTyped { + let expand: ExpandElementTyped = self.into(); + Init::init(expand, context) } } - impl ComptimeType for $type { - fn into_expand(self) -> Self::ExpandType { - let elem = Self::as_elem(); - let value = match elem { - Elem::Int(kind) => ConstantScalarValue::Int(self.val as i64, kind), - Elem::UInt => ConstantScalarValue::UInt(self.val as u64), - _ => panic!("Wrong elem type"), - }; + impl Numeric for $type {} - ExpandElementTyped::new(ExpandElement::Plain(Variable::ConstantScalar(value))) + impl Vectorized for $type { + fn vectorization_factor(&self) -> u32 { + 1 } - } - impl From<$type> for ExpandElement { - fn from(value: $type) -> Self { - let constant = $type::as_elem().from_constant(value.val.into()); - ExpandElement::Plain(constant) + fn vectorize(self, _factor: u32) -> Self { + unexpanded!() } } - impl Numeric for $type { - type Primitive = $primitive; - } - impl ExpandElementBaseInit for $type { fn init_elem(context: &mut CubeContext, elem: ExpandElement) -> ExpandElement { init_expand_element(context, elem) @@ -113,21 +96,11 @@ macro_rules! impl_int { impl Int for $type { fn new(val: i64) -> Self { - Self { - val: val as $primitive, - vectorization: 1, - } + val as $type } - fn vectorized(val: i64, vectorization: UInt) -> Self { - if vectorization.val == 1 { - Self::new(val) - } else { - Self { - val: val as $primitive, - vectorization: vectorization.val as u8, - } - } + fn vectorized(val: i64, _vectorization: u32) -> Self { + Self::new(val) } } @@ -136,36 +109,33 @@ macro_rules! impl_int { builder: &mut KernelBuilder, vectorization: Vectorization, ) -> ExpandElementTyped { - assert_eq!(vectorization, 1, "Attempted to vectorize a scalar"); + assert_eq!(vectorization, None, "Attempted to vectorize a scalar"); builder.scalar($type::as_elem()).into() } } + }; +} - impl Vectorized for $type { - fn vectorization_factor(&self) -> UInt { - UInt { - val: self.vectorization as u32, - vectorization: 1, - } - } +impl_int!(i32, I32); +impl_int!(i64, I64); - fn vectorize(mut self, factor: UInt) -> Self { - self.vectorization = factor.vectorization; - self - } - } - }; +impl Int for u32 { + fn new(val: i64) -> Self { + val as u32 + } + + fn vectorized(val: i64, _vectorization: u32) -> Self { + Self::new(val) + } } -impl_int!(I32, i32); -impl_int!(I64, i64); +impl Vectorized for u32 { + fn vectorization_factor(&self) -> u32 { + 1 + } -impl From for I64 { - fn from(value: i64) -> Self { - Self { - val: value, - vectorization: 1, - } + fn vectorize(self, _factor: u32) -> Self { + unexpanded!() } } diff --git a/crates/cubecl-core/src/frontend/element/mod.rs b/crates/cubecl-core/src/frontend/element/mod.rs index e1aeee63..f3a8d6cb 100644 --- a/crates/cubecl-core/src/frontend/element/mod.rs +++ b/crates/cubecl-core/src/frontend/element/mod.rs @@ -24,5 +24,4 @@ pub use numeric::*; pub use shared_memory::*; pub use slice::*; pub use tensor::*; -pub use uint::*; pub use vectorized::*; diff --git a/crates/cubecl-core/src/frontend/element/numeric.rs b/crates/cubecl-core/src/frontend/element/numeric.rs index 0d57aa5a..9c19b1c4 100644 --- a/crates/cubecl-core/src/frontend/element/numeric.rs +++ b/crates/cubecl-core/src/frontend/element/numeric.rs @@ -1,5 +1,8 @@ +use std::num::NonZero; + +use num_traits::NumCast; + use crate::compute::KernelLauncher; -use crate::frontend::{CubeContext, CubePrimitive, CubeType}; use crate::ir::{Item, Variable}; use crate::prelude::Clamp; use crate::Runtime; @@ -7,10 +10,14 @@ use crate::{ frontend::{index_assign, Abs, Max, Min, Remainder}, unexpanded, }; +use crate::{ + frontend::{CubeContext, CubePrimitive, CubeType}, + prelude::CubeIndexMut, +}; use super::{ ArgSettings, ExpandElement, ExpandElementBaseInit, ExpandElementTyped, LaunchArg, - LaunchArgExpand, UInt, I64, + LaunchArgExpand, Vectorized, }; /// Type that encompasses both (unsigned or signed) integers and floats @@ -22,9 +29,15 @@ pub trait Numeric: + Min + Clamp + Remainder - + ExpandElementBaseInit + + Vectorized + CubePrimitive + LaunchArgExpand + + ScalarArgSettings + + ExpandElementBaseInit + + Into> + + CubeIndexMut + + CubeIndexMut, Output = Self> + + num_traits::NumCast + std::ops::AddAssign + std::ops::SubAssign + std::ops::MulAssign @@ -34,24 +47,8 @@ pub trait Numeric: + std::ops::Mul + std::ops::Div + std::cmp::PartialOrd - + core::ops::Index - + core::ops::IndexMut - + core::ops::Index - + core::ops::IndexMut - + From - + std::ops::Add - + std::ops::Sub - + std::ops::Mul - + std::ops::Div - + std::ops::AddAssign - + std::ops::SubAssign - + std::ops::MulAssign - + std::ops::DivAssign - + std::cmp::PartialOrd - + std::cmp::PartialEq + + std::cmp::PartialEq { - type Primitive: ScalarArgSettings; - /// Create a new constant numeric. /// /// Note: since this must work for both integer and float @@ -60,8 +57,8 @@ pub trait Numeric: /// /// This method panics when unexpanded. For creating an element /// with a val, use the new method of the sub type. - fn from_int(_val: u32) -> Self { - unexpanded!() + fn from_int(val: i64) -> Self { + ::from(val).unwrap() } fn from_vec(_vec: [u32; D]) -> Self { @@ -70,7 +67,7 @@ pub trait Numeric: fn __expand_from_int( _context: &mut CubeContext, - val: ExpandElementTyped, + val: ExpandElementTyped, ) -> ::ExpandType { let elem = Self::as_elem(); let var: Variable = elem.constant_from_i64(val.constant().unwrap().as_i64()); @@ -80,16 +77,19 @@ pub trait Numeric: fn __expand_from_vec( context: &mut CubeContext, - vec: [ExpandElementTyped; D], + vec: [u32; D], ) -> ::ExpandType { - let new_var = context.create_local(Item::vectorized(Self::as_elem(), vec.len() as u8)); + let new_var = context.create_local(Item::vectorized( + Self::as_elem(), + NonZero::new(vec.len() as u8), + )); let elem = Self::as_elem(); for (i, element) in vec.iter().enumerate() { - let var: Variable = elem.constant_from_i64(element.constant().unwrap().as_i64()); + let var: Variable = elem.constant_from_i64(*element as i64); let expand = ExpandElement::Plain(var); - index_assign::expand::( + index_assign::expand::( context, new_var.clone().into(), ExpandElementTyped::from_lit(i), @@ -110,7 +110,7 @@ pub trait ScalarArgSettings: Send + Sync { #[derive(new)] pub struct ScalarArg { - elem: T::Primitive, + elem: T, } impl ArgSettings for ScalarArg { diff --git a/crates/cubecl-core/src/frontend/element/shared_memory.rs b/crates/cubecl-core/src/frontend/element/shared_memory.rs index 4ca4941e..a2540f1e 100644 --- a/crates/cubecl-core/src/frontend/element/shared_memory.rs +++ b/crates/cubecl-core/src/frontend/element/shared_memory.rs @@ -1,11 +1,11 @@ -use std::marker::PhantomData; +use std::{marker::PhantomData, num::NonZero}; use crate::{ frontend::{indexation::Index, CubeContext, CubePrimitive, CubeType}, ir::Item, }; -use super::{ExpandElementTyped, Init, UInt}; +use super::{ExpandElementTyped, Init, IntoRuntime}; #[derive(Clone, Copy)] pub struct SharedMemory { @@ -18,6 +18,12 @@ impl Init for ExpandElementTyped> { } } +impl IntoRuntime for SharedMemory { + fn __expand_runtime_method(self, _context: &mut CubeContext) -> ExpandElementTyped { + unimplemented!("Shared memory can't exist at comptime"); + } +} + impl CubeType for SharedMemory { type ExpandType = ExpandElementTyped>; } @@ -27,36 +33,34 @@ impl SharedMemory { SharedMemory { _val: PhantomData } } - pub fn vectorized(_size: S, _vectorization_factor: UInt) -> Self { + pub fn vectorized(_size: S, _vectorization_factor: u32) -> Self { SharedMemory { _val: PhantomData } } - pub fn __expand_vectorized( + pub fn __expand_vectorized( context: &mut CubeContext, - size: S, - vectorization_factor: UInt, + size: ExpandElementTyped, + vectorization_factor: u32, ) -> ::ExpandType { - let size = size.value(); - let size = match size { - crate::ir::Variable::ConstantScalar(value) => value.as_u32(), - _ => panic!("Shared memory need constant initialization value"), - }; + let size = size + .constant() + .expect("Shared memory need constant initialization value") + .as_u32(); let var = context.create_shared( - Item::vectorized(T::as_elem(), vectorization_factor.val as u8), + Item::vectorized(T::as_elem(), NonZero::new(vectorization_factor as u8)), size, ); ExpandElementTyped::new(var) } - pub fn __expand_new( + pub fn __expand_new( context: &mut CubeContext, - size: S, + size: ExpandElementTyped, ) -> ::ExpandType { - let size = size.value(); - let size = match size { - crate::ir::Variable::ConstantScalar(value) => value.as_u32(), - _ => panic!("Shared memory need constant initialization value"), - }; + let size = size + .constant() + .expect("Shared memory need constant initialization value") + .as_u32(); let var = context.create_shared(Item::new(T::as_elem()), size); ExpandElementTyped::new(var) } diff --git a/crates/cubecl-core/src/frontend/element/slice.rs b/crates/cubecl-core/src/frontend/element/slice.rs index 582353ac..0ed56965 100644 --- a/crates/cubecl-core/src/frontend/element/slice.rs +++ b/crates/cubecl-core/src/frontend/element/slice.rs @@ -2,7 +2,6 @@ use std::marker::PhantomData; use super::{ Array, CubePrimitive, CubeType, ExpandElement, ExpandElementTyped, Init, SharedMemory, Tensor, - UInt, }; use crate::{ frontend::indexation::Index, @@ -25,14 +24,16 @@ pub struct SliceMut<'a, E> { impl<'a, E> Slice<'a, E> { /// Get the length of the slice. - pub fn len(&self) -> UInt { + #[allow(clippy::len_without_is_empty)] + pub fn len(&self) -> u32 { unexpanded!() } } impl<'a, E> SliceMut<'a, E> { /// Get the length of the slice. - pub fn len(&self) -> UInt { + #[allow(clippy::len_without_is_empty)] + pub fn len(&self) -> u32 { unexpanded!() } } @@ -68,11 +69,11 @@ pub trait SliceOperator: CubeType { unexpanded!() } /// Expand function of [SliceOperator::slice]. - fn __expand_slice( + fn __expand_slice( context: &mut CubeContext, expand: Self::Expand, - start: Start, - end: End, + start: ExpandElementTyped, + end: ExpandElementTyped, ) -> ExpandElementTyped> { expand.__expand_slice_method(context, start, end) } @@ -88,11 +89,11 @@ pub trait SliceOperator: CubeType { } /// Expand function of [SliceOperator::slice_mut]. - fn __expand_slice_mut( + fn __expand_slice_mut( context: &mut CubeContext, expand: Self::Expand, - start: Start, - end: End, + start: ExpandElementTyped, + end: ExpandElementTyped, ) -> ExpandElementTyped> { expand.__expand_slice_mut_method(context, start, end) } @@ -112,11 +113,11 @@ pub trait SliceOperator: CubeType { } /// Expand function of [SliceOperator::slice_mut_unsafe]. - fn __expand_slice_mut_unsafe( + fn __expand_slice_mut_unsafe( context: &mut CubeContext, expand: Self::Expand, - start: Start, - end: End, + start: ExpandElementTyped, + end: ExpandElementTyped, ) -> ExpandElementTyped> { expand.__expand_slice_mut_unsafe_method(context, start, end) } @@ -176,29 +177,29 @@ pub trait SliceOperatorExpand: Into + Clone { end: End, ) -> ExpandElement; - fn __expand_slice_method( + fn __expand_slice_method( &self, context: &mut CubeContext, - start: Start, - end: End, + start: ExpandElementTyped, + end: ExpandElementTyped, ) -> ExpandElementTyped> { ExpandElementTyped::new(self.slice_base(context, start, end)) } - fn __expand_slice_mut_method( + fn __expand_slice_mut_method( &self, context: &mut CubeContext, - start: Start, - end: End, + start: ExpandElementTyped, + end: ExpandElementTyped, ) -> ExpandElementTyped> { ExpandElementTyped::new(self.slice_base(context, start, end)) } - fn __expand_slice_mut_unsafe_method( + fn __expand_slice_mut_unsafe_method( &self, context: &mut CubeContext, - start: Start, - end: End, + start: ExpandElementTyped, + end: ExpandElementTyped, ) -> ExpandElementTyped> { ExpandElementTyped::new(self.slice_base(context, start, end)) } diff --git a/crates/cubecl-core/src/frontend/element/tensor.rs b/crates/cubecl-core/src/frontend/element/tensor.rs index 9ffce8e6..94ba711e 100644 --- a/crates/cubecl-core/src/frontend/element/tensor.rs +++ b/crates/cubecl-core/src/frontend/element/tensor.rs @@ -1,13 +1,13 @@ use super::{ExpandElementBaseInit, ExpandElementTyped, LaunchArgExpand}; use crate::{ frontend::{ - indexation::Index, ArgSettings, CubeContext, CubePrimitive, CubeType, ExpandElement, UInt, + indexation::Index, ArgSettings, CubeContext, CubePrimitive, CubeType, ExpandElement, }, ir::{Elem, Item, Metadata, Variable, Vectorization}, prelude::{KernelBuilder, KernelLauncher}, unexpanded, KernelSettings, LaunchArg, Runtime, }; -use std::marker::PhantomData; +use std::{marker::PhantomData, num::NonZero}; /// The tensor type is similar to the [array type](crate::prelude::Array), however it comes with more /// metadata such as [stride](Tensor::stride) and [shape](Tensor::shape). @@ -143,7 +143,7 @@ impl<'a, R: Runtime> ArgSettings for TensorArg<'a, R> { TensorArg::Handle { handle: _, vectorization_factor, - } => settings.vectorize_input(position, *vectorization_factor), + } => settings.vectorize_input(position, NonZero::new(*vectorization_factor)), TensorArg::Alias { input_pos: _ } => { panic!("Not yet supported, only output can be aliased for now."); } @@ -155,7 +155,7 @@ impl<'a, R: Runtime> ArgSettings for TensorArg<'a, R> { TensorArg::Handle { handle: _, vectorization_factor, - } => settings.vectorize_output(position, *vectorization_factor), + } => settings.vectorize_output(position, NonZero::new(*vectorization_factor)), TensorArg::Alias { input_pos } => { settings.mappings.push(crate::InplaceMapping { pos_input: *input_pos, @@ -169,12 +169,12 @@ impl<'a, R: Runtime> ArgSettings for TensorArg<'a, R> { impl Tensor { /// Obtain the stride of input at dimension dim - pub fn stride(&self, _dim: C) -> UInt { + pub fn stride(&self, _dim: C) -> u32 { unexpanded!() } /// Obtain the shape of input at dimension dim - pub fn shape(&self, _dim: C) -> UInt { + pub fn shape(&self, _dim: C) -> u32 { unexpanded!() } @@ -184,26 +184,28 @@ impl Tensor { /// /// The length will be affected by the vectorization factor. To obtain the number of elements, /// you should multiply the length by the vectorization factor. - pub fn len(&self) -> UInt { + #[allow(clippy::len_without_is_empty)] + pub fn len(&self) -> u32 { unexpanded!() } /// Returns the rank of the tensor. - pub fn rank(&self) -> UInt { + pub fn rank(&self) -> u32 { unexpanded!() } } impl ExpandElementTyped { // Expanded version of stride - pub fn __expand_stride_method( + pub fn __expand_stride_method( self, context: &mut CubeContext, - dim: C, - ) -> ExpandElementTyped { + dim: ExpandElementTyped, + ) -> ExpandElementTyped { + let dim: ExpandElement = dim.into(); let out = context.create_local(Item::new(Elem::UInt)); context.register(Metadata::Stride { - dim: dim.value(), + dim: *dim, var: self.expand.into(), out: out.clone().into(), }); @@ -211,14 +213,15 @@ impl ExpandElementTyped { } // Expanded version of shape - pub fn __expand_shape_method( + pub fn __expand_shape_method( self, context: &mut CubeContext, - dim: C, - ) -> ExpandElementTyped { + dim: ExpandElementTyped, + ) -> ExpandElementTyped { + let dim: ExpandElement = dim.into(); let out = context.create_local(Item::new(Elem::UInt)); context.register(Metadata::Shape { - dim: dim.value(), + dim: *dim, var: self.expand.into(), out: out.clone().into(), }); @@ -226,7 +229,7 @@ impl ExpandElementTyped { } // Expanded version of len - pub fn __expand_len_method(self, context: &mut CubeContext) -> ExpandElementTyped { + pub fn __expand_len_method(self, context: &mut CubeContext) -> ExpandElementTyped { let out = context.create_local(Item::new(Elem::UInt)); context.register(Metadata::Length { var: self.expand.into(), @@ -236,7 +239,7 @@ impl ExpandElementTyped { } // Expanded version of rank. - pub fn __expand_rank_method(self, _context: &mut CubeContext) -> ExpandElementTyped { + pub fn __expand_rank_method(self, _context: &mut CubeContext) -> ExpandElementTyped { ExpandElement::Plain(Variable::Rank).into() } } diff --git a/crates/cubecl-core/src/frontend/element/uint.rs b/crates/cubecl-core/src/frontend/element/uint.rs index 72f2497e..56283a8f 100644 --- a/crates/cubecl-core/src/frontend/element/uint.rs +++ b/crates/cubecl-core/src/frontend/element/uint.rs @@ -1,55 +1,43 @@ use crate::frontend::{CubeContext, CubePrimitive, CubeType, ExpandElement, Numeric}; use crate::ir::{Elem, Vectorization}; use crate::prelude::{KernelBuilder, KernelLauncher}; -use crate::{frontend::Comptime, Runtime}; +use crate::Runtime; use super::{ - init_expand_element, ExpandElementBaseInit, ExpandElementTyped, LaunchArgExpand, - ScalarArgSettings, Vectorized, __expand_new, __expand_vectorized, + init_expand_element, ExpandElementBaseInit, ExpandElementTyped, Init, IntoRuntime, + LaunchArgExpand, ScalarArgSettings, }; -#[allow(clippy::derived_hash_with_manual_eq)] -#[derive(Clone, Copy, Hash)] -/// An unsigned int. -/// Preferred for indexing operations -pub struct UInt { - pub val: u32, - pub vectorization: u8, -} - -impl core::fmt::Debug for UInt { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - if self.vectorization == 1 { - f.write_fmt(format_args!("{}", self.val)) - } else { - f.write_fmt(format_args!("{}-{}", self.val, self.vectorization)) - } - } -} - -impl CubeType for UInt { +impl CubeType for u32 { type ExpandType = ExpandElementTyped; } -impl ExpandElementBaseInit for UInt { +impl ExpandElementBaseInit for u32 { fn init_elem(context: &mut CubeContext, elem: ExpandElement) -> ExpandElement { init_expand_element(context, elem) } } -impl CubePrimitive for UInt { +impl CubePrimitive for u32 { fn as_elem() -> Elem { Elem::UInt } } -impl LaunchArgExpand for UInt { +impl IntoRuntime for u32 { + fn __expand_runtime_method(self, context: &mut CubeContext) -> ExpandElementTyped { + let expand: ExpandElementTyped = self.into(); + Init::init(expand, context) + } +} + +impl LaunchArgExpand for u32 { fn expand( builder: &mut KernelBuilder, vectorization: Vectorization, ) -> ExpandElementTyped { - assert_eq!(vectorization, 1, "Attempted to vectorize a scalar"); - builder.scalar(UInt::as_elem()).into() + assert_eq!(vectorization, None, "Attempted to vectorize a scalar"); + builder.scalar(u32::as_elem()).into() } } @@ -59,78 +47,4 @@ impl ScalarArgSettings for u32 { } } -impl Numeric for UInt { - type Primitive = u32; -} - -impl UInt { - pub const fn new(val: u32) -> Self { - Self { - val, - vectorization: 1, - } - } - - pub fn vectorized(val: u32, vectorization: UInt) -> Self { - if vectorization.val == 1 { - Self::new(val) - } else { - Self { - val, - vectorization: vectorization.val as u8, - } - } - } - pub fn __expand_new( - context: &mut CubeContext, - val: ::ExpandType, - ) -> ::ExpandType { - __expand_new(context, val, Self::as_elem()) - } - - pub fn __expand_vectorized( - context: &mut CubeContext, - val: ::ExpandType, - vectorization: UInt, - ) -> ::ExpandType { - __expand_vectorized(context, val, vectorization, Self::as_elem()) - } -} - -impl From for UInt { - fn from(value: u32) -> Self { - UInt::new(value) - } -} - -impl From> for UInt { - fn from(value: Comptime) -> Self { - UInt::new(value.inner) - } -} - -impl From for UInt { - fn from(value: usize) -> Self { - UInt::new(value as u32) - } -} - -impl From for UInt { - fn from(value: i32) -> Self { - UInt::new(value as u32) - } -} - -impl Vectorized for UInt { - fn vectorization_factor(&self) -> UInt { - UInt { - val: self.vectorization as u32, - vectorization: 1, - } - } - - fn vectorize(mut self, factor: UInt) -> Self { - self.vectorization = factor.vectorization; - self - } -} +impl Numeric for u32 {} diff --git a/crates/cubecl-core/src/frontend/element/vectorized.rs b/crates/cubecl-core/src/frontend/element/vectorized.rs index e9497acf..126a246a 100644 --- a/crates/cubecl-core/src/frontend/element/vectorized.rs +++ b/crates/cubecl-core/src/frontend/element/vectorized.rs @@ -1,68 +1,105 @@ use crate::unexpanded; -use super::{CubeType, ExpandElement, Tensor, UInt}; +use super::{Array, CubeType, ExpandElement, Tensor}; pub trait Vectorized { - fn vectorization_factor(&self) -> UInt; - fn vectorize(self, factor: UInt) -> Self; + fn vectorization_factor(&self) -> u32; + fn vectorize(self, factor: u32) -> Self; } -impl Vectorized for Tensor { - fn vectorization_factor(&self) -> UInt { +impl Vectorized for Tensor { + fn vectorization_factor(&self) -> u32 { unexpanded!() } - fn vectorize(self, _factor: UInt) -> Self { + fn vectorize(self, _factor: u32) -> Self { unexpanded!() } } -impl Vectorized for &Tensor { - fn vectorization_factor(&self) -> UInt { +impl Vectorized for &Tensor { + fn vectorization_factor(&self) -> u32 { unexpanded!() } - fn vectorize(self, _factor: UInt) -> Self { + fn vectorize(self, _factor: u32) -> Self { unexpanded!() } } -impl Vectorized for &mut Tensor { - fn vectorization_factor(&self) -> UInt { +impl Vectorized for Array { + fn vectorization_factor(&self) -> u32 { unexpanded!() } - fn vectorize(self, _factor: UInt) -> Self { + fn vectorize(self, _factor: u32) -> Self { + unexpanded!() + } +} + +impl Vectorized for &Array { + fn vectorization_factor(&self) -> u32 { + unexpanded!() + } + + fn vectorize(self, _factor: u32) -> Self { + unexpanded!() + } +} + +impl Vectorized for &mut Tensor { + fn vectorization_factor(&self) -> u32 { + unexpanded!() + } + + fn vectorize(self, _factor: u32) -> Self { unexpanded!() } } impl Vectorized for ExpandElement { - fn vectorization_factor(&self) -> UInt { + fn vectorization_factor(&self) -> u32 { let var = match self { ExpandElement::Managed(var) => var, ExpandElement::Plain(var) => var, }; - UInt::new(var.item().vectorization as u32) + var.item().vectorization.map(|it| it.get()).unwrap_or(1) as u32 } - fn vectorize(self, _factor: UInt) -> Self { + fn vectorize(self, _factor: u32) -> Self { todo!() } } impl Vectorized for &ExpandElement { - fn vectorization_factor(&self) -> UInt { + fn vectorization_factor(&self) -> u32 { let var = match self { ExpandElement::Managed(var) => var, ExpandElement::Plain(var) => var, }; - UInt::new(var.item().vectorization as u32) + var.item().vectorization.map(|it| it.get()).unwrap_or(1) as u32 } - fn vectorize(self, _factor: UInt) -> Self { + fn vectorize(self, _factor: u32) -> Self { todo!() } } + +/// Cubecl intrinsic. Gets the vectorization factor of an element at compile time. +pub fn vectorization_of(_element: &C) -> u32 { + 1 +} + +pub mod vectorization_of { + use crate::prelude::*; + + pub fn expand(_context: &mut CubeContext, element: ExpandElementTyped) -> u32 { + let elem: ExpandElement = element.into(); + elem.item() + .vectorization + .map(|it| it.get() as u32) + .unwrap_or(1) + } +} diff --git a/crates/cubecl-core/src/frontend/indexation.rs b/crates/cubecl-core/src/frontend/indexation.rs index e69ead13..5e710ebf 100644 --- a/crates/cubecl-core/src/frontend/indexation.rs +++ b/crates/cubecl-core/src/frontend/indexation.rs @@ -1,23 +1,27 @@ -use super::{Comptime, ExpandElement, ExpandElementTyped, UInt}; -use crate::ir::{IntKind, Variable}; - -pub trait Index { - fn value(self) -> Variable; +use super::{CubeType, ExpandElement, ExpandElementTyped}; +use crate::{ + ir::{IntKind, Variable}, + unexpanded, +}; + +/// Fake indexation so we can rewrite indexes into scalars as calls to this fake function in the +/// non-expanded function +pub trait CubeIndex { + type Output: CubeType; + + fn cube_idx(&self, _i: T) -> &Self::Output { + unexpanded!() + } } -impl Index for Comptime { - fn value(self) -> Variable { - Variable::ConstantScalar(crate::ir::ConstantScalarValue::UInt(self.inner as u64)) +pub trait CubeIndexMut: CubeIndex { + fn cube_idx_mut(&mut self, _i: T) -> &mut Self::Output { + unexpanded!() } } -impl Index for Comptime { - fn value(self) -> Variable { - Variable::ConstantScalar(crate::ir::ConstantScalarValue::Int( - self.inner as i64, - IntKind::I32, - )) - } +pub trait Index { + fn value(self) -> Variable; } impl Index for i32 { @@ -35,21 +39,14 @@ impl Index for u32 { } } -impl Index for UInt { - fn value(self) -> Variable { - Variable::ConstantScalar(crate::ir::ConstantScalarValue::UInt(self.val as u64)) - } -} - impl Index for ExpandElement { fn value(self) -> Variable { *self } } -impl Index for ExpandElementTyped { +impl Index for ExpandElementTyped { fn value(self) -> Variable { - let value: ExpandElement = self.into(); - value.value() + *self.expand } } diff --git a/crates/cubecl-core/src/frontend/mod.rs b/crates/cubecl-core/src/frontend/mod.rs index b2f11c85..82941a05 100644 --- a/crates/cubecl-core/src/frontend/mod.rs +++ b/crates/cubecl-core/src/frontend/mod.rs @@ -3,7 +3,7 @@ pub mod cmma; pub mod synchronization; mod base; -mod comptime; +mod const_expand; mod context; mod element; mod indexation; @@ -12,9 +12,11 @@ mod sequence; mod subcube; mod topology; -pub use comptime::*; +pub use branch::{RangeExpand, SteppedRangeExpand}; +pub use const_expand::*; pub use context::*; pub use element::*; +pub use indexation::*; pub use operation::*; pub use sequence::*; pub use subcube::*; diff --git a/crates/cubecl-core/src/frontend/operation/assignation.rs b/crates/cubecl-core/src/frontend/operation/assignation.rs index 0f8e05cb..6c782870 100644 --- a/crates/cubecl-core/src/frontend/operation/assignation.rs +++ b/crates/cubecl-core/src/frontend/operation/assignation.rs @@ -1,40 +1,26 @@ -use crate::frontend::{Array, CubeContext, ExpandElement, SharedMemory, Tensor, UInt}; -use crate::frontend::{BF16, F16, F32, F64, I32, I64}; -use crate::{ir, unexpanded}; +use half::{bf16, f16}; -macro_rules! impl_op_assign { - (($tr:ident|$func:ident) => { $($type:ty| $($rhs:ty);*),* }) => { - $( - $( - impl $tr<$rhs> for $type { - fn $func(&mut self, _rhs: $rhs) { - unexpanded!() - } - } - )* - - impl $tr for $type { - fn $func(&mut self, _rhs: Self) { - unexpanded!() - } - } - )* - }; -} +use crate::{ + frontend::{Array, CubeContext, ExpandElement, SharedMemory, Tensor}, + prelude::{CubeIndex, CubeIndexMut, CubeType}, +}; +use crate::{ir, prelude::Index}; pub mod assign { + use crate::prelude::ExpandElementTyped; + use self::ir::{Operator, UnaryOperator}; use super::*; - pub fn expand, O: Into>( + pub fn expand( context: &mut CubeContext, - input: I, - output: O, + input: ExpandElementTyped, + output: ExpandElementTyped, ) { context.register(Operator::Assign(UnaryOperator { - input: *input.into(), - out: *output.into(), + input: *input.expand, + out: *output.expand, })); } } @@ -43,17 +29,16 @@ pub mod index_assign { use crate::{ frontend::CubeType, prelude::{ExpandElementTyped, SliceMut}, - unexpanded, }; use self::ir::{BinaryOperator, Operator, Variable}; use super::*; - pub fn expand>( + pub fn expand>( context: &mut CubeContext, array: ExpandElementTyped, - index: ExpandElementTyped, + index: ExpandElementTyped, value: ExpandElementTyped, ) where A::Output: CubeType + Sized, @@ -74,27 +59,13 @@ pub mod index_assign { macro_rules! impl_index { ($type:ident) => { - impl> core::ops::IndexMut for $type { - fn index_mut(&mut self, _index: I) -> &mut Self::Output { - unexpanded!() - } - } + impl CubeIndexMut for $type {} }; } macro_rules! impl_index_vec { ($($type:ident),*) => { $( - impl core::ops::IndexMut for $type { - fn index_mut(&mut self, _index: UInt) -> &mut Self::Output { - unexpanded!() - } - } - impl core::ops::IndexMut for $type { - fn index_mut(&mut self, _index: u32) -> &mut Self::Output { - unexpanded!() - } - } - + impl CubeIndexMut for $type {} )* }; } @@ -102,13 +73,9 @@ pub mod index_assign { impl_index!(Array); impl_index!(Tensor); impl_index!(SharedMemory); - impl_index_vec!(I64, I32, F16, BF16, F32, F64, UInt); + impl_index_vec!(i64, i32, f16, bf16, f32, f64, u32); - impl<'a, E: CubeType, I: Into> core::ops::IndexMut for SliceMut<'a, E> { - fn index_mut(&mut self, _index: I) -> &mut Self::Output { - unexpanded!() - } - } + impl<'a, E: CubeType, I: Index> CubeIndexMut for SliceMut<'a, E> {} } pub mod index { @@ -118,17 +85,16 @@ pub mod index { CubeType, }, prelude::{ExpandElementTyped, Slice, SliceMut}, - unexpanded, }; use self::ir::{Operator, Variable}; use super::*; - pub fn expand>( + pub fn expand>( context: &mut CubeContext, array: ExpandElementTyped, - index: ExpandElementTyped, + index: ExpandElementTyped, ) -> ExpandElementTyped where A::Output: CubeType + Sized, @@ -153,33 +119,16 @@ pub mod index { macro_rules! impl_index { ($type:ident) => { - impl> core::ops::Index for $type { + impl CubeIndex for $type { type Output = E; - - fn index(&self, _index: I) -> &Self::Output { - unexpanded!() - } } }; } - macro_rules! impl_index_vec { ($($type:ident),*) => { $( - impl core::ops::Index for $type { + impl CubeIndex for $type { type Output = Self; - - fn index(&self, _index: UInt) -> &Self::Output { - unexpanded!() - } - } - - impl core::ops::Index for $type { - type Output = Self; - - fn index(&self, _index: u32) -> &Self::Output { - unexpanded!() - } } )* }; @@ -188,21 +137,14 @@ pub mod index { impl_index!(Array); impl_index!(Tensor); impl_index!(SharedMemory); + impl_index_vec!(i64, i32, f16, bf16, f32, f64, u32); - impl_index_vec!(I64, I32, F16, BF16, F32, F64, UInt); - - impl<'a, E: CubeType, I: Into> core::ops::Index for SliceMut<'a, E> { + impl<'a, E: CubeType, I: Index> CubeIndex for Slice<'a, E> { type Output = E; - fn index(&self, _index: I) -> &Self::Output { - unexpanded!() - } } - impl<'a, E: CubeType, I: Into> core::ops::Index for Slice<'a, E> { + impl<'a, E: CubeType, I: Index> CubeIndex for SliceMut<'a, E> { type Output = E; - fn index(&self, _index: I) -> &Self::Output { - unexpanded!() - } } } @@ -211,10 +153,10 @@ pub mod add_assign_array_op { use super::*; use crate::prelude::{array_assign_binary_op_expand, CubeType, ExpandElementTyped}; - pub fn expand>( + pub fn expand>( context: &mut CubeContext, array: ExpandElementTyped, - index: ExpandElementTyped, + index: ExpandElementTyped, value: ExpandElementTyped, ) where A::Output: CubeType + Sized, @@ -228,10 +170,10 @@ pub mod sub_assign_array_op { use super::*; use crate::prelude::{array_assign_binary_op_expand, CubeType, ExpandElementTyped}; - pub fn expand>( + pub fn expand>( context: &mut CubeContext, array: ExpandElementTyped, - index: ExpandElementTyped, + index: ExpandElementTyped, value: ExpandElementTyped, ) where A::Output: CubeType + Sized, @@ -245,10 +187,10 @@ pub mod mul_assign_array_op { use super::*; use crate::prelude::{array_assign_binary_op_expand, CubeType, ExpandElementTyped}; - pub fn expand>( + pub fn expand>( context: &mut CubeContext, array: ExpandElementTyped, - index: ExpandElementTyped, + index: ExpandElementTyped, value: ExpandElementTyped, ) where A::Output: CubeType + Sized, @@ -262,10 +204,10 @@ pub mod div_assign_array_op { use super::*; use crate::prelude::{array_assign_binary_op_expand, CubeType, ExpandElementTyped}; - pub fn expand>( + pub fn expand>( context: &mut CubeContext, array: ExpandElementTyped, - index: ExpandElementTyped, + index: ExpandElementTyped, value: ExpandElementTyped, ) where A::Output: CubeType + Sized, @@ -274,112 +216,250 @@ pub mod div_assign_array_op { } } -pub mod add_assign_op { - use crate::frontend::{operation::base::assign_op_expand, BF16, F16, F32, F64, I32, I64}; - use core::ops::AddAssign; +pub mod rem_assign_array_op { + use self::ir::Operator; + use super::*; + use crate::prelude::{array_assign_binary_op_expand, CubeType, ExpandElementTyped}; + pub fn expand>( + context: &mut CubeContext, + array: ExpandElementTyped, + index: ExpandElementTyped, + value: ExpandElementTyped, + ) where + A::Output: CubeType + Sized, + { + array_assign_binary_op_expand(context, array, index, value, Operator::Modulo); + } +} + +pub mod bitor_assign_array_op { use self::ir::Operator; + use super::*; + use crate::prelude::{array_assign_binary_op_expand, CubeType, ExpandElementTyped}; + pub fn expand>( + context: &mut CubeContext, + array: ExpandElementTyped, + index: ExpandElementTyped, + value: ExpandElementTyped, + ) where + A::Output: CubeType + Sized, + { + array_assign_binary_op_expand(context, array, index, value, Operator::BitwiseOr); + } +} + +pub mod bitand_assign_array_op { + use self::ir::Operator; use super::*; + use crate::prelude::{array_assign_binary_op_expand, CubeType, ExpandElementTyped}; - pub fn expand, R: Into>( + pub fn expand>( context: &mut CubeContext, - lhs: L, - rhs: R, - ) -> ExpandElement { - assign_op_expand(context, lhs.into(), rhs.into(), Operator::Add) + array: ExpandElementTyped, + index: ExpandElementTyped, + value: ExpandElementTyped, + ) where + A::Output: CubeType + Sized, + { + array_assign_binary_op_expand(context, array, index, value, Operator::BitwiseAnd); } +} + +pub mod bitxor_assign_array_op { + use self::ir::Operator; + use super::*; + use crate::prelude::{array_assign_binary_op_expand, CubeType, ExpandElementTyped}; - impl_op_assign!( - (AddAssign|add_assign) => { - F16 | f32;u32, - F32 | f32;u32, - BF16 | f32;u32, - F64 | f32;u32, - I32 | i32;u32, - I64 | i32;u32, - UInt | u32 - } - ); + pub fn expand>( + context: &mut CubeContext, + array: ExpandElementTyped, + index: ExpandElementTyped, + value: ExpandElementTyped, + ) where + A::Output: CubeType + Sized, + { + array_assign_binary_op_expand(context, array, index, value, Operator::BitwiseXor); + } +} + +pub mod shl_assign_array_op { + use self::ir::Operator; + use super::*; + use crate::prelude::{array_assign_binary_op_expand, CubeType, ExpandElementTyped}; + + pub fn expand>( + context: &mut CubeContext, + array: ExpandElementTyped, + index: ExpandElementTyped, + value: ExpandElementTyped, + ) where + A::Output: CubeType + Sized, + { + array_assign_binary_op_expand(context, array, index, value, Operator::ShiftLeft); + } +} + +pub mod shr_assign_array_op { + use self::ir::Operator; + use super::*; + use crate::prelude::{array_assign_binary_op_expand, CubeType, ExpandElementTyped}; + + pub fn expand>( + context: &mut CubeContext, + array: ExpandElementTyped, + index: ExpandElementTyped, + value: ExpandElementTyped, + ) where + A::Output: CubeType + Sized, + { + array_assign_binary_op_expand(context, array, index, value, Operator::ShiftRight); + } +} + +pub mod add_assign_op { + use std::ops::AddAssign; + + use self::ir::Operator; + use crate::{ + frontend::operation::base::assign_op_expand, + prelude::{CubeType, ExpandElementTyped}, + }; + + use super::*; + + pub fn expand( + context: &mut CubeContext, + lhs: ExpandElementTyped, + rhs: ExpandElementTyped, + ) -> ExpandElementTyped { + assign_op_expand(context, lhs.into(), rhs.into(), Operator::Add).into() + } } pub mod sub_assign_op { use self::ir::Operator; use super::*; - use crate::frontend::{operation::base::assign_op_expand, BF16, F16, F32, F64, I32, I64}; - use core::ops::SubAssign; + use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped}; - pub fn expand, R: Into>( + pub fn expand( context: &mut CubeContext, - lhs: L, - rhs: R, + lhs: ExpandElementTyped, + rhs: ExpandElementTyped, ) -> ExpandElement { assign_op_expand(context, lhs.into(), rhs.into(), Operator::Sub) } - - impl_op_assign!( - (SubAssign|sub_assign) => { - F16 | f32;u32, - F32 | f32;u32, - BF16 | f32;u32, - F64 | f32;u32, - I32 | i32;u32, - I64 | i32;u32, - UInt | u32 - } - ); } pub mod mul_assign_op { use self::ir::Operator; use super::*; - use crate::frontend::{operation::base::assign_op_expand, BF16, F16, F32, F64, I32, I64}; - use core::ops::MulAssign; + use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped}; - pub fn expand, R: Into>( + pub fn expand( context: &mut CubeContext, - lhs: L, - rhs: R, + lhs: ExpandElementTyped, + rhs: ExpandElementTyped, ) -> ExpandElement { assign_op_expand(context, lhs.into(), rhs.into(), Operator::Mul) } - - impl_op_assign!( - (MulAssign|mul_assign) => { - F16 | f32;u32, - F32 | f32;u32, - BF16 | f32;u32, - F64 | f32;u32, - I32 | i32;u32, - I64 | i32;u32, - UInt | u32 - } - ); } pub mod div_assign_op { use self::ir::Operator; use super::*; - use crate::frontend::{operation::base::assign_op_expand, BF16, F16, F32, F64, I32, I64}; - use core::ops::DivAssign; + use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped}; - pub fn expand, R: Into>( + pub fn expand( context: &mut CubeContext, - lhs: L, - rhs: R, + lhs: ExpandElementTyped, + rhs: ExpandElementTyped, ) -> ExpandElement { assign_op_expand(context, lhs.into(), rhs.into(), Operator::Div) } +} + +pub mod rem_assign_op { + use self::ir::Operator; + use super::*; + use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped}; + + pub fn expand( + context: &mut CubeContext, + lhs: ExpandElementTyped, + rhs: ExpandElementTyped, + ) -> ExpandElement { + assign_op_expand(context, lhs.into(), rhs.into(), Operator::Modulo) + } +} + +pub mod bitor_assign_op { + use self::ir::Operator; + use super::*; + use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped}; + + pub fn expand( + context: &mut CubeContext, + lhs: ExpandElementTyped, + rhs: ExpandElementTyped, + ) -> ExpandElement { + assign_op_expand(context, lhs.into(), rhs.into(), Operator::BitwiseOr) + } +} + +pub mod bitand_assign_op { + use self::ir::Operator; + use super::*; + use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped}; + + pub fn expand( + context: &mut CubeContext, + lhs: ExpandElementTyped, + rhs: ExpandElementTyped, + ) -> ExpandElement { + assign_op_expand(context, lhs.into(), rhs.into(), Operator::BitwiseAnd) + } +} + +pub mod bitxor_assign_op { + use self::ir::Operator; + use super::*; + use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped}; + + pub fn expand( + context: &mut CubeContext, + lhs: ExpandElementTyped, + rhs: ExpandElementTyped, + ) -> ExpandElement { + assign_op_expand(context, lhs.into(), rhs.into(), Operator::BitwiseXor) + } +} - impl_op_assign!( - (DivAssign|div_assign) => { - F16 | f32;u32, - F32 | f32;u32, - BF16 | f32;u32, - F64 | f32;u32, - I32 | i32;u32, - I64 | i32;u32, - UInt | u32 - } - ); +pub mod shl_assign_op { + use self::ir::Operator; + use super::*; + use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped}; + + pub fn expand( + context: &mut CubeContext, + lhs: ExpandElementTyped, + rhs: ExpandElementTyped, + ) -> ExpandElement { + assign_op_expand(context, lhs.into(), rhs.into(), Operator::ShiftLeft) + } +} + +pub mod shr_assign_op { + use self::ir::Operator; + use super::*; + use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped}; + + pub fn expand( + context: &mut CubeContext, + lhs: ExpandElementTyped, + rhs: ExpandElementTyped, + ) -> ExpandElement { + assign_op_expand(context, lhs.into(), rhs.into(), Operator::ShiftRight) + } } diff --git a/crates/cubecl-core/src/frontend/operation/base.rs b/crates/cubecl-core/src/frontend/operation/base.rs index 70d07189..8868acc0 100644 --- a/crates/cubecl-core/src/frontend/operation/base.rs +++ b/crates/cubecl-core/src/frontend/operation/base.rs @@ -1,6 +1,9 @@ -use crate::frontend::{CubeContext, ExpandElement}; use crate::ir::{BinaryOperator, Elem, Item, Operator, UnaryOperator, Variable, Vectorization}; -use crate::prelude::{CubeType, ExpandElementTyped, UInt}; +use crate::prelude::{CubeType, ExpandElementTyped}; +use crate::{ + frontend::{CubeContext, ExpandElement}, + prelude::CubeIndex, +}; pub(crate) fn binary_expand( context: &mut CubeContext, @@ -17,7 +20,8 @@ where let item_lhs = lhs.item(); let item_rhs = rhs.item(); - let vectorization = check_vectorization(item_lhs.vectorization, item_rhs.vectorization); + let vectorization = find_vectorization(item_lhs.vectorization, item_rhs.vectorization); + let item = Item::vectorized(item_lhs.elem, vectorization); // We can only reuse rhs. @@ -94,7 +98,7 @@ where let rhs: Variable = *rhs; let item = lhs.item(); - check_vectorization(item.vectorization, rhs.item().vectorization); + find_vectorization(item.vectorization, rhs.item().vectorization); let out_item = Item { elem: Elem::Bool, @@ -127,7 +131,7 @@ where let lhs_var: Variable = *lhs; let rhs: Variable = *rhs; - check_vectorization(lhs_var.item().vectorization, rhs.item().vectorization); + find_vectorization(lhs_var.item().vectorization, rhs.item().vectorization); let op = func(BinaryOperator { lhs: lhs_var, @@ -190,29 +194,31 @@ where out } -fn check_vectorization(lhs: Vectorization, rhs: Vectorization) -> Vectorization { - let output = u8::max(lhs, rhs); - - if lhs == 1 || rhs == 1 { - return output; +fn find_vectorization(lhs: Vectorization, rhs: Vectorization) -> Vectorization { + match (lhs, rhs) { + (None, None) => None, + (None, Some(rhs)) => Some(rhs), + (Some(lhs), None) => Some(lhs), + (Some(lhs), Some(rhs)) if lhs == rhs => Some(lhs), + (Some(lhs), Some(rhs)) => { + panic!( + "Left and right have different vectorizations. + Left: {lhs}, right: {rhs}. + Auto-matching fixed vectorization currently unsupported." + ); + } } - - assert!( - lhs == rhs, - "Tried to perform binary operation on different vectorization schemes." - ); - - output } pub fn array_assign_binary_op_expand< - A: CubeType + core::ops::Index, + A: CubeType + CubeIndex, + V: CubeType, F: Fn(BinaryOperator) -> Operator, >( context: &mut CubeContext, array: ExpandElementTyped, - index: ExpandElementTyped, - value: ExpandElementTyped, + index: ExpandElementTyped, + value: ExpandElementTyped, func: F, ) where A::Output: CubeType + Sized, diff --git a/crates/cubecl-core/src/frontend/operation/binary.rs b/crates/cubecl-core/src/frontend/operation/binary.rs index dabe8001..eacede0c 100644 --- a/crates/cubecl-core/src/frontend/operation/binary.rs +++ b/crates/cubecl-core/src/frontend/operation/binary.rs @@ -1,135 +1,55 @@ -use crate::frontend::operation::base::binary_expand; -use crate::frontend::{ - AtomicI32, AtomicI64, AtomicUInt, CubeContext, CubePrimitive, ExpandElementTyped, UInt, BF16, - F16, F32, F64, I32, I64, -}; +use crate::frontend::CubeType; +use crate::frontend::{CubeContext, CubePrimitive, ExpandElementTyped}; use crate::ir::Operator; -use crate::{frontend::CubeType, unexpanded}; - -macro_rules! impl_op { - (($tr:ident|$func:ident|$op:tt) => { $($type:ty| $($rhs:ty);*),* }) => { - $( - $( - impl $tr<$rhs> for $type { - type Output = Self; - - fn $func(self, rhs: $rhs) -> Self::Output { - let rhs: Self = rhs.into(); - self $op rhs - } - } - )* - - impl $tr for $type { - type Output = Self; - - fn $func(self, rhs: Self) -> Self::Output { - (self.val $op rhs.val).into() - } - } - )* - }; -} +use crate::{frontend::operation::base::binary_expand, unexpanded}; +use half::{bf16, f16}; pub mod add { use super::*; - use core::ops::Add; pub fn expand( context: &mut CubeContext, - lhs: ExpandElementTyped, - rhs: ExpandElementTyped, + lhs: impl Into>, + rhs: impl Into>, ) -> ExpandElementTyped { - binary_expand(context, lhs.into(), rhs.into(), Operator::Add).into() + binary_expand(context, lhs.into().into(), rhs.into().into(), Operator::Add).into() } - - impl_op!( - (Add|add|+) => { - F16 | f32;u32, - F32 | f32;u32, - BF16 | f32;u32, - F64 | f32;u32, - I32 | i32;u32, - I64 | i32;u32, - UInt | u32 - } - ); } pub mod sub { use super::*; - use core::ops::Sub; pub fn expand( context: &mut CubeContext, - lhs: ExpandElementTyped, - rhs: ExpandElementTyped, + lhs: impl Into>, + rhs: impl Into>, ) -> ExpandElementTyped { - binary_expand(context, lhs.into(), rhs.into(), Operator::Sub).into() + binary_expand(context, lhs.into().into(), rhs.into().into(), Operator::Sub).into() } - - impl_op!( - (Sub|sub|-) => { - F16 | f32;u32, - F32 | f32;u32, - BF16 | f32;u32, - F64 | f32;u32, - I32 | i32;u32, - I64 | i32;u32, - UInt | u32 - } - ); } pub mod mul { use super::*; - use core::ops::Mul; pub fn expand( context: &mut CubeContext, - lhs: ExpandElementTyped, - rhs: ExpandElementTyped, + lhs: impl Into>, + rhs: impl Into>, ) -> ExpandElementTyped { - binary_expand(context, lhs.into(), rhs.into(), Operator::Mul).into() + binary_expand(context, lhs.into().into(), rhs.into().into(), Operator::Mul).into() } - - impl_op!( - (Mul|mul|*) => { - F16 | f32;u32, - F32 | f32;u32, - BF16 | f32;u32, - F64 | f32;u32, - I32 | i32;u32, - I64 | i32;u32, - UInt | u32 - } - ); } pub mod div { use super::*; - use core::ops::Div; - pub fn expand>>( + pub fn expand( context: &mut CubeContext, - lhs: ExpandElementTyped, - rhs: R, + lhs: impl Into>, + rhs: impl Into>, ) -> ExpandElementTyped { - let rhs: ExpandElementTyped = rhs.into(); - binary_expand(context, lhs.into(), rhs.into(), Operator::Div).into() + binary_expand(context, lhs.into().into(), rhs.into().into(), Operator::Div).into() } - - impl_op!( - (Div|div|/) => { - F16 | f32;u32, - F32 | f32;u32, - BF16 | f32;u32, - F64 | f32;u32, - I32 | i32;u32, - I64 | i32;u32, - UInt | u32 - } - ); } pub mod rem { @@ -137,27 +57,17 @@ pub mod rem { pub fn expand( context: &mut CubeContext, - lhs: ExpandElementTyped, - rhs: ExpandElementTyped, + lhs: impl Into>, + rhs: impl Into>, ) -> ExpandElementTyped { - binary_expand(context, lhs.into(), rhs.into(), Operator::Modulo).into() - } - - macro_rules! impl_rem { - ($type:ty) => { - impl core::ops::Rem for $type { - type Output = Self; - - fn rem(self, _rhs: Self) -> Self::Output { - unexpanded!() - } - } - }; + binary_expand( + context, + lhs.into().into(), + rhs.into().into(), + Operator::Modulo, + ) + .into() } - - impl_rem!(I32); - impl_rem!(I64); - impl_rem!(UInt); } pub mod and { @@ -165,10 +75,10 @@ pub mod and { pub fn expand( context: &mut CubeContext, - lhs: ExpandElementTyped, - rhs: ExpandElementTyped, + lhs: impl Into>, + rhs: impl Into>, ) -> ExpandElementTyped { - binary_expand(context, lhs.into(), rhs.into(), Operator::And).into() + binary_expand(context, lhs.into().into(), rhs.into().into(), Operator::And).into() } } @@ -177,18 +87,16 @@ pub mod bitand { pub fn expand( context: &mut CubeContext, - lhs: ExpandElementTyped, - rhs: ExpandElementTyped, + lhs: impl Into>, + rhs: impl Into>, ) -> ExpandElementTyped { - binary_expand(context, lhs.into(), rhs.into(), Operator::BitwiseAnd).into() - } - - impl core::ops::BitAnd for UInt { - type Output = UInt; - - fn bitand(self, _rhs: Self) -> Self::Output { - unexpanded!() - } + binary_expand( + context, + lhs.into().into(), + rhs.into().into(), + Operator::BitwiseAnd, + ) + .into() } } @@ -197,18 +105,16 @@ pub mod bitor { pub fn expand( context: &mut CubeContext, - lhs: ExpandElementTyped, - rhs: ExpandElementTyped, + lhs: impl Into>, + rhs: impl Into>, ) -> ExpandElementTyped { - binary_expand(context, lhs.into(), rhs.into(), Operator::BitwiseOr).into() - } - - impl core::ops::BitOr for UInt { - type Output = UInt; - - fn bitor(self, _rhs: Self) -> Self::Output { - unexpanded!() - } + binary_expand( + context, + lhs.into().into(), + rhs.into().into(), + Operator::BitwiseOr, + ) + .into() } } @@ -217,10 +123,10 @@ pub mod or { pub fn expand( context: &mut CubeContext, - lhs: ExpandElementTyped, - rhs: ExpandElementTyped, + lhs: impl Into>, + rhs: impl Into>, ) -> ExpandElementTyped { - binary_expand(context, lhs.into(), rhs.into(), Operator::Or).into() + binary_expand(context, lhs.into().into(), rhs.into().into(), Operator::Or).into() } } @@ -229,18 +135,16 @@ pub mod bitxor { pub fn expand( context: &mut CubeContext, - lhs: ExpandElementTyped, - rhs: ExpandElementTyped, + lhs: impl Into>, + rhs: impl Into>, ) -> ExpandElementTyped { - binary_expand(context, lhs.into(), rhs.into(), Operator::BitwiseXor).into() - } - - impl core::ops::BitXor for UInt { - type Output = UInt; - - fn bitxor(self, _rhs: Self) -> Self::Output { - unexpanded!() - } + binary_expand( + context, + lhs.into().into(), + rhs.into().into(), + Operator::BitwiseXor, + ) + .into() } } @@ -249,18 +153,16 @@ pub mod shl { pub fn expand( context: &mut CubeContext, - lhs: ExpandElementTyped, - rhs: ExpandElementTyped, + lhs: impl Into>, + rhs: impl Into>, ) -> ExpandElementTyped { - binary_expand(context, lhs.into(), rhs.into(), Operator::ShiftLeft).into() - } - - impl core::ops::Shl for UInt { - type Output = UInt; - - fn shl(self, _rhs: Self) -> Self::Output { - unexpanded!() - } + binary_expand( + context, + lhs.into().into(), + rhs.into().into(), + Operator::ShiftLeft, + ) + .into() } } @@ -269,30 +171,28 @@ pub mod shr { pub fn expand( context: &mut CubeContext, - lhs: ExpandElementTyped, - rhs: ExpandElementTyped, + lhs: impl Into>, + rhs: impl Into>, ) -> ExpandElementTyped { - binary_expand(context, lhs.into(), rhs.into(), Operator::ShiftRight).into() - } - - impl core::ops::Shr for UInt { - type Output = UInt; - - fn shr(self, _rhs: Self) -> Self::Output { - unexpanded!() - } + binary_expand( + context, + lhs.into().into(), + rhs.into().into(), + Operator::ShiftRight, + ) + .into() } } /// For binary functions without special syntax macro_rules! impl_binary_func { - ($trait_name:ident, $method_name:ident, $method_name_expand:ident, $operator:expr, $($type:ty),*) => { + ($trait_name:ident, $method_name:ident, $func_name_expand:ident, $method_name_expand:ident, $operator:expr, $($type:ty),*) => { pub trait $trait_name: CubeType + Sized { fn $method_name(self, _rhs: Self) -> Self { unexpanded!() } - fn $method_name_expand( + fn $func_name_expand( context: &mut CubeContext, lhs: ExpandElementTyped, rhs: ExpandElementTyped, @@ -302,6 +202,11 @@ macro_rules! impl_binary_func { } $(impl $trait_name for $type {})* + $(impl ExpandElementTyped<$type> { + pub fn $method_name_expand(self, context: &mut CubeContext, rhs: ExpandElementTyped<$type>) -> ExpandElementTyped<$type> { + binary_expand(context, self.into(), rhs.into(), $operator).into() + } + })* } } @@ -309,51 +214,52 @@ impl_binary_func!( Powf, powf, __expand_powf, + __expand_powf_method, Operator::Powf, - F16, - BF16, - F32, - F64 + f16, + bf16, + f32, + f64 ); impl_binary_func!( Max, max, __expand_max, + __expand_max_method, Operator::Max, - F16, - BF16, - F32, - F64, - I32, - I64, - UInt, - AtomicI32, - AtomicI64, - AtomicUInt + f16, + bf16, + f32, + f64, + i32, + i64, + u32 ); impl_binary_func!( Min, min, __expand_min, + __expand_min_method, Operator::Min, - F16, - BF16, - F32, - F64, - I32, - I64, - UInt + f16, + bf16, + f32, + f64, + i32, + i64, + u32 ); impl_binary_func!( Remainder, rem, __expand_rem, + __expand_rem_method, Operator::Remainder, - F16, - BF16, - F32, - F64, - I32, - I64, - UInt + f16, + bf16, + f32, + f64, + i32, + i64, + u32 ); diff --git a/crates/cubecl-core/src/frontend/operation/clamp.rs b/crates/cubecl-core/src/frontend/operation/clamp.rs index 6a00d643..6e1e6b5f 100644 --- a/crates/cubecl-core/src/frontend/operation/clamp.rs +++ b/crates/cubecl-core/src/frontend/operation/clamp.rs @@ -1,6 +1,8 @@ +use half::{bf16, f16}; + use crate::{ ir::{ClampOperator, Operator}, - prelude::{CubeContext, CubePrimitive, ExpandElement, UInt, BF16, F16, F32, F64, I32, I64}, + prelude::{CubeContext, CubePrimitive, ExpandElement}, unexpanded, }; @@ -34,10 +36,10 @@ pub trait Clamp: CubePrimitive + Sized { } } -impl Clamp for F16 {} -impl Clamp for BF16 {} -impl Clamp for F32 {} -impl Clamp for F64 {} -impl Clamp for I32 {} -impl Clamp for I64 {} -impl Clamp for UInt {} +impl Clamp for f16 {} +impl Clamp for bf16 {} +impl Clamp for f32 {} +impl Clamp for f64 {} +impl Clamp for i32 {} +impl Clamp for i64 {} +impl Clamp for u32 {} diff --git a/crates/cubecl-core/src/frontend/operation/cmp.rs b/crates/cubecl-core/src/frontend/operation/cmp.rs index a2d44a84..0a482063 100644 --- a/crates/cubecl-core/src/frontend/operation/cmp.rs +++ b/crates/cubecl-core/src/frontend/operation/cmp.rs @@ -1,74 +1,23 @@ use crate::frontend::operation::base::cmp_expand; -use crate::frontend::{CubeContext, ExpandElementTyped, UInt, BF16, F16, F32, F64, I32, I64}; +use crate::frontend::{CubeContext, ExpandElementTyped}; use crate::ir::Operator; use crate::prelude::CubePrimitive; -macro_rules! impl_cmp { - ({ $($type:ty| $($rhs:ty);*),* }) => { - $( - $( - impl core::cmp::PartialEq<$rhs> for $type { - fn eq(&self, rhs: &$rhs) -> bool { - let rhs: Self = (*rhs).into(); - self == &rhs - } - } - - impl core::cmp::PartialOrd<$rhs> for $type { - fn partial_cmp(&self, rhs: &$rhs) -> Option { - let rhs: Self = (*rhs).into(); - core::cmp::PartialOrd::partial_cmp(self, &rhs) - } - } - - )* - - impl_cmp!($type); - )* - }; - ($type:ty) => { - impl core::cmp::PartialEq for $type { - fn eq(&self, other: &Self) -> bool { - self.val == other.val && self.vectorization == other.vectorization - } - } - - impl core::cmp::Eq for $type {} - - impl core::cmp::PartialOrd for $type { - fn partial_cmp(&self, other: &Self) -> Option { - match self.val.partial_cmp(&other.val) { - Some(core::cmp::Ordering::Equal) => {} - ord => return ord, - } - self.vectorization.partial_cmp(&other.vectorization) - } - } - }; -} - -impl_cmp!( - { - F16 | f32;u32, - F32 | f32;u32, - BF16 | f32;u32, - F64 | f32;u32, - I32 | i32;u32, - I64 | i32;u32, - UInt | u32 - } -); - pub mod ne { - use super::*; pub fn expand( context: &mut CubeContext, - lhs: ExpandElementTyped, - rhs: ExpandElementTyped, + lhs: impl Into>, + rhs: impl Into>, ) -> ExpandElementTyped { - cmp_expand(context, lhs.into(), rhs.into(), Operator::NotEqual).into() + cmp_expand( + context, + lhs.into().into(), + rhs.into().into(), + Operator::NotEqual, + ) + .into() } } @@ -77,10 +26,16 @@ pub mod gt { pub fn expand( context: &mut CubeContext, - lhs: ExpandElementTyped, - rhs: ExpandElementTyped, + lhs: impl Into>, + rhs: impl Into>, ) -> ExpandElementTyped { - cmp_expand(context, lhs.into(), rhs.into(), Operator::Greater).into() + cmp_expand( + context, + lhs.into().into(), + rhs.into().into(), + Operator::Greater, + ) + .into() } } @@ -89,10 +44,16 @@ pub mod lt { pub fn expand( context: &mut CubeContext, - lhs: ExpandElementTyped, - rhs: ExpandElementTyped, + lhs: impl Into>, + rhs: impl Into>, ) -> ExpandElementTyped { - cmp_expand(context, lhs.into(), rhs.into(), Operator::Lower).into() + cmp_expand( + context, + lhs.into().into(), + rhs.into().into(), + Operator::Lower, + ) + .into() } } @@ -101,10 +62,16 @@ pub mod ge { pub fn expand( context: &mut CubeContext, - lhs: ExpandElementTyped, - rhs: ExpandElementTyped, + lhs: impl Into>, + rhs: impl Into>, ) -> ExpandElementTyped { - cmp_expand(context, lhs.into(), rhs.into(), Operator::GreaterEqual).into() + cmp_expand( + context, + lhs.into().into(), + rhs.into().into(), + Operator::GreaterEqual, + ) + .into() } } @@ -113,10 +80,16 @@ pub mod le { pub fn expand( context: &mut CubeContext, - lhs: ExpandElementTyped, - rhs: ExpandElementTyped, + lhs: impl Into>, + rhs: impl Into>, ) -> ExpandElementTyped { - cmp_expand(context, lhs.into(), rhs.into(), Operator::LowerEqual).into() + cmp_expand( + context, + lhs.into().into(), + rhs.into().into(), + Operator::LowerEqual, + ) + .into() } } @@ -126,10 +99,16 @@ pub mod eq { pub fn expand( context: &mut CubeContext, - lhs: ExpandElementTyped, - rhs: ExpandElementTyped, + lhs: impl Into>, + rhs: impl Into>, ) -> ExpandElementTyped { - cmp_expand(context, lhs.into(), rhs.into(), Operator::Equal).into() + cmp_expand( + context, + lhs.into().into(), + rhs.into().into(), + Operator::Equal, + ) + .into() } } @@ -138,9 +117,9 @@ pub mod add_assign { pub fn expand( context: &mut CubeContext, - lhs: ExpandElementTyped, - rhs: ExpandElementTyped, + lhs: impl Into>, + rhs: impl Into>, ) -> ExpandElementTyped { - cmp_expand(context, lhs.into(), rhs.into(), Operator::Add).into() + cmp_expand(context, lhs.into().into(), rhs.into().into(), Operator::Add).into() } } diff --git a/crates/cubecl-core/src/frontend/operation/unary.rs b/crates/cubecl-core/src/frontend/operation/unary.rs index 51e50eb7..ef5b0bd3 100644 --- a/crates/cubecl-core/src/frontend/operation/unary.rs +++ b/crates/cubecl-core/src/frontend/operation/unary.rs @@ -1,5 +1,7 @@ +use half::{bf16, f16}; + use crate::{ - frontend::{CubeContext, UInt, BF16, F16, F32, F64, I32, I64}, + frontend::CubeContext, ir::Operator, prelude::{CubePrimitive, ExpandElementTyped}, unexpanded, @@ -18,6 +20,19 @@ pub mod not { } } +pub mod neg { + use crate::prelude::Numeric; + + use super::*; + + pub fn expand( + context: &mut CubeContext, + x: ExpandElementTyped, + ) -> ExpandElementTyped { + unary_expand(context, x.into(), Operator::Neg).into() + } +} + macro_rules! impl_unary_func { ($trait_name:ident, $method_name:ident, $method_name_expand:ident, $operator:expr, $($type:ty),*) => { pub trait $trait_name: CubePrimitive + Sized { @@ -40,96 +55,96 @@ impl_unary_func!( abs, __expand_abs, Operator::Abs, - F16, - BF16, - F32, - F64, - I32, - I64, - UInt + f16, + bf16, + f32, + f64, + i32, + i64, + u32 ); -impl_unary_func!(Exp, exp, __expand_exp, Operator::Exp, F16, BF16, F32, F64); -impl_unary_func!(Log, log, __expand_log, Operator::Log, F16, BF16, F32, F64); +impl_unary_func!(Exp, exp, __expand_exp, Operator::Exp, f16, bf16, f32, f64); +impl_unary_func!(Log, log, __expand_log, Operator::Log, f16, bf16, f32, f64); impl_unary_func!( Log1p, log1p, __expand_log1p, Operator::Log1p, - F16, - BF16, - F32, - F64 + f16, + bf16, + f32, + f64 ); -impl_unary_func!(Cos, cos, __expand_cos, Operator::Cos, F16, BF16, F32, F64); -impl_unary_func!(Sin, sin, __expand_sin, Operator::Sin, F16, BF16, F32, F64); +impl_unary_func!(Cos, cos, __expand_cos, Operator::Cos, f16, bf16, f32, f64); +impl_unary_func!(Sin, sin, __expand_sin, Operator::Sin, f16, bf16, f32, f64); impl_unary_func!( Tanh, tanh, __expand_tanh, Operator::Tanh, - F16, - BF16, - F32, - F64 + f16, + bf16, + f32, + f64 ); impl_unary_func!( Sqrt, sqrt, __expand_sqrt, Operator::Sqrt, - F16, - BF16, - F32, - F64 + f16, + bf16, + f32, + f64 ); impl_unary_func!( Round, round, __expand_round, Operator::Round, - F16, - BF16, - F32, - F64 + f16, + bf16, + f32, + f64 ); impl_unary_func!( Floor, floor, __expand_floor, Operator::Floor, - F16, - BF16, - F32, - F64 + f16, + bf16, + f32, + f64 ); impl_unary_func!( Ceil, ceil, __expand_ceil, Operator::Ceil, - F16, - BF16, - F32, - F64 + f16, + bf16, + f32, + f64 ); -impl_unary_func!(Erf, erf, __expand_erf, Operator::Erf, F16, BF16, F32, F64); +impl_unary_func!(Erf, erf, __expand_erf, Operator::Erf, f16, bf16, f32, f64); impl_unary_func!( Recip, recip, __expand_recip, Operator::Recip, - F16, - BF16, - F32, - F64 + f16, + bf16, + f32, + f64 ); impl_unary_func!( Normalize, normalize, __expand_normalize, Operator::Normalize, - F16, - BF16, - F32, - F64 + f16, + bf16, + f32, + f64 ); diff --git a/crates/cubecl-core/src/frontend/sequence.rs b/crates/cubecl-core/src/frontend/sequence.rs index 7f52cced..74583e9c 100644 --- a/crates/cubecl-core/src/frontend/sequence.rs +++ b/crates/cubecl-core/src/frontend/sequence.rs @@ -1,4 +1,4 @@ -use super::{indexation::Index, CubeContext, CubeType, Init}; +use super::{branch::Iterable, indexation::Index, CubeContext, CubeType, ExpandElementTyped, Init}; use crate::unexpanded; use std::{cell::RefCell, rc::Rc}; @@ -54,10 +54,10 @@ impl Sequence { } /// Expand function of [index](Self::index). - pub fn __expand_index( + pub fn __expand_index( context: &mut CubeContext, expand: SequenceExpand, - index: I, + index: ExpandElementTyped, ) -> T::ExpandType { expand.__expand_index_method(context, index) } @@ -70,6 +70,26 @@ pub struct SequenceExpand { values: Rc>>, } +impl Iterable for SequenceExpand { + fn expand( + self, + context: &mut CubeContext, + func: impl FnMut(&mut CubeContext, ::ExpandType), + ) { + self.expand_unroll(context, func); + } + + fn expand_unroll( + self, + context: &mut CubeContext, + mut func: impl FnMut(&mut CubeContext, ::ExpandType), + ) { + for elem in self { + func(context, elem); + } + } +} + impl Init for SequenceExpand { fn init(self, _context: &mut crate::prelude::CubeContext) -> Self { self @@ -115,20 +135,15 @@ impl SequenceExpand { } /// Expand method of [index](Sequence::index). - pub fn __expand_index_method( + pub fn __expand_index_method( &self, _context: &mut CubeContext, - index: I, + index: ExpandElementTyped, ) -> T::ExpandType { - let value = index.value(); - let index = match value { - crate::ir::Variable::ConstantScalar(value) => match value { - crate::ir::ConstantScalarValue::Int(val, _) => val as usize, - crate::ir::ConstantScalarValue::UInt(val) => val as usize, - _ => panic!("Only integer types are supported"), - }, - _ => panic!("Only constant are supported"), - }; + let index = index + .constant() + .expect("Only constant are supported") + .as_usize(); self.values.borrow()[index].clone() } } diff --git a/crates/cubecl-core/src/frontend/subcube.rs b/crates/cubecl-core/src/frontend/subcube.rs index 096a55ea..2fbb7463 100644 --- a/crates/cubecl-core/src/frontend/subcube.rs +++ b/crates/cubecl-core/src/frontend/subcube.rs @@ -1,12 +1,12 @@ -use super::{CubeContext, CubePrimitive, ExpandElement, UInt}; -use crate::prelude::{Bool, ExpandElementTyped}; +use super::{CubeContext, CubePrimitive, ExpandElement}; +use crate::prelude::ExpandElementTyped; use crate::{ ir::{Elem, InitOperator, Item, Operation, Subcube, UnaryOperator}, unexpanded, }; /// Returns true if the cube unit has the lowest subcube_unit_id among active unit in the subcube -pub fn subcube_elect() -> Bool { +pub fn subcube_elect() -> bool { unexpanded!() } @@ -16,7 +16,7 @@ pub mod subcube_elect { use super::*; /// Expand method of [subcube_elect()]. - pub fn __expand(context: &mut CubeContext) -> ExpandElementTyped { + pub fn expand(context: &mut CubeContext) -> ExpandElementTyped { let output = context.create_local(Item::new(Elem::Bool)); let out = *output; @@ -29,7 +29,7 @@ pub mod subcube_elect { /// Broadcasts the value from the specified subcube unit at the given index /// to all active units within that subcube. #[allow(unused_variables)] -pub fn subcube_broadcast(value: E, index: UInt) -> E { +pub fn subcube_broadcast(value: E, index: u32) -> E { unexpanded!() } @@ -39,10 +39,10 @@ pub mod subcube_broadcast { use super::*; /// Expand method of [subcube_broadcast()]. - pub fn __expand( + pub fn expand( context: &mut CubeContext, value: ExpandElementTyped, - id: ExpandElementTyped, + id: ExpandElementTyped, ) -> ExpandElementTyped { let output = context.create_local(value.expand.item()); let out = *output; @@ -68,7 +68,7 @@ pub mod subcube_sum { use super::*; /// Expand method of [subcube_sum()]. - pub fn __expand( + pub fn expand( context: &mut CubeContext, elem: ExpandElementTyped, ) -> ExpandElementTyped { @@ -97,7 +97,7 @@ pub mod subcube_prod { use super::*; /// Expand method of [subcube_prod()]. - pub fn __expand( + pub fn expand( context: &mut CubeContext, elem: ExpandElementTyped, ) -> ExpandElementTyped { @@ -126,7 +126,7 @@ pub mod subcube_max { use super::*; /// Expand method of [subcube_max()]. - pub fn __expand( + pub fn expand( context: &mut CubeContext, elem: ExpandElementTyped, ) -> ExpandElementTyped { @@ -155,7 +155,7 @@ pub mod subcube_min { use super::*; /// Expand method of [subcube_min()]. - pub fn __expand( + pub fn expand( context: &mut CubeContext, elem: ExpandElementTyped, ) -> ExpandElementTyped { @@ -175,7 +175,7 @@ pub mod subcube_min { } /// Perform a reduce all operation across all units in a subcube. -pub fn subcube_all(_elem: Bool) -> Bool { +pub fn subcube_all(_elem: bool) -> bool { unexpanded!() } @@ -185,10 +185,10 @@ pub mod subcube_all { use super::*; /// Expand method of [subcube_all()]. - pub fn __expand( + pub fn expand( context: &mut CubeContext, - elem: ExpandElementTyped, - ) -> ExpandElementTyped { + elem: ExpandElementTyped, + ) -> ExpandElementTyped { let elem: ExpandElement = elem.into(); let output = context.create_local(elem.item()); @@ -205,7 +205,7 @@ pub mod subcube_all { } /// Perform a reduce any operation across all units in a subcube. -pub fn subcube_any(_elem: Bool) -> Bool { +pub fn subcube_any(_elem: bool) -> bool { unexpanded!() } @@ -215,10 +215,10 @@ pub mod subcube_any { use super::*; /// Expand method of [subcube_any()]. - pub fn __expand( + pub fn expand( context: &mut CubeContext, - elem: ExpandElementTyped, - ) -> ExpandElementTyped { + elem: ExpandElementTyped, + ) -> ExpandElementTyped { let elem: ExpandElement = elem.into(); let output = context.create_local(elem.item()); diff --git a/crates/cubecl-core/src/frontend/synchronization.rs b/crates/cubecl-core/src/frontend/synchronization.rs index bf0b2e7d..13acf6fd 100644 --- a/crates/cubecl-core/src/frontend/synchronization.rs +++ b/crates/cubecl-core/src/frontend/synchronization.rs @@ -16,7 +16,7 @@ pub fn sync_units() {} pub mod sync_units { use super::*; - pub fn __expand(context: &mut CubeContext) { + pub fn expand(context: &mut CubeContext) { context.register(Synchronization::SyncUnits) } } @@ -29,7 +29,7 @@ pub fn sync_storage() {} pub mod sync_storage { use super::*; - pub fn __expand(context: &mut CubeContext) { + pub fn expand(context: &mut CubeContext) { context.register(Synchronization::SyncStorage) } } diff --git a/crates/cubecl-core/src/frontend/topology.rs b/crates/cubecl-core/src/frontend/topology.rs index 5507755d..139c2166 100644 --- a/crates/cubecl-core/src/frontend/topology.rs +++ b/crates/cubecl-core/src/frontend/topology.rs @@ -2,12 +2,11 @@ //! the expand function, so that a user implicitly imports the expand function when importing the constant. use super::ExpandElementTyped; -use crate::frontend::UInt; macro_rules! constant { ($ident:ident, $var:expr, $doc:expr) => { #[doc = $doc] - pub const $ident: UInt = UInt::new(0u32); + pub const $ident: u32 = 1; #[allow(non_snake_case)] #[doc = $doc] @@ -16,7 +15,7 @@ macro_rules! constant { use crate::frontend::{CubeContext, ExpandElement}; /// Expansion of the constant variable. - pub fn expand(_context: &mut CubeContext) -> ExpandElementTyped { + pub fn expand(_context: &mut CubeContext) -> ExpandElementTyped { ExpandElementTyped::new(ExpandElement::Plain($var)) } } diff --git a/crates/cubecl-core/src/ir/branch.rs b/crates/cubecl-core/src/ir/branch.rs index 320d1b59..1a954d34 100644 --- a/crates/cubecl-core/src/ir/branch.rs +++ b/crates/cubecl-core/src/ir/branch.rs @@ -40,6 +40,7 @@ pub struct RangeLoop { pub start: Variable, pub end: Variable, pub step: Option, + pub inclusive: bool, pub scope: Scope, } @@ -93,6 +94,7 @@ impl RangeLoop { start: Variable, end: Variable, step: Option, + inclusive: bool, func: F, ) { let mut scope = parent_scope.child(); @@ -107,6 +109,7 @@ impl RangeLoop { end, step, scope, + inclusive, })); } } @@ -133,9 +136,20 @@ impl UnrolledRangeLoop { start: u32, end: u32, step: Option, + inclusive: bool, func: F, ) { - if let Some(step) = step { + if inclusive { + if let Some(step) = step { + for i in (start..=end).step_by(step as usize) { + func(i.into(), scope); + } + } else { + for i in start..=end { + func(i.into(), scope); + } + } + } else if let Some(step) = step { for i in (start..end).step_by(step as usize) { func(i.into(), scope); } diff --git a/crates/cubecl-core/src/ir/kernel.rs b/crates/cubecl-core/src/ir/kernel.rs index 59f0f4cf..76d63d14 100644 --- a/crates/cubecl-core/src/ir/kernel.rs +++ b/crates/cubecl-core/src/ir/kernel.rs @@ -196,7 +196,7 @@ impl Item { pub fn new(elem: Elem) -> Self { Self { elem, - vectorization: 1, + vectorization: None, } } diff --git a/crates/cubecl-core/src/ir/macros.rs b/crates/cubecl-core/src/ir/macros.rs index ccbcdce3..5cd0a0aa 100644 --- a/crates/cubecl-core/src/ir/macros.rs +++ b/crates/cubecl-core/src/ir/macros.rs @@ -366,14 +366,14 @@ macro_rules! cpa { }; // range(start, end).for_each(|i, scope| { ... }) ($scope:expr, range($start:expr, $end:expr).for_each($arg:expr)) => { - $crate::ir::RangeLoop::register($scope, $start.into(), $end.into(), None, $arg); + $crate::ir::RangeLoop::register($scope, $start.into(), $end.into(), None, false, $arg); }; // range(start, end, unroll).for_each(|i, scope| { ... }) ($scope:expr, range($start:expr, $end:expr, $unroll:expr).for_each($arg:expr)) => { if $unroll { - $crate::ir::UnrolledRangeLoop::register($scope, $start.into(), $end.into(), None, $arg); + $crate::ir::UnrolledRangeLoop::register($scope, $start.into(), $end.into(), None, false, $arg); } else { - $crate::ir::RangeLoop::register($scope, $start.into(), $end.into(), None, $arg); + $crate::ir::RangeLoop::register($scope, $start.into(), $end.into(), None, false, $arg); } }; // range_stepped(start, end, step).for_each(|i, scope| { ... }) diff --git a/crates/cubecl-core/src/ir/operation.rs b/crates/cubecl-core/src/ir/operation.rs index b26c8b67..ffb2a79a 100644 --- a/crates/cubecl-core/src/ir/operation.rs +++ b/crates/cubecl-core/src/ir/operation.rs @@ -61,6 +61,7 @@ pub enum Operator { And(BinaryOperator), Or(BinaryOperator), Not(UnaryOperator), + Neg(UnaryOperator), Max(BinaryOperator), Min(BinaryOperator), BitwiseAnd(BinaryOperator), diff --git a/crates/cubecl-core/src/ir/procedure/assign.rs b/crates/cubecl-core/src/ir/procedure/assign.rs index e72f23dc..499c1c9f 100644 --- a/crates/cubecl-core/src/ir/procedure/assign.rs +++ b/crates/cubecl-core/src/ir/procedure/assign.rs @@ -19,15 +19,18 @@ impl ConditionalAssign { let rhs = self.rhs; let out = self.out; - let index_var = - |scope: &mut Scope, var: Variable, index: usize| match var.item().vectorization == 1 { - true => var, - false => { - let out = scope.create_local(var.item().elem()); - cpa!(scope, out = var[index]); - out - } - }; + let index_var = |scope: &mut Scope, var: Variable, index: usize| match var + .item() + .vectorization + .is_none() + { + true => var, + false => { + let out = scope.create_local(var.item().elem()); + cpa!(scope, out = var[index]); + out + } + }; let mut assign_index = |index: usize| { let cond = index_var(scope, cond, index); @@ -44,16 +47,16 @@ impl ConditionalAssign { }; let vectorization = out.item().vectorization; - match vectorization == 1 { - true => { + match vectorization { + None => { cpa!(scope, if (cond).then(|scope| { cpa!(scope, out = lhs); }).else(|scope| { cpa!(scope, out = rhs); })); } - false => { - for i in 0..vectorization { + Some(vectorization) => { + for i in 0..vectorization.get() { assign_index(i as usize); } } diff --git a/crates/cubecl-core/src/ir/procedure/read.rs b/crates/cubecl-core/src/ir/procedure/read.rs index e63065d5..9d6db951 100644 --- a/crates/cubecl-core/src/ir/procedure/read.rs +++ b/crates/cubecl-core/src/ir/procedure/read.rs @@ -143,7 +143,11 @@ impl IndexOffsetGlobalWithLayout { let index_item_ty = Item::new(Elem::UInt); let offset_ref = self.position; let zero: Variable = 0u32.into(); - let vectorization_factor: u8 = self.tensors[0].item().vectorization; + let vectorization_factor: u8 = self.tensors[0] + .item() + .vectorization + .map(|it| it.get()) + .unwrap_or(1); let vectorization_factor: Variable = (vectorization_factor as u32).into(); for index in self.indexes.iter() { cpa!(scope, index = zero); diff --git a/crates/cubecl-core/src/ir/processing.rs b/crates/cubecl-core/src/ir/processing.rs index e1c9798d..7c80a936 100644 --- a/crates/cubecl-core/src/ir/processing.rs +++ b/crates/cubecl-core/src/ir/processing.rs @@ -159,6 +159,7 @@ impl ScopeProcessing { Operator::Not(op) => { sanitize_constant_scalar_ref_elem(&mut op.input, Elem::Bool); } + Operator::Neg(op) => sanitize_constant_scalar_ref_var(&mut op.input, &op.out), Operator::Max(op) => { sanitize_constant_scalar_ref_var(&mut op.lhs, &op.out); sanitize_constant_scalar_ref_var(&mut op.rhs, &op.out); @@ -248,6 +249,9 @@ impl ScopeProcessing { Branch::RangeLoop(op) => { sanitize_constant_scalar_ref_elem(&mut op.start, Elem::UInt); sanitize_constant_scalar_ref_elem(&mut op.end, Elem::UInt); + if let Some(step) = &mut op.step { + sanitize_constant_scalar_ref_elem(step, Elem::UInt); + } } _ => { // Nothing to do. diff --git a/crates/cubecl-core/src/ir/variable.rs b/crates/cubecl-core/src/ir/variable.rs index 9c81f7a2..172d1478 100644 --- a/crates/cubecl-core/src/ir/variable.rs +++ b/crates/cubecl-core/src/ir/variable.rs @@ -173,6 +173,18 @@ impl ConstantScalarValue { self.try_as_i64() .expect("Only Int and UInt kind can be made into i64.") } + + pub fn try_as_bool(&self) -> Option { + match self { + ConstantScalarValue::Bool(val) => Some(*val), + _ => None, + } + } + + pub fn as_bool(&self) -> bool { + self.try_as_bool() + .expect("Only bool can be made into a bool") + } } impl Variable { @@ -250,6 +262,13 @@ impl Variable { Variable::SubcubeDim => Item::new(Elem::UInt), } } + + pub fn as_const(&self) -> Option { + match self { + Variable::ConstantScalar(constant) => Some(*constant), + _ => None, + } + } } // Useful with the cube_inline macro. diff --git a/crates/cubecl-core/src/ir/vectorization.rs b/crates/cubecl-core/src/ir/vectorization.rs index 9e3f6852..04f1e159 100644 --- a/crates/cubecl-core/src/ir/vectorization.rs +++ b/crates/cubecl-core/src/ir/vectorization.rs @@ -1,9 +1,11 @@ +use std::num::NonZero; + use super::{ BinaryOperator, ClampOperator, CompareAndSwapOperator, FmaOperator, InitOperator, Item, Operation, Operator, SliceOperator, Subcube, UnaryOperator, Variable, }; -pub type Vectorization = u8; +pub type Vectorization = Option>; impl Operation { pub(crate) fn vectorize(&self, vectorization: Vectorization) -> Self { @@ -78,6 +80,7 @@ impl Operator { Operator::Or(op) => Operator::Or(op.vectorize(vectorization)), Operator::Not(op) => Operator::Not(op.vectorize(vectorization)), Operator::BitwiseOr(op) => Operator::BitwiseOr(op.vectorize(vectorization)), + Operator::Neg(op) => Operator::Neg(op.vectorize(vectorization)), Operator::BitwiseAnd(op) => Operator::BitwiseAnd(op.vectorize(vectorization)), Operator::BitwiseXor(op) => Operator::BitwiseXor(op.vectorize(vectorization)), Operator::ShiftLeft(op) => Operator::ShiftLeft(op.vectorize(vectorization)), @@ -275,6 +278,7 @@ impl Item { } pub(crate) fn vectorized_size(&self, vectorize: Vectorization, size: u32) -> u32 { - size / (vectorize as u32) + let vec = vectorize.map(|it| it.get()).unwrap_or(1); + size / (vec as u32) } } diff --git a/crates/cubecl-core/src/lib.rs b/crates/cubecl-core/src/lib.rs index d8918dbf..3d4d45ff 100644 --- a/crates/cubecl-core/src/lib.rs +++ b/crates/cubecl-core/src/lib.rs @@ -23,9 +23,7 @@ pub use codegen::*; pub use pod::*; pub use runtime::*; -pub use cubecl_macros::cube; -pub use cubecl_macros::CubeLaunch; -pub use cubecl_macros::CubeType; +pub use cubecl_macros::{comptime, cube, CubeLaunch, CubeType}; pub use cubecl_runtime::benchmark; /// An approximation of the subcube dimension. diff --git a/crates/cubecl-core/src/prelude.rs b/crates/cubecl-core/src/prelude.rs index df6b0ea5..a2f687b7 100644 --- a/crates/cubecl-core/src/prelude.rs +++ b/crates/cubecl-core/src/prelude.rs @@ -5,14 +5,14 @@ pub use crate::compute::{ CompiledKernel, CubeCount, CubeTask, KernelBuilder, KernelLauncher, KernelTask, }; pub use crate::frontend::cmma; -pub use crate::frontend::{branch::*, synchronization::*}; +pub use crate::frontend::{branch::*, synchronization::*, vectorization_of}; pub use crate::ir::{CubeDim, KernelDefinition}; pub use crate::runtime::Runtime; /// Elements pub use crate::frontend::{ - Array, ArrayHandleRef, AtomicI32, AtomicI64, AtomicUInt, Bool, Float, LaunchArg, Slice, - SliceMut, Tensor, TensorArg, UInt, F16, F32, F64, I32, I64, + Array, ArrayHandleRef, AtomicI32, AtomicI64, AtomicU32, Float, LaunchArg, Slice, SliceMut, + Tensor, TensorArg, }; pub use crate::pod::CubeElement; diff --git a/crates/cubecl-core/src/runtime_tests/assign.rs b/crates/cubecl-core/src/runtime_tests/assign.rs index f9c81aae..0d3ff3e3 100644 --- a/crates/cubecl-core/src/runtime_tests/assign.rs +++ b/crates/cubecl-core/src/runtime_tests/assign.rs @@ -3,9 +3,9 @@ use crate as cubecl; use cubecl::prelude::*; #[cube(launch)] -pub fn kernel_assign(output: &mut Array, vectorization: Comptime) { - if UNIT_POS == UInt::new(0) { - let item = F32::vectorized(5.0, Comptime::get(vectorization)); +pub fn kernel_assign(output: &mut Array) { + if UNIT_POS == 0 { + let item = 5.0; output[0] = item; } } @@ -20,7 +20,6 @@ pub fn test_kernel_assign_scalar(client: ComputeClient, rhs: &Array, out: &mut Array) { - let a = cmma::Matrix::::new( +pub fn kernel_simple_1(lhs: &Array, rhs: &Array, out: &mut Array) { + let a = cmma::Matrix::::new( cmma::MatrixIdent::A, 16, 16, 16, cmma::MatrixLayout::RowMajor, ); - let b = cmma::Matrix::::new( + let b = cmma::Matrix::::new( cmma::MatrixIdent::B, 16, 16, 16, cmma::MatrixLayout::ColMajor, ); - let c = cmma::Matrix::::new( + let c = cmma::Matrix::::new( cmma::MatrixIdent::Accumulator, 16, 16, 16, cmma::MatrixLayout::Undefined, ); - cmma::fill::(&c, F32::new(0.0)); - cmma::load::(&a, lhs.as_slice(), UInt::new(16)); - cmma::load::(&b, rhs.as_slice(), UInt::new(16)); + cmma::fill::(&c, 0.0); + cmma::load(&a, lhs.as_slice(), 16); + cmma::load(&b, rhs.as_slice(), 16); - cmma::execute::(&a, &b, &c, &c); + cmma::execute::(&a, &b, &c, &c); - cmma::store::( - out.as_slice_mut(), - &c, - UInt::new(16), - cmma::MatrixLayout::RowMajor, - ); + cmma::store(out.as_slice_mut(), &c, 16, cmma::MatrixLayout::RowMajor); } pub fn test_simple_1(client: ComputeClient) { diff --git a/crates/cubecl-core/src/runtime_tests/launch.rs b/crates/cubecl-core/src/runtime_tests/launch.rs index 38c7d204..a831f080 100644 --- a/crates/cubecl-core/src/runtime_tests/launch.rs +++ b/crates/cubecl-core/src/runtime_tests/launch.rs @@ -1,5 +1,4 @@ use crate as cubecl; - use cubecl::prelude::*; #[cube(launch)] @@ -10,16 +9,16 @@ pub fn kernel_with_generics(output: &mut Array) { } #[cube(launch)] -pub fn kernel_without_generics(output: &mut Array) { - if UNIT_POS == UInt::new(0) { - output[0] = F32::new(5.0); +pub fn kernel_without_generics(output: &mut Array) { + if UNIT_POS == 0 { + output[0] = 5.0; } } pub fn test_kernel_with_generics(client: ComputeClient) { let handle = client.create(f32::as_bytes(&[0.0, 1.0])); - kernel_with_generics::launch::( + kernel_with_generics::launch::( &client, CubeCount::Static(1, 1, 1), CubeDim::default(), diff --git a/crates/cubecl-core/src/runtime_tests/sequence.rs b/crates/cubecl-core/src/runtime_tests/sequence.rs index 119acada..4aedc873 100644 --- a/crates/cubecl-core/src/runtime_tests/sequence.rs +++ b/crates/cubecl-core/src/runtime_tests/sequence.rs @@ -1,16 +1,15 @@ use crate as cubecl; - use cubecl::prelude::*; #[cube(launch)] -pub fn sequence_for_loop(output: &mut Array) { - if UNIT_POS != UInt::new(0) { +pub fn sequence_for_loop(output: &mut Array) { + if UNIT_POS != 0 { return; } - let mut sequence = Sequence::::new(); - sequence.push(F32::new(1.0)); - sequence.push(F32::new(4.0)); + let mut sequence = Sequence::::new(); + sequence.push(1.0); + sequence.push(4.0); for value in sequence { output[0] += value; @@ -18,17 +17,17 @@ pub fn sequence_for_loop(output: &mut Array) { } #[cube(launch)] -pub fn sequence_index(output: &mut Array) { - if UNIT_POS != UInt::new(0) { +pub fn sequence_index(output: &mut Array) { + if UNIT_POS != 0 { return; } - let mut sequence = Sequence::::new(); - sequence.push(F32::new(2.0)); - sequence.push(F32::new(4.0)); + let mut sequence = Sequence::::new(); + sequence.push(2.0); + sequence.push(4.0); - output[0] += *sequence.index(0); - output[0] += *Sequence::index(&sequence, 1); + output[0] += sequence.index(0); + output[0] += Sequence::::index(&sequence, 1); } pub fn test_sequence_for_loop(client: ComputeClient) { diff --git a/crates/cubecl-core/src/runtime_tests/slice.rs b/crates/cubecl-core/src/runtime_tests/slice.rs index c151c16f..87a4b1bc 100644 --- a/crates/cubecl-core/src/runtime_tests/slice.rs +++ b/crates/cubecl-core/src/runtime_tests/slice.rs @@ -2,24 +2,24 @@ use crate as cubecl; use cubecl::prelude::*; #[cube(launch)] -pub fn slice_select(input: &Array, output: &mut Array) { - if UNIT_POS == UInt::new(0) { +pub fn slice_select(input: &Array, output: &mut Array) { + if UNIT_POS == 0 { let slice = input.slice(2, 3); - output[0] = slice[0u32]; + output[0] = slice[0]; } } #[cube(launch)] -pub fn slice_assign(input: &Array, output: &mut Array) { - if UNIT_POS == UInt::new(0) { - let slice_1 = output.slice_mut(2, 3); - slice_1[0] = input[0u32]; +pub fn slice_assign(input: &Array, output: &mut Array) { + if UNIT_POS == 0 { + let slice_1 = &mut output.slice_mut(2, 3); + slice_1[0] = input[0]; } } #[cube(launch)] -pub fn slice_len(input: &Array, output: &mut Array) { - if UNIT_POS == UInt::new(0) { +pub fn slice_len(input: &Array, output: &mut Array) { + if UNIT_POS == 0 { let slice = input.slice(2, 4); let _tmp = slice[0]; // It must be used at least once, otherwise wgpu isn't happy. output[0] = slice.len(); @@ -31,7 +31,7 @@ pub fn test_slice_select(client: ComputeClient()); unsafe { - slice_select::launch::( + slice_select::launch::( &client, CubeCount::Static(1, 1, 1), CubeDim::new(1, 1, 1), @@ -51,7 +51,7 @@ pub fn test_slice_len(client: ComputeClient) let output = client.empty(core::mem::size_of::()); unsafe { - slice_len::launch::( + slice_len::launch::( &client, CubeCount::Static(1, 1, 1), CubeDim::new(1, 1, 1), @@ -71,7 +71,7 @@ pub fn test_slice_assign(client: ComputeClient( + slice_assign::launch::( &client, CubeCount::Static(1, 1, 1), CubeDim::new(1, 1, 1), diff --git a/crates/cubecl-core/src/runtime_tests/subcube.rs b/crates/cubecl-core/src/runtime_tests/subcube.rs index e1232d7f..fc20f2c2 100644 --- a/crates/cubecl-core/src/runtime_tests/subcube.rs +++ b/crates/cubecl-core/src/runtime_tests/subcube.rs @@ -3,61 +3,61 @@ use crate::Feature; use cubecl::prelude::*; #[cube(launch)] -pub fn kernel_sum(output: &mut Tensor) { +pub fn kernel_sum(output: &mut Tensor) { let val = output[UNIT_POS]; - let val2 = subcube_sum::(val); + let val2 = subcube_sum(val); - if UNIT_POS == UInt::new(0) { + if UNIT_POS == 0 { output[0] = val2; } } #[cube(launch)] -pub fn kernel_prod(output: &mut Tensor) { +pub fn kernel_prod(output: &mut Tensor) { let val = output[UNIT_POS]; - let val2 = subcube_prod::(val); + let val2 = subcube_prod(val); - if UNIT_POS == UInt::new(0) { + if UNIT_POS == 0 { output[0] = val2; } } #[cube(launch)] -pub fn kernel_max(output: &mut Tensor) { +pub fn kernel_max(output: &mut Tensor) { let val = output[UNIT_POS]; - let val2 = subcube_max::(val); + let val2 = subcube_max(val); - if UNIT_POS == UInt::new(0) { + if UNIT_POS == 0 { output[0] = val2; } } #[cube(launch)] -pub fn kernel_min(output: &mut Tensor) { +pub fn kernel_min(output: &mut Tensor) { let val = output[UNIT_POS]; - let val2 = subcube_min::(val); + let val2 = subcube_min(val); - if UNIT_POS == UInt::new(0) { + if UNIT_POS == 0 { output[0] = val2; } } #[cube(launch)] -pub fn kernel_all(output: &mut Tensor) { +pub fn kernel_all(output: &mut Tensor) { let val = output[UNIT_POS]; - let val2 = subcube_all(val < 5); - output[UNIT_POS] = F::cast_from(val2); + let val2 = subcube_all(val < 5.0); + output[UNIT_POS] = val2 as u32 as f32; } #[cube(launch)] -pub fn kernel_any(output: &mut Tensor) { +pub fn kernel_any(output: &mut Tensor) { let val = output[UNIT_POS]; - let val2 = subcube_any(val < 5); - output[UNIT_POS] = F::cast_from(val2); + let val2 = subcube_any(val < 5.0); + output[UNIT_POS] = val2 as u32 as f32; } #[cube(launch)] -pub fn kernel_elect(output: &mut Tensor) { +pub fn kernel_elect(output: &mut Tensor) { let val = output[UNIT_POS]; let elect = subcube_elect(); if elect { @@ -66,9 +66,9 @@ pub fn kernel_elect(output: &mut Tensor) { } #[cube(launch)] -pub fn kernel_broadcast(output: &mut Tensor) { +pub fn kernel_broadcast(output: &mut Tensor) { let val = output[UNIT_POS]; - let val2 = subcube_broadcast::(val, UInt::new(2)); + let val2 = subcube_broadcast(val, 2); if UNIT_POS == 0 { output[0] = val2; @@ -82,8 +82,8 @@ pub fn test_subcube_sum( &[4.0, 5.0, 7.0, 1.0], &[17.0, 5.0, 7.0, 1.0], client.clone(), - |cube_count, cube_dim, handle| { - kernel_sum::launch::(&client, cube_count, cube_dim, handle) + |cube_count: CubeCount<::Server>, cube_dim, handle| { + kernel_sum::launch::(&client, cube_count, cube_dim, handle) }, ); } @@ -96,7 +96,7 @@ pub fn test_subcube_prod( &[140.0, 5.0, 7.0, 1.0], client.clone(), |cube_dim, settings, handle| { - kernel_prod::launch::(&client, cube_dim, settings, handle) + kernel_prod::launch::(&client, cube_dim, settings, handle) }, ); } @@ -108,7 +108,7 @@ pub fn test_subcube_max( &[7.0, 5.0, 7.0, 1.0], client.clone(), |cube_dim, settings, handle| { - kernel_max::launch::(&client, cube_dim, settings, handle) + kernel_max::launch::(&client, cube_dim, settings, handle) }, ); } @@ -121,7 +121,7 @@ pub fn test_subcube_min( &[1.0, 5.0, 7.0, 1.0], client.clone(), |cube_dim, settings, handle| { - kernel_min::launch::(&client, cube_dim, settings, handle) + kernel_min::launch::(&client, cube_dim, settings, handle) }, ); } @@ -134,7 +134,7 @@ pub fn test_subcube_all( &[1.0, 1.0, 1.0, 1.0], client.clone(), |cube_dim, settings, handle| { - kernel_all::launch::(&client, cube_dim, settings, handle) + kernel_all::launch::(&client, cube_dim, settings, handle) }, ); test_subcube_operation::( @@ -142,7 +142,7 @@ pub fn test_subcube_all( &[0.0, 0.0, 0.0, 0.0], client.clone(), |cube_dim, settings, handle| { - kernel_all::launch::(&client, cube_dim, settings, handle) + kernel_all::launch::(&client, cube_dim, settings, handle) }, ); } @@ -155,7 +155,7 @@ pub fn test_subcube_any( &[1.0, 1.0, 1.0, 1.0], client.clone(), |cube_dim, settings, handle| { - kernel_any::launch::(&client, cube_dim, settings, handle) + kernel_any::launch::(&client, cube_dim, settings, handle) }, ); test_subcube_operation::( @@ -163,7 +163,7 @@ pub fn test_subcube_any( &[0.0, 0.0, 0.0, 0.0], client.clone(), |cube_dim, settings, handle| { - kernel_any::launch::(&client, cube_dim, settings, handle) + kernel_any::launch::(&client, cube_dim, settings, handle) }, ); } @@ -176,7 +176,7 @@ pub fn test_subcube_elect( &[2.0, 1.0, 1.0, 5.0], client.clone(), |cube_dim, settings, handle| { - kernel_elect::launch::(&client, cube_dim, settings, handle) + kernel_elect::launch::(&client, cube_dim, settings, handle) }, ); } @@ -189,7 +189,7 @@ pub fn test_subcube_broadcast( &[-6.0, 1.0, -6.0, 3.0], client.clone(), |cube_dim, settings, handle| { - kernel_broadcast::launch::(&client, cube_dim, settings, handle) + kernel_broadcast::launch::(&client, cube_dim, settings, handle) }, ); } @@ -251,7 +251,7 @@ macro_rules! testgen_subcube { #[test] fn test_subcube_min() { let client = TestRuntime::client(&Default::default()); - cubecl_core::runtime_tests::subcube::test_subcube_max::(client); + cubecl_core::runtime_tests::subcube::test_subcube_min::(client); } #[test] @@ -266,10 +266,11 @@ macro_rules! testgen_subcube { cubecl_core::runtime_tests::subcube::test_subcube_any::(client); } + #[ignore] #[test] fn test_subcube_elect() { let client = TestRuntime::client(&Default::default()); - cubecl_core::runtime_tests::subcube::test_subcube_any::(client); + cubecl_core::runtime_tests::subcube::test_subcube_elect::(client); } #[test] diff --git a/crates/cubecl-core/src/runtime_tests/topology.rs b/crates/cubecl-core/src/runtime_tests/topology.rs index 86fe0d73..bbdb3aac 100644 --- a/crates/cubecl-core/src/runtime_tests/topology.rs +++ b/crates/cubecl-core/src/runtime_tests/topology.rs @@ -3,7 +3,7 @@ use crate as cubecl; use cubecl::prelude::*; #[cube(launch)] -pub fn kernel_absolute_pos(output1: &mut Array) { +pub fn kernel_absolute_pos(output1: &mut Array) { if ABSOLUTE_POS >= output1.len() { return; } diff --git a/crates/cubecl-core/tests/error/array_variable.rs b/crates/cubecl-core/tests/error/array_variable.rs index 2634175b..00801636 100644 --- a/crates/cubecl-core/tests/error/array_variable.rs +++ b/crates/cubecl-core/tests/error/array_variable.rs @@ -1,8 +1,8 @@ -use cubecl_core as cubecl; use cubecl::prelude::*; +use cubecl_core as cubecl; #[cube] -fn range(x: UInt, y: UInt) { +fn array_variable(x: u32, y: u32) { let _array = [x, y]; } diff --git a/crates/cubecl-core/tests/error/array_variable.stderr b/crates/cubecl-core/tests/error/array_variable.stderr index b942b91b..56a2ed1f 100644 --- a/crates/cubecl-core/tests/error/array_variable.stderr +++ b/crates/cubecl-core/tests/error/array_variable.stderr @@ -1,4 +1,4 @@ -error: Only arrays of literals are supported +error: Array expressions can't be used at runtime --> tests/error/array_variable.rs:6:18 | 6 | let _array = [x, y]; diff --git a/crates/cubecl-core/tests/error/for_loop_range.rs b/crates/cubecl-core/tests/error/for_loop_range.rs deleted file mode 100644 index b8d21a08..00000000 --- a/crates/cubecl-core/tests/error/for_loop_range.rs +++ /dev/null @@ -1,9 +0,0 @@ -use cubecl_core as cubecl; -use cubecl::prelude::*; - -#[cube] -fn range() { - for _ in 0..10 {} -} - -fn main() {} diff --git a/crates/cubecl-core/tests/error/for_loop_range.stderr b/crates/cubecl-core/tests/error/for_loop_range.stderr deleted file mode 100644 index 0a31e86c..00000000 --- a/crates/cubecl-core/tests/error/for_loop_range.stderr +++ /dev/null @@ -1,5 +0,0 @@ -error: Invalid for loop: use [range](cubecl::prelude::range] or [range_stepped](cubecl::prelude::range_stepped) instead. - --> tests/error/for_loop_range.rs:6:14 - | -6 | for _ in 0..10 {} - | ^^^^^ diff --git a/crates/cubecl-core/tests/error/range.rs b/crates/cubecl-core/tests/error/range.rs deleted file mode 100644 index dfe3b696..00000000 --- a/crates/cubecl-core/tests/error/range.rs +++ /dev/null @@ -1,9 +0,0 @@ -use cubecl_core as cubecl; -use cubecl::prelude::*; - -#[cube] -fn range() { - 0..10; -} - -fn main() {} diff --git a/crates/cubecl-core/tests/error/range.stderr b/crates/cubecl-core/tests/error/range.stderr deleted file mode 100644 index c71a420a..00000000 --- a/crates/cubecl-core/tests/error/range.stderr +++ /dev/null @@ -1,5 +0,0 @@ -error: Range is not supported, use [range](cubecl::prelude::range) instead. - --> tests/error/range.rs:6:5 - | -6 | 0..10; - | ^^^^^ diff --git a/crates/cubecl-core/tests/error/return_value.rs b/crates/cubecl-core/tests/error/return_value.rs index 73021b07..ab598961 100644 --- a/crates/cubecl-core/tests/error/return_value.rs +++ b/crates/cubecl-core/tests/error/return_value.rs @@ -2,7 +2,7 @@ use cubecl_core as cubecl; use cubecl_core::prelude::*; #[cube] -fn range(x: UInt, y: UInt) -> UInt { +fn return_value(x: u32, y: u32) -> u32 { if x == y { return x; } diff --git a/crates/cubecl-core/tests/error/return_value.stderr b/crates/cubecl-core/tests/error/return_value.stderr index 770df6a4..9c7a44a5 100644 --- a/crates/cubecl-core/tests/error/return_value.stderr +++ b/crates/cubecl-core/tests/error/return_value.stderr @@ -1,5 +1,5 @@ error: Only void return is supported. - --> tests/error/return_value.rs:7:9 + --> tests/error/return_value.rs:7:16 | 7 | return x; - | ^^^^^^^^ + | ^ diff --git a/crates/cubecl-core/tests/error/undeclared_variable.rs b/crates/cubecl-core/tests/error/undeclared_variable.rs deleted file mode 100644 index 6aeca06a..00000000 --- a/crates/cubecl-core/tests/error/undeclared_variable.rs +++ /dev/null @@ -1,10 +0,0 @@ -use cubecl_core as cubecl; -use cubecl::prelude::*; - -#[cube] -fn kernel(x: UInt) { - if x == y { - } -} - -fn main() {} diff --git a/crates/cubecl-core/tests/error/undeclared_variable.stderr b/crates/cubecl-core/tests/error/undeclared_variable.stderr deleted file mode 100644 index 1b9297d1..00000000 --- a/crates/cubecl-core/tests/error/undeclared_variable.stderr +++ /dev/null @@ -1,11 +0,0 @@ -error: Variable not declared - --> tests/error/undeclared_variable.rs:6:13 - | -6 | if x == y { - | ^ - -error[E0425]: cannot find value `y` in this scope - --> tests/error/undeclared_variable.rs:6:13 - | -6 | if x == y { - | ^ help: a local variable with a similar name exists: `x` diff --git a/crates/cubecl-core/tests/frontend/array.rs b/crates/cubecl-core/tests/frontend/array.rs index 0d7b5a23..1d2ba1a6 100644 --- a/crates/cubecl-core/tests/frontend/array.rs +++ b/crates/cubecl-core/tests/frontend/array.rs @@ -1,11 +1,11 @@ +use cubecl::prelude::*; use cubecl_core as cubecl; -use cubecl_core::prelude::*; #[cube] -pub fn array_read_write(array_size: Comptime) { +pub fn array_read_write(#[comptime] array_size: u32) { let mut array = Array::::new(array_size); array[0] = T::from_int(3); - let _ = array[0]; + let _a = array[0]; } #[cube] @@ -13,40 +13,43 @@ pub fn array_to_vectorized_variable() -> T { let mut array = Array::::new(2); array[0] = T::from_int(0); array[1] = T::from_int(1); - array.to_vectorized(Comptime::new(UInt::new(2))) + array.to_vectorized(2) } #[cube] pub fn array_of_one_to_vectorized_variable() -> T { let mut array = Array::::new(1); array[0] = T::from_int(3); - array.to_vectorized(Comptime::new(UInt::new(1))) + array.to_vectorized(1) } #[cube] -pub fn array_add_assign_simple(array: &mut Array) { - array[UInt::new(1)] += UInt::new(1); +pub fn array_add_assign_simple(array: &mut Array) { + array[1] += 1; } #[cube] -pub fn array_add_assign_expr(array: &mut Array) { - array[UInt::new(1) + UInt::new(5)] += UInt::new(1); +pub fn array_add_assign_expr(array: &mut Array) { + array[1 + 5] += 1; } mod tests { + use pretty_assertions::assert_eq; + use std::num::NonZero; + use super::*; use cubecl_core::{ cpa, ir::{self, Elem, Item, Variable}, }; - type ElemType = F32; + type ElemType = f32; #[test] fn cube_support_array() { let mut context = CubeContext::root(); - array_read_write::__expand::(&mut context, 512); + array_read_write::expand::(&mut context, 512); assert_eq!( context.into_scope().operations, inline_macro_ref_read_write() @@ -58,7 +61,7 @@ mod tests { let mut context = CubeContext::root(); let array = context.input(0, Item::new(Elem::UInt)); - array_add_assign_simple::__expand(&mut context, array.into()); + array_add_assign_simple::expand(&mut context, array.into()); let scope = context.into_scope(); assert_eq!(scope.operations, inline_macro_array_add_assign_simple()); @@ -68,7 +71,7 @@ mod tests { fn cube_array_to_vectorized() { let mut context = CubeContext::root(); - array_to_vectorized_variable::__expand::(&mut context); + array_to_vectorized_variable::expand::(&mut context); assert_eq!( context.into_scope().operations, inline_macro_ref_to_vectorized() @@ -79,7 +82,7 @@ mod tests { fn cube_array_of_one_to_vectorized() { let mut context = CubeContext::root(); - array_of_one_to_vectorized_variable::__expand::(&mut context); + array_of_one_to_vectorized_variable::expand::(&mut context); assert_eq!( context.into_scope().operations, inline_macro_ref_one_to_vectorized() @@ -111,7 +114,7 @@ mod tests { let mut context = CubeContext::root(); let array = context.input(0, Item::new(Elem::UInt)); - array_add_assign_expr::__expand(&mut context, array.into()); + array_add_assign_expr::expand(&mut context, array.into()); let scope = context.into_scope(); assert_eq!(scope.operations, inline_macro_array_add_assign_expr()); @@ -140,7 +143,7 @@ mod tests { fn inline_macro_ref_to_vectorized() -> Vec { let context = CubeContext::root(); let scalar_item = Item::new(ElemType::as_elem()); - let vectorized_item = Item::vectorized(ElemType::as_elem(), 2); + let vectorized_item = Item::vectorized(ElemType::as_elem(), NonZero::new(2)); let mut scope = context.into_scope(); let pos0: Variable = 0u32.into(); @@ -162,7 +165,7 @@ mod tests { fn inline_macro_ref_one_to_vectorized() -> Vec { let context = CubeContext::root(); let scalar_item = Item::new(ElemType::as_elem()); - let unvectorized_item = Item::new(ElemType::as_elem()); + let unvectorized_item = Item::vectorized(ElemType::as_elem(), NonZero::new(1)); let mut scope = context.into_scope(); let pos0: Variable = 0u32.into(); @@ -181,18 +184,15 @@ mod tests { let context = CubeContext::root(); let mut scope = context.into_scope(); - let index = scope.create_local(Item::new(Elem::UInt)); let local = scope.create_local(Item::new(Elem::UInt)); let array = Variable::GlobalInputArray { id: 0, item: Item::new(Elem::UInt), }; - let const1: Variable = 1u32.into(); - let const2: Variable = 5u32.into(); + let index: Variable = 6u32.into(); let value: Variable = 1u32.into(); - cpa!(scope, index = const1 + const2); cpa!(scope, local = array[index]); cpa!(scope, local += value); cpa!(scope, array[index] = local); diff --git a/crates/cubecl-core/tests/frontend/assign.rs b/crates/cubecl-core/tests/frontend/assign.rs index 3b807ff3..90b513bb 100644 --- a/crates/cubecl-core/tests/frontend/assign.rs +++ b/crates/cubecl-core/tests/frontend/assign.rs @@ -1,40 +1,44 @@ +#![allow(unused)] + use cubecl_core as cubecl; use cubecl_core::prelude::*; #[cube] pub fn mut_assign() { - let mut x = UInt::new(0); - x += UInt::new(1); + let mut x: u32 = 0; + x += 1; } #[cube] -pub fn mut_assign_input(y: UInt) -> UInt { +pub fn mut_assign_input(y: u32) -> u32 { let mut x = y; - x += UInt::new(1); - y + UInt::new(2) + x += 1; + y + 2 } #[cube] -pub fn assign_mut_input(mut y: UInt) -> UInt { +pub fn assign_mut_input(mut y: u32) -> u32 { let x = y; - y += UInt::new(1); - x + UInt::new(2) + y += 1; + x + 2 } #[cube] -pub fn assign_vectorized(y: UInt) -> UInt { - let vectorization_factor = Comptime::vectorization(&y); - let x = UInt::vectorized(1, Comptime::get(vectorization_factor)); +pub fn assign_vectorized(y: u32) -> u32 { + let x = u32::vectorized(1, vectorization_of(&y)); x + y } #[cube] -pub fn assign_deref(y: &mut UInt) -> UInt { - *y = UInt::new(1); +pub fn assign_deref(y: &mut u32) -> u32 { + *y = 1; *y } mod tests { + use pretty_assertions::assert_eq; + use std::num::NonZero; + use super::*; use cubecl_core::{ cpa, @@ -45,7 +49,7 @@ mod tests { fn cube_mut_assign_test() { let mut context = CubeContext::root(); - mut_assign::__expand(&mut context); + mut_assign::expand(&mut context); let scope = context.into_scope(); assert_eq!(scope.operations, inline_macro_ref_mut_assign()); @@ -55,9 +59,9 @@ mod tests { fn cube_mut_assign_input_test() { let mut context = CubeContext::root(); - let y = context.create_local(Item::new(UInt::as_elem())); + let y = context.create_local(Item::new(u32::as_elem())); - mut_assign_input::__expand(&mut context, y.into()); + mut_assign_input::expand(&mut context, y.into()); let scope = context.into_scope(); assert_eq!(scope.operations, inline_macro_ref_mut_assign_input()); @@ -67,9 +71,9 @@ mod tests { fn cube_assign_mut_input_test() { let mut context = CubeContext::root(); - let y = context.create_local(Item::new(UInt::as_elem())); + let y = context.create_local(Item::new(u32::as_elem())); - assign_mut_input::__expand(&mut context, y.into()); + assign_mut_input::expand(&mut context, y.into()); let scope = context.into_scope(); assert_eq!(scope.operations, inline_macro_ref_assign_mut_input()); @@ -79,9 +83,9 @@ mod tests { fn cube_assign_vectorized_test() { let mut context = CubeContext::root(); - let y = context.create_local(Item::vectorized(UInt::as_elem(), 4)); + let y = context.create_local(Item::vectorized(u32::as_elem(), NonZero::new(4))); - assign_vectorized::__expand(&mut context, y.into()); + assign_vectorized::expand(&mut context, y.into()); let scope = context.into_scope(); assert_eq!(scope.operations, inline_macro_ref_assign_vectorized()); @@ -91,8 +95,8 @@ mod tests { fn cube_assign_deref_test() { let mut context = CubeContext::root(); - let y = context.create_local(Item::new(UInt::as_elem())); - assign_deref::__expand(&mut context, y.into()); + let y = context.create_local(Item::new(u32::as_elem())); + assign_deref::expand(&mut context, y.into()); let scope = context.into_scope(); @@ -154,7 +158,7 @@ mod tests { fn inline_macro_ref_assign_vectorized() -> Vec { let mut context = CubeContext::root(); - let item = Item::vectorized(Elem::UInt, 4); + let item = Item::vectorized(Elem::UInt, NonZero::new(4)); let y = context.create_local(item); let mut scope = context.into_scope(); diff --git a/crates/cubecl-core/tests/frontend/cast_elem.rs b/crates/cubecl-core/tests/frontend/cast_elem.rs index ca91bc5b..23eb15f2 100644 --- a/crates/cubecl-core/tests/frontend/cast_elem.rs +++ b/crates/cubecl-core/tests/frontend/cast_elem.rs @@ -1,117 +1,117 @@ use cubecl_core as cubecl; use cubecl_core::{ cube, - frontend::{Bool, BoolOps, Cast, Numeric, UInt, F32, I32}, + frontend::{Cast, Numeric}, }; // From float #[cube] -pub fn float_to_float(x: F32) { - let y = x + F32::from_int(2); - let _ = F32::cast_from(y) + F32::from_int(34); +pub fn float_to_float(x: f32) { + let y = x + f32::from_int(2); + let _ = f32::cast_from(y) + f32::from_int(34); } #[cube] -pub fn float_to_int(x: F32) { - let y = x + F32::from_int(2); - let _ = I32::cast_from(y) + I32::from_int(34); +pub fn float_to_int(x: f32) { + let y = x + f32::from_int(2); + let _ = i32::cast_from(y) + i32::from_int(34); } #[cube] -pub fn float_to_uint(x: F32) { - let y = x + F32::from_int(2); - let _ = UInt::cast_from(y) + UInt::from_int(34); +pub fn float_to_u32(x: f32) { + let y = x + f32::from_int(2); + let _ = u32::cast_from(y) + u32::from_int(34); } #[cube] #[allow(clippy::overly_complex_bool_expr)] -pub fn float_to_bool(x: F32) { - let y = x + F32::from_int(2); - let _ = Bool::cast_from(y) || Bool::new(true); +pub fn float_to_bool(x: f32) { + let y = x + f32::from_int(2); + let _ = bool::cast_from(y) || true; } // From int #[cube] -pub fn int_to_float(x: I32) { - let y = x + I32::from_int(2); - let _ = F32::cast_from(y) + F32::from_int(34); +pub fn int_to_float(x: i32) { + let y = x + i32::from_int(2); + let _ = f32::cast_from(y) + f32::from_int(34); } #[cube] #[allow(clippy::useless_conversion)] -pub fn int_to_int(x: I32) { - let y = x + I32::from_int(2); - let _ = I32::cast_from(y) + I32::from_int(34); +pub fn int_to_int(x: i32) { + let y = x + i32::from_int(2); + let _ = i32::cast_from(y) + i32::from_int(34); } #[cube] -pub fn int_to_uint(x: I32) { - let y = x + I32::from_int(2); - let _ = UInt::cast_from(y) + UInt::from_int(34); +pub fn int_to_u32(x: i32) { + let y = x + i32::from_int(2); + let _ = u32::cast_from(y) + u32::from_int(34); } #[cube] #[allow(clippy::overly_complex_bool_expr)] -pub fn int_to_bool(x: I32) { - let y = x + I32::from_int(2); - let _ = Bool::cast_from(y) || Bool::new(true); +pub fn int_to_bool(x: i32) { + let y = x + i32::from_int(2); + let _ = bool::cast_from(y) || true; } -// // From uint +// // From u32 #[cube] -pub fn uint_to_float(x: UInt) { - let y = x + UInt::from_int(2); - let _ = F32::cast_from(y) + F32::from_int(34); +pub fn u32_to_float(x: u32) { + let y = x + u32::from_int(2); + let _ = f32::cast_from(y) + f32::from_int(34); } #[cube] -pub fn uint_to_int(x: UInt) { - let y = x + UInt::from_int(2); - let _ = I32::cast_from(y) + I32::from_int(34); +pub fn u32_to_int(x: u32) { + let y = x + u32::from_int(2); + let _ = i32::cast_from(y) + i32::from_int(34); } #[cube] #[allow(clippy::useless_conversion)] -pub fn uint_to_uint(x: UInt) { - let y = x + UInt::from_int(2); - let _ = UInt::cast_from(y) + UInt::from_int(34); +pub fn u32_to_u32(x: u32) { + let y = x + u32::from_int(2); + let _ = u32::cast_from(y) + u32::from_int(34); } #[cube] #[allow(clippy::overly_complex_bool_expr)] -pub fn uint_to_bool(x: UInt) { - let y = x + UInt::from_int(2); - let _ = Bool::cast_from(y) || Bool::new(true); +pub fn u32_to_bool(x: u32) { + let y = x + u32::from_int(2); + let _ = bool::cast_from(y) || true; } // From bool #[cube] #[allow(clippy::overly_complex_bool_expr)] -pub fn bool_to_float(x: Bool) { - let y = x && Bool::new(false); - let _ = F32::cast_from(y) + F32::from_int(34); +pub fn bool_to_float(x: bool) { + let y = x && false; + let _ = f32::cast_from(y) + f32::from_int(34); } #[cube] #[allow(clippy::overly_complex_bool_expr)] -pub fn bool_to_int(x: Bool) { - let y = x && Bool::new(false); - let _ = I32::cast_from(y) + I32::from_int(34); +pub fn bool_to_int(x: bool) { + let y = x && false; + let _ = i32::cast_from(y) + i32::from_int(34); } #[cube] #[allow(clippy::overly_complex_bool_expr)] -pub fn bool_to_uint(x: Bool) { - let y = x && Bool::new(false); - let _ = UInt::cast_from(y) + UInt::from_int(34); +pub fn bool_to_u32(x: bool) { + let y = x && false; + let _ = u32::cast_from(y) + u32::from_int(34); } #[cube] #[allow(clippy::overly_complex_bool_expr)] #[allow(clippy::useless_conversion)] -pub fn bool_to_bool(x: Bool) { - let y = x && Bool::new(false); - let _ = Bool::cast_from(y) || Bool::new(true); +pub fn bool_to_bool(x: bool) { + let y = x && false; + let _ = bool::cast_from(y) || true; } mod tests { @@ -143,112 +143,112 @@ mod tests { cast_test!( cube_float_to_float_test, - float_to_float::__expand, - Item::new(F32::as_elem()), - Item::new(F32::as_elem()) + float_to_float::expand, + Item::new(f32::as_elem()), + Item::new(f32::as_elem()) ); cast_test!( cube_float_to_int_test, - float_to_int::__expand, - Item::new(F32::as_elem()), - Item::new(I32::as_elem()) + float_to_int::expand, + Item::new(f32::as_elem()), + Item::new(i32::as_elem()) ); cast_test!( - cube_float_to_uint_test, - float_to_uint::__expand, - Item::new(F32::as_elem()), + cube_float_to_u32_test, + float_to_u32::expand, + Item::new(f32::as_elem()), Item::new(Elem::UInt) ); cast_test!( cube_float_to_bool_test, - float_to_bool::__expand, - Item::new(F32::as_elem()), + float_to_bool::expand, + Item::new(f32::as_elem()), Item::new(Elem::Bool) ); cast_test!( cube_int_to_float_test, - int_to_float::__expand, - Item::new(I32::as_elem()), - Item::new(F32::as_elem()) + int_to_float::expand, + Item::new(i32::as_elem()), + Item::new(f32::as_elem()) ); cast_test!( cube_int_to_int_test, - int_to_int::__expand, - Item::new(I32::as_elem()), - Item::new(I32::as_elem()) + int_to_int::expand, + Item::new(i32::as_elem()), + Item::new(i32::as_elem()) ); cast_test!( - cube_int_to_uint_test, - int_to_uint::__expand, - Item::new(I32::as_elem()), + cube_int_to_u32_test, + int_to_u32::expand, + Item::new(i32::as_elem()), Item::new(Elem::UInt) ); cast_test!( cube_int_to_bool_test, - int_to_bool::__expand, - Item::new(I32::as_elem()), + int_to_bool::expand, + Item::new(i32::as_elem()), Item::new(Elem::Bool) ); cast_test!( - cube_uint_to_float_test, - uint_to_float::__expand, + cube_u32_to_float_test, + u32_to_float::expand, Item::new(Elem::UInt), - Item::new(F32::as_elem()) + Item::new(f32::as_elem()) ); cast_test!( - cube_uint_to_int_test, - uint_to_int::__expand, + cube_u32_to_int_test, + u32_to_int::expand, Item::new(Elem::UInt), - Item::new(I32::as_elem()) + Item::new(i32::as_elem()) ); cast_test!( - cube_uint_to_uint_test, - uint_to_uint::__expand, + cube_u32_to_u32_test, + u32_to_u32::expand, Item::new(Elem::UInt), Item::new(Elem::UInt) ); cast_test!( - cube_uint_to_bool_test, - uint_to_bool::__expand, + cube_u32_to_bool_test, + u32_to_bool::expand, Item::new(Elem::UInt), Item::new(Elem::Bool) ); cast_test!( cube_bool_to_float_test, - bool_to_float::__expand, + bool_to_float::expand, Item::new(Elem::Bool), - Item::new(F32::as_elem()) + Item::new(f32::as_elem()) ); cast_test!( cube_bool_to_int_test, - bool_to_int::__expand, + bool_to_int::expand, Item::new(Elem::Bool), - Item::new(I32::as_elem()) + Item::new(i32::as_elem()) ); cast_test!( - cube_bool_to_uint_test, - bool_to_uint::__expand, + cube_bool_to_u32_test, + bool_to_u32::expand, Item::new(Elem::Bool), Item::new(Elem::UInt) ); cast_test!( cube_bool_to_bool_test, - bool_to_bool::__expand, + bool_to_bool::expand, Item::new(Elem::Bool), Item::new(Elem::Bool) ); diff --git a/crates/cubecl-core/tests/frontend/cast_kind.rs b/crates/cubecl-core/tests/frontend/cast_kind.rs index 8a191800..dd5fa410 100644 --- a/crates/cubecl-core/tests/frontend/cast_kind.rs +++ b/crates/cubecl-core/tests/frontend/cast_kind.rs @@ -36,18 +36,18 @@ mod tests { use super::*; use cubecl_core::{ cpa, - frontend::{CubeContext, CubePrimitive, F32, F64, I32, I64}, + frontend::{CubeContext, CubePrimitive}, ir::{Item, Variable}, }; #[test] fn cube_cast_float_kind_test() { let mut context = CubeContext::root(); - let item = Item::new(F64::as_elem()); + let item = Item::new(f64::as_elem()); let input = context.create_local(item); - cast_float_kind::__expand::(&mut context, input.into()); + cast_float_kind::expand::(&mut context, input.into()); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_float()); @@ -56,11 +56,11 @@ mod tests { #[test] fn cube_cast_int_kind_test() { let mut context = CubeContext::root(); - let item = Item::new(I32::as_elem()); + let item = Item::new(i32::as_elem()); let input = context.create_local(item); - cast_int_kind::__expand::(&mut context, input.into()); + cast_int_kind::expand::(&mut context, input.into()); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_int()); @@ -69,11 +69,11 @@ mod tests { #[test] fn cube_cast_numeric_kind_test() { let mut context = CubeContext::root(); - let item = Item::new(I32::as_elem()); + let item = Item::new(i32::as_elem()); let input = context.create_local(item); - cast_numeric_to_kind::__expand::(&mut context, input.into()); + cast_numeric_to_kind::expand::(&mut context, input.into()); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_int()); @@ -82,11 +82,11 @@ mod tests { #[test] fn cube_cast_kind_numeric_test() { let mut context = CubeContext::root(); - let item = Item::new(I32::as_elem()); + let item = Item::new(i32::as_elem()); let input = context.create_local(item); - cast_int_to_numeric::__expand::(&mut context, input.into()); + cast_int_to_numeric::expand::(&mut context, input.into()); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_int()); @@ -94,8 +94,8 @@ mod tests { fn inline_macro_ref_float() -> String { let mut context = CubeContext::root(); - let float_64 = Item::new(F64::as_elem()); - let float_32 = Item::new(F32::as_elem()); + let float_64 = Item::new(f64::as_elem()); + let float_32 = Item::new(f32::as_elem()); let input = context.create_local(float_64); let mut scope = context.into_scope(); @@ -111,8 +111,8 @@ mod tests { fn inline_macro_ref_int() -> String { let mut context = CubeContext::root(); - let int_32 = Item::new(I32::as_elem()); - let int_64 = Item::new(I64::as_elem()); + let int_32 = Item::new(i32::as_elem()); + let int_64 = Item::new(i64::as_elem()); let input = context.create_local(int_32); let mut scope = context.into_scope(); diff --git a/crates/cubecl-core/tests/frontend/comptime.rs b/crates/cubecl-core/tests/frontend/comptime.rs index c4f1c36b..d8714b60 100644 --- a/crates/cubecl-core/tests/frontend/comptime.rs +++ b/crates/cubecl-core/tests/frontend/comptime.rs @@ -1,5 +1,5 @@ -use cubecl_core as cubecl; use cubecl_core::prelude::*; +use cubecl_core::{self as cubecl, comptime}; #[derive(Clone)] pub struct State { @@ -14,8 +14,8 @@ impl Init for State { } #[cube] -pub fn comptime_if_else(lhs: T, cond: Comptime) { - if Comptime::get(cond) { +pub fn comptime_if_else(lhs: T, #[comptime] cond: bool) { + if cond { let _ = lhs + T::from_int(4); } else { let _ = lhs - T::from_int(5); @@ -24,11 +24,11 @@ pub fn comptime_if_else(lhs: T, cond: Comptime) { #[cube] #[allow(clippy::collapsible_else_if)] -pub fn comptime_else_then_if(lhs: T, cond1: Comptime, cond2: Comptime) { - if Comptime::get(cond1) { +pub fn comptime_else_then_if(lhs: T, #[comptime] cond1: bool, #[comptime] cond2: bool) { + if cond1 { let _ = lhs + T::from_int(4); } else { - if Comptime::get(cond2) { + if cond2 { let _ = lhs + T::from_int(5); } else { let _ = lhs - T::from_int(6); @@ -38,15 +38,15 @@ pub fn comptime_else_then_if(lhs: T, cond1: Comptime, cond2: C #[cube] pub fn comptime_float() { - let comptime_float = Comptime::new(F32::new(0.0)); - let _runtime_float = Comptime::runtime(comptime_float); + let comptime_float = 0.0f32; + let _runtime_float = comptime_float.runtime(); } #[cube] -pub fn comptime_elsif(lhs: T, cond1: Comptime, cond2: Comptime) { - if Comptime::get(cond1) { +pub fn comptime_elsif(lhs: T, #[comptime] cond1: bool, #[comptime] cond2: bool) { + if cond1 { let _ = lhs + T::from_int(4); - } else if Comptime::get(cond2) { + } else if cond2 { let _ = lhs + T::from_int(5); } else { let _ = lhs - T::from_int(6); @@ -54,9 +54,9 @@ pub fn comptime_elsif(lhs: T, cond1: Comptime, cond2: Comptime } #[cube] -pub fn comptime_elsif_with_runtime1(lhs: T, comptime_cond: Comptime) { +pub fn comptime_elsif_with_runtime1(lhs: T, #[comptime] comptime_cond: bool) { let runtime_cond = lhs >= T::from_int(2); - if Comptime::get(comptime_cond) { + if comptime_cond { let _ = lhs + T::from_int(4); } else if runtime_cond { let _ = lhs + T::from_int(5); @@ -66,11 +66,11 @@ pub fn comptime_elsif_with_runtime1(lhs: T, comptime_cond: Comptime< } #[cube] -pub fn comptime_elsif_with_runtime2(lhs: T, comptime_cond: Comptime) { +pub fn comptime_elsif_with_runtime2(lhs: T, #[comptime] comptime_cond: bool) { let runtime_cond = lhs >= T::from_int(2); if runtime_cond { let _ = lhs + T::from_int(4); - } else if Comptime::get(comptime_cond) { + } else if comptime_cond { let _ = lhs + T::from_int(5); } else { let _ = lhs - T::from_int(6); @@ -78,7 +78,7 @@ pub fn comptime_elsif_with_runtime2(lhs: T, comptime_cond: Comptime< } #[cube] -pub fn comptime_if_expr(lhs: T, x: Comptime, y: Comptime) { +pub fn comptime_if_expr(lhs: T, #[comptime] x: u32, #[comptime] y: u32) { let y2 = x + y; if x < y2 { @@ -89,11 +89,11 @@ pub fn comptime_if_expr(lhs: T, x: Comptime, y: Comptime } #[cube] -pub fn comptime_with_map_bool(state: Comptime) -> T { - let cond = Comptime::map(state, |s: State| s.cond); +pub fn comptime_with_map_bool(#[comptime] state: State) -> T { + let cond = state.cond; let mut x = T::from_int(3); - if Comptime::get(cond) { + if cond { x += T::from_int(4); } else { x -= T::from_int(4); @@ -102,26 +102,39 @@ pub fn comptime_with_map_bool(state: Comptime) -> T { } #[cube] -pub fn comptime_with_map_uint(state: Comptime) -> T { - let bound = Comptime::map(state, |s: State| s.bound); +pub fn comptime_with_map_uint(#[comptime] state: State) -> T { + let bound = state.bound; let mut x = T::from_int(3); - for _ in range(0u32, Comptime::get(bound), Comptime::new(true)) { + #[unroll] + for _ in 0..bound { x += T::from_int(4); } x } +fn rust_function(input: u32) -> u32 { + input + 2 +} + +#[cube] +pub fn comptime_block(a: T) -> T { + let comptime_val = comptime! { rust_function(2) as i64 }; + + a + T::from_int(comptime_val) +} + mod tests { use super::*; use cubecl_core::{ cpa, - frontend::{CubeContext, CubePrimitive, F32}, + frontend::{CubeContext, CubePrimitive}, ir::{Elem, Item, Variable}, }; + use pretty_assertions::assert_eq; - type ElemType = F32; + type ElemType = f32; #[test] fn cube_comptime_if_test() { @@ -129,7 +142,7 @@ mod tests { let lhs = context.create_local(Item::new(ElemType::as_elem())); - comptime_if_else::__expand::(&mut context, lhs.into(), true); + comptime_if_else::expand::(&mut context, lhs.into(), true); let scope = context.into_scope(); assert_eq!( @@ -144,12 +157,7 @@ mod tests { let lhs = context.create_local(Item::new(ElemType::as_elem())); - comptime_if_expr::__expand::( - &mut context, - lhs.into(), - UInt::new(4), - UInt::new(5), - ); + comptime_if_expr::expand::(&mut context, lhs.into(), 4, 5); let scope = context.into_scope(); assert_eq!( @@ -164,12 +172,12 @@ mod tests { let lhs = context.create_local(Item::new(ElemType::as_elem())); - comptime_if_else::__expand::(&mut context, lhs.into(), false); + comptime_if_else::expand::(&mut context, lhs.into(), false); let scope = context.into_scope(); assert_eq!( format!("{:?}", scope.operations), - inline_macro_ref_comptime(false) + inline_macro_ref_comptime2(false) ); } @@ -179,17 +187,12 @@ mod tests { for cond2 in [false, true] { let mut context1 = CubeContext::root(); let lhs = context1.create_local(Item::new(ElemType::as_elem())); - comptime_else_then_if::__expand::( - &mut context1, - lhs.into(), - cond1, - cond2, - ); + comptime_else_then_if::expand::(&mut context1, lhs.into(), cond1, cond2); let scope1 = context1.into_scope(); let mut context2 = CubeContext::root(); let lhs = context2.create_local(Item::new(ElemType::as_elem())); - comptime_elsif::__expand::(&mut context2, lhs.into(), cond1, cond2); + comptime_elsif::expand::(&mut context2, lhs.into(), cond1, cond2); let scope2 = context2.into_scope(); assert_eq!( @@ -205,7 +208,7 @@ mod tests { for cond in [false, true] { let mut context = CubeContext::root(); let lhs = context.create_local(Item::new(ElemType::as_elem())); - comptime_elsif_with_runtime1::__expand::(&mut context, lhs.into(), cond); + comptime_elsif_with_runtime1::expand::(&mut context, lhs.into(), cond); let scope = context.into_scope(); assert_eq!( @@ -221,7 +224,7 @@ mod tests { let mut context = CubeContext::root(); let lhs = context.create_local(Item::new(ElemType::as_elem())); - comptime_elsif_with_runtime2::__expand::(&mut context, lhs.into(), cond); + comptime_elsif_with_runtime2::expand::(&mut context, lhs.into(), cond); let scope = context.into_scope(); assert_eq!( @@ -245,8 +248,8 @@ mod tests { bound: 4, }; - comptime_with_map_bool::__expand::(&mut context1, comptime_state_true); - comptime_with_map_bool::__expand::(&mut context2, comptime_state_false); + comptime_with_map_bool::expand::(&mut context1, comptime_state_true); + comptime_with_map_bool::expand::(&mut context2, comptime_state_false); let scope1 = context1.into_scope(); let scope2 = context2.into_scope(); @@ -266,13 +269,29 @@ mod tests { bound: 4, }; - comptime_with_map_uint::__expand::(&mut context, comptime_state); + comptime_with_map_uint::expand::(&mut context, comptime_state); let scope = context.into_scope(); assert!(!format!("{:?}", scope.operations).contains("RangeLoop")); } + #[test] + fn cube_comptime_block_test() { + let mut context = CubeContext::root(); + + let a = context.create_local(Item::new(ElemType::as_elem())); + + comptime_block::expand::(&mut context, a.into()); + + let scope = context.into_scope(); + + assert_eq!( + format!("{:?}", scope.operations), + inline_macro_ref_comptime_block() + ); + } + fn inline_macro_ref_comptime(cond: bool) -> String { let mut context = CubeContext::root(); let item = Item::new(ElemType::as_elem()); @@ -291,6 +310,23 @@ mod tests { format!("{:?}", scope.operations) } + fn inline_macro_ref_comptime2(cond: bool) -> String { + let mut context = CubeContext::root(); + let item = Item::new(ElemType::as_elem()); + let x = context.create_local(item); + + let mut scope = context.into_scope(); + let x: Variable = x.into(); + + if cond { + cpa!(scope, x = x + 4.0f32); + } else { + cpa!(scope, x = x - 5.0f32); + }; + + format!("{:?}", scope.operations) + } + fn inline_macro_ref_elsif_runtime1(comptime_cond: bool) -> String { let mut context = CubeContext::root(); let item = Item::new(ElemType::as_elem()); @@ -338,4 +374,17 @@ mod tests { format!("{:?}", scope.operations) } + + fn inline_macro_ref_comptime_block() -> String { + let mut context = CubeContext::root(); + let item = Item::new(ElemType::as_elem()); + let a = context.create_local(item); + let comptime_var: Variable = ElemType::from_int(4).into(); + + let mut scope = context.into_scope(); + let x: Variable = a.into(); + cpa!(scope, x = x + comptime_var); + + format!("{:?}", scope.operations) + } } diff --git a/crates/cubecl-core/tests/frontend/cube_trait.rs b/crates/cubecl-core/tests/frontend/cube_trait.rs index 9135b61e..5cf0dfb6 100644 --- a/crates/cubecl-core/tests/frontend/cube_trait.rs +++ b/crates/cubecl-core/tests/frontend/cube_trait.rs @@ -60,10 +60,10 @@ mod tests { #[test] fn test_function_generic() { let mut context = CubeContext::root(); - let lhs = context.create_local(Item::new(F32::as_elem())); - let rhs = context.create_local(Item::new(F32::as_elem())); + let lhs = context.create_local(Item::new(f32::as_elem())); + let rhs = context.create_local(Item::new(f32::as_elem())); - ::__expand_test::(&mut context, lhs.into(), rhs.into()); + ::__expand_test::(&mut context, lhs.into(), rhs.into()); assert_eq!(simple_scope(), context.into_scope()); } @@ -71,10 +71,10 @@ mod tests { #[test] fn test_trait_generic() { let mut context = CubeContext::root(); - let lhs = context.create_local(Item::new(F32::as_elem())); - let rhs = context.create_local(Item::new(F32::as_elem())); + let lhs = context.create_local(Item::new(f32::as_elem())); + let rhs = context.create_local(Item::new(f32::as_elem())); - >::__expand_test(&mut context, lhs.into(), rhs.into()); + >::__expand_test(&mut context, lhs.into(), rhs.into()); assert_eq!(simple_scope(), context.into_scope()); } @@ -82,10 +82,10 @@ mod tests { #[test] fn test_combined_function_generic() { let mut context = CubeContext::root(); - let lhs = context.create_local(Item::new(F32::as_elem())); - let rhs = context.create_local(Item::new(F32::as_elem())); + let lhs = context.create_local(Item::new(f32::as_elem())); + let rhs = context.create_local(Item::new(f32::as_elem())); - >::__expand_test::( + >::__expand_test::( &mut context, lhs.into(), rhs.into(), @@ -96,19 +96,19 @@ mod tests { fn simple_scope() -> Scope { let mut context_ref = CubeContext::root(); - let lhs = context_ref.create_local(Item::new(F32::as_elem())); - let rhs = context_ref.create_local(Item::new(F32::as_elem())); + let lhs = context_ref.create_local(Item::new(f32::as_elem())); + let rhs = context_ref.create_local(Item::new(f32::as_elem())); - simple::__expand::(&mut context_ref, lhs.into(), rhs.into()); + simple::expand::(&mut context_ref, lhs.into(), rhs.into()); context_ref.into_scope() } fn with_cast_scope() -> Scope { let mut context_ref = CubeContext::root(); - let lhs = context_ref.create_local(Item::new(F32::as_elem())); - let rhs = context_ref.create_local(Item::new(F32::as_elem())); + let lhs = context_ref.create_local(Item::new(f32::as_elem())); + let rhs = context_ref.create_local(Item::new(f32::as_elem())); - with_cast::__expand::(&mut context_ref, lhs.into(), rhs.into()); + with_cast::expand::(&mut context_ref, lhs.into(), rhs.into()); context_ref.into_scope() } } diff --git a/crates/cubecl-core/tests/frontend/for_loop.rs b/crates/cubecl-core/tests/frontend/for_loop.rs index ba8317d0..18cc9681 100644 --- a/crates/cubecl-core/tests/frontend/for_loop.rs +++ b/crates/cubecl-core/tests/frontend/for_loop.rs @@ -1,18 +1,18 @@ use cubecl_core as cubecl; use cubecl_core::{ cube, - frontend::branch::range, - frontend::{Array, Comptime, CubeContext, CubePrimitive, Float, UInt, F32}, + frontend::{Array, CubeContext, CubePrimitive, Float}, }; -type ElemType = F32; +type ElemType = f32; #[cube] -pub fn for_loop(mut lhs: Array, rhs: F, end: UInt, unroll: Comptime) { +pub fn for_loop(mut lhs: Array, rhs: F, end: u32, #[comptime] unroll: bool) { let tmp1 = rhs * rhs; let tmp2 = tmp1 + rhs; - for i in range(0u32, end, unroll) { + #[unroll(unroll)] + for i in 0..end { lhs[i] = tmp2 + lhs[i]; } } @@ -32,7 +32,7 @@ mod tests { let rhs = context.create_local(Item::new(ElemType::as_elem())); let end: ExpandElement = 4u32.into(); - for_loop::__expand::(&mut context, lhs.into(), rhs.into(), end.into(), unroll); + for_loop::expand::(&mut context, lhs.into(), rhs.into(), end.into(), unroll); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), inline_macro_ref(unroll)); @@ -47,7 +47,7 @@ mod tests { let rhs = context.create_local(Item::new(ElemType::as_elem())); let end: ExpandElement = 4u32.into(); - for_loop::__expand::(&mut context, lhs.into(), rhs.into(), end.into(), unroll); + for_loop::expand::(&mut context, lhs.into(), rhs.into(), end.into(), unroll); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), inline_macro_ref(unroll)); diff --git a/crates/cubecl-core/tests/frontend/function_call.rs b/crates/cubecl-core/tests/frontend/function_call.rs index 56c097d7..ec60ab66 100644 --- a/crates/cubecl-core/tests/frontend/function_call.rs +++ b/crates/cubecl-core/tests/frontend/function_call.rs @@ -1,37 +1,34 @@ use cubecl_core as cubecl; -use cubecl_core::{ - cube, - frontend::{Numeric, UInt}, -}; +use cubecl_core::{cube, frontend::Numeric}; #[cube] -pub fn caller_no_arg(x: UInt) { +pub fn caller_no_arg(x: u32) { let _ = x + callee_no_arg(); } #[cube] -pub fn callee_no_arg() -> UInt { - UInt::from_int(8) +pub fn callee_no_arg() -> u32 { + u32::from_int(8) } #[cube] -pub fn no_call_no_arg(x: UInt) { - let _ = x + UInt::from_int(8); +pub fn no_call_no_arg(x: u32) { + let _ = x + u32::from_int(8); } #[cube] -pub fn caller_with_arg(x: UInt) { +pub fn caller_with_arg(x: u32) { let _ = x + callee_with_arg(x); } #[cube] -pub fn callee_with_arg(x: UInt) -> UInt { - x * UInt::from_int(8) +pub fn callee_with_arg(x: u32) -> u32 { + x * u32::from_int(8) } #[cube] -pub fn no_call_with_arg(x: UInt) { - let _ = x + x * UInt::from_int(8); +pub fn no_call_with_arg(x: u32) { + let _ = x + x * u32::from_int(8); } #[cube] @@ -52,7 +49,7 @@ pub fn no_call_with_generics(x: T) { mod tests { use super::*; use cubecl_core::{ - frontend::{CubeContext, CubePrimitive, I64}, + frontend::{CubeContext, CubePrimitive}, ir::{Elem, Item}, }; @@ -60,12 +57,12 @@ mod tests { fn cube_call_equivalent_to_no_call_no_arg_test() { let mut caller_context = CubeContext::root(); let x = caller_context.create_local(Item::new(Elem::UInt)); - caller_no_arg::__expand(&mut caller_context, x.into()); + caller_no_arg::expand(&mut caller_context, x.into()); let caller_scope = caller_context.into_scope(); let mut no_call_context = CubeContext::root(); let x = no_call_context.create_local(Item::new(Elem::UInt)); - no_call_no_arg::__expand(&mut no_call_context, x.into()); + no_call_no_arg::expand(&mut no_call_context, x.into()); let no_call_scope = no_call_context.into_scope(); assert_eq!( @@ -79,12 +76,12 @@ mod tests { let mut caller_context = CubeContext::root(); let x = caller_context.create_local(Item::new(Elem::UInt)); - caller_with_arg::__expand(&mut caller_context, x.into()); + caller_with_arg::expand(&mut caller_context, x.into()); let caller_scope = caller_context.into_scope(); let mut no_call_context = CubeContext::root(); let x = no_call_context.create_local(Item::new(Elem::UInt)); - no_call_with_arg::__expand(&mut no_call_context, x.into()); + no_call_with_arg::expand(&mut no_call_context, x.into()); let no_call_scope = no_call_context.into_scope(); assert_eq!( @@ -96,14 +93,14 @@ mod tests { #[test] fn cube_call_equivalent_to_no_call_with_generics_test() { let mut caller_context = CubeContext::root(); - type ElemType = I64; + type ElemType = i64; let x = caller_context.create_local(Item::new(ElemType::as_elem())); - caller_with_generics::__expand::(&mut caller_context, x.into()); + caller_with_generics::expand::(&mut caller_context, x.into()); let caller_scope = caller_context.into_scope(); let mut no_call_context = CubeContext::root(); let x = no_call_context.create_local(Item::new(ElemType::as_elem())); - no_call_with_generics::__expand::(&mut no_call_context, x.into()); + no_call_with_generics::expand::(&mut no_call_context, x.into()); let no_call_scope = no_call_context.into_scope(); assert_eq!( diff --git a/crates/cubecl-core/tests/frontend/generic_kernel.rs b/crates/cubecl-core/tests/frontend/generic_kernel.rs index c969a3d0..39f9115f 100644 --- a/crates/cubecl-core/tests/frontend/generic_kernel.rs +++ b/crates/cubecl-core/tests/frontend/generic_kernel.rs @@ -9,7 +9,7 @@ pub fn generic_kernel(lhs: T) { mod tests { use cubecl_core::{ cpa, - frontend::{CubeContext, CubePrimitive, F32, I32}, + frontend::{CubeContext, CubePrimitive}, ir::{Item, Variable}, }; @@ -19,9 +19,9 @@ mod tests { fn cube_generic_float_test() { let mut context = CubeContext::root(); - let lhs = context.create_local(Item::new(F32::as_elem())); + let lhs = context.create_local(Item::new(f32::as_elem())); - generic_kernel::__expand::(&mut context, lhs.into()); + generic_kernel::expand::(&mut context, lhs.into()); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_float()); @@ -31,9 +31,9 @@ mod tests { fn cube_generic_int_test() { let mut context = CubeContext::root(); - let lhs = context.create_local(Item::new(I32::as_elem())); + let lhs = context.create_local(Item::new(i32::as_elem())); - generic_kernel::__expand::(&mut context, lhs.into()); + generic_kernel::expand::(&mut context, lhs.into()); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_int()); @@ -41,7 +41,7 @@ mod tests { fn inline_macro_ref_float() -> String { let mut context = CubeContext::root(); - let item = Item::new(F32::as_elem()); + let item = Item::new(f32::as_elem()); let var = context.create_local(item); let mut scope = context.into_scope(); @@ -53,7 +53,7 @@ mod tests { fn inline_macro_ref_int() -> String { let mut context = CubeContext::root(); - let item = Item::new(I32::as_elem()); + let item = Item::new(i32::as_elem()); let var = context.create_local(item); let mut scope = context.into_scope(); diff --git a/crates/cubecl-core/tests/frontend/if.rs b/crates/cubecl-core/tests/frontend/if.rs index 38d074f8..36635241 100644 --- a/crates/cubecl-core/tests/frontend/if.rs +++ b/crates/cubecl-core/tests/frontend/if.rs @@ -39,13 +39,13 @@ pub fn elsif(lhs: F) { mod tests { use cubecl_core::{ cpa, - frontend::{CubeContext, CubePrimitive, F32}, + frontend::{CubeContext, CubePrimitive}, ir::{Elem, Item, Variable}, }; use super::*; - type ElemType = F32; + type ElemType = f32; #[test] fn cube_if_test() { @@ -53,7 +53,7 @@ mod tests { let lhs = context.create_local(Item::new(ElemType::as_elem())); - if_greater::__expand::(&mut context, lhs.into()); + if_greater::expand::(&mut context, lhs.into()); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_if()); @@ -65,7 +65,7 @@ mod tests { let lhs = context.create_local(Item::new(ElemType::as_elem())); - if_then_else::__expand::(&mut context, lhs.into()); + if_then_else::expand::(&mut context, lhs.into()); let scope = context.into_scope(); assert_eq!( @@ -80,7 +80,7 @@ mod tests { let lhs = context.create_local(Item::new(ElemType::as_elem())); - elsif::__expand::(&mut context, lhs.into()); + elsif::expand::(&mut context, lhs.into()); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_elsif()); diff --git a/crates/cubecl-core/tests/frontend/intrinsics.rs b/crates/cubecl-core/tests/frontend/intrinsics.rs new file mode 100644 index 00000000..0f5616d8 --- /dev/null +++ b/crates/cubecl-core/tests/frontend/intrinsics.rs @@ -0,0 +1,54 @@ +use cubecl_core as cubecl; +use cubecl_core::prelude::*; + +fn assert_comptime(_elem: T) {} +pub mod assert_comptime { + use cubecl_core::prelude::CubeContext; + pub fn expand(_context: &mut CubeContext, _elem: T) {} +} + +#[cube] +pub fn vectorization_of_intrinsic(input: F) -> u32 { + let vec = vectorization_of(&input); + assert_comptime::(vec); + vec +} + +mod tests { + use pretty_assertions::assert_eq; + use std::num::NonZero; + + use super::*; + use cubecl_core::{ + cpa, + ir::{Item, Variable}, + }; + + type ElemType = f32; + + #[test] + fn vectorization_of_test() { + let mut context = CubeContext::root(); + + let input = context.create_local(Item::vectorized(ElemType::as_elem(), NonZero::new(3))); + + vectorization_of_intrinsic::expand::(&mut context, input.into()); + let scope = context.into_scope(); + + assert_eq!(format!("{:?}", scope.operations), inline_macro_ref()); + } + + fn inline_macro_ref() -> String { + let mut context = CubeContext::root(); + let item = Item::new(ElemType::as_elem()); + let _input = context.create_local(item); + let out = context.create_local(Item::new(u32::as_elem())); + + let mut scope = context.into_scope(); + let out: Variable = out.into(); + let three: Variable = 3u32.into(); + cpa!(scope, out = three); + + format!("{:?}", scope.operations) + } +} diff --git a/crates/cubecl-core/tests/frontend/literal.rs b/crates/cubecl-core/tests/frontend/literal.rs index 101d2818..c5a98aad 100644 --- a/crates/cubecl-core/tests/frontend/literal.rs +++ b/crates/cubecl-core/tests/frontend/literal.rs @@ -18,7 +18,7 @@ mod tests { ir::{Item, Variable}, }; - type ElemType = F32; + type ElemType = f32; #[test] fn cube_literal_test() { @@ -26,7 +26,7 @@ mod tests { let lhs = context.create_local(Item::new(ElemType::as_elem())); - literal::__expand::(&mut context, lhs.into()); + literal::expand::(&mut context, lhs.into()); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), inline_macro_ref()); @@ -38,7 +38,7 @@ mod tests { let lhs = context.create_local(Item::new(ElemType::as_elem())); - literal_float_no_decimals::__expand::(&mut context, lhs.into()); + literal_float_no_decimals::expand::(&mut context, lhs.into()); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), inline_macro_ref()); diff --git a/crates/cubecl-core/tests/frontend/loop.rs b/crates/cubecl-core/tests/frontend/loop.rs index fb4acd3d..7b3414f1 100644 --- a/crates/cubecl-core/tests/frontend/loop.rs +++ b/crates/cubecl-core/tests/frontend/loop.rs @@ -35,7 +35,7 @@ mod tests { ir::{Branch, Elem, Item, Variable}, }; - type ElemType = I32; + type ElemType = i32; #[test] fn cube_while_test() { @@ -43,7 +43,7 @@ mod tests { let lhs = context.create_local(Item::new(ElemType::as_elem())); - while_not::__expand::(&mut context, lhs.into()); + while_not::expand::(&mut context, lhs.into()); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_while()); @@ -55,7 +55,7 @@ mod tests { let lhs = context.create_local(Item::new(ElemType::as_elem())); - manual_loop_break::__expand::(&mut context, lhs.into()); + manual_loop_break::expand::(&mut context, lhs.into()); let scope = context.into_scope(); assert_eq!( @@ -70,7 +70,7 @@ mod tests { let lhs = context.create_local(Item::new(ElemType::as_elem())); - loop_with_return::__expand::(&mut context, lhs.into()); + loop_with_return::expand::(&mut context, lhs.into()); let scope = context.into_scope(); assert_eq!( diff --git a/crates/cubecl-core/tests/frontend/mod.rs b/crates/cubecl-core/tests/frontend/mod.rs index 64cebc69..d5743ad9 100644 --- a/crates/cubecl-core/tests/frontend/mod.rs +++ b/crates/cubecl-core/tests/frontend/mod.rs @@ -8,6 +8,7 @@ mod for_loop; mod function_call; mod generic_kernel; mod r#if; +mod intrinsics; mod literal; mod r#loop; mod module_import; @@ -20,6 +21,5 @@ mod r#struct; mod tensor; mod topology; mod r#trait; - mod tuple; mod vectorization; diff --git a/crates/cubecl-core/tests/frontend/module_import.rs b/crates/cubecl-core/tests/frontend/module_import.rs index dde7aeb2..a16956f4 100644 --- a/crates/cubecl-core/tests/frontend/module_import.rs +++ b/crates/cubecl-core/tests/frontend/module_import.rs @@ -28,18 +28,18 @@ mod tests { use super::*; use cubecl_core::ir::Item; - type ElemType = F32; + type ElemType = f32; #[test] fn cube_call_equivalent_to_no_call_no_arg_test() { let mut caller_context = CubeContext::root(); let x = caller_context.create_local(Item::new(ElemType::as_elem())); - here::caller::__expand::(&mut caller_context, x.into()); + here::caller::expand::(&mut caller_context, x.into()); let caller_scope = caller_context.into_scope(); let mut no_call_context = CubeContext::root(); let x = no_call_context.create_local(Item::new(ElemType::as_elem())); - here::no_call_ref::__expand::(&mut no_call_context, x.into()); + here::no_call_ref::expand::(&mut no_call_context, x.into()); let no_call_scope = no_call_context.into_scope(); assert_eq!( diff --git a/crates/cubecl-core/tests/frontend/ops.rs b/crates/cubecl-core/tests/frontend/ops.rs index 064e8b9f..f092b863 100644 --- a/crates/cubecl-core/tests/frontend/ops.rs +++ b/crates/cubecl-core/tests/frontend/ops.rs @@ -122,7 +122,7 @@ pub fn greater_equal_op(a: T, b: T) -> bool { } #[cube] -pub fn modulo_op(a: UInt, b: UInt) -> UInt { +pub fn modulo_op(a: u32, b: u32) -> u32 { a % b } @@ -157,27 +157,27 @@ pub fn not_op(a: bool) -> bool { } #[cube] -pub fn bitand_op(a: UInt, b: UInt) -> UInt { +pub fn bitand_op(a: u32, b: u32) -> u32 { a & b } #[cube] -pub fn bitor_op(a: UInt, b: UInt) -> UInt { +pub fn bitor_op(a: u32, b: u32) -> u32 { a | b } #[cube] -pub fn bitxor_op(a: UInt, b: UInt) -> UInt { +pub fn bitxor_op(a: u32, b: u32) -> u32 { a ^ b } #[cube] -pub fn shl_op(a: UInt, b: UInt) -> UInt { +pub fn shl_op(a: u32, b: u32) -> u32 { a << b } #[cube] -pub fn shr_op(a: UInt, b: UInt) -> UInt { +pub fn shr_op(a: u32, b: u32) -> u32 { a >> b } @@ -201,9 +201,40 @@ pub fn div_assign_op(mut a: T, b: T) { a /= b; } +#[cube] +pub fn rem_assign_op(mut a: T, b: T) { + a %= b; +} + +#[cube] +pub fn bitor_assign_op(mut a: T, b: T) { + a |= b; +} + +#[cube] +pub fn bitand_assign_op(mut a: T, b: T) { + a &= b; +} + +#[cube] +pub fn bitxor_assign_op(mut a: T, b: T) { + a ^= b; +} + +#[cube] +pub fn shl_assign_op(mut a: T, b: u32) { + a <<= b; +} + +#[cube] +pub fn shr_assign_op(mut a: T, b: u32) { + a >>= b; +} + mod tests { use super::*; use cubecl_core::ir::{Elem, FloatKind, Item}; + use pretty_assertions::assert_eq; macro_rules! binary_test { ($test_name:ident, $op_expand:expr, $op_name:expr, $func:ident) => { @@ -258,7 +289,7 @@ mod tests { }; } - macro_rules! binary_uint_test { + macro_rules! binary_u32_test { ($test_name:ident, $op_expand:expr, $op_name:expr) => { #[test] fn $test_name() { @@ -270,98 +301,134 @@ mod tests { assert_eq!( format!("{:?}", context.into_scope().operations), - ref_ops_binary_uint($op_name) + ref_ops_binary_u32($op_name) ); } }; } - binary_test!(cube_can_add, add_op::__expand::, "Add", ref_ops_binary); - binary_test!(cube_can_sub, sub_op::__expand::, "Sub", ref_ops_binary); - binary_test!(cube_can_mul, mul_op::__expand::, "Mul", ref_ops_binary); - binary_test!(cube_can_div, div_op::__expand::, "Div", ref_ops_binary); - unary_test!(cube_can_abs, abs_op::__expand::, "Abs"); - unary_test!(cube_can_exp, exp_op::__expand::, "Exp"); - unary_test!(cube_can_log, log_op::__expand::, "Log"); - unary_test!(cube_can_log1p, log1p_op::__expand::, "Log1p"); - unary_test!(cube_can_cos, cos_op::__expand::, "Cos"); - unary_test!(cube_can_sin, sin_op::__expand::, "Sin"); - unary_test!(cube_can_tanh, tanh_op::__expand::, "Tanh"); + binary_test!(cube_can_add, add_op::expand::, "Add", ref_ops_binary); + binary_test!(cube_can_sub, sub_op::expand::, "Sub", ref_ops_binary); + binary_test!(cube_can_mul, mul_op::expand::, "Mul", ref_ops_binary); + binary_test!(cube_can_div, div_op::expand::, "Div", ref_ops_binary); + unary_test!(cube_can_abs, abs_op::expand::, "Abs"); + unary_test!(cube_can_exp, exp_op::expand::, "Exp"); + unary_test!(cube_can_log, log_op::expand::, "Log"); + unary_test!(cube_can_log1p, log1p_op::expand::, "Log1p"); + unary_test!(cube_can_cos, cos_op::expand::, "Cos"); + unary_test!(cube_can_sin, sin_op::expand::, "Sin"); + unary_test!(cube_can_tanh, tanh_op::expand::, "Tanh"); binary_test!( cube_can_powf, - powf_op::__expand::, + powf_op::expand::, "Powf", ref_ops_binary ); - unary_test!(cube_can_sqrt, sqrt_op::__expand::, "Sqrt"); - unary_test!(cube_can_erf, erf_op::__expand::, "Erf"); - unary_test!(cube_can_recip, recip_op::__expand::, "Recip"); - unary_test!(cube_can_round, round_op::__expand::, "Round"); - unary_test!(cube_can_floor, floor_op::__expand::, "Floor"); - unary_test!(cube_can_ceil, ceil_op::__expand::, "Ceil"); - binary_test!(cube_can_eq, equal_op::__expand::, "Equal", ref_ops_cmp); + unary_test!(cube_can_sqrt, sqrt_op::expand::, "Sqrt"); + unary_test!(cube_can_erf, erf_op::expand::, "Erf"); + unary_test!(cube_can_recip, recip_op::expand::, "Recip"); + unary_test!(cube_can_round, round_op::expand::, "Round"); + unary_test!(cube_can_floor, floor_op::expand::, "Floor"); + unary_test!(cube_can_ceil, ceil_op::expand::, "Ceil"); + binary_test!(cube_can_eq, equal_op::expand::, "Equal", ref_ops_cmp); binary_test!( cube_can_ne, - not_equal_op::__expand::, + not_equal_op::expand::, "NotEqual", ref_ops_cmp ); - binary_test!(cube_can_lt, lower_op::__expand::, "Lower", ref_ops_cmp); + binary_test!(cube_can_lt, lower_op::expand::, "Lower", ref_ops_cmp); binary_test!( cube_can_le, - lower_equal_op::__expand::, + lower_equal_op::expand::, "LowerEqual", ref_ops_cmp ); binary_test!( cube_can_ge, - greater_equal_op::__expand::, + greater_equal_op::expand::, "GreaterEqual", ref_ops_cmp ); binary_test!( cube_can_gt, - greater_op::__expand::, + greater_op::expand::, "Greater", ref_ops_cmp ); - binary_test!(cube_can_max, max_op::__expand::, "Max", ref_ops_binary); - binary_test!(cube_can_min, min_op::__expand::, "Min", ref_ops_binary); + binary_test!(cube_can_max, max_op::expand::, "Max", ref_ops_binary); + binary_test!(cube_can_min, min_op::expand::, "Min", ref_ops_binary); binary_test!( cube_can_add_assign, - add_assign_op::__expand::, + add_assign_op::expand::, "Add", ref_ops_binary ); binary_test!( cube_can_sub_assign, - sub_assign_op::__expand::, + sub_assign_op::expand::, "Sub", ref_ops_binary ); binary_test!( cube_can_mul_assign, - mul_assign_op::__expand::, + mul_assign_op::expand::, "Mul", ref_ops_binary ); binary_test!( cube_can_div_assign, - div_assign_op::__expand::, + div_assign_op::expand::, "Div", ref_ops_binary ); - binary_boolean_test!(cube_can_and, and_op::__expand, "And"); - binary_boolean_test!(cube_can_or, or_op::__expand, "Or"); - binary_uint_test!(cube_can_bitand, bitand_op::__expand, "BitwiseAnd"); - binary_uint_test!(cube_can_bitor, bitor_op::__expand, "BitwiseOr"); - binary_uint_test!(cube_can_bitxor, bitxor_op::__expand, "BitwiseXor"); - binary_uint_test!(cube_can_shl, shl_op::__expand, "ShiftLeft"); - binary_uint_test!(cube_can_shr, shr_op::__expand, "ShiftRight"); - binary_uint_test!(cube_can_mod, modulo_op::__expand, "Modulo"); + binary_test!( + cube_can_rem_assign, + rem_assign_op::expand::, + "Modulo", + ref_ops_binary + ); + binary_test!( + cube_can_bitor_assign, + bitor_assign_op::expand::, + "BitwiseOr", + ref_ops_binary + ); + binary_test!( + cube_can_bitand_assign, + bitand_assign_op::expand::, + "BitwiseAnd", + ref_ops_binary + ); + binary_test!( + cube_can_bitxor_assign, + bitxor_assign_op::expand::, + "BitwiseXor", + ref_ops_binary + ); + binary_test!( + cube_can_shl_assign, + shl_assign_op::expand::, + "ShiftLeft", + ref_ops_binary + ); + binary_test!( + cube_can_shr_assign, + shr_assign_op::expand::, + "ShiftRight", + ref_ops_binary + ); + binary_boolean_test!(cube_can_and, and_op::expand, "And"); + binary_boolean_test!(cube_can_or, or_op::expand, "Or"); + binary_u32_test!(cube_can_bitand, bitand_op::expand, "BitwiseAnd"); + binary_u32_test!(cube_can_bitor, bitor_op::expand, "BitwiseOr"); + binary_u32_test!(cube_can_bitxor, bitxor_op::expand, "BitwiseXor"); + binary_u32_test!(cube_can_shl, shl_op::expand, "ShiftLeft"); + binary_u32_test!(cube_can_shr, shr_op::expand, "ShiftRight"); + binary_u32_test!(cube_can_mod, modulo_op::expand, "Modulo"); binary_test!( cube_can_rem, - remainder_op::__expand::, + remainder_op::expand::, "Remainder", ref_ops_binary ); @@ -371,7 +438,7 @@ mod tests { let mut context = CubeContext::root(); let x = context.create_local(Item::new(Elem::Bool)); - not_op::__expand(&mut context, x.into()); + not_op::expand(&mut context, x.into()); assert_eq!( format!("{:?}", context.into_scope().operations), @@ -399,7 +466,7 @@ mod tests { ref_ops_template(ops_name, "Bool", "Bool", true) } - fn ref_ops_binary_uint(ops_name: &str) -> String { + fn ref_ops_binary_u32(ops_name: &str) -> String { ref_ops_template(ops_name, "UInt", "UInt", true) } @@ -410,15 +477,15 @@ mod tests { "[Operator({ops_name}(BinaryOperator {{ \ lhs: Local {{ id: 0, item: Item {{ \ elem: {in_type}, \ - vectorization: 1 \ + vectorization: None \ }}, depth: 0 }}, \ rhs: Local {{ id: 1, item: Item {{ \ elem: {in_type}, \ - vectorization: 1 \ + vectorization: None \ }}, depth: 0 }}, \ out: Local {{ id: {out_number}, item: Item {{ \ elem: {out_type}, \ - vectorization: 1 \ + vectorization: None \ }}, depth: 0 }} \ }}))]" ) @@ -427,11 +494,11 @@ mod tests { "[Operator({ops_name}(UnaryOperator {{ \ input: Local {{ id: 0, item: Item {{ \ elem: {in_type}, \ - vectorization: 1 \ + vectorization: None \ }}, depth: 0 }}, \ out: Local {{ id: 0, item: Item {{ \ elem: {out_type}, \ - vectorization: 1 \ + vectorization: None \ }}, depth: 0 }} \ }}))]" ) diff --git a/crates/cubecl-core/tests/frontend/parenthesis.rs b/crates/cubecl-core/tests/frontend/parenthesis.rs index 72d636e8..701ce413 100644 --- a/crates/cubecl-core/tests/frontend/parenthesis.rs +++ b/crates/cubecl-core/tests/frontend/parenthesis.rs @@ -13,7 +13,7 @@ mod tests { ir::{Item, Variable}, }; - type ElemType = F32; + type ElemType = f32; #[test] fn cube_parenthesis_priority_test() { @@ -23,7 +23,7 @@ mod tests { let y = context.create_local(Item::new(ElemType::as_elem())); let z = context.create_local(Item::new(ElemType::as_elem())); - parenthesis::__expand::(&mut context, x.into(), y.into(), z.into()); + parenthesis::expand::(&mut context, x.into(), y.into(), z.into()); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), inline_macro_ref()); diff --git a/crates/cubecl-core/tests/frontend/redeclare.rs b/crates/cubecl-core/tests/frontend/redeclare.rs index eb5eb214..34c30678 100644 --- a/crates/cubecl-core/tests/frontend/redeclare.rs +++ b/crates/cubecl-core/tests/frontend/redeclare.rs @@ -21,18 +21,19 @@ pub fn redeclare_same_scope_other_type(mut x: I) -> F { pub fn redeclare_different_scope(mut x: I) { let y = I::new(1); x += y; - for _ in range(0u32, 2u32, Comptime::new(false)) { + for _ in 0..2u32 { let y = I::new(2); x += y; } } #[cube] -pub fn redeclare_two_for_loops(mut x: UInt) { - for i in range(0u32, 2u32, Comptime::new(false)) { +#[allow(unused)] +pub fn redeclare_two_for_loops(mut x: u32) { + for i in 0..2 { x += i; } - for i in range(0u32, 2u32, Comptime::new(false)) { + for i in 0..2 { x += i; x += i; } @@ -43,10 +44,11 @@ mod tests { cpa, ir::{Item, Variable}, }; + use pretty_assertions::assert_eq; use super::*; - type ElemType = I32; + type ElemType = i32; #[test] fn cube_redeclare_same_scope_test() { @@ -54,11 +56,11 @@ mod tests { let x = context.create_local(Item::new(ElemType::as_elem())); - redeclare_same_scope::__expand::(&mut context, x.into()); + redeclare_same_scope::expand::(&mut context, x.into()); let scope = context.into_scope(); assert_eq!( - format!("{:?}", scope.operations), + format!("{:#?}", scope.operations), inline_macro_ref_same_scope() ); } @@ -69,7 +71,7 @@ mod tests { let x = context.create_local(Item::new(ElemType::as_elem())); - redeclare_same_scope_other_type::__expand::(&mut context, x.into()); + redeclare_same_scope_other_type::expand::(&mut context, x.into()); let scope = context.into_scope(); assert_eq!( @@ -84,11 +86,11 @@ mod tests { let x = context.create_local(Item::new(ElemType::as_elem())); - redeclare_different_scope::__expand::(&mut context, x.into()); + redeclare_different_scope::expand::(&mut context, x.into()); let scope = context.into_scope(); assert_eq!( - format!("{:?}", scope.operations), + format!("{:#?}", scope.operations), inline_macro_ref_different() ); } @@ -97,9 +99,9 @@ mod tests { fn cube_redeclare_two_for_loops_test() { let mut context = CubeContext::root(); - let x = context.create_local(Item::new(UInt::as_elem())); + let x = context.create_local(Item::new(u32::as_elem())); - redeclare_two_for_loops::__expand(&mut context, x.into()); + redeclare_two_for_loops::expand(&mut context, x.into()); let scope = context.into_scope(); assert_eq!( @@ -116,16 +118,17 @@ mod tests { let mut scope = context.into_scope(); let x: Variable = x.into(); - let i = scope.create_with_value(1, item); - cpa!(scope, x += i); + let value: ExpandElement = ElemType::from(1).into(); + let value: Variable = *value; - let value: ExpandElement = ElemType::from(2).into_expand().into(); + cpa!(scope, x += value); + + let value: ExpandElement = ElemType::from(2).into(); let value: Variable = *value; - cpa!(scope, i = value); - cpa!(scope, x += i); + cpa!(scope, x += value); - format!("{:?}", scope.operations) + format!("{:#?}", scope.operations) } fn inline_macro_ref_same_scope_other_type() -> String { @@ -136,10 +139,12 @@ mod tests { let mut scope = context.into_scope(); let x: Variable = x.into(); - let i = scope.create_with_value(1, item); + let i: ExpandElement = ElemType::new(1).into(); + let i = *i; cpa!(scope, x += i); - let i = scope.create_with_value(2, Item::new(F32::as_elem())); - let y = scope.create_local(Item::new(F32::as_elem())); + let i: ExpandElement = 2f32.into(); + let i = *i; + let y = scope.create_local(Item::new(f32::as_elem())); cpa!(scope, y = i + i); format!("{:?}", scope.operations) @@ -154,26 +159,26 @@ mod tests { let mut scope = context.into_scope(); let x: Variable = x.into(); - let y = scope.create_with_value(1, item); + let y: ExpandElement = ElemType::new(1).into(); + let y = *y; cpa!(scope, x += y); cpa!( &mut scope, range(0u32, end, false).for_each(|_, scope| { - let value: ExpandElement = ElemType::from(2).into_expand().into(); + let value: ExpandElement = ElemType::new(2).into(); let value: Variable = *value; - cpa!(scope, y = value); - cpa!(scope, x += y); + cpa!(scope, x += value); }) ); - format!("{:?}", scope.operations) + format!("{:#?}", scope.operations) } fn inline_macro_ref_two_for_loops() -> String { let mut context = CubeContext::root(); - let item = Item::new(UInt::as_elem()); + let item = Item::new(u32::as_elem()); let x = context.create_local(item); let end = 2u32; diff --git a/crates/cubecl-core/tests/frontend/reuse.rs b/crates/cubecl-core/tests/frontend/reuse.rs index 8ccd6988..c66a1284 100644 --- a/crates/cubecl-core/tests/frontend/reuse.rs +++ b/crates/cubecl-core/tests/frontend/reuse.rs @@ -26,14 +26,14 @@ mod tests { ir::{Branch, Elem, Item, Variable}, }; - type ElemType = I32; + type ElemType = i32; #[test] fn cube_reuse_assign_test() { let mut context = CubeContext::root(); let x = context.create_local(Item::new(ElemType::as_elem())); - reuse::__expand::(&mut context, x.into()); + reuse::expand::(&mut context, x.into()); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_assign()); @@ -45,7 +45,7 @@ mod tests { let x = context.create_local(Item::new(ElemType::as_elem())); - reuse_incr::__expand::(&mut context, x.into()); + reuse_incr::expand::(&mut context, x.into()); let scope = context.into_scope(); assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_incr()); diff --git a/crates/cubecl-core/tests/frontend/shared_memory.rs b/crates/cubecl-core/tests/frontend/shared_memory.rs index 603551fd..b41dbe0b 100644 --- a/crates/cubecl-core/tests/frontend/shared_memory.rs +++ b/crates/cubecl-core/tests/frontend/shared_memory.rs @@ -2,7 +2,7 @@ use cubecl_core as cubecl; use cubecl_core::prelude::*; #[cube] -pub fn shared_memory_read_write(sm_size: Comptime) { +pub fn shared_memory_read_write(#[comptime] sm_size: u32) { let mut shared = SharedMemory::::new(sm_size); shared[0] = T::from_int(3); let _ = shared[0]; @@ -15,13 +15,13 @@ mod tests { ir::{Item, Variable}, }; - type ElemType = F32; + type ElemType = f32; #[test] fn cube_support_shared_memory() { let mut context = CubeContext::root(); - shared_memory_read_write::__expand::(&mut context, 512); + shared_memory_read_write::expand::(&mut context, 512); assert_eq!( format!("{:?}", context.into_scope().operations), inline_macro_ref() diff --git a/crates/cubecl-core/tests/frontend/struct.rs b/crates/cubecl-core/tests/frontend/struct.rs index e0deee8a..4eb21572 100644 --- a/crates/cubecl-core/tests/frontend/struct.rs +++ b/crates/cubecl-core/tests/frontend/struct.rs @@ -40,7 +40,7 @@ mod tests { ir::{Item, Variable}, }; - type ElemType = F32; + type ElemType = f32; #[test] fn cube_new_struct_test() { @@ -49,7 +49,7 @@ mod tests { let x = context.create_local(Item::new(ElemType::as_elem())); let y = context.create_local(Item::new(ElemType::as_elem())); - creator::__expand::(&mut context, x.into(), y.into()); + creator::expand::(&mut context, x.into(), y.into()); let scope = context.into_scope(); assert_eq!( @@ -69,7 +69,7 @@ mod tests { first: x.into(), second: y.into(), }; - state_receiver_with_reuse::__expand::(&mut context, expanded_state); + state_receiver_with_reuse::expand::(&mut context, expanded_state); let scope = context.into_scope(); assert_eq!( @@ -89,7 +89,7 @@ mod tests { first: x.into(), second: y.into(), }; - attribute_modifier_reuse_field::__expand::(&mut context, expanded_state); + attribute_modifier_reuse_field::expand::(&mut context, expanded_state); let scope = context.into_scope(); assert_eq!( @@ -109,7 +109,7 @@ mod tests { first: x.into(), second: y.into(), }; - attribute_modifier_reuse_struct::__expand::(&mut context, expanded_state); + attribute_modifier_reuse_struct::expand::(&mut context, expanded_state); let scope = context.into_scope(); assert_eq!( diff --git a/crates/cubecl-core/tests/frontend/tensor.rs b/crates/cubecl-core/tests/frontend/tensor.rs index d7d905bd..231e3055 100644 --- a/crates/cubecl-core/tests/frontend/tensor.rs +++ b/crates/cubecl-core/tests/frontend/tensor.rs @@ -15,14 +15,14 @@ mod tests { ir::{Item, Operation, Variable}, }; - type ElemType = F32; + type ElemType = f32; #[test] fn cube_support_tensor_metadata() { let mut context = CubeContext::root(); let input = context.input(0, Item::new(ElemType::as_elem())); - kernel::__expand::(&mut context, input.into()); + kernel::expand::(&mut context, input.into()); assert_eq!(context.into_scope().operations, inline_macro_ref()); } @@ -33,9 +33,9 @@ mod tests { let mut scope = context.into_scope(); let input: Variable = input.into(); - let x = scope.create_local(Item::new(UInt::as_elem())); - let y = scope.create_local(Item::new(UInt::as_elem())); - let z = scope.create_local(Item::new(UInt::as_elem())); + let x = scope.create_local(Item::new(u32::as_elem())); + let y = scope.create_local(Item::new(u32::as_elem())); + let z = scope.create_local(Item::new(u32::as_elem())); cpa!(&mut scope, x = shape(input, 1u32)); cpa!(&mut scope, y = stride(input, 1u32)); diff --git a/crates/cubecl-core/tests/frontend/topology.rs b/crates/cubecl-core/tests/frontend/topology.rs index 816ce5cd..1cd7263d 100644 --- a/crates/cubecl-core/tests/frontend/topology.rs +++ b/crates/cubecl-core/tests/frontend/topology.rs @@ -3,7 +3,7 @@ use cubecl_core::prelude::*; #[cube] pub fn topology_kernel(input: Tensor) { - let x = ABSOLUTE_POS + UInt::new(4); + let x = ABSOLUTE_POS + 4; let _ = input[x]; } @@ -14,14 +14,14 @@ mod tests { ir::{Elem, Item, Variable}, }; - type ElemType = F32; + type ElemType = f32; #[test] fn cube_support_topology() { let mut context = CubeContext::root(); let input = context.input(0, Item::new(ElemType::as_elem())); - topology_kernel::__expand::(&mut context, input.into()); + topology_kernel::expand::(&mut context, input.into()); assert_eq!( format!("{:?}", context.into_scope().operations), inline_macro_ref() diff --git a/crates/cubecl-core/tests/frontend/trait.rs b/crates/cubecl-core/tests/frontend/trait.rs index 8d75f27b..bade72ac 100644 --- a/crates/cubecl-core/tests/frontend/trait.rs +++ b/crates/cubecl-core/tests/frontend/trait.rs @@ -64,7 +64,7 @@ impl MethodTypedStrategy for AddStrategy { input_1: ::ExpandType, input_2: ::ExpandType, ) -> ::ExpandType { - add_strategy_operation::__expand::(context, input_1, input_2) + add_strategy_operation::expand::(context, input_1, input_2) } } @@ -80,7 +80,7 @@ mod tests { ir::{Item, Variable}, }; - type ElemType = F32; + type ElemType = f32; #[test] fn cube_strategy_trait_add_test() { let mut context = CubeContext::root(); @@ -88,7 +88,7 @@ mod tests { let x = context.create_local(Item::new(ElemType::as_elem())); let y = context.create_local(Item::new(ElemType::as_elem())); - with_strategy_trait::__expand::(&mut context, x.into(), y.into()); + with_strategy_trait::expand::(&mut context, x.into(), y.into()); let scope = context.into_scope(); assert_eq!( @@ -104,7 +104,7 @@ mod tests { let x = context.create_local(Item::new(ElemType::as_elem())); let y = context.create_local(Item::new(ElemType::as_elem())); - with_strategy_trait::__expand::(&mut context, x.into(), y.into()); + with_strategy_trait::expand::(&mut context, x.into(), y.into()); let scope = context.into_scope(); assert_eq!( @@ -120,7 +120,7 @@ mod tests { let x = context.create_local(Item::new(ElemType::as_elem())); let y = context.create_local(Item::new(ElemType::as_elem())); - two_strategy_traits::__expand::( + two_strategy_traits::expand::( &mut context, x.into(), y.into(), @@ -137,7 +137,7 @@ mod tests { let x = context.create_local(Item::new(ElemType::as_elem())); let y = context.create_local(Item::new(ElemType::as_elem())); - with_trait_generic_method::__expand::( + with_trait_generic_method::expand::( &mut context, x.into(), y.into(), diff --git a/crates/cubecl-core/tests/frontend/tuple.rs b/crates/cubecl-core/tests/frontend/tuple.rs index 84936f48..0bbefa62 100644 --- a/crates/cubecl-core/tests/frontend/tuple.rs +++ b/crates/cubecl-core/tests/frontend/tuple.rs @@ -2,17 +2,17 @@ use cubecl_core as cubecl; use cubecl_core::prelude::*; #[cube] -pub fn tuple_const() -> (UInt, UInt) { - let x = UInt::new(0); - let y = UInt::new(1); +pub fn tuple_const() -> (u32, u32) { + let x = 0u32; + let y = 1u32; (x, y) } #[cube] -pub fn tuple_destructuring() -> (UInt, UInt) { - let x = (UInt::new(0), UInt::new(1)); +pub fn tuple_destructuring() -> (u32, u32) { + let x = (0u32, 1u32); let (a, b) = x; - (a + UInt::new(1), b) + (a + 1, b) } mod tests { @@ -21,12 +21,13 @@ mod tests { cpa, ir::{Elem, Item, Operation, Variable}, }; + use pretty_assertions::assert_eq; #[test] fn cube_tuple_const_test() { let mut context = CubeContext::root(); - tuple_const::__expand(&mut context); + tuple_const::expand(&mut context); let scope = context.into_scope(); assert_eq!(scope.operations, inline_macro_ref_tuple_const()); @@ -52,7 +53,7 @@ mod tests { fn cube_tuple_destructuring() { let mut context = CubeContext::root(); - tuple_destructuring::__expand(&mut context); + tuple_destructuring::expand(&mut context); let scope = context.into_scope(); assert_eq!(scope.operations, inline_macro_ref_tuple_destructuring()); @@ -65,12 +66,10 @@ mod tests { let a = scope.create_local(Item::new(Elem::UInt)); let b = scope.create_local(Item::new(Elem::UInt)); - let zero: Variable = 0u32.into(); let one: Variable = 1u32.into(); - cpa!(scope, a = zero); + cpa!(scope, a = one); cpa!(scope, b = one); - cpa!(scope, a = a + 1u32); scope.operations } diff --git a/crates/cubecl-core/tests/frontend/vectorization.rs b/crates/cubecl-core/tests/frontend/vectorization.rs index 938750d0..6a95921e 100644 --- a/crates/cubecl-core/tests/frontend/vectorization.rs +++ b/crates/cubecl-core/tests/frontend/vectorization.rs @@ -12,18 +12,20 @@ pub fn vectorization_cmp(rhs: T) { } mod tests { + use std::num::NonZero; + use super::*; use cubecl_core::ir::Item; - type ElemType = F32; + type ElemType = f32; #[test] fn cube_vectorization_binary_op_with_same_scheme_does_not_fail() { let mut context = CubeContext::root(); - let lhs = context.create_local(Item::vectorized(ElemType::as_elem(), 2)); + let lhs = context.create_local(Item::vectorized(ElemType::as_elem(), NonZero::new(2))); - vectorization_binary::__expand::(&mut context, lhs.into()); + vectorization_binary::expand::(&mut context, lhs.into()); } #[test] @@ -31,18 +33,18 @@ mod tests { fn cube_vectorization_binary_op_with_different_scheme_fails() { let mut context = CubeContext::root(); - let lhs = context.create_local(Item::vectorized(ElemType::as_elem(), 4)); + let lhs = context.create_local(Item::vectorized(ElemType::as_elem(), NonZero::new(4))); - vectorization_binary::__expand::(&mut context, lhs.into()); + vectorization_binary::expand::(&mut context, lhs.into()); } #[test] fn cube_vectorization_cmp_op_with_same_scheme_does_not_fail() { let mut context = CubeContext::root(); - let lhs = context.create_local(Item::vectorized(ElemType::as_elem(), 2)); + let lhs = context.create_local(Item::vectorized(ElemType::as_elem(), NonZero::new(2))); - vectorization_cmp::__expand::(&mut context, lhs.into()); + vectorization_cmp::expand::(&mut context, lhs.into()); } #[test] @@ -50,17 +52,17 @@ mod tests { fn cube_vectorization_cmp_op_with_different_scheme_fails() { let mut context = CubeContext::root(); - let lhs = context.create_local(Item::vectorized(ElemType::as_elem(), 4)); + let lhs = context.create_local(Item::vectorized(ElemType::as_elem(), NonZero::new(4))); - vectorization_cmp::__expand::(&mut context, lhs.into()); + vectorization_cmp::expand::(&mut context, lhs.into()); } #[test] fn cube_vectorization_can_be_broadcasted() { let mut context = CubeContext::root(); - let lhs = context.create_local(Item::vectorized(ElemType::as_elem(), 1)); + let lhs = context.create_local(Item::vectorized(ElemType::as_elem(), None)); - vectorization_cmp::__expand::(&mut context, lhs.into()); + vectorization_cmp::expand::(&mut context, lhs.into()); } } diff --git a/crates/cubecl-cuda/Cargo.toml b/crates/cubecl-cuda/Cargo.toml index f753a076..ec3470d4 100644 --- a/crates/cubecl-cuda/Cargo.toml +++ b/crates/cubecl-cuda/Cargo.toml @@ -20,18 +20,22 @@ default = [ std = ["cubecl-runtime/std", "cubecl-common/std", "cubecl-core/std"] [dependencies] +cubecl-common = { path = "../cubecl-common", version = "0.2.0" } +cubecl-core = { path = "../cubecl-core", version = "0.2.0" } cubecl-runtime = { path = "../cubecl-runtime", version = "0.2.0", default-features = false, features = [ "channel-mutex", ] } -cubecl-common = { path = "../cubecl-common", version = "0.2.0" } -cubecl-core = { path = "../cubecl-core", version = "0.2.0" } -half = { workspace = true } bytemuck = { workspace = true } -cudarc = { version = "0.12", features = ["std", "driver", "cuda-version-from-build-system"], default-features = false } +cudarc = { version = "0.12", features = [ + "std", + "driver", + "cuda-version-from-build-system", +], default-features = false } -log = { workspace = true } derive-new = { workspace = true } +half = { workspace = true } +log = { workspace = true } [dev-dependencies] cubecl-core = { path = "../cubecl-core", version = "0.2.0", features = [ @@ -40,3 +44,4 @@ cubecl-core = { path = "../cubecl-core", version = "0.2.0", features = [ cubecl-linalg = { path = "../cubecl-linalg", version = "0.2.0", features = [ "export_tests", ] } +pretty_assertions = { workspace = true } diff --git a/crates/cubecl-cuda/src/compiler/base.rs b/crates/cubecl-cuda/src/compiler/base.rs index 7939f927..d010e245 100644 --- a/crates/cubecl-cuda/src/compiler/base.rs +++ b/crates/cubecl-cuda/src/compiler/base.rs @@ -1,4 +1,4 @@ -use std::collections::HashSet; +use std::{collections::HashSet, num::NonZero}; use cubecl_core::{ ir::{self as gpu, ConstantScalarValue}, @@ -294,6 +294,7 @@ impl CudaCompiler { start: self.compile_variable(range_loop.start), end: self.compile_variable(range_loop.end), step: range_loop.step.map(|it| self.compile_variable(it)), + inclusive: range_loop.inclusive, instructions: self.compile_scope(&mut range_loop.scope), }), gpu::Branch::Loop(mut op) => instructions.push(Instruction::Loop { @@ -544,6 +545,9 @@ impl CudaCompiler { val: self.compile_variable(op.val), out: self.compile_variable(op.out), }), + gpu::Operator::Neg(op) => { + instructions.push(Instruction::Negate(self.compile_unary(op))) + } gpu::Operator::Normalize(op) => { instructions.push(Instruction::Normalize(self.compile_unary(op))) } @@ -717,7 +721,10 @@ impl CudaCompiler { } fn compile_item(&mut self, item: gpu::Item) -> super::Item { - let item = super::Item::new(self.compile_elem(item.elem), item.vectorization.into()); + let item = super::Item::new( + self.compile_elem(item.elem), + item.vectorization.map(NonZero::get).unwrap_or(1).into(), + ); self.items.insert(item); self.items.insert(item.optimized()); item diff --git a/crates/cubecl-cuda/src/compiler/instruction.rs b/crates/cubecl-cuda/src/compiler/instruction.rs index d418fc38..82fe333f 100644 --- a/crates/cubecl-cuda/src/compiler/instruction.rs +++ b/crates/cubecl-cuda/src/compiler/instruction.rs @@ -50,6 +50,7 @@ pub enum Instruction { start: Variable, end: Variable, step: Option, + inclusive: bool, instructions: Vec, }, Loop { @@ -138,6 +139,7 @@ pub enum Instruction { val: Variable, out: Variable, }, + Negate(UnaryInstruction), Normalize(UnaryInstruction), } @@ -188,15 +190,17 @@ impl Display for Instruction { start, end, step, + inclusive, instructions, } => { let increment = step .map(|step| format!("{i} += {step}")) .unwrap_or_else(|| format!("++{i}")); + let cmp = if *inclusive { "<=" } else { "<" }; f.write_fmt(format_args!( " -for (uint {i} = {start}; {i} < {end}; {increment}) {{ +for (uint {i} = {start}; {i} {cmp} {end}; {increment}) {{ " ))?; for instruction in instructions { @@ -381,6 +385,9 @@ for (uint {i} = {start}; {i} < {end}; {increment}) {{ f.write_fmt(format_args!("atomicExch({out}, {input});\n")) } Instruction::Remainder(inst) => Remainder::format(f, &inst.lhs, &inst.rhs, &inst.out), + Instruction::Negate(UnaryInstruction { input, out }) => { + f.write_fmt(format_args!("{out} = !{input};\n")) + } Instruction::Normalize(inst) => Normalize::format(f, &inst.input, &inst.out), } } diff --git a/crates/cubecl-linalg/Cargo.toml b/crates/cubecl-linalg/Cargo.toml index 74ef713c..b554c74f 100644 --- a/crates/cubecl-linalg/Cargo.toml +++ b/crates/cubecl-linalg/Cargo.toml @@ -15,14 +15,15 @@ version.workspace = true [features] default = [] +export_tests = ["pretty_assertions"] std = [] -export_tests = [] [dependencies] +bytemuck = { workspace = true } cubecl-core = { path = "../cubecl-core", version = "0.2.0", default-features = false } cubecl-runtime = { path = "../cubecl-runtime", version = "0.2.0", default-features = false } -bytemuck = { workspace = true } half = { workspace = true, features = ["bytemuck"] } +pretty_assertions = { workspace = true, optional = true } [dev-dependencies] trybuild = "1" diff --git a/crates/cubecl-linalg/src/matmul/cmma/base.rs b/crates/cubecl-linalg/src/matmul/cmma/base.rs index d723e5e2..cc37ca6b 100644 --- a/crates/cubecl-linalg/src/matmul/cmma/base.rs +++ b/crates/cubecl-linalg/src/matmul/cmma/base.rs @@ -1,5 +1,5 @@ -use cubecl_core as cubecl; -use cubecl_core::prelude::*; +use cubecl_core::cube; +use cubecl_core::{self as cubecl, prelude::*}; use super::block_loop::block_loop; use super::config::ComptimeCmmaInfo; @@ -10,7 +10,7 @@ pub fn cmma_kernel( lhs: &Tensor, rhs: &Tensor, out: &mut Tensor, - comptime_info: Comptime, + #[comptime] comptime_info: ComptimeCmmaInfo, ) { let ids = get_ids(); let dims = get_dims::(lhs, rhs); @@ -32,15 +32,15 @@ pub fn cmma_kernel( #[derive(CubeType, Copy, Clone)] pub(crate) struct Dimensions { - pub m: UInt, - pub k: UInt, - pub n: UInt, + pub m: u32, + pub k: u32, + pub n: u32, } #[derive(CubeType, Copy, Clone)] pub(crate) struct Ids { - pub coop: UInt, - pub lane: UInt, + pub coop: u32, + pub lane: u32, } #[derive(CubeType, Copy, Clone)] @@ -61,18 +61,18 @@ pub(crate) struct SharedMemories { /// /// Note: batch offsets take stride into account, but not the others pub(crate) struct Offsets { - pub batch_lhs: UInt, - pub batch_rhs: UInt, - pub batch_out: UInt, - pub cube_row: UInt, - pub cube_col: UInt, + pub batch_lhs: u32, + pub batch_rhs: u32, + pub batch_out: u32, + pub cube_row: u32, + pub cube_col: u32, } #[cube] fn get_dims(lhs: &Tensor, rhs: &Tensor) -> Dimensions { let rank = lhs.rank(); - let first_dim = rank - UInt::new(2); - let second_dim = rank - UInt::new(1); + let first_dim = rank - 2; + let second_dim = rank - 1; let m = lhs.shape(first_dim); let k = lhs.shape(second_dim); let n = rhs.shape(second_dim); @@ -85,27 +85,27 @@ fn calculate_offsets( lhs: &Tensor, rhs: &Tensor, out: &Tensor, - config: Comptime, + #[comptime] config: ComptimeCmmaInfo, ) -> Offsets { - let block_size_m = Comptime::map(config, |c| c.block_size_m); - let block_size_n = Comptime::map(config, |c| c.block_size_n); + let block_size_m = config.block_size_m; + let block_size_n = config.block_size_m; // Cube offset - let cube_row = CUBE_POS_X * Comptime::runtime(block_size_m); - let cube_col = CUBE_POS_Y * Comptime::runtime(block_size_n); + let cube_row = CUBE_POS_X * block_size_m; + let cube_col = CUBE_POS_Y * block_size_n; let rank = out.rank(); - let dim_m = lhs.shape(rank - UInt::new(2)); - let dim_n = rhs.shape(rank - UInt::new(1)); + let dim_m = lhs.shape(rank - 2); + let dim_n = rhs.shape(rank - 1); // Batch offset for output let batch_out = dim_m * dim_n * CUBE_POS_Z; - let mut batch_lhs = UInt::new(0); - let mut batch_rhs = UInt::new(0); + let mut batch_lhs = 0; + let mut batch_rhs = 0; // Batch offset for lhs, rhs - for b in range(0u32, rank - UInt::new(2), Comptime::new(false)) { + for b in 0..rank - 2 { let tmp = batch_out / out.stride(b); batch_lhs += tmp % lhs.shape(b) * lhs.stride(b); batch_rhs += tmp % rhs.shape(b) * rhs.stride(b); @@ -121,25 +121,26 @@ fn calculate_offsets( } #[cube] -fn make_shared_memories(config: Comptime) -> SharedMemories { - let block_size_m = Comptime::map(config, |c| c.block_size_m); - let block_size_k = Comptime::map(config, |c| c.block_size_k); - let block_size_n = Comptime::map(config, |c| c.block_size_n); +fn make_shared_memories(#[comptime] config: ComptimeCmmaInfo) -> SharedMemories { + let block_size_m = config.block_size_m; + let block_size_k = config.block_size_k; + let block_size_n = config.block_size_n; - let lhs = SharedMemory::::new(Comptime::get(block_size_k * block_size_m)); - let rhs = SharedMemory::::new(Comptime::get(block_size_k * block_size_n)); + let lhs = SharedMemory::::new(block_size_k * block_size_m); + let rhs = SharedMemory::::new(block_size_k * block_size_n); - SharedMemories { lhs, rhs } + SharedMemories:: { lhs, rhs } } #[cube] pub(crate) fn make_accumulators( - config: Comptime, + #[comptime] config: ComptimeCmmaInfo, ) -> Sequence> { - let num_accumulators = Comptime::map(config, |c| c.num_accumulators); + let num_accumulators = config.num_accumulators; let mut accumulators = Sequence::>::new(); - for _ in range(0u32, Comptime::get(num_accumulators), Comptime::new(true)) { + #[unroll] + for _ in 0..num_accumulators { let acc = cmma::Matrix::::new( cmma::MatrixIdent::Accumulator, 16, diff --git a/crates/cubecl-linalg/src/matmul/cmma/block_io/base.rs b/crates/cubecl-linalg/src/matmul/cmma/block_io/base.rs index b0e29364..d6c26abe 100644 --- a/crates/cubecl-linalg/src/matmul/cmma/block_io/base.rs +++ b/crates/cubecl-linalg/src/matmul/cmma/block_io/base.rs @@ -1,19 +1,18 @@ +use crate::matmul::cmma::base::Dimensions; use cubecl_core as cubecl; use cubecl_core::prelude::*; -use crate::matmul::cmma::base::Dimensions; - #[cube] -pub(crate) trait BlockLoader: Send + Sync + 'static { +pub(crate) trait BlockLoader { fn load_tile( tensor: &Tensor, shared_memory: &mut SharedMemory, - batch_offset: UInt, - read_row: UInt, - read_col: UInt, - write_pos: UInt, - dim_vertical: UInt, - dim_horizontal: UInt, + batch_offset: u32, + read_row: u32, + read_col: u32, + write_pos: u32, + dim_vertical: u32, + dim_horizontal: u32, ); } @@ -23,10 +22,10 @@ pub(crate) trait BlockWriter: Send + Sync + 'static { fn write_output( out: &mut Tensor, accumulator_sm: SharedMemory, - batch_offset: UInt, - read_position: UInt, - write_row: UInt, - write_col: UInt, + batch_offset: u32, + read_position: u32, + write_row: u32, + write_col: u32, dims: Dimensions, ); } diff --git a/crates/cubecl-linalg/src/matmul/cmma/block_io/horizontal_block_check.rs b/crates/cubecl-linalg/src/matmul/cmma/block_io/horizontal_block_check.rs index 98b05171..0c933bb3 100644 --- a/crates/cubecl-linalg/src/matmul/cmma/block_io/horizontal_block_check.rs +++ b/crates/cubecl-linalg/src/matmul/cmma/block_io/horizontal_block_check.rs @@ -12,30 +12,31 @@ impl BlockLoader for HorizontalCheckBlockIO { fn load_tile( tensor: &Tensor, shared_memory: &mut SharedMemory, - batch_offset: UInt, - read_row: UInt, - read_col: UInt, - write_pos: UInt, - _dim_vertical: UInt, - dim_horizontal: UInt, + batch_offset: u32, + read_row: u32, + read_col: u32, + write_pos: u32, + _dim_vertical: u32, + dim_horizontal: u32, ) { - let tensor_vec = Comptime::vectorization(tensor); - let tensor_vec_r = Comptime::runtime(tensor_vec); - let is_scalar = Comptime::map(tensor_vec, |v| v.val == 1); + let tensor_vec = vectorization_of(tensor); + let is_scalar = tensor_vec == 1; if read_col < dim_horizontal { - let read_pos = (batch_offset + read_row * dim_horizontal + read_col) / tensor_vec_r; + let read_pos = (batch_offset + read_row * dim_horizontal + read_col) / tensor_vec; let value = tensor[read_pos]; - if Comptime::get(is_scalar) { + if is_scalar { shared_memory[write_pos] = FC::cast_from(value); } else { - for i in range(0u32, Comptime::get(tensor_vec), Comptime::new(true)) { + #[unroll] + for i in 0..tensor_vec { shared_memory[write_pos + i] = FC::cast_from(value[i]); } } } else { - for i in range(0u32, Comptime::get(tensor_vec), Comptime::new(true)) { + #[unroll] + for i in 0..tensor_vec { shared_memory[write_pos + i] = FC::new(0.); } } @@ -47,30 +48,30 @@ impl BlockWriter for HorizontalCheckBlockIO { fn write_output( out: &mut Tensor, accumulator_sm: SharedMemory, - batch_offset: UInt, - read_position: UInt, - write_row: UInt, - write_col: UInt, + batch_offset: u32, + read_position: u32, + write_row: u32, + write_col: u32, dims: Dimensions, ) { - let out_vec = Comptime::vectorization(out); - let out_vec_r = Comptime::runtime(out_vec); - let is_scalar = Comptime::map(out_vec, |v| v.val == 1); + let out_vec = vectorization_of(out); + let is_scalar = out_vec == 1; if write_col < dims.n { let write_position = batch_offset + write_row * dims.n + write_col; - if Comptime::get(is_scalar) { + if is_scalar { let val = accumulator_sm[read_position]; - out[write_position / out_vec_r] = val; + out[write_position / out_vec] = val; } else { - let mut value = F::vectorized_empty(Comptime::get(out_vec)); + let mut value = F::vectorized_empty(out_vec); - for i in range(0u32, Comptime::get(out_vec), Comptime::new(true)) { + #[unroll] + for i in 0..out_vec { value[i] = accumulator_sm[read_position + i]; } - out[write_position / out_vec_r] = value; + out[write_position / out_vec] = value; } } } diff --git a/crates/cubecl-linalg/src/matmul/cmma/block_io/unchecked_block.rs b/crates/cubecl-linalg/src/matmul/cmma/block_io/unchecked_block.rs index 72c0c47b..fdcc5e31 100644 --- a/crates/cubecl-linalg/src/matmul/cmma/block_io/unchecked_block.rs +++ b/crates/cubecl-linalg/src/matmul/cmma/block_io/unchecked_block.rs @@ -1,8 +1,7 @@ +use crate::matmul::cmma::base::Dimensions; use cubecl_core as cubecl; use cubecl_core::prelude::*; -use crate::matmul::cmma::base::Dimensions; - use super::base::{BlockLoader, BlockWriter}; /// Assumes block sizes divide tensor shape @@ -13,24 +12,24 @@ impl BlockLoader for UncheckedBlockIO { fn load_tile( tensor: &Tensor, shared_memory: &mut SharedMemory, - batch_offset: UInt, - read_row: UInt, - read_col: UInt, - write_pos: UInt, - _dim_vertical: UInt, - dim_horizontal: UInt, + batch_offset: u32, + read_row: u32, + read_col: u32, + write_pos: u32, + _dim_vertical: u32, + dim_horizontal: u32, ) { - let tensor_vec = Comptime::vectorization(tensor); - let tensor_vec_r = Comptime::runtime(tensor_vec); - let is_scalar = Comptime::map(tensor_vec, |v| v.val == 1); + let tensor_vec = vectorization_of(tensor); + let is_scalar = tensor_vec == 1; - let read_pos = (batch_offset + read_row * dim_horizontal + read_col) / tensor_vec_r; + let read_pos = (batch_offset + read_row * dim_horizontal + read_col) / tensor_vec; let value = tensor[read_pos]; - if Comptime::get(is_scalar) { + if is_scalar { shared_memory[write_pos] = FC::cast_from(value); } else { - for i in range(0u32, Comptime::get(tensor_vec), Comptime::new(true)) { + #[unroll] + for i in 0..tensor_vec { shared_memory[write_pos + i] = FC::cast_from(value[i]); } } @@ -42,29 +41,29 @@ impl BlockWriter for UncheckedBlockIO { fn write_output( out: &mut Tensor, accumulator_sm: SharedMemory, - batch_offset: UInt, - read_position: UInt, - write_row: UInt, - write_col: UInt, + batch_offset: u32, + read_position: u32, + write_row: u32, + write_col: u32, dims: Dimensions, ) { - let out_vec = Comptime::vectorization(out); - let out_vec_r = Comptime::runtime(out_vec); - let is_scalar = Comptime::map(out_vec, |v| v.val == 1); + let out_vec = vectorization_of(out); + let is_scalar = out_vec == 1; let write_position = batch_offset + write_row * dims.n + write_col; - if Comptime::get(is_scalar) { + if is_scalar { let val = accumulator_sm[read_position]; - out[write_position / out_vec_r] = val; + out[write_position / out_vec] = val; } else { - let mut value = F::vectorized_empty(Comptime::get(out_vec)); + let mut value = F::vectorized_empty(out_vec); - for i in range(0u32, Comptime::get(out_vec), Comptime::new(true)) { + #[unroll] + for i in 0..out_vec { value[i] = accumulator_sm[read_position + i]; } - out[write_position / out_vec_r] = value; + out[write_position / out_vec] = value; } } } diff --git a/crates/cubecl-linalg/src/matmul/cmma/block_io/vertical_block_check.rs b/crates/cubecl-linalg/src/matmul/cmma/block_io/vertical_block_check.rs index ec48748b..26b35569 100644 --- a/crates/cubecl-linalg/src/matmul/cmma/block_io/vertical_block_check.rs +++ b/crates/cubecl-linalg/src/matmul/cmma/block_io/vertical_block_check.rs @@ -1,8 +1,7 @@ +use crate::matmul::cmma::base::Dimensions; use cubecl_core as cubecl; use cubecl_core::prelude::*; -use crate::matmul::cmma::base::Dimensions; - use super::base::{BlockLoader, BlockWriter}; pub(crate) struct VerticalCheckBlockIO; @@ -12,30 +11,31 @@ impl BlockLoader for VerticalCheckBlockIO { fn load_tile( tensor: &Tensor, shared_memory: &mut SharedMemory, - batch_offset: UInt, - read_row: UInt, - read_col: UInt, - write_pos: UInt, - dim_vertical: UInt, - dim_horizontal: UInt, + batch_offset: u32, + read_row: u32, + read_col: u32, + write_pos: u32, + dim_vertical: u32, + dim_horizontal: u32, ) { - let tensor_vec = Comptime::vectorization(tensor); - let tensor_vec_r = Comptime::runtime(tensor_vec); - let is_scalar = Comptime::map(tensor_vec, |v| v.val == 1); + let tensor_vec = vectorization_of(tensor); + let is_scalar = tensor_vec == 1; if read_row < dim_vertical { - let read_pos = (batch_offset + read_row * dim_horizontal + read_col) / tensor_vec_r; + let read_pos = (batch_offset + read_row * dim_horizontal + read_col) / tensor_vec; let value = tensor[read_pos]; - if Comptime::get(is_scalar) { + if is_scalar { shared_memory[write_pos] = FC::cast_from(value); } else { - for i in range(0u32, Comptime::get(tensor_vec), Comptime::new(true)) { + #[unroll] + for i in 0..tensor_vec { shared_memory[write_pos + i] = FC::cast_from(value[i]); } } } else { - for i in range(0u32, Comptime::get(tensor_vec), Comptime::new(true)) { + #[unroll] + for i in 0..tensor_vec { shared_memory[write_pos + i] = FC::new(0.); } } @@ -47,30 +47,30 @@ impl BlockWriter for VerticalCheckBlockIO { fn write_output( out: &mut Tensor, accumulator_sm: SharedMemory, - batch_offset: UInt, - read_position: UInt, - write_row: UInt, - write_col: UInt, + batch_offset: u32, + read_position: u32, + write_row: u32, + write_col: u32, dims: Dimensions, ) { - let out_vec = Comptime::vectorization(out); - let out_vec_r = Comptime::runtime(out_vec); - let is_scalar = Comptime::map(out_vec, |v| v.val == 1); + let out_vec = vectorization_of(out); + let is_scalar = out_vec == 1; if write_row < dims.m { let write_position = batch_offset + write_row * dims.n + write_col; - if Comptime::get(is_scalar) { + if is_scalar { let val = accumulator_sm[read_position]; - out[write_position / out_vec_r] = val; + out[write_position / out_vec] = val; } else { - let mut value = F::vectorized_empty(Comptime::get(out_vec)); + let mut value = F::vectorized_empty(out_vec); - for i in range(0u32, Comptime::get(out_vec), Comptime::new(true)) { + #[unroll] + for i in 0..out_vec { value[i] = accumulator_sm[read_position + i]; } - out[write_position / out_vec_r] = value; + out[write_position / out_vec] = value; } } } diff --git a/crates/cubecl-linalg/src/matmul/cmma/block_io/whole_block_check.rs b/crates/cubecl-linalg/src/matmul/cmma/block_io/whole_block_check.rs index dff4ae85..185da7e6 100644 --- a/crates/cubecl-linalg/src/matmul/cmma/block_io/whole_block_check.rs +++ b/crates/cubecl-linalg/src/matmul/cmma/block_io/whole_block_check.rs @@ -1,8 +1,7 @@ +use crate::matmul::cmma::base::Dimensions; use cubecl_core as cubecl; use cubecl_core::prelude::*; -use crate::matmul::cmma::base::Dimensions; - use super::base::{BlockLoader, BlockWriter}; pub(crate) struct WholeCheckBlockIO; @@ -12,30 +11,31 @@ impl BlockLoader for WholeCheckBlockIO { fn load_tile( tensor: &Tensor, shared_memory: &mut SharedMemory, - batch_offset: UInt, - read_row: UInt, - read_col: UInt, - write_pos: UInt, - dim_vertical: UInt, - dim_horizontal: UInt, + batch_offset: u32, + read_row: u32, + read_col: u32, + write_pos: u32, + dim_vertical: u32, + dim_horizontal: u32, ) { - let tensor_vec = Comptime::vectorization(tensor); - let tensor_vec_r = Comptime::runtime(tensor_vec); - let is_scalar = Comptime::map(tensor_vec, |v| v.val == 1); + let tensor_vec = vectorization_of(tensor); + let is_scalar = tensor_vec == 1; if read_col < dim_horizontal && read_row < dim_vertical { - let read_pos = (batch_offset + read_row * dim_horizontal + read_col) / tensor_vec_r; + let read_pos = (batch_offset + read_row * dim_horizontal + read_col) / tensor_vec; let value = tensor[read_pos]; - if Comptime::get(is_scalar) { + if is_scalar { shared_memory[write_pos] = FC::cast_from(value); } else { - for i in range(0u32, Comptime::get(tensor_vec), Comptime::new(true)) { + #[unroll] + for i in 0..tensor_vec { shared_memory[write_pos + i] = FC::cast_from(value[i]); } } } else { - for i in range(0u32, Comptime::get(tensor_vec), Comptime::new(true)) { + #[unroll] + for i in 0..tensor_vec { shared_memory[write_pos + i] = FC::new(0.); } } @@ -47,30 +47,30 @@ impl BlockWriter for WholeCheckBlockIO { fn write_output( out: &mut Tensor, accumulator_sm: SharedMemory, - batch_offset: UInt, - read_position: UInt, - write_row: UInt, - write_col: UInt, + batch_offset: u32, + read_position: u32, + write_row: u32, + write_col: u32, dims: Dimensions, ) { - let out_vec = Comptime::vectorization(out); - let out_vec_r = Comptime::runtime(out_vec); - let is_scalar = Comptime::map(out_vec, |v| v.val == 1); + let out_vec = vectorization_of(out); + let is_scalar = out_vec == 1; if write_row < dims.m && write_col < dims.n { let write_position = batch_offset + write_row * dims.n + write_col; - if Comptime::get(is_scalar) { + if is_scalar { let val = accumulator_sm[read_position]; - out[write_position / out_vec_r] = val; + out[write_position / out_vec] = val; } else { - let mut value = F::vectorized_empty(Comptime::get(out_vec)); + let mut value = F::vectorized_empty(out_vec); - for i in range(0u32, Comptime::get(out_vec), Comptime::new(true)) { + #[unroll] + for i in 0..out_vec { value[i] = accumulator_sm[read_position + i]; } - out[write_position / out_vec_r] = value; + out[write_position / out_vec] = value; } } } diff --git a/crates/cubecl-linalg/src/matmul/cmma/block_loop.rs b/crates/cubecl-linalg/src/matmul/cmma/block_loop.rs index 10fe9f74..a3a3de5c 100644 --- a/crates/cubecl-linalg/src/matmul/cmma/block_loop.rs +++ b/crates/cubecl-linalg/src/matmul/cmma/block_loop.rs @@ -17,16 +17,16 @@ pub(crate) fn block_loop( shared_memories: SharedMemories, mut accumulators: Sequence>, runtime_info: RuntimeCmmaInfo, - comptime_info: Comptime, + #[comptime] comptime_info: ComptimeCmmaInfo, ) { - let block_size_k = Comptime::runtime(Comptime::map(comptime_info, |c| c.block_size_k)); - let write_out_reuse_smem = Comptime::map(comptime_info, |c| c.write_out_reuse_smem); + let block_size_k = comptime_info.block_size_k; + let write_out_reuse_smem = comptime_info.write_out_reuse_smem; // Equals ceil(dims.k / block_size_k) let dims = runtime_info.dims; let num_loops = (dims.k + block_size_k - 1) / block_size_k; - for block in range(0u32, num_loops, Comptime::new(false)) { + for block in 0..num_loops { let k_offset = block * block_size_k; load_to_shared_memories::( @@ -50,7 +50,7 @@ pub(crate) fn block_loop( sync_units(); } - if Comptime::get(write_out_reuse_smem) { + if write_out_reuse_smem { ReuseSmemWriter::write_to_output(out, accumulators, runtime_info, comptime_info); } else { LargeSmemWriter::write_to_output(out, accumulators, runtime_info, comptime_info); diff --git a/crates/cubecl-linalg/src/matmul/cmma/compute_loop.rs b/crates/cubecl-linalg/src/matmul/cmma/compute_loop.rs index 272f8cf7..d11da9fa 100644 --- a/crates/cubecl-linalg/src/matmul/cmma/compute_loop.rs +++ b/crates/cubecl-linalg/src/matmul/cmma/compute_loop.rs @@ -10,17 +10,18 @@ pub(crate) fn compute_loop( shared_memories: SharedMemories, accumulators: &mut Sequence>, ids: Ids, - comptime_info: Comptime, + #[comptime] comptime_info: ComptimeCmmaInfo, ) { - let block_size_n = Comptime::map(comptime_info, |c| c.block_size_n); - let tile_size = Comptime::map(comptime_info, |c| c.tile_size); - let num_accumulators = Comptime::map(comptime_info, |c| c.num_accumulators); - let num_coop_per_row = Comptime::runtime((block_size_n / tile_size) / num_accumulators); + let block_size_n = comptime_info.block_size_n; + let tile_size = comptime_info.tile_size; + let num_accumulators = comptime_info.num_accumulators; + let num_coop_per_row = (block_size_n / tile_size) / num_accumulators; let tile_row = ids.coop / num_coop_per_row; - let tile_col_base = (ids.coop % num_coop_per_row) * Comptime::runtime(num_accumulators); + let tile_col_base = (ids.coop % num_coop_per_row) * num_accumulators; - for n in range(0u32, Comptime::get(num_accumulators), Comptime::new(true)) { + #[unroll] + for n in 0..num_accumulators { compute_tile::( tile_row, tile_col_base + n, @@ -33,20 +34,21 @@ pub(crate) fn compute_loop( #[cube] fn compute_tile( - tile_row: UInt, - tile_col: UInt, + tile_row: u32, + tile_col: u32, shared_memories: SharedMemories, accumulator: cmma::Matrix, - comptime_info: Comptime, + #[comptime] comptime_info: ComptimeCmmaInfo, ) { - let block_size_k = Comptime::map(comptime_info, |c| c.block_size_k); - let tile_size = Comptime::map(comptime_info, |c| c.tile_size); - let unroll = Comptime::map(comptime_info, |c| c.unroll); + let block_size_k = comptime_info.block_size_k; + let tile_size = comptime_info.tile_size; + let unroll = comptime_info.unroll; - let smem_stride = Comptime::runtime(tile_size * tile_size); - let num_tiles_in_k = Comptime::runtime(block_size_k / tile_size); + let smem_stride = tile_size * tile_size; + let num_tiles_in_k = block_size_k / tile_size; - for k_iter in range(0u32, num_tiles_in_k, unroll) { + #[unroll(unroll)] + for k_iter in 0..num_tiles_in_k { let shared_lhs_tile = tile_row * num_tiles_in_k + k_iter; let shared_rhs_tile = tile_col * num_tiles_in_k + k_iter; let shared_lhs_pos = shared_lhs_tile * smem_stride; @@ -74,8 +76,8 @@ fn compute_tile( cmma::MatrixLayout::RowMajor, ); - cmma::load::(&a, lhs_slice, UInt::new(16)); - cmma::load::(&b, rhs_slice, UInt::new(16)); + cmma::load(&a, lhs_slice, 16); + cmma::load(&b, rhs_slice, 16); cmma::execute::(&a, &b, &accumulator, &accumulator); } diff --git a/crates/cubecl-linalg/src/matmul/cmma/config.rs b/crates/cubecl-linalg/src/matmul/cmma/config.rs index f355a753..d3124568 100644 --- a/crates/cubecl-linalg/src/matmul/cmma/config.rs +++ b/crates/cubecl-linalg/src/matmul/cmma/config.rs @@ -56,17 +56,17 @@ impl CmmaConfig { let num_coops = self.b_mn * self.b_k / (CMMA_TILE_SIZE * CMMA_TILE_SIZE); ComptimeCmmaInfo { - block_size_m: self.b_mn.into(), - block_size_k: self.b_k.into(), - block_size_n: self.b_mn.into(), - tile_size: CMMA_TILE_SIZE.into(), + block_size_m: self.b_mn as u32, + block_size_k: self.b_k as u32, + block_size_n: self.b_mn as u32, + tile_size: CMMA_TILE_SIZE as u32, unroll: self.unroll, check_m_bounds: m % self.b_mn != 0, check_k_bounds: k % self.b_k != 0, check_n_bounds: n % self.b_mn != 0, - coop_dim: CMMA_COOP_DIM.into(), - num_coops: UInt::new(num_coops as u32), - num_accumulators: UInt::new(self.alpha as u32), + coop_dim: CMMA_COOP_DIM as u32, + num_coops: num_coops as u32, + num_accumulators: self.alpha as u32, write_out_reuse_smem: self.write_out_strategy == WriteOutStrategy::ReuseSmem, } } @@ -117,13 +117,13 @@ impl Init for ComptimeCmmaInfo { /// Tiling 2D parameters pub struct ComptimeCmmaInfo { /// Block size along dimension of lhs - pub block_size_m: UInt, + pub block_size_m: u32, /// Block size along common dimension - pub block_size_k: UInt, + pub block_size_k: u32, /// Block size along dimension of rhs - pub block_size_n: UInt, + pub block_size_n: u32, /// Tile size (dimension of one side). Should correspond to cmma supported tile size - pub tile_size: UInt, + pub tile_size: u32, /// Bounds must be checked on lhs dimension pub check_m_bounds: bool, /// Bounds must be checked on common dimension @@ -133,11 +133,11 @@ pub struct ComptimeCmmaInfo { /// Unroll pub unroll: bool, /// The number of units that can collaborate - pub coop_dim: UInt, + pub coop_dim: u32, /// The number of collaboration groups - pub num_coops: UInt, + pub num_coops: u32, /// Number of cmma per subcube performed in one pass - pub num_accumulators: UInt, + pub num_accumulators: u32, /// Write out strategy: false = large, true = reuse pub write_out_reuse_smem: bool, } diff --git a/crates/cubecl-linalg/src/matmul/cmma/launch.rs b/crates/cubecl-linalg/src/matmul/cmma/launch.rs index d0e99f2f..094a57cb 100644 --- a/crates/cubecl-linalg/src/matmul/cmma/launch.rs +++ b/crates/cubecl-linalg/src/matmul/cmma/launch.rs @@ -1,9 +1,10 @@ use cubecl_core::{ client::ComputeClient, - frontend::{Float, TensorArg, TensorHandleRef, F16}, + frontend::{Float, TensorArg, TensorHandleRef}, ir::{Elem, FloatKind}, tensor_vectorization_factor, Feature, Runtime, }; +use half::f16; use crate::{ matmul::cmma::{base::cmma_kernel, config::CmmaConfig}, @@ -121,7 +122,7 @@ fn matmul_cmma_ref_no_check( tensor_vectorization_factor(&available_vectorizations, out.shape, out.strides, rank - 1); unsafe { - cmma_kernel::launch_unchecked::( + cmma_kernel::launch_unchecked::( client, cmma_config.cube_count::(out.shape), cmma_config.cube_dim(), diff --git a/crates/cubecl-linalg/src/matmul/cmma/load_shared_memory.rs b/crates/cubecl-linalg/src/matmul/cmma/load_shared_memory.rs index 1febb69a..99c247bc 100644 --- a/crates/cubecl-linalg/src/matmul/cmma/load_shared_memory.rs +++ b/crates/cubecl-linalg/src/matmul/cmma/load_shared_memory.rs @@ -16,14 +16,14 @@ use crate::matmul::cmma::block_io::{ pub(crate) fn load_to_shared_memories( lhs: &Tensor, rhs: &Tensor, - k_offset: UInt, + k_offset: u32, mut shared: SharedMemories, runtime_info: RuntimeCmmaInfo, - comptime_info: Comptime, + #[comptime] comptime_info: ComptimeCmmaInfo, ) { - let block_size_k = Comptime::map(comptime_info, |c| c.block_size_k); - let tile_size = Comptime::map(comptime_info, |c| c.tile_size); - let num_tiles_in_k = Comptime::runtime(block_size_k / tile_size); + let block_size_k = comptime_info.block_size_k; + let tile_size = comptime_info.tile_size; + let num_tiles_in_k = block_size_k / tile_size; load_lhs( lhs, @@ -47,13 +47,13 @@ pub(crate) fn load_to_shared_memories( pub(crate) fn load_lhs( lhs: &Tensor, shared_lhs: &mut SharedMemory, - num_tiles_in_k: UInt, - k_offset: UInt, + num_tiles_in_k: u32, + k_offset: u32, runtime_info: RuntimeCmmaInfo, - comptime_info: Comptime, + #[comptime] comptime_info: ComptimeCmmaInfo, ) { - let check_m_bounds = Comptime::map(comptime_info, |c| c.check_m_bounds); - let check_k_bounds = Comptime::map(comptime_info, |c| c.check_k_bounds); + let check_m_bounds = comptime_info.check_m_bounds; + let check_k_bounds = comptime_info.check_k_bounds; let ids = runtime_info.ids; let dims = runtime_info.dims; let offsets = runtime_info.offsets; @@ -61,8 +61,8 @@ pub(crate) fn load_lhs( let tile_row = ids.coop / num_tiles_in_k; let tile_col = ids.coop % num_tiles_in_k; - if Comptime::get(check_m_bounds) { - if Comptime::get(check_k_bounds) { + if check_m_bounds { + if check_k_bounds { load_tile::( lhs, shared_lhs, @@ -91,7 +91,7 @@ pub(crate) fn load_lhs( comptime_info, ); } - } else if Comptime::get(check_k_bounds) { + } else if check_k_bounds { load_tile::( lhs, shared_lhs, @@ -126,13 +126,13 @@ pub(crate) fn load_lhs( pub(crate) fn load_rhs( rhs: &Tensor, shared_rhs: &mut SharedMemory, - num_tiles_in_k: UInt, - k_offset: UInt, + num_tiles_in_k: u32, + k_offset: u32, runtime_info: RuntimeCmmaInfo, - comptime_info: Comptime, + #[comptime] comptime_info: ComptimeCmmaInfo, ) { - let check_k_bounds = Comptime::map(comptime_info, |c| c.check_k_bounds); - let check_n_bounds = Comptime::map(comptime_info, |c| c.check_n_bounds); + let check_k_bounds = comptime_info.check_k_bounds; + let check_n_bounds = comptime_info.check_n_bounds; let ids = runtime_info.ids; let dims = runtime_info.dims; let offsets = runtime_info.offsets; @@ -140,8 +140,8 @@ pub(crate) fn load_rhs( let tile_row = ids.coop % num_tiles_in_k; let tile_col = ids.coop / num_tiles_in_k; - if Comptime::get(check_k_bounds) { - if Comptime::get(check_n_bounds) { + if check_k_bounds { + if check_n_bounds { load_tile::( rhs, shared_rhs, @@ -170,7 +170,7 @@ pub(crate) fn load_rhs( comptime_info, ); } - } else if Comptime::get(check_n_bounds) { + } else if check_n_bounds { load_tile::( rhs, shared_rhs, @@ -204,41 +204,40 @@ pub(crate) fn load_rhs( fn load_tile>( tensor: &Tensor, shared_memory: &mut SharedMemory, - batch_offset: UInt, - tile_row: UInt, - tile_col: UInt, - dim_vertical: UInt, - dim_horizontal: UInt, - skip_row: UInt, - skip_col: UInt, + batch_offset: u32, + tile_row: u32, + tile_col: u32, + dim_vertical: u32, + dim_horizontal: u32, + skip_row: u32, + skip_col: u32, runtime_info: RuntimeCmmaInfo, - comptime_info: Comptime, + #[comptime] comptime_info: ComptimeCmmaInfo, ) { - let tile_size = Comptime::map(comptime_info, |c| c.tile_size); - let tile_size_r = Comptime::runtime(tile_size); - let tensor_vec = Comptime::vectorization(tensor); - let tensor_vec_r = Comptime::runtime(tensor_vec); + let tile_size = comptime_info.tile_size; + let tensor_vec = vectorization_of(tensor); let ids = runtime_info.ids; // Must equal SUBCUBE_DIM, but must be known comptime too - let coop_dim = Comptime::map(comptime_info, |c| c.coop_dim); + let coop_dim = comptime_info.coop_dim; let num_unit_reads = tile_size * tile_size / (tensor_vec * coop_dim); - let num_units_per_row = Comptime::runtime(tile_size / tensor_vec); + let num_units_per_row = tile_size / tensor_vec; - let lane_row_step = Comptime::runtime(coop_dim * tensor_vec / tile_size); + let lane_row_step = coop_dim * tensor_vec / tile_size; let lane_row_offset = ids.lane / num_units_per_row; - let read_row_offset = skip_row + tile_row * tile_size_r + lane_row_offset; + let read_row_offset = skip_row + tile_row * tile_size + lane_row_offset; - let lane_col_offset = ids.lane % num_units_per_row * tensor_vec_r; - let read_col = skip_col + tile_col * tile_size_r + lane_col_offset; + let lane_col_offset = ids.lane % num_units_per_row * tensor_vec; + let read_col = skip_col + tile_col * tile_size + lane_col_offset; - let sm_stride = Comptime::runtime(tile_size * tile_size); + let sm_stride = tile_size * tile_size; - let write_offset = ids.coop * sm_stride + ids.lane * tensor_vec_r; - let sm_step = Comptime::runtime(coop_dim * tensor_vec); + let write_offset = ids.coop * sm_stride + ids.lane * tensor_vec; + let sm_step = coop_dim * tensor_vec; - for i in range(0u32, Comptime::get(num_unit_reads), Comptime::new(true)) { + #[unroll] + for i in 0..num_unit_reads { let read_row = read_row_offset + i * lane_row_step; let write_pos = write_offset + i * sm_step; diff --git a/crates/cubecl-linalg/src/matmul/cmma/write_output/base.rs b/crates/cubecl-linalg/src/matmul/cmma/write_output/base.rs index f8af705f..27727650 100644 --- a/crates/cubecl-linalg/src/matmul/cmma/write_output/base.rs +++ b/crates/cubecl-linalg/src/matmul/cmma/write_output/base.rs @@ -19,24 +19,24 @@ pub(crate) trait OutputWriter: Send + Sync + 'static { out: &mut Tensor, accumulators: Sequence>, runtime_info: RuntimeCmmaInfo, - comptime_info: Comptime, + #[comptime] comptime_info: ComptimeCmmaInfo, ); } #[cube] pub(crate) fn shared_memory_to_output( out: &mut Tensor, - smem_position: UInt, + smem_position: u32, accumulator_sm: SharedMemory, - n_iter: UInt, + n_iter: u32, runtime_info: RuntimeCmmaInfo, - comptime_info: Comptime, + #[comptime] comptime_info: ComptimeCmmaInfo, ) { - let check_m_bounds = Comptime::map(comptime_info, |c| c.check_m_bounds); - let check_n_bounds = Comptime::map(comptime_info, |c| c.check_n_bounds); + let check_m_bounds = comptime_info.check_m_bounds; + let check_n_bounds = comptime_info.check_n_bounds; - if Comptime::get(check_m_bounds) { - if Comptime::get(check_n_bounds) { + if check_m_bounds { + if check_n_bounds { write_tile::( out, smem_position, @@ -55,7 +55,7 @@ pub(crate) fn shared_memory_to_output( comptime_info, ); } - } else if Comptime::get(check_n_bounds) { + } else if check_n_bounds { write_tile::( out, smem_position, @@ -79,46 +79,43 @@ pub(crate) fn shared_memory_to_output( #[cube] fn write_tile>( out: &mut Tensor, - smem_position: UInt, + smem_position: u32, accumulator_sm: SharedMemory, - n_iter: UInt, + n_iter: u32, runtime_info: RuntimeCmmaInfo, - comptime_info: Comptime, + #[comptime] comptime_info: ComptimeCmmaInfo, ) { - let tile_size = Comptime::map(comptime_info, |c| c.tile_size); - let tile_size_r = Comptime::runtime(tile_size); - let num_accumulators = Comptime::map(comptime_info, |c| c.num_accumulators); - let block_size_n = Comptime::map(comptime_info, |c| c.block_size_n); - let num_accum_groups_in_block_row = - Comptime::runtime(block_size_n / (tile_size * num_accumulators)); + let tile_size = comptime_info.tile_size; + let num_accumulators = comptime_info.num_accumulators; + let block_size_n = comptime_info.block_size_n; + let num_accum_groups_in_block_row = block_size_n / (tile_size * num_accumulators); - let out_vec = Comptime::vectorization(out); - let out_vec_r = Comptime::runtime(out_vec); - let n_units_per_tile_row = Comptime::runtime(tile_size / out_vec); - let sm_stride = Comptime::runtime(tile_size * tile_size); - let coop_dim = Comptime::map(comptime_info, |c| c.coop_dim); + let out_vec = vectorization_of(out); + let n_units_per_tile_row = tile_size / out_vec; + let sm_stride = tile_size * tile_size; + let coop_dim = comptime_info.coop_dim; let dims = runtime_info.dims; let ids = runtime_info.ids; let offsets = runtime_info.offsets; let tile_row = ids.coop / num_accum_groups_in_block_row; - let tile_col = (ids.coop % num_accum_groups_in_block_row) * Comptime::runtime(num_accumulators); + let tile_col = (ids.coop % num_accum_groups_in_block_row) * num_accumulators; let num_unit_writes = tile_size * tile_size / (out_vec * coop_dim); - let smem_offset = smem_position * sm_stride + ids.lane * out_vec_r; - let sm_step = Comptime::runtime(coop_dim * out_vec); + let smem_offset = smem_position * sm_stride + ids.lane * out_vec; + let sm_step = coop_dim * out_vec; - let lane_row_step = Comptime::runtime(coop_dim * out_vec / tile_size); + let lane_row_step = coop_dim * out_vec / tile_size; let unit_write_row = ids.lane / n_units_per_tile_row; - let unit_write_col = ids.lane % n_units_per_tile_row * out_vec_r; + let unit_write_col = ids.lane % n_units_per_tile_row * out_vec; - let row_offset = offsets.cube_row + tile_row * tile_size_r; - let write_col = - offsets.cube_col + tile_col * tile_size_r + unit_write_col + n_iter * tile_size_r; + let row_offset = offsets.cube_row + tile_row * tile_size; + let write_col = offsets.cube_col + tile_col * tile_size + unit_write_col + n_iter * tile_size; - for i in range(0u32, Comptime::get(num_unit_writes), Comptime::new(true)) { + #[unroll] + for i in 0..num_unit_writes { let read_pos = smem_offset + i * sm_step; let write_row = row_offset + unit_write_row + i * lane_row_step; diff --git a/crates/cubecl-linalg/src/matmul/cmma/write_output/large_smem.rs b/crates/cubecl-linalg/src/matmul/cmma/write_output/large_smem.rs index 29d3652c..b3c31030 100644 --- a/crates/cubecl-linalg/src/matmul/cmma/write_output/large_smem.rs +++ b/crates/cubecl-linalg/src/matmul/cmma/write_output/large_smem.rs @@ -16,37 +16,38 @@ impl OutputWriter for LargeSmemWriter { out: &mut Tensor, accumulators: Sequence>, runtime_info: RuntimeCmmaInfo, - comptime_info: Comptime, + #[comptime] comptime_info: ComptimeCmmaInfo, ) { - let num_accumulators = Comptime::map(comptime_info, |c| c.num_accumulators); - let tile_size = Comptime::map(comptime_info, |c| c.tile_size); - let num_coops = Comptime::map(comptime_info, |c| c.num_coops); + let num_accumulators = comptime_info.num_accumulators; + let tile_size = comptime_info.tile_size; + let num_coops = comptime_info.num_coops; let ids = runtime_info.ids; let smem_stride = tile_size * tile_size; - let smem_stride_r = Comptime::runtime(smem_stride); let smem_size = num_accumulators * num_coops * smem_stride; - let mut acc_sm = SharedMemory::::new(Comptime::get(smem_size)); + let mut acc_sm = SharedMemory::::new(smem_size); - let slice_offset = ids.coop * Comptime::runtime(num_accumulators * smem_stride); - let smem_position_base = Comptime::runtime(num_accumulators) * ids.coop; + let slice_offset = ids.coop * num_accumulators * smem_stride; + let smem_position_base = num_accumulators * ids.coop; - for n in range(0u32, Comptime::get(num_accumulators), Comptime::new(true)) { - let slice_start = slice_offset + n * smem_stride_r; - let slice_end = slice_start + smem_stride_r; + #[unroll] + for n in 0..num_accumulators { + let slice_start = slice_offset + n * smem_stride; + let slice_end = slice_start + smem_stride; let slice = acc_sm.slice_mut(slice_start, slice_end); cmma::store::( slice, accumulators.index(n), - UInt::new(16), + 16, cmma::MatrixLayout::RowMajor, ); } - for n in range(0u32, Comptime::get(num_accumulators), Comptime::new(true)) { + #[unroll] + for n in 0..num_accumulators { let smem_position = smem_position_base + n; shared_memory_to_output(out, smem_position, acc_sm, n, runtime_info, comptime_info); } diff --git a/crates/cubecl-linalg/src/matmul/cmma/write_output/reuse_smem.rs b/crates/cubecl-linalg/src/matmul/cmma/write_output/reuse_smem.rs index 2dde03a3..f0286422 100644 --- a/crates/cubecl-linalg/src/matmul/cmma/write_output/reuse_smem.rs +++ b/crates/cubecl-linalg/src/matmul/cmma/write_output/reuse_smem.rs @@ -16,27 +16,27 @@ impl OutputWriter for ReuseSmemWriter { out: &mut Tensor, accumulators: Sequence>, runtime_info: RuntimeCmmaInfo, - comptime_info: Comptime, + #[comptime] comptime_info: ComptimeCmmaInfo, ) { - let num_accumulators = Comptime::map(comptime_info, |c| c.num_accumulators); - let tile_size = Comptime::map(comptime_info, |c| c.tile_size); - let num_coops = Comptime::map(comptime_info, |c| c.num_coops); + let num_accumulators = comptime_info.num_accumulators; + let tile_size = comptime_info.tile_size; + let num_coops = comptime_info.num_coops; let ids = runtime_info.ids; let sm_stride = tile_size * tile_size; let sm_size = num_coops * sm_stride; - let acc_sm = SharedMemory::::new(Comptime::get(sm_size)); + let acc_sm = SharedMemory::::new(sm_size); - let slice_offset = ids.coop * Comptime::runtime(sm_stride); - let slice = - acc_sm.slice_mut_unsafe(slice_offset, slice_offset + Comptime::runtime(sm_stride)); + let slice_offset = ids.coop * sm_stride; + let slice = acc_sm.slice_mut_unsafe(slice_offset, slice_offset + sm_stride); - for n in range(0u32, Comptime::get(num_accumulators), Comptime::new(true)) { + #[unroll] + for n in 0..num_accumulators { cmma::store::( slice, accumulators.index(n), - UInt::new(16), + 16, cmma::MatrixLayout::RowMajor, ); diff --git a/crates/cubecl-linalg/src/matmul/tests/cmma/compute_loop.rs b/crates/cubecl-linalg/src/matmul/tests/cmma/compute_loop.rs index 58309e5f..ee8a4423 100644 --- a/crates/cubecl-linalg/src/matmul/tests/cmma/compute_loop.rs +++ b/crates/cubecl-linalg/src/matmul/tests/cmma/compute_loop.rs @@ -1,38 +1,39 @@ +use cubecl::prelude::*; use cubecl_core as cubecl; -use cubecl_core::prelude::*; use crate::matmul::cmma::{ - base::{make_accumulators, Ids, IdsExpand, SharedMemories, SharedMemoriesExpand}, + base::{make_accumulators, Ids, SharedMemories}, compute_loop::compute_loop, config::{CmmaConfig, ComptimeCmmaInfo, WriteOutStrategy}, }; use crate::matmul::tests::test_utils::{ assert_equals, cmma_available, create_empty, range_tensor_f16, }; +use half::f16; #[cube(launch_unchecked)] fn compute_loop_test( lhs_tensor: &Tensor, rhs_tensor: &Tensor, accumulate_array: &mut Array, - b_mn: Comptime, - b_k: Comptime, - comptime_info: Comptime, + #[comptime] b_mn: u32, + #[comptime] b_k: u32, + #[comptime] comptime_info: ComptimeCmmaInfo, ) { - let mut lhs = SharedMemory::::new(Comptime::get(b_mn * b_k)); - let mut rhs = SharedMemory::::new(Comptime::get(b_k * b_mn)); + let mut lhs = SharedMemory::::new(b_mn * b_k); + let mut rhs = SharedMemory::::new(b_k * b_mn); - for i in range(0u32, Comptime::get(b_mn * b_k), Comptime::new(false)) { + for i in 0..b_mn * b_k { lhs[i] = lhs_tensor[i]; } - for i in range(0u32, Comptime::get(b_k * b_mn), Comptime::new(false)) { + for i in 0..b_k * b_mn { rhs[i] = rhs_tensor[i]; } - for i in range(0u32, Comptime::get(b_mn * b_mn), Comptime::new(false)) { + for i in 0..b_mn * b_mn { accumulate_array[i] = F::new(0.); } - let shared_memories = SharedMemories { lhs, rhs }; + let shared_memories = SharedMemories:: { lhs, rhs }; let mut accumulators = make_accumulators::(comptime_info); compute_loop( @@ -45,18 +46,19 @@ fn compute_loop_test( comptime_info, ); - let num_accumulators = Comptime::map(comptime_info, |c| c.num_accumulators); - let tile_size = Comptime::map(comptime_info, |c| c.tile_size); - let slice_offset = Comptime::runtime(tile_size * tile_size); - let offset = UNIT_POS_Y * slice_offset * Comptime::runtime(num_accumulators); + let num_accumulators = comptime_info.num_accumulators; + let tile_size = comptime_info.tile_size; + let slice_offset = tile_size * tile_size; + let offset = UNIT_POS_Y * slice_offset * num_accumulators; - for n in range(0u32, Comptime::get(num_accumulators), Comptime::new(true)) { + #[unroll] + for n in 0..num_accumulators { let slice = accumulate_array.slice_mut(offset + n * slice_offset, offset + (n + 1) * slice_offset); cmma::store::( slice, - &accumulators.index(n), - UInt::new(16), + accumulators.index(n), + 16, cmma::MatrixLayout::RowMajor, ); } @@ -83,15 +85,15 @@ fn compute_loop_test_case( block_config.comptime_info(block_config.b_mn, block_config.b_k, block_config.b_mn); unsafe { - compute_loop_test::launch_unchecked::( + compute_loop_test::launch_unchecked::( &R::client(device), cube_count, cube_dim, TensorArg::from_raw_parts(&lhs.handle, &lhs.strides, &lhs.shape, 1), TensorArg::from_raw_parts(&rhs.handle, &rhs.strides, &rhs.shape, 1), ArrayArg::from_raw_parts(&results, block_config.b_mn * block_config.b_mn, 1), - UInt::new(block_config.b_mn as u32), - UInt::new(block_config.b_k as u32), + block_config.b_mn as u32, + block_config.b_k as u32, comptime_info, ); }; diff --git a/crates/cubecl-linalg/src/matmul/tests/cmma/load_shared_memory.rs b/crates/cubecl-linalg/src/matmul/tests/cmma/load_shared_memory.rs index 658837e3..5ba8d174 100644 --- a/crates/cubecl-linalg/src/matmul/tests/cmma/load_shared_memory.rs +++ b/crates/cubecl-linalg/src/matmul/tests/cmma/load_shared_memory.rs @@ -3,10 +3,7 @@ use std::ops::Range; use cubecl_core as cubecl; use cubecl_core::prelude::*; -use crate::matmul::cmma::base::{ - Dimensions, DimensionsExpand, Ids, IdsExpand, Offsets, OffsetsExpand, RuntimeCmmaInfo, - RuntimeCmmaInfoExpand, -}; +use crate::matmul::cmma::base::{Dimensions, Ids, Offsets, RuntimeCmmaInfo}; use crate::matmul::cmma::config::{CmmaConfig, WriteOutStrategy}; use crate::matmul::tests::test_utils::{assert_equals_range, create_empty}; use crate::matmul::{ @@ -20,27 +17,27 @@ use super::base::{DimsTestCase, B_K, B_MN}; fn load_lhs_test( lhs_tensor: &Tensor, lhs_sm_arr: &mut Array, - k_offset: UInt, - m: UInt, - k: UInt, - n: UInt, - config: Comptime, + k_offset: u32, + m: u32, + k: u32, + n: u32, + #[comptime] config: ComptimeCmmaInfo, ) { - let block_size_m = Comptime::map(config, |c| c.block_size_m); - let block_size_k = Comptime::map(config, |c| c.block_size_k); + let block_size_m = config.block_size_m; + let block_size_k = config.block_size_k; let sm_size = block_size_k * block_size_m; - let mut lhs_sm = SharedMemory::::new(Comptime::get(sm_size)); - for i in range(0u32, Comptime::get(sm_size), Comptime::new(false)) { + let mut lhs_sm = SharedMemory::::new(sm_size); + for i in 0..sm_size { lhs_sm[i] = F::new(0.); } let offsets = Offsets { - batch_lhs: UInt::new(0), - batch_rhs: UInt::new(0), - batch_out: UInt::new(0), - cube_row: UInt::new(0), - cube_col: UInt::new(0), + batch_lhs: 0, + batch_rhs: 0, + batch_out: 0, + cube_row: 0, + cube_col: 0, }; let dims = Dimensions { m, k, n }; let ids = Ids { @@ -49,16 +46,9 @@ fn load_lhs_test( }; let runtime_info = RuntimeCmmaInfo { offsets, dims, ids }; - load_lhs( - lhs_tensor, - &mut lhs_sm, - UInt::new(2), - k_offset, - runtime_info, - config, - ); + load_lhs(lhs_tensor, &mut lhs_sm, 2, k_offset, runtime_info, config); - for i in range(0u32, Comptime::get(sm_size), Comptime::new(false)) { + for i in 0..sm_size { lhs_sm_arr[i] = lhs_sm[i]; } } @@ -67,27 +57,27 @@ fn load_lhs_test( fn load_rhs_test( rhs_tensor: &Tensor, rhs_sm_arr: &mut Array, - k_offset: UInt, - m: UInt, - k: UInt, - n: UInt, - config: Comptime, + k_offset: u32, + m: u32, + k: u32, + n: u32, + #[comptime] config: ComptimeCmmaInfo, ) { - let block_size_k = Comptime::map(config, |c| c.block_size_k); - let block_size_n = Comptime::map(config, |c| c.block_size_n); + let block_size_k = config.block_size_k; + let block_size_n = config.block_size_n; let sm_size = block_size_k * block_size_n; - let mut rhs_sm = SharedMemory::::new(Comptime::get(sm_size)); + let mut rhs_sm = SharedMemory::::new(sm_size); - for i in range(0u32, Comptime::get(sm_size), Comptime::new(false)) { + for i in 0..sm_size { rhs_sm[i] = F::new(0.); } let offsets = Offsets { - batch_lhs: UInt::new(0), - batch_rhs: UInt::new(0), - batch_out: UInt::new(0), - cube_row: UInt::new(0), - cube_col: UInt::new(0), + batch_lhs: 0, + batch_rhs: 0, + batch_out: 0, + cube_row: 0, + cube_col: 0, }; let dims = Dimensions { m, k, n }; let ids = Ids { @@ -96,16 +86,9 @@ fn load_rhs_test( }; let runtime_info = RuntimeCmmaInfo { offsets, dims, ids }; - load_rhs( - rhs_tensor, - &mut rhs_sm, - UInt::new(2), - k_offset, - runtime_info, - config, - ); + load_rhs(rhs_tensor, &mut rhs_sm, 2, k_offset, runtime_info, config); - for i in range(0u32, Comptime::get(sm_size), Comptime::new(false)) { + for i in 0..sm_size { rhs_sm_arr[i] = rhs_sm[i]; } } @@ -141,7 +124,7 @@ fn load_shared_memory_test_case( }; unsafe { - load_lhs_test::launch_unchecked::( + load_lhs_test::launch_unchecked::( &R::client(device), config.cube_count::(&[dims.m, dims.n]), config.cube_dim(), diff --git a/crates/cubecl-linalg/src/matmul/tests/cmma/write_output.rs b/crates/cubecl-linalg/src/matmul/tests/cmma/write_output.rs index 96409a61..de5d27b9 100644 --- a/crates/cubecl-linalg/src/matmul/tests/cmma/write_output.rs +++ b/crates/cubecl-linalg/src/matmul/tests/cmma/write_output.rs @@ -3,10 +3,7 @@ use std::ops::Range; use cubecl_core as cubecl; use cubecl_core::prelude::*; -use crate::matmul::cmma::base::{ - Dimensions, DimensionsExpand, Ids, IdsExpand, Offsets, OffsetsExpand, RuntimeCmmaInfo, - RuntimeCmmaInfoExpand, -}; +use crate::matmul::cmma::base::{Dimensions, Ids, Offsets, RuntimeCmmaInfo}; use crate::matmul::cmma::config::{CmmaConfig, ComptimeCmmaInfo, WriteOutStrategy}; use crate::matmul::cmma::write_output::base::shared_memory_to_output; use crate::matmul::tests::test_utils::{ @@ -19,31 +16,31 @@ use super::base::DimsTestCase; fn write_output_test( out: &mut Tensor, acc_sm_arr: &mut Array, - m: UInt, - k: UInt, - n: UInt, - config: Comptime, + m: u32, + k: u32, + n: u32, + #[comptime] config: ComptimeCmmaInfo, ) { - let num_accumulators = Comptime::map(config, |c| c.num_accumulators); - let tile_size = Comptime::map(config, |c| c.tile_size); - let num_coops = Comptime::map(config, |c| c.num_coops); - let block_size_m = Comptime::map(config, |c| c.block_size_m); - let block_size_n = Comptime::map(config, |c| c.block_size_n); + let num_accumulators = config.num_accumulators; + let tile_size = config.tile_size; + let num_coops = config.num_coops; + let block_size_m = config.block_size_m; + let block_size_n = config.block_size_n; let sm_stride = tile_size * tile_size; let sm_size = num_accumulators * num_coops * sm_stride; - let mut accumulate = SharedMemory::::new(Comptime::get(sm_size)); - for i in range(0u32, Comptime::get(sm_size), Comptime::new(false)) { + let mut accumulate = SharedMemory::::new(sm_size); + for i in 0..sm_size { accumulate[i] = acc_sm_arr[i]; } let offsets = Offsets { - batch_lhs: UInt::new(0), - batch_rhs: UInt::new(0), - batch_out: UInt::new(0), - cube_row: CUBE_POS_X * Comptime::runtime(block_size_m), - cube_col: CUBE_POS_Y * Comptime::runtime(block_size_n), + batch_lhs: 0, + batch_rhs: 0, + batch_out: 0, + cube_row: CUBE_POS_X * block_size_m, + cube_col: CUBE_POS_Y * block_size_n, }; let dims = Dimensions { m, k, n }; let ids = Ids { @@ -52,8 +49,9 @@ fn write_output_test( }; let runtime_info = RuntimeCmmaInfo { offsets, dims, ids }; - let smem_position_base = Comptime::runtime(num_accumulators) * ids.coop; - for n_iter in range(0u32, Comptime::get(num_accumulators), Comptime::new(true)) { + let smem_position_base = num_accumulators * ids.coop; + #[unroll] + for n_iter in 0..num_accumulators { shared_memory_to_output( out, smem_position_base + n_iter, @@ -80,7 +78,7 @@ fn write_output_test_case( let acc_sm = range_tensor::(&client, config.b_mn, config.b_mn); unsafe { - write_output_test::launch_unchecked::( + write_output_test::launch_unchecked::( &client, config.cube_count::(&[dims.m, dims.n]), config.cube_dim(), diff --git a/crates/cubecl-linalg/src/matmul/tests/matmul_tests.rs b/crates/cubecl-linalg/src/matmul/tests/matmul_tests.rs index 6c482a3f..0bda6d29 100644 --- a/crates/cubecl-linalg/src/matmul/tests/matmul_tests.rs +++ b/crates/cubecl-linalg/src/matmul/tests/matmul_tests.rs @@ -1,4 +1,4 @@ -use cubecl_core::{frontend::F32, CubeElement, Runtime}; +use cubecl_core::{CubeElement, Runtime}; use half::f16; use crate::{ @@ -202,7 +202,7 @@ impl MatmulTestCase { f32::from_bytes(&R::client(device).read(tensor_2.handle.clone().binding())), ); - let out = tiling2d::launch::(&client, tensor_1, tensor_2, out, Default::default()); + let out = tiling2d::launch::(&client, tensor_1, tensor_2, out, Default::default()); assert_equals_approx::(&client, out.handle, &expected, self.epsilon); } @@ -228,7 +228,7 @@ impl MatmulTestCase { f32::from_bytes(&client.read(tensor_2.handle.clone().binding())), ); - let out = launch::(&client, tensor_1, tensor_2, out, config); + let out = launch::(&client, tensor_1, tensor_2, out, config); assert_equals_approx::(&client, out.handle, &expected, self.epsilon); } diff --git a/crates/cubecl-linalg/src/matmul/tests/test_utils.rs b/crates/cubecl-linalg/src/matmul/tests/test_utils.rs index e6a88f55..fd6824a0 100644 --- a/crates/cubecl-linalg/src/matmul/tests/test_utils.rs +++ b/crates/cubecl-linalg/src/matmul/tests/test_utils.rs @@ -1,11 +1,11 @@ use bytemuck::cast_slice; use cubecl_core::{ client::ComputeClient, - frontend::{F16, F32}, ir::{Elem, FloatKind}, server::Handle, CubeElement, Feature, Runtime, }; +use half::f16; use std::ops::Range; use crate::{ @@ -17,7 +17,7 @@ pub(crate) fn range_tensor_f16( client: &ComputeClient, x: usize, y: usize, -) -> TensorHandle { +) -> TensorHandle { let n_elements = x * y; let mut data = Vec::with_capacity(n_elements); @@ -34,7 +34,7 @@ pub(crate) fn range_tensor( client: &ComputeClient, x: usize, y: usize, -) -> TensorHandle { +) -> TensorHandle { let n_elements = x * y; let mut data: Vec = Vec::with_capacity(n_elements); @@ -53,7 +53,7 @@ pub(crate) fn range_tensor_with_factor( x: usize, y: usize, factor: f32, -) -> TensorHandle { +) -> TensorHandle { let n_elements = batch * x * y; let mut data: Vec = Vec::with_capacity(n_elements); @@ -70,7 +70,7 @@ pub(crate) fn range_tensor_transposed( client: &ComputeClient, x: usize, y: usize, -) -> TensorHandle { +) -> TensorHandle { let n_elements = x * y; let mut data: Vec = Vec::with_capacity(n_elements); @@ -90,7 +90,7 @@ pub(crate) fn zeros_tensor( client: &ComputeClient, x: usize, y: usize, -) -> TensorHandle { +) -> TensorHandle { let n_elements = x * y; let data: Vec = vec![0.; n_elements]; @@ -115,7 +115,7 @@ pub(crate) fn assert_equals( let actual = client.read(output.binding()); let actual = f32::from_bytes(&actual); - assert_eq!(actual, expected); + pretty_assertions::assert_eq!(actual, expected); } pub(crate) fn assert_equals_approx( diff --git a/crates/cubecl-linalg/src/matmul/tests/tiling2d/compute_loop.rs b/crates/cubecl-linalg/src/matmul/tests/tiling2d/compute_loop.rs index 7a3db32b..9f40ff71 100644 --- a/crates/cubecl-linalg/src/matmul/tests/tiling2d/compute_loop.rs +++ b/crates/cubecl-linalg/src/matmul/tests/tiling2d/compute_loop.rs @@ -7,7 +7,7 @@ use crate::matmul::{ assert_equals, create_empty, make_tiling2d_config, range_tensor, range_tensor_transposed, }, tiling2d::{ - base::{Coordinates, CoordinatesExpand, TILE_SIZE}, + base::{Coordinates, TILE_SIZE}, compute_loop::compute_loop, config::CubeTiling2dConfig, }, @@ -19,19 +19,15 @@ fn tile_outer_product_test( register_m: Array, register_n: Array, results: &mut Array, - config: Comptime, + #[comptime] config: CubeTiling2dConfig, ) { // We launch with array then convert to vectorized float, // because direct launch of vectorized float is not supported - let tile_size = Comptime::map(config, |c| c.tile_size); + let tile_size = config.tile_size; let register_m = register_m.to_vectorized(tile_size); let register_n = register_n.to_vectorized(tile_size); - for i in range( - 0u32, - Comptime::get(tile_size * tile_size), - Comptime::new(false), - ) { + for i in 0..tile_size * tile_size { results[i] = F::new(0.); } tile_outer_product::(register_m, register_n, results, config) @@ -51,7 +47,7 @@ pub fn tile_outer_product_vectorized_unit_test_2(device: &R::Device) let config = make_tiling2d_config(SOME_DIM, SOME_DIM, SOME_DIM); unsafe { - tile_outer_product_test::launch_unchecked::( + tile_outer_product_test::launch_unchecked::( &client, cube_count, cube_dim, @@ -73,42 +69,42 @@ pub fn tile_outer_product_vectorized_unit_test_2(device: &R::Device) fn compute_loop_test( lhs: &Tensor, rhs: &Tensor, - unit_row: UInt, - unit_col: UInt, + unit_row: u32, + unit_col: u32, results: &mut Array, - lhs_len: Comptime, - rhs_len: Comptime, - config: Comptime, + #[comptime] lhs_len: u32, + #[comptime] rhs_len: u32, + #[comptime] config: CubeTiling2dConfig, ) { - let tile_size = Comptime::map(config, |c| c.tile_size); - let block_size_m = Comptime::map(config, |c| c.block_size_m); - let block_size_k = Comptime::map(config, |c| c.block_size_m); - let block_size_n = Comptime::map(config, |c| c.block_size_m); + let tile_size = config.tile_size; + let block_size_m = config.block_size_m; + let block_size_k = config.block_size_m; + let block_size_n = config.block_size_m; let sm_size_lhs = block_size_m * block_size_k / tile_size; let sm_size_rhs = block_size_n * block_size_k / tile_size; // Shared memories are not launchable, so we launch with tensor and convert to shared memory - let mut shared_lhs = - SharedMemory::::vectorized(Comptime::get(sm_size_lhs), Comptime::get(tile_size)); - for i in range(0u32, Comptime::get(lhs_len), Comptime::new(true)) { + let mut shared_lhs = SharedMemory::::vectorized(sm_size_lhs, tile_size); + #[unroll] + for i in 0..lhs_len { shared_lhs[i] = lhs[i]; } - let mut shared_rhs = - SharedMemory::::vectorized(Comptime::get(sm_size_rhs), Comptime::get(tile_size)); - for i in range(0u32, Comptime::get(rhs_len), Comptime::new(true)) { + let mut shared_rhs = SharedMemory::::vectorized(sm_size_rhs, tile_size); + #[unroll] + for i in 0..rhs_len { shared_rhs[i] = rhs[i]; } - for i in range(0u32, 16u32, Comptime::new(false)) { + for i in 0..16 { results[i] = F::new(0.); } let coordinates = Coordinates { unit_row, unit_col, - skip_row: UInt::new(0), - skip_col: UInt::new(0), + skip_row: 0, + skip_col: 0, }; compute_loop(coordinates, shared_lhs, shared_rhs, results, config) @@ -127,7 +123,7 @@ pub fn tile_outer_product_vectorized_unit_test(device: &R::Device) { let config = make_tiling2d_config(SOME_DIM, SOME_DIM, SOME_DIM); unsafe { - tile_outer_product_test::launch_unchecked::( + tile_outer_product_test::launch_unchecked::( &client, cube_count, cube_dim, @@ -157,7 +153,7 @@ pub fn compute_loop_unit_test(device: &R::Device) { let config = make_tiling2d_config(SOME_DIM, SOME_DIM, SOME_DIM); unsafe { - compute_loop_test::launch_unchecked::( + compute_loop_test::launch_unchecked::( &R::client(device), cube_count, cube_dim, @@ -166,8 +162,8 @@ pub fn compute_loop_unit_test(device: &R::Device) { ScalarArg::new(0), ScalarArg::new(0), ArrayArg::from_raw_parts(&results, 16, 1), - UInt::new(16), - UInt::new(16), + 16, + 16, config, ); }; @@ -191,7 +187,7 @@ pub fn compute_loop_unit_offset_test(device: &R::Device) { let config = make_tiling2d_config(4, 8, 4); unsafe { - compute_loop_test::launch_unchecked::( + compute_loop_test::launch_unchecked::( &R::client(device), cube_count, cube_dim, @@ -200,8 +196,8 @@ pub fn compute_loop_unit_offset_test(device: &R::Device) { ScalarArg::new(4), ScalarArg::new(4), ArrayArg::from_raw_parts(&results, 16, 1), - UInt::new(8), - UInt::new(8), + 8, + 8, config, ); }; diff --git a/crates/cubecl-linalg/src/matmul/tests/tiling2d/load_shared_memory.rs b/crates/cubecl-linalg/src/matmul/tests/tiling2d/load_shared_memory.rs index f9299ec9..11502a66 100644 --- a/crates/cubecl-linalg/src/matmul/tests/tiling2d/load_shared_memory.rs +++ b/crates/cubecl-linalg/src/matmul/tests/tiling2d/load_shared_memory.rs @@ -9,9 +9,9 @@ use crate::matmul::tiling2d::tile::loader::TileLoader; use crate::matmul::{ tests::test_utils::{assert_equals, create_empty, range_tensor}, tiling2d::{ - base::{Coordinates, CoordinatesExpand, Dimensions, DimensionsExpand, TILE_SIZE}, + base::{Coordinates, Dimensions, TILE_SIZE}, config::CubeTiling2dConfig, - load_shared_memory::{LoadInfo, LoadInfoExpand}, + load_shared_memory::LoadInfo, }, }; @@ -19,68 +19,65 @@ use crate::matmul::{ fn load_tensor_test( tensor: &Tensor, sm_out: &mut Array, - unit_row: UInt, - unit_col: UInt, - k: UInt, - config: Comptime, - is_lhs: Comptime, + unit_row: u32, + unit_col: u32, + k: u32, + #[comptime] config: CubeTiling2dConfig, + #[comptime] is_lhs: bool, ) { - let tile_size = Comptime::map(config, |c| c.tile_size); - let block_size_k = Comptime::map(config, |c| c.block_size_k); - let block_size_m = Comptime::map(config, |c| c.block_size_m); + let tile_size = config.tile_size; + let block_size_k = config.block_size_k; + let block_size_m = config.block_size_m; let sm_size = block_size_k * block_size_m / tile_size; - let mut shared_memory = - SharedMemory::::vectorized(Comptime::get(sm_size), Comptime::get(tile_size)); + let mut shared_memory = SharedMemory::::vectorized(sm_size, tile_size); - for i in range(0u32, Comptime::get(sm_size), Comptime::new(false)) { - sm_out[i] = F::vectorized(0., Comptime::get(tile_size)); - shared_memory[i] = F::vectorized(0., Comptime::get(tile_size)); + for i in 0..sm_size { + sm_out[i] = F::vectorized(0., tile_size); + shared_memory[i] = F::vectorized(0., tile_size); } - let batch_offset = UInt::new(0); + let batch_offset = 0; let coordinates = Coordinates { unit_row, unit_col, - skip_row: UInt::new(0), - skip_col: UInt::new(0), + skip_row: 0, + skip_col: 0, }; - if Comptime::get(is_lhs) { + if is_lhs { let dims = Dimensions { - m: tensor.shape(tensor.rank() - UInt::new(2)), - k: tensor.shape(tensor.rank() - UInt::new(1)), - n: UInt::new(0), + m: tensor.shape(tensor.rank() - 2), + k: tensor.shape(tensor.rank() - 1), + n: 0, }; - let info = LoadInfo { + let info = LoadInfo:: { coordinates, k, batch_offset, shared_memory, - config, dims, }; load_lhs_transposed::>(tensor, info, config); } else { let dims = Dimensions { - m: UInt::new(0), - k: tensor.shape(tensor.rank() - UInt::new(2)), - n: tensor.shape(tensor.rank() - UInt::new(1)), + m: 0, + k: tensor.shape(tensor.rank() - 2), + n: tensor.shape(tensor.rank() - 1), }; - let info = LoadInfo { + let info = LoadInfo:: { coordinates, k, batch_offset, shared_memory, - config, dims, }; load_rhs_plain::>(tensor, info, config); } - for i in range(0u32, Comptime::get(sm_size), Comptime::new(false)) { + for i in 0..sm_size { sm_out[i] = shared_memory[i]; } } @@ -89,46 +86,44 @@ fn load_tensor_test( fn load_tensor_permuted_test( tensor: &Tensor, sm_out: &mut Array, - unit_row: UInt, - unit_col: UInt, - k: UInt, - config: Comptime, - is_lhs: Comptime, + unit_row: u32, + unit_col: u32, + k: u32, + #[comptime] config: CubeTiling2dConfig, + #[comptime] is_lhs: bool, ) { - let tile_size = Comptime::map(config, |c| c.tile_size); - let block_size_k = Comptime::map(config, |c| c.block_size_k); - let block_size_m = Comptime::map(config, |c| c.block_size_m); + let tile_size = config.tile_size; + let block_size_k = config.block_size_k; + let block_size_m = config.block_size_m; let sm_size = block_size_k * block_size_m / tile_size; - let mut shared_memory = - SharedMemory::::vectorized(Comptime::get(sm_size), Comptime::get(tile_size)); + let mut shared_memory = SharedMemory::::vectorized(sm_size, tile_size); - for i in range(0u32, Comptime::get(sm_size), Comptime::new(false)) { - sm_out[i] = F::vectorized(0., Comptime::get(tile_size)); - shared_memory[i] = F::vectorized(0., Comptime::get(tile_size)); + for i in 0..sm_size { + sm_out[i] = F::vectorized(0., tile_size); + shared_memory[i] = F::vectorized(0., tile_size); } - let batch_offset = UInt::new(0); + let batch_offset = 0; let coordinates = Coordinates { unit_row, unit_col, - skip_row: UInt::new(0), - skip_col: UInt::new(0), + skip_row: 0, + skip_col: 0, }; - if Comptime::get(is_lhs) { + if is_lhs { // Permuted let dims = Dimensions { - m: tensor.shape(tensor.rank() - UInt::new(1)), - k: tensor.shape(tensor.rank() - UInt::new(2)), - n: UInt::new(0), + m: tensor.shape(tensor.rank() - 1), + k: tensor.shape(tensor.rank() - 2), + n: 0, }; - let info = LoadInfo { + let info = LoadInfo:: { coordinates, k, batch_offset, shared_memory, - config, dims, }; @@ -136,23 +131,22 @@ fn load_tensor_permuted_test( } else { // Permuted let dims = Dimensions { - m: UInt::new(0), - k: tensor.shape(tensor.rank() - UInt::new(1)), - n: tensor.shape(tensor.rank() - UInt::new(2)), + m: 0, + k: tensor.shape(tensor.rank() - 1), + n: tensor.shape(tensor.rank() - 2), }; - let info = LoadInfo { + let info = LoadInfo:: { coordinates, k, batch_offset, shared_memory, - config, dims, }; load_rhs_transposed::>(tensor, info, config); } - for i in range(0u32, Comptime::get(sm_size), Comptime::new(false)) { + for i in 0..sm_size { sm_out[i] = shared_memory[i]; } } @@ -161,68 +155,65 @@ fn load_tensor_permuted_test( fn load_tensor_multiple_tiles_test( tensor: &Tensor, sm_out: &mut Array, - k: UInt, - config: Comptime, - is_lhs: Comptime, + k: u32, + #[comptime] config: CubeTiling2dConfig, + #[comptime] is_lhs: bool, ) { - let tile_size = Comptime::map(config, |c| c.tile_size); - let block_size_k = Comptime::map(config, |c| c.block_size_k); - let block_size_m = Comptime::map(config, |c| c.block_size_m); + let tile_size = config.tile_size; + let block_size_k = config.block_size_k; + let block_size_m = config.block_size_m; let sm_size = block_size_k * block_size_m / tile_size; - let mut shared_memory = - SharedMemory::::vectorized(Comptime::get(sm_size), Comptime::get(tile_size)); + let mut shared_memory = SharedMemory::::vectorized(sm_size, tile_size); - for i in range(0u32, Comptime::get(sm_size), Comptime::new(false)) { - sm_out[i] = F::vectorized(0., Comptime::get(tile_size)); - shared_memory[i] = F::vectorized(0., Comptime::get(tile_size)); + for i in 0..sm_size { + sm_out[i] = F::vectorized(0., tile_size); + shared_memory[i] = F::vectorized(0., tile_size); } - let unit_row = UInt::new(4) * UNIT_POS_X; - let unit_col = UInt::new(4) * UNIT_POS_Y; - let batch_offset = UInt::new(0); + let unit_row = 4 * UNIT_POS_X; + let unit_col = 4 * UNIT_POS_Y; + let batch_offset = 0; let coordinates = Coordinates { unit_row, unit_col, - skip_row: UInt::new(0), - skip_col: UInt::new(0), + skip_row: 0, + skip_col: 0, }; - if Comptime::get(is_lhs) { + if is_lhs { let dims = Dimensions { - m: tensor.shape(tensor.rank() - UInt::new(2)), - k: tensor.shape(tensor.rank() - UInt::new(1)), - n: UInt::new(0), + m: tensor.shape(tensor.rank() - 2), + k: tensor.shape(tensor.rank() - 1), + n: 0, }; - let info = LoadInfo { + let info = LoadInfo:: { coordinates, k, batch_offset, shared_memory, - config, dims, }; load_lhs_transposed::>(tensor, info, config); } else { let dims = Dimensions { - m: UInt::new(0), - k: tensor.shape(tensor.rank() - UInt::new(2)), - n: tensor.shape(tensor.rank() - UInt::new(1)), + m: 0, + k: tensor.shape(tensor.rank() - 2), + n: tensor.shape(tensor.rank() - 1), }; - let info = LoadInfo { + let info = LoadInfo:: { coordinates, k, batch_offset, shared_memory, - config, dims, }; load_rhs_plain::>(tensor, info, config); } - for i in range(0u32, Comptime::get(sm_size), Comptime::new(false)) { + for i in 0..sm_size { sm_out[i] = shared_memory[i]; } } @@ -238,7 +229,7 @@ pub fn load_lhs_transposed_unit_test(device: &R::Device) { let config = make_tiling2d_config(16, 16, 8); unsafe { - load_tensor_test::launch_unchecked::( + load_tensor_test::launch_unchecked::( &client, cube_count, cube_dim, @@ -273,7 +264,7 @@ pub fn load_lhs_transposed_out_of_bounds_cube_test(device: &R::Devic let config = make_tiling2d_config(5, 1, 1); unsafe { - load_tensor_multiple_tiles_test::launch_unchecked::( + load_tensor_multiple_tiles_test::launch_unchecked::( &client, cube_count, cube_dim, @@ -310,7 +301,7 @@ pub fn load_lhs_transposed_cube_test(device: &R::Device) { let config = make_tiling2d_config(8, 8, 8); unsafe { - load_tensor_multiple_tiles_test::launch_unchecked::( + load_tensor_multiple_tiles_test::launch_unchecked::( &client, cube_count, cube_dim, @@ -343,7 +334,7 @@ pub fn load_lhs_transposed_offset_cube_test(device: &R::Device) { let config = make_tiling2d_config(8, 8, 16); unsafe { - load_tensor_multiple_tiles_test::launch_unchecked::( + load_tensor_multiple_tiles_test::launch_unchecked::( &client, cube_count, cube_dim, @@ -376,7 +367,7 @@ pub fn load_rhs_plain_unit_test(device: &R::Device) { let config = make_tiling2d_config(8, 16, 16); unsafe { - load_tensor_test::launch_unchecked::( + load_tensor_test::launch_unchecked::( &client, cube_count, cube_dim, @@ -410,7 +401,7 @@ pub fn load_rhs_plain_cube_test(device: &R::Device) { let config = make_tiling2d_config(8, 8, 8); unsafe { - load_tensor_multiple_tiles_test::launch_unchecked::( + load_tensor_multiple_tiles_test::launch_unchecked::( &client, cube_count, cube_dim, @@ -443,7 +434,7 @@ pub fn load_rhs_plain_cube_offset_test(device: &R::Device) { let config = make_tiling2d_config(16, 16, 8); unsafe { - load_tensor_multiple_tiles_test::launch_unchecked::( + load_tensor_multiple_tiles_test::launch_unchecked::( &client, cube_count, cube_dim, @@ -476,7 +467,7 @@ pub fn load_lhs_plain_unit_test(device: &R::Device) { let config = make_tiling2d_config(16, 16, 8); unsafe { - load_tensor_permuted_test::launch_unchecked::( + load_tensor_permuted_test::launch_unchecked::( &client, cube_count, cube_dim, @@ -511,7 +502,7 @@ pub fn load_lhs_plain_out_of_bounds_unit_test(device: &R::Device) { let config = make_tiling2d_config(m, k, 8); unsafe { - load_tensor_permuted_test::launch_unchecked::( + load_tensor_permuted_test::launch_unchecked::( &R::client(device), cube_count, cube_dim, @@ -545,7 +536,7 @@ pub fn load_rhs_transposed_unit_test(device: &R::Device) { let config = make_tiling2d_config(16, 16, 8); unsafe { - load_tensor_permuted_test::launch_unchecked::( + load_tensor_permuted_test::launch_unchecked::( &client, cube_count, cube_dim, @@ -580,7 +571,7 @@ pub fn load_rhs_transposed_out_of_bounds_unit_test(device: &R::Devic let config = make_tiling2d_config(8, k, n); unsafe { - load_tensor_permuted_test::launch_unchecked::( + load_tensor_permuted_test::launch_unchecked::( &client, cube_count, cube_dim, diff --git a/crates/cubecl-linalg/src/matmul/tests/tiling2d/write_output.rs b/crates/cubecl-linalg/src/matmul/tests/tiling2d/write_output.rs index 41c2f293..50995a9c 100644 --- a/crates/cubecl-linalg/src/matmul/tests/tiling2d/write_output.rs +++ b/crates/cubecl-linalg/src/matmul/tests/tiling2d/write_output.rs @@ -9,7 +9,7 @@ use crate::matmul::tiling2d::write_output::write_to_output; use crate::matmul::{ tests::test_utils::{assert_equals, range_tensor}, tiling2d::{ - base::{Coordinates, CoordinatesExpand, Dimensions, DimensionsExpand, TILE_SIZE}, + base::{Coordinates, Dimensions, TILE_SIZE}, config::CubeTiling2dConfig, }, }; @@ -18,42 +18,42 @@ use crate::matmul::{ fn write_to_output_test( out: &mut Tensor, results: &mut Array, - config: Comptime, + #[comptime] config: CubeTiling2dConfig, ) { let coordinates = Coordinates { - unit_row: UInt::new(4), - unit_col: UInt::new(4), - skip_row: UInt::new(0), - skip_col: UInt::new(0), + unit_row: 4, + unit_col: 4, + skip_row: 0, + skip_col: 0, }; let dims = Dimensions { - m: out.shape(out.rank() - UInt::new(2)), - k: UInt::new(0), - n: out.shape(out.rank() - UInt::new(1)), + m: out.shape(out.rank() - 2), + k: 0, + n: out.shape(out.rank() - 1), }; - write_to_output::>(out, results, coordinates, UInt::new(0), dims, config); + write_to_output::>(out, results, coordinates, 0, dims, config); } #[cube(launch_unchecked)] fn write_results_to_output_out_of_bounds_test( out: &mut Tensor, results: &mut Array, - config: Comptime, + #[comptime] config: CubeTiling2dConfig, ) { let coordinates = Coordinates { - unit_row: UNIT_POS_X * UInt::new(4), - unit_col: UNIT_POS_Y * UInt::new(4), - skip_row: UInt::new(0), - skip_col: UInt::new(0), + unit_row: UNIT_POS_X * 4, + unit_col: UNIT_POS_Y * 4, + skip_row: 0, + skip_col: 0, }; let dims = Dimensions { - m: out.shape(out.rank() - UInt::new(2)), - k: UInt::new(0), - n: out.shape(out.rank() - UInt::new(1)), + m: out.shape(out.rank() - 2), + k: 0, + n: out.shape(out.rank() - 1), }; - write_to_output::>(out, results, coordinates, UInt::new(0), dims, config); + write_to_output::>(out, results, coordinates, 0, dims, config); } /// Exported test @@ -67,7 +67,7 @@ pub fn write_to_output_over_height_unit_test(device: &R::Device) { let config = make_tiling2d_config(6, 8, 8); unsafe { - write_to_output_test::launch_unchecked::( + write_to_output_test::launch_unchecked::( &R::client(device), cube_count, cube_dim, @@ -96,7 +96,7 @@ pub fn write_to_output_over_width_unit_test(device: &R::Device) { let config = make_tiling2d_config(8, 8, 4); unsafe { - write_to_output_test::launch_unchecked::( + write_to_output_test::launch_unchecked::( &R::client(device), cube_count, cube_dim, @@ -125,7 +125,7 @@ pub fn write_to_output_vectorized_less_than_tile_unit_test(device: & let config = make_tiling2d_config(8, 8, 8); unsafe { - write_to_output_test::launch_unchecked::( + write_to_output_test::launch_unchecked::( &R::client(device), cube_count, cube_dim, @@ -156,7 +156,7 @@ pub fn write_to_output_scalar_unit_test(device: &R::Device) { let config = make_tiling2d_config(8, 8, 8); unsafe { - write_to_output_test::launch_unchecked::( + write_to_output_test::launch_unchecked::( &R::client(device), cube_count, cube_dim, @@ -187,7 +187,7 @@ pub fn write_to_output_scalar_out_of_bounds_cube_test(device: &R::De let config = make_tiling2d_config(5, 8, 1); unsafe { - write_results_to_output_out_of_bounds_test::launch_unchecked::( + write_results_to_output_out_of_bounds_test::launch_unchecked::( &R::client(device), cube_count, cube_dim, diff --git a/crates/cubecl-linalg/src/matmul/tiling2d/base.rs b/crates/cubecl-linalg/src/matmul/tiling2d/base.rs index 4418a3d5..ac24fa99 100644 --- a/crates/cubecl-linalg/src/matmul/tiling2d/base.rs +++ b/crates/cubecl-linalg/src/matmul/tiling2d/base.rs @@ -1,5 +1,5 @@ -use cubecl_core as cubecl; use cubecl_core::prelude::*; +use cubecl_core::{self as cubecl, CubeType}; use super::{block_loop::block_loop, config::CubeTiling2dConfig}; @@ -12,7 +12,7 @@ pub fn tiling2d_cube_kernel( lhs: &Tensor, rhs: &Tensor, out: &mut Tensor, - config: Comptime, + #[comptime] config: CubeTiling2dConfig, ) { let dims = get_dims::(lhs, rhs); let coordinates = calculate_coordinates(CUBE_POS_X, CUBE_POS_Y, UNIT_POS, config); @@ -34,9 +34,9 @@ pub fn tiling2d_cube_kernel( /// Information available at runtime only /// Strides assume contiguous pub(crate) struct Dimensions { - pub m: UInt, - pub k: UInt, - pub n: UInt, + pub m: u32, + pub k: u32, + pub n: u32, } #[derive(CubeType, Copy, Clone)] @@ -49,24 +49,24 @@ pub(crate) struct SharedMemories { /// Number of elements in previous batches /// Not divided by vectorization facto pub(crate) struct BatchOffsets { - pub lhs: UInt, - pub rhs: UInt, - pub out: UInt, + pub lhs: u32, + pub rhs: u32, + pub out: u32, } #[derive(CubeType, Copy, Clone)] pub(crate) struct Coordinates { - pub unit_row: UInt, - pub unit_col: UInt, - pub skip_row: UInt, - pub skip_col: UInt, + pub unit_row: u32, + pub unit_col: u32, + pub skip_row: u32, + pub skip_col: u32, } #[cube] fn get_dims(lhs: &Tensor, rhs: &Tensor) -> Dimensions { let rank = lhs.rank(); - let first_dim = rank - UInt::new(2); - let second_dim = rank - UInt::new(1); + let first_dim = rank - 2; + let second_dim = rank - 1; let m = lhs.shape(first_dim); let k = lhs.shape(second_dim); let n = rhs.shape(second_dim); @@ -76,26 +76,24 @@ fn get_dims(lhs: &Tensor, rhs: &Tensor) -> Dimensions { #[cube] fn calculate_coordinates( - cube_pos_x: UInt, - cube_pos_y: UInt, - unit_pos: UInt, - config: Comptime, + cube_pos_x: u32, + cube_pos_y: u32, + unit_pos: u32, + #[comptime] config: CubeTiling2dConfig, ) -> Coordinates { - let block_size_m = Comptime::map(config, |c| c.block_size_m); - let block_size_n = Comptime::map(config, |c| c.block_size_n); - let tile_size = Comptime::map(config, |c| c.tile_size); + let block_size_m = config.block_size_m; + let block_size_n = config.block_size_n; + let tile_size = config.tile_size; - let n_units_per_row = ((Comptime::runtime(block_size_n) - UInt::new(1)) - / Comptime::runtime(tile_size)) - + UInt::new(1); + let n_units_per_row = ((block_size_n - 1) / tile_size) + 1; // Cube offset - let skip_row = cube_pos_x * Comptime::runtime(block_size_m); - let skip_col = cube_pos_y * Comptime::runtime(block_size_n); + let skip_row = cube_pos_x * block_size_m; + let skip_col = cube_pos_y * block_size_n; // Position of the first element of the unit, relative to the cube - let unit_row = (unit_pos / n_units_per_row) * Comptime::runtime(tile_size); - let unit_col = (unit_pos % n_units_per_row) * Comptime::runtime(tile_size); + let unit_row = (unit_pos / n_units_per_row) * tile_size; + let unit_col = (unit_pos % n_units_per_row) * tile_size; Coordinates { unit_row, @@ -111,20 +109,20 @@ fn calculate_batch_offsets( lhs: &Tensor, rhs: &Tensor, out: &Tensor, - batch_number: UInt, + batch_number: u32, ) -> BatchOffsets { let rank = out.rank(); - let dim_m = lhs.shape(rank - UInt::new(2)); - let dim_n = rhs.shape(rank - UInt::new(1)); + let dim_m = lhs.shape(rank - 2); + let dim_n = rhs.shape(rank - 1); // Batch offset for output let mut offset_out = dim_m * dim_n * batch_number; - let mut offset_lhs = UInt::new(0); - let mut offset_rhs = UInt::new(0); + let mut offset_lhs = 0; + let mut offset_rhs = 0; // Batch offset for lhs, rhs - for b in range(0u32, rank - UInt::new(2), Comptime::new(false)) { + for b in 0..rank - 2 { let tmp = offset_out / out.stride(b); offset_lhs += tmp % lhs.shape(b) * lhs.stride(b); offset_rhs += tmp % rhs.shape(b) * rhs.stride(b); @@ -138,21 +136,14 @@ fn calculate_batch_offsets( } #[cube] -fn make_shared_memories(config: Comptime) -> SharedMemories { - let tile_size = Comptime::map(config, |c| c.tile_size); - let block_size_m = Comptime::map(config, |c| c.block_size_m); - let block_size_k = Comptime::map(config, |c| c.block_size_k); - let block_size_n = Comptime::map(config, |c| c.block_size_n); - - let lhs = SharedMemory::::vectorized( - Comptime::get(block_size_k * block_size_m / tile_size), - Comptime::get(tile_size), - ); +fn make_shared_memories(#[comptime] config: CubeTiling2dConfig) -> SharedMemories { + let tile_size = config.tile_size; + let block_size_m = config.block_size_m; + let block_size_k = config.block_size_k; + let block_size_n = config.block_size_n; - let rhs = SharedMemory::::vectorized( - Comptime::get(block_size_k * block_size_n / tile_size), - Comptime::get(tile_size), - ); + let lhs = SharedMemory::::vectorized(block_size_k * block_size_m / tile_size, tile_size); + let rhs = SharedMemory::::vectorized(block_size_k * block_size_n / tile_size, tile_size); - SharedMemories { lhs, rhs } + SharedMemories:: { lhs, rhs } } diff --git a/crates/cubecl-linalg/src/matmul/tiling2d/block_loop.rs b/crates/cubecl-linalg/src/matmul/tiling2d/block_loop.rs index 5adae96b..75df8b56 100644 --- a/crates/cubecl-linalg/src/matmul/tiling2d/block_loop.rs +++ b/crates/cubecl-linalg/src/matmul/tiling2d/block_loop.rs @@ -18,14 +18,14 @@ pub(crate) fn block_loop( coordinates: Coordinates, offsets: BatchOffsets, shared: SharedMemories, - config: Comptime, + #[comptime] config: CubeTiling2dConfig, dims: Dimensions, ) { let mut results = init_results::(config); - let block_size_k = Comptime::runtime(Comptime::map(config, |c| c.block_size_k)); + let block_size_k = config.block_size_k; let n_loops = (dims.k + block_size_k - 1) / block_size_k; - for k in range(0u32, n_loops, Comptime::new(false)) { + for k in 0..n_loops { let k = k * block_size_k; load_to_shared_memories::>( @@ -50,12 +50,13 @@ pub(crate) fn block_loop( } #[cube] -fn init_results(config: Comptime) -> Array { - let tile_size = Comptime::map(config, |c| c.tile_size); - let unroll = Comptime::map(config, |c| c.unroll_tile); +fn init_results(#[comptime] config: CubeTiling2dConfig) -> Array { + let tile_size = config.tile_size; + let unroll = config.unroll_tile; - let mut results = Array::::new(Comptime::get(tile_size * tile_size)); - for i in range(0u32, Comptime::get(tile_size * tile_size), unroll) { + let mut results = Array::::new(tile_size * tile_size); + #[unroll(unroll)] + for i in 0..tile_size * tile_size { results[i] = F::new(0.); } diff --git a/crates/cubecl-linalg/src/matmul/tiling2d/compute_loop.rs b/crates/cubecl-linalg/src/matmul/tiling2d/compute_loop.rs index f80bd14d..74e737b5 100644 --- a/crates/cubecl-linalg/src/matmul/tiling2d/compute_loop.rs +++ b/crates/cubecl-linalg/src/matmul/tiling2d/compute_loop.rs @@ -10,22 +10,21 @@ pub(crate) fn compute_loop( shared_lhs: SharedMemory, shared_rhs: SharedMemory, results: &mut Array, - config: Comptime, + #[comptime] config: CubeTiling2dConfig, ) { - let tile_size = Comptime::map(config, |c| c.tile_size); - let block_size_m = Comptime::map(config, |c| c.block_size_m); - let block_size_k = Comptime::runtime(Comptime::map(config, |c| c.block_size_k)); - let block_size_n = Comptime::map(config, |c| c.block_size_n); - let unroll = Comptime::map(config, |c| c.unroll_compute); + let tile_size = config.tile_size; + let block_size_m = config.block_size_m; + let block_size_k = config.block_size_k; + let block_size_n = config.block_size_n; + let unroll = config.unroll_compute; let unit_row = coordinates.unit_row; let unit_col = coordinates.unit_col; - for dot_index in range(0u32, block_size_k, unroll) { - let register_m = shared_lhs[(unit_row + dot_index * Comptime::runtime(block_size_m)) - / Comptime::runtime(tile_size)]; - let register_n = shared_rhs[(unit_col + dot_index * Comptime::runtime(block_size_n)) - / Comptime::runtime(tile_size)]; + #[unroll(unroll)] + for dot_index in 0..block_size_k { + let register_m = shared_lhs[(unit_row + dot_index * block_size_m) / tile_size]; + let register_n = shared_rhs[(unit_col + dot_index * block_size_n) / tile_size]; tile_outer_product::(register_m, register_n, results, config); } diff --git a/crates/cubecl-linalg/src/matmul/tiling2d/config.rs b/crates/cubecl-linalg/src/matmul/tiling2d/config.rs index 69a69e79..d6cfe578 100644 --- a/crates/cubecl-linalg/src/matmul/tiling2d/config.rs +++ b/crates/cubecl-linalg/src/matmul/tiling2d/config.rs @@ -1,9 +1,8 @@ use cubecl_core::{ - compute::CubeCount, - frontend::{CubeContext, Init, UInt}, - ir::CubeDim, - Runtime, + self as cubecl, + prelude::{CubeContext, Init}, }; +use cubecl_core::{compute::CubeCount, ir::CubeDim, CubeType, Runtime}; use super::base::TILE_SIZE; @@ -34,21 +33,15 @@ impl Default for Tiling2dConfig { } } -impl Init for CubeTiling2dConfig { - fn init(self, _context: &mut CubeContext) -> Self { - self - } -} - -#[derive(Clone, Copy, Hash, PartialEq, Eq, Debug)] +#[derive(Clone, Copy, Hash, PartialEq, Eq, Debug, CubeType)] /// Tiling 2D parameters pub struct CubeTiling2dConfig { /// Block size along dimension of lhs - pub block_size_m: UInt, + pub block_size_m: u32, /// Block size along common dimension - pub block_size_k: UInt, + pub block_size_k: u32, /// Block size along dimension of rhs - pub block_size_n: UInt, + pub block_size_n: u32, /// Loop unrolling for inner compute loop. Probably slower pub unroll_compute: bool, /// Loop unrolling for all loops related to vectorization/tile size. Probably faster @@ -60,13 +53,19 @@ pub struct CubeTiling2dConfig { /// Bounds must be checked on rhs dimension pub check_n_bounds: bool, /// Tile size. Should correspond to vectorization of inputs/outputs/shared memory - pub tile_size: UInt, + pub tile_size: u32, /// Lhs is transposed in global memory pub lhs_transposed: bool, /// Rhs is transposed in global memory pub rhs_transposed: bool, } +impl Init for CubeTiling2dConfig { + fn init(self, _context: &mut CubeContext) -> Self { + self + } +} + impl CubeTiling2dConfig { pub fn new( config: &Tiling2dConfig, @@ -89,15 +88,15 @@ impl CubeTiling2dConfig { ); CubeTiling2dConfig { - block_size_m: UInt::new(config.block_size_m as u32), - block_size_k: UInt::new(config.block_size_k as u32), - block_size_n: UInt::new(config.block_size_n as u32), + block_size_m: config.block_size_m as u32, + block_size_k: config.block_size_k as u32, + block_size_n: config.block_size_n as u32, unroll_compute: config.unroll, unroll_tile: true, check_m_bounds: m % config.block_size_m != 0, check_k_bounds: k % config.block_size_k != 0, check_n_bounds: n % config.block_size_n != 0, - tile_size: UInt::new(config.tile_size as u32), + tile_size: config.tile_size as u32, lhs_transposed, rhs_transposed, } diff --git a/crates/cubecl-linalg/src/matmul/tiling2d/load_shared_memory.rs b/crates/cubecl-linalg/src/matmul/tiling2d/load_shared_memory.rs index 4a841955..416d0aef 100644 --- a/crates/cubecl-linalg/src/matmul/tiling2d/load_shared_memory.rs +++ b/crates/cubecl-linalg/src/matmul/tiling2d/load_shared_memory.rs @@ -1,5 +1,5 @@ use cubecl_core as cubecl; -use cubecl_core::prelude::*; +use cubecl_core::{prelude::*, CubeType}; use super::{ base::{BatchOffsets, Coordinates, Dimensions, SharedMemories}, @@ -15,19 +15,34 @@ use super::{ #[allow(dead_code)] pub(crate) struct LoadInfo { pub coordinates: Coordinates, - pub k: UInt, - pub batch_offset: UInt, + pub k: u32, + pub batch_offset: u32, pub shared_memory: SharedMemory, - pub config: Comptime, pub dims: Dimensions, } #[cube] pub(crate) trait Loader: Sync + Send + 'static { - fn load_lhs_plain>(lhs: &Tensor, load_info: LoadInfo); - fn load_lhs_transposed>(lhs: &Tensor, load_info: LoadInfo); - fn load_rhs_plain>(rhs: &Tensor, load_info: LoadInfo); - fn load_rhs_transposed>(rhs: &Tensor, load_info: LoadInfo); + fn load_lhs_plain>( + lhs: &Tensor, + load_info: LoadInfo, + #[comptime] config: CubeTiling2dConfig, + ); + fn load_lhs_transposed>( + lhs: &Tensor, + load_info: LoadInfo, + #[comptime] config: CubeTiling2dConfig, + ); + fn load_rhs_plain>( + rhs: &Tensor, + load_info: LoadInfo, + #[comptime] config: CubeTiling2dConfig, + ); + fn load_rhs_transposed>( + rhs: &Tensor, + load_info: LoadInfo, + #[comptime] config: CubeTiling2dConfig, + ); } #[cube] @@ -35,41 +50,39 @@ pub(crate) fn load_to_shared_memories>( lhs: &Tensor, rhs: &Tensor, coordinates: Coordinates, - k: UInt, + k: u32, offsets: BatchOffsets, shared: SharedMemories, - config: Comptime, + #[comptime] config: CubeTiling2dConfig, dims: Dimensions, ) { - let lhs_transposed = Comptime::map(config, |c| c.lhs_transposed); - let rhs_transposed = Comptime::map(config, |c| c.rhs_transposed); + let lhs_transposed = config.lhs_transposed; + let rhs_transposed = config.rhs_transposed; - let lhs_load_info = LoadInfo { + let lhs_load_info = LoadInfo:: { coordinates, k, batch_offset: offsets.lhs, shared_memory: shared.lhs, - config, dims, }; - let rhs_load_info = LoadInfo { + let rhs_load_info = LoadInfo:: { coordinates, k, batch_offset: offsets.rhs, shared_memory: shared.rhs, - config, dims, }; // Lhs must be loaded as transposed. If it already is transposed in global memory, we load as plain. - if Comptime::get(lhs_transposed) { + if lhs_transposed { load_lhs_plain::(lhs, lhs_load_info, config); } else { load_lhs_transposed::(lhs, lhs_load_info, config); } // Rhs must be loaded as plain. If it is transposed in global memory, we transpose it back. - if Comptime::get(rhs_transposed) { + if rhs_transposed { load_rhs_transposed::(rhs, rhs_load_info, config); } else { load_rhs_plain::(rhs, rhs_load_info, config); @@ -80,21 +93,21 @@ pub(crate) fn load_to_shared_memories>( pub(crate) fn load_lhs_transposed>( lhs: &Tensor, load_info: LoadInfo, - config: Comptime, + #[comptime] config: CubeTiling2dConfig, ) { - let check_m_bounds = Comptime::map(config, |c| c.check_m_bounds); - let check_k_bounds = Comptime::map(config, |c| c.check_k_bounds); + let check_m_bounds = config.check_m_bounds; + let check_k_bounds = config.check_k_bounds; - if Comptime::get(check_m_bounds) { - if Comptime::get(check_k_bounds) { - L::load_lhs_transposed::(lhs, load_info); + if check_m_bounds { + if check_k_bounds { + L::load_lhs_transposed::(lhs, load_info, config); } else { - L::load_lhs_transposed::(lhs, load_info); + L::load_lhs_transposed::(lhs, load_info, config); } - } else if Comptime::get(check_k_bounds) { - L::load_lhs_transposed::(lhs, load_info); + } else if check_k_bounds { + L::load_lhs_transposed::(lhs, load_info, config); } else { - L::load_lhs_transposed::(lhs, load_info); + L::load_lhs_transposed::(lhs, load_info, config); } } @@ -102,21 +115,21 @@ pub(crate) fn load_lhs_transposed>( pub(crate) fn load_lhs_plain>( lhs: &Tensor, load_info: LoadInfo, - config: Comptime, + #[comptime] config: CubeTiling2dConfig, ) { - let check_m_bounds = Comptime::map(config, |c| c.check_m_bounds); - let check_k_bounds = Comptime::map(config, |c| c.check_k_bounds); + let check_m_bounds = config.check_m_bounds; + let check_k_bounds = config.check_k_bounds; - if Comptime::get(check_k_bounds) { - if Comptime::get(check_m_bounds) { - L::load_lhs_plain::(lhs, load_info); + if check_k_bounds { + if check_m_bounds { + L::load_lhs_plain::(lhs, load_info, config); } else { - L::load_lhs_plain::(lhs, load_info); + L::load_lhs_plain::(lhs, load_info, config); } - } else if Comptime::get(check_m_bounds) { - L::load_lhs_plain::(lhs, load_info); + } else if check_m_bounds { + L::load_lhs_plain::(lhs, load_info, config); } else { - L::load_lhs_plain::(lhs, load_info); + L::load_lhs_plain::(lhs, load_info, config); } } @@ -124,21 +137,21 @@ pub(crate) fn load_lhs_plain>( pub(crate) fn load_rhs_transposed>( rhs: &Tensor, load_info: LoadInfo, - config: Comptime, + #[comptime] config: CubeTiling2dConfig, ) { - let check_k_bounds = Comptime::map(config, |c| c.check_k_bounds); - let check_n_bounds = Comptime::map(config, |c| c.check_n_bounds); + let check_k_bounds = config.check_k_bounds; + let check_n_bounds = config.check_n_bounds; - if Comptime::get(check_n_bounds) { - if Comptime::get(check_k_bounds) { - L::load_rhs_transposed::(rhs, load_info); + if check_n_bounds { + if check_k_bounds { + L::load_rhs_transposed::(rhs, load_info, config); } else { - L::load_rhs_transposed::(rhs, load_info); + L::load_rhs_transposed::(rhs, load_info, config); } - } else if Comptime::get(check_k_bounds) { - L::load_rhs_transposed::(rhs, load_info); + } else if check_k_bounds { + L::load_rhs_transposed::(rhs, load_info, config); } else { - L::load_rhs_transposed::(rhs, load_info); + L::load_rhs_transposed::(rhs, load_info, config); } } @@ -146,20 +159,20 @@ pub(crate) fn load_rhs_transposed>( pub(crate) fn load_rhs_plain>( rhs: &Tensor, load_info: LoadInfo, - config: Comptime, + #[comptime] config: CubeTiling2dConfig, ) { - let check_k_bounds = Comptime::map(config, |c| c.check_k_bounds); - let check_n_bounds = Comptime::map(config, |c| c.check_n_bounds); + let check_k_bounds = config.check_k_bounds; + let check_n_bounds = config.check_n_bounds; - if Comptime::get(check_k_bounds) { - if Comptime::get(check_n_bounds) { - L::load_rhs_plain::(rhs, load_info); + if check_k_bounds { + if check_n_bounds { + L::load_rhs_plain::(rhs, load_info, config); } else { - L::load_rhs_plain::(rhs, load_info); + L::load_rhs_plain::(rhs, load_info, config); } - } else if Comptime::get(check_n_bounds) { - L::load_rhs_plain::(rhs, load_info); + } else if check_n_bounds { + L::load_rhs_plain::(rhs, load_info, config); } else { - L::load_rhs_plain::(rhs, load_info); + L::load_rhs_plain::(rhs, load_info, config); } } diff --git a/crates/cubecl-linalg/src/matmul/tiling2d/outer_product.rs b/crates/cubecl-linalg/src/matmul/tiling2d/outer_product.rs index 4d471e19..86161c97 100644 --- a/crates/cubecl-linalg/src/matmul/tiling2d/outer_product.rs +++ b/crates/cubecl-linalg/src/matmul/tiling2d/outer_product.rs @@ -8,14 +8,16 @@ pub(crate) fn tile_outer_product( register_m: F, register_n: F, results: &mut Array, - config: Comptime, + #[comptime] config: CubeTiling2dConfig, ) { - let tile_size = Comptime::map(config, |c| c.tile_size); - let unroll = Comptime::map(config, |c| c.unroll_tile); + let tile_size = config.tile_size; + let unroll = config.unroll_tile; - for res_idx_m in range(0u32, Comptime::get(tile_size), unroll) { - let res_pos_base = res_idx_m * Comptime::runtime(tile_size); - for res_idx_n in range(0u32, Comptime::get(tile_size), unroll) { + #[unroll(unroll)] + for res_idx_m in 0..tile_size { + let res_pos_base = res_idx_m * tile_size; + #[unroll(unroll)] + for res_idx_n in 0..tile_size { let mul = register_m[res_idx_m] * register_n[res_idx_n]; results[res_pos_base + res_idx_n] += mul; } diff --git a/crates/cubecl-linalg/src/matmul/tiling2d/tile/block_io/base.rs b/crates/cubecl-linalg/src/matmul/tiling2d/tile/block_io/base.rs index 3fd8481e..49216104 100644 --- a/crates/cubecl-linalg/src/matmul/tiling2d/tile/block_io/base.rs +++ b/crates/cubecl-linalg/src/matmul/tiling2d/tile/block_io/base.rs @@ -12,7 +12,7 @@ pub(crate) trait BlockLoader: Send + Sync + 'static { tensor: &Tensor, shared_memory: &mut SharedMemory, read_tile_info: ReadTileInfo, - config: Comptime, + #[comptime] config: CubeTiling2dConfig, check_bounds: CheckBounds, ); @@ -20,7 +20,7 @@ pub(crate) trait BlockLoader: Send + Sync + 'static { tensor: &Tensor, shared_memory: &mut SharedMemory, read_tile_info: ReadTileInfo, - config: Comptime, + #[comptime] config: CubeTiling2dConfig, check_bounds: CheckBounds, ); } @@ -31,7 +31,7 @@ pub(crate) trait BlockWriter: Send + Sync + 'static { out: &mut Tensor, results: &Array, write_tile_info: WriteTileInfo, - config: Comptime, + #[comptime] config: CubeTiling2dConfig, check_bounds: CheckBounds, ); } @@ -39,16 +39,16 @@ pub(crate) trait BlockWriter: Send + Sync + 'static { #[cube] pub(crate) fn all_zeros_runtime( shared_memory: &mut SharedMemory, - start: UInt, - sm_position_base: UInt, - sm_stride: UInt, - config: Comptime, + start: u32, + sm_position_base: u32, + sm_stride: u32, + #[comptime] config: CubeTiling2dConfig, ) { - let tile_size = Comptime::map(config, |c| c.tile_size); - let zeros = F::vectorized(0., Comptime::get(tile_size)); + let tile_size = config.tile_size; + let zeros = F::vectorized(0., tile_size); - for i in range(start, Comptime::get(tile_size), Comptime::new(false)) { - let sm_position = (sm_position_base + i * sm_stride) / Comptime::runtime(tile_size); + for i in start..tile_size { + let sm_position = (sm_position_base + i * sm_stride) / tile_size; shared_memory[sm_position] = zeros; } @@ -57,16 +57,17 @@ pub(crate) fn all_zeros_runtime( #[cube] pub(crate) fn all_zeros_comptime( shared_memory: &mut SharedMemory, - sm_position_base: UInt, - sm_stride: UInt, - config: Comptime, + sm_position_base: u32, + sm_stride: u32, + #[comptime] config: CubeTiling2dConfig, ) { - let tile_size = Comptime::map(config, |c| c.tile_size); - let unroll = Comptime::map(config, |c| c.unroll_tile); - let zeros = F::vectorized(0., Comptime::get(tile_size)); + let tile_size = config.tile_size; + let unroll = config.unroll_tile; + let zeros = F::vectorized(0., tile_size); - for i in range(0u32, Comptime::get(tile_size), unroll) { - let sm_position = (sm_position_base + i * sm_stride) / Comptime::runtime(tile_size); + #[unroll(unroll)] + for i in 0..tile_size { + let sm_position = (sm_position_base + i * sm_stride) / tile_size; shared_memory[sm_position] = zeros; } diff --git a/crates/cubecl-linalg/src/matmul/tiling2d/tile/block_io/horizontal_block_check.rs b/crates/cubecl-linalg/src/matmul/tiling2d/tile/block_io/horizontal_block_check.rs index 4f55b7b2..096070b8 100644 --- a/crates/cubecl-linalg/src/matmul/tiling2d/tile/block_io/horizontal_block_check.rs +++ b/crates/cubecl-linalg/src/matmul/tiling2d/tile/block_io/horizontal_block_check.rs @@ -5,10 +5,7 @@ use crate::matmul::tiling2d::{ config::CubeTiling2dConfig, tile::{ loader::{CheckBounds, ReadTileInfo}, - memory_access::{ - ContiguousAccess, StridedAccess, UnmatchingVectorization, WritePositions, - WritePositionsExpand, - }, + memory_access::{ContiguousAccess, StridedAccess, UnmatchingVectorization, WritePositions}, }, write_output::WriteTileInfo, }; @@ -23,20 +20,19 @@ impl BlockLoader for HorizontalCheckBlockIO { tensor: &Tensor, shared_memory: &mut SharedMemory, info: ReadTileInfo, - config: Comptime, + #[comptime] config: CubeTiling2dConfig, check_bounds: CheckBounds, ) { - let tile_size = Comptime::map(config, |c| c.tile_size); - let vectorization = Comptime::vectorization(&tensor); - let unroll = Comptime::map(config, |c| c.unroll_tile); + let tile_size = config.tile_size; + let vectorization = vectorization_of(tensor); + let unroll = config.unroll_tile; let col = check_bounds.skip_col + info.read_col; if check_bounds.dim_horizontal > col { - for i in range(0u32, Comptime::get(tile_size), unroll) { - let gm_position = - (info.gm_position_base + i * info.gm_stride) / Comptime::runtime(vectorization); - let sm_position = - (info.sm_position_base + i * info.sm_stride) / Comptime::runtime(tile_size); + #[unroll(unroll)] + for i in 0..tile_size { + let gm_position = (info.gm_position_base + i * info.gm_stride) / vectorization; + let sm_position = (info.sm_position_base + i * info.sm_stride) / tile_size; shared_memory[sm_position] = A::read_contiguous_checked(tensor, gm_position, check_bounds, info, config); @@ -50,22 +46,21 @@ impl BlockLoader for HorizontalCheckBlockIO { tensor: &Tensor, shared_memory: &mut SharedMemory, info: ReadTileInfo, - config: Comptime, + #[comptime] config: CubeTiling2dConfig, check_bounds: CheckBounds, ) { - let tile_size = Comptime::map(config, |c| c.tile_size); + let tile_size = config.tile_size; - let mut num_reads = UInt::new(0); + let mut num_reads = 0; let col = check_bounds.skip_col + info.read_col; let dim_horizontal = check_bounds.dim_horizontal; if dim_horizontal > col { - num_reads = UInt::min(dim_horizontal - col, Comptime::runtime(tile_size)); + num_reads = Min::min(dim_horizontal - col, tile_size); } - for i in range(0u32, num_reads, Comptime::new(false)) { + for i in 0..num_reads { let gm_position = info.gm_position_base + i; - let sm_position = - (info.sm_position_base + i * info.sm_stride) / Comptime::runtime(tile_size); + let sm_position = (info.sm_position_base + i * info.sm_stride) / tile_size; shared_memory[sm_position] = UnmatchingVectorization::read_strided_unchecked( tensor, @@ -91,11 +86,11 @@ impl BlockWriter for HorizontalCheckBlockIO { out: &mut Tensor, results: &Array, info: WriteTileInfo, - config: Comptime, + #[comptime] config: CubeTiling2dConfig, check_bounds: CheckBounds, ) { - let tile_size = Comptime::map(config, |c| c.tile_size); - let unroll = Comptime::map(config, |c| c.unroll_tile); + let tile_size = config.tile_size; + let unroll = config.unroll_tile; let coordinates = info.coordinates; let col = coordinates.skip_col + coordinates.unit_col; @@ -104,9 +99,10 @@ impl BlockWriter for HorizontalCheckBlockIO { let row = coordinates.skip_row + coordinates.unit_row; let out_position_base = row * info.out_stride + col + info.offset_output; - for result_index in range(0u32, Comptime::get(tile_size), unroll) { + #[unroll(unroll)] + for result_index in 0..tile_size { let positions = WritePositions { - result: result_index * Comptime::runtime(tile_size), + result: result_index * tile_size, out: out_position_base + result_index * info.out_stride, }; diff --git a/crates/cubecl-linalg/src/matmul/tiling2d/tile/block_io/unchecked_block.rs b/crates/cubecl-linalg/src/matmul/tiling2d/tile/block_io/unchecked_block.rs index ebc73439..74cb61d3 100644 --- a/crates/cubecl-linalg/src/matmul/tiling2d/tile/block_io/unchecked_block.rs +++ b/crates/cubecl-linalg/src/matmul/tiling2d/tile/block_io/unchecked_block.rs @@ -5,10 +5,7 @@ use crate::matmul::tiling2d::{ config::CubeTiling2dConfig, tile::{ loader::{CheckBounds, ReadTileInfo}, - memory_access::{ - ContiguousAccess, StridedAccess, UnmatchingVectorization, WritePositions, - WritePositionsExpand, - }, + memory_access::{ContiguousAccess, StridedAccess, UnmatchingVectorization, WritePositions}, }, write_output::WriteTileInfo, }; @@ -24,18 +21,17 @@ impl BlockLoader for UncheckedBlockIO { tensor: &Tensor, shared_memory: &mut SharedMemory, info: ReadTileInfo, - config: Comptime, + #[comptime] config: CubeTiling2dConfig, _check_bounds: CheckBounds, ) { - let tile_size = Comptime::map(config, |c| c.tile_size); - let unroll = Comptime::map(config, |c| c.unroll_tile); - let vectorization = Comptime::vectorization(&tensor); + let tile_size = config.tile_size; + let unroll = config.unroll_tile; + let vectorization = vectorization_of(tensor); - for i in range(0u32, Comptime::get(tile_size), unroll) { - let gm_position = - (info.gm_position_base + i * info.gm_stride) / Comptime::runtime(vectorization); - let sm_position = - (info.sm_position_base + i * info.sm_stride) / Comptime::runtime(tile_size); + #[unroll(unroll)] + for i in 0..tile_size { + let gm_position = (info.gm_position_base + i * info.gm_stride) / vectorization; + let sm_position = (info.sm_position_base + i * info.sm_stride) / tile_size; shared_memory[sm_position] = A::read_contiguous_unchecked(tensor, gm_position, config); } @@ -45,16 +41,16 @@ impl BlockLoader for UncheckedBlockIO { tensor: &Tensor, shared_memory: &mut SharedMemory, info: ReadTileInfo, - config: Comptime, + #[comptime] config: CubeTiling2dConfig, _check_bounds: CheckBounds, ) { - let tile_size = Comptime::map(config, |c| c.tile_size); - let unroll = Comptime::map(config, |c| c.unroll_tile); + let tile_size = config.tile_size; + let unroll = config.unroll_tile; - for i in range(0u32, Comptime::get(tile_size), unroll) { + #[unroll(unroll)] + for i in 0..tile_size { let gm_position = info.gm_position_base + i; - let sm_position = - (info.sm_position_base + i * info.sm_stride) / Comptime::runtime(tile_size); + let sm_position = (info.sm_position_base + i * info.sm_stride) / tile_size; shared_memory[sm_position] = UnmatchingVectorization::read_strided_unchecked( tensor, @@ -72,20 +68,21 @@ impl BlockWriter for UncheckedBlockIO { out: &mut Tensor, results: &Array, info: WriteTileInfo, - config: Comptime, + #[comptime] config: CubeTiling2dConfig, _check_bounds: CheckBounds, ) { - let tile_size = Comptime::map(config, |c| c.tile_size); - let unroll = Comptime::map(config, |c| c.unroll_tile); + let tile_size = config.tile_size; + let unroll = config.unroll_tile; let coordinates = info.coordinates; let row = coordinates.skip_row + coordinates.unit_row; let col = coordinates.skip_col + coordinates.unit_col; let out_position_base = row * info.out_stride + col + info.offset_output; - for result_index in range(0u32, Comptime::get(tile_size), unroll) { + #[unroll(unroll)] + for result_index in 0..tile_size { let positions = WritePositions { - result: result_index * Comptime::runtime(tile_size), + result: result_index * tile_size, out: out_position_base + result_index * info.out_stride, }; diff --git a/crates/cubecl-linalg/src/matmul/tiling2d/tile/block_io/vertical_block_check.rs b/crates/cubecl-linalg/src/matmul/tiling2d/tile/block_io/vertical_block_check.rs index ea61f6ae..8ff3014f 100644 --- a/crates/cubecl-linalg/src/matmul/tiling2d/tile/block_io/vertical_block_check.rs +++ b/crates/cubecl-linalg/src/matmul/tiling2d/tile/block_io/vertical_block_check.rs @@ -5,10 +5,7 @@ use crate::matmul::tiling2d::{ config::CubeTiling2dConfig, tile::{ loader::{CheckBounds, ReadTileInfo}, - memory_access::{ - ContiguousAccess, StridedAccess, UnmatchingVectorization, WritePositions, - WritePositionsExpand, - }, + memory_access::{ContiguousAccess, StridedAccess, UnmatchingVectorization, WritePositions}, }, write_output::WriteTileInfo, }; @@ -23,26 +20,21 @@ impl BlockLoader for VerticalCheckBlockIO { tensor: &Tensor, shared_memory: &mut SharedMemory, info: ReadTileInfo, - config: Comptime, + #[comptime] config: CubeTiling2dConfig, check_bounds: CheckBounds, ) { - let tile_size = Comptime::map(config, |c| c.tile_size); - let vectorization = Comptime::vectorization(&tensor); + let tile_size = config.tile_size; + let vectorization = vectorization_of(tensor); - let mut num_reads = UInt::new(0); + let mut num_reads = 0; let row = check_bounds.skip_row + info.read_row; if check_bounds.dim_vertical > row { - num_reads = UInt::min( - check_bounds.dim_vertical - row, - Comptime::runtime(tile_size), - ); + num_reads = Min::min(check_bounds.dim_horizontal - row, tile_size); } - for i in range(0u32, num_reads, Comptime::new(false)) { - let gm_position = - (info.gm_position_base + i * info.gm_stride) / Comptime::runtime(vectorization); - let sm_position = - (info.sm_position_base + i * info.sm_stride) / Comptime::runtime(tile_size); + for i in 0..num_reads { + let gm_position = (info.gm_position_base + i * info.gm_stride) / vectorization; + let sm_position = (info.sm_position_base + i * info.sm_stride) / tile_size; shared_memory[sm_position] = A::read_contiguous_unchecked(tensor, gm_position, config); } @@ -60,16 +52,16 @@ impl BlockLoader for VerticalCheckBlockIO { tensor: &Tensor, shared_memory: &mut SharedMemory, info: ReadTileInfo, - config: Comptime, + #[comptime] config: CubeTiling2dConfig, check_bounds: CheckBounds, ) { - let tile_size = Comptime::map(config, |c| c.tile_size); - let unroll = Comptime::map(config, |c| c.unroll_tile); + let tile_size = config.tile_size; + let unroll = config.unroll_tile; - for i in range(0u32, Comptime::get(tile_size), unroll) { + #[unroll(unroll)] + for i in 0..tile_size { let gm_position = info.gm_position_base + i; - let sm_position = - (info.sm_position_base + i * info.sm_stride) / Comptime::runtime(tile_size); + let sm_position = (info.sm_position_base + i * info.sm_stride) / tile_size; shared_memory[sm_position] = UnmatchingVectorization::read_strided_checked( tensor, @@ -89,27 +81,24 @@ impl BlockWriter for VerticalCheckBlockIO { out: &mut Tensor, results: &Array, info: WriteTileInfo, - config: Comptime, + #[comptime] config: CubeTiling2dConfig, check_bounds: CheckBounds, ) { - let tile_size = Comptime::map(config, |c| c.tile_size); + let tile_size = config.tile_size; let coordinates = info.coordinates; let row = coordinates.skip_row + coordinates.unit_row; let col = coordinates.skip_col + coordinates.unit_col; let out_position_base = row * info.out_stride + col + info.offset_output; - let mut num_writes = UInt::new(0); + let mut num_writes = 0; if check_bounds.dim_vertical > row { - num_writes = UInt::min( - check_bounds.dim_vertical - row, - Comptime::runtime(tile_size), - ); + num_writes = Min::min(check_bounds.dim_vertical - row, tile_size); } - for result_index in range(0u32, num_writes, Comptime::new(false)) { + for result_index in 0..num_writes { let positions = WritePositions { - result: result_index * Comptime::runtime(tile_size), + result: result_index * tile_size, out: out_position_base + result_index * info.out_stride, }; diff --git a/crates/cubecl-linalg/src/matmul/tiling2d/tile/block_io/whole_block_check.rs b/crates/cubecl-linalg/src/matmul/tiling2d/tile/block_io/whole_block_check.rs index d1ed794c..faa4fa90 100644 --- a/crates/cubecl-linalg/src/matmul/tiling2d/tile/block_io/whole_block_check.rs +++ b/crates/cubecl-linalg/src/matmul/tiling2d/tile/block_io/whole_block_check.rs @@ -5,10 +5,7 @@ use crate::matmul::tiling2d::{ config::CubeTiling2dConfig, tile::{ loader::{CheckBounds, ReadTileInfo}, - memory_access::{ - ContiguousAccess, StridedAccess, UnmatchingVectorization, WritePositions, - WritePositionsExpand, - }, + memory_access::{ContiguousAccess, StridedAccess, UnmatchingVectorization, WritePositions}, }, write_output::WriteTileInfo, }; @@ -23,28 +20,23 @@ impl BlockLoader for WholeCheckBlockIO { tensor: &Tensor, shared_memory: &mut SharedMemory, info: ReadTileInfo, - config: Comptime, + #[comptime] config: CubeTiling2dConfig, check_bounds: CheckBounds, ) { - let tile_size = Comptime::map(config, |c| c.tile_size); - let vectorization = Comptime::vectorization(&tensor); + let tile_size = config.tile_size; + let vectorization = vectorization_of(tensor); let col = check_bounds.skip_col + info.read_col; if check_bounds.dim_horizontal > col { - let mut num_reads_vertical = UInt::new(0); + let mut num_reads_vertical = 0; let row = check_bounds.skip_row + info.read_row; if check_bounds.dim_vertical > row { - num_reads_vertical = UInt::min( - check_bounds.dim_vertical - row, - Comptime::runtime(tile_size), - ); + num_reads_vertical = Min::min(check_bounds.dim_vertical - row, tile_size); } - for i in range(0u32, num_reads_vertical, Comptime::new(false)) { - let gm_position = - (info.gm_position_base + i * info.gm_stride) / Comptime::runtime(vectorization); - let sm_position = - (info.sm_position_base + i * info.sm_stride) / Comptime::runtime(tile_size); + for i in 0..num_reads_vertical { + let gm_position = (info.gm_position_base + i * info.gm_stride) / vectorization; + let sm_position = (info.sm_position_base + i * info.sm_stride) / tile_size; shared_memory[sm_position] = A::read_contiguous_checked(tensor, gm_position, check_bounds, info, config); @@ -65,22 +57,21 @@ impl BlockLoader for WholeCheckBlockIO { tensor: &Tensor, shared_memory: &mut SharedMemory, info: ReadTileInfo, - config: Comptime, + #[comptime] config: CubeTiling2dConfig, check_bounds: CheckBounds, ) { - let tile_size = Comptime::map(config, |c| c.tile_size); + let tile_size = config.tile_size; - let mut num_reads_horizontal = UInt::new(0); + let mut num_reads_horizontal = 0; let col = check_bounds.skip_col + info.read_col; let dim_horizontal = check_bounds.dim_horizontal; if dim_horizontal > col { - num_reads_horizontal = UInt::min(dim_horizontal - col, Comptime::runtime(tile_size)); + num_reads_horizontal = Min::min(dim_horizontal - col, tile_size); } - for i in range(0u32, num_reads_horizontal, Comptime::new(false)) { + for i in 0..num_reads_horizontal { let gm_position = info.gm_position_base + i; - let sm_position = - (info.sm_position_base + i * info.sm_stride) / Comptime::runtime(tile_size); + let sm_position = (info.sm_position_base + i * info.sm_stride) / tile_size; shared_memory[sm_position] = UnmatchingVectorization::read_strided_checked( tensor, @@ -108,30 +99,27 @@ impl BlockWriter for WholeCheckBlockIO { out: &mut Tensor, results: &Array, info: WriteTileInfo, - config: Comptime, + #[comptime] config: CubeTiling2dConfig, check_bounds: CheckBounds, ) { - let tile_size = Comptime::map(config, |c| c.tile_size); + let tile_size = config.tile_size; let coordinates = info.coordinates; let col = coordinates.skip_col + coordinates.unit_col; if check_bounds.dim_horizontal > col { - let mut num_writes_vertical = UInt::new(0); + let mut num_writes_vertical = 0; let row = coordinates.skip_row + coordinates.unit_row; if check_bounds.dim_vertical > row { - num_writes_vertical = UInt::min( - check_bounds.dim_vertical - row, - Comptime::runtime(tile_size), - ); + num_writes_vertical = Min::min(check_bounds.dim_vertical - row, tile_size); } let out_position_base = row * info.out_stride + col + info.offset_output; - for result_index in range(0u32, num_writes_vertical, Comptime::new(false)) { + for result_index in 0..num_writes_vertical { let positions = WritePositions { - result: result_index * Comptime::runtime(tile_size), + result: result_index * tile_size, out: out_position_base + result_index * info.out_stride, }; diff --git a/crates/cubecl-linalg/src/matmul/tiling2d/tile/loader.rs b/crates/cubecl-linalg/src/matmul/tiling2d/tile/loader.rs index fd08ad93..71e263c1 100644 --- a/crates/cubecl-linalg/src/matmul/tiling2d/tile/loader.rs +++ b/crates/cubecl-linalg/src/matmul/tiling2d/tile/loader.rs @@ -1,8 +1,11 @@ -use cubecl_core as cubecl; use cubecl_core::prelude::*; +use cubecl_core::{self as cubecl, CubeType}; use std::marker::PhantomData; -use crate::matmul::tiling2d::load_shared_memory::{LoadInfo, Loader}; +use crate::matmul::tiling2d::{ + config::CubeTiling2dConfig, + load_shared_memory::{LoadInfo, Loader}, +}; use super::{ block_io::base::BlockLoader, @@ -17,33 +20,36 @@ pub(crate) struct TileLoader { #[derive(CubeType)] pub(crate) struct LoadIndices { - pub offset: UInt, - pub gm_stride: UInt, - pub sm_stride: UInt, + pub offset: u32, + pub gm_stride: u32, + pub sm_stride: u32, } #[derive(CubeType, Copy, Clone)] pub(crate) struct CheckBounds { - pub dim_vertical: UInt, - pub dim_horizontal: UInt, - pub skip_row: UInt, - pub skip_col: UInt, + pub dim_vertical: u32, + pub dim_horizontal: u32, + pub skip_row: u32, + pub skip_col: u32, } #[derive(CubeType, Copy, Clone)] pub(crate) struct ReadTileInfo { - pub read_row: UInt, - pub read_col: UInt, - pub gm_position_base: UInt, - pub sm_position_base: UInt, - pub gm_stride: UInt, - pub sm_stride: UInt, + pub read_row: u32, + pub read_col: u32, + pub gm_position_base: u32, + pub sm_position_base: u32, + pub gm_stride: u32, + pub sm_stride: u32, } #[cube] impl Loader for TileLoader { - fn load_lhs_plain>(lhs: &Tensor, load_info: LoadInfo) { - let config = load_info.config; + fn load_lhs_plain>( + lhs: &Tensor, + load_info: LoadInfo, + #[comptime] config: CubeTiling2dConfig, + ) { let dims = load_info.dims; let coordinates = load_info.coordinates; let gm_stride = dims.m; @@ -51,7 +57,7 @@ impl Loader for TileLoader { let load_indices = LoadIndices { offset: coordinates.skip_row + load_info.k * gm_stride + load_info.batch_offset, gm_stride, - sm_stride: Comptime::runtime(Comptime::map(config, |c| c.block_size_n)), + sm_stride: config.block_size_n, }; let check_bounds = CheckBounds { dim_vertical: dims.k, @@ -60,11 +66,14 @@ impl Loader for TileLoader { skip_col: coordinates.skip_row, }; - load_plain::(lhs, load_info, load_indices, check_bounds); + load_plain::(lhs, load_info, load_indices, check_bounds, config); } - fn load_lhs_transposed>(lhs: &Tensor, load_info: LoadInfo) { - let config = load_info.config; + fn load_lhs_transposed>( + lhs: &Tensor, + load_info: LoadInfo, + #[comptime] config: CubeTiling2dConfig, + ) { let dims = load_info.dims; let coordinates = load_info.coordinates; let gm_stride = dims.k; @@ -72,7 +81,7 @@ impl Loader for TileLoader { let load_indices = LoadIndices { offset: coordinates.skip_row * gm_stride + load_info.k + load_info.batch_offset, gm_stride, - sm_stride: Comptime::runtime(Comptime::map(config, |c| c.block_size_m)), + sm_stride: config.block_size_m, }; let check_bounds = CheckBounds { dim_vertical: dims.m, @@ -81,19 +90,22 @@ impl Loader for TileLoader { skip_col: load_info.k, }; - load_transposed::(lhs, load_info, load_indices, check_bounds); + load_transposed::(lhs, load_info, load_indices, check_bounds, config); } - fn load_rhs_plain>(rhs: &Tensor, load_info: LoadInfo) { + fn load_rhs_plain>( + rhs: &Tensor, + load_info: LoadInfo, + #[comptime] config: CubeTiling2dConfig, + ) { let coordinates = load_info.coordinates; let dims = load_info.dims; - let config = load_info.config; let gm_stride = dims.n; let load_indices = LoadIndices { offset: coordinates.skip_col + load_info.k * gm_stride + load_info.batch_offset, gm_stride, - sm_stride: Comptime::runtime(Comptime::map(config, |c| c.block_size_n)), + sm_stride: config.block_size_n, }; let check_bounds = CheckBounds { dim_vertical: dims.k, @@ -102,11 +114,14 @@ impl Loader for TileLoader { skip_col: coordinates.skip_col, }; - load_plain::(rhs, load_info, load_indices, check_bounds); + load_plain::(rhs, load_info, load_indices, check_bounds, config); } - fn load_rhs_transposed>(rhs: &Tensor, load_info: LoadInfo) { - let config = load_info.config; + fn load_rhs_transposed>( + rhs: &Tensor, + load_info: LoadInfo, + #[comptime] config: CubeTiling2dConfig, + ) { let dims = load_info.dims; let coordinates = load_info.coordinates; let gm_stride = dims.k; @@ -114,7 +129,7 @@ impl Loader for TileLoader { let load_indices = LoadIndices { offset: coordinates.skip_col * gm_stride + load_info.k + load_info.batch_offset, gm_stride, - sm_stride: Comptime::runtime(Comptime::map(config, |c| c.block_size_n)), + sm_stride: config.block_size_n, }; let check_bounds = CheckBounds { dim_vertical: dims.n, @@ -123,7 +138,7 @@ impl Loader for TileLoader { skip_col: load_info.k, }; - load_transposed::(rhs, load_info, load_indices, check_bounds); + load_transposed::(rhs, load_info, load_indices, check_bounds, config); } } @@ -133,13 +148,13 @@ pub(crate) fn load_plain>( load_info: LoadInfo, load_indices: LoadIndices, check_bounds: CheckBounds, + #[comptime] config: CubeTiling2dConfig, ) { let coordinates = load_info.coordinates; - let config = load_info.config; - let vectorization = Comptime::vectorization(tensor); - let tile_size = Comptime::map(config, |c| c.tile_size); - let sm_dim_vertical = Comptime::runtime(Comptime::map(config, |c| c.block_size_k)); + let vectorization = vectorization_of(tensor); + let tile_size = config.tile_size; + let sm_dim_vertical = config.block_size_k; let read_row = coordinates.unit_row; let read_col = coordinates.unit_col; @@ -186,11 +201,11 @@ pub(crate) fn load_transposed>( load_info: LoadInfo, load_indices: LoadIndices, check_bounds: CheckBounds, + #[comptime] config: CubeTiling2dConfig, ) { let coordinates = load_info.coordinates; - let config = load_info.config; - let sm_dim_vertical = Comptime::runtime(Comptime::map(config, |c| c.block_size_k)); + let sm_dim_vertical = config.block_size_k; let read_row = coordinates.unit_row; let read_col = coordinates.unit_col; diff --git a/crates/cubecl-linalg/src/matmul/tiling2d/tile/memory_access.rs b/crates/cubecl-linalg/src/matmul/tiling2d/tile/memory_access.rs index 736787f2..4022dec3 100644 --- a/crates/cubecl-linalg/src/matmul/tiling2d/tile/memory_access.rs +++ b/crates/cubecl-linalg/src/matmul/tiling2d/tile/memory_access.rs @@ -1,5 +1,5 @@ -use cubecl_core as cubecl; use cubecl_core::prelude::*; +use cubecl_core::{self as cubecl, CubeType}; use crate::matmul::tiling2d::config::CubeTiling2dConfig; @@ -7,31 +7,31 @@ use super::loader::{CheckBounds, ReadTileInfo}; #[derive(CubeType)] pub(crate) struct WritePositions { - pub out: UInt, - pub result: UInt, + pub out: u32, + pub result: u32, } #[cube] pub(crate) trait ContiguousAccess: Send + Sync + 'static { fn read_contiguous_unchecked( tensor: &Tensor, - gm_position: UInt, - config: Comptime, + gm_position: u32, + #[comptime] config: CubeTiling2dConfig, ) -> F; fn read_contiguous_checked( tensor: &Tensor, - gm_position: UInt, + gm_position: u32, check_bounds: CheckBounds, read_info: ReadTileInfo, - config: Comptime, + #[comptime] config: CubeTiling2dConfig, ) -> F; fn write_contiguous_unchecked( out: &mut Tensor, results: &Array, positions: WritePositions, - config: Comptime, + #[comptime] config: CubeTiling2dConfig, ); fn write_contiguous_checked( @@ -39,8 +39,8 @@ pub(crate) trait ContiguousAccess: Send + Sync + 'static { results: &Array, positions: WritePositions, check_bounds: CheckBounds, - write_col: UInt, - config: Comptime, + write_col: u32, + #[comptime] config: CubeTiling2dConfig, ); } @@ -48,18 +48,18 @@ pub(crate) trait ContiguousAccess: Send + Sync + 'static { pub(crate) trait StridedAccess: Send + Sync + 'static { fn read_strided_unchecked( tensor: &Tensor, - gm_position: UInt, - gm_stride: UInt, - config: Comptime, + gm_position: u32, + gm_stride: u32, + #[comptime] config: CubeTiling2dConfig, ) -> F; fn read_strided_checked( tensor: &Tensor, - gm_position: UInt, - gm_stride: UInt, + gm_position: u32, + gm_stride: u32, check_bounds: CheckBounds, info: ReadTileInfo, - config: Comptime, + #[comptime] config: CubeTiling2dConfig, ) -> F; } @@ -73,18 +73,18 @@ pub(crate) struct UnmatchingVectorization; impl ContiguousAccess for MatchingVectorization { fn read_contiguous_unchecked( tensor: &Tensor, - gm_position: UInt, - _config: Comptime, + gm_position: u32, + #[comptime] _config: CubeTiling2dConfig, ) -> F { tensor[gm_position] } fn read_contiguous_checked( tensor: &Tensor, - gm_position: UInt, + gm_position: u32, _check_bounds: CheckBounds, _read_info: ReadTileInfo, - config: Comptime, + #[comptime] config: CubeTiling2dConfig, ) -> F { // If vectorization matches, then it's certain to fit since tile_size divides block_sizes MatchingVectorization::read_contiguous_unchecked(tensor, gm_position, config) @@ -94,18 +94,19 @@ impl ContiguousAccess for MatchingVectorization { out: &mut Tensor, results: &Array, positions: WritePositions, - config: Comptime, + #[comptime] config: CubeTiling2dConfig, ) { - let tile_size = Comptime::map(config, |c| c.tile_size); - let unroll = Comptime::map(config, |c| c.unroll_tile); + let tile_size = config.tile_size; + let unroll = config.unroll_tile; - let mut output_elem = F::vectorized_empty(Comptime::get(tile_size)); + let mut output_elem = F::vectorized_empty(tile_size); - for i in range(0u32, Comptime::get(tile_size), unroll) { + #[unroll(unroll)] + for i in 0..tile_size { output_elem[i] = results[positions.result + i]; } - out[positions.out / Comptime::runtime(tile_size)] = output_elem; + out[positions.out / tile_size] = output_elem; } fn write_contiguous_checked( @@ -113,8 +114,8 @@ impl ContiguousAccess for MatchingVectorization { results: &Array, positions: WritePositions, _check_bounds: CheckBounds, - _write_col: UInt, - config: Comptime, + _write_col: u32, + #[comptime] config: CubeTiling2dConfig, ) { // If vectorization matches, then it's certain to fit since tile_size divides block_sizes MatchingVectorization::write_contiguous_unchecked(out, results, positions, config) @@ -125,30 +126,26 @@ impl ContiguousAccess for MatchingVectorization { impl ContiguousAccess for UnmatchingVectorization { fn read_contiguous_unchecked( tensor: &Tensor, - gm_position: UInt, - config: Comptime, + gm_position: u32, + #[comptime] config: CubeTiling2dConfig, ) -> F { - let tile_size = Comptime::map(config, |c| c.tile_size); - let unroll = Comptime::map(config, |c| c.unroll_tile); - let vectorization_factor = Comptime::vectorization(tensor); - let is_scalar = Comptime::map(vectorization_factor, |v| v.val == 1); + let tile_size = config.tile_size; + let unroll = config.unroll_tile; + let vectorization_factor = vectorization_of(tensor); + let is_scalar = vectorization_factor == 1; - let mut vector = F::vectorized(0., Comptime::get(tile_size)); + let mut vector = F::vectorized(0., tile_size); - for i in range( - 0u32, - Comptime::get(tile_size / vectorization_factor), - unroll, - ) { - let runtime_vectorization = Comptime::runtime(vectorization_factor); - - if Comptime::get(is_scalar) { + #[unroll(unroll)] + for i in 0u32..tile_size / vectorization_factor { + if is_scalar { vector[i] = tensor[gm_position + i]; } else { let intermediate = tensor[gm_position + i]; - for j in range(0u32, Comptime::get(vectorization_factor), unroll) { - vector[i * runtime_vectorization + j] = intermediate[j]; + #[unroll(unroll)] + for j in 0..vectorization_factor { + vector[i * vectorization_factor + j] = intermediate[j]; } } } @@ -158,36 +155,33 @@ impl ContiguousAccess for UnmatchingVectorization { fn read_contiguous_checked( tensor: &Tensor, - gm_position: UInt, + gm_position: u32, check_bounds: CheckBounds, read_info: ReadTileInfo, - config: Comptime, + #[comptime] config: CubeTiling2dConfig, ) -> F { - let tile_size = Comptime::map(config, |c| c.tile_size); - let unroll = Comptime::map(config, |c| c.unroll_tile); - let vectorization_factor = Comptime::vectorization(tensor); - let is_scalar = Comptime::map(vectorization_factor, |v| v.val == 1); - let runtime_vectorization = Comptime::runtime(vectorization_factor); + let tile_size = config.tile_size; + let unroll = config.unroll_tile; + let vectorization_factor = vectorization_of(tensor); + let is_scalar = vectorization_factor == 1; - let mut vector = F::vectorized(0., Comptime::get(tile_size)); + let mut vector = F::vectorized(0., tile_size); - let mut num_loops = UInt::new(0); + let mut num_loops = 0; if check_bounds.dim_horizontal > read_info.read_col { - let num_reads = UInt::min( - check_bounds.dim_horizontal - read_info.read_col, - Comptime::runtime(tile_size), - ); - num_loops = num_reads / runtime_vectorization; + let num_reads = Min::min(check_bounds.dim_horizontal - read_info.read_col, tile_size); + num_loops = num_reads / vectorization_factor; } - for i in range(0u32, num_loops, Comptime::new(false)) { - if Comptime::get(is_scalar) { + for i in 0..num_loops { + if is_scalar { vector[i] = tensor[gm_position + i]; } else { let intermediate = tensor[gm_position + i]; - for j in range(0u32, Comptime::get(vectorization_factor), unroll) { - vector[i * runtime_vectorization + j] = intermediate[j]; + #[unroll(unroll)] + for j in 0..vectorization_factor { + vector[i * vectorization_factor + j] = intermediate[j]; } } } @@ -199,30 +193,27 @@ impl ContiguousAccess for UnmatchingVectorization { out: &mut Tensor, results: &Array, positions: WritePositions, - config: Comptime, + #[comptime] config: CubeTiling2dConfig, ) { - let tile_size = Comptime::map(config, |c| c.tile_size); - let unroll = Comptime::map(config, |c| c.unroll_tile); - let vectorization_factor = Comptime::vectorization(out); - let runtime_vectorization = Comptime::runtime(vectorization_factor); - let is_scalar = Comptime::map(vectorization_factor, |v| v.val == 1); - - for i in range( - 0u32, - Comptime::get(tile_size / vectorization_factor), - unroll, - ) { - if Comptime::get(is_scalar) { + let tile_size = config.tile_size; + let unroll = config.unroll_tile; + let vectorization_factor = vectorization_of(out); + let is_scalar = vectorization_factor == 1; + + #[unroll(unroll)] + for i in 0..tile_size / vectorization_factor { + if is_scalar { out[i + positions.out] = results[positions.result + i]; } else { - let mut output_elem = F::vectorized_empty(Comptime::get(vectorization_factor)); + let mut output_elem = F::vectorized_empty(vectorization_factor); - for j in range(0u32, Comptime::get(vectorization_factor), unroll) { - let index = i * runtime_vectorization + j; + #[unroll(unroll)] + for j in 0..vectorization_factor { + let index = i * vectorization_factor + j; output_elem[j] = results[positions.result + index]; } - out[i + positions.out / runtime_vectorization] = output_elem; + out[i + positions.out / vectorization_factor] = output_elem; } } } @@ -232,37 +223,34 @@ impl ContiguousAccess for UnmatchingVectorization { results: &Array, positions: WritePositions, check_bounds: CheckBounds, - write_col: UInt, - config: Comptime, + write_col: u32, + #[comptime] config: CubeTiling2dConfig, ) { - let tile_size = Comptime::map(config, |c| c.tile_size); - let vectorization_factor = Comptime::vectorization(out); - let runtime_vectorization = Comptime::runtime(vectorization_factor); - let is_scalar = Comptime::map(vectorization_factor, |v| v.val == 1); + let tile_size = config.tile_size; + let vectorization_factor = vectorization_of(out); + let is_scalar = vectorization_factor == 1; - let mut num_loops = UInt::new(0); + let mut num_loops = 0; if check_bounds.dim_horizontal > write_col { - let num_writes = UInt::min( - check_bounds.dim_horizontal - write_col, - Comptime::runtime(tile_size), - ); - num_loops = num_writes / runtime_vectorization; + let num_writes = Min::min(check_bounds.dim_horizontal - write_col, tile_size); + num_loops = num_writes / vectorization_factor; } - for i in range(0u32, num_loops, Comptime::new(false)) { - let unroll = Comptime::map(config, |c| c.unroll_tile); + for i in 0..num_loops { + let unroll = config.unroll_tile; - if Comptime::get(is_scalar) { + if is_scalar { out[i + positions.out] = results[positions.result + i]; } else { - let mut output_elem = F::vectorized_empty(Comptime::get(vectorization_factor)); + let mut output_elem = F::vectorized_empty(vectorization_factor); - for j in range(0u32, Comptime::get(vectorization_factor), unroll) { - let index = i * runtime_vectorization + j; + #[unroll(unroll)] + for j in 0u32..vectorization_factor { + let index = i * vectorization_factor + j; output_elem[j] = results[positions.result + index]; } - out[i + positions.out / runtime_vectorization] = output_elem; + out[i + positions.out / vectorization_factor] = output_elem; } } } @@ -272,15 +260,16 @@ impl ContiguousAccess for UnmatchingVectorization { impl StridedAccess for UnmatchingVectorization { fn read_strided_unchecked( tensor: &Tensor, - gm_position: UInt, - gm_stride: UInt, - config: Comptime, + gm_position: u32, + gm_stride: u32, + #[comptime] config: CubeTiling2dConfig, ) -> F { - let tile_size = Comptime::map(config, |c| c.tile_size); - let unroll = Comptime::map(config, |c| c.unroll_tile); + let tile_size = config.tile_size; + let unroll = config.unroll_tile; - let mut vertical = F::vectorized_empty(Comptime::get(tile_size)); - for i in range(0u32, Comptime::get(tile_size), unroll) { + let mut vertical = F::vectorized_empty(tile_size); + #[unroll(unroll)] + for i in 0..tile_size { vertical[i] = tensor[gm_position + i * gm_stride]; } @@ -289,27 +278,27 @@ impl StridedAccess for UnmatchingVectorization { fn read_strided_checked( tensor: &Tensor, - gm_position: UInt, - gm_stride: UInt, + gm_position: u32, + gm_stride: u32, check_bounds: CheckBounds, info: ReadTileInfo, - config: Comptime, + #[comptime] config: CubeTiling2dConfig, ) -> F { - let tile_size = Comptime::map(config, |c| c.tile_size); + let tile_size = config.tile_size; - let mut vertical = F::vectorized_empty(Comptime::get(tile_size)); + let mut vertical = F::vectorized_empty(tile_size); - let mut num_reads = UInt::new(0); + let mut num_reads = 0; let row = check_bounds.skip_row + info.read_row; let dim_vertical = check_bounds.dim_vertical; if dim_vertical > row { - num_reads = UInt::min(dim_vertical - row, Comptime::runtime(tile_size)); + num_reads = Min::min(dim_vertical - row, tile_size); } - for i in range(0u32, num_reads, Comptime::new(false)) { + for i in 0..num_reads { vertical[i] = tensor[gm_position + i * gm_stride]; } - for i in range(num_reads, Comptime::get(tile_size), Comptime::new(false)) { + for i in num_reads..tile_size { vertical[i] = F::new(0.); } diff --git a/crates/cubecl-linalg/src/matmul/tiling2d/tile/writer.rs b/crates/cubecl-linalg/src/matmul/tiling2d/tile/writer.rs index 556a3538..aa13d57e 100644 --- a/crates/cubecl-linalg/src/matmul/tiling2d/tile/writer.rs +++ b/crates/cubecl-linalg/src/matmul/tiling2d/tile/writer.rs @@ -11,9 +11,10 @@ use crate::matmul::tiling2d::{ use super::{ block_io::base::BlockWriter, - loader::{CheckBounds, CheckBoundsExpand}, + loader::CheckBounds, memory_access::{MatchingVectorization, UnmatchingVectorization}, }; + pub(crate) struct TileWriter { _f: PhantomData, } @@ -25,10 +26,10 @@ impl OutputWriter for TileWriter { results: &Array, write_info: WriteTileInfo, dims: Dimensions, - config: Comptime, + #[comptime] config: CubeTiling2dConfig, ) { - let vectorization = Comptime::vectorization(out); - let tile_size = Comptime::map(config, |c| c.tile_size); + let vectorization = vectorization_of(out); + let tile_size = config.tile_size; let coordinates = write_info.coordinates; let check_bounds = CheckBounds { diff --git a/crates/cubecl-linalg/src/matmul/tiling2d/write_output.rs b/crates/cubecl-linalg/src/matmul/tiling2d/write_output.rs index 23132b5f..71bfb9e9 100644 --- a/crates/cubecl-linalg/src/matmul/tiling2d/write_output.rs +++ b/crates/cubecl-linalg/src/matmul/tiling2d/write_output.rs @@ -1,5 +1,5 @@ -use cubecl_core as cubecl; -use cubecl_core::prelude::*; +use cubecl_core::{self as cubecl}; +use cubecl_core::{prelude::*, CubeType}; use super::{ base::{Coordinates, Dimensions}, @@ -14,8 +14,8 @@ use super::{ #[derive(CubeType)] pub(crate) struct WriteTileInfo { pub coordinates: Coordinates, - pub offset_output: UInt, - pub out_stride: UInt, + pub offset_output: u32, + pub out_stride: u32, } #[cube] @@ -25,7 +25,7 @@ pub(crate) trait OutputWriter: Sync + Send + 'static { results: &Array, write_tile_info: WriteTileInfo, dims: Dimensions, - config: Comptime, + #[comptime] config: CubeTiling2dConfig, ); } @@ -34,12 +34,12 @@ pub(crate) fn write_to_output>( out: &mut Tensor, results: &Array, coordinates: Coordinates, - offset_output: UInt, + offset_output: u32, dims: Dimensions, - config: Comptime, + #[comptime] config: CubeTiling2dConfig, ) { - let check_m_bounds = Comptime::map(config, |c| c.check_m_bounds); - let check_n_bounds = Comptime::map(config, |c| c.check_n_bounds); + let check_m_bounds = config.check_m_bounds; + let check_n_bounds = config.check_n_bounds; let write_info = WriteTileInfo { coordinates, @@ -47,13 +47,13 @@ pub(crate) fn write_to_output>( out_stride: dims.n, }; - if Comptime::get(check_m_bounds) { - if Comptime::get(check_n_bounds) { + if check_m_bounds { + if check_n_bounds { W::write_output::(out, results, write_info, dims, config); } else { W::write_output::(out, results, write_info, dims, config); } - } else if Comptime::get(check_n_bounds) { + } else if check_n_bounds { W::write_output::(out, results, write_info, dims, config); } else { W::write_output::(out, results, write_info, dims, config); diff --git a/crates/cubecl-linalg/src/tensor/contiguous.rs b/crates/cubecl-linalg/src/tensor/contiguous.rs index 081f9d5d..8605b1c3 100644 --- a/crates/cubecl-linalg/src/tensor/contiguous.rs +++ b/crates/cubecl-linalg/src/tensor/contiguous.rs @@ -1,38 +1,36 @@ -use cubecl_core::{self as cubecl, calculate_cube_count_elemwise, tensor_vectorization_factor}; - -use cubecl::prelude::*; - use super::TensorHandle; +use cubecl::prelude::*; +use cubecl_core::{self as cubecl, calculate_cube_count_elemwise, tensor_vectorization_factor}; /// Returns the offset of the tensor corresponding to the layout tensor. #[cube] pub fn index_offset_with_layout( tensor: &Tensor, layout: &Tensor, - offset_layout: UInt, - dim_start: UInt, - dim_end: UInt, - unroll: Comptime, -) -> UInt { - let vectorization_factor = Comptime::vectorization(tensor); - let vectorization_factor_runtime = Comptime::runtime(vectorization_factor); + offset_layout: u32, + dim_start: u32, + dim_end: u32, + #[comptime] unroll: bool, +) -> u32 { + let vectorization = vectorization_of(tensor); - let offset_ref = offset_layout * vectorization_factor_runtime; - let mut offset = UInt::new(0); + let offset_ref = offset_layout * vectorization; + let mut offset = 0; - for i in range(dim_start, dim_end, unroll) { + #[unroll(unroll)] + for i in dim_start..dim_end { let ogwl = offset_ref / layout.stride(i); offset += ogwl % tensor.shape(i) * tensor.stride(i); } - offset / vectorization_factor_runtime + offset / vectorization } #[cube(launch_unchecked)] fn into_contiguous_kernel( input: &Tensor, output: &mut Tensor, - rank: Comptime>, + #[comptime] rank: Option, ) { let offset_output = ABSOLUTE_POS; @@ -44,9 +42,9 @@ fn into_contiguous_kernel( input, output, offset_output, - UInt::new(0), - Comptime::unwrap_or_else(rank, || output.rank()), - Comptime::is_some(rank), + 0, + rank.unwrap_or_else(|| output.rank()), + rank.is_some(), ); output[offset_output] = input[offset_input]; @@ -76,7 +74,7 @@ pub fn into_contiguous( cube_dim, input.as_tensor_arg(vectorization_factor), output.as_ref().as_tensor_arg(vectorization_factor), - Some(UInt::new(rank as u32)), + Some(rank as u32), ); } diff --git a/crates/cubecl-macros/Cargo.toml b/crates/cubecl-macros/Cargo.toml index 638966de..38cc3f2c 100644 --- a/crates/cubecl-macros/Cargo.toml +++ b/crates/cubecl-macros/Cargo.toml @@ -21,7 +21,19 @@ default = [] std = [] [dependencies] +darling = { workspace = true } +derive-new = { workspace = true } +ident_case = { workspace = true } +prettyplease = "0.2" proc-macro2 = { workspace = true } quote = { workspace = true } syn = { workspace = true } -derive-new = { workspace = true } + +cubecl-common = { path = "../cubecl-common", version = "0.2", default-features = false } + +[dev-dependencies] +cubecl-core = { path = "../cubecl-core", version = "0.2", default-features = false } +cubecl-cuda = { path = "../cubecl-cuda", version = "0.2", default-features = false } +cubecl-linalg = { path = "../cubecl-linalg", version = "0.2", default-features = false } +cubecl-wgpu = { path = "../cubecl-wgpu", version = "0.2", default-features = false } +pretty_assertions = { workspace = true } diff --git a/crates/cubecl-macros/LICENSE-APACHE b/crates/cubecl-macros/LICENSE-APACHE deleted file mode 120000 index 1cd601d0..00000000 --- a/crates/cubecl-macros/LICENSE-APACHE +++ /dev/null @@ -1 +0,0 @@ -../../LICENSE-APACHE \ No newline at end of file diff --git a/crates/cubecl-macros/LICENSE-APACHE b/crates/cubecl-macros/LICENSE-APACHE new file mode 100644 index 00000000..1cd601d0 --- /dev/null +++ b/crates/cubecl-macros/LICENSE-APACHE @@ -0,0 +1 @@ +../../LICENSE-APACHE \ No newline at end of file diff --git a/crates/cubecl-macros/LICENSE-MIT b/crates/cubecl-macros/LICENSE-MIT deleted file mode 120000 index b2cfbdc7..00000000 --- a/crates/cubecl-macros/LICENSE-MIT +++ /dev/null @@ -1 +0,0 @@ -../../LICENSE-MIT \ No newline at end of file diff --git a/crates/cubecl-macros/LICENSE-MIT b/crates/cubecl-macros/LICENSE-MIT new file mode 100644 index 00000000..b2cfbdc7 --- /dev/null +++ b/crates/cubecl-macros/LICENSE-MIT @@ -0,0 +1 @@ +../../LICENSE-MIT \ No newline at end of file diff --git a/crates/cubecl-macros/src/analyzer.rs b/crates/cubecl-macros/src/analyzer.rs deleted file mode 100644 index 3e734f3d..00000000 --- a/crates/cubecl-macros/src/analyzer.rs +++ /dev/null @@ -1,305 +0,0 @@ -use syn::{Member, Pat, PathArguments, Stmt}; - -use crate::tracker::VariableTracker; - -pub const KEYWORDS: [&str; 21] = [ - "ABSOLUTE_POS", - "ABSOLUTE_POS_X", - "ABSOLUTE_POS_Y", - "ABSOLUTE_POS_Z", - "UNIT_POS", - "UNIT_POS_X", - "UNIT_POS_Y", - "UNIT_POS_Z", - "CUBE_POS", - "CUBE_POS_X", - "CUBE_POS_Y", - "CUBE_POS_Z", - "CUBE_DIM", - "CUBE_DIM_X", - "CUBE_DIM_Y", - "CUBE_DIM_Z", - "CUBE_COUNT", - "CUBE_COUNT_X", - "CUBE_COUNT_Y", - "CUBE_COUNT_Z", - "SUBCUBE_DIM", -]; - -#[derive(Debug, Default)] -/// Reads the whole Cube code and accumulates information, -/// to generate a VariableTracker that looked variable uses ahead -pub(crate) struct VariableAnalyzer { - variable_tracker: VariableTracker, -} - -impl VariableAnalyzer { - pub fn create_tracker(func: &syn::ItemFn) -> VariableTracker { - let analyzer = VariableAnalyzer::default(); - analyzer.analyze(func) - } -} - -impl VariableAnalyzer { - fn analyze(mut self, func: &syn::ItemFn) -> VariableTracker { - // Build the vector of (Id, depth), using recursion - self.signature_declarations(&func.sig); - self.find_occurrences_in_stmts(&func.block.stmts, 0); - - self.variable_tracker - } - - fn signature_declarations(&mut self, sig: &syn::Signature) { - for input in &sig.inputs { - match input { - syn::FnArg::Typed(pat) => { - let ident = &*pat.pat; - let is_comptime = is_ty_comptime(&pat.ty); - - match ident { - syn::Pat::Ident(pat_ident) => { - let id = &pat_ident.ident; - self.variable_tracker - .analyze_declare(id.to_string(), 0, is_comptime); - } - _ => todo!("Analysis: unsupported ident {ident:?}"), - } - } - _ => todo!("Analysis: unsupported input {input:?}"), - } - } - } - - fn find_occurrences_in_stmts(&mut self, stmts: &Vec, depth: u8) { - for stmt in stmts { - match stmt { - // Declaration - syn::Stmt::Local(local) => { - match &local.pat { - syn::Pat::Tuple(pat_tuple) => { - for pat in pat_tuple.elems.iter() { - let (id, is_comptime) = find_local_declaration_ident(pat); - if let Some(id) = id { - self.variable_tracker.analyze_declare( - id.to_string(), - depth, - is_comptime, - ); - } - } - } - _ => { - let (id, is_comptime) = find_local_declaration_ident(&local.pat); - if let Some(id) = id { - self.variable_tracker.analyze_declare( - id.to_string(), - depth, - is_comptime, - ); - } - } - } - if let Some(local_init) = &local.init { - self.find_occurrences_in_expr(&local_init.expr, depth) - } - } - syn::Stmt::Expr(expr, _) => self.find_occurrences_in_expr(expr, depth), - _ => todo!("Analysis: unsupported stmt {stmt:?}"), - } - } - } - - fn find_occurrences_in_expr(&mut self, expr: &syn::Expr, depth: u8) { - match expr { - syn::Expr::ForLoop(expr) => { - self.find_occurrences_in_expr(&expr.expr, depth); - - let depth = depth + 1; - - if let syn::Pat::Ident(pat_ident) = &*expr.pat { - let id = &pat_ident.ident; - self.variable_tracker - .analyze_declare(id.to_string(), depth, false); - } - - self.find_occurrences_in_stmts(&expr.body.stmts, depth); - } - syn::Expr::While(expr) => { - let depth = depth + 1; - - self.find_occurrences_in_expr(&expr.cond, depth); - self.find_occurrences_in_stmts(&expr.body.stmts, depth); - } - syn::Expr::Loop(expr) => { - let depth = depth + 1; - - self.find_occurrences_in_stmts(&expr.body.stmts, depth); - } - syn::Expr::If(expr) => { - let depth = depth + 1; - - self.find_occurrences_in_expr(&expr.cond, depth); - self.find_occurrences_in_stmts(&expr.then_branch.stmts, depth); - if let Some((_, expr)) = &expr.else_branch { - match &**expr { - syn::Expr::Block(expr_block) => { - self.find_occurrences_in_stmts(&expr_block.block.stmts, depth); - } - syn::Expr::If(expr) => { - self.find_occurrences_in_expr(&syn::Expr::If(expr.clone()), depth); - } - _ => unreachable!(), - } - } - } - syn::Expr::Assign(expr) => { - self.find_occurrences_in_expr(&expr.left, depth); - self.find_occurrences_in_expr(&expr.right, depth); - } - syn::Expr::Index(expr) => { - self.find_occurrences_in_expr(&expr.expr, depth); - self.find_occurrences_in_expr(&expr.index, depth); - } - syn::Expr::Path(expr) => { - if let Some(ident) = expr.path.get_ident() { - if !KEYWORDS.contains(&ident.to_string().as_str()) { - self.variable_tracker.analyze_reuse(ident, depth, None); - } - } - } - syn::Expr::Binary(expr) => { - self.find_occurrences_in_expr(&expr.left, depth); - self.find_occurrences_in_expr(&expr.right, depth); - } - syn::Expr::Lit(_) => {} - syn::Expr::Call(expr) => { - match &*expr.func { - syn::Expr::Path(expr_path) => { - if let Some(first_segment) = expr_path.path.segments.first() { - // Check if the path segment has generic arguments - if let PathArguments::AngleBracketed(arguments) = - &first_segment.arguments - { - // Extract the generic arguments - for arg in &arguments.args { - match arg { - syn::GenericArgument::Type(_) - | syn::GenericArgument::Constraint(_) => {} - _ => todo!("Analysis: Generic {:?} not supported", arg), - } - } - } - } - } - _ => todo!("Analysis: unsupported func expr {:?}", expr.func), - } - for arg in expr.args.iter() { - self.find_occurrences_in_expr(arg, depth); - } - } - syn::Expr::MethodCall(expr) => { - self.find_occurrences_in_expr(&expr.receiver, depth); - for arg in expr.args.iter() { - self.find_occurrences_in_expr(arg, depth); - } - } - syn::Expr::Break(_) => {} - syn::Expr::Return(expr) => { - if expr.expr.is_some() { - // Unsupported: handled in codegen. - } - } - syn::Expr::Paren(expr) => self.find_occurrences_in_expr(&expr.expr, depth), - syn::Expr::Array(_expr) => { - // No analysis since only literals are supported - } - syn::Expr::Reference(expr) => self.find_occurrences_in_expr(&expr.expr, depth), - syn::Expr::Closure(expr) => { - let depth = depth + 1; - - for path in expr.inputs.iter() { - let mut is_comptime = false; - let ident = match path { - Pat::Ident(pat_ident) => &pat_ident.ident, - Pat::Type(pat_type) => { - is_comptime = is_ty_comptime(&pat_type.ty); - - if let Pat::Ident(pat_ident) = &*pat_type.pat { - &pat_ident.ident - } else { - todo!("Analysis: {:?} not supported in closure inputs. ", path); - } - } - _ => todo!("Analysis: {:?} not supported in closure inputs. ", path), - }; - - self.variable_tracker - .analyze_declare(ident.to_string(), depth, is_comptime); - } - - self.find_occurrences_in_expr(&expr.body, depth) - } - syn::Expr::Unary(expr) => self.find_occurrences_in_expr(&expr.expr, depth), - syn::Expr::Field(expr) => { - if let Member::Named(attribute_ident) = &expr.member { - if let syn::Expr::Path(struct_expr) = &*expr.base { - let struct_ident = struct_expr - .path - .get_ident() - .expect("Analysis: field access only supported on ident struct."); - - self.variable_tracker.analyze_reuse( - struct_ident, - depth, - Some(attribute_ident.to_string()), - ); - } else { - todo!("Analysis: field access only supported on ident struct."); - } - } else { - todo!("Analysis: unnamed attribute not supported."); - } - } - syn::Expr::Struct(expr) => { - for field in expr.fields.iter() { - self.find_occurrences_in_expr(&field.expr, depth) - } - } - syn::Expr::Range(_range) => { - // Error is handled during codegen. - } - _ => { - // Error is handled during codegen. - } - } - } -} - -fn find_local_declaration_ident(pat: &syn::Pat) -> (Option<&syn::Ident>, bool) { - let mut is_comptime = false; - let id = match &pat { - syn::Pat::Ident(pat_ident) => Some(&pat_ident.ident), - syn::Pat::Type(pat_type) => { - is_comptime = is_ty_comptime(&pat_type.ty); - match &*pat_type.pat { - syn::Pat::Ident(pat_ident) => Some(&pat_ident.ident), - _ => todo!("Analysis: unsupported typed path {:?}", pat_type.pat), - } - } - syn::Pat::Wild(_) => None, - _ => todo!("Analysis: unsupported path {:?}", pat), - }; - (id, is_comptime) -} - -fn is_ty_comptime(ty: &syn::Type) -> bool { - if let syn::Type::Path(path) = ty { - for segment in path.path.segments.iter() { - if segment.ident == "Comptime" { - return true; - } - } - } - - false -} diff --git a/crates/cubecl-macros/src/codegen_common/mod.rs b/crates/cubecl-macros/src/codegen_common/mod.rs deleted file mode 100644 index ed3f3a2d..00000000 --- a/crates/cubecl-macros/src/codegen_common/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub(crate) mod signature; diff --git a/crates/cubecl-macros/src/codegen_common/signature.rs b/crates/cubecl-macros/src/codegen_common/signature.rs deleted file mode 100644 index ee11395f..00000000 --- a/crates/cubecl-macros/src/codegen_common/signature.rs +++ /dev/null @@ -1,70 +0,0 @@ -use quote::ToTokens; - -use crate::tracker::VariableTracker; - -#[derive(Copy, Clone, Debug)] -pub enum ExpandMode { - FuncImpl, - MethodImpl, -} - -pub fn expand_sig( - sig: &syn::Signature, - visibility: &syn::Visibility, - mut variable_tracker: Option<&mut VariableTracker>, - mode: ExpandMode, -) -> proc_macro2::TokenStream { - let mut inputs = quote::quote!(); - - for input in &sig.inputs { - match input { - syn::FnArg::Typed(pat) => { - let ident = pat.pat.clone(); - - if let syn::Pat::Ident(ident) = ident.as_ref() { - if let Some(vars) = &mut variable_tracker { - vars.codegen_declare(ident.ident.to_string(), 0); - } - } - - let ty = no_ref(pat.ty.as_ref()); - inputs.extend(quote::quote! { - #ident: <#ty as cubecl::frontend::CubeType>::ExpandType, - }); - } - _ => todo!("Only Typed inputs are supported"), - } - } - - let mut output = quote::quote!(); - - match &sig.output { - syn::ReturnType::Default => output.extend(quote::quote! { ()}), - syn::ReturnType::Type(_, ty) => { - let ty = no_ref(ty.as_ref()); - output.extend(quote::quote! { - <#ty as cubecl::frontend::CubeType>::ExpandType - }); - } - } - - let ident = &sig.ident; - let ident = match mode { - ExpandMode::FuncImpl => syn::Ident::new("__expand".to_string().as_str(), ident.span()), - _ => syn::Ident::new(format!("__expand_{ident}").as_str(), ident.span()), - }; - - let generics = sig.generics.clone().into_token_stream(); - - quote::quote! { - /// Expanded Cube function - #visibility fn #ident #generics (context: &mut cubecl::frontend::CubeContext, #inputs) -> #output - } -} - -pub fn no_ref(ty: &syn::Type) -> &syn::Type { - match ty { - syn::Type::Reference(val) => &val.elem, - _ => ty, - } -} diff --git a/crates/cubecl-macros/src/codegen_function/base.rs b/crates/cubecl-macros/src/codegen_function/base.rs deleted file mode 100644 index f40aa4e4..00000000 --- a/crates/cubecl-macros/src/codegen_function/base.rs +++ /dev/null @@ -1,132 +0,0 @@ -use proc_macro2::TokenStream; -use quote::ToTokens; - -use super::{expr::codegen_expr, variable::codegen_local}; -use crate::tracker::VariableTracker; - -/// Codegen for a statement (generally one line) -/// Entry point of code generation -pub fn codegen_statement( - statement: &syn::Stmt, - loop_level: usize, - variable_tracker: &mut VariableTracker, -) -> TokenStream { - match statement { - syn::Stmt::Local(local) => codegen_local(local, loop_level, variable_tracker), - syn::Stmt::Expr(expr, semi) => { - let expr = codegen_expr(expr, loop_level, variable_tracker).tokens; - - match semi { - Some(_semi) => quote::quote!( - #expr; - ), - None => expr, - } - } - _ => todo!("Codegen: statement {statement:?} not supported"), - } -} - -/// Codegen for a code block (a list of statements) -pub(crate) fn codegen_block( - block: &syn::Block, - loop_level: usize, - variable_tracker: &mut VariableTracker, -) -> TokenStream { - let mut statements = quote::quote!(); - - for statement in block.stmts.iter() { - statements.extend(codegen_statement(statement, loop_level, variable_tracker)); - } - - quote::quote! { - { - #statements - } - } -} - -#[derive(Clone, Copy)] -pub(crate) enum CodegenKind { - Comptime, - Literal, - Expand, -} - -#[derive(Clone)] -pub(crate) struct Codegen { - tokens: proc_macro2::TokenStream, - array_indexing: Option, - kind: CodegenKind, -} - -#[derive(Clone)] -pub(crate) struct ArrayIndexing { - pub array: proc_macro2::TokenStream, - pub index: proc_macro2::TokenStream, -} - -impl From for Codegen { - fn from(tokens: proc_macro2::TokenStream) -> Self { - Self { - tokens, - kind: CodegenKind::Expand, - array_indexing: None, - } - } -} - -impl Codegen { - pub fn new>(tokens: S, kind: CodegenKind) -> Self { - Self { - tokens: tokens.into(), - kind, - array_indexing: None, - } - } - - pub fn process(mut self) -> (proc_macro2::TokenStream, CodegenKind, Option) { - let kind = self.kind; - let array_indexing = self.pop_array_indexing(); - let tokens = self.tokens(); - - (tokens, kind, array_indexing) - } - - pub fn tokens(self) -> TokenStream { - self.into_token_stream() - } - - pub fn pop_array_indexing(&mut self) -> Option { - let mut result = None; - core::mem::swap(&mut result, &mut self.array_indexing); - result - } - - pub fn set_array_indexing(&mut self, array_indexing: Option) { - self.array_indexing = array_indexing; - } -} - -impl ToTokens for Codegen { - fn to_tokens(&self, tokens: &mut TokenStream) { - let cloned = self.clone(); - let toks = cloned.into_token_stream(); - tokens.extend(toks); - } - fn into_token_stream(self) -> TokenStream - where - Self: Sized, - { - match self.kind { - CodegenKind::Comptime => self.tokens, - CodegenKind::Expand => self.tokens, - CodegenKind::Literal => { - let lit = self.tokens; - quote::quote! { - cubecl::frontend::ExpandElementTyped::from_lit(#lit) - } - } - } - } -} diff --git a/crates/cubecl-macros/src/codegen_function/branch.rs b/crates/cubecl-macros/src/codegen_function/branch.rs deleted file mode 100644 index 0305aff5..00000000 --- a/crates/cubecl-macros/src/codegen_function/branch.rs +++ /dev/null @@ -1,251 +0,0 @@ -use proc_macro2::TokenStream; -use syn::{Expr, ExprUnary, UnOp}; - -use crate::{ - codegen_function::{base::CodegenKind, expr::codegen_expr}, - tracker::VariableTracker, -}; - -use super::{ - base::{codegen_block, Codegen}, - function::codegen_call, - operation::{codegen_binary, codegen_unary}, - variable::{codegen_lit, codegen_path_var}, -}; - -/// Codegen of for loops -/// Supports range: -/// ```ignore -/// for i in range(start, end, unroll) {...} -/// ``` -/// and range_stepped: -/// ```ignore -/// for i in range_stepped(start, end, step, unroll) {...} -/// ``` -pub(crate) fn codegen_for_loop( - for_loop: &syn::ExprForLoop, - loop_level: usize, - variable_tracker: &mut VariableTracker, -) -> TokenStream { - let i = &for_loop.pat; - - if let syn::Pat::Ident(pat_ident) = &*for_loop.pat { - let id = &pat_ident.ident; - variable_tracker.codegen_declare(id.to_string(), loop_level as u8 + 1); - } - - let invalid_for_loop = || { - syn::Error::new_spanned( - &for_loop.expr, - "Invalid for loop: use [range](cubecl::prelude::range] or [range_stepped](cubecl::prelude::range_stepped) instead.", - ) - .into_compile_error() - }; - - match for_loop.expr.as_ref() { - syn::Expr::Call(call) => { - let func_name = match call.func.as_ref() { - syn::Expr::Path(path) => match path.path.get_ident() { - Some(ident) => ident, - None => return invalid_for_loop(), - }, - _ => { - return invalid_for_loop(); - } - }; - - if &func_name.to_string() == "range" { - let mut args = call.args.clone(); - - let unroll = codegen_expr( - &args.pop().unwrap().into_value(), - loop_level, - variable_tracker, - ); - let end = codegen_expr( - &args.pop().unwrap().into_value(), - loop_level, - variable_tracker, - ); - let start = codegen_expr( - &args.pop().unwrap().into_value(), - loop_level, - variable_tracker, - ); - - let block = codegen_block(&for_loop.body, loop_level + 1, variable_tracker); - - quote::quote! { - { - let _start = #start; - let _end = #end; - let _unroll = #unroll; - cubecl::frontend::branch::range_expand(context, _start, _end, _unroll, |context, #i| #block); - } - } - } else if &func_name.to_string() == "range_stepped" { - let mut args = call.args.clone(); - - let unroll = codegen_expr( - &args.pop().unwrap().into_value(), - loop_level, - variable_tracker, - ); - let step = codegen_expr( - &args.pop().unwrap().into_value(), - loop_level, - variable_tracker, - ); - let end = codegen_expr( - &args.pop().unwrap().into_value(), - loop_level, - variable_tracker, - ); - let start = codegen_expr( - &args.pop().unwrap().into_value(), - loop_level, - variable_tracker, - ); - - let block = codegen_block(&for_loop.body, loop_level + 1, variable_tracker); - - quote::quote! { - { - let _start = #start; - let _end = #end; - let _step = #step; - let _unroll = #unroll; - cubecl::frontend::branch::range_stepped_expand(context, _start, _end, _step, _unroll, |context, #i| #block); - } - } - } else { - invalid_for_loop() - } - } - syn::Expr::Path(pat) => { - let block = codegen_block(&for_loop.body, loop_level + 1, variable_tracker); - - quote::quote! { - for #i in #pat #block - } - } - _ => invalid_for_loop(), - } -} - -/// Codegen for condition of an if or a while -pub(crate) fn codegen_cond( - cond: &syn::Expr, - loop_level: usize, - variable_tracker: &mut VariableTracker, -) -> Codegen { - match cond { - syn::Expr::Unary(expr) => codegen_unary(expr, loop_level, variable_tracker), - syn::Expr::Binary(expr) => codegen_binary(expr, loop_level, variable_tracker), - syn::Expr::Lit(expr) => Codegen::new(codegen_lit(expr), CodegenKind::Literal), - syn::Expr::Path(expr) => codegen_path_var(expr, loop_level, variable_tracker), - syn::Expr::Call(expr) => codegen_call(expr, loop_level, variable_tracker), - _ => todo!("{cond:?} cond not supported"), - } -} - -/// Codegen for break statement -pub(crate) fn codegen_break() -> TokenStream { - quote::quote! { - cubecl::frontend::branch::break_expand(context); - } -} - -/// Codegen for return statement -pub(crate) fn codegen_return(expr_return: &syn::ExprReturn) -> TokenStream { - if expr_return.expr.is_some() { - return syn::Error::new_spanned(expr_return, "Only void return is supported.") - .into_compile_error(); - } - - quote::quote! { - cubecl::frontend::branch::return_expand(context); - } -} - -/// Codegen for if and if/else statements -/// Supports: -/// if cond {...} -/// if cond {...} else {...} -/// if Comptime::get(...) {...} [else {...}] -/// if Comptime::get(...) {...} [else if Comptime::get(...) {...}]* [else {...}] -pub(crate) fn codegen_if( - expr_if: &syn::ExprIf, - loop_level: usize, - variable_tracker: &mut VariableTracker, -) -> TokenStream { - let (cond, kind, _) = codegen_cond(&expr_if.cond, loop_level, variable_tracker).process(); - let comptime_bool = if let CodegenKind::Comptime = kind { - quote::quote! { Some(#cond) } - } else { - quote::quote! { None } - }; - - let then_block = codegen_block(&expr_if.then_branch, loop_level + 1, variable_tracker); - - if let Some((_, expr)) = &expr_if.else_branch { - let else_block = match &**expr { - syn::Expr::Block(expr_block) => { - codegen_block(&expr_block.block, loop_level + 1, variable_tracker) - } - - syn::Expr::If(expr_if) => codegen_if(expr_if, loop_level + 1, variable_tracker), - _ => unreachable!(), - }; - quote::quote! { - { - let _cond = #cond; - cubecl::frontend::branch::if_else_expand(context, #comptime_bool, _cond.into(), |context| #then_block, |context| #else_block); - } - } - } else { - quote::quote! { - let _cond = #cond; - cubecl::frontend::branch::if_expand(context, #comptime_bool, _cond.into(), |context| #then_block); - } - } -} - -/// Codegen of loop -pub(crate) fn codegen_loop( - loop_expr: &syn::ExprLoop, - loop_level: usize, - variable_tracker: &mut VariableTracker, -) -> TokenStream { - let block = codegen_block(&loop_expr.body, loop_level + 1, variable_tracker); - - quote::quote! { - cubecl::frontend::branch::loop_expand(context, |context| #block); - } -} - -/// Codegen for while loop -pub(crate) fn codegen_while_loop( - while_loop: &syn::ExprWhile, - loop_level: usize, - variable_tracker: &mut VariableTracker, -) -> TokenStream { - let inverted_cond = Expr::Unary(ExprUnary { - attrs: vec![], - op: UnOp::Not(Default::default()), - expr: Box::new(*while_loop.cond.clone()), - }); - - let (cond, kind, _) = codegen_cond(&inverted_cond, loop_level + 1, variable_tracker).process(); - - if let CodegenKind::Comptime = kind { - return syn::Error::new_spanned(while_loop.while_token, "Comptime not supported for while") - .into_compile_error(); - } - - let block = codegen_block(&while_loop.body, loop_level + 1, variable_tracker); - - quote::quote! { - cubecl::frontend::branch::while_loop_expand(context, |context| #cond, |context| #block); - } -} diff --git a/crates/cubecl-macros/src/codegen_function/expr.rs b/crates/cubecl-macros/src/codegen_function/expr.rs deleted file mode 100644 index 84fbdb1b..00000000 --- a/crates/cubecl-macros/src/codegen_function/expr.rs +++ /dev/null @@ -1,133 +0,0 @@ -use crate::tracker::VariableTracker; -use proc_macro2::{Ident, Span, TokenStream}; - -use super::{ - base::{codegen_block, Codegen, CodegenKind}, - branch::{ - codegen_break, codegen_for_loop, codegen_if, codegen_loop, codegen_return, - codegen_while_loop, - }, - function::{codegen_call, codegen_closure, codegen_expr_method_call}, - operation::{codegen_binary, codegen_unary}, - variable::{ - codegen_array_lit, codegen_assign, codegen_field, codegen_index, codegen_lit, - codegen_path_var, codegen_struct, - }, -}; - -/// Codegen for expressions -pub(crate) fn codegen_expr( - expr: &syn::Expr, - loop_level: usize, - variable_tracker: &mut VariableTracker, -) -> Codegen { - match expr { - syn::Expr::Call(call) => codegen_call(call, loop_level, variable_tracker), - syn::Expr::Paren(paren) => codegen_expr(&paren.expr, loop_level, variable_tracker), - _ => { - let mut array_indexing = None; - let mut kind = CodegenKind::Expand; - let tokens = match expr { - syn::Expr::Path(path) => { - return codegen_path_var(path, loop_level, variable_tracker) - } - syn::Expr::Binary(op) => return codegen_binary(op, loop_level, variable_tracker), - syn::Expr::Unary(op) => return codegen_unary(op, loop_level, variable_tracker), - syn::Expr::Lit(lit) => { - kind = CodegenKind::Literal; - codegen_lit(lit) - } - syn::Expr::Closure(closure) => { - codegen_closure(closure, loop_level, variable_tracker) - } - syn::Expr::Block(block) => codegen_expr_block(block, loop_level, variable_tracker), - syn::Expr::Assign(assign) => codegen_assign(assign, loop_level, variable_tracker), - syn::Expr::ForLoop(for_loop) => { - codegen_for_loop(for_loop, loop_level, variable_tracker) - } - syn::Expr::While(while_loop) => { - codegen_while_loop(while_loop, loop_level, variable_tracker) - } - syn::Expr::Loop(loop_expr) => codegen_loop(loop_expr, loop_level, variable_tracker), - syn::Expr::Break(_) => codegen_break(), - syn::Expr::Return(return_expr) => codegen_return(return_expr), - syn::Expr::If(expr_if) => codegen_if(expr_if, loop_level, variable_tracker), - syn::Expr::MethodCall(call) => { - codegen_expr_method_call(call, loop_level, variable_tracker) - } - syn::Expr::Index(index) => { - let (tokens, index_kind, index_array_indexing) = - codegen_index(index, loop_level, variable_tracker).process(); - - array_indexing = index_array_indexing; - kind = index_kind; - tokens - } - syn::Expr::Array(array) => codegen_array_lit(array), - syn::Expr::Reference(reference) => { - codegen_ref(reference, loop_level, variable_tracker) - } - syn::Expr::Field(field) => codegen_field(field, loop_level, variable_tracker), - syn::Expr::Struct(struct_) => codegen_struct(struct_, loop_level, variable_tracker), - syn::Expr::Range(range) => syn::Error::new_spanned( - range, - "Range is not supported, use [range](cubecl::prelude::range) instead.", - ) - .to_compile_error(), - syn::Expr::Tuple(tuple) => codegen_tuple(tuple, loop_level, variable_tracker), - _ => { - syn::Error::new_spanned(expr, "Expression Is not supported").to_compile_error() - } - }; - - let mut codegen = Codegen::new(tokens, kind); - codegen.set_array_indexing(array_indexing); - codegen - } - } -} - -/// Codegen for tuple expressions -pub(crate) fn codegen_tuple( - unary: &syn::ExprTuple, - loop_level: usize, - variable_tracker: &mut VariableTracker, -) -> TokenStream { - let mut res = quote::quote! {}; - let mut vars = Vec::new(); - for (i, expr) in unary.elems.iter().enumerate() { - let expr_codegen = codegen_expr(expr, loop_level, variable_tracker); - let expr_tokens = expr_codegen.tokens(); - let var = Ident::new(&format!("_tuple_{}", i), Span::call_site()); - res = quote::quote! { - #res - let #var = #expr_tokens; - }; - vars.push(var); - } - quote::quote! { - { - #res - ( #(#vars),* ) - } - } -} - -/// Codegen for an expression containing a block -pub(crate) fn codegen_expr_block( - block: &syn::ExprBlock, - loop_level: usize, - variable_tracker: &mut VariableTracker, -) -> TokenStream { - codegen_block(&block.block, loop_level, variable_tracker) -} - -pub(crate) fn codegen_ref( - reference: &syn::ExprReference, - loop_level: usize, - variable_tracker: &mut VariableTracker, -) -> TokenStream { - // We ignore reference for the expansion. - let inner = codegen_expr(&reference.expr, loop_level, variable_tracker); - quote::quote! { #inner } -} diff --git a/crates/cubecl-macros/src/codegen_function/function.rs b/crates/cubecl-macros/src/codegen_function/function.rs deleted file mode 100644 index 9626f554..00000000 --- a/crates/cubecl-macros/src/codegen_function/function.rs +++ /dev/null @@ -1,261 +0,0 @@ -use proc_macro2::{Span, TokenStream}; -use quote::quote_spanned; -use syn::{ - punctuated::Punctuated, spanned::Spanned, AngleBracketedGenericArguments, Expr, Ident, - PathArguments, Token, -}; - -use crate::{codegen_function::expr::codegen_expr, tracker::VariableTracker}; - -use super::base::{Codegen, CodegenKind}; - -/// Codegen for method call -/// Supports [expr].method(args) -pub(crate) fn codegen_expr_method_call( - call: &syn::ExprMethodCall, - loop_level: usize, - variable_tracker: &mut VariableTracker, -) -> TokenStream { - let receiver = codegen_expr(&call.receiver, loop_level, variable_tracker); - - if call.method == "into" { - let (tokens, kind, _) = receiver.process(); - - return match kind { - CodegenKind::Comptime => quote::quote! { #tokens.into() }, - CodegenKind::Literal => quote::quote! { #tokens }, - CodegenKind::Expand => quote::quote! { #tokens.into() }, - }; - } - - let method_expand = syn::Ident::new( - format!("__expand_{}_method", call.method).as_str(), - proc_macro2::Span::call_site(), - ); - let (expansion, variables) = codegen_args(&call.args, loop_level, variable_tracker); - - quote::quote!( { - #expansion - #receiver . #method_expand ( #variables ) - }) -} - -/// Codegen for a closure -pub(crate) fn codegen_closure( - closure: &syn::ExprClosure, - loop_level: usize, - variable_tracker: &mut VariableTracker, -) -> TokenStream { - let mut inputs = quote::quote! {}; - for input in closure.inputs.iter() { - let (ident, ty) = match input { - syn::Pat::Ident(ident) => (&ident.ident, None), - syn::Pat::Type(pat_type) => ( - if let syn::Pat::Ident(ident) = &*pat_type.pat { - &ident.ident - } else { - return syn::Error::new_spanned(pat_type, "Unsupported input") - .into_compile_error(); - }, - Some(pat_type.ty.clone()), - ), - _ => return syn::Error::new_spanned(input, "Unsupported input").into_compile_error(), - }; - - if let Some(ty) = ty { - inputs.extend(quote::quote! { - #ident: <#ty as CubeType>::ExpandType, - }); - } else { - inputs.extend(quote::quote! { - #ident, - }); - } - } - - let body = codegen_expr(closure.body.as_ref(), loop_level, variable_tracker); - - quote::quote! { - |context: &mut CubeContext, #inputs| #body - } -} - -/// Maps -/// [A[::<...>]?::]^* func[::<...>] (args) -/// to -/// [A[::<...>]?::]^* func_expand[::<...>] (context, args) -/// -/// Also returns a bool that is true if it's comptime -pub(crate) fn codegen_call( - call: &syn::ExprCall, - loop_level: usize, - variable_tracker: &mut VariableTracker, -) -> Codegen { - // We start with parsing the function path - let path: Vec<(&Ident, Option<&AngleBracketedGenericArguments>)> = match call.func.as_ref() { - syn::Expr::Path(expr_path) => { - let mut path = Vec::new(); - for segment in expr_path.path.segments.iter() { - let generics = if let PathArguments::AngleBracketed(arguments) = &segment.arguments - { - Some(arguments) - } else { - None - }; - path.push((&segment.ident, generics)); - } - path - } - _ => { - return Codegen::new( - syn::Error::new_spanned(&call.func, "Unsupported").into_compile_error(), - CodegenKind::Expand, - ) - } - }; - - // Path - let mut path_tokens = TokenStream::new(); - let mut is_comptime = false; - let mut is_plain_func = true; - let mut comptime_func: Option = None; - - for (i, (ident, generics)) in path.iter().enumerate() { - let name = ident.to_string(); - - if name == "Comptime" { - is_comptime = true; - continue; - } - - if let Some(first_char) = name.chars().next() { - if first_char.is_uppercase() { - is_plain_func = false; - } - } - - if i == path.len() - 1 { - if is_comptime { - comptime_func = Some(ident.to_string()); - break; - } - - let func_name_expand = if is_plain_func { - quote::quote! { - #ident::__expand - } - } else { - let ident = syn::Ident::new( - format!("__expand_{ident}").as_str(), - proc_macro2::Span::call_site(), - ); - quote::quote! { - #ident - } - }; - path_tokens.extend(quote_spanned! {func_name_expand.span() => #func_name_expand }); - if let Some(generics) = generics { - path_tokens.extend(quote_spanned! {generics.span() => #generics }); - } - } else if let Some(generics) = generics { - path_tokens.extend(quote_spanned! {ident.span() => #ident }); - path_tokens.extend(quote_spanned! {generics.span() => #generics :: }); - } else { - path_tokens.extend(quote_spanned! {ident.span() => #ident :: }); - } - } - - // Arguments - if let Some(func_name) = comptime_func { - let tokens = match func_name.as_str() { - "get" | "new" => { - let code = call.args.first().unwrap(); - quote::quote! {#code} - } - "map" => { - let args = &call.args; - - // Codegen - quote::quote! { - { - Comptime::__expand_map(#args) - } - } - } - "unwrap_or_else" => { - let (expansion, variables) = codegen_args(&call.args, loop_level, variable_tracker); - - // Codegen - quote::quote! {{ - #expansion - Comptime::__expand_unwrap_or_else(#variables) - }} - } - "is_some" => { - let code = call.args.first().unwrap(); - quote::quote! { #code.is_some() } - } - "vectorization" => { - let (expansion, variables) = codegen_args(&call.args, loop_level, variable_tracker); - - // Codegen - quote::quote! {{ - #expansion - Comptime::__expand_vectorization(#variables) - }} - } - "vectorize" => { - let (expansion, variables) = codegen_args(&call.args, loop_level, variable_tracker); - - // Codegen - quote::quote! {{ - #expansion - Comptime::__expand_vectorize(#variables) - }} - } - "runtime" => { - let (expansion, variables) = codegen_args(&call.args, loop_level, variable_tracker); - - // Codegen - quote::quote! {{ - #expansion - Comptime::__expand_runtime(#variables) - }} - } - - _ => panic!("Codegen: Comptime function {:?} does not exist", func_name), - }; - - Codegen::new(tokens, CodegenKind::Comptime) - } else { - let (expansion, variables) = codegen_args(&call.args, loop_level, variable_tracker); - - // Codegen - let tokens = quote::quote! {{ - #expansion - #path_tokens (#variables) - }}; - - Codegen::new(tokens, CodegenKind::Expand) - } -} - -fn codegen_args( - args: &Punctuated, - loop_level: usize, - variable_tracker: &mut VariableTracker, -) -> (TokenStream, TokenStream) { - let mut expansion = quote::quote! {}; - let mut variables = quote::quote! {}; - - variables.extend(quote::quote! { context, }); - - for (i, argument) in args.iter().enumerate() { - let ident = Ident::new(format!("_var_{i}").as_str(), Span::call_site()); - let arg_token = codegen_expr(argument, loop_level, variable_tracker); - expansion.extend(quote::quote! { let #ident = #arg_token; }); - variables.extend(quote::quote! { #ident, }); - } - - (expansion, variables) -} diff --git a/crates/cubecl-macros/src/codegen_function/launch.rs b/crates/cubecl-macros/src/codegen_function/launch.rs deleted file mode 100644 index c4ec8647..00000000 --- a/crates/cubecl-macros/src/codegen_function/launch.rs +++ /dev/null @@ -1,546 +0,0 @@ -use proc_macro2::{Span, TokenStream}; -use syn::{parse_quote, Generics, Ident}; - -#[derive(Default)] -struct Codegen { - // Basic attributes. - name: String, - generics: Generics, - fn_inputs: TokenStream, - fn_output: TokenStream, - // States to generate code. - state_comptimes: Vec<(syn::Type, Ident)>, - state_args: Vec, - state_inputs: Vec<(Ident, syn::Type)>, - state_outputs: Vec<(Ident, syn::Type)>, - unchecked: bool, -} - -impl Codegen { - fn from_sig(sig: &syn::Signature, unchecked: bool) -> Self { - let mut codegen = Codegen { - name: snake_to_pascal_case(&sig.ident.to_string()), - generics: sig.generics.clone(), - unchecked, - ..Codegen::default() - }; - - let mut inputs = quote::quote!(); - - for input in &sig.inputs { - let mut is_output = false; - let mut comptime = false; - - match input { - syn::FnArg::Typed(pat) => { - let (ty, ident) = match pat.pat.as_ref() { - syn::Pat::Ident(ident) => { - if ident.mutability.is_some() { - is_output = true; - } - - if let syn::Type::Reference(ty) = pat.ty.as_ref() { - if ty.mutability.is_some() { - is_output = true; - } - }; - - if let syn::Type::Path(pat) = pat.ty.as_ref() { - if let Some(name) = pat.path.segments.first() { - let name = name.ident.to_string(); - - if name == "Comptime" { - comptime = true; - } - } - }; - - (pat.ty.clone(), ident.ident.clone()) - } - _ => panic!("Nop"), - }; - - if comptime { - codegen.state_args.push(quote::quote! { - self.#ident - }); - } else { - codegen.state_args.push(quote::quote! { - #ident - }); - } - - if comptime { - let ty = no_ref(&ty); - inputs.extend(quote::quote! { - #ident: <#ty as cubecl::frontend::CubeType>::ExpandType, - }); - } else { - let ty = no_ref(&ty); - inputs.extend(quote::quote! { - #ident: RuntimeArg<'a, #ty, R>, - }); - } - - if is_output { - codegen - .state_outputs - .push((ident.clone(), no_ref(&ty).clone())); - } else if comptime { - codegen - .state_comptimes - .push((first_generic_ty(&ty).clone(), ident.clone())); - } else { - codegen - .state_inputs - .push((ident.clone(), no_ref(&ty).clone())); - } - } - _ => panic!("Only Typed inputs are supported"), - }; - } - - let mut output = quote::quote!(); - - match &sig.output { - syn::ReturnType::Default => output.extend(quote::quote! {()}), - syn::ReturnType::Type(_, ty) => { - output.extend(quote::quote! { - <#ty as cubecl::frontend::CubeType>::ExpandType - }); - } - } - - codegen.fn_inputs = inputs; - codegen.fn_output = output; - - codegen - } - - fn gen_kernel_struct(&self) -> TokenStream { - let ident = Ident::new(&self.name, Span::call_site()); - let generics = add_runtime(self.generics.clone()); - let phantoms = self.phantoms(&generics, true); - let mut comptimes = quote::quote! {}; - - for (ty, ident) in self.state_comptimes.iter() { - comptimes.extend(quote::quote! { - #ident: #ty, - }); - } - - quote::quote! { - /// Kernel - pub struct #ident #generics { - settings: KernelSettings, - #comptimes - #phantoms - } - } - } - - fn gen_settings(&self) -> TokenStream { - let mut variables = quote::quote! {}; - - for (pos, (ident, _ty)) in self.state_inputs.iter().enumerate() { - variables.extend(quote::quote! { - settings = ArgSettings::::configure_input(&#ident, #pos, settings); - }); - } - - for (pos, (ident, _ty)) in self.state_outputs.iter().enumerate() { - variables.extend(quote::quote! { - settings = ArgSettings::::configure_output(&#ident, #pos, settings); - }); - } - - quote::quote! { - let mut settings = KernelSettings::default(); - settings = settings.cube_dim(cube_dim); - #variables - } - } - - fn gen_register_input(&self) -> TokenStream { - let generics = &self.generics; - let mut variables = quote::quote! {}; - - for (pos, (_ident, ty)) in self.state_inputs.iter().enumerate() { - variables.extend(quote::quote! { - #pos => std::sync::Arc::new(<#ty as LaunchArgExpand>::expand(builder, settings.vectorization_input(#pos))), - }); - } - - quote::quote! { - #[allow(unused)] - fn register_input #generics( - builder: &mut KernelBuilder, - settings: &KernelSettings, - position: usize, - ) -> std::sync::Arc { - match position { - #variables - _ => panic!("Input {position} is invalid."), - } - } - } - } - - fn gen_register_output(&self) -> TokenStream { - let generics = &self.generics; - let mut variables = quote::quote! {}; - - for (pos, (_ident, ty)) in self.state_outputs.iter().enumerate() { - variables.extend(quote::quote! { - #pos => std::sync::Arc::new(<#ty as LaunchArgExpand>::expand_output(builder, settings.vectorization_output(#pos))), - }); - } - - quote::quote! { - #[allow(unused)] - fn register_output #generics ( - builder: &mut KernelBuilder, - settings: &KernelSettings, - position: usize, - ) -> std::sync::Arc { - match position { - #variables - _ => panic!("Input {position} is invalid."), - } - } - } - } - - fn gen_define_impl(&self, expand: &TokenStream) -> TokenStream { - let mut expand_args = quote::quote! { &mut builder.context, }; - - let mut variables = quote::quote! {}; - - for (pos, (ident, ty)) in self.state_inputs.iter().enumerate() { - variables.extend(quote::quote! { - let #ident: &<#ty as CubeType>::ExpandType = inputs - .get(&#pos) - .unwrap() - .downcast_ref() - .expect("Input type should be correct. It could be caused by an invalid kernel input/output alias."); - }); - } - - for (pos, (ident, ty)) in self.state_outputs.iter().enumerate() { - variables.extend(quote::quote! { - let #ident: &<#ty as CubeType>::ExpandType = outputs - .get(&#pos) - .unwrap() - .downcast_ref() - .expect("Output type should be correct. It could be caused by an invalid kernel input/output alias."); - }); - } - - for arg in self.state_args.iter() { - expand_args.extend(quote::quote! { - #arg.clone(), - }) - } - - let expand_func = match self.generics.params.is_empty() { - true => quote::quote! { #expand }, - false => { - let generics = self.generics.split_for_impl().1; - quote::quote! { #expand::#generics } - } - }; - - quote::quote! { - #variables - #expand_func(#expand_args); - builder.build(self.settings.clone()) - } - } - - fn gen_define_args(&self) -> TokenStream { - let num_inputs = self.state_inputs.len(); - let num_outputs = self.state_outputs.len(); - - let register_input = self.gen_register_input(); - let register_output = self.gen_register_output(); - - let (register_input_call, register_output_call) = match self.generics.params.is_empty() { - true => ( - quote::quote! { register_input }, - quote::quote! { register_output }, - ), - false => { - let generics = self.generics.split_for_impl().1; - - ( - quote::quote! { register_input::#generics }, - quote::quote! { register_output::#generics }, - ) - } - }; - - let mut variables = quote::quote! {}; - - for (pos, (ident, ty)) in self.state_inputs.iter().enumerate() { - variables.extend(quote::quote! { - let #ident = <&#ty as CubeType>::ExpandType = - *inputs.remove(&#pos).unwrap().downcast().unwrap(); - }); - } - - for (pos, (ident, ty)) in self.state_outputs.iter().enumerate() { - variables.extend(quote::quote! { - let #ident = <&mut #ty as CubeType>::ExpandType = - *outputs.remove(&#pos).unwrap().downcast().unwrap(); - }); - } - - let mut tokens = quote::quote! { - let mut builder = KernelBuilder::default(); - - let mut inputs: std::collections::BTreeMap> = std::collections::BTreeMap::new(); - let mut outputs: std::collections::BTreeMap> = std::collections::BTreeMap::new(); - - #register_input - #register_output - }; - - if num_inputs > 0 { - tokens.extend(quote::quote! { - for i in 0..#num_inputs { - if !inputs.contains_key(&i) { - inputs.insert(i, #register_input_call(&mut builder, &self.settings, i)); - } - } - }); - } - - tokens.extend(quote::quote! { - for mapping in self.settings.mappings.iter() { - let input = inputs.get(&mapping.pos_input).unwrap(); - outputs.insert(mapping.pos_output, input.clone()); - } - }); - - if num_outputs > 0 { - tokens.extend(quote::quote! { - for i in 0..#num_outputs { - if !outputs.contains_key(&i) { - outputs.insert(i, #register_output_call(&mut builder, &self.settings, i)); - } - } - }); - } - - tokens - } - - fn gen_compile_impl(&self, expand: &TokenStream) -> TokenStream { - let ident = Ident::new(&self.name, Span::call_site()); - let generics = add_runtime(self.generics.clone()); - let (impl_gen, ty_gen, where_gen) = generics.split_for_impl(); - - let mut args = quote::quote! { self.settings.clone(), }; - - for (_, ident) in self.state_comptimes.iter() { - args.extend(quote::quote! { self.#ident.clone(), }); - } - - let define_args = self.gen_define_args(); - let define_impl = self.gen_define_impl(expand); - - quote::quote! { - impl #impl_gen Kernel for #ident #ty_gen #where_gen { - fn define(&self) -> KernelDefinition { - #define_args - #define_impl - } - - fn id(&self) -> cubecl::KernelId { - cubecl::KernelId::new::().info((#args)) - } - } - } - } - - fn phantoms(&self, generics: &Generics, declaration: bool) -> TokenStream { - let mut phantoms = quote::quote! {}; - - for param in generics.params.iter() { - let ty = match param { - syn::GenericParam::Type(ty) => ty, - _ => continue, - }; - let ident = Ident::new( - format!("_{}", ty.ident.to_string().to_lowercase()).as_str(), - Span::call_site(), - ); - let ty = &ty.ident; - if declaration { - phantoms.extend(quote::quote! { - #ident: core::marker::PhantomData<#ty>, - }); - } else { - phantoms.extend(quote::quote! { - #ident: core::marker::PhantomData::<#ty>, - }); - } - } - phantoms - } - - fn gen_launch_body(&self) -> TokenStream { - let ident = Ident::new(&self.name, Span::call_site()); - let generics = add_runtime(self.generics.clone()); - let phantoms = self.phantoms(&generics, false); - - let mut comptimes = quote::quote! {}; - let settings = self.gen_settings(); - - let mut body = quote::quote! { - let mut launcher = KernelLauncher::::default(); - }; - - for (input, _) in self.state_inputs.iter() { - body.extend(quote::quote! { - #input.register(&mut launcher); - }); - } - - for (input, _) in self.state_outputs.iter() { - body.extend(quote::quote! { - #input.register(&mut launcher); - }); - } - - for (_ty, ident) in self.state_comptimes.iter() { - comptimes.extend(quote::quote! { - #ident, - }); - } - - let kernel = quote::quote! { - #ident { - settings, - #comptimes - #phantoms - } - }; - - let mut tokens = quote::quote! { - #settings - - let kernel = #kernel; - - #body - }; - - if self.unchecked { - tokens.extend(quote::quote! { - launcher.launch_unchecked(cube_count, kernel, client); - }); - } else { - tokens.extend(quote::quote! { - launcher.launch(cube_count, kernel, client); - }); - } - - tokens - } -} - -pub fn codegen_launch(sig: &syn::Signature, unchecked: bool) -> TokenStream { - let codegen = Codegen::from_sig(sig, unchecked); - - let ident = &sig.ident; - - let ident_expand = quote::quote! { - __expand - }; - - let generics = add_runtime(add_lifetime(sig.generics.clone())); - let body = codegen.gen_launch_body(); - let kernel = codegen.gen_kernel_struct(); - let compile = codegen.gen_compile_impl(&ident_expand); - let (inputs, output) = (codegen.fn_inputs, codegen.fn_output); - let doc = format!("Launch the kernel [{ident}()] on the given runtime."); - - let maybe_unsafe = if unchecked { - quote::quote! {unsafe} - } else { - quote::quote! {} - }; - let launch_name = if unchecked { - quote::quote! { launch_unchecked} - } else { - quote::quote! { launch} - }; - - quote::quote! { - #kernel - #compile - - #[allow(clippy::too_many_arguments)] - #[doc = #doc] - pub #maybe_unsafe fn #launch_name #generics ( - client: &ComputeClient, - cube_count: CubeCount, - cube_dim: CubeDim, - #inputs - ) -> #output { - #body; - } - } -} - -pub fn add_lifetime(mut generics: Generics) -> Generics { - let lifetime: syn::Generics = parse_quote! {<'a>}; - - generics - .params - .insert(0, lifetime.params.into_iter().next().unwrap()); - generics -} - -pub fn add_runtime(mut generics: Generics) -> Generics { - let runtime: syn::Generics = parse_quote! { }; - - generics - .params - .push(runtime.params.into_iter().next().unwrap()); - generics -} - -fn first_generic_ty(ty: &syn::Type) -> syn::Type { - match ty { - syn::Type::Path(pat) => match &pat.path.segments.first().unwrap().arguments { - syn::PathArguments::AngleBracketed(ty) => match ty.args.first().unwrap() { - syn::GenericArgument::Type(ty) => ty.clone(), - _ => panic!("Should have a generic type"), - }, - _ => panic!("Comptime must have a generic"), - }, - _ => todo!(), - } -} - -fn no_ref(ty: &syn::Type) -> &syn::Type { - match ty { - syn::Type::Reference(val) => &val.elem, - _ => ty, - } -} - -fn snake_to_pascal_case(input: &str) -> String { - input - .split('_') - .filter(|s| !s.is_empty()) - .map(|s| { - let mut c = s.chars(); - match c.next() { - None => String::new(), - Some(f) => f.to_uppercase().collect::() + c.as_str(), - } - }) - .collect() -} diff --git a/crates/cubecl-macros/src/codegen_function/mod.rs b/crates/cubecl-macros/src/codegen_function/mod.rs deleted file mode 100644 index ed9bff87..00000000 --- a/crates/cubecl-macros/src/codegen_function/mod.rs +++ /dev/null @@ -1,10 +0,0 @@ -mod base; -mod branch; -mod expr; -mod function; -mod launch; -mod operation; -mod variable; - -pub(crate) use base::codegen_statement; -pub(crate) use launch::codegen_launch; diff --git a/crates/cubecl-macros/src/codegen_function/operation.rs b/crates/cubecl-macros/src/codegen_function/operation.rs deleted file mode 100644 index 18b6ee80..00000000 --- a/crates/cubecl-macros/src/codegen_function/operation.rs +++ /dev/null @@ -1,282 +0,0 @@ -use crate::tracker::VariableTracker; - -use super::{ - base::{Codegen, CodegenKind}, - expr::codegen_expr, -}; - -/// Codegen for binary operations (+, -, *, etc.) -pub(crate) fn codegen_binary( - binary: &syn::ExprBinary, - loop_level: usize, - variable_tracker: &mut VariableTracker, -) -> Codegen { - let lhs = codegen_expr(&binary.left, loop_level, variable_tracker); - let (lhs, kind_lhs, lhs_array) = lhs.process(); - let (rhs, kind_rhs, _) = codegen_expr(&binary.right, loop_level, variable_tracker).process(); - - if matches!(kind_lhs, CodegenKind::Comptime) && matches!(kind_rhs, CodegenKind::Comptime) { - return Codegen::new( - quote::quote! { - #binary - }, - CodegenKind::Comptime, - ); - } - - Codegen::new( - match binary.op { - syn::BinOp::Add(_) => quote::quote! { - { - let _lhs = #lhs; - let _rhs = #rhs; - cubecl::frontend::add::expand(context, _lhs, _rhs) - } - }, - syn::BinOp::Sub(_) => quote::quote! { - { - let _lhs = #lhs; - let _rhs = #rhs; - cubecl::frontend::sub::expand(context, _lhs, _rhs) - } - }, - syn::BinOp::Mul(_) => quote::quote! { - { - let _lhs = #lhs; - let _rhs = #rhs; - cubecl::frontend::mul::expand(context, _lhs, _rhs) - } - }, - syn::BinOp::Div(_) => quote::quote! { - { - let _lhs = #lhs; - let _rhs = #rhs; - cubecl::frontend::div::expand(context, _lhs, _rhs) - } - }, - syn::BinOp::Rem(_) => quote::quote! { - { - let _lhs = #lhs; - let _rhs = #rhs; - cubecl::frontend::rem::expand(context, _lhs, _rhs) - } - }, - syn::BinOp::Ne(_) => quote::quote! { - { - let _lhs = #lhs; - let _rhs = #rhs; - cubecl::frontend::ne::expand(context, _lhs, _rhs) - } - }, - syn::BinOp::Gt(_) => quote::quote! { - { - let _lhs = #lhs; - let _rhs = #rhs; - cubecl::frontend::gt::expand(context, _lhs, _rhs) - } - }, - syn::BinOp::Ge(_) => quote::quote! { - { - let _lhs = #lhs; - let _rhs = #rhs; - cubecl::frontend::ge::expand(context, _lhs, _rhs) - } - }, - syn::BinOp::Lt(_) => quote::quote! { - { - let _lhs = #lhs; - let _rhs = #rhs; - cubecl::frontend::lt::expand(context, _lhs, _rhs) - } - }, - syn::BinOp::Le(_) => quote::quote! { - { - let _lhs = #lhs; - let _rhs = #rhs; - cubecl::frontend::le::expand(context, _lhs, _rhs) - } - }, - syn::BinOp::Eq(_) => quote::quote! { - { - let _lhs = #lhs; - let _rhs = #rhs; - cubecl::frontend::eq::expand(context, _lhs, _rhs) - } - }, - syn::BinOp::AddAssign(_) => { - if let Some(array) = lhs_array { - let (array, index) = (array.array, array.index); - - quote::quote! { - { - let _array = #array; - let _index = #index; - let _value = #rhs; - cubecl::frontend::add_assign_array_op::expand(context, _array, _index, _value) - } - } - } else { - quote::quote! { - { - let _lhs = #lhs; - let _rhs = #rhs; - cubecl::frontend::add_assign_op::expand(context, _lhs, _rhs) - } - } - } - } - syn::BinOp::SubAssign(_) => { - if let Some(array) = lhs_array { - let (array, index) = (array.array, array.index); - - quote::quote! { - { - let _array = #array; - let _index = #index; - let _value = #rhs; - cubecl::frontend::sub_assign_array_op::expand(context, _array, _index, _value) - } - } - } else { - quote::quote! { - { - let _lhs = #lhs; - let _rhs = #rhs; - cubecl::frontend::sub_assign_op::expand(context, _lhs, _rhs) - } - } - } - } - syn::BinOp::MulAssign(_) => { - if let Some(array) = lhs_array { - let (array, index) = (array.array, array.index); - - quote::quote! { - { - let _array = #array; - let _index = #index; - let _value = #rhs; - cubecl::frontend::mul_assign_array_op::expand(context, _array, _index, _value) - } - } - } else { - quote::quote! { - { - let _lhs = #lhs; - let _rhs = #rhs; - cubecl::frontend::mul_assign_op::expand(context, _lhs, _rhs) - } - } - } - } - syn::BinOp::DivAssign(_) => { - if let Some(array) = lhs_array { - let (array, index) = (array.array, array.index); - - quote::quote! { - { - let _array = #array; - let _index = #index; - let _value = #rhs; - cubecl::frontend::div_assign_array_op::expand(context, _array, _index, _value) - } - } - } else { - quote::quote! { - { - let _lhs = #lhs; - let _rhs = #rhs; - cubecl::frontend::div_assign_op::expand(context, _lhs, _rhs) - } - } - } - } - syn::BinOp::And(_) => quote::quote! { - { - - let _lhs = #lhs; - let _rhs = #rhs; - cubecl::frontend::and::expand(context, _lhs, _rhs) - } - }, - syn::BinOp::Or(_) => quote::quote! { - { - let _lhs = #lhs; - let _rhs = #rhs; - cubecl::frontend::or::expand(context, _lhs, _rhs) - } - }, - syn::BinOp::BitOr(_) => quote::quote! { - { - - let _lhs = #lhs; - let _rhs = #rhs; - cubecl::frontend::bitor::expand(context, _lhs, _rhs) - } - }, - syn::BinOp::BitAnd(_) => quote::quote! { - { - - let _lhs = #lhs; - let _rhs = #rhs; - cubecl::frontend::bitand::expand(context, _lhs, _rhs) - } - }, - syn::BinOp::BitXor(_) => quote::quote! { - { - let _lhs = #lhs; - let _rhs = #rhs; - cubecl::frontend::bitxor::expand(context, _lhs, _rhs) - } - }, - syn::BinOp::Shl(_) => quote::quote! { - { - let _lhs = #lhs; - let _rhs = #rhs; - cubecl::frontend::shl::expand(context, _lhs, _rhs) - } - }, - syn::BinOp::Shr(_) => quote::quote! { - { - let _lhs = #lhs; - let _rhs = #rhs; - cubecl::frontend::shr::expand(context, _lhs, _rhs) - } - }, - _ => todo!("Codegen: unsupported op {:?}", binary.op), - }, - CodegenKind::Expand, - ) -} - -/// Codegen for unary operations -pub(crate) fn codegen_unary( - unary: &syn::ExprUnary, - loop_level: usize, - variable_tracker: &mut VariableTracker, -) -> Codegen { - let (inner, kind, _) = codegen_expr(&unary.expr, loop_level, variable_tracker).process(); - - if matches!(kind, CodegenKind::Comptime) { - return Codegen::new( - quote::quote! { - #unary - }, - CodegenKind::Comptime, - ); - } - - Codegen::new( - match unary.op { - syn::UnOp::Not(_) => quote::quote! { - { - let _inner = #inner; - cubecl::frontend::not::expand(context, _inner) - } - }, - syn::UnOp::Deref(_) => inner, - _ => todo!("Codegen: unsupported op {:?}", unary.op), - }, - CodegenKind::Expand, - ) -} diff --git a/crates/cubecl-macros/src/codegen_function/variable.rs b/crates/cubecl-macros/src/codegen_function/variable.rs deleted file mode 100644 index e6f7f2fd..00000000 --- a/crates/cubecl-macros/src/codegen_function/variable.rs +++ /dev/null @@ -1,322 +0,0 @@ -use proc_macro2::TokenStream; -use quote::ToTokens; -use syn::{punctuated::Punctuated, FieldValue, Lit, Member, PathArguments, Token}; - -use crate::{analyzer::KEYWORDS, codegen_function::expr::codegen_expr, tracker::VariableTracker}; - -use super::base::{Codegen, CodegenKind}; - -/// Codegen for literals -pub(crate) fn codegen_lit(lit: &syn::ExprLit) -> TokenStream { - match lit.lit { - // We treat floats differently to avoid getting 4..into() for instance - Lit::Float(_) => { - let lit_str = lit.lit.to_token_stream().to_string(); - let float_lit = lit_str.parse::().unwrap(); - quote::quote! { #float_lit } - } - _ => { - quote::quote! { #lit } - } - } -} - -/// Codegen for arrays of literals -pub(crate) fn codegen_array_lit(array: &syn::ExprArray) -> TokenStream { - let mut tokens = quote::quote! {}; - for element in array.elems.iter() { - let token = match element { - syn::Expr::Lit(lit) => Codegen::new(codegen_lit(lit), CodegenKind::Literal), - _ => { - return syn::Error::new_spanned(array, "Only arrays of literals are supported") - .into_compile_error() - } - }; - tokens.extend(quote::quote! { #token, }); - } - quote::quote! { [ #tokens ] } -} - -/// Codegen for a local declaration (let ...) -/// Supports: -/// let x = ... -/// let x: T = ... -/// let _ = ... -/// let (a, b) = ... -/// let mut _ = ... -pub(crate) fn codegen_local( - local: &syn::Local, - loop_level: usize, - variable_tracker: &mut VariableTracker, -) -> TokenStream { - let let_tok = local.let_token; - - let ident = match &local.pat { - syn::Pat::Ident(ident) => ident.to_token_stream(), - syn::Pat::Type(pat_type) => match &*pat_type.pat { - syn::Pat::Ident(pat_ident) => pat_ident.to_token_stream(), - _ => todo!("Codegen: Unsupported typed path {:?}", pat_type.pat), - }, - syn::Pat::Wild(wild) => wild.underscore_token.to_token_stream(), - syn::Pat::Tuple(_) => { - // destructuring pattern; we can just return it as is - return quote::quote! { - #local - }; - } - _ => todo!("Codegen: Declaration {:?} is unsupported.", local.pat), - }; - - variable_tracker.codegen_declare(ident.to_string(), loop_level as u8); - - match local.init.as_ref() { - Some(init) => { - let (init, kind, _) = codegen_expr(&init.expr, loop_level, variable_tracker).process(); - - if matches!(kind, CodegenKind::Comptime) { - variable_tracker - .set_as_comptime(ident.to_string(), loop_level as u8, None) - .unwrap(); - } - - if matches!(kind, CodegenKind::Comptime) { - quote::quote! { - #let_tok #ident = #init; - } - } else { - quote::quote! { - #let_tok #ident = { - let _inner = #init; - cubecl::frontend::Init::init(_inner, context) - }; - } - } - } - None => { - quote::quote! { - #let_tok #ident; - } - } - } -} - -/// Codegen for indexed access -pub(crate) fn codegen_index( - index: &syn::ExprIndex, - loop_level: usize, - variable_tracker: &mut VariableTracker, -) -> Codegen { - let array = codegen_expr(&index.expr, loop_level, variable_tracker); - let index = codegen_expr(&index.index, loop_level, variable_tracker); - - let tokens = quote::quote! { - { - let _array = #array; - let _index = #index; - cubecl::frontend::index::expand(context, _array, _index) - } - }; - - let mut codegen = Codegen::new(tokens, CodegenKind::Expand); - codegen.set_array_indexing(Some(super::base::ArrayIndexing { - array: array.tokens(), - index: index.tokens(), - })); - - codegen -} - -/// Codegen for assignation -/// Supports: -/// - scalar -/// - indexed array -pub(crate) fn codegen_assign( - assign: &syn::ExprAssign, - loop_level: usize, - variable_tracker: &mut VariableTracker, -) -> TokenStream { - match assign.left.as_ref() { - syn::Expr::Index(index) => { - let array = codegen_expr(&index.expr, loop_level, variable_tracker); - let index = codegen_expr(&index.index, loop_level, variable_tracker); - let value = codegen_expr(&assign.right, loop_level, variable_tracker); - - quote::quote! { - { - let _array = #array; - let _index = #index; - let _value = #value; - cubecl::frontend::index_assign::expand(context, _array, _index, _value) - } - } - } - syn::Expr::Unary(_) | syn::Expr::Field(_) | syn::Expr::Path(_) => { - let lhs = codegen_expr(&assign.left, loop_level, variable_tracker); - let rhs = codegen_expr(&assign.right, loop_level, variable_tracker); - - quote::quote! { - { - let _assign_lhs = #lhs; - let _assign_rhs = #rhs; - cubecl::frontend::assign::expand(context, _assign_rhs, _assign_lhs) - } - } - } - _ => todo!("Assign of expr {:?} unsupported", assign.left), - } -} - -pub(crate) fn codegen_path_var( - path: &syn::ExprPath, - loop_level: usize, - variable_tracker: &mut VariableTracker, -) -> Codegen { - let ident = match path.path.get_ident() { - Some(ident) => ident, - None => { - return Codegen::new( - quote::quote! { - #path - }, - CodegenKind::Expand, - ); - } - }; - - let name = ident.to_string(); - - if name == "None" { - return Codegen::new(quote::quote! { None }, CodegenKind::Comptime); - } - - if KEYWORDS.contains(&name.as_str()) { - Codegen::new( - quote::quote! { - #ident :: expand(context) - }, - CodegenKind::Expand, - ) - } else { - let (will_be_used_again, is_comptime) = variable_tracker - .codegen_reuse(name, loop_level as u8, None) - .unwrap_or((true, false)); - - let kind = if is_comptime { - CodegenKind::Comptime - } else { - CodegenKind::Expand - }; - - let output = if will_be_used_again { - quote::quote! { - #ident.clone() - } - } else { - quote::quote! { - #ident - } - }; - - Codegen::new(output, kind) - } -} - -/// Codegen for a field used in rhs of a statement -/// This function adds cloning when necessary -pub(crate) fn codegen_field( - field: &syn::ExprField, - loop_level: usize, - variable_tracker: &mut VariableTracker, -) -> TokenStream { - let (struct_, field) = if let Member::Named(attribute_ident) = &field.member { - if let syn::Expr::Path(struct_expr) = &*field.base { - let struct_ident = struct_expr - .path - .get_ident() - .expect("Codegen: field access only supported on ident struct."); - - (struct_ident, attribute_ident) - } else { - todo!("Codegen: field access only supported on ident struct."); - } - } else { - todo!("Codegen: unnamed attribute not supported."); - }; - - let (will_be_used_again, _) = variable_tracker - .codegen_reuse( - struct_.to_string(), - loop_level as u8, - Some(field.to_string()), - ) - .unwrap(); - - if will_be_used_again { - quote::quote! { - #struct_ . #field .clone() - } - } else { - quote::quote! { - #struct_ . #field - } - } -} - -// Codegen for a struct declaration -pub(crate) fn codegen_struct( - struct_: &syn::ExprStruct, - loop_level: usize, - variable_tracker: &mut VariableTracker, -) -> TokenStream { - let mut deconstructed_path = Vec::new(); - for segment in struct_.path.segments.iter() { - let generics = if let PathArguments::AngleBracketed(arguments) = &segment.arguments { - Some(arguments) - } else { - None - }; - deconstructed_path.push((&segment.ident, generics)); - } - - let (struct_name, generics) = deconstructed_path - .pop() - .expect("At least one ident in the path"); - - // This is hacky but using ::ExpandType {...} is experimental in Rust - let expanded_struct_name = syn::Ident::new( - format!("{}Expand", struct_name).as_str(), - proc_macro2::Span::call_site(), - ); - - deconstructed_path.push((&expanded_struct_name, generics)); - - // Reconstruct the path - let mut path_tokens = quote::quote! {}; - for (ident, angle_bracketed_generics) in deconstructed_path { - let ident_tokens = ident.to_token_stream(); - let generics_tokens = angle_bracketed_generics.to_token_stream(); - - path_tokens.extend(quote::quote! { - #ident_tokens #generics_tokens - }); - } - - let fields = codegen_field_creation(&struct_.fields, loop_level, variable_tracker); - quote::quote! { - #path_tokens { #fields } - } -} - -fn codegen_field_creation( - fields: &Punctuated, - loop_level: usize, - variable_tracker: &mut VariableTracker, -) -> TokenStream { - let mut field_tokens = quote::quote! {}; - for field in fields.iter() { - let field_name_token = &field.member; - let field_value_token = codegen_expr(&field.expr, loop_level, variable_tracker); - field_tokens.extend(quote::quote! { #field_name_token : #field_value_token, }); - } - field_tokens -} diff --git a/crates/cubecl-macros/src/codegen_trait/mod.rs b/crates/cubecl-macros/src/codegen_trait/mod.rs deleted file mode 100644 index 7ee64401..00000000 --- a/crates/cubecl-macros/src/codegen_trait/mod.rs +++ /dev/null @@ -1,112 +0,0 @@ -use proc_macro2::TokenStream; - -use crate::codegen_common::signature::{expand_sig, ExpandMode}; - -pub fn expand_trait_def(mut tr: syn::ItemTrait) -> proc_macro2::TokenStream { - let mut expand_items = Vec::new(); - - for item in tr.items.iter() { - match item { - syn::TraitItem::Fn(func) => { - let expand = expand_sig( - &func.sig, - &syn::Visibility::Inherited, - None, - ExpandMode::MethodImpl, - ); - expand_items.push(syn::parse_quote!(#expand;)); - } - _ => continue, - } - } - tr.items.append(&mut expand_items); - - quote::quote! { - #[allow(clippy::too_many_arguments)] - #tr - } -} - -pub fn expand_trait_impl(mut tr: syn::ItemImpl) -> proc_macro2::TokenStream { - let mut expand_items = Vec::new(); - - for item in tr.items.iter() { - match item { - syn::ImplItem::Fn(func) => { - let ident = &func.sig.ident; - let ident = quote::quote! {#ident::__expand}; - let mut inputs = quote::quote!(); - - for input in &func.sig.inputs { - match input { - syn::FnArg::Typed(pat) => { - let ident = pat.pat.clone(); - inputs.extend(quote::quote! { - #ident, - }); - } - _ => todo!("Only Typed inputs are supported"), - } - } - - let expand = expand_sig( - &func.sig, - &syn::Visibility::Inherited, - None, - ExpandMode::MethodImpl, - ); - - let tokens = if !tr.generics.params.is_empty() { - let mut func = func.clone(); - for param in tr.generics.params.iter() { - func.sig.generics.params.push(param.clone()); - } - register_expand(&func, &ident, expand, inputs) - } else { - register_expand(func, &ident, expand, inputs) - }; - - expand_items.push(syn::parse2(tokens).unwrap()); - } - _ => continue, - } - } - tr.items.append(&mut expand_items); - - quote::quote! { - #[allow(clippy::too_many_arguments)] - #tr - } -} - -fn register_expand( - func: &syn::ImplItemFn, - name: &TokenStream, - expand: proc_macro2::TokenStream, - inputs: proc_macro2::TokenStream, -) -> proc_macro2::TokenStream { - let (func, func_expand) = if func.sig.generics.params.is_empty() { - ( - quote::quote! { #func }, - quote::quote! { - #name(context, #inputs) - }, - ) - } else { - let (_, gen, _) = &func.sig.generics.split_for_impl(); - ( - quote::quote! { #func }, - quote::quote! { - #name::#gen(context, #inputs) - }, - ) - }; - - quote::quote! ( - #expand { - #[cube] - #func - #func_expand - } - ) -} diff --git a/crates/cubecl-macros/src/codegen_type/base.rs b/crates/cubecl-macros/src/codegen_type/base.rs deleted file mode 100644 index 47557307..00000000 --- a/crates/cubecl-macros/src/codegen_type/base.rs +++ /dev/null @@ -1,295 +0,0 @@ -use proc_macro::TokenStream; -use quote::quote; -use syn::Ident; - -use super::GenericsCodegen; - -struct TypeCodegen { - name: syn::Ident, - name_launch: syn::Ident, - name_expand: syn::Ident, - fields: Vec, - generics: GenericsCodegen, - vis: syn::Visibility, -} - -impl TypeCodegen { - pub fn expand_ty(&self) -> proc_macro2::TokenStream { - let mut fields = quote::quote! {}; - let name = &self.name_expand; - - for field in self.fields.iter() { - let ident = &field.ident; - let ty = &field.ty; - let vis = &field.vis; - - fields.extend(quote! { - #vis #ident: <#ty as CubeType>::ExpandType, - }); - } - - let generics = self.generics.type_definitions(); - let vis = &self.vis; - - quote! { - #[derive(Clone)] - #vis struct #name #generics { - #fields - } - } - } - - pub fn launch_ty(&self) -> proc_macro2::TokenStream { - let mut fields = quote::quote! {}; - let name = &self.name_launch; - - for field in self.fields.iter() { - let ident = &field.ident; - let ty = &field.ty; - let vis = &field.vis; - - fields.extend(quote! { - #vis #ident: <#ty as LaunchArg>::RuntimeArg<'a, R>, - }); - } - - let generics = self.generics.all_definitions(); - let vis = &self.vis; - - quote! { - #vis struct #name #generics { - #fields - } - } - } - - pub fn launch_new(&self) -> proc_macro2::TokenStream { - let mut args = quote::quote! {}; - let mut fields = quote::quote! {}; - let name = &self.name_launch; - - for field in self.fields.iter() { - let ident = &field.ident; - let ty = &field.ty; - let vis = &field.vis; - - args.extend(quote! { - #vis #ident: <#ty as LaunchArg>::RuntimeArg<'a, R>, - }); - fields.extend(quote! { - #ident, - }); - } - - let generics_impl = self.generics.all_definitions(); - let generics_use = self.generics.all_in_use(); - let vis = &self.vis; - - quote! { - impl #generics_impl #name #generics_use { - /// New kernel - #[allow(clippy::too_many_arguments)] - #vis fn new(#args) -> Self { - Self { - #fields - } - } - } - } - } - - pub fn arg_settings_impl(&self) -> proc_macro2::TokenStream { - let mut register_body = quote::quote! {}; - let mut configure_input_body = quote::quote! {}; - let mut configure_output_body = quote::quote! {}; - let name = &self.name_launch; - - for (pos, field) in self.fields.iter().enumerate() { - let ident = &field.ident; - - register_body.extend(quote! { - self.#ident.register(launcher); - }); - configure_input_body.extend(quote! { - settings = ArgSettings::::configure_input(&self.#ident, #pos, settings); - }); - configure_output_body.extend(quote! { - settings = ArgSettings::::configure_output(&self.#ident, #pos, settings); - }); - } - - let generics_impl = self.generics.all_definitions(); - let generics_use = self.generics.all_in_use(); - - quote! { - impl #generics_impl ArgSettings for #name #generics_use { - fn register(&self, launcher: &mut KernelLauncher) { - #register_body - } - - fn configure_input(&self, position: usize, mut settings: KernelSettings) -> KernelSettings { - #configure_input_body - - settings - } - - fn configure_output(&self, position: usize, mut settings: KernelSettings) -> KernelSettings { - #configure_output_body - - settings - } - } - } - } - - pub fn cube_type_impl(&self) -> proc_macro2::TokenStream { - let name = &self.name; - let name_expand = &self.name_expand; - - let generics_impl = self.generics.type_definitions(); - let generics_use = self.generics.type_in_use(); - - quote! { - impl #generics_impl CubeType for #name #generics_use { - type ExpandType = #name_expand #generics_use; - } - } - } - - pub fn launch_arg_impl(&self) -> proc_macro2::TokenStream { - let mut body_input = quote::quote! {}; - let mut body_output = quote::quote! {}; - let name = &self.name; - let name_launch = &self.name_launch; - let name_expand = &self.name_expand; - - for field in self.fields.iter() { - let ident = &field.ident; - let ty = &field.ty; - let vis = &field.vis; - - body_input.extend(quote! { - #vis #ident: <#ty as LaunchArgExpand>::expand(builder, vectorization), - }); - body_output.extend(quote! { - #vis #ident: <#ty as LaunchArgExpand>::expand_output(builder, vectorization), - }); - } - - let type_generics_impl = self.generics.type_definitions(); - let type_generics_use = self.generics.type_in_use(); - - let runtime_generics_impl = self.generics.runtime_definitions(); - let all_generics_use = self.generics.all_in_use(); - - quote! { - impl #type_generics_impl LaunchArg for #name #type_generics_use { - type RuntimeArg #runtime_generics_impl = #name_launch #all_generics_use; - } - - impl #type_generics_impl LaunchArgExpand for #name #type_generics_use { - fn expand( - builder: &mut KernelBuilder, - vectorization: cubecl::ir::Vectorization, - ) -> ::ExpandType { - #name_expand { - #body_input - } - } - fn expand_output( - builder: &mut KernelBuilder, - vectorization: cubecl::ir::Vectorization, - ) -> ::ExpandType { - #name_expand { - #body_output - } - } - } - } - } - - pub fn expand_type_impl(&self) -> proc_macro2::TokenStream { - let name_expand = &self.name_expand; - let type_generics_impl = self.generics.type_definitions(); - let type_generics_use = self.generics.type_in_use(); - - let mut body = quote::quote! {}; - for field in self.fields.iter() { - let ident = &field.ident; - body.extend(quote::quote! { - #ident: Init::init(self.#ident, context), - }); - } - - quote! { - impl #type_generics_impl Init for #name_expand #type_generics_use { - fn init(self, context: &mut CubeContext) -> Self { - Self { - #body - } - } - } - } - } -} - -pub(crate) fn generate_cube_type(ast: &syn::DeriveInput, with_launch: bool) -> TokenStream { - let name = ast.ident.clone(); - let generics = ast.generics.clone(); - let visibility = ast.vis.clone(); - - let name_string = name.to_string(); - let name_expand = Ident::new(format!("{}Expand", name_string).as_str(), name.span()); - let name_launch = Ident::new(format!("{}Launch", name_string).as_str(), name.span()); - - let mut fields = Vec::new(); - - match &ast.data { - syn::Data::Struct(struct_data) => { - for field in struct_data.fields.iter() { - fields.push(field.clone()); - } - } - syn::Data::Enum(_) => panic!("Only struct can be derived"), - syn::Data::Union(_) => panic!("Only struct can be derived"), - }; - - let codegen = TypeCodegen { - name, - name_launch, - name_expand, - fields, - generics: GenericsCodegen::new(generics), - vis: visibility, - }; - - let expand_ty = codegen.expand_ty(); - let launch_ty = codegen.launch_ty(); - let launch_new = codegen.launch_new(); - - let cube_type_impl = codegen.cube_type_impl(); - let arg_settings_impl = codegen.arg_settings_impl(); - let launch_arg_impl = codegen.launch_arg_impl(); - let expand_type_impl = codegen.expand_type_impl(); - - if with_launch { - quote! { - #expand_ty - #launch_ty - #launch_new - - #cube_type_impl - #arg_settings_impl - #launch_arg_impl - #expand_type_impl - } - .into() - } else { - quote! { - #expand_ty - #cube_type_impl - #expand_type_impl - } - .into() - } -} diff --git a/crates/cubecl-macros/src/codegen_type/generics.rs b/crates/cubecl-macros/src/codegen_type/generics.rs deleted file mode 100644 index b92170a0..00000000 --- a/crates/cubecl-macros/src/codegen_type/generics.rs +++ /dev/null @@ -1,81 +0,0 @@ -use proc_macro2::{Span, TokenStream}; -use quote::ToTokens; -use syn::{GenericParam, Generics, Ident, Lifetime, LifetimeParam, TypeParam}; - -pub(crate) struct GenericsCodegen { - arg_lifetime: syn::Generics, - arg_runtime: syn::Generics, - type_gens: syn::Generics, -} - -impl GenericsCodegen { - pub(crate) fn new(type_gens: syn::Generics) -> Self { - Self { - arg_lifetime: Self::arg_lifetime(), - arg_runtime: Self::arg_runtime(), - type_gens, - } - } - - fn arg_lifetime() -> Generics { - let mut generics = Generics::default(); - let lifetime = - GenericParam::Lifetime(LifetimeParam::new(Lifetime::new("'a", Span::call_site()))); - generics.params.push(lifetime); - generics - } - - fn arg_runtime() -> Generics { - let mut generics = Generics::default(); - let mut runtime_param = TypeParam::from(Ident::new("R", Span::call_site())); - runtime_param - .bounds - .push(syn::parse_str("Runtime").unwrap()); - let runtime = GenericParam::Type(runtime_param); - generics.params.push(runtime); - generics - } - - pub(crate) fn type_definitions(&self) -> TokenStream { - self.type_gens.to_token_stream() - } - - pub(crate) fn type_in_use(&self) -> TokenStream { - generics_in_use_codegen(self.type_gens.clone()) - } - - pub(crate) fn runtime_definitions(&self) -> TokenStream { - let mut generics = self.arg_runtime.clone(); - generics.params.extend(self.arg_lifetime.params.clone()); - generics.to_token_stream() - } - - pub(crate) fn all_definitions(&self) -> TokenStream { - let mut generics = self.arg_lifetime.clone(); - generics.params.extend(self.arg_runtime.params.clone()); - generics.params.extend(self.type_gens.params.clone()); - generics.to_token_stream() - } - - pub(crate) fn all_in_use(&self) -> TokenStream { - let mut generics = self.arg_lifetime.clone(); - generics.params.extend(self.arg_runtime.params.clone()); - generics.params.extend(self.type_gens.params.clone()); - generics_in_use_codegen(generics) - } -} - -fn generics_in_use_codegen(generics: Generics) -> TokenStream { - let mut tokens = quote::quote! {<}; - for generic in generics.params.iter() { - let ident = match generic { - GenericParam::Lifetime(param) => param.lifetime.to_token_stream(), - GenericParam::Type(param) => param.ident.to_token_stream(), - GenericParam::Const(_) => todo!("Const generic not supported"), - }; - tokens.extend(quote::quote! { #ident, }) - } - tokens.extend(quote::quote! {>}); - - tokens -} diff --git a/crates/cubecl-macros/src/codegen_type/mod.rs b/crates/cubecl-macros/src/codegen_type/mod.rs deleted file mode 100644 index 68f38dcd..00000000 --- a/crates/cubecl-macros/src/codegen_type/mod.rs +++ /dev/null @@ -1,5 +0,0 @@ -mod base; -mod generics; - -pub(crate) use base::*; -pub(crate) use generics::*; diff --git a/crates/cubecl-macros/src/error.rs b/crates/cubecl-macros/src/error.rs new file mode 100644 index 00000000..cfedb389 --- /dev/null +++ b/crates/cubecl-macros/src/error.rs @@ -0,0 +1,82 @@ +// modified from https://github.com/elastio/bon/blob/master/bon-macros/src/error.rs + +use proc_macro2::{TokenStream, TokenTree}; +use quote::{quote, ToTokens}; +use syn::parse::Parse; + +use crate::parse::helpers::is_helper; + +/// Handle the error returned from the macro logic. This may be either a syntax +/// error or a logic error. In either case, we want to return a [`TokenStream2`] +/// that still provides good IDE experience. See [`Fallback`] for details. +pub(crate) fn error_into_token_stream(err: syn::Error, item: TokenStream) -> TokenStream { + let compile_error = err.to_compile_error(); + + syn::parse2::(item) + .map(|fallback| quote!(#compile_error #fallback)) + .unwrap_or_else(|_| compile_error) +} + +/// This is used in error handling for better IDE experience. For example, while +/// the developer is writing the function code they'll have a bunch of syntax +/// errors in the process. While that happens the proc macro should output at +/// least some representation of the input code that the developer wrote with +/// a separate compile error entry. This keeps the syntax highlighting and IDE +/// type analysis, completions and other hints features working even if macro +/// fails to parse some syntax or finds some other logic errors. +/// +/// This utility does very low-level parsing to strip helper attributes (i.e. `#[comptime]`) from +/// the input. This is to prevent the IDE from showing errors for helper attributes that need to be +/// processed by this macro. +struct Fallback { + output: TokenStream, +} + +impl Parse for Fallback { + fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result { + let mut output = TokenStream::new(); + + loop { + let found_attr = input.step(|cursor| { + let mut cursor = *cursor; + while let Some((tt, next)) = cursor.token_tree() { + match &tt { + TokenTree::Group(group) => { + let fallback: Self = syn::parse2(group.stream())?; + let new_group = + proc_macro2::Group::new(group.delimiter(), fallback.output); + + output.extend([TokenTree::Group(new_group)]); + } + TokenTree::Punct(punct) if punct.as_char() == '#' => { + return Ok((true, cursor)); + } + TokenTree::Punct(_) | TokenTree::Ident(_) | TokenTree::Literal(_) => { + output.extend([tt]); + } + } + + cursor = next; + } + + Ok((false, cursor)) + })?; + + if !found_attr { + return Ok(Self { output }); + } + + input + .call(syn::Attribute::parse_outer)? + .into_iter() + .filter(|attr| !is_helper(attr)) + .for_each(|attr| attr.to_tokens(&mut output)); + } + } +} + +impl ToTokens for Fallback { + fn to_tokens(&self, tokens: &mut TokenStream) { + self.output.to_tokens(tokens); + } +} diff --git a/crates/cubecl-macros/src/expression.rs b/crates/cubecl-macros/src/expression.rs new file mode 100644 index 00000000..d8238750 --- /dev/null +++ b/crates/cubecl-macros/src/expression.rs @@ -0,0 +1,273 @@ +use std::{rc::Rc, sync::atomic::AtomicUsize}; + +use proc_macro2::{Span, TokenStream}; +use quote::{quote, ToTokens}; +use syn::{ + AngleBracketedGenericArguments, Ident, Lit, Member, Pat, Path, PathArguments, PathSegment, Type, +}; + +use crate::{operator::Operator, scope::Context, statement::Statement}; + +#[derive(Clone, Debug)] +pub enum Expression { + Binary { + left: Box, + operator: Operator, + right: Box, + ty: Option, + }, + Unary { + input: Box, + operator: Operator, + ty: Option, + }, + Variable { + name: Ident, + is_ref: bool, + is_mut: bool, + use_count: Rc, + ty: Option, + }, + ConstVariable { + name: Ident, + use_count: Rc, + ty: Option, + }, + FieldAccess { + base: Box, + field: Member, + }, + Path { + path: Path, + }, + Literal { + value: Lit, + ty: Type, + }, + Assigment { + left: Box, + right: Box, + ty: Option, + }, + Block(Block), + FunctionCall { + func: Box, + args: Vec, + associated_type: Option<(Path, PathSegment)>, + }, + CompilerIntrinsic { + func: Path, + args: Vec, + }, + MethodCall { + receiver: Box, + method: Ident, + generics: Option, + args: Vec, + }, + Closure { + params: Vec, + body: Box, + }, + Cast { + from: Box, + to: Type, + }, + Break, + /// Tokens not relevant to parsing + Verbatim { + tokens: TokenStream, + }, + VerbatimTerminated { + tokens: TokenStream, + }, + Continue(Span), + ForLoop { + range: Box, + unroll: Option>, + var_name: syn::Ident, + var_ty: Option, + block: Block, + }, + WhileLoop { + condition: Box, + block: Block, + }, + Loop(Block), + If { + condition: Box, + then_block: Block, + else_branch: Option>, + }, + Return { + expr: Option>, + span: Span, + _ty: Type, + }, + Range { + start: Box, + end: Option>, + span: Span, + inclusive: bool, + }, + Array { + elements: Vec, + span: Span, + }, + Tuple { + elements: Vec, + }, + Index { + expr: Box, + index: Box, + }, + Slice { + expr: Box, + span: Span, + _ranges: Vec, + }, + ArrayInit { + init: Box, + len: Box, + }, + Reference { + inner: Box, + }, + StructInit { + path: Path, + fields: Vec<(Member, Expression)>, + }, + Keyword { + name: syn::Ident, + }, +} + +#[derive(Clone, Debug)] +pub struct Block { + pub inner: Vec, + pub ret: Option>, + pub ty: Option, +} + +impl Expression { + pub fn ty(&self) -> Option { + match self { + Expression::Binary { ty, .. } => ty.clone(), + Expression::Unary { ty, .. } => ty.clone(), + Expression::Variable { ty, .. } => ty.clone(), + Expression::ConstVariable { ty, .. } => ty.clone(), + Expression::Literal { ty, .. } => Some(ty.clone()), + Expression::Assigment { ty, .. } => ty.clone(), + Expression::Verbatim { .. } => None, + Expression::Block(block) => block.ty.clone(), + Expression::FunctionCall { .. } => None, + Expression::Break { .. } => None, + Expression::Cast { to, .. } => Some(to.clone()), + Expression::Continue { .. } => None, + Expression::ForLoop { .. } => None, + Expression::FieldAccess { .. } => None, + Expression::MethodCall { .. } => None, + Expression::Path { .. } => None, + Expression::Range { start, .. } => start.ty(), + Expression::WhileLoop { .. } => None, + Expression::Loop { .. } => None, + Expression::If { then_block, .. } => then_block.ty.clone(), + Expression::Return { expr, .. } => expr.as_ref().and_then(|expr| expr.ty()), + Expression::Array { .. } => None, + Expression::Index { .. } => None, + Expression::Tuple { .. } => None, + Expression::Slice { expr, .. } => expr.ty(), + Expression::ArrayInit { init, .. } => init.ty(), + Expression::VerbatimTerminated { .. } => None, + Expression::Reference { inner } => inner.ty(), + Expression::StructInit { .. } => None, + Expression::Closure { .. } => None, + Expression::Keyword { .. } => None, + Expression::CompilerIntrinsic { .. } => None, + } + } + + pub fn is_const(&self) -> bool { + match self { + Expression::Literal { .. } => true, + Expression::Path { .. } => true, + Expression::Verbatim { .. } => true, + Expression::VerbatimTerminated { .. } => true, + Expression::ConstVariable { .. } => true, + Expression::FieldAccess { base, .. } => base.is_const(), + Expression::Reference { inner } => inner.is_const(), + Expression::Array { elements, .. } => elements.iter().all(|it| it.is_const()), + Expression::Tuple { elements, .. } => elements.iter().all(|it| it.is_const()), + Expression::CompilerIntrinsic { .. } => true, + Expression::MethodCall { + receiver, method, .. + } => receiver.is_const() && method != "runtime", + _ => false, + } + } + + pub fn as_const(&self, context: &mut Context) -> Option { + match self { + Expression::Literal { value, .. } => Some(quote![#value]), + Expression::Verbatim { tokens, .. } => Some(tokens.clone()), + Expression::VerbatimTerminated { tokens, .. } => Some(tokens.clone()), + Expression::ConstVariable { name, .. } => Some(quote![#name.clone()]), + Expression::Path { path, .. } => Some(quote![#path]), + Expression::Array { elements, .. } => { + let elements = elements + .iter() + .map(|it| it.as_const(context)) + .collect::>>()?; + Some(quote![[#(#elements),*]]) + } + Expression::Tuple { elements, .. } => { + let elements = elements + .iter() + .map(|it| it.as_const(context)) + .collect::>>()?; + Some(quote![(#(#elements),*)]) + } + Expression::FieldAccess { base, field, .. } => { + base.as_const(context).map(|base| quote![#base.#field]) + } + Expression::Reference { inner } => inner.as_const(context).map(|base| quote![&#base]), + Expression::MethodCall { .. } if self.is_const() => Some(self.to_tokens(context)), + _ => None, + } + } + + pub fn as_index(&self) -> Option<(&Expression, &Expression)> { + match self { + Expression::Index { expr, index, .. } => Some((&**expr, &**index)), + _ => None, + } + } + + pub fn needs_terminator(&self) -> bool { + match self { + Expression::If { then_block, .. } => then_block.ret.is_some(), + Expression::Block(block) => block.ret.is_some(), + Expression::ForLoop { .. } => false, + Expression::WhileLoop { .. } => false, + Expression::Loop { .. } => false, + Expression::VerbatimTerminated { .. } => false, + _ => true, + } + } +} + +pub fn is_intrinsic(path: &Path) -> bool { + // Add both possible import paths + let intrinsic_paths = [ + "::cubecl::prelude::vectorization_of", + "::cubecl::frontend::vectorization_of", + ]; + + let mut path = path.clone(); + // Strip function generics + path.segments.last_mut().unwrap().arguments = PathArguments::None; + let func_path = path.to_token_stream().to_string(); + intrinsic_paths + .iter() + .any(|path| path.ends_with(&func_path)) +} diff --git a/crates/cubecl-macros/src/generate/cube_trait.rs b/crates/cubecl-macros/src/generate/cube_trait.rs new file mode 100644 index 00000000..39b647a8 --- /dev/null +++ b/crates/cubecl-macros/src/generate/cube_trait.rs @@ -0,0 +1,58 @@ +use crate::parse::cube_trait::{CubeTrait, CubeTraitImpl, CubeTraitImplItem, CubeTraitItem}; +use proc_macro2::TokenStream; +use quote::quote; +use quote::ToTokens; + +impl ToTokens for CubeTrait { + fn to_tokens(&self, tokens: &mut TokenStream) { + let original_body = &self.original_trait.items; + let colon = &self.original_trait.colon_token; + let base_traits = &self.original_trait.supertraits; + let attrs = &self.attrs; + let vis = &self.vis; + let unsafety = &self.unsafety; + let name = &self.name; + let generics = &self.generics; + let fns = self.items.iter().filter_map(CubeTraitItem::func); + + let out = quote! { + #(#attrs)* + #[allow(clippy::too_many_arguments)] + #vis #unsafety trait #name #generics #colon #base_traits { + #(#original_body)* + + #( + #[allow(clippy::too_many_arguments)] + #fns; + )* + } + }; + tokens.extend(out); + } +} + +impl CubeTraitImpl { + pub fn to_tokens_mut(&mut self) -> TokenStream { + let unsafety = &self.unsafety; + let items = &self.original_items; + let fns = &self + .items + .iter_mut() + .filter_map(CubeTraitImplItem::func) + .map(|it| it.to_tokens_mut()) + .collect::>(); + let struct_name = &self.struct_name; + let trait_name = &self.trait_name; + let (generics, _, impl_where) = self.generics.split_for_impl(); + + quote! { + #unsafety impl #generics #trait_name for #struct_name #impl_where { + #(#items)* + #( + #[allow(unused, clone_on_copy, clippy::all)] + #fns + )* + } + } + } +} diff --git a/crates/cubecl-macros/src/generate/cube_type.rs b/crates/cubecl-macros/src/generate/cube_type.rs new file mode 100644 index 00000000..aa213ac9 --- /dev/null +++ b/crates/cubecl-macros/src/generate/cube_type.rs @@ -0,0 +1,261 @@ +use darling::FromDeriveInput; +use proc_macro2::TokenStream; +use quote::quote; +use syn::{Ident, Type, Visibility}; + +use crate::{ + parse::cube_type::{TypeCodegen, TypeField}, + paths::{core_type, prelude_type}, +}; + +impl TypeField { + pub fn expand_field(&self) -> TokenStream { + let cube_type = prelude_type("CubeType"); + let vis = &self.vis; + let name = self.ident.as_ref().unwrap(); + let ty = &self.ty; + if self.comptime.is_present() { + quote![#vis #name: #ty] + } else { + quote![#vis #name: <#ty as #cube_type>::ExpandType] + } + } + + pub fn launch_field(&self) -> TokenStream { + let launch_arg = prelude_type("LaunchArg"); + let vis = &self.vis; + let name = self.ident.as_ref().unwrap(); + let ty = &self.ty; + quote![#vis #name: <#ty as #launch_arg>::RuntimeArg<'a, R>] + } + + pub fn split(&self) -> (&Visibility, &Ident, &Type) { + (&self.vis, self.ident.as_ref().unwrap(), &self.ty) + } +} + +impl TypeCodegen { + pub fn expand_ty(&self) -> proc_macro2::TokenStream { + let fields = self.fields.iter().map(TypeField::expand_field); + let name = &self.name_expand; + let generics = &self.generics; + let vis = &self.vis; + + quote! { + #[derive(Clone)] + #vis struct #name #generics { + #(#fields),* + } + } + } + + pub fn launch_ty(&self) -> proc_macro2::TokenStream { + let name = &self.name_launch; + let fields = self.fields.iter().map(TypeField::launch_field); + let generics = self.expanded_generics(); + let vis = &self.vis; + + quote! { + #vis struct #name #generics { + #(#fields),* + } + } + } + + pub fn launch_new(&self) -> proc_macro2::TokenStream { + let args = self.fields.iter().map(TypeField::launch_field); + let fields = self.fields.iter().map(|field| &field.ident); + let name = &self.name_launch; + + let generics = self.expanded_generics(); + let (generics_impl, generics_use, where_clause) = generics.split_for_impl(); + let vis = &self.vis; + + quote! { + impl #generics_impl #name #generics_use #where_clause { + /// New kernel + #[allow(clippy::too_many_arguments)] + #vis fn new(#(#args),*) -> Self { + Self { + #(#fields),* + } + } + } + } + } + + pub fn arg_settings_impl(&self) -> proc_macro2::TokenStream { + let arg_settings = prelude_type("ArgSettings"); + let kernel_launcher = prelude_type("KernelLauncher"); + let kernel_settings = core_type("KernelSettings"); + let name = &self.name_launch; + let register_body = self + .fields + .iter() + .map(TypeField::split) + .map(|(_, ident, _)| quote![self.#ident.register(launcher)]); + let config_input_body = self.fields.iter().enumerate().map(|(pos, field)| { + let ident = &field.ident; + quote![settings = #arg_settings::::configure_input(&self.#ident, #pos, settings)] + }); + let config_output_body = self.fields.iter().enumerate().map(|(pos, field)| { + let ident = &field.ident; + quote![settings = #arg_settings::::configure_output(&self.#ident, #pos, settings)] + }); + + let generics = self.expanded_generics(); + let (generics, generic_names, where_clause) = generics.split_for_impl(); + + quote! { + impl #generics #arg_settings for #name #generic_names #where_clause { + fn register(&self, launcher: &mut #kernel_launcher) { + #(#register_body;)* + } + + fn configure_input(&self, position: usize, mut settings: #kernel_settings) -> #kernel_settings { + #(#config_input_body;)* + + settings + } + + fn configure_output(&self, position: usize, mut settings: #kernel_settings) -> #kernel_settings { + #(#config_output_body;)* + + settings + } + } + } + } + + pub fn cube_type_impl(&self) -> proc_macro2::TokenStream { + let cube_type = prelude_type("CubeType"); + let name = &self.ident; + let name_expand = &self.name_expand; + + let (generics, generic_names, where_clause) = self.generics.split_for_impl(); + + quote! { + impl #generics #cube_type for #name #generic_names #where_clause { + type ExpandType = #name_expand #generic_names; + } + } + } + + pub fn launch_arg_impl(&self) -> proc_macro2::TokenStream { + let launch_arg_expand = prelude_type("LaunchArgExpand"); + let body_input = self.fields.iter().map(TypeField::split).map(|(vis, name, ty)| { + quote![#vis #name: <#ty as #launch_arg_expand>::expand(builder, vectorization)] + }); + let body_output = self.fields.iter().map(TypeField::split).map(|(vis, name, ty)| { + quote![#vis #name: <#ty as #launch_arg_expand>::expand_output(builder, vectorization)] + }); + + let name = &self.ident; + let name_launch = &self.name_launch; + let name_expand = &self.name_expand; + + let (type_generics, type_generic_names, where_clause) = self.generics.split_for_impl(); + + let assoc_generics = self.assoc_generics(); + let all = self.expanded_generics(); + let (_, all_generic_names, _) = all.split_for_impl(); + + quote! { + impl #type_generics LaunchArg for #name #type_generic_names #where_clause { + type RuntimeArg #assoc_generics = #name_launch #all_generic_names; + } + + impl #type_generics LaunchArgExpand for #name #type_generic_names #where_clause { + fn expand( + builder: &mut KernelBuilder, + vectorization: cubecl::ir::Vectorization, + ) -> ::ExpandType { + #name_expand { + #(#body_input),* + } + } + fn expand_output( + builder: &mut KernelBuilder, + vectorization: cubecl::ir::Vectorization, + ) -> ::ExpandType { + #name_expand { + #(#body_output),* + } + } + } + } + } + + pub fn expand_type_impl(&self) -> proc_macro2::TokenStream { + let init = prelude_type("Init"); + let into_runtime = prelude_type("IntoRuntime"); + let context = prelude_type("CubeContext"); + let name = &self.ident; + let name_expand = &self.name_expand; + let (generics, generic_names, where_clause) = self.generics.split_for_impl(); + let body = self + .fields + .iter() + .map(TypeField::split) + .map(|(_, ident, _)| quote![#ident: #init::init(self.#ident, context)]); + let fields_to_runtime = self + .fields + .iter() + .map(TypeField::split) + .map(|(_, name, _)| quote![#name: self.#name.__expand_runtime_method(context)]); + + quote! { + impl #generics #init for #name_expand #generic_names #where_clause { + fn init(self, context: &mut #context) -> Self { + Self { + #(#body),* + } + } + } + + impl #generics #into_runtime for #name #generic_names #where_clause { + fn __expand_runtime_method(self, context: &mut CubeContext) -> Self::ExpandType { + let expand = #name_expand { + #(#fields_to_runtime),* + }; + Init::init(expand, context) + } + } + } + } +} + +pub(crate) fn generate_cube_type(ast: &syn::DeriveInput, with_launch: bool) -> TokenStream { + let codegen = match TypeCodegen::from_derive_input(ast) { + Ok(codegen) => codegen, + Err(e) => return e.write_errors(), + }; + + let expand_ty = codegen.expand_ty(); + let launch_ty = codegen.launch_ty(); + let launch_new = codegen.launch_new(); + + let cube_type_impl = codegen.cube_type_impl(); + let arg_settings_impl = codegen.arg_settings_impl(); + let launch_arg_impl = codegen.launch_arg_impl(); + let expand_type_impl = codegen.expand_type_impl(); + + if with_launch { + quote! { + #expand_ty + #launch_ty + #launch_new + + #cube_type_impl + #arg_settings_impl + #launch_arg_impl + #expand_type_impl + } + } else { + quote! { + #expand_ty + #cube_type_impl + #expand_type_impl + } + } +} diff --git a/crates/cubecl-macros/src/generate/expression.rs b/crates/cubecl-macros/src/generate/expression.rs new file mode 100644 index 00000000..46c52955 --- /dev/null +++ b/crates/cubecl-macros/src/generate/expression.rs @@ -0,0 +1,503 @@ +use proc_macro2::TokenStream; +use quote::{format_ident, quote, quote_spanned}; +use syn::{spanned::Spanned, Member, PathArguments}; + +use crate::{ + expression::{Block, Expression}, + operator::Operator, + paths::{frontend_path, frontend_type, prelude_type}, + scope::Context, +}; + +macro_rules! error { + ($span:expr, $fmt:literal $(,$args:expr)*) => { + syn::Error::new($span, format!($fmt $(,$args)*)).into_compile_error() + }; +} + +impl Expression { + pub fn to_tokens(&self, context: &mut Context) -> TokenStream { + match self { + Expression::Binary { + left, + operator, + right, + .. + } if operator.is_assign() && matches!(**left, Expression::Index { .. }) => { + let elem = frontend_type("ExpandElementTyped"); + let frontend_path = frontend_path(); + let (array, index) = left.as_index().unwrap(); + let array = array.to_tokens(context); + let index = index + .as_const(context) + .map(|as_const| quote![#elem::from_lit(#as_const)]) + .unwrap_or_else(|| index.to_tokens(context)); + let right = right + .as_const(context) + .map(|as_const| quote![#elem::from_lit(#as_const)]) + .unwrap_or_else(|| right.to_tokens(context)); + let op = format_ident!("{}", operator.array_op_name()); + quote! { + { + let _array = #array; + let _index = #index; + let _value = #right; + #frontend_path::#op::expand(context, _array, _index, _value) + } + } + } + Expression::Binary { + left, + operator, + right, + .. + } => { + let frontend_path = frontend_path(); + let op = format_ident!("{}", operator.op_name()); + let left = left.to_tokens(context); + let right = right.to_tokens(context); + quote! { + { + let _lhs = #left; + let _rhs = #right; + #frontend_path::#op::expand(context, _lhs, _rhs) + } + } + } + Expression::Unary { + input, + operator: Operator::Deref, + .. + } => input.to_tokens(context), + Expression::Unary { + input, operator, .. + } => { + let frontend_path = frontend_path(); + let input = input.to_tokens(context); + let op = format_ident!("{}", operator.op_name()); + quote! { + { + let _inner = #input; + #frontend_path::#op::expand(context, _inner) + } + } + } + Expression::Keyword { name } => { + quote![#name::expand(context)] + } + Expression::Variable { name, .. } => { + let last_use = context.try_consume(name); + if last_use { + quote![#name] + } else { + quote![#name.clone()] + } + } + Expression::FieldAccess { base, field, .. } => { + let base = base + .as_const(context) + .unwrap_or_else(|| base.to_tokens(context)); + quote![#base.#field.clone()] + } + Expression::Literal { value, .. } => { + let expand_elem = frontend_type("ExpandElementTyped"); + quote![#expand_elem::from_lit(#value)] + } + Expression::ConstVariable { name, .. } => { + let expand_elem = frontend_type("ExpandElementTyped"); + quote![#expand_elem::from_lit(#name)] + } + Expression::Assigment { left, right, .. } + if matches!(**left, Expression::Index { .. }) => + { + let (array, index) = left.as_index().unwrap(); + let array = array.to_tokens(context); + let index = index.to_tokens(context); + let right = right.to_tokens(context); + let frontend_path = frontend_path(); + quote! { + { + let _array = #array; + let _index = #index; + let _value = #right; + #frontend_path::index_assign::expand(context, _array, _index, _value) + } + } + } + Expression::Assigment { left, right, .. } => { + let frontend_path = frontend_path(); + let left = left.to_tokens(context); + let right = right.to_tokens(context); + quote! { + { + let _var = #left; + let _value = #right; + #frontend_path::assign::expand(context, _value, _var) + } + } + } + Expression::Index { expr, index } => { + let expr = expr.to_tokens(context); + let index = index.to_tokens(context); + let index_fn = frontend_type("index"); + quote! { + { + let _array = #expr; + let _index = #index; + #index_fn::expand(context, _array, _index) + } + } + } + Expression::FunctionCall { + func, + args, + associated_type: None, + .. + } => { + let (args, arg_names) = map_args(args, context); + let (generics, path) = split_generics(func, context); + quote! { + { + #(#args)* + #path::expand #generics(context, #(#arg_names),*) + } + } + } + Expression::CompilerIntrinsic { func, args } => { + let (args, arg_names) = map_args(args, context); + let mut path = func.clone(); + let generics = core::mem::replace( + &mut path.segments.last_mut().unwrap().arguments, + PathArguments::None, + ); + quote! { + { + #(#args)* + #path::expand #generics(context, #(#arg_names),*) + } + } + } + Expression::FunctionCall { + args, + associated_type: Some((ty_path, func)), + .. + } => { + let (args, arg_names) = map_args(args, context); + let mut name = func.clone(); + name.ident = format_ident!("__expand_{}", name.ident); + quote! { + { + #(#args)* + #ty_path::#name(context, #(#arg_names),*) + } + } + } + Expression::MethodCall { + receiver, + method, + generics, + args, + .. + } => { + let method = format_ident!("__expand_{method}_method"); + let receiver = receiver + .as_const(context) + .unwrap_or_else(|| receiver.to_tokens(context)); + let (args, arg_names) = map_args(args, context); + quote! { + { + #(#args)* + #receiver.#method #generics(context, #(#arg_names),*) + } + } + } + Expression::Break => { + let path = frontend_path(); + quote![#path::branch::break_expand(context);] + } + Expression::Continue(span) => error!(*span, "Continue not supported yet"), + Expression::Return { expr, span, .. } => { + if expr.is_some() { + error!(*span, "Only void return is supported.") + } else { + quote![cubecl::frontend::branch::return_expand(context);] + } + } + Expression::Cast { from, to } => { + let cast = prelude_type("Cast"); + let from = from.to_tokens(context); + let to = quote_spanned![to.span()=> <#to as #cast>]; + quote![#to::__expand_cast_from(context, #from)] + } + Expression::ForLoop { + range, + unroll, + var_name, + var_ty, + block, + } => { + let for_ty = frontend_type("branch"); + + let range = range.to_tokens(context); + let unroll = unroll + .as_ref() + .and_then(|it| it.as_const(context)) + .unwrap_or(quote![false]); + let block = context.with_restored_closure_scope(|ctx| block.to_tokens(ctx)); + let var_ty = var_ty.as_ref().map(|it| quote![: #it]); + + quote! { + { + let _range = #range; + let _unroll = #unroll; + #for_ty::for_expand(context, _range, _unroll, |context, #var_name #var_ty| #block); + } + } + } + Expression::WhileLoop { condition, block } => { + let while_ty = frontend_type("branch"); + let condition = condition.to_tokens(context); + let block = context.with_restored_closure_scope(|ctx| block.to_tokens(ctx)); + + quote![#while_ty::while_loop_expand(context, |context| #condition, |context| #block);] + } + Expression::Loop(block) => { + let loop_ty = frontend_type("branch"); + let block = context.with_restored_closure_scope(|ctx| block.to_tokens(ctx)); + + quote![#loop_ty::loop_expand(context, |context| #block);] + } + Expression::If { + condition, + then_block, + else_branch, + .. + } if condition.is_const() => { + let as_const = condition.as_const(context).unwrap(); + let then_block = context.with_restored_scope(|ctx| then_block.to_tokens(ctx)); + let else_branch = else_branch + .as_ref() + .map(|it| context.with_restored_scope(|ctx| it.to_tokens(ctx))) + .map(|it| quote![else #it]); + quote![if #as_const #then_block #else_branch] + } + Expression::If { + condition, + then_block, + else_branch: Some(else_branch), + } => { + let path = frontend_path(); + let condition = condition.to_tokens(context); + let then_block = + context.with_restored_closure_scope(|ctx| then_block.to_tokens(ctx)); + let else_branch = + context.with_restored_closure_scope(|ctx| else_branch.to_tokens(ctx)); + quote! { + { + let _cond = #condition; + #path::branch::if_else_expand(context, _cond.into(), |context| #then_block, |context| #else_branch); + } + } + } + Expression::If { + condition, + then_block, + .. + } => { + let path = frontend_path(); + let condition = condition.to_tokens(context); + let then_block = + context.with_restored_closure_scope(|ctx| then_block.to_tokens(ctx)); + quote! { + { + let _cond = #condition; + #path::branch::if_expand(context, _cond.into(), |context| #then_block); + } + } + } + Expression::Path { path, .. } => quote![#path], + Expression::Range { + start, + end, + inclusive, + span, + } => { + let start = start + .as_const(context) + .unwrap_or_else(|| start.to_tokens(context)); + if let Some(end) = end { + let range = frontend_type("RangeExpand"); + let end = end + .as_const(context) + .unwrap_or_else(|| end.to_tokens(context)); + quote! { + { + let _start = #start; + let _end = #end; + #range::new(_start.into(), _end.into(), #inclusive) + } + } + } else { + error!(*span, "Slice range not yet supported") + } + } + + Expression::Array { span, .. } => { + if let Some(constant) = self.as_const(context) { + constant + } else { + error!(*span, "Array expressions can't be used at runtime") + } + } + Expression::Tuple { elements, .. } => { + if let Some(constant) = self.as_const(context) { + constant + } else { + let elements = elements.iter().map(|it| it.to_tokens(context)); + quote![(#(#elements),*)] + } + } + + Expression::Slice { span, .. } => { + error!(*span, "Slice expressions not yet implemented") + } + Expression::ArrayInit { init, len } => { + let init_ty = frontend_type("ArrayInit"); + let init = init.to_tokens(context); + let len = len.to_tokens(context); + + quote![#init_ty::new(#len, #init)] + } + Expression::VerbatimTerminated { tokens } => tokens.clone(), + Expression::Reference { inner } => { + if let Some(as_const) = inner.as_const(context) { + quote![&#as_const] + } else { + let inner = inner.to_tokens(context); + quote![#inner] + } + } + Expression::StructInit { path, fields } => { + let cube_type = prelude_type("CubeType"); + let fields = init_fields(fields, context); + let path_last = path.segments.last().unwrap(); + let turbofish = path_last.arguments.clone(); + let generics = match &turbofish { + PathArguments::None => None, + PathArguments::AngleBracketed(params) => { + let params = params.args.iter(); + Some(quote![<#(#params),*>]) + } + args => { + return error!( + args.span(), + "Fn generics not supported when constructing runtime structs" + ) + } + }; + + quote! { + { + type _Ty #generics = <#path as #cube_type>::ExpandType; + _Ty #turbofish { #(#fields),* } + } + } + } + Expression::Closure { params, body, .. } => { + let body = context.with_restored_closure_scope(|ctx| body.to_tokens(ctx)); + quote![|context, #(#params),*| #body] + } + Expression::Verbatim { tokens, .. } => tokens.clone(), + Expression::Block(block) => context.with_restored_scope(|ctx| block.to_tokens(ctx)), + } + } +} + +impl Block { + pub fn to_tokens(&self, context: &mut Context) -> TokenStream { + let inner: Vec<_> = self.inner.iter().map(|it| it.to_tokens(context)).collect(); + let ret = if let Some(ret) = self.ret.as_ref() { + let as_const = ret.as_const(context); + if let Some(as_const) = as_const { + quote![#as_const.__expand_runtime_method(context)] + } else { + ret.to_tokens(context) + } + } else { + quote![()] + }; + + quote! { + { + #(#inner)* + #ret + } + } + } +} + +fn split_generics(path: &Expression, context: &mut Context) -> (PathArguments, TokenStream) { + let mut path = match path { + Expression::Path { path, .. } => path.clone(), + _ => return (PathArguments::None, path.to_tokens(context)), + }; + let generics = if let Some(last) = path.segments.last_mut() { + core::mem::replace(&mut last.arguments, PathArguments::None) + } else { + PathArguments::None + }; + (generics, quote![#path]) +} + +fn map_args(args: &[Expression], context: &mut Context) -> (Vec, Vec) { + let names: Vec<_> = (0..args.len()).map(|i| format_ident!("_arg_{i}")).collect(); + let values = names + .iter() + .zip(args.iter()) + .map(|(i, value)| { + if matches!(value, Expression::Closure { .. }) { + quote![] + } else { + let tokens = value + .as_const(context) + .unwrap_or_else(|| value.to_tokens(context)); + quote_spanned![tokens.span()=> let #i = #tokens;] + } + }) + .collect(); + let names = names + .into_iter() + .zip(args.iter()) + .map(|(name, value)| { + if matches!(value, Expression::Closure { .. }) { + value.to_tokens(context) + } else { + quote![#name.into()] + } + }) + .collect(); + (values, names) +} + +/// Since we no longer (unnecessarily) init immutable locals, we do need to init all struct fields +/// because of interior mutability. +fn init_fields<'a>( + fields: &'a [(Member, Expression)], + context: &'a mut Context, +) -> impl Iterator + 'a { + fields.iter().map(|(pat, it)| { + let init = frontend_type("Init"); + let it = if let Some(as_const) = it.as_const(context) { + let expand_elem = frontend_type("ExpandElementTyped"); + quote_spanned![as_const.span()=> #expand_elem::from_lit(#as_const)] + } else { + it.to_tokens(context) + }; + quote! { + #pat: { + let _init = #it; + #init::init(_init, context) + } + } + }) +} diff --git a/crates/cubecl-macros/src/generate/kernel.rs b/crates/cubecl-macros/src/generate/kernel.rs new file mode 100644 index 00000000..b1954adb --- /dev/null +++ b/crates/cubecl-macros/src/generate/kernel.rs @@ -0,0 +1,260 @@ +use darling::usage::{CollectLifetimes as _, CollectTypeParams as _, GenericsExt as _, Purpose}; +use proc_macro2::TokenStream; +use quote::{format_ident, quote, ToTokens}; +use std::iter; +use syn::Ident; + +use crate::{ + parse::kernel::{KernelFn, KernelParam, KernelSignature, Launch}, + paths::{core_type, prelude_path, prelude_type}, +}; + +impl KernelFn { + pub fn to_tokens_mut(&mut self) -> TokenStream { + let prelude_path = prelude_path(); + let sig = &self.sig; + let block = self + .context + .with_restored_scope(|ctx| self.block.to_tokens(ctx)); + + let out = quote! { + #sig { + use #prelude_path::IntoRuntime as _; + + #block + } + }; + + out + } +} + +impl ToTokens for KernelSignature { + fn to_tokens(&self, tokens: &mut TokenStream) { + let cube_context = prelude_type("CubeContext"); + let cube_type = prelude_type("CubeType"); + + let name = &self.name; + let generics = &self.generics; + let return_type = &self.returns; + let args = &self.parameters; + + let out = quote! { + fn #name #generics( + context: &mut #cube_context, + #(#args),* + ) -> <#return_type as #cube_type>::ExpandType + }; + tokens.extend(out); + } +} + +impl ToTokens for KernelParam { + fn to_tokens(&self, tokens: &mut TokenStream) { + let name = &self.name; + let ty = &self.normalized_ty; + tokens.extend(quote![#name: #ty]); + } +} + +impl Launch { + fn kernel_phantom_data(&self) -> Option { + let generics = self.kernel_generics.clone(); + let declared_lifetimes = generics.declared_lifetimes(); + let declared_type_params = generics.declared_type_params(); + + let used_lifetimes = self + .comptime_params() + .map(|param| ¶m.ty) + .collect_lifetimes_cloned(&Purpose::Declare.into(), &declared_lifetimes); + let used_type_params = self + .comptime_params() + .map(|param| ¶m.ty) + .collect_type_params_cloned(&Purpose::Declare.into(), &declared_type_params); + let lifetimes: Vec<_> = declared_lifetimes.difference(&used_lifetimes).collect(); + let type_params: Vec<_> = declared_type_params.difference(&used_type_params).collect(); + + (!lifetimes.is_empty() || !type_params.is_empty()) + .then(|| quote![__ty: ::core::marker::PhantomData<(#(#lifetimes,)* #(#type_params),*)>]) + } + + pub fn io_mappings(&self) -> TokenStream { + let launch_arg_expand = prelude_type("LaunchArgExpand"); + let expand_fn = |i, expand_name, vec_name, ty| { + quote! { + #i => ::std::sync::Arc::new(<#ty as #launch_arg_expand>::#expand_name(builder, settings.#vec_name(#i))) + } + }; + let inputs = self.runtime_inputs().enumerate().map(|(i, input)| { + expand_fn( + i, + format_ident!("expand"), + format_ident!("vectorization_input"), + input.ty_owned(), + ) + }); + let outputs = self.runtime_outputs().enumerate().map(|(i, output)| { + expand_fn( + i, + format_ident!("expand_output"), + format_ident!("vectorization_output"), + output.ty_owned(), + ) + }); + let map = quote![::std::collections::BTreeMap> = std::collections::BTreeMap::new()]; + let inputs_len = self.runtime_inputs().count(); + let outputs_len = self.runtime_outputs().count(); + let register_input = register_fn("register_input", inputs); + let register_output = register_fn("register_output", outputs); + + let insert_inputs = (inputs_len > 0).then(|| { + quote! { + for i in 0..#inputs_len { + inputs.insert(i, register_input(&mut builder, &self.settings, i)); + } + } + }); + let insert_outputs = (outputs_len > 0).then(|| { + quote! { + for i in 0..#outputs_len { + if !outputs.contains_key(&i) { + outputs.insert(i, register_output(&mut builder, &self.settings, i)); + } + } + } + }); + + let in_params = self + .runtime_inputs() + .enumerate() + .map(runtime_param("inputs")); + let out_params = self + .runtime_outputs() + .enumerate() + .map(runtime_param("outputs")); + + quote! { + let mut inputs: #map; + let mut outputs: #map; + + #register_input + #register_output + + #insert_inputs + for mapping in &self.settings.mappings { + let input = inputs.get(&mapping.pos_input).unwrap(); + outputs.insert(mapping.pos_output, input.clone()); + } + #insert_outputs + #(#in_params)* + #(#out_params)* + } + } + + fn define_body(&self) -> TokenStream { + let kernel_builder = prelude_type("KernelBuilder"); + let io_map = self.io_mappings(); + let runtime_args = self.runtime_params().map(|it| &it.name); + let comptime_args = self.comptime_params().map(|it| &it.name); + let (_, generics, _) = self.func.sig.generics.split_for_impl(); + let generics = generics.as_turbofish(); + + quote! { + let mut builder = #kernel_builder::default(); + #io_map + expand #generics(&mut builder.context, #(#runtime_args.clone(),)* #(self.#comptime_args.clone()),*); + builder.build(self.settings.clone()) + } + } + + pub fn kernel_definition(&self) -> TokenStream { + if self.args.is_launch() { + let kernel = core_type("Kernel"); + let kernel_settings = prelude_type("KernelSettings"); + let kernel_definition: syn::Path = prelude_type("KernelDefinition"); + let kernel_id = core_type("KernelId"); + + let kernel_name = self.kernel_name(); + let define = self.define_body(); + let kernel_doc = format!("{} Kernel", self.func.sig.name); + + let (generics, generic_names, where_clause) = self.kernel_generics.split_for_impl(); + let const_params: Vec<_> = self.comptime_params().collect(); + let param_names = self + .comptime_params() + .map(|param| param.name.clone()) + .collect::>(); + let phantom_data = self.kernel_phantom_data(); + let info = iter::once(format_ident!("settings")).chain(param_names.clone()); + let phantom_data_init = phantom_data + .as_ref() + .map(|_| quote![__ty: ::core::marker::PhantomData]); + + quote! { + #[doc = #kernel_doc] + pub struct #kernel_name #generics #where_clause { + settings: #kernel_settings, + #(#const_params,)* + #phantom_data + } + + impl #generics #kernel_name #generic_names #where_clause { + pub fn new(settings: #kernel_settings, #(#const_params),*) -> Self { + Self { + settings, + #(#param_names,)* + #phantom_data_init + } + } + } + + impl #generics #kernel for #kernel_name #generic_names #where_clause { + fn define(&self) -> #kernel_definition { + #define + } + + fn id(&self) -> #kernel_id { + #kernel_id::new::().info((#(self.#info.clone()),*)) + } + } + } + } else { + TokenStream::new() + } + } +} + +fn register_fn(name: &str, values: impl Iterator) -> TokenStream { + let kernel_settings = prelude_type("KernelSettings"); + let kernel_builder = prelude_type("KernelBuilder"); + + let name = format_ident!("{name}"); + quote! { + #[allow(unused)] + let #name = | + builder: &mut #kernel_builder, + settings: &#kernel_settings, + position: usize, + | -> ::std::sync::Arc { + match position { + #(#values,)* + _ => { + panic!("Input {position} is invalid"); + } + } + }; + } +} + +fn runtime_param(io_map: &str) -> impl FnMut((usize, &KernelParam)) -> TokenStream { + let cube_type = prelude_type("CubeType"); + let io_map = format_ident!("{io_map}"); + move |(i, input)| { + let name: &Ident = &input.name; + let ty = input.ty_owned(); + quote! { + let #name: &<#ty as #cube_type>::ExpandType = #io_map.get(&#i).unwrap().downcast_ref() + .expect("Output type should be correct. It could be caused by an invalid kernel input/output alias."); + } + } +} diff --git a/crates/cubecl-macros/src/generate/launch.rs b/crates/cubecl-macros/src/generate/launch.rs new file mode 100644 index 00000000..47604820 --- /dev/null +++ b/crates/cubecl-macros/src/generate/launch.rs @@ -0,0 +1,234 @@ +use ident_case::RenameRule; +use proc_macro2::TokenStream; +use quote::{format_ident, quote, ToTokens}; +use syn::{parse_quote, Ident}; + +use crate::{ + parse::kernel::{KernelParam, Launch}, + paths::{core_path, core_type, prelude_type}, +}; + +impl ToTokens for Launch { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + let vis = &self.vis; + + let name = &self.func.sig.name; + let launch = self.launch(); + let launch_unchecked = self.launch_unchecked(); + let dummy = self.create_dummy_kernel(); + let kernel = self.kernel_definition(); + let mut func = self.func.clone(); + func.sig.name = format_ident!("expand"); + let func = func.to_tokens_mut(); + + let out = quote! { + #vis mod #name { + use super::*; + + #[allow(unused, clippy::all)] + pub #func + + #kernel + #launch + #launch_unchecked + #dummy + } + }; + + if self.args.debug.is_present() { + let file = syn::parse_file(&out.to_string()).unwrap(); + let tokens = prettyplease::unparse(&file); + panic!("{tokens}"); + } + tokens.extend(out); + } +} + +impl Launch { + fn launch(&self) -> TokenStream { + if self.args.launch.is_present() { + let compute_client = prelude_type("ComputeClient"); + let cube_count = prelude_type("CubeCount"); + let cube_dim = prelude_type("CubeDim"); + + let kernel_doc = format!( + "Launch the kernel [{}()] on the given runtime", + self.func.sig.name + ); + let generics = &self.launch_generics; + let args = self.launch_args(); + let body = self.launch_body(); + + quote! { + #[allow(clippy::too_many_arguments)] + #[doc = #kernel_doc] + pub fn launch #generics( + __client: &#compute_client<__R::Server, __R::Channel>, + __cube_count: #cube_count<__R::Server>, + __cube_dim: #cube_dim, + #(#args),* + ) -> () { + #body + launcher.launch(__cube_count, kernel, __client); + } + } + } else { + TokenStream::new() + } + } + + fn launch_unchecked(&self) -> TokenStream { + if self.args.launch_unchecked.is_present() { + let compute_client = prelude_type("ComputeClient"); + let cube_count = prelude_type("CubeCount"); + let cube_dim = prelude_type("CubeDim"); + + let kernel_doc = format!( + "Launch the kernel [{}()] on the given runtime", + self.func.sig.name + ); + let generics = &self.launch_generics; + let args = self.launch_args(); + let body = self.launch_body(); + + quote! { + #[allow(clippy::too_many_arguments)] + #[doc = #kernel_doc] + pub unsafe fn launch_unchecked #generics( + __client: &#compute_client<__R::Server, __R::Channel>, + __cube_count: #cube_count<__R::Server>, + __cube_dim: #cube_dim, + #(#args),* + ) -> () { + #body + launcher.launch_unchecked(__cube_count, kernel, __client); + } + } + } else { + TokenStream::new() + } + } + + fn launch_body(&self) -> TokenStream { + let kernel_launcher = prelude_type("KernelLauncher"); + + let registers_in = self.runtime_inputs().map(|arg| { + let name = &arg.name; + quote![#name.register(&mut launcher);] + }); + + let registers_out = self.runtime_outputs().map(|arg| { + let name = &arg.name; + quote![#name.register(&mut launcher);] + }); + + let settings = self.configure_settings(); + let kernel_name = self.kernel_name(); + let core_path = core_path(); + let kernel_generics = self.kernel_generics.split_for_impl(); + let kernel_generics = kernel_generics.1.as_turbofish(); + let comptime_args = self.comptime_params().map(|it| &it.name); + + quote! { + use #core_path::frontend::ArgSettings as _; + + #settings + let kernel = #kernel_name #kernel_generics::new(__settings, #(#comptime_args),*); + let mut launcher = #kernel_launcher::<__R>::default(); + #(#registers_in)* + #(#registers_out)* + } + } + + fn configure_settings(&self) -> TokenStream { + let kernel_settings = prelude_type("KernelSettings"); + let arg_settings = prelude_type("ArgSettings"); + + let input_configs = self.runtime_inputs().enumerate().map(|(i, arg)| { + let name = &arg.name; + quote![__settings = #arg_settings::<__R>::configure_input(&#name, #i, __settings);] + }); + let output_configs = self.runtime_outputs().enumerate().map(|(i, arg)| { + let name = &arg.name; + quote![__settings = #arg_settings::<__R>::configure_output(&#name, #i, __settings);] + }); + + quote! { + let mut __settings = #kernel_settings::default().cube_dim(__cube_dim); + #(#input_configs)* + #(#output_configs)* + } + } + + fn create_dummy_kernel(&self) -> TokenStream { + if self.args.create_dummy_kernel.is_present() { + let cube_count = prelude_type("CubeCount"); + let cube_dim = prelude_type("CubeDim"); + + let kernel_doc = format!( + "Launch the kernel [{}()] on the given runtime", + self.func.sig.name + ); + let generics = &self.launch_generics; + let (_, generic_names, _) = self.kernel_generics.split_for_impl(); + + let settings = self.configure_settings(); + let kernel_name = self.kernel_name(); + let core_path = core_path(); + let comptime_args = self.launch_args(); + let comptime_names = self.comptime_params().map(|it| &it.name); + + quote! { + #[allow(clippy::too_many_arguments)] + #[doc = #kernel_doc] + pub fn create_dummy_kernel #generics( + __cube_count: #cube_count<__R::Server>, + __cube_dim: #cube_dim, + #(#comptime_args),* + ) -> #kernel_name #generic_names { + use #core_path::frontend::ArgSettings as _; + + #settings + #kernel_name::new(__settings, #(#comptime_names),*) + } + } + } else { + TokenStream::new() + } + } + + pub fn runtime_inputs(&self) -> impl Iterator { + self.runtime_params().filter(|it| !it.is_mut) + } + + pub fn runtime_outputs(&self) -> impl Iterator { + self.runtime_params().filter(|it| it.is_mut) + } + + pub fn runtime_params(&self) -> impl Iterator { + self.func.sig.parameters.iter().filter(|it| !it.is_const) + } + + fn launch_args(&self) -> Vec { + let mut args = self.func.sig.parameters.clone(); + let runtime_arg = core_type("RuntimeArg"); + for arg in args.iter_mut().filter(|it| !it.is_const) { + let ty = arg.ty_owned(); + arg.normalized_ty = parse_quote![#runtime_arg<'kernel, #ty, __R>]; + } + args + } + + pub fn kernel_name(&self) -> Ident { + let kernel_name = RenameRule::PascalCase.apply_to_field(self.func.sig.name.to_string()); + format_ident!("{kernel_name}") + } + + pub fn comptime_params(&self) -> impl Iterator { + self.func + .sig + .parameters + .iter() + .filter(|param| param.is_const) + } +} diff --git a/crates/cubecl-macros/src/generate/mod.rs b/crates/cubecl-macros/src/generate/mod.rs new file mode 100644 index 00000000..9416d016 --- /dev/null +++ b/crates/cubecl-macros/src/generate/mod.rs @@ -0,0 +1,6 @@ +pub mod cube_trait; +pub mod cube_type; +pub mod expression; +pub mod kernel; +pub mod launch; +pub mod statement; diff --git a/crates/cubecl-macros/src/generate/statement.rs b/crates/cubecl-macros/src/generate/statement.rs new file mode 100644 index 00000000..ad02fc0e --- /dev/null +++ b/crates/cubecl-macros/src/generate/statement.rs @@ -0,0 +1,90 @@ +use proc_macro2::TokenStream; +use quote::{quote, quote_spanned}; +use syn::{spanned::Spanned, Token}; + +use crate::{expression::Expression, paths::frontend_type, scope::Context, statement::Statement}; + +impl Statement { + pub fn to_tokens(&self, context: &mut Context) -> TokenStream { + match self { + Statement::Local { + left, + init, + mutable, + ty, + } => { + let cube_type = frontend_type("CubeType"); + let name = match &**left { + Expression::Variable { name, .. } => name, + _ => panic!("Local is always variable or init"), + }; + let is_mut = *mutable || init.as_deref().map(is_mut_owned).unwrap_or(false); + let mutable = mutable.then(|| quote![mut]); + let init = if is_mut { + if let Some(as_const) = init.as_ref().and_then(|it| it.as_const(context)) { + let expand = frontend_type("ExpandElementTyped"); + Some(quote_spanned![as_const.span()=> #expand::from_lit(#as_const)]) + } else { + init.as_ref().map(|it| it.to_tokens(context)) + } + } else { + init.as_ref().map(|init| { + init.as_const(context) + .unwrap_or_else(|| init.to_tokens(context)) + }) + }; + let ty = ty + .as_ref() + .map(|ty| quote_spanned![ty.span()=> :<#ty as #cube_type>::ExpandType]); + + let init = match (is_mut, init) { + (true, Some(init)) => { + let init_ty = frontend_type("Init"); + let init_ty = quote_spanned![init.span()=> #init_ty::init(_init, context)]; + Some(quote! { + { + let _init = #init; + #init_ty + } + }) + } + (_, init) => init, + }; + + if let Some(init) = init { + quote![let #mutable #name #ty = #init;] + } else { + quote![let #mutable #name #ty;] + } + } + Statement::Group { statements } => { + let statements = statements.iter().map(|it| it.to_tokens(context)); + quote! { + #(#statements)* + } + } + Statement::Expression { + expression, + span, + terminated, + } => { + let terminator = terminated.then(|| Token![;](*span)); + if let Some(as_const) = expression.as_const(context) { + quote![#as_const #terminator] + } else { + let expression = expression.to_tokens(context); + quote![#expression #terminator] + } + } + Statement::Skip => TokenStream::new(), + } + } +} + +fn is_mut_owned(init: &Expression) -> bool { + match init { + Expression::Variable { is_ref, is_mut, .. } => *is_mut && !is_ref, + Expression::FieldAccess { base, .. } => is_mut_owned(base), + _ => false, + } +} diff --git a/crates/cubecl-macros/src/lib.rs b/crates/cubecl-macros/src/lib.rs index 4379ab9f..3d71118d 100644 --- a/crates/cubecl-macros/src/lib.rs +++ b/crates/cubecl-macros/src/lib.rs @@ -1,204 +1,116 @@ -#[macro_use] -extern crate derive_new; - -mod analyzer; -mod codegen_function; -mod codegen_trait; -mod codegen_type; -mod tracker; - -pub(crate) mod codegen_common; - -use analyzer::VariableAnalyzer; -use codegen_common::signature::{expand_sig, ExpandMode}; -use codegen_function::{codegen_launch, codegen_statement}; -use codegen_trait::{expand_trait_def, expand_trait_impl}; -use codegen_type::generate_cube_type; +use error::error_into_token_stream; +use generate::cube_type::generate_cube_type; +use parse::{ + cube_trait::{CubeTrait, CubeTraitImpl}, + helpers::{RemoveHelpers, ReplaceIndices}, + kernel::{from_tokens, Launch}, +}; use proc_macro::TokenStream; -use syn::{parse_macro_input, punctuated::Punctuated, token::Comma, Meta}; -use tracker::VariableTracker; - -enum CubeMode { - /// Generates the expanded version of the function - Default, - /// Panics and prints the generated code, useful when debugging - /// Use by writing #[cube(debug)] - Debug, -} - -// Derive macro to define a cube type that is launched with a kernel -#[proc_macro_derive(CubeLaunch)] -pub fn module_derive_cube_launch(input: TokenStream) -> TokenStream { - let input = syn::parse(input).unwrap(); - - generate_cube_type(&input, true) -} - -// Derive macro to define a cube type that is not launched -#[proc_macro_derive(CubeType)] -pub fn module_derive_cube_type(input: TokenStream) -> TokenStream { - let input = syn::parse(input).unwrap(); - - generate_cube_type(&input, false) -} - -struct SupportedAttributes { - mode: CubeMode, - launch: bool, - launch_unchecked: bool, -} - -/// Derive macro for the module. +use quote::quote; +use syn::{visit_mut::VisitMut, Item}; + +mod error; +mod expression; +mod generate; +mod operator; +mod parse; +mod paths; +mod scope; +mod statement; + +/// Mark a cube function, trait or implementation for expansion. +/// +/// # Arguments +/// * `launch` - generates a function to launch the kernel +/// * `launch_unchecked` - generates a launch function without checks +/// * `debug` - panics after generation to print the output to console +/// * `create_dummy_kernel` - Generates a function to create a kernel without launching it. Used for testing. +/// +/// # Example +/// +/// ``` +/// # use cubecl_macros::cube; +/// #[cube] +/// fn my_addition(a: u32, b: u32) -> u32 { +/// a + b +/// } +/// ``` #[proc_macro_attribute] -pub fn cube(attr: TokenStream, tokens: TokenStream) -> TokenStream { - let args = parse_macro_input!(attr with Punctuated::::parse_terminated); - let attrs = parse_attributes(&args); - - let code: TokenStream = match syn::parse::(tokens).unwrap() { - syn::Item::Fn(func) => cube_fn(func, &attrs), - syn::Item::Impl(item) => expand_trait_impl(item).into(), - syn::Item::Trait(item) => expand_trait_def(item).into(), - _ => panic!("Cube annotations only supported for functions"), - }; - - match attrs.mode { - CubeMode::Default => code, - CubeMode::Debug => panic!("{code}"), - } -} - -fn cube_fn(func: syn::ItemFn, attrs: &SupportedAttributes) -> TokenStream { - let mut variable_tracker = VariableAnalyzer::create_tracker(&func); - - match codegen_cube( - &func, - &mut variable_tracker, - attrs.launch, - attrs.launch_unchecked, - ) { - Ok(code) => code.into(), - Err(err) => err.into(), +pub fn cube(args: TokenStream, input: TokenStream) -> TokenStream { + match cube_impl(args, input.clone()) { + Ok(tokens) => tokens, + Err(e) => error_into_token_stream(e, input.into()).into(), } } -fn parse_attributes(args: &Punctuated) -> SupportedAttributes { - let mut mode = CubeMode::Default; - let mut launch = false; - let mut launch_unchecked = false; - - for arg in args.iter() { - match arg { - Meta::Path(path) => { - if let Some(ident) = path.get_ident().map(|id| id.to_string()) { - match ident.as_str() { - "debug" => { - mode = CubeMode::Debug; - } - "launch" => { - launch = true; - } - "launch_unchecked" => { - launch_unchecked = true; - } - _ => { - panic!("Attribute {ident} is not supported") - } - } - } else { - panic!("Only ident attribute supported"); - } - } - Meta::List(_) => panic!("No List attribute supported"), - Meta::NameValue(_) => panic!("No NameValue attribute supported"), +fn cube_impl(args: TokenStream, input: TokenStream) -> syn::Result { + let mut item: Item = syn::parse(input)?; + match item.clone() { + Item::Fn(kernel) => { + let args = from_tokens(args.into())?; + let kernel = Launch::from_item_fn(kernel, args)?; + RemoveHelpers.visit_item_mut(&mut item); + ReplaceIndices.visit_item_mut(&mut item); + + Ok(TokenStream::from(quote! { + #[allow(dead_code, clippy::too_many_arguments)] + #item + #kernel + })) } - } - - SupportedAttributes { - mode, - launch, - launch_unchecked, - } -} - -/// Generate the expanded version of a function marked with the cube macro -fn codegen_cube( - func: &syn::ItemFn, - variable_tracker: &mut VariableTracker, - launch: bool, - launch_unchecked: bool, -) -> Result { - let signature = expand_sig( - &func.sig, - &syn::Visibility::Public(Default::default()), // Always public, otherwise we can't import - // it from an outside module. - Some(variable_tracker), - ExpandMode::FuncImpl, - ); - let mut body = quote::quote! {}; + Item::Trait(kernel_trait) => { + let expand_trait = CubeTrait::from_item_trait(kernel_trait)?; - for statement in func.block.stmts.iter() { - let tokens = codegen_statement(statement, 0, variable_tracker); - body.extend(tokens); - } - - let is_in_error = !variable_tracker.errors.is_empty(); - - if is_in_error { - // When there is an error, we don't generate the expand method, since it's only going to - // create more errors that won't help fixing the issue. - - let mut code = quote::quote! { - #[allow(dead_code)] - #[allow(clippy::too_many_arguments)] - #func - }; - - for err in variable_tracker.errors.drain(..) { - code.extend(err.into_compile_error()); + Ok(TokenStream::from(quote! { + #expand_trait + })) } + Item::Impl(item_impl) if item_impl.trait_.is_some() => { + let mut expand_impl = CubeTraitImpl::from_item_impl(item_impl)?; + let expand_impl = expand_impl.to_tokens_mut(); - return Err(code); + Ok(TokenStream::from(quote! { + #expand_impl + })) + } + item => Err(syn::Error::new_spanned( + item, + "`#[cube]` is only supported on traits and functions", + ))?, } +} - let launch_doc = if launch { - "and launch functions " - } else { - "function " - }; - - let mut launch = if launch { - codegen_launch(&func.sig, false) - } else { - quote::quote! {} - }; - - launch.extend(if launch_unchecked { - codegen_launch(&func.sig, true) - } else { - quote::quote! {} - }); - - let mod_name = &func.sig.ident; - let vis = &func.vis; - let doc = format!("Module containing the expand {launch_doc}of {mod_name}."); - - Ok(quote::quote! { - #[allow(dead_code)] - #[allow(clippy::too_many_arguments)] - #func +/// Derive macro to define a cube type that is launched with a kernel +#[proc_macro_derive(CubeLaunch, attributes(expand))] +pub fn module_derive_cube_launch(input: TokenStream) -> TokenStream { + let input = syn::parse(input).unwrap(); + generate_cube_type(&input, true).into() +} - #[doc = #doc] - #[allow(clippy::too_many_arguments)] - #vis mod #mod_name { - use super::*; +/// Derive macro to define a cube type that is not launched +#[proc_macro_derive(CubeType, attributes(expand))] +pub fn module_derive_cube_type(input: TokenStream) -> TokenStream { + let input = syn::parse(input).unwrap(); - #launch + generate_cube_type(&input, false).into() +} - #[allow(unused_mut)] - #signature { - #body - } - } - }) +/// Mark the contents of this macro as compile time values, turning off all expansion for this code +/// and using it verbatim +/// +/// # Example +/// ``` +/// #use cubecl_macros::cube; +/// #fn some_rust_function(a: u32) -> u32 {} +/// #[cube] +/// fn do_stuff(input: u32) -> u32 { +/// let comptime_value = comptime! { some_rust_function(3) }; +/// input + comptime_value +/// } +/// ``` +#[proc_macro] +pub fn comptime(input: TokenStream) -> TokenStream { + let tokens: proc_macro2::TokenStream = input.into(); + quote![{ #tokens }].into() } diff --git a/crates/cubecl-macros/src/operator.rs b/crates/cubecl-macros/src/operator.rs new file mode 100644 index 00000000..e51f35df --- /dev/null +++ b/crates/cubecl-macros/src/operator.rs @@ -0,0 +1,125 @@ +use std::fmt::Display; + +/// An operator used in the intermediate representaion +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Operator { + // Arithmetic + /// Add (+) operator + Add, + /// Sub (-) operator + Sub, + /// Mul (*) operator + Mul, + /// Div (/) operator + Div, + /// Rem (%) operator + Rem, + + // Arithmetic Assign + /// Add assign (+=) operator + AddAssign, + /// Sub assign (-=) operator + SubAssign, + /// Mul assing (*=) operator + MulAssign, + /// Div assign (/=) operator + DivAssign, + /// Rem assign (%=) operator + RemAssign, + + // Comparison + /// Equals (==) operator + Eq, + /// Not equal (!=) operator + Ne, + /// Less than (<) operator + Lt, + /// Less than equals (<=) operator + Le, + /// Greater than equal (>=) operator + Ge, + /// Greater than (>) operator + Gt, + + // Boolean + /// And (&&) operator + And, + /// Or (||) operator + Or, + /// Bitwise XOR (^) operator + BitXor, + /// Bitwise And (&) operator + BitAnd, + /// Bitwise Or (|) operator + BitOr, + + // Boolean assign + /// Bitwise xor assign (^=) operator + BitXorAssign, + /// Bitwise and assign (&=) operator + BitAndAssign, + /// Bitwise or assign (|=) operator + BitOrAssign, + + /// Shift left (<<) operator + Shl, + /// Shift right (>>) operator + Shr, + /// Shift left assign (<<=) operator + ShlAssign, + /// Shift right assign (>>= operator) + ShrAssign, + + // Unary + /// Dereference operator (*) + Deref, + /// Not operator (!) + Not, + /// Negation unary operator (-) + Neg, +} + +impl Display for Operator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_fmt(format_args!("{self:?}")) + } +} + +impl Operator { + /// Whether this is an assign op, aka whether the output is the same as the left hand side + pub fn is_assign(&self) -> bool { + matches!( + self, + Operator::AddAssign + | Operator::SubAssign + | Operator::MulAssign + | Operator::DivAssign + | Operator::RemAssign + | Operator::BitXorAssign + | Operator::BitAndAssign + | Operator::BitOrAssign + | Operator::ShlAssign + | Operator::ShrAssign + ) + } + + /// Get the expanded op name for this operation + pub fn op_name(&self) -> String { + if self.is_assign() { + let name = self.to_string().to_lowercase(); + format!("{}_assign_op", &name[..name.len() - 6]) + } else { + self.to_string().to_lowercase() + } + } + + /// Get the expanded op name for this array operation + pub fn array_op_name(&self) -> String { + if self.is_assign() { + let name = self.to_string().to_lowercase(); + format!("{}_assign_array_op", &name[..name.len() - 6]) + } else { + self.to_string().to_lowercase() + } + } +} diff --git a/crates/cubecl-macros/src/parse/branch.rs b/crates/cubecl-macros/src/parse/branch.rs new file mode 100644 index 00000000..f3ee2a33 --- /dev/null +++ b/crates/cubecl-macros/src/parse/branch.rs @@ -0,0 +1,137 @@ +use quote::quote; +use syn::{spanned::Spanned, ExprForLoop, ExprIf, ExprLoop, ExprWhile, Ident}; + +use crate::{ + expression::{Block, Expression}, + operator::Operator, + scope::Context, + statement::{parse_pat, Statement}, +}; + +use super::helpers::Unroll; + +pub fn expand_for_loop(for_loop: ExprForLoop, context: &mut Context) -> syn::Result { + let span = for_loop.span(); + let unroll = Unroll::from_attributes(&for_loop.attrs, context)?.map(|it| it.value); + + let right = Expression::from_expr(*for_loop.expr.clone(), context) + .map_err(|_| syn::Error::new(span, "Unsupported for loop expression"))?; + let var = parse_pat(*for_loop.pat)?; + + if right.is_const() && !matches!(right, Expression::Range { .. }) { + return expand_for_in_loop(var.ident, right, for_loop.body, context); + } + + let block = context.with_scope(|context| { + context.push_variable( + var.ident.clone(), + var.ty.clone(), + false, + var.is_ref, + var.is_mut, + ); + Block::from_block(for_loop.body, context) + })?; + + Ok(Expression::ForLoop { + range: Box::new(right), + unroll: unroll.map(Box::new), + var_name: var.ident, + var_ty: var.ty, + block, + }) +} + +fn expand_for_in_loop( + var_name: Ident, + right: Expression, + block: syn::Block, + context: &mut Context, +) -> syn::Result { + let statements = block + .stmts + .into_iter() + .map(|stmt| Statement::from_stmt(stmt, context)) + .collect::, _>>()?; + + let right = right.to_tokens(context); + let statements = statements.into_iter().map(|it| it.to_tokens(context)); + let for_loop = Expression::VerbatimTerminated { + tokens: quote! { + for #var_name in #right { + #(#statements)* + } + }, + }; + Ok(for_loop) +} + +pub fn expand_while_loop(while_loop: ExprWhile, context: &mut Context) -> syn::Result { + let span = while_loop.span(); + + let condition = Expression::from_expr(*while_loop.cond, context) + .map_err(|_| syn::Error::new(span, "Unsupported while condition"))?; + let inverted = Expression::Unary { + input: Box::new(condition), + operator: Operator::Not, + ty: None, + }; + + let block = context.with_scope(|ctx| Block::from_block(while_loop.body, ctx))?; + Ok(Expression::WhileLoop { + condition: Box::new(inverted), + block, + }) +} + +pub fn expand_loop(loop_expr: ExprLoop, context: &mut Context) -> syn::Result { + let block = context.with_scope(|ctx| Block::from_block(loop_expr.body, ctx))?; + Ok(Expression::Loop(block)) +} + +pub fn expand_if(if_expr: ExprIf, context: &mut Context) -> syn::Result { + let span = if_expr.span(); + let condition = Expression::from_expr(*if_expr.cond, context) + .map_err(|_| syn::Error::new(span, "Unsupported while condition"))?; + + let then_block = context.with_scope(|ctx| Block::from_block(if_expr.then_branch, ctx))?; + let else_branch = if let Some((_, else_branch)) = if_expr.else_branch { + Some(context.with_scope(|ctx| Expression::from_expr(*else_branch, ctx))?) + } else { + None + }; + Ok(Expression::If { + condition: Box::new(condition), + then_block, + else_branch: else_branch.map(Box::new), + }) +} + +impl Block { + pub fn from_block(block: syn::Block, context: &mut Context) -> syn::Result { + let mut statements = block + .stmts + .into_iter() + .map(|stmt| Statement::from_stmt(stmt, context)) + .collect::, _>>()?; + // Pop implicit return if it exists so we can assign it as the block output + let ret = match statements.pop() { + Some(Statement::Expression { + expression, + terminated: false, + .. + }) => Some(expression), + Some(stmt) => { + statements.push(stmt); + None + } + _ => None, + }; + let ty = ret.as_ref().and_then(|ret| ret.ty()); + Ok(Self { + inner: statements, + ret, + ty, + }) + } +} diff --git a/crates/cubecl-macros/src/parse/cube_trait.rs b/crates/cubecl-macros/src/parse/cube_trait.rs new file mode 100644 index 00000000..cbd66355 --- /dev/null +++ b/crates/cubecl-macros/src/parse/cube_trait.rs @@ -0,0 +1,150 @@ +use quote::format_ident; +use syn::{ + visit_mut::VisitMut, Attribute, Generics, Ident, ImplItem, ItemImpl, ItemTrait, Path, Token, + TraitItem, Type, Visibility, +}; + +use super::{ + helpers::{RemoveHelpers, ReplaceIndices}, + kernel::{KernelFn, KernelSignature}, + StripBounds, StripDefault, +}; + +pub struct CubeTrait { + pub attrs: Vec, + pub vis: Visibility, + pub unsafety: Option, + pub name: Ident, + pub generics: Generics, + pub items: Vec, + pub original_trait: ItemTrait, +} + +pub struct CubeTraitImpl { + pub unsafety: Option, + pub struct_name: Type, + pub trait_name: Path, + pub generics: Generics, + pub items: Vec, + pub original_items: Vec, +} + +pub enum CubeTraitItem { + Fn(KernelSignature), + Other, +} + +pub enum CubeTraitImplItem { + Fn(KernelFn), + Other, +} + +impl CubeTraitItem { + pub fn from_trait_item(item: TraitItem) -> syn::Result { + let res = match item { + TraitItem::Fn(func) => { + let mut func = KernelSignature::from_trait_fn(func)?; + func.name = format_ident!("__expand_{}", func.name); + CubeTraitItem::Fn(func) + } + _ => CubeTraitItem::Other, + }; + Ok(res) + } + + pub fn func(&self) -> Option<&KernelSignature> { + match self { + CubeTraitItem::Fn(func) => Some(func), + CubeTraitItem::Other => None, + } + } +} + +impl CubeTraitImplItem { + pub fn from_impl_item(item: ImplItem) -> syn::Result { + let res = match item { + ImplItem::Fn(func) => { + let mut func = KernelFn::from_sig_and_block(func.sig, func.block)?; + func.sig.name = format_ident!("__expand_{}", func.sig.name); + CubeTraitImplItem::Fn(func) + } + _ => CubeTraitImplItem::Other, + }; + Ok(res) + } + + pub fn func(&mut self) -> Option<&mut KernelFn> { + match self { + CubeTraitImplItem::Fn(func) => Some(func), + CubeTraitImplItem::Other => None, + } + } +} + +impl CubeTrait { + pub fn from_item_trait(item: ItemTrait) -> syn::Result { + let mut original_trait = item.clone(); + RemoveHelpers.visit_item_trait_mut(&mut original_trait); + + let mut attrs = item.attrs; + attrs.retain(|attr| !attr.path().is_ident("cube")); + attrs.retain(|attr| !attr.path().is_ident("cube")); + let vis = item.vis; + let unsafety = item.unsafety; + let name = item.ident; + + let mut original_generic_names = item.generics.clone(); + StripBounds.visit_generics_mut(&mut original_generic_names); + + let mut generics = item.generics; + StripDefault.visit_generics_mut(&mut generics); + + let items = item + .items + .into_iter() + .map(CubeTraitItem::from_trait_item) + .collect::>()?; + + Ok(Self { + attrs, + vis, + unsafety, + name, + generics, + items, + original_trait, + }) + } +} + +impl CubeTraitImpl { + pub fn from_item_impl(mut item_impl: ItemImpl) -> syn::Result { + let items = item_impl + .items + .iter() + .cloned() + .map(CubeTraitImplItem::from_impl_item) + .collect::>()?; + + RemoveHelpers.visit_item_impl_mut(&mut item_impl); + ReplaceIndices.visit_item_impl_mut(&mut item_impl); + + let struct_name = *item_impl.self_ty; + let trait_name = item_impl.trait_.unwrap().1; + + let mut attrs = item_impl.attrs; + attrs.retain(|attr| !attr.path().is_ident("cube")); + let unsafety = item_impl.unsafety; + + let generics = item_impl.generics; + + Ok(Self { + unsafety, + struct_name, + trait_name, + generics, + items, + original_items: item_impl.items, + }) + } +} diff --git a/crates/cubecl-macros/src/parse/cube_type.rs b/crates/cubecl-macros/src/parse/cube_type.rs new file mode 100644 index 00000000..6dd7b72f --- /dev/null +++ b/crates/cubecl-macros/src/parse/cube_type.rs @@ -0,0 +1,59 @@ +use std::iter; + +use darling::{ast::Data, util::Flag, FromDeriveInput, FromField}; +use quote::format_ident; +use syn::{parse_quote, punctuated::Punctuated, Generics, Ident, Type, Visibility}; + +use crate::paths::prelude_type; + +#[derive(FromDeriveInput)] +#[darling(supports(struct_named), attributes(expand), map = unwrap_fields)] +pub struct TypeCodegen { + pub ident: Ident, + pub name_launch: Option, + pub name_expand: Option, + data: Data<(), TypeField>, + #[darling(skip)] + pub fields: Vec, + pub generics: Generics, + pub vis: Visibility, +} + +#[derive(FromField, Clone)] +#[darling(attributes(expand))] +pub struct TypeField { + pub vis: Visibility, + pub ident: Option, + pub ty: Type, + pub comptime: Flag, +} + +fn unwrap_fields(mut ty: TypeCodegen) -> TypeCodegen { + // This will be supported inline with the next darling release + let fields = ty.data.as_ref().take_struct().unwrap().fields; + ty.fields = fields.into_iter().cloned().collect(); + + let name = &ty.ident; + ty.name_expand + .get_or_insert_with(|| format_ident!("{name}Expand")); + ty.name_launch + .get_or_insert_with(|| format_ident!("{name}Launch")); + + ty +} + +impl TypeCodegen { + pub fn expanded_generics(&self) -> Generics { + let runtime = prelude_type("Runtime"); + let mut generics = self.generics.clone(); + generics.params.push(parse_quote![R: #runtime]); + let all = iter::once(parse_quote!['a]).chain(generics.params); + generics.params = Punctuated::from_iter(all); + generics + } + + pub fn assoc_generics(&self) -> Generics { + let runtime = prelude_type("Runtime"); + parse_quote![<'a, R: #runtime>] + } +} diff --git a/crates/cubecl-macros/src/parse/expression.rs b/crates/cubecl-macros/src/parse/expression.rs new file mode 100644 index 00000000..722dff21 --- /dev/null +++ b/crates/cubecl-macros/src/parse/expression.rs @@ -0,0 +1,470 @@ +use proc_macro2::Span; +use quote::{format_ident, quote, quote_spanned, ToTokens}; +use syn::{parse_quote, spanned::Spanned, Expr, Lit, LitInt, Path, PathSegment, RangeLimits, Type}; + +use crate::{ + expression::{is_intrinsic, Block, Expression}, + operator::Operator, + scope::{Context, ManagedVar}, +}; + +use super::{ + branch::{expand_for_loop, expand_if, expand_loop, expand_while_loop}, + operator::{parse_binop, parse_unop}, +}; + +impl Expression { + pub fn from_expr(expr: Expr, context: &mut Context) -> syn::Result { + let result = match expr.clone() { + Expr::Assign(assign) => { + let right = Self::from_expr(*assign.right, context)?; + Expression::Assigment { + ty: right.ty(), + left: Box::new(Self::from_expr(*assign.left, context)?), + right: Box::new(right), + } + } + Expr::Binary(binary) => { + let left = Self::from_expr(*binary.left, context)?; + let right = Self::from_expr(*binary.right, context)?; + if left.is_const() && right.is_const() { + Expression::Verbatim { + tokens: quote![#expr], + } + } else { + let ty = left.ty().or(right.ty()); + Expression::Binary { + left: Box::new(left), + operator: parse_binop(&binary.op)?, + right: Box::new(right), + ty, + } + } + } + Expr::Lit(literal) => { + let ty = lit_ty(&literal.lit)?; + Expression::Literal { + value: literal.lit, + ty, + } + } + Expr::Path(path) => { + let variable = path + .path + .get_ident() + .and_then(|ident| context.variable(ident)); + if let Some(ManagedVar { + name, + ty, + is_const, + is_keyword, + use_count, + is_ref, + is_mut, + }) = variable + { + if is_const { + Expression::ConstVariable { + name, + ty, + use_count, + } + } else if is_keyword { + Expression::Keyword { name } + } else { + Expression::Variable { + name, + ty, + is_ref, + is_mut, + use_count, + } + } + } else { + // If it's not in the scope, it's not a managed local variable. Treat it as an + // external value like a Rust `const`. + Expression::Path { path: path.path } + } + } + Expr::Unary(unary) => { + let input = Self::from_expr(*unary.expr, context)?; + let ty = input.ty(); + Expression::Unary { + input: Box::new(input), + operator: parse_unop(&unary.op)?, + ty, + } + } + Expr::Block(block) => { + let block = context.with_scope(|ctx| Block::from_block(block.block, ctx))?; + Expression::Block(block) + } + Expr::Break(_) => Expression::Break, + Expr::Call(call) => { + let func = Box::new(Expression::from_expr(*call.func, context)?); + let args = call + .args + .into_iter() + .map(|arg| Expression::from_expr(arg, context)) + .collect::, _>>()?; + match *func { + Expression::Path { path } if is_intrinsic(&path) => { + Expression::CompilerIntrinsic { func: path, args } + } + func => { + let associated_type = fn_associated_type(&func); + Expression::FunctionCall { + func: Box::new(func), + args, + associated_type, + } + } + } + } + Expr::MethodCall(method) => { + let receiver = Expression::from_expr(*method.receiver.clone(), context)?; + let args = method + .args + .iter() + .map(|arg| Expression::from_expr(arg.clone(), context)) + .collect::, _>>()?; + if receiver.is_const() + && args.iter().all(|arg| arg.is_const()) + && method.method != "runtime" + { + let receiver = receiver.as_const(context).unwrap(); + let method = &method.method; + let args = args.iter().map(|it| it.to_tokens(context)); + Expression::Verbatim { + tokens: quote![#receiver.#method(#(#args),*)], + } + } else { + Expression::MethodCall { + receiver: Box::new(receiver), + method: method.method, + generics: method.turbofish, + args, + } + } + } + Expr::Cast(cast) => { + let mut from_expr = *cast.expr; + // Flatten multicasts because they shouldn't exist on the GPU + while matches!(from_expr, Expr::Cast(_)) { + match from_expr { + Expr::Cast(cast) => from_expr = *cast.expr, + _ => unreachable!(), + } + } + let from = Expression::from_expr(from_expr, context)?; + if let Some(as_const) = from.as_const(context) { + Expression::Verbatim { tokens: as_const } + } else { + Expression::Cast { + from: Box::new(from), + to: *cast.ty, + } + } + } + Expr::Const(block) => Expression::Verbatim { + tokens: quote![#block], + }, + Expr::Continue(cont) => Expression::Continue(cont.span()), + Expr::ForLoop(for_loop) => expand_for_loop(for_loop, context)?, + Expr::While(while_loop) => expand_while_loop(while_loop, context)?, + Expr::Loop(loop_expr) => expand_loop(loop_expr, context)?, + Expr::If(if_expr) => expand_if(if_expr, context)?, + Expr::Range(range) => { + let span = range.span(); + let start = range + .start + .map(|start| Expression::from_expr(*start, context)) + .transpose()? + .unwrap_or_else(|| { + let lit = Lit::Int(LitInt::new("0", span)); + Expression::Literal { + value: lit, + ty: parse_quote![i32], + } + }); + let end = range + .end + .map(|end| Expression::from_expr(*end, context)) + .transpose()? + .map(Box::new); + Expression::Range { + start: Box::new(start), + end, + span, + inclusive: matches!(range.limits, RangeLimits::Closed(..)), + } + } + Expr::Field(field) => { + let base = Expression::from_expr(*field.base.clone(), context)?; + Expression::FieldAccess { + base: Box::new(base), + field: field.member, + } + } + Expr::Group(group) => Expression::from_expr(*group.expr, context)?, + Expr::Paren(paren) => Expression::from_expr(*paren.expr, context)?, + Expr::Return(ret) => { + let span = ret.expr.span(); + Expression::Return { + expr: ret + .expr + .map(|expr| Expression::from_expr(*expr, context)) + .transpose()? + .map(Box::new), + span, + _ty: context.return_type.clone(), + } + } + Expr::Array(array) => { + let span = array.span(); + let elements = array + .elems + .into_iter() + .map(|elem| Expression::from_expr(elem, context)) + .collect::>()?; + Expression::Array { elements, span } + } + Expr::Tuple(tuple) => { + let elements = tuple + .elems + .into_iter() + .map(|elem| Expression::from_expr(elem, context)) + .collect::>()?; + Expression::Tuple { elements } + } + Expr::Index(index) => { + let span = index.span(); + let expr = Expression::from_expr(*index.expr, context)?; + let index = Expression::from_expr(*index.index, context)?; + if is_slice(&index) { + let ranges = match index { + Expression::Array { elements, .. } => elements.clone(), + Expression::Tuple { elements, .. } => elements.clone(), + index => vec![index], + }; + Expression::Slice { + expr: Box::new(expr), + span, + _ranges: ranges, + } + } else { + let index = match index { + Expression::Array { elements, span } => { + generate_strided_index(&expr, elements, span)? + } + index => index, + }; + Expression::Index { + expr: Box::new(expr), + index: Box::new(index), + } + } + } + Expr::Repeat(repeat) => { + let span = repeat.span(); + let len = Expression::from_expr(*repeat.len, context)?; + if !len.is_const() { + Err(syn::Error::new( + span, + "Array initializer length must be known at compile time", + ))? + } + Expression::ArrayInit { + init: Box::new(Expression::from_expr(*repeat.expr, context)?), + len: Box::new(len), + } + } + Expr::Let(expr) => { + let span = expr.span(); + let elem = Expression::from_expr(*expr.expr.clone(), context)?; + if elem.is_const() { + Expression::Verbatim { + tokens: quote![#expr], + } + } else { + Err(syn::Error::new( + span, + "let bindings aren't yet supported at runtime", + ))? + } + } + Expr::Match(mat) => { + let span = mat.span(); + let elem = Expression::from_expr(*mat.expr.clone(), context)?; + if elem.is_const() { + Expression::Verbatim { + tokens: quote![#mat], + } + } else { + Err(syn::Error::new( + span, + "match expressions aren't yet supported at runtime", + ))? + } + } + Expr::Macro(mac) if is_comptime_macro(&mac.mac.path) => { + let tokens = mac.mac.tokens; + Expression::Verbatim { + tokens: quote![{ #tokens }], + } + } + Expr::Macro(mac) => Expression::Verbatim { + tokens: quote![#mac], + }, + Expr::Struct(init) => { + let fields = init + .fields + .clone() + .into_iter() + .map(|field| { + let member = field.member; + let value = Expression::from_expr(field.expr, context)?; + syn::Result::Ok((member, value)) + }) + .collect::, _>>()?; + Expression::StructInit { + path: init.path, + fields, + } + } + Expr::Unsafe(unsafe_expr) => Expression::Block( + context.with_scope(|ctx| Block::from_block(unsafe_expr.block, ctx))?, + ), + Expr::Infer(_) => Expression::Verbatim { tokens: quote![_] }, + Expr::Verbatim(verbatim) => Expression::Verbatim { tokens: verbatim }, + Expr::Reference(reference) => Expression::Reference { + inner: Box::new(Expression::from_expr(*reference.expr, context)?), + }, + Expr::Closure(expr) => { + let body = context.with_scope(|ctx| Expression::from_expr(*expr.body, ctx))?; + let body = Box::new(body); + let params = expr.inputs.into_iter().collect(); + Expression::Closure { params, body } + } + + Expr::Try(expr) => { + let span = expr.span(); + let expr = Expression::from_expr(*expr.expr, context)? + .as_const(context) + .ok_or_else(|| syn::Error::new(span, "? Operator not supported at runtime"))?; + Expression::Verbatim { + tokens: quote_spanned![span=> #expr?], + } + } + Expr::TryBlock(_) => Err(syn::Error::new_spanned( + expr, + "try_blocks is unstable and not supported in kernels", + ))?, + e => Err(syn::Error::new_spanned( + expr, + format!("Unsupported expression {e:?}"), + ))?, + }; + Ok(result) + } +} + +fn lit_ty(lit: &Lit) -> syn::Result { + let res = match lit { + Lit::Int(int) => (!int.suffix().is_empty()) + .then(|| int.suffix()) + .map(|suffix| format_ident!("{suffix}")) + .and_then(|ident| syn::parse2(quote![#ident]).ok()) + .unwrap_or_else(|| parse_quote![i32]), + Lit::Float(float) => (!float.suffix().is_empty()) + .then(|| float.suffix()) + .map(|suffix| format_ident!("{suffix}")) + .and_then(|ident| syn::parse2(quote![#ident]).ok()) + .unwrap_or_else(|| parse_quote![f32]), + Lit::Bool(_) => parse_quote![bool], + lit => Err(syn::Error::new_spanned( + lit, + format!("Unsupported literal type: {lit:?}"), + ))?, + }; + Ok(res) +} + +fn generate_strided_index( + tensor: &Expression, + elements: Vec, + span: Span, +) -> syn::Result { + let index_ty = elements + .first() + .unwrap() + .ty() + .unwrap_or_else(|| parse_quote![u32]); + let strided_indices = elements.into_iter().enumerate().map(|(i, elem)| { + let i = Lit::Int(LitInt::new(&i.to_string(), span)); + let stride = Expression::MethodCall { + receiver: Box::new(tensor.clone()), + method: format_ident!("stride"), + args: vec![Expression::Literal { + value: i, + ty: index_ty.clone(), + }], + generics: None, + }; + Expression::Binary { + left: Box::new(elem), + operator: Operator::Mul, + right: Box::new(stride), + ty: None, + } + }); + let sum = strided_indices + .reduce(|a, b| Expression::Binary { + left: Box::new(a), + operator: Operator::Add, + right: Box::new(b), + ty: None, + }) + .unwrap(); + Ok(sum) +} + +fn is_slice(index: &Expression) -> bool { + match index { + Expression::Range { .. } => true, + Expression::Array { elements, .. } => elements.iter().any(is_slice), + Expression::Tuple { elements, .. } => elements.iter().any(is_slice), + _ => false, + } +} + +fn fn_associated_type(path: &Expression) -> Option<(Path, PathSegment)> { + // All supported primitives. Primitives don't start with an uppercase letter + const PRIMITIVES: &[&str] = &["bool", "i32", "i64", "u32", "f16", "bf16", "f32", "f64"]; + if !matches!(path, Expression::Path { .. }) { + panic!("path: {path:?}"); + } + match path { + Expression::Path { path, .. } => { + let second_last = path.segments.iter().nth_back(1)?; + let name = second_last.ident.to_string(); + let ch = name.chars().next(); + let is_assoc = ch.map(|ch| ch.is_uppercase()).unwrap_or(false); + let is_primitive = PRIMITIVES.contains(&name.as_str()); + if is_assoc || is_primitive { + let mut path = path.clone(); + let name = path.segments.pop().unwrap().into_value(); + path.segments.pop_punct(); + Some((path, name)) + } else { + None + } + } + _ => None, + } +} + +fn is_comptime_macro(path: &Path) -> bool { + let path = path.to_token_stream().to_string(); + "::cubecl::comptime".ends_with(&path) +} diff --git a/crates/cubecl-macros/src/parse/helpers.rs b/crates/cubecl-macros/src/parse/helpers.rs new file mode 100644 index 00000000..a5318564 --- /dev/null +++ b/crates/cubecl-macros/src/parse/helpers.rs @@ -0,0 +1,189 @@ +use darling::FromMeta; +use syn::{ + parse_quote, + visit_mut::{self, VisitMut}, + Attribute, Expr, +}; + +use crate::{expression::Expression, paths::prelude_path, scope::Context}; + +pub struct Unroll { + pub value: Expression, +} + +impl Unroll { + pub fn from_attributes( + attrs: &[Attribute], + context: &mut Context, + ) -> syn::Result> { + #[derive(FromMeta)] + struct NameVal { + pub value: Expr, + } + + let attr = attrs.iter().find(|attr| attr.path().is_ident("unroll")); + let attr = match attr { + Some(attr) => attr, + None => return Ok(None), + }; + + let res = match &attr.meta { + syn::Meta::Path(_) => Self { + value: Expression::from_expr(parse_quote![true], context).unwrap(), + }, + syn::Meta::List(list) => { + let expr = syn::parse2(list.tokens.clone())?; + let expr = Expression::from_expr(expr, context)?; + Self { value: expr } + } + meta => { + let expr = NameVal::from_meta(meta)?; + let expr = Expression::from_expr(expr.value, context)?; + Self { value: expr } + } + }; + Ok(Some(res)) + } + + pub fn unroll_expr(attrs: &[Attribute]) -> Option { + #[derive(FromMeta)] + struct NameVal { + pub value: Expr, + } + + let attr = attrs.iter().find(|attr| attr.path().is_ident("unroll")); + let attr = match attr { + Some(attr) => attr, + None => return None, + }; + + match &attr.meta { + syn::Meta::Path(_) => None, + syn::Meta::List(list) => syn::parse2(list.tokens.clone()).ok(), + meta => Some(NameVal::from_meta(meta).ok()?.value), + } + } +} + +pub struct RemoveHelpers; + +impl VisitMut for RemoveHelpers { + fn visit_fn_arg_mut(&mut self, i: &mut syn::FnArg) { + match i { + syn::FnArg::Receiver(recv) => recv.attrs.retain(|it| !is_comptime_attr(it)), + syn::FnArg::Typed(typed) => typed.attrs.retain(|it| !is_comptime_attr(it)), + } + visit_mut::visit_fn_arg_mut(self, i); + } + + fn visit_expr_for_loop_mut(&mut self, i: &mut syn::ExprForLoop) { + let unroll = Unroll::unroll_expr(&i.attrs); + i.attrs.retain(|attr| !is_unroll_attr(attr)); + if let Some(unroll) = unroll { + i.body + .stmts + .insert(0, parse_quote![let __unroll = #unroll;]) + } + visit_mut::visit_expr_for_loop_mut(self, i); + } +} + +pub struct ReplaceIndices; +pub struct ReplaceIndex; +pub struct ReplaceIndexMut; + +impl VisitMut for ReplaceIndices { + fn visit_expr_assign_mut(&mut self, i: &mut syn::ExprAssign) { + ReplaceIndexMut.visit_expr_mut(&mut i.left); + ReplaceIndex.visit_expr_mut(&mut i.right); + visit_mut::visit_expr_assign_mut(self, i); + } + + fn visit_expr_binary_mut(&mut self, i: &mut syn::ExprBinary) { + match i.op { + syn::BinOp::AddAssign(_) + | syn::BinOp::SubAssign(_) + | syn::BinOp::MulAssign(_) + | syn::BinOp::DivAssign(_) + | syn::BinOp::RemAssign(_) + | syn::BinOp::BitXorAssign(_) + | syn::BinOp::BitAndAssign(_) + | syn::BinOp::BitOrAssign(_) + | syn::BinOp::ShlAssign(_) + | syn::BinOp::ShrAssign(_) => { + ReplaceIndexMut.visit_expr_mut(&mut i.left); + ReplaceIndex.visit_expr_mut(&mut i.right); + } + _ => {} + } + visit_mut::visit_expr_binary_mut(self, i) + } + + fn visit_expr_mut(&mut self, i: &mut syn::Expr) { + if matches!(i, Expr::Index(_)) { + ReplaceIndex.visit_expr_mut(i) + } + visit_mut::visit_expr_mut(self, i); + } + + fn visit_item_fn_mut(&mut self, i: &mut syn::ItemFn) { + let prelude_path = prelude_path(); + let import = parse_quote![use #prelude_path::{CubeIndex as _, CubeIndexMut as _};]; + i.block.stmts.insert(0, import); + visit_mut::visit_item_fn_mut(self, i); + } + + fn visit_impl_item_fn_mut(&mut self, i: &mut syn::ImplItemFn) { + let prelude_path = prelude_path(); + let import = parse_quote![use #prelude_path::{CubeIndex as _, CubeIndexMut as _};]; + i.block.stmts.insert(0, import); + visit_mut::visit_impl_item_fn_mut(self, i); + } + + fn visit_trait_item_fn_mut(&mut self, i: &mut syn::TraitItemFn) { + if let Some(block) = &mut i.default { + let prelude_path = prelude_path(); + let import = parse_quote![use #prelude_path::{CubeIndex as _, CubeIndexMut as _};]; + block.stmts.insert(0, import); + } + visit_mut::visit_trait_item_fn_mut(self, i); + } +} + +impl VisitMut for ReplaceIndex { + fn visit_expr_mut(&mut self, i: &mut Expr) { + if let Expr::Index(index) = i { + let inner = &index.expr; + let index = &index.index; + *i = parse_quote![*#inner.cube_idx(#index)] + } + visit_mut::visit_expr_mut(self, i); + } +} + +impl VisitMut for ReplaceIndexMut { + fn visit_expr_mut(&mut self, i: &mut syn::Expr) { + if let Expr::Index(index) = i { + let inner = &index.expr; + let index = &index.index; + *i = parse_quote![*#inner.cube_idx_mut(#index)] + } + visit_mut::visit_expr_mut(self, i); + } +} + +pub fn is_comptime_attr(attr: &Attribute) -> bool { + attr.path().is_ident("comptime") +} + +pub fn is_unroll_attr(attr: &Attribute) -> bool { + attr.path().is_ident("unroll") +} + +pub fn is_expr_attribute(attr: &Attribute) -> bool { + attr.path().is_ident("expr") +} + +pub fn is_helper(attr: &Attribute) -> bool { + is_comptime_attr(attr) || is_unroll_attr(attr) || is_expr_attribute(attr) +} diff --git a/crates/cubecl-macros/src/parse/kernel.rs b/crates/cubecl-macros/src/parse/kernel.rs new file mode 100644 index 00000000..3cc9eec8 --- /dev/null +++ b/crates/cubecl-macros/src/parse/kernel.rs @@ -0,0 +1,205 @@ +use crate::{ + expression::Block, + paths::prelude_type, + scope::Context, + statement::{parse_pat, Pattern}, +}; +use darling::{ast::NestedMeta, util::Flag, FromMeta}; +use proc_macro2::TokenStream; +use std::iter; +use syn::{ + parse_quote, punctuated::Punctuated, FnArg, Generics, Ident, ItemFn, Signature, TraitItemFn, + Type, Visibility, +}; + +use super::helpers::is_comptime_attr; + +#[derive(Default, FromMeta)] +pub(crate) struct KernelArgs { + pub launch: Flag, + pub launch_unchecked: Flag, + pub debug: Flag, + pub create_dummy_kernel: Flag, +} + +pub fn from_tokens(tokens: TokenStream) -> syn::Result { + let meta = NestedMeta::parse_meta_list(tokens)?; + T::from_list(&meta).map_err(syn::Error::from) +} + +impl KernelArgs { + pub fn is_launch(&self) -> bool { + self.launch.is_present() || self.launch_unchecked.is_present() + } +} + +pub struct Launch { + pub args: KernelArgs, + pub vis: Visibility, + pub func: KernelFn, + pub kernel_generics: Generics, + pub launch_generics: Generics, +} + +#[derive(Clone)] +pub struct KernelFn { + pub sig: KernelSignature, + pub block: Block, + pub context: Context, +} + +#[derive(Clone)] +pub struct KernelSignature { + pub name: Ident, + pub parameters: Vec, + pub returns: Type, + pub generics: Generics, +} + +#[derive(Clone)] +pub struct KernelParam { + pub name: Ident, + pub ty: Type, + pub normalized_ty: Type, + pub is_const: bool, + pub is_mut: bool, + pub is_ref: bool, +} + +impl KernelParam { + fn from_param(param: FnArg) -> syn::Result { + let param = match param { + FnArg::Typed(param) => param, + param => Err(syn::Error::new_spanned( + param, + "Can't use `cube` on methods", + ))?, + }; + let Pattern { + ident, + mut is_ref, + mut is_mut, + .. + } = parse_pat(*param.pat.clone())?; + let is_const = param.attrs.iter().any(is_comptime_attr); + let ty = *param.ty.clone(); + let normalized_ty = normalize_kernel_ty(*param.ty, is_const, &mut is_ref, &mut is_mut); + + Ok(Self { + name: ident, + ty, + normalized_ty, + is_const, + is_mut, + is_ref, + }) + } + + pub fn ty_owned(&self) -> Type { + strip_ref(self.ty.clone(), &mut false, &mut false) + } +} + +impl KernelSignature { + pub fn from_signature(sig: Signature) -> syn::Result { + let name = sig.ident; + let generics = sig.generics; + let returns = match sig.output { + syn::ReturnType::Default => parse_quote![()], + syn::ReturnType::Type(_, ty) => *ty, + }; + let parameters = sig + .inputs + .into_iter() + .map(KernelParam::from_param) + .collect::, _>>()?; + + Ok(KernelSignature { + generics, + name, + parameters, + returns, + }) + } + + pub fn from_trait_fn(function: TraitItemFn) -> syn::Result { + let name = function.sig.ident; + let generics = function.sig.generics; + let returns = match function.sig.output { + syn::ReturnType::Default => parse_quote![()], + syn::ReturnType::Type(_, ty) => *ty, + }; + let parameters = function + .sig + .inputs + .into_iter() + .map(KernelParam::from_param) + .collect::, _>>()?; + + Ok(Self { + generics, + name, + parameters, + returns, + }) + } +} + +impl KernelFn { + pub fn from_sig_and_block(sig: Signature, block: syn::Block) -> syn::Result { + let sig = KernelSignature::from_signature(sig)?; + + let mut context = Context::new(sig.returns.clone()); + context.extend(sig.parameters.clone()); + let block = context.with_scope(|ctx| Block::from_block(block, ctx))?; + + Ok(KernelFn { + sig, + block, + context, + }) + } +} + +impl Launch { + pub fn from_item_fn(function: ItemFn, args: KernelArgs) -> syn::Result { + let runtime = prelude_type("Runtime"); + + let vis = function.vis; + let func = KernelFn::from_sig_and_block(function.sig, *function.block)?; + let mut kernel_generics = func.sig.generics.clone(); + kernel_generics.params.push(parse_quote![__R: #runtime]); + let mut expand_generics = kernel_generics.clone(); + expand_generics.params = + Punctuated::from_iter(iter::once(parse_quote!['kernel]).chain(expand_generics.params)); + + Ok(Launch { + args, + vis, + func, + kernel_generics, + launch_generics: expand_generics, + }) + } +} + +fn normalize_kernel_ty(ty: Type, is_const: bool, is_ref: &mut bool, is_mut: &mut bool) -> Type { + let ty = strip_ref(ty, is_ref, is_mut); + let cube_type = prelude_type("CubeType"); + if is_const { + ty + } else { + parse_quote![<#ty as #cube_type>::ExpandType] + } +} + +fn strip_ref(ty: Type, is_ref: &mut bool, is_mut: &mut bool) -> Type { + match ty { + Type::Reference(reference) => { + *is_ref = true; + *is_mut = *is_mut || reference.mutability.is_some(); + *reference.elem + } + ty => ty, + } +} diff --git a/crates/cubecl-macros/src/parse/mod.rs b/crates/cubecl-macros/src/parse/mod.rs new file mode 100644 index 00000000..ae2f38ca --- /dev/null +++ b/crates/cubecl-macros/src/parse/mod.rs @@ -0,0 +1,59 @@ +use syn::{visit_mut::VisitMut, GenericParam, TypeParam}; + +pub mod branch; +pub mod cube_trait; +pub mod cube_type; +pub mod expression; +pub mod helpers; +pub mod kernel; +pub mod operator; + +pub struct StripDefault; +impl VisitMut for StripDefault { + fn visit_generics_mut(&mut self, i: &mut syn::Generics) { + for generic in i.params.iter_mut() { + match generic { + GenericParam::Lifetime(_) => {} + GenericParam::Type(ty) => { + ty.default.take(); + ty.eq_token.take(); + } + GenericParam::Const(con) => { + con.default.take(); + con.eq_token.take(); + } + } + } + } +} + +pub struct StripBounds; + +impl VisitMut for StripBounds { + fn visit_generics_mut(&mut self, i: &mut syn::Generics) { + for generic in i.params.iter_mut() { + match generic { + GenericParam::Lifetime(lifetime) => { + lifetime.attrs.clear(); + lifetime.bounds.clear(); + lifetime.colon_token.take(); + } + GenericParam::Type(ty) => { + ty.attrs.clear(); + ty.bounds.clear(); + ty.colon_token.take(); + } + GenericParam::Const(con) => { + *generic = GenericParam::Type(TypeParam { + attrs: Default::default(), + ident: con.ident.clone(), + colon_token: None, + bounds: Default::default(), + eq_token: None, + default: None, + }) + } + } + } + } +} diff --git a/crates/cubecl-macros/src/parse/operator.rs b/crates/cubecl-macros/src/parse/operator.rs new file mode 100644 index 00000000..4d2e36a3 --- /dev/null +++ b/crates/cubecl-macros/src/parse/operator.rs @@ -0,0 +1,48 @@ +use syn::{BinOp, UnOp}; + +use crate::operator::Operator; + +pub fn parse_binop(op: &BinOp) -> syn::Result { + let op = match op { + BinOp::Add(_) => Operator::Add, + BinOp::Sub(_) => Operator::Sub, + BinOp::Mul(_) => Operator::Mul, + BinOp::Div(_) => Operator::Div, + BinOp::Rem(_) => Operator::Rem, + BinOp::And(_) => Operator::And, + BinOp::Or(_) => Operator::Or, + BinOp::BitXor(_) => Operator::BitXor, + BinOp::BitAnd(_) => Operator::BitAnd, + BinOp::BitOr(_) => Operator::BitOr, + BinOp::Shl(_) => Operator::Shl, + BinOp::Shr(_) => Operator::Shr, + BinOp::Eq(_) => Operator::Eq, + BinOp::Lt(_) => Operator::Lt, + BinOp::Le(_) => Operator::Le, + BinOp::Ne(_) => Operator::Ne, + BinOp::Ge(_) => Operator::Ge, + BinOp::Gt(_) => Operator::Gt, + BinOp::AddAssign(_) => Operator::AddAssign, + BinOp::SubAssign(_) => Operator::SubAssign, + BinOp::MulAssign(_) => Operator::MulAssign, + BinOp::DivAssign(_) => Operator::DivAssign, + BinOp::RemAssign(_) => Operator::RemAssign, + BinOp::BitXorAssign(_) => Operator::BitXorAssign, + BinOp::BitAndAssign(_) => Operator::BitAndAssign, + BinOp::BitOrAssign(_) => Operator::BitOrAssign, + BinOp::ShlAssign(_) => Operator::ShlAssign, + BinOp::ShrAssign(_) => Operator::ShrAssign, + op => Err(syn::Error::new_spanned(op, "Unsupported operator"))?, + }; + Ok(op) +} + +pub fn parse_unop(op: &UnOp) -> syn::Result { + let op = match op { + UnOp::Deref(_) => Operator::Deref, + UnOp::Not(_) => Operator::Not, + UnOp::Neg(_) => Operator::Neg, + op => Err(syn::Error::new_spanned(op, "Unsupported operator"))?, + }; + Ok(op) +} diff --git a/crates/cubecl-macros/src/paths.rs b/crates/cubecl-macros/src/paths.rs new file mode 100644 index 00000000..1a50772a --- /dev/null +++ b/crates/cubecl-macros/src/paths.rs @@ -0,0 +1,54 @@ +use quote::format_ident; +use std::cell::LazyCell; +use syn::Path; + +#[allow(clippy::declare_interior_mutable_const)] +const CORE_PATH: LazyCell = LazyCell::new(|| Path::from(format_ident!("cubecl"))); +#[allow(clippy::declare_interior_mutable_const)] +const FRONTEND_PATH: LazyCell = LazyCell::new(|| { + let mut path = core_path(); + path.segments.push(format_ident!("frontend").into()); + path +}); +#[allow(clippy::declare_interior_mutable_const)] +const PRELUDE_PATH: LazyCell = LazyCell::new(|| { + let mut path = core_path(); + path.segments.push(format_ident!("prelude").into()); + path +}); + +pub fn frontend_path() -> Path { + #[allow(clippy::borrow_interior_mutable_const)] + FRONTEND_PATH.clone() +} + +pub fn prelude_path() -> Path { + #[allow(clippy::borrow_interior_mutable_const)] + PRELUDE_PATH.clone() +} + +pub fn core_path() -> Path { + #[allow(clippy::borrow_interior_mutable_const)] + CORE_PATH.clone() +} + +pub fn core_type(ty: &str) -> Path { + let mut path = core_path(); + let ident = format_ident!("{ty}"); + path.segments.push(ident.into()); + path +} + +pub fn frontend_type(ty: &str) -> Path { + let mut path = frontend_path(); + let ident = format_ident!("{ty}"); + path.segments.push(ident.into()); + path +} + +pub fn prelude_type(ty: &str) -> Path { + let mut path = prelude_path(); + let ident = format_ident!("{ty}"); + path.segments.push(ident.into()); + path +} diff --git a/crates/cubecl-macros/src/scope.rs b/crates/cubecl-macros/src/scope.rs new file mode 100644 index 00000000..24de9633 --- /dev/null +++ b/crates/cubecl-macros/src/scope.rs @@ -0,0 +1,238 @@ +use std::{ + collections::{HashMap, VecDeque}, + rc::Rc, + sync::atomic::{AtomicUsize, Ordering}, +}; + +use quote::format_ident; +use syn::{parse_quote, Ident, Type}; + +use crate::parse::kernel::KernelParam; + +pub const KEYWORDS: [&str; 21] = [ + "ABSOLUTE_POS", + "ABSOLUTE_POS_X", + "ABSOLUTE_POS_Y", + "ABSOLUTE_POS_Z", + "UNIT_POS", + "UNIT_POS_X", + "UNIT_POS_Y", + "UNIT_POS_Z", + "CUBE_POS", + "CUBE_POS_X", + "CUBE_POS_Y", + "CUBE_POS_Z", + "CUBE_DIM", + "CUBE_DIM_X", + "CUBE_DIM_Y", + "CUBE_DIM_Z", + "CUBE_COUNT", + "CUBE_COUNT_X", + "CUBE_COUNT_Y", + "CUBE_COUNT_Z", + "SUBCUBE_DIM", +]; + +#[derive(Clone)] +pub struct Context { + pub return_type: Type, + scopes: Vec, + // Allows for global variable analysis + scope_history: HashMap>, +} + +impl Context { + pub fn new(return_type: Type) -> Self { + let mut root_scope = Scope::default(); + root_scope.variables.extend(KEYWORDS.iter().map(|it| { + let name = format_ident!("{it}"); + let ty = parse_quote![u32]; + ManagedVar { + name, + ty: Some(ty), + is_const: false, + is_ref: false, + is_mut: false, + is_keyword: true, + use_count: AtomicUsize::new(0).into(), + } + })); + Self { + return_type, + scopes: vec![root_scope], + scope_history: Default::default(), + } + } + + pub fn push_variable( + &mut self, + name: Ident, + ty: Option, + is_const: bool, + is_ref: bool, + is_mut: bool, + ) { + self.scopes + .last_mut() + .expect("Scopes must at least have root scope") + .variables + .push(ManagedVar { + name, + ty, + is_const, + is_ref, + is_mut, + is_keyword: false, + use_count: AtomicUsize::new(0).into(), + }); + } + + fn push_scope(&mut self) { + self.scopes.push(Scope::default()) + } + + fn pop_scope(&mut self) { + let scope = self.scopes.pop().expect("Can't pop root scope"); + self.scope_history + .entry(self.scopes.len()) + .or_default() + .push_back(scope); + } + + fn delete_scope(&mut self) { + self.scopes.pop(); + } + + pub fn with_scope(&mut self, with: impl FnOnce(&mut Self) -> T) -> T { + self.push_scope(); + let res = with(self); + self.pop_scope(); + res + } + + fn restore_scope(&mut self) { + let scope = self + .scope_history + .get_mut(&(self.scopes.len())) + .and_then(|it| it.pop_front()); + if let Some(scope) = scope { + self.scopes.push(scope); + } + } + + fn restore_mut_scope(&mut self) { + let scope = self + .scope_history + .get_mut(&(self.scopes.len())) + .and_then(|it| it.pop_front()); + if let Some(mut scope) = scope { + scope.is_mut = true; + self.scopes.push(scope); + } + } + + pub fn with_restored_scope(&mut self, with: impl FnOnce(&mut Self) -> T) -> T { + self.restore_scope(); + let res = with(self); + self.delete_scope(); + res + } + + /// Mutable closures (for loops) have different behaviour because outer vars must be cloned + pub fn with_restored_closure_scope(&mut self, with: impl FnOnce(&mut Self) -> T) -> T { + self.restore_mut_scope(); + let res = with(self); + self.delete_scope(); + res + } + + pub fn variable(&self, name: &Ident) -> Option { + // Walk through each scope backwards until we find the variable. + self.scopes + .iter() + .rev() + .flat_map(|scope| scope.variables.iter().rev()) + .find(|var| name == &var.name) + .map(|var| { + var.use_count.fetch_add(1, Ordering::AcqRel); + var.clone() + }) + } + + pub fn try_consume(&self, name: &Ident) -> bool { + // Find innermost closure scope if it exists + let mut_scope_idx = self + .scopes + .iter() + .enumerate() + .rev() + .find(|(_, scope)| scope.is_mut) + .map(|(i, _)| i); + let (level, var) = self + .scopes + .iter() + .enumerate() + .rev() + .flat_map(|(i, scope)| scope.variables.iter().map(move |it| (i, it))) + .find(|(_, var)| &var.name == name && var.use_count.load(Ordering::Acquire) > 0) + .unwrap_or_else(|| { + panic!( + "Trying to get use count of variable {name} that never existed.\nScopes: {:#?}\nHistory:{:#?}", + self.scopes, + self.scope_history + ); + }); + let count = var.use_count.fetch_sub(1, Ordering::AcqRel); + /* if level == 0 { + // Always clone outer vars since we can't see whether they're still used outside the + // function + false + } else */ + if let Some(mut_scope_idx) = mut_scope_idx { + // Always clone vars from outside closure, otherwise proceed as normal + level >= mut_scope_idx && count <= 1 + } else { + count <= 1 + } + } + + pub fn extend(&mut self, vars: impl IntoIterator) { + self.scopes + .last_mut() + .expect("Scopes must at least have root scope") + .variables + .extend(vars.into_iter().map(Into::into)) + } +} + +#[derive(Default, Clone, Debug)] +pub struct Scope { + variables: Vec, + /// Must clone outer vars + is_mut: bool, +} + +#[derive(Clone, Debug)] +pub struct ManagedVar { + pub name: Ident, + pub ty: Option, + pub is_const: bool, + pub is_ref: bool, + pub is_mut: bool, + pub is_keyword: bool, + pub use_count: Rc, +} + +impl From for ManagedVar { + fn from(value: KernelParam) -> Self { + ManagedVar { + name: value.name, + ty: Some(value.ty), + is_const: value.is_const, + is_keyword: false, + use_count: AtomicUsize::new(0).into(), + is_ref: value.is_ref, + is_mut: value.is_mut, + } + } +} diff --git a/crates/cubecl-macros/src/statement.rs b/crates/cubecl-macros/src/statement.rs new file mode 100644 index 00000000..6c5a8fc6 --- /dev/null +++ b/crates/cubecl-macros/src/statement.rs @@ -0,0 +1,228 @@ +use std::{ + rc::Rc, + sync::atomic::{AtomicUsize, Ordering}, +}; + +use crate::{expression::Expression, scope::Context}; +use proc_macro2::Span; +use quote::format_ident; +use syn::{ + spanned::Spanned, Ident, Index, Member, Pat, PatStruct, PatTuple, PatTupleStruct, Stmt, Type, + TypeReference, +}; + +#[derive(Clone, Debug)] +pub enum Statement { + Local { + left: Box, + init: Option>, + mutable: bool, + ty: Option, + }, + /// Group of statements generated by desugaring + Group { + statements: Vec, + }, + Expression { + expression: Box, + terminated: bool, + span: Span, + }, + Skip, +} + +impl Statement { + pub fn from_stmt(stmt: Stmt, context: &mut Context) -> syn::Result { + let statement = match stmt { + Stmt::Local(local) => { + let init = local + .init + .map(|init| Expression::from_expr(*init.expr, context)) + .transpose()? + .map(Box::new); + let Pattern { + ident, + ty, + is_ref, + is_mut, + } = match local.pat { + Pat::Struct(pat) => { + return desugar_struct_local(pat, *init.unwrap(), context); + } + Pat::Tuple(PatTuple { elems, .. }) + | Pat::TupleStruct(PatTupleStruct { elems, .. }) => { + return desugar_tuple_local(elems, *init.unwrap(), context) + } + pat => parse_pat(pat)?, + }; + let is_const = init.as_ref().map(|init| init.is_const()).unwrap_or(false); + let variable = Box::new(Expression::Variable { + name: ident.clone(), + is_ref, + is_mut, + ty: ty.clone(), + use_count: Rc::new(AtomicUsize::new(0)), + }); + + context.push_variable(ident, ty.clone(), is_const && !is_mut, is_ref, is_mut); + Self::Local { + left: variable, + init, + mutable: is_mut, + ty, + } + } + Stmt::Expr(expr, semi) => { + let span = expr.span(); + let expression = Box::new(Expression::from_expr(expr, context)?); + Statement::Expression { + terminated: semi.is_some() || !expression.needs_terminator(), + span, + expression, + } + } + Stmt::Item(_) => Statement::Skip, + stmt => Err(syn::Error::new_spanned(stmt, "Unsupported statement"))?, + }; + Ok(statement) + } +} + +pub struct Pattern { + pub ident: Ident, + pub ty: Option, + pub is_ref: bool, + pub is_mut: bool, +} + +pub fn parse_pat(pat: Pat) -> syn::Result { + let res = match pat { + Pat::Ident(ident) => Pattern { + ident: ident.ident, + ty: None, + is_ref: ident.by_ref.is_some(), + is_mut: ident.mutability.is_some(), + }, + Pat::Type(pat) => { + let ty = *pat.ty; + let is_ref = matches!(ty, Type::Reference(_)); + let ref_mut = matches!( + ty, + Type::Reference(TypeReference { + mutability: Some(_), + .. + }) + ); + let inner = parse_pat(*pat.pat)?; + Pattern { + ident: inner.ident, + ty: Some(ty), + is_ref: is_ref || inner.is_ref, + is_mut: ref_mut || inner.is_mut, + } + } + Pat::Wild(_) => Pattern { + ident: format_ident!("_"), + ty: None, + is_ref: false, + is_mut: false, + }, + pat => Err(syn::Error::new_spanned( + pat.clone(), + format!("Unsupported local pat: {pat:?}"), + ))?, + }; + Ok(res) +} + +fn desugar_struct_local( + pat: PatStruct, + init: Expression, + context: &mut Context, +) -> syn::Result { + let fields = pat + .fields + .into_iter() + .map(|field| { + let access = Expression::FieldAccess { + base: Box::new(init.clone()), + field: field.member, + }; + let Pattern { + ident, + ty, + is_ref, + is_mut, + } = parse_pat(*field.pat.clone())?; + context.push_variable(ident.clone(), ty.clone(), init.is_const(), is_ref, is_mut); + let statement = Statement::Local { + left: Box::new(Expression::Variable { + name: ident, + is_ref, + is_mut, + ty: ty.clone(), + use_count: AtomicUsize::new(0).into(), + }), + init: Some(Box::new(access)), + mutable: is_mut, + ty, + }; + Ok(statement) + }) + .collect::>>()?; + + match init { + Expression::Variable { use_count, .. } | Expression::ConstVariable { use_count, .. } => { + use_count.fetch_add(fields.len() - 1, Ordering::AcqRel); + } + _ => {} + } + + Ok(Statement::Group { statements: fields }) +} + +fn desugar_tuple_local( + elems: impl IntoIterator, + init: Expression, + context: &mut Context, +) -> syn::Result { + let fields = elems + .into_iter() + .enumerate() + .map(|(i, pat)| { + let access = Expression::FieldAccess { + base: Box::new(init.clone()), + field: Member::Unnamed(Index::from(i)), + }; + let Pattern { + ident, + ty, + is_ref, + is_mut, + } = parse_pat(pat.clone())?; + context.push_variable(ident.clone(), ty.clone(), init.is_const(), is_ref, is_mut); + let statement = Statement::Local { + left: Box::new(Expression::Variable { + name: ident, + ty: ty.clone(), + use_count: AtomicUsize::new(0).into(), + is_ref, + is_mut, + }), + init: Some(Box::new(access)), + mutable: is_mut, + ty, + }; + Ok(statement) + }) + .collect::>>()?; + + match init { + Expression::Variable { use_count, .. } | Expression::ConstVariable { use_count, .. } => { + use_count.fetch_add(fields.len() - 1, Ordering::AcqRel); + } + _ => {} + } + + Ok(Statement::Group { statements: fields }) +} diff --git a/crates/cubecl-macros/src/tracker.rs b/crates/cubecl-macros/src/tracker.rs deleted file mode 100644 index 7371cf10..00000000 --- a/crates/cubecl-macros/src/tracker.rs +++ /dev/null @@ -1,244 +0,0 @@ -use std::collections::HashMap; - -#[derive(new, Hash, PartialEq, Eq, Debug, Clone)] -/// Identifies a variable uniquely -pub struct VariableIdent { - name: String, - repeat: u8, - scope: u8, - field: Option, -} - -#[derive(new, Eq, PartialEq, Hash, Debug)] -/// Identifies a variable, with possible collisions when variables are redeclared -struct VariableKey { - name: String, - scope: u8, -} - -#[derive(Debug, Default)] -/// Tracks variable uses -pub(crate) struct VariableTracker { - scopes_declared: HashMap>, - analysis_repeats: HashMap, - codegen_repeats: HashMap, - variable_uses: HashMap, - pub errors: Vec, -} - -#[derive(Debug, Default)] -/// Encapsulates number of uses and whether this implies cloning -pub(crate) struct VariableUse { - pub num_used: usize, - pub is_comptime: bool, -} - -impl VariableUse { - pub fn should_clone(&self) -> bool { - self.num_used > 1 - } -} - -impl VariableTracker { - /// During analysis, tracks a variable declaration - pub(crate) fn analyze_declare(&mut self, name: String, scope: u8, is_comptime: bool) { - if let Some(scopes) = self.scopes_declared.get_mut(&name) { - if !scopes.contains(&scope) { - scopes.push(scope); - } - } else { - self.scopes_declared.insert(name.clone(), vec![scope]); - } - - let key = VariableKey::new(name.clone(), scope); - let repeat = if let Some(count) = self.analysis_repeats.get_mut(&key) { - *count += 1; - *count - } else { - self.analysis_repeats.insert(key, 0); - 0 - }; - - let analysis = VariableUse { - num_used: 1, - is_comptime, - }; - let variable_ident = VariableIdent::new(name, repeat, scope, None); - self.variable_uses.insert(variable_ident, analysis); - } - - /// During analysis, tracks a variable use - pub(crate) fn analyze_reuse(&mut self, ident: &syn::Ident, scope: u8, field: Option) { - let name = ident.to_string(); - - if name == "None" { - return; - } - - let scopes_declared = match self.scopes_declared.get(&name) { - Some(val) => val, - None => { - self.errors - .push(syn::Error::new_spanned(ident, "Variable not declared")); - return; - } - }; - - let scope = *scopes_declared - .iter() - .filter(|s| **s <= scope) - .max() - .unwrap(); - let key = VariableKey::new(name.clone(), scope); - - // If the name and scope do not match a declared variable, - // then we are using a variable declared in a parent scope, and - // cloning must always happen, therefore no need for further analysis - if let Some(repeat) = self.analysis_repeats.get(&key) { - let variable = VariableIdent::new(name, *repeat, scope, field); - self.analyze(&variable); - } - } - - /// Increments variable use and its parent struct if need be - fn analyze(&mut self, variable_ident: &VariableIdent) { - match self.variable_uses.get_mut(variable_ident) { - Some(variable_use) => { - variable_use.num_used += 1; - } - None => { - // If variable was not inserted yet, it must be a field - if variable_ident.field.is_some() { - let mut parent_ident = variable_ident.clone(); - parent_ident.field = None; - let parent = self.variable_uses.get(&parent_ident).unwrap(); - - let attr_analysis = VariableUse { - num_used: 1, - is_comptime: parent.is_comptime, - }; - self.variable_uses - .insert(variable_ident.clone(), attr_analysis); - } else { - panic!("Variable not declared"); - } - } - }; - - // Whether a field was previously seen or not, we must increase the use of the parent struct - if variable_ident.field.is_some() { - let mut declaration_ident = variable_ident.clone(); - declaration_ident.field = None; - let declaration = self - .variable_uses - .get_mut(&declaration_ident) - .unwrap_or_else(|| panic!("Struct {:?} does not exist", declaration_ident)); - declaration.num_used += 1; - } - } - - /// During codegen, tracks a variable declaration. - /// This must be done again to know on what repeat a use occurs - pub(crate) fn codegen_declare(&mut self, name: String, scope: u8) { - let key = VariableKey::new(name.clone(), scope); - if let Some(count) = self.codegen_repeats.get_mut(&key) { - *count += 1; - } else { - self.codegen_repeats.insert(key, 0); - } - } - - /// During codegen, tracks a variable use. - pub(crate) fn codegen_reuse( - &mut self, - name: String, - scope: u8, - field: Option, - ) -> Result<(bool, bool), VariableReuseError> { - let scopes_declared = self - .scopes_declared - .get(&name) - .ok_or_else(|| VariableNotFound::new(name.clone(), scope, field.clone()))?; - let scope_declared = *scopes_declared - .iter() - .filter(|s| **s <= scope) - .max() - .ok_or_else(|| VariableNotFound::new(name.clone(), scope, field.clone()))?; - - let key = VariableKey::new(name.clone(), scope_declared); - let repeat = self.codegen_repeats.get(&key).unwrap_or(&0); - let ident = VariableIdent::new(name.clone(), *repeat, scope_declared, field.clone()); - - let should_clone_parent = if field.is_some() { - let struct_ident = VariableIdent::new(name.clone(), *repeat, scope_declared, None); - let parent_analysis = self - .variable_uses - .get_mut(&struct_ident) - .ok_or_else(|| VariableNotFound::new(name.clone(), scope_declared, None))?; - - parent_analysis.num_used -= 1; - parent_analysis.should_clone() - } else { - false - }; - - let analysis = self - .variable_uses - .get_mut(&ident) - .ok_or_else(|| VariableNotFound::new(name, scope_declared, field))?; - - analysis.num_used -= 1; - let should_clone = - analysis.should_clone() || should_clone_parent || scope_declared != scope; - Ok((should_clone, analysis.is_comptime)) - } - - pub fn set_as_comptime( - &mut self, - name: String, - scope: u8, - field: Option, - ) -> Result<(), VariableReuseError> { - let scopes_declared = self - .scopes_declared - .get(&name) - .ok_or_else(|| VariableNotFound::new(name.clone(), scope, field.clone()))?; - let scope_declared = *scopes_declared - .iter() - .filter(|s| **s <= scope) - .max() - .ok_or_else(|| VariableNotFound::new(name.clone(), scope, field.clone()))?; - - let key = VariableKey::new(name.clone(), scope_declared); - let repeat = self.codegen_repeats.get(&key).unwrap_or(&0); - let ident = VariableIdent::new(name.clone(), *repeat, scope_declared, field.clone()); - - let analysis = self - .variable_uses - .get_mut(&ident) - .ok_or_else(|| VariableNotFound::new(name, scope_declared, field))?; - - analysis.is_comptime = true; - - Ok(()) - } -} - -#[derive(new, Debug)] -pub struct VariableNotFound { - _name: String, - _scope: u8, - _field: Option, -} - -#[derive(Debug)] -#[allow(dead_code)] -pub enum VariableReuseError { - VariableNotFound(VariableNotFound), -} - -impl From for VariableReuseError { - fn from(value: VariableNotFound) -> Self { - Self::VariableNotFound(value) - } -} diff --git a/crates/cubecl-macros/tests/cuda/common.rs b/crates/cubecl-macros/tests/cuda/common.rs new file mode 100644 index 00000000..60ab07aa --- /dev/null +++ b/crates/cubecl-macros/tests/cuda/common.rs @@ -0,0 +1,69 @@ +use std::{io::Write, process::Command}; + +use cubecl_core::{ + client::ComputeClient, + prelude::{ArrayArg, TensorArg}, + server, Compiler, ExecutionMode, Kernel, Runtime, +}; +use cubecl_cuda::{CudaDevice, CudaRuntime}; + +type Client = ComputeClient<::Server, ::Channel>; +type Handle = server::Handle<::Server>; + +pub fn client() -> Client { + let device = CudaDevice::new(0); + CudaRuntime::client(&device) +} + +#[allow(unused)] +pub fn handle(client: &Client) -> Handle { + client.empty(1) +} + +#[allow(unused)] +pub fn tensor(tensor: &Handle) -> TensorArg<'_, CudaRuntime> { + unsafe { TensorArg::from_raw_parts(tensor, &[1], &[1], 1) } +} + +#[allow(unused)] +pub fn tensor_vec(tensor: &Handle, vec: u8) -> TensorArg<'_, CudaRuntime> { + unsafe { TensorArg::from_raw_parts(tensor, &[1], &[1], vec) } +} + +#[allow(unused)] +pub fn array(tensor: &Handle) -> ArrayArg<'_, CudaRuntime> { + unsafe { ArrayArg::from_raw_parts(tensor, 1, 1) } +} + +pub fn compile(kernel: impl Kernel) -> String { + let kernel = <::Compiler as Compiler>::compile( + kernel.define(), + ExecutionMode::Checked, + ) + .to_string(); + format_cpp_code(&kernel).unwrap() +} + +/// Format C++ code, useful when debugging. +fn format_cpp_code(code: &str) -> Result { + let mut child = Command::new("clang-format") + .stdin(std::process::Stdio::piped()) + .stdout(std::process::Stdio::piped()) + .spawn()?; + + { + let stdin = child.stdin.as_mut().expect("Failed to open stdin"); + stdin.write_all(code.as_bytes())?; + } + + let output = child.wait_with_output()?; + + if output.status.success() { + Ok(String::from_utf8_lossy(&output.stdout).into_owned()) + } else { + Err(std::io::Error::new( + std::io::ErrorKind::Other, + "clang-format failed", + )) + } +} diff --git a/crates/cubecl-macros/tests/cuda/main.rs b/crates/cubecl-macros/tests/cuda/main.rs new file mode 100644 index 00000000..3bd959eb --- /dev/null +++ b/crates/cubecl-macros/tests/cuda/main.rs @@ -0,0 +1,115 @@ +use common::*; +use cubecl_core as cubecl; +use cubecl_core::{prelude::*, CubeCount, CubeDim}; +use cubecl_cuda::CudaRuntime; +use pretty_assertions::assert_eq; + +mod common; + +#[cube(launch_unchecked, create_dummy_kernel)] +pub fn slice_assign_kernel(input: &Tensor, output: &mut Tensor) { + if UNIT_POS == 0 { + let slice_1 = output.slice_mut(2, 3); + slice_1[0] = input[0]; + } +} + +#[test] +pub fn slice_assign() { + let client = client(); + let input = handle(&client); + let output = handle(&client); + + let kernel = slice_assign_kernel::create_dummy_kernel::( + CubeCount::Static(1, 1, 1), + CubeDim::new(1, 1, 1), + tensor(&input), + tensor(&output), + ); + let expected = include_str!("slice_assign.cu").replace("\r\n", "\n"); + assert_eq!(compile(kernel), expected); +} + +#[cube(launch, create_dummy_kernel)] +pub fn kernel_sum(output: &mut Tensor) { + let val = output[UNIT_POS]; + let val2 = cubecl_core::prelude::subcube_sum(val); + + if UNIT_POS == 0 { + output[0] = val2; + } +} + +#[test] +pub fn subcube_sum() { + let client = client(); + let output = handle(&client); + + let kernel = kernel_sum::create_dummy_kernel::( + CubeCount::Static(1, 1, 1), + CubeDim::new(4, 1, 1), + tensor(&output), + ); + let expected = include_str!("subcube_sum.cu").replace("\r\n", "\n"); + assert_eq!(compile(kernel), expected); +} + +#[cube(launch, create_dummy_kernel)] +pub fn sequence_for_loop_kernel(output: &mut Array) { + if UNIT_POS != 0 { + return; + } + + let mut sequence = Sequence::::new(); + sequence.push(1.0); + sequence.push(4.0); + + for value in sequence { + output[0] += value; + } +} + +#[test] +pub fn sequence_for_loop() { + let client = client(); + let output = handle(&client); + + let kernel = sequence_for_loop_kernel::create_dummy_kernel::( + CubeCount::Static(1, 1, 1), + CubeDim::default(), + array(&output), + ); + let expected = include_str!("sequence_for_loop.cu").replace("\r\n", "\n"); + assert_eq!(compile(kernel), expected); +} + +#[cube(launch, create_dummy_kernel)] +fn execute_unary_kernel(lhs: &Tensor, rhs: &Tensor, out: &mut Tensor) { + if ABSOLUTE_POS < out.len() { + for i in 0..256u32 { + if i % 2 == 0 { + out[ABSOLUTE_POS] -= F::cos(lhs[ABSOLUTE_POS] * rhs[ABSOLUTE_POS]); + } else { + out[ABSOLUTE_POS] += F::cos(lhs[ABSOLUTE_POS] * rhs[ABSOLUTE_POS]); + } + } + } +} + +#[test] +pub fn unary_bench() { + let client = client(); + let lhs = handle(&client); + let rhs = handle(&client); + let out = handle(&client); + + let kernel = execute_unary_kernel::create_dummy_kernel::( + CubeCount::Static(1, 1, 1), + CubeDim::default(), + tensor_vec(&lhs, 4), + tensor_vec(&rhs, 4), + tensor_vec(&out, 4), + ); + let expected = include_str!("unary_bench.cu").replace("\r\n", "\n"); + assert_eq!(compile(kernel), expected); +} diff --git a/crates/cubecl-macros/tests/cuda/sequence_for_loop.cu b/crates/cubecl-macros/tests/cuda/sequence_for_loop.cu new file mode 100644 index 00000000..7f0630ba --- /dev/null +++ b/crates/cubecl-macros/tests/cuda/sequence_for_loop.cu @@ -0,0 +1,49 @@ +typedef unsigned int uint; + +extern "C" __global__ void kernel(float output_0[], uint info[]) { + + int threadIdxGlobal = threadIdx.x + threadIdx.y * blockDim.x + + threadIdx.z * (blockDim.x * blockDim.y); + uint rank = info[0]; + uint rank_2 = rank * 2; + bool l_0_0; + float l_0_1; + l_0_0 = threadIdxGlobal != uint(0); + if (l_0_0) { + return; + } + uint l_0_2; + bool l_0_3; + l_0_2 = info[(1 * 2 * info[0]) + 1]; + l_0_3 = uint(0) < l_0_2; + if (l_0_3) { + l_0_1 = output_0[uint(0)]; + } else { + l_0_1 = float(0.0); + } + l_0_1 = l_0_1 + float(1.0); + uint l_0_4; + bool l_0_5; + l_0_4 = info[(1 * 2 * info[0]) + 1]; + l_0_5 = uint(0) < l_0_4; + if (l_0_5) { + output_0[uint(0)] = l_0_1; + } + uint l_0_6; + bool l_0_7; + l_0_6 = info[(1 * 2 * info[0]) + 1]; + l_0_7 = uint(0) < l_0_6; + if (l_0_7) { + l_0_1 = output_0[uint(0)]; + } else { + l_0_1 = float(0.0); + } + l_0_1 = l_0_1 + float(4.0); + uint l_0_8; + bool l_0_9; + l_0_8 = info[(1 * 2 * info[0]) + 1]; + l_0_9 = uint(0) < l_0_8; + if (l_0_9) { + output_0[uint(0)] = l_0_1; + } +} \ No newline at end of file diff --git a/crates/cubecl-macros/tests/cuda/slice_assign.cu b/crates/cubecl-macros/tests/cuda/slice_assign.cu new file mode 100644 index 00000000..08afd9e4 --- /dev/null +++ b/crates/cubecl-macros/tests/cuda/slice_assign.cu @@ -0,0 +1,33 @@ +typedef unsigned int uint; + +extern "C" __global__ void kernel(float input_0[], float output_0[], + uint info[]) { + + int threadIdxGlobal = threadIdx.x + threadIdx.y * blockDim.x + + threadIdx.z * (blockDim.x * blockDim.y); + uint rank = info[0]; + uint rank_2 = rank * 2; + bool l_0_0; + float l_0_1; + l_0_0 = threadIdxGlobal == uint(0); + if (l_0_0) { + uint slice_1_0_length = uint(3) - uint(2); + float *slice_1_0 = output_0 + uint(2); + uint l_1_0; + bool l_1_1; + l_1_0 = info[(2 * 2 * info[0]) + 1]; + l_1_1 = uint(0) < l_1_0; + if (l_1_1) { + l_0_1 = input_0[uint(0)]; + } else { + l_0_1 = float(0.0); + } + uint l_1_2; + bool l_1_3; + l_1_2 = slice_1_0_length; + l_1_3 = uint(0) < l_1_2; + if (l_1_3) { + slice_1_0[uint(0)] = l_0_1; + } + } +} \ No newline at end of file diff --git a/crates/cubecl-macros/tests/cuda/subcube_sum.cu b/crates/cubecl-macros/tests/cuda/subcube_sum.cu new file mode 100644 index 00000000..addd20ab --- /dev/null +++ b/crates/cubecl-macros/tests/cuda/subcube_sum.cu @@ -0,0 +1,40 @@ +typedef unsigned int uint; + +extern "C" __global__ void kernel(float output_0[], uint info[]) { + + int threadIdxGlobal = threadIdx.x + threadIdx.y * blockDim.x + + threadIdx.z * (blockDim.x * blockDim.y); + + int warpSizeChecked = min(warpSize, blockDim.x * blockDim.y * blockDim.z); + uint rank = info[0]; + uint rank_2 = rank * 2; + float l_0_0; + float l_0_1; + bool l_0_2; + uint l_0_3; + bool l_0_4; + l_0_3 = info[(1 * 2 * info[0]) + 1]; + l_0_4 = threadIdxGlobal < l_0_3; + if (l_0_4) { + l_0_0 = output_0[threadIdxGlobal]; + } else { + l_0_0 = float(0.0); + } + + l_0_1 = l_0_0; + { + for (int offset = warpSizeChecked / 2; offset > 0; offset /= 2) { + l_0_1 += __shfl_down_sync(0xFFFFFFFF, l_0_1, offset); + } + } + l_0_2 = threadIdxGlobal == uint(0); + if (l_0_2) { + uint l_1_0; + bool l_1_1; + l_1_0 = info[(1 * 2 * info[0]) + 1]; + l_1_1 = uint(0) < l_1_0; + if (l_1_1) { + output_0[uint(0)] = l_0_1; + } + } +} \ No newline at end of file diff --git a/crates/cubecl-macros/tests/cuda/unary_bench.cu b/crates/cubecl-macros/tests/cuda/unary_bench.cu new file mode 100644 index 00000000..675c8bfc --- /dev/null +++ b/crates/cubecl-macros/tests/cuda/unary_bench.cu @@ -0,0 +1,149 @@ +typedef unsigned int uint; + +struct __align__(16) float_4 { + float i_0; + float i_1; + float i_2; + float i_3; +}; + +extern "C" __global__ void kernel(float_4 input_0[], float_4 input_1[], + float_4 output_0[], uint info[]) { + + int3 absoluteIdx = make_int3(blockIdx.x * blockDim.x + threadIdx.x, + blockIdx.y * blockDim.y + threadIdx.y, + blockIdx.z * blockDim.z + threadIdx.z); + + uint idxGlobal = + (absoluteIdx.z * gridDim.x * blockDim.x * gridDim.y * blockDim.y) + + (absoluteIdx.y * gridDim.x * blockDim.x) + absoluteIdx.x; + uint rank = info[0]; + uint rank_2 = rank * 2; + uint l_0_0; + bool l_0_1; + bool l_0_2; + float_4 l_0_3; + float_4 l_0_4; + l_0_0 = info[(3 * 2 * info[0]) + 3] / 4; + l_0_1 = idxGlobal < l_0_0; + if (l_0_1) { + + for (uint l_2_0 = uint(0); l_2_0 < uint(256); ++l_2_0) { + l_0_0 = l_2_0 % uint(2); + l_0_2 = l_0_0 == uint(0); + if (l_0_2) { + uint l_3_0; + bool l_3_1; + l_3_0 = info[(3 * 2 * info[0]) + 1] / 4; + l_3_1 = idxGlobal < l_3_0; + if (l_3_1) { + l_0_3 = input_0[idxGlobal]; + } else { + l_0_3.i_0 = float(0.0); + l_0_3.i_1 = float(0.0); + l_0_3.i_2 = float(0.0); + l_0_3.i_3 = float(0.0); + } + uint l_3_2; + bool l_3_3; + l_3_2 = info[(3 * 2 * info[0]) + 2] / 4; + l_3_3 = idxGlobal < l_3_2; + if (l_3_3) { + l_0_4 = input_1[idxGlobal]; + } else { + l_0_4.i_0 = float(0.0); + l_0_4.i_1 = float(0.0); + l_0_4.i_2 = float(0.0); + l_0_4.i_3 = float(0.0); + } + l_0_3.i_0 = l_0_3.i_0 * l_0_4.i_0; + l_0_3.i_1 = l_0_3.i_1 * l_0_4.i_1; + l_0_3.i_2 = l_0_3.i_2 * l_0_4.i_2; + l_0_3.i_3 = l_0_3.i_3 * l_0_4.i_3; + l_0_3.i_0 = cos(l_0_3.i_0); + l_0_3.i_1 = cos(l_0_3.i_1); + l_0_3.i_2 = cos(l_0_3.i_2); + l_0_3.i_3 = cos(l_0_3.i_3); + uint l_3_4; + bool l_3_5; + l_3_4 = info[(3 * 2 * info[0]) + 3] / 4; + l_3_5 = idxGlobal < l_3_4; + if (l_3_5) { + l_0_4 = output_0[idxGlobal]; + } else { + l_0_4.i_0 = float(0.0); + l_0_4.i_1 = float(0.0); + l_0_4.i_2 = float(0.0); + l_0_4.i_3 = float(0.0); + } + l_0_4.i_0 = l_0_4.i_0 - l_0_3.i_0; + l_0_4.i_1 = l_0_4.i_1 - l_0_3.i_1; + l_0_4.i_2 = l_0_4.i_2 - l_0_3.i_2; + l_0_4.i_3 = l_0_4.i_3 - l_0_3.i_3; + uint l_3_6; + bool l_3_7; + l_3_6 = info[(3 * 2 * info[0]) + 3] / 4; + l_3_7 = idxGlobal < l_3_6; + if (l_3_7) { + output_0[idxGlobal] = l_0_4; + } + } else { + uint l_3_0; + bool l_3_1; + l_3_0 = info[(3 * 2 * info[0]) + 1] / 4; + l_3_1 = idxGlobal < l_3_0; + if (l_3_1) { + l_0_4 = input_0[idxGlobal]; + } else { + l_0_4.i_0 = float(0.0); + l_0_4.i_1 = float(0.0); + l_0_4.i_2 = float(0.0); + l_0_4.i_3 = float(0.0); + } + uint l_3_2; + bool l_3_3; + l_3_2 = info[(3 * 2 * info[0]) + 2] / 4; + l_3_3 = idxGlobal < l_3_2; + if (l_3_3) { + l_0_3 = input_1[idxGlobal]; + } else { + l_0_3.i_0 = float(0.0); + l_0_3.i_1 = float(0.0); + l_0_3.i_2 = float(0.0); + l_0_3.i_3 = float(0.0); + } + l_0_4.i_0 = l_0_4.i_0 * l_0_3.i_0; + l_0_4.i_1 = l_0_4.i_1 * l_0_3.i_1; + l_0_4.i_2 = l_0_4.i_2 * l_0_3.i_2; + l_0_4.i_3 = l_0_4.i_3 * l_0_3.i_3; + l_0_4.i_0 = cos(l_0_4.i_0); + l_0_4.i_1 = cos(l_0_4.i_1); + l_0_4.i_2 = cos(l_0_4.i_2); + l_0_4.i_3 = cos(l_0_4.i_3); + uint l_3_4; + bool l_3_5; + l_3_4 = info[(3 * 2 * info[0]) + 3] / 4; + l_3_5 = idxGlobal < l_3_4; + if (l_3_5) { + l_0_3 = output_0[idxGlobal]; + } else { + l_0_3.i_0 = float(0.0); + l_0_3.i_1 = float(0.0); + l_0_3.i_2 = float(0.0); + l_0_3.i_3 = float(0.0); + } + l_0_3.i_0 = l_0_3.i_0 + l_0_4.i_0; + l_0_3.i_1 = l_0_3.i_1 + l_0_4.i_1; + l_0_3.i_2 = l_0_3.i_2 + l_0_4.i_2; + l_0_3.i_3 = l_0_3.i_3 + l_0_4.i_3; + uint l_3_6; + bool l_3_7; + l_3_6 = info[(3 * 2 * info[0]) + 3] / 4; + l_3_7 = idxGlobal < l_3_6; + if (l_3_7) { + output_0[idxGlobal] = l_0_3; + } + } + } + } +} \ No newline at end of file diff --git a/crates/cubecl-macros/tests/wgpu/common.rs b/crates/cubecl-macros/tests/wgpu/common.rs new file mode 100644 index 00000000..f7734d7a --- /dev/null +++ b/crates/cubecl-macros/tests/wgpu/common.rs @@ -0,0 +1,42 @@ +use cubecl_core::{ + client::ComputeClient, + prelude::{ArrayArg, TensorArg}, + server, Compiler, ExecutionMode, Kernel, Runtime, +}; +use cubecl_wgpu::{WgpuDevice, WgpuRuntime}; + +type Client = ComputeClient<::Server, ::Channel>; +type Handle = server::Handle<::Server>; + +pub fn client() -> Client { + let device = WgpuDevice::default(); + WgpuRuntime::client(&device) +} + +#[allow(unused)] +pub fn handle(client: &Client) -> Handle { + client.empty(1) +} + +#[allow(unused)] +pub fn tensor(tensor: &Handle) -> TensorArg<'_, WgpuRuntime> { + unsafe { TensorArg::from_raw_parts(tensor, &[1], &[1], 1) } +} + +#[allow(unused)] +pub fn tensor_vec(tensor: &Handle, vectorization: u8) -> TensorArg<'_, WgpuRuntime> { + unsafe { TensorArg::from_raw_parts(tensor, &[1], &[1], vectorization) } +} + +#[allow(unused)] +pub fn array(tensor: &Handle) -> ArrayArg<'_, WgpuRuntime> { + unsafe { ArrayArg::from_raw_parts(tensor, 1, 1) } +} + +pub fn compile(kernel: impl Kernel) -> String { + <::Compiler as Compiler>::compile( + kernel.define(), + ExecutionMode::Checked, + ) + .to_string() +} diff --git a/crates/cubecl-macros/tests/wgpu/main.rs b/crates/cubecl-macros/tests/wgpu/main.rs new file mode 100644 index 00000000..27941c1f --- /dev/null +++ b/crates/cubecl-macros/tests/wgpu/main.rs @@ -0,0 +1,115 @@ +use common::*; +use cubecl_core as cubecl; +use cubecl_core::{prelude::*, CubeCount, CubeDim}; +use cubecl_wgpu::WgpuRuntime; +use pretty_assertions::assert_eq; + +mod common; + +#[cube(launch_unchecked, create_dummy_kernel)] +pub fn slice_assign_kernel(input: &Tensor, output: &mut Tensor) { + if UNIT_POS == 0 { + let slice_1 = output.slice_mut(2, 3); + slice_1[0] = input[0]; + } +} + +#[test] +pub fn slice_assign() { + let client = client(); + let input = handle(&client); + let output = handle(&client); + + let kernel = slice_assign_kernel::create_dummy_kernel::( + CubeCount::Static(1, 1, 1), + CubeDim::new(1, 1, 1), + tensor(&input), + tensor(&output), + ); + let expected = include_str!("slice_assign.wgsl").replace("\r\n", "\n"); + assert_eq!(compile(kernel), expected); +} + +#[cube(launch, create_dummy_kernel)] +pub fn kernel_sum(output: &mut Tensor) { + let val = output[UNIT_POS]; + let val2 = cubecl_core::prelude::subcube_sum(val); + + if UNIT_POS == 0 { + output[0] = val2; + } +} + +#[test] +pub fn subcube_sum() { + let client = client(); + let output = handle(&client); + + let kernel = kernel_sum::create_dummy_kernel::( + CubeCount::Static(1, 1, 1), + CubeDim::new(4, 1, 1), + tensor(&output), + ); + let expected = include_str!("subcube_sum.wgsl").replace("\r\n", "\n"); + assert_eq!(compile(kernel), expected); +} + +#[cube(launch, create_dummy_kernel)] +pub fn sequence_for_loop_kernel(output: &mut Array) { + if UNIT_POS != 0 { + return; + } + + let mut sequence = Sequence::::new(); + sequence.push(1.0); + sequence.push(4.0); + + for value in sequence { + output[0] += value; + } +} + +#[test] +pub fn sequence_for_loop() { + let client = client(); + let output = handle(&client); + + let kernel = sequence_for_loop_kernel::create_dummy_kernel::( + CubeCount::Static(1, 1, 1), + CubeDim::default(), + array(&output), + ); + let expected = include_str!("sequence_for_loop.wgsl").replace("\r\n", "\n"); + assert_eq!(compile(kernel), expected); +} + +#[cube(launch, create_dummy_kernel)] +fn execute_unary_kernel(lhs: &Tensor, rhs: &Tensor, out: &mut Tensor) { + if ABSOLUTE_POS < out.len() { + for i in 0..256u32 { + if i % 2 == 0 { + out[ABSOLUTE_POS] -= F::cos(lhs[ABSOLUTE_POS] * rhs[ABSOLUTE_POS]); + } else { + out[ABSOLUTE_POS] += F::cos(lhs[ABSOLUTE_POS] * rhs[ABSOLUTE_POS]); + } + } + } +} + +#[test] +pub fn unary_bench() { + let client = client(); + let lhs = handle(&client); + let rhs = handle(&client); + let out = handle(&client); + + let kernel = execute_unary_kernel::create_dummy_kernel::( + CubeCount::Static(1, 1, 1), + CubeDim::default(), + tensor_vec(&lhs, 4), + tensor_vec(&rhs, 4), + tensor_vec(&out, 4), + ); + let expected = include_str!("unary_bench.wgsl").replace("\r\n", "\n"); + assert_eq!(compile(kernel), expected); +} diff --git a/crates/cubecl-macros/tests/wgpu/sequence_for_loop.wgsl b/crates/cubecl-macros/tests/wgpu/sequence_for_loop.wgsl new file mode 100644 index 00000000..dda059e8 --- /dev/null +++ b/crates/cubecl-macros/tests/wgpu/sequence_for_loop.wgsl @@ -0,0 +1,30 @@ +@group(0) +@binding(0) +var output_0_global: array; + +@group(0) +@binding(1) +var info: array; + +const WORKGROUP_SIZE_X = 16u; +const WORKGROUP_SIZE_Y = 16u; +const WORKGROUP_SIZE_Z = 1u; + +@compute +@workgroup_size(16, 16, 1) +fn main( + @builtin(local_invocation_index) local_idx: u32, +) {let rank: u32 = info[0]; +var l_0_0: bool; +var l_0_1: f32; +l_0_0 = local_idx != 0u; +if l_0_0 { +return; +} +l_0_1 = output_0_global[0u]; +l_0_1 = l_0_1 + 1f; +output_0_global[0u] = f32(l_0_1); +l_0_1 = output_0_global[0u]; +l_0_1 = l_0_1 + 4f; +output_0_global[0u] = f32(l_0_1); +} \ No newline at end of file diff --git a/crates/cubecl-macros/tests/wgpu/slice_assign.wgsl b/crates/cubecl-macros/tests/wgpu/slice_assign.wgsl new file mode 100644 index 00000000..9abd5158 --- /dev/null +++ b/crates/cubecl-macros/tests/wgpu/slice_assign.wgsl @@ -0,0 +1,32 @@ +@group(0) +@binding(0) +var input_0_global: array; + +@group(0) +@binding(1) +var output_0_global: array; + +@group(0) +@binding(2) +var info: array; + +const WORKGROUP_SIZE_X = 1u; +const WORKGROUP_SIZE_Y = 1u; +const WORKGROUP_SIZE_Z = 1u; + +@compute +@workgroup_size(1, 1, 1) +fn main( + @builtin(local_invocation_index) local_idx: u32, +) {let rank: u32 = info[0]; +var l_0_0: bool; +var l_0_1: f32; +l_0_0 = local_idx == 0u; +if l_0_0 { +let slice_1_0_offset = 2u; +let slice_1_0_length = 3u - 2u; +let slice_1_0_ptr = &output_0_global; +l_0_1 = input_0_global[0u]; +(*slice_1_0_ptr)[0u + slice_1_0_offset] = f32(l_0_1); +} +} \ No newline at end of file diff --git a/crates/cubecl-macros/tests/wgpu/subcube_sum.wgsl b/crates/cubecl-macros/tests/wgpu/subcube_sum.wgsl new file mode 100644 index 00000000..eb10db45 --- /dev/null +++ b/crates/cubecl-macros/tests/wgpu/subcube_sum.wgsl @@ -0,0 +1,27 @@ +@group(0) +@binding(0) +var output_0_global: array; + +@group(0) +@binding(1) +var info: array; + +const WORKGROUP_SIZE_X = 4u; +const WORKGROUP_SIZE_Y = 1u; +const WORKGROUP_SIZE_Z = 1u; + +@compute +@workgroup_size(4, 1, 1) +fn main( + @builtin(local_invocation_index) local_idx: u32, +) {let rank: u32 = info[0]; +var l_0_0: f32; +var l_0_1: f32; +var l_0_2: bool; +l_0_0 = output_0_global[local_idx]; +l_0_1 = subgroupAdd(l_0_0); +l_0_2 = local_idx == 0u; +if l_0_2 { +output_0_global[0u] = f32(l_0_1); +} +} \ No newline at end of file diff --git a/crates/cubecl-macros/tests/wgpu/unary_bench.wgsl b/crates/cubecl-macros/tests/wgpu/unary_bench.wgsl new file mode 100644 index 00000000..d8684e82 --- /dev/null +++ b/crates/cubecl-macros/tests/wgpu/unary_bench.wgsl @@ -0,0 +1,59 @@ +@group(0) +@binding(0) +var input_0_global: array>; + +@group(0) +@binding(1) +var input_1_global: array>; + +@group(0) +@binding(2) +var output_0_global: array>; + +@group(0) +@binding(3) +var info: array; + +const WORKGROUP_SIZE_X = 16u; +const WORKGROUP_SIZE_Y = 16u; +const WORKGROUP_SIZE_Z = 1u; + +@compute +@workgroup_size(16, 16, 1) +fn main( + @builtin(global_invocation_id) global_id: vec3, + @builtin(num_workgroups) num_workgroups: vec3, +) {let id = (global_id.z * num_workgroups.x * WORKGROUP_SIZE_X * num_workgroups.y * WORKGROUP_SIZE_Y) + (global_id.y * num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x; +let rank: u32 = info[0]; +var l_0_0: u32; +var l_0_1: bool; +var l_0_2: bool; +var l_0_3: vec4; +var l_0_4: vec4; +l_0_0 = arrayLength(&output_0_global); +l_0_1 = id < l_0_0; +if l_0_1 { + +for (var l_2_0: u32 = 0u; l_2_0 < 256u; l_2_0++) { +l_0_0 = l_2_0 % 2u; +l_0_2 = l_0_0 == 0u; +if l_0_2 { +l_0_3 = input_0_global[id]; +l_0_4 = input_1_global[id]; +l_0_3 = l_0_3 * l_0_4; +l_0_3 = cos(l_0_3); +l_0_4 = output_0_global[id]; +l_0_4 = l_0_4 - l_0_3; +output_0_global[id] = vec4(l_0_4); +} else { +l_0_4 = input_0_global[id]; +l_0_3 = input_1_global[id]; +l_0_4 = l_0_4 * l_0_3; +l_0_4 = cos(l_0_4); +l_0_3 = output_0_global[id]; +l_0_3 = l_0_3 + l_0_4; +output_0_global[id] = vec4(l_0_3); +} +} +} +} \ No newline at end of file diff --git a/crates/cubecl-wgpu/Cargo.toml b/crates/cubecl-wgpu/Cargo.toml index 4c09350f..a95a31e7 100644 --- a/crates/cubecl-wgpu/Cargo.toml +++ b/crates/cubecl-wgpu/Cargo.toml @@ -16,24 +16,24 @@ default = [ "cubecl-common/default", "cubecl-core/default", ] -std = ["cubecl-runtime/std", "cubecl-common/std", "cubecl-core/std"] simple-memory-management = [] +std = ["cubecl-runtime/std", "cubecl-common/std", "cubecl-core/std"] [dependencies] +cubecl-common = { path = "../cubecl-common", version = "0.2.0" } +cubecl-core = { path = "../cubecl-core", version = "0.2.0" } cubecl-runtime = { path = "../cubecl-runtime", version = "0.2.0", default-features = false, features = [ "channel-mutex", ] } -cubecl-common = { path = "../cubecl-common", version = "0.2.0" } -cubecl-core = { path = "../cubecl-core", version = "0.2.0" } bytemuck = { workspace = true } -wgpu = { version = "22.0.0", features = ["fragile-send-sync-non-atomic-wasm"] } pollster = { workspace = true } +wgpu = { version = "22.0.0", features = ["fragile-send-sync-non-atomic-wasm"] } -log = { workspace = true } async-channel = { workspace = true } derive-new = { workspace = true } hashbrown = { workspace = true } +log = { workspace = true } [dev-dependencies] cubecl-core = { path = "../cubecl-core", version = "0.2.0", features = [ @@ -42,6 +42,7 @@ cubecl-core = { path = "../cubecl-core", version = "0.2.0", features = [ cubecl-linalg = { path = "../cubecl-linalg", version = "0.2.0", features = [ "export_tests", ] } +pretty_assertions = { workspace = true } [build-dependencies] cfg_aliases = "0.2.1" diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs b/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs index ac8492b7..3d0b90ad 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs @@ -102,7 +102,7 @@ impl WgslCompiler { fn compile_item(item: cube::Item) -> Item { let elem = Self::compile_elem(item.elem); - match item.vectorization { + match item.vectorization.map(|it| it.get()).unwrap_or(1) { 1 => wgsl::Item::Scalar(elem), 2 => wgsl::Item::Vec2(elem), 3 => wgsl::Item::Vec3(elem), @@ -133,7 +133,7 @@ impl WgslCompiler { } } - fn compile_variable(&mut self, value: cube::Variable) -> wgsl::Variable { + pub(crate) fn compile_variable(&mut self, value: cube::Variable) -> wgsl::Variable { match value { cube::Variable::GlobalInputArray { id, item } => { wgsl::Variable::GlobalInputArray(id, Self::compile_item(item)) @@ -374,6 +374,7 @@ impl WgslCompiler { start: self.compile_variable(range_loop.start), end: self.compile_variable(range_loop.end), step: range_loop.step.map(|it| self.compile_variable(it)), + inclusive: range_loop.inclusive, instructions: self.compile_scope(&mut range_loop.scope), }) } @@ -749,6 +750,10 @@ impl WgslCompiler { rhs: self.compile_variable(op.rhs), out: self.compile_variable(op.out), }, + cube::Operator::Neg(op) => wgsl::Instruction::Negate { + input: self.compile_variable(op.input), + out: self.compile_variable(op.out), + }, cube::Operator::Normalize(op) => wgsl::Instruction::Normalize { input: self.compile_variable(op.input), out: self.compile_variable(op.out), diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs b/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs index b6dd4e6e..005b6b6f 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs @@ -182,6 +182,7 @@ pub enum Instruction { start: Variable, end: Variable, step: Option, + inclusive: bool, instructions: Vec, }, And { @@ -308,6 +309,10 @@ pub enum Instruction { out: Variable, }, Subgroup(Subgroup), + Negate { + input: Variable, + out: Variable, + }, Normalize { input: Variable, out: Variable, @@ -394,9 +399,18 @@ impl Display for Instruction { Instruction::Modulo { lhs, rhs, out } => { f.write_fmt(format_args!("{out} = {lhs} % {rhs};\n")) } - Instruction::Remainder { lhs, rhs, out } => f.write_fmt(format_args!( - "{out} = {lhs} - {rhs} * floor({lhs} / {rhs});\n" - )), + Instruction::Remainder { lhs, rhs, out } => { + let f_type = match lhs.item() { + Item::Vec4(_) => Item::Vec4(Elem::F32), + Item::Vec3(_) => Item::Vec3(Elem::F32), + Item::Vec2(_) => Item::Vec2(Elem::F32), + Item::Scalar(_) => Item::Scalar(Elem::F32), + }; + let ty = lhs.item(); + f.write_fmt(format_args!( + "{out} = {lhs} - {rhs} * {ty}(floor({f_type}({lhs}) / {f_type}({rhs})));\n" + )) + } Instruction::Sub { lhs, rhs, out } => { if out.is_atomic() { assert_eq!(lhs, out, "Can't use regular sub on atomic"); @@ -531,16 +545,18 @@ impl Display for Instruction { start, end, step, + inclusive, instructions, } => { let increment = step .as_ref() .map(|step| format!("{i} += {step}")) .unwrap_or_else(|| format!("{i}++")); + let cmp = if *inclusive { "<=" } else { "<" }; f.write_fmt(format_args!( " -for (var {i}: u32 = {start}; {i} < {end}; {increment}) {{ +for (var {i}: u32 = {start}; {i} {cmp} {end}; {increment}) {{ " ))?; for instruction in instructions { @@ -671,6 +687,7 @@ for (var {i}: u32 = {start}; {i} < {end}; {increment}) {{ // For compatibility with cuda, only return old_value "{out} = atomicCompareExchangeWeak({lhs}, {cmp}, {value}).old_value;\n" )), + Instruction::Negate { input, out } => f.write_fmt(format_args!("{out} = -{input};\n")), Instruction::Normalize { input, out } => { f.write_fmt(format_args!("{out} = normalize({input});\n")) } diff --git a/crates/cubecl/Cargo.toml b/crates/cubecl/Cargo.toml index f861668c..9492749c 100644 --- a/crates/cubecl/Cargo.toml +++ b/crates/cubecl/Cargo.toml @@ -3,41 +3,44 @@ authors = ["nathanielsimard "] categories = ["science", "mathematics", "algorithms"] description = "Multi-platform high-performance compute language extension for Rust." edition.workspace = true -keywords = [ - "gpu", - "cuda", - "wgpu", - "gpgpu", - "tensor", -] +keywords = ["gpu", "cuda", "wgpu", "gpgpu", "tensor"] license.workspace = true name = "cubecl" readme.workspace = true repository = "https://github.com/tracel-ai/cubecl" -version.workspace = true rust-version = "1.79" +version.workspace = true [features] -default = ["std", "linalg", "cubecl-core/default", "cubecl-wgpu?/default", "cubecl-cuda?/default"] -std = ["cubecl-core/std", "cubecl-wgpu?/std", "cubecl-cuda?/std"] -template = ["cubecl-core/template"] +default = [ + "std", + "linalg", + "cubecl-core/default", + "cubecl-wgpu?/default", + "cubecl-cuda?/default", +] linalg = ["dep:cubecl-linalg"] simple-memory-management = ["cubecl-wgpu?/simple-memory-management"] +std = ["cubecl-core/std", "cubecl-wgpu?/std", "cubecl-cuda?/std"] +template = ["cubecl-core/template"] # Runtimes -wgpu = ["cubecl-wgpu"] cuda = ["cubecl-cuda"] +wgpu = ["cubecl-wgpu"] [dependencies] cubecl-core = { path = "../cubecl-core", version = "0.2.0", default-features = false } -cubecl-wgpu = { path = "../cubecl-wgpu", version = "0.2.0", default-features = false, optional = true } cubecl-cuda = { path = "../cubecl-cuda", version = "0.2.0", default-features = false, optional = true } cubecl-linalg = { path = "../cubecl-linalg", version = "0.2.0", default-features = false, optional = true } +cubecl-wgpu = { path = "../cubecl-wgpu", version = "0.2.0", default-features = false, optional = true } + +[dev-dependencies] +half = { workspace = true } [[bench]] -name = "matmul" harness = false +name = "matmul" [[bench]] -name = "unary" harness = false +name = "unary" diff --git a/crates/cubecl/benches/unary.rs b/crates/cubecl/benches/unary.rs index 99ab027d..8e575c3f 100644 --- a/crates/cubecl/benches/unary.rs +++ b/crates/cubecl/benches/unary.rs @@ -1,16 +1,18 @@ -use cubecl::{calculate_cube_count_elemwise, prelude::*}; +use cubecl::{calculate_cube_count_elemwise, frontend, prelude::*}; use std::marker::PhantomData; +#[cfg(feature = "cuda")] +use half::f16; + use cubecl::benchmark::Benchmark; use cubecl::client::SyncType; -use cubecl::frontend::Float; use cubecl_linalg::tensor::TensorHandle; #[cube(launch)] fn execute(lhs: &Tensor, rhs: &Tensor, out: &mut Tensor) { if ABSOLUTE_POS < out.len() { - for i in range(0, 256, Comptime::new(false)) { - if i % UInt::new(2) == UInt::new(0) { + for i in 0..256u32 { + if i % 2 == 0 { out[ABSOLUTE_POS] -= F::cos(lhs[ABSOLUTE_POS] * rhs[ABSOLUTE_POS]); } else { out[ABSOLUTE_POS] += F::cos(lhs[ABSOLUTE_POS] * rhs[ABSOLUTE_POS]); @@ -86,7 +88,7 @@ enum MatmulKind { } #[allow(dead_code)] -fn run(device: R::Device, vectorization: u8) { +fn run(device: R::Device, vectorization: u8) { let bench = UnaryBench:: { shape: vec![32, 512, 2048], vectorization, @@ -100,11 +102,11 @@ fn run(device: R::Device, vectorization: u8) { fn main() { #[cfg(feature = "cuda")] - run::(Default::default(), 8); + run::(Default::default(), 8); #[cfg(feature = "cuda")] - run::(Default::default(), 4); + run::(Default::default(), 4); #[cfg(feature = "wgpu")] - run::(Default::default(), 1); + run::(Default::default(), 1); #[cfg(feature = "wgpu")] - run::(Default::default(), 4); + run::(Default::default(), 4); } diff --git a/examples/gelu/src/lib.rs b/examples/gelu/src/lib.rs index 80d76c94..b6427a71 100644 --- a/examples/gelu/src/lib.rs +++ b/examples/gelu/src/lib.rs @@ -9,7 +9,7 @@ fn gelu_array(input: &Array, output: &mut Array) { #[cube] fn gelu_scalar(x: F) -> F { - x * (F::erf(x / F::sqrt(2.0.into())) + 1.0) / 2.0 + x * F::erf(x / F::new(2.0f32.sqrt()) + F::new(1.0)) / F::new(2.0) } pub fn launch(device: &R::Device) { @@ -19,7 +19,7 @@ pub fn launch(device: &R::Device) { let input_handle = client.create(f32::as_bytes(input)); unsafe { - gelu_array::launch_unchecked::( + gelu_array::launch_unchecked::( &client, CubeCount::Static(1, 1, 1), CubeDim::new(input.len() as u32, 1, 1), diff --git a/examples/normalization/src/lib.rs b/examples/normalization/src/lib.rs index d3f9c5d2..c7f288cd 100644 --- a/examples/normalization/src/lib.rs +++ b/examples/normalization/src/lib.rs @@ -14,7 +14,7 @@ pub fn launch(device: &R::Device) { let input_handle = client.create(f32::as_bytes(input)); unsafe { - norm_test::launch_unchecked::( + norm_test::launch_unchecked::( &client, CubeCount::Static(1, 1, 1), CubeDim::new(input.len() as u32, 1, 1), diff --git a/profiling/matmul-example/Cargo.toml b/profiling/matmul-example/Cargo.toml index bbfa1830..3af1aef1 100644 --- a/profiling/matmul-example/Cargo.toml +++ b/profiling/matmul-example/Cargo.toml @@ -1,18 +1,20 @@ [package] -name = "matmul-example" edition.workspace = true -version.workspace = true license.workspace = true +name = "matmul-example" readme.workspace = true +version.workspace = true [dependencies] +burn = { git = "https://github.com/tracel-ai/burn", optional = true, features = [ + "tch", +] } +burn-tensor = { git = "https://github.com/tracel-ai/burn", optional = true } cubecl = { version = "0.1.0", path = "../../crates/cubecl", features = [ - "cuda", "linalg", ], optional = true } -burn = { git = "https://github.com/tracel-ai/burn", optional = true, features = ["tch"] } -burn-tensor = { git = "https://github.com/tracel-ai/burn", optional = true } [features] burn-tch-cuda = ["burn"] -cube-cuda = ["cubecl"] +cube-cuda = ["cubecl/cuda"] +cube-wgpu = ["cubecl/wgpu"] diff --git a/profiling/matmul-example/src/main.rs b/profiling/matmul-example/src/main.rs index 08a28e42..6bb16f7a 100644 --- a/profiling/matmul-example/src/main.rs +++ b/profiling/matmul-example/src/main.rs @@ -3,6 +3,8 @@ fn main() { tch_gpu::run(); #[cfg(feature = "cube-cuda")] cube_cuda::run(); + #[cfg(feature = "cube-wgpu")] + cube_wgpu::run(); } #[cfg(feature = "burn-tch-cuda")] @@ -56,3 +58,40 @@ mod cube_cuda { tiling2d::launch(&client, tensor_a, tensor_b, tensor_c, Default::default()); } } + +#[cfg(feature = "cube-wgpu")] +mod cube_wgpu { + use cubecl::frontend::F32; + use cubecl::linalg::{matmul::tiling2d, tensor::TensorHandle}; + use cubecl::prelude::*; + use cubecl::wgpu::{WgpuDevice, WgpuRuntime}; + use cubecl::Runtime; + + pub fn run() { + let device = WgpuDevice::default(); + let client = WgpuRuntime::client(&device); + + let num_of_batch = 12; + let heigth = 1024; + let width = 1024; + + let tensor_values: Vec = (0..num_of_batch * heigth * width) + .map(|x| x as f32) + .collect(); + let tensor_a_handle = client.create(f32::as_bytes(&tensor_values)); + let tensor_b_handle = client.create(f32::as_bytes(&tensor_values)); + let tensor_c_handle = client.empty(12 * 1024 * 1024 * core::mem::size_of::()); + + let tensor_a_shape = vec![num_of_batch, heigth, width]; + let tensor_b_shape = vec![num_of_batch, heigth, width]; + let tensor_c_shape = vec![num_of_batch, heigth, width]; + + let tensor_a: TensorHandle = + TensorHandle::new_contiguous(tensor_a_shape, tensor_a_handle); + let tensor_b: TensorHandle = + TensorHandle::new_contiguous(tensor_b_shape, tensor_b_handle); + let tensor_c: TensorHandle = + TensorHandle::new_contiguous(tensor_c_shape, tensor_c_handle); + tiling2d::launch(&client, tensor_a, tensor_b, tensor_c, Default::default()); + } +}