Skip to content

Commit

Permalink
frontend: Add support for numeric constants.
Browse files Browse the repository at this point in the history
  • Loading branch information
booti386 committed Sep 10, 2024
1 parent ff6302a commit c00de60
Show file tree
Hide file tree
Showing 7 changed files with 258 additions and 15 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ num-traits = { version = "0.2.19", default-features = false, features = [
proc-macro2 = "1.0.86"
syn = { version = "2.0.69", features = ["full", "extra-traits"] }
quote = "1.0.36"
paste = "1.0.15"

### For xtask crate ###
strum = {version = "0.26.3", features = ["derive"]}
Expand Down
1 change: 1 addition & 0 deletions crates/cubecl-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ serde = { workspace = true }
cubecl-macros = { path = "../cubecl-macros", version = "0.2.0" }
derive-new = { workspace = true }
num-traits = { workspace = true }
paste = { workspace = true }

log = { workspace = true }

Expand Down
22 changes: 12 additions & 10 deletions crates/cubecl-core/src/frontend/element/float.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use half::{bf16, f16};

use crate::frontend::{
Ceil, Cos, Erf, Exp, Floor, Log, Log1p, Normalize, Powf, Recip, Round, Sin, Sqrt, Tanh,
Ceil, Cos, Erf, Exp, FloatConstants, Floor, Log, Log1p, Normalize, Powf, Recip, Round, Sin,
Sqrt, Tanh,
};
use crate::frontend::{
ComptimeType, CubeContext, CubePrimitive, CubeType, ExpandElement, ExpandElementBaseInit,
Expand All @@ -19,20 +20,21 @@ use crate::Runtime;
/// Floating point numbers. Used as input in float kernels
pub trait Float:
Numeric
+ Ceil
+ Cos
+ Erf
+ Exp
+ FloatConstants
+ Floor
+ Log
+ Log1p
+ Cos
+ Sin
+ Tanh
+ Normalize
+ Powf
+ Sqrt
+ Round
+ Floor
+ Ceil
+ Erf
+ Recip
+ Normalize
+ Round
+ Sin
+ Sqrt
+ Tanh
+ From<f32>
+ core::ops::Add<f32, Output = Self>
+ core::ops::Sub<f32, Output = Self>
Expand Down
5 changes: 3 additions & 2 deletions crates/cubecl-core/src/frontend/element/int.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::compute::{KernelBuilder, KernelLauncher};
use crate::frontend::{
ComptimeType, CubeContext, CubePrimitive, CubeType, ExpandElement, ExpandElementBaseInit,
ExpandElementTyped, Numeric,
ExpandElementTyped, IntConstants, Numeric,
};
use crate::ir::{ConstantScalarValue, Elem, IntKind, Variable, Vectorization};
use crate::Runtime;
Expand All @@ -14,8 +14,9 @@ use super::{
/// Signed integer. Used as input in int kernels
pub trait Int:
Numeric
+ std::ops::Rem<Output = Self>
+ IntConstants
+ From<i32>
+ std::ops::Rem<Output = Self>
+ core::ops::Add<i32, Output = Self>
+ core::ops::Sub<i32, Output = Self>
+ core::ops::Mul<i32, Output = Self>
Expand Down
7 changes: 4 additions & 3 deletions crates/cubecl-core/src/frontend/element/numeric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::ir::{Item, Variable};
use crate::prelude::Clamp;
use crate::Runtime;
use crate::{
frontend::{index_assign, Abs, Max, Min, Remainder},
frontend::{index_assign, Abs, Max, Min, NumConstants, Remainder},
unexpanded,
};

Expand All @@ -18,13 +18,15 @@ use super::{
pub trait Numeric:
Copy
+ Abs
+ Clamp
+ Max
+ Min
+ Clamp
+ NumConstants
+ Remainder
+ ExpandElementBaseInit
+ CubePrimitive
+ LaunchArgExpand
+ From<u32>
+ std::ops::AddAssign
+ std::ops::SubAssign
+ std::ops::MulAssign
Expand All @@ -38,7 +40,6 @@ pub trait Numeric:
+ core::ops::IndexMut<UInt, Output = Self>
+ core::ops::Index<u32, Output = Self>
+ core::ops::IndexMut<u32, Output = Self>
+ From<u32>
+ std::ops::Add<u32, Output = Self>
+ std::ops::Sub<u32, Output = Self>
+ std::ops::Mul<u32, Output = Self>
Expand Down
235 changes: 235 additions & 0 deletions crates/cubecl-core/src/frontend/operation/constants.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
use crate::{
prelude::{
CubeContext, CubePrimitive, ExpandElementTyped, UInt, BF16, F16, F32, F64, I32, I64,
},
unexpanded,
};

use half::{bf16, f16};
use paste::paste;

macro_rules! tt_as_expr {
($t:expr) => {
$t
};
}

macro_rules! tt_as_ty {
(( $t:ty )) => {
$t
};
([ $t:ty ]) => {
$t
};
({ $t:ty }) => {
$t
};
($t:ty) => {
$t
};
}

macro_rules! impl_op {
($tr:ident { $( $name:ident $(-> { $($ret_type:tt)+ } )? ),* } => { $($type:ty | $expr:tt),* }) => {
paste! {
impl_op!(@default
[],
$tr,
{ $({ $name, [<lim_ $name:lower>], [<__expand_lim_ $name:lower>], $( { $($ret_type)+ } )? })* },
{ $($type, $expr);* }
);
}
};

// Default $ret_type to {Self}
(@default [$($acc:tt)*], $tr:ident, {}, $types:tt) => {
impl_op!(@internal
$tr,
{ $($acc);* },
$types
);
};
(@default [$($acc:tt)*], $tr:ident, { { $name:ident, $unexpand_name:ident, $expand_name:ident, } $($tail:tt)* }, $types:tt) => {
impl_op!(@default
[ $($acc)* { $name, $unexpand_name, $expand_name, {Self} } ],
$tr,
{ $($tail)* },
$types
);
};
(@default [$($acc:tt)*], $tr:ident, { { $name:ident, $unexpand_name:ident, $expand_name:ident, $ret_type:tt } $($tail:tt)* }, $types:tt) => {
impl_op!(@default
[ $($acc)* { $name, $unexpand_name, $expand_name, $ret_type } ],
$tr,
{ $($tail)* },
$types
);
};

// Generate trait declarations and impl blocks
(@internal $tr:ident, $values:tt, $types:tt) => {
impl_op!(@decls $tr, $values);
impl_op!(@impls $tr, $types, $values);
};

// Generate trait declarations
(@decls $tr:ident, { $({ $name:ident, $unexpand_name:ident, $expand_name:ident, { $ret_type:ty } });* }) => {
pub trait $tr: CubePrimitive + Sized {
$(
fn $unexpand_name() -> $ret_type {
unexpanded!()
}
fn $expand_name(_context: &mut CubeContext) -> ExpandElementTyped<$ret_type>;
)*
}
};

// Generate impl blocks
(@impls $tr:ident, { $($type:ty, $e:tt);* }, $values:tt) => {
$(
impl $tr for $type {
impl_op!(@impl $e, $values);
}
)*
};

// Generate impl block contents
(@impl $e:tt, { $({ $name:ident, $unexpand_name:ident, $expand_name:ident, $ret_type:tt });* }) => {
$(
fn $expand_name(_context: &mut CubeContext) -> ExpandElementTyped<tt_as_ty!($ret_type)> {
impl_op!(@implbody $e, $name, $ret_type)
}
)*
};

(@implbody { $({ $($select_type:tt)+ } => $e:tt),+; $default_e:tt }, $name:ident, $ret_type:tt) => {
impl_op!(@implbody $({ { $($select_type)+ }, $e }),+; $default_e, $name, $ret_type)
};

(@implbody { $select_type:tt, $e:tt } $(, $acc:tt)*; $default_e:tt, $name:ident, $ret_type:tt) => {{
macro_rules! __emit_on_match {
($select_type, $select_type) => { impl_op!(@implbody $e, $name, { $ret_type }) };
($select_type, $ret_type) => { impl_op!(@implbody $($acc),*; $default_e, $name, $ret_type) };
}

__emit_on_match!($select_type, $ret_type)
}};

(@implbody ; $default_e:tt, $name:ident, $($unused:tt)*) => {
impl_op!(@implbody $default_e, $name,)
};

(@implbody ( $($e:tt)* ), $name:ident, $($unused:tt)*) => {
ExpandElementTyped::from_lit(tt_as_expr!(impl_op!(@subst $name, [], $($e)*)))
};

// Handle @name substitution

// No more tokens to process, return the token tree
(@subst $name:ident, [$($acc:tt)*],) => {
$($acc)*
};

// Handle replacement
(@subst $name:ident, [$($acc:tt)*], @name $($tail:tt)*) => {
impl_op!(@subst $name,
[ $($acc)* $name ],
$($tail)*
)
};

// Handle empty token trees
(@subst $name:ident, [$($acc:tt)*], () $($tail:tt)*) => {
impl_op!(@subst $name,
[ $($acc)* () ],
$($tail)*
)
};
(@subst $name:ident, [$($acc:tt)*], [] $($tail:tt)*) => {
impl_op!(@subst $name,
[ $($acc)* [] ],
$($tail)*
)
};
(@subst $name:ident, [$($acc:tt)*], {} $($tail:tt)*) => {
impl_op!(@subst $name,
[ $($acc)* {} ],
$($tail)*
)
};

// Handle token trees
(@subst $name:ident, [$($acc:tt)*], ( $($head:tt)* ) $($tail:tt)*) => {
impl_op!(@subst $name,
[ $($acc)* ( impl_op!(@subst $name, [], $($head)*) ) ],
$($tail)*
)
};
(@subst $name:ident, [$($acc:tt)*], [ $($head:tt)* ] $($tail:tt)*) => {
impl_op!(@subst $name,
[ $($acc)* [ impl_op!(@subst $name, [], $($head)*) ] ],
$($tail)*
)
};
(@subst $name:ident, [$($acc:tt)*], { $($head:tt)* } $($tail:tt)*) => {
impl_op!(@subst $name,
[ $($acc)* { impl_op!(@subst $name, [], $($head)*) } ],
$($tail)*
)
};

// Handle lone tokens
(@subst $name:ident, [$($acc:tt)*], $head:tt $($tail:tt)*) => {
impl_op!(@subst $name,
[ $($acc)* $head ],
$($tail)*
)
};
}

impl_op!(
NumConstants {
MAX,
MIN
} => {
BF16 | (bf16::@name.to_f32()),
F16 | (f16::@name.to_f32()),
F32 | (f32::@name),
F64 | (f64::@name),
I32 | (i32::@name),
I64 | (i64::@name),
UInt | (u32::@name)
}
);

impl_op!(
FloatConstants {
DIGITS -> {UInt},
EPSILON,
INFINITY,
MANTISSA_DIGITS -> {UInt},
MAX_10_EXP -> {I32},
MAX_EXP -> {I32},
MIN_10_EXP -> {I32},
MIN_EXP -> {I32},
MIN_POSITIVE,
NAN,
NEG_INFINITY,
RADIX -> {UInt}
} => {
BF16 | {{Self} => (bf16::@name.to_f32()); (bf16::@name)},
F16 | {{Self} => (f16::@name.to_f32()); (f16::@name)},
F32 | (f32::@name),
F64 | (f64::@name)
}
);

impl_op!(
IntConstants {
BITS -> {UInt}
} => {
I32 | (i32::@name),
I64 | (i64::@name),
UInt | (u32::@name)
}
);
2 changes: 2 additions & 0 deletions crates/cubecl-core/src/frontend/operation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ mod base;
mod binary;
mod clamp;
mod cmp;
mod constants;
mod fma;
mod unary;

Expand All @@ -11,5 +12,6 @@ pub use base::*;
pub use binary::*;
pub use clamp::*;
pub use cmp::*;
pub use constants::*;
pub use fma::*;
pub use unary::*;

0 comments on commit c00de60

Please sign in to comment.