From c00de60052af985e054f68824e0cc8228274c272 Mon Sep 17 00:00:00 2001 From: Guillaume Charifi-Hoareau Date: Tue, 10 Sep 2024 08:19:28 +0200 Subject: [PATCH] frontend: Add support for numeric constants. --- Cargo.toml | 1 + crates/cubecl-core/Cargo.toml | 1 + .../cubecl-core/src/frontend/element/float.rs | 22 +- .../cubecl-core/src/frontend/element/int.rs | 5 +- .../src/frontend/element/numeric.rs | 7 +- .../src/frontend/operation/constants.rs | 235 ++++++++++++++++++ .../cubecl-core/src/frontend/operation/mod.rs | 2 + 7 files changed, 258 insertions(+), 15 deletions(-) create mode 100644 crates/cubecl-core/src/frontend/operation/constants.rs diff --git a/Cargo.toml b/Cargo.toml index 0b78e5633..4c1ca5f24 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -60,6 +60,7 @@ num-traits = { version = "0.2.19", default-features = false, features = [ proc-macro2 = "1.0.86" syn = { version = "2.0.69", features = ["full", "extra-traits"] } quote = "1.0.36" +paste = "1.0.15" ### For xtask crate ### strum = {version = "0.26.3", features = ["derive"]} diff --git a/crates/cubecl-core/Cargo.toml b/crates/cubecl-core/Cargo.toml index d2e4b0f66..d72323a16 100644 --- a/crates/cubecl-core/Cargo.toml +++ b/crates/cubecl-core/Cargo.toml @@ -28,6 +28,7 @@ serde = { workspace = true } cubecl-macros = { path = "../cubecl-macros", version = "0.2.0" } derive-new = { workspace = true } num-traits = { workspace = true } +paste = { workspace = true } log = { workspace = true } diff --git a/crates/cubecl-core/src/frontend/element/float.rs b/crates/cubecl-core/src/frontend/element/float.rs index 4201c20fe..f3ae2419d 100644 --- a/crates/cubecl-core/src/frontend/element/float.rs +++ b/crates/cubecl-core/src/frontend/element/float.rs @@ -1,7 +1,8 @@ use half::{bf16, f16}; use crate::frontend::{ - Ceil, Cos, Erf, Exp, Floor, Log, Log1p, Normalize, Powf, Recip, Round, Sin, Sqrt, Tanh, + Ceil, Cos, Erf, Exp, FloatConstants, Floor, Log, Log1p, Normalize, Powf, Recip, Round, Sin, + Sqrt, Tanh, }; use crate::frontend::{ ComptimeType, CubeContext, CubePrimitive, CubeType, ExpandElement, ExpandElementBaseInit, @@ -19,20 +20,21 @@ use crate::Runtime; /// Floating point numbers. Used as input in float kernels pub trait Float: Numeric + + Ceil + + Cos + + Erf + Exp + + FloatConstants + + Floor + Log + Log1p - + Cos - + Sin - + Tanh + + Normalize + Powf - + Sqrt - + Round - + Floor - + Ceil - + Erf + Recip - + Normalize + + Round + + Sin + + Sqrt + + Tanh + From + core::ops::Add + core::ops::Sub diff --git a/crates/cubecl-core/src/frontend/element/int.rs b/crates/cubecl-core/src/frontend/element/int.rs index 7579ea79d..fe4fc936b 100644 --- a/crates/cubecl-core/src/frontend/element/int.rs +++ b/crates/cubecl-core/src/frontend/element/int.rs @@ -1,7 +1,7 @@ use crate::compute::{KernelBuilder, KernelLauncher}; use crate::frontend::{ ComptimeType, CubeContext, CubePrimitive, CubeType, ExpandElement, ExpandElementBaseInit, - ExpandElementTyped, Numeric, + ExpandElementTyped, IntConstants, Numeric, }; use crate::ir::{ConstantScalarValue, Elem, IntKind, Variable, Vectorization}; use crate::Runtime; @@ -14,8 +14,9 @@ use super::{ /// Signed integer. Used as input in int kernels pub trait Int: Numeric - + std::ops::Rem + + IntConstants + From + + std::ops::Rem + core::ops::Add + core::ops::Sub + core::ops::Mul diff --git a/crates/cubecl-core/src/frontend/element/numeric.rs b/crates/cubecl-core/src/frontend/element/numeric.rs index 0d57aa5a0..5b3b84535 100644 --- a/crates/cubecl-core/src/frontend/element/numeric.rs +++ b/crates/cubecl-core/src/frontend/element/numeric.rs @@ -4,7 +4,7 @@ use crate::ir::{Item, Variable}; use crate::prelude::Clamp; use crate::Runtime; use crate::{ - frontend::{index_assign, Abs, Max, Min, Remainder}, + frontend::{index_assign, Abs, Max, Min, NumConstants, Remainder}, unexpanded, }; @@ -18,13 +18,15 @@ use super::{ pub trait Numeric: Copy + Abs + + Clamp + Max + Min - + Clamp + + NumConstants + Remainder + ExpandElementBaseInit + CubePrimitive + LaunchArgExpand + + From + std::ops::AddAssign + std::ops::SubAssign + std::ops::MulAssign @@ -38,7 +40,6 @@ pub trait Numeric: + core::ops::IndexMut + core::ops::Index + core::ops::IndexMut - + From + std::ops::Add + std::ops::Sub + std::ops::Mul diff --git a/crates/cubecl-core/src/frontend/operation/constants.rs b/crates/cubecl-core/src/frontend/operation/constants.rs new file mode 100644 index 000000000..453fa2318 --- /dev/null +++ b/crates/cubecl-core/src/frontend/operation/constants.rs @@ -0,0 +1,235 @@ +use crate::{ + prelude::{ + CubeContext, CubePrimitive, ExpandElementTyped, UInt, BF16, F16, F32, F64, I32, I64, + }, + unexpanded, +}; + +use half::{bf16, f16}; +use paste::paste; + +macro_rules! tt_as_expr { + ($t:expr) => { + $t + }; +} + +macro_rules! tt_as_ty { + (( $t:ty )) => { + $t + }; + ([ $t:ty ]) => { + $t + }; + ({ $t:ty }) => { + $t + }; + ($t:ty) => { + $t + }; +} + +macro_rules! impl_op { + ($tr:ident { $( $name:ident $(-> { $($ret_type:tt)+ } )? ),* } => { $($type:ty | $expr:tt),* }) => { + paste! { + impl_op!(@default + [], + $tr, + { $({ $name, [], [<__expand_lim_ $name:lower>], $( { $($ret_type)+ } )? })* }, + { $($type, $expr);* } + ); + } + }; + + // Default $ret_type to {Self} + (@default [$($acc:tt)*], $tr:ident, {}, $types:tt) => { + impl_op!(@internal + $tr, + { $($acc);* }, + $types + ); + }; + (@default [$($acc:tt)*], $tr:ident, { { $name:ident, $unexpand_name:ident, $expand_name:ident, } $($tail:tt)* }, $types:tt) => { + impl_op!(@default + [ $($acc)* { $name, $unexpand_name, $expand_name, {Self} } ], + $tr, + { $($tail)* }, + $types + ); + }; + (@default [$($acc:tt)*], $tr:ident, { { $name:ident, $unexpand_name:ident, $expand_name:ident, $ret_type:tt } $($tail:tt)* }, $types:tt) => { + impl_op!(@default + [ $($acc)* { $name, $unexpand_name, $expand_name, $ret_type } ], + $tr, + { $($tail)* }, + $types + ); + }; + + // Generate trait declarations and impl blocks + (@internal $tr:ident, $values:tt, $types:tt) => { + impl_op!(@decls $tr, $values); + impl_op!(@impls $tr, $types, $values); + }; + + // Generate trait declarations + (@decls $tr:ident, { $({ $name:ident, $unexpand_name:ident, $expand_name:ident, { $ret_type:ty } });* }) => { + pub trait $tr: CubePrimitive + Sized { + $( + fn $unexpand_name() -> $ret_type { + unexpanded!() + } + fn $expand_name(_context: &mut CubeContext) -> ExpandElementTyped<$ret_type>; + )* + } + }; + + // Generate impl blocks + (@impls $tr:ident, { $($type:ty, $e:tt);* }, $values:tt) => { + $( + impl $tr for $type { + impl_op!(@impl $e, $values); + } + )* + }; + + // Generate impl block contents + (@impl $e:tt, { $({ $name:ident, $unexpand_name:ident, $expand_name:ident, $ret_type:tt });* }) => { + $( + fn $expand_name(_context: &mut CubeContext) -> ExpandElementTyped { + impl_op!(@implbody $e, $name, $ret_type) + } + )* + }; + + (@implbody { $({ $($select_type:tt)+ } => $e:tt),+; $default_e:tt }, $name:ident, $ret_type:tt) => { + impl_op!(@implbody $({ { $($select_type)+ }, $e }),+; $default_e, $name, $ret_type) + }; + + (@implbody { $select_type:tt, $e:tt } $(, $acc:tt)*; $default_e:tt, $name:ident, $ret_type:tt) => {{ + macro_rules! __emit_on_match { + ($select_type, $select_type) => { impl_op!(@implbody $e, $name, { $ret_type }) }; + ($select_type, $ret_type) => { impl_op!(@implbody $($acc),*; $default_e, $name, $ret_type) }; + } + + __emit_on_match!($select_type, $ret_type) + }}; + + (@implbody ; $default_e:tt, $name:ident, $($unused:tt)*) => { + impl_op!(@implbody $default_e, $name,) + }; + + (@implbody ( $($e:tt)* ), $name:ident, $($unused:tt)*) => { + ExpandElementTyped::from_lit(tt_as_expr!(impl_op!(@subst $name, [], $($e)*))) + }; + + // Handle @name substitution + + // No more tokens to process, return the token tree + (@subst $name:ident, [$($acc:tt)*],) => { + $($acc)* + }; + + // Handle replacement + (@subst $name:ident, [$($acc:tt)*], @name $($tail:tt)*) => { + impl_op!(@subst $name, + [ $($acc)* $name ], + $($tail)* + ) + }; + + // Handle empty token trees + (@subst $name:ident, [$($acc:tt)*], () $($tail:tt)*) => { + impl_op!(@subst $name, + [ $($acc)* () ], + $($tail)* + ) + }; + (@subst $name:ident, [$($acc:tt)*], [] $($tail:tt)*) => { + impl_op!(@subst $name, + [ $($acc)* [] ], + $($tail)* + ) + }; + (@subst $name:ident, [$($acc:tt)*], {} $($tail:tt)*) => { + impl_op!(@subst $name, + [ $($acc)* {} ], + $($tail)* + ) + }; + + // Handle token trees + (@subst $name:ident, [$($acc:tt)*], ( $($head:tt)* ) $($tail:tt)*) => { + impl_op!(@subst $name, + [ $($acc)* ( impl_op!(@subst $name, [], $($head)*) ) ], + $($tail)* + ) + }; + (@subst $name:ident, [$($acc:tt)*], [ $($head:tt)* ] $($tail:tt)*) => { + impl_op!(@subst $name, + [ $($acc)* [ impl_op!(@subst $name, [], $($head)*) ] ], + $($tail)* + ) + }; + (@subst $name:ident, [$($acc:tt)*], { $($head:tt)* } $($tail:tt)*) => { + impl_op!(@subst $name, + [ $($acc)* { impl_op!(@subst $name, [], $($head)*) } ], + $($tail)* + ) + }; + + // Handle lone tokens + (@subst $name:ident, [$($acc:tt)*], $head:tt $($tail:tt)*) => { + impl_op!(@subst $name, + [ $($acc)* $head ], + $($tail)* + ) + }; +} + +impl_op!( + NumConstants { + MAX, + MIN + } => { + BF16 | (bf16::@name.to_f32()), + F16 | (f16::@name.to_f32()), + F32 | (f32::@name), + F64 | (f64::@name), + I32 | (i32::@name), + I64 | (i64::@name), + UInt | (u32::@name) + } +); + +impl_op!( + FloatConstants { + DIGITS -> {UInt}, + EPSILON, + INFINITY, + MANTISSA_DIGITS -> {UInt}, + MAX_10_EXP -> {I32}, + MAX_EXP -> {I32}, + MIN_10_EXP -> {I32}, + MIN_EXP -> {I32}, + MIN_POSITIVE, + NAN, + NEG_INFINITY, + RADIX -> {UInt} + } => { + BF16 | {{Self} => (bf16::@name.to_f32()); (bf16::@name)}, + F16 | {{Self} => (f16::@name.to_f32()); (f16::@name)}, + F32 | (f32::@name), + F64 | (f64::@name) + } +); + +impl_op!( + IntConstants { + BITS -> {UInt} + } => { + I32 | (i32::@name), + I64 | (i64::@name), + UInt | (u32::@name) + } +); diff --git a/crates/cubecl-core/src/frontend/operation/mod.rs b/crates/cubecl-core/src/frontend/operation/mod.rs index 06273444b..ea0c3436d 100644 --- a/crates/cubecl-core/src/frontend/operation/mod.rs +++ b/crates/cubecl-core/src/frontend/operation/mod.rs @@ -3,6 +3,7 @@ mod base; mod binary; mod clamp; mod cmp; +mod constants; mod fma; mod unary; @@ -11,5 +12,6 @@ pub use base::*; pub use binary::*; pub use clamp::*; pub use cmp::*; +pub use constants::*; pub use fma::*; pub use unary::*;