Skip to content

Commit

Permalink
Merge pull request #32 from tracel-ai/fix/comptime-runtime-value
Browse files Browse the repository at this point in the history
Fix comptime::runtime(value) for float and int types
  • Loading branch information
nathanielsimard authored Jul 22, 2024
2 parents b3c1799 + d89825d commit e686da9
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 1 deletion.
2 changes: 1 addition & 1 deletion crates/cubecl-core/src/frontend/comptime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ pub struct Comptime<T> {
}

/// Type that can be used within [Comptime].
pub trait ComptimeType: CubeType {
pub trait ComptimeType: CubeType + Into<ExpandElement> {
/// Create the expand type from the normal type.
fn into_expand(self) -> Self::ExpandType;
}
Expand Down
7 changes: 7 additions & 0 deletions crates/cubecl-core/src/frontend/element/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,13 @@ macro_rules! impl_float {
}
}

impl From<$type> for ExpandElement {
fn from(value: $type) -> Self {
let constant = $type::as_elem().from_constant(value.val.into());
ExpandElement::Plain(constant)
}
}

impl Numeric for $type {
type Primitive = $primitive;
}
Expand Down
7 changes: 7 additions & 0 deletions crates/cubecl-core/src/frontend/element/int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,13 @@ macro_rules! impl_int {
}
}

impl From<$type> for ExpandElement {
fn from(value: $type) -> Self {
let constant = $type::as_elem().from_constant(value.val.into());
ExpandElement::Plain(constant)
}
}

impl Numeric for $type {
type Primitive = $primitive;
}
Expand Down
6 changes: 6 additions & 0 deletions crates/cubecl-core/tests/frontend/comptime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ pub fn comptime_else_then_if<T: Numeric>(lhs: T, cond1: Comptime<bool>, cond2: C
}
}

#[cube]
pub fn comptime_float() {
let comptime_float = Comptime::new(F32::new(0.0));
let _runtime_float = Comptime::runtime(comptime_float);
}

#[cube]
pub fn comptime_elsif<T: Numeric>(lhs: T, cond1: Comptime<bool>, cond2: Comptime<bool>) {
if Comptime::get(cond1) {
Expand Down

0 comments on commit e686da9

Please sign in to comment.