Skip to content

Commit

Permalink
Fix remaining tests
Browse files Browse the repository at this point in the history
  • Loading branch information
wingertge committed Sep 8, 2024
1 parent df8c970 commit 02bc447
Show file tree
Hide file tree
Showing 17 changed files with 214 additions and 124 deletions.
20 changes: 20 additions & 0 deletions crates/cubecl-core/src/frontend/element/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,19 @@ macro_rules! tuple_init {
}
}
}
macro_rules! tuple_runtime {
($($P:ident),*) => {
impl<$($P: IntoRuntime),*> IntoRuntime for ($($P,)*) {
#[allow(non_snake_case)]
fn __expand_runtime_method(self, context: &mut CubeContext) -> Self::ExpandType {
let ($($P,)*) = self;
($(
$P.__expand_runtime_method(context),
)*)
}
}
}
}

tuple_cube_type!(P1);
tuple_cube_type!(P1, P2);
Expand All @@ -201,6 +214,13 @@ tuple_init!(P1, P2, P3, P4);
tuple_init!(P1, P2, P3, P4, P5);
tuple_init!(P1, P2, P3, P4, P5, P6);

tuple_runtime!(P1);
tuple_runtime!(P1, P2);
tuple_runtime!(P1, P2, P3);
tuple_runtime!(P1, P2, P3, P4);
tuple_runtime!(P1, P2, P3, P4, P5);
tuple_runtime!(P1, P2, P3, P4, P5, P6);

pub trait ExpandElementBaseInit: CubeType {
fn init_elem(context: &mut CubeContext, elem: ExpandElement) -> ExpandElement;
}
Expand Down
4 changes: 2 additions & 2 deletions crates/cubecl-core/src/frontend/element/numeric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ pub trait Numeric:

fn __expand_from_vec<const D: usize>(
context: &mut CubeContext,
vec: [ExpandElementTyped<u32>; D],
vec: [u32; D],
) -> <Self as CubeType>::ExpandType {
let new_var = context.create_local(Item::vectorized(
Self::as_elem(),
Expand All @@ -91,7 +91,7 @@ pub trait Numeric:
let elem = Self::as_elem();

for (i, element) in vec.iter().enumerate() {
let var: Variable = elem.constant_from_i64(element.constant().unwrap().as_i64());
let var: Variable = elem.constant_from_i64(*element as i64);
let expand = ExpandElement::Plain(var);

index_assign::expand::<u32>(
Expand Down
21 changes: 12 additions & 9 deletions crates/cubecl-core/src/frontend/operation/base.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use std::num::NonZero;

use crate::ir::{BinaryOperator, Elem, Item, Operator, UnaryOperator, Variable, Vectorization};
use crate::prelude::{CubeType, ExpandElementTyped};
use crate::{
Expand All @@ -23,6 +21,7 @@ where
let item_rhs = rhs.item();

let vectorization = find_vectorization(item_lhs.vectorization, item_rhs.vectorization);

let item = Item::vectorized(item_lhs.elem, vectorization);

// We can only reuse rhs.
Expand Down Expand Up @@ -196,17 +195,21 @@ where
}

fn find_vectorization(lhs: Vectorization, rhs: Vectorization) -> Vectorization {
if lhs == rhs {
return lhs;
}
match (lhs, rhs) {
(None, None) => None,
(None, Some(rhs)) => Some(rhs),
(Some(lhs), None) => Some(lhs),
(Some(lhs), Some(rhs)) => {
let min = lhs.get().min(rhs.get());
let common = (0..=min)
.rev()
.find(|i| lhs.get() % i == 0 && rhs.get() % i == 0)
.unwrap_or(1);
NonZero::new(common)
(Some(_), Some(_)) => {
panic!("Auto-matching fixed vectorization currently unsupported");
// let min = lhs.get().min(rhs.get());
// let common = (0..=min)
// .rev()
// .find(|i| lhs.get() % i == 0 && rhs.get() % i == 0)
// .unwrap_or(1);
// NonZero::new(common)
}
}
}
Expand Down
16 changes: 8 additions & 8 deletions crates/cubecl-core/tests/frontend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ mod module_import;
mod ops;
mod parenthesis;
mod redeclare;
// mod reuse;
// mod shared_memory;
// mod r#struct;
// mod tensor;
// mod topology;
// mod r#trait;
mod reuse;
mod shared_memory;
mod r#struct;
mod tensor;
mod topology;
mod r#trait;

// mod tuple;
// mod vectorization;
mod tuple;
mod vectorization;
6 changes: 3 additions & 3 deletions crates/cubecl-core/tests/frontend/reuse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ mod tests {
ir::{Branch, Elem, Item, Variable},
};

type ElemType = I32;
type ElemType = i32;
#[test]
fn cube_reuse_assign_test() {
let mut context = CubeContext::root();

let x = context.create_local(Item::new(ElemType::as_elem()));

reuse::__expand::<ElemType>(&mut context, x.into());
reuse::expand::<ElemType>(&mut context, x.into());
let scope = context.into_scope();

assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_assign());
Expand All @@ -45,7 +45,7 @@ mod tests {

let x = context.create_local(Item::new(ElemType::as_elem()));

reuse_incr::__expand::<ElemType>(&mut context, x.into());
reuse_incr::expand::<ElemType>(&mut context, x.into());
let scope = context.into_scope();

assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_incr());
Expand Down
6 changes: 3 additions & 3 deletions crates/cubecl-core/tests/frontend/shared_memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use cubecl_core as cubecl;
use cubecl_core::prelude::*;

#[cube]
pub fn shared_memory_read_write<T: Numeric>(sm_size: Comptime<u32>) {
pub fn shared_memory_read_write<T: Numeric>(#[comptime] sm_size: u32) {
let mut shared = SharedMemory::<T>::new(sm_size);
shared[0] = T::from_int(3);
let _ = shared[0];
Expand All @@ -15,13 +15,13 @@ mod tests {
ir::{Item, Variable},
};

type ElemType = F32;
type ElemType = f32;

#[test]
fn cube_support_shared_memory() {
let mut context = CubeContext::root();

shared_memory_read_write::__expand::<ElemType>(&mut context, 512);
shared_memory_read_write::expand::<ElemType>(&mut context, 512);
assert_eq!(
format!("{:?}", context.into_scope().operations),
inline_macro_ref()
Expand Down
10 changes: 5 additions & 5 deletions crates/cubecl-core/tests/frontend/struct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ mod tests {
ir::{Item, Variable},
};

type ElemType = F32;
type ElemType = f32;

#[test]
fn cube_new_struct_test() {
Expand All @@ -49,7 +49,7 @@ mod tests {
let x = context.create_local(Item::new(ElemType::as_elem()));
let y = context.create_local(Item::new(ElemType::as_elem()));

creator::__expand::<ElemType>(&mut context, x.into(), y.into());
creator::expand::<ElemType>(&mut context, x.into(), y.into());
let scope = context.into_scope();

assert_eq!(
Expand All @@ -69,7 +69,7 @@ mod tests {
first: x.into(),
second: y.into(),
};
state_receiver_with_reuse::__expand::<ElemType>(&mut context, expanded_state);
state_receiver_with_reuse::expand::<ElemType>(&mut context, expanded_state);
let scope = context.into_scope();

assert_eq!(
Expand All @@ -89,7 +89,7 @@ mod tests {
first: x.into(),
second: y.into(),
};
attribute_modifier_reuse_field::__expand::<ElemType>(&mut context, expanded_state);
attribute_modifier_reuse_field::expand::<ElemType>(&mut context, expanded_state);
let scope = context.into_scope();

assert_eq!(
Expand All @@ -109,7 +109,7 @@ mod tests {
first: x.into(),
second: y.into(),
};
attribute_modifier_reuse_struct::__expand::<ElemType>(&mut context, expanded_state);
attribute_modifier_reuse_struct::expand::<ElemType>(&mut context, expanded_state);
let scope = context.into_scope();

assert_eq!(
Expand Down
10 changes: 5 additions & 5 deletions crates/cubecl-core/tests/frontend/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@ mod tests {
ir::{Item, Operation, Variable},
};

type ElemType = F32;
type ElemType = f32;

#[test]
fn cube_support_tensor_metadata() {
let mut context = CubeContext::root();
let input = context.input(0, Item::new(ElemType::as_elem()));

kernel::__expand::<ElemType>(&mut context, input.into());
kernel::expand::<ElemType>(&mut context, input.into());
assert_eq!(context.into_scope().operations, inline_macro_ref());
}

Expand All @@ -33,9 +33,9 @@ mod tests {

let mut scope = context.into_scope();
let input: Variable = input.into();
let x = scope.create_local(Item::new(UInt::as_elem()));
let y = scope.create_local(Item::new(UInt::as_elem()));
let z = scope.create_local(Item::new(UInt::as_elem()));
let x = scope.create_local(Item::new(u32::as_elem()));
let y = scope.create_local(Item::new(u32::as_elem()));
let z = scope.create_local(Item::new(u32::as_elem()));

cpa!(&mut scope, x = shape(input, 1u32));
cpa!(&mut scope, y = stride(input, 1u32));
Expand Down
6 changes: 3 additions & 3 deletions crates/cubecl-core/tests/frontend/topology.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use cubecl_core::prelude::*;

#[cube]
pub fn topology_kernel<T: Numeric>(input: Tensor<T>) {
let x = ABSOLUTE_POS + UInt::new(4);
let x = ABSOLUTE_POS + 4;
let _ = input[x];
}

Expand All @@ -14,14 +14,14 @@ mod tests {
ir::{Elem, Item, Variable},
};

type ElemType = F32;
type ElemType = f32;

#[test]
fn cube_support_topology() {
let mut context = CubeContext::root();
let input = context.input(0, Item::new(ElemType::as_elem()));

topology_kernel::__expand::<ElemType>(&mut context, input.into());
topology_kernel::expand::<ElemType>(&mut context, input.into());
assert_eq!(
format!("{:?}", context.into_scope().operations),
inline_macro_ref()
Expand Down
12 changes: 6 additions & 6 deletions crates/cubecl-core/tests/frontend/trait.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ impl MethodTypedStrategy for AddStrategy {
input_1: <T as CubeType>::ExpandType,
input_2: <T as CubeType>::ExpandType,
) -> <T as CubeType>::ExpandType {
add_strategy_operation::__expand::<T>(context, input_1, input_2)
add_strategy_operation::expand::<T>(context, input_1, input_2)
}
}

Expand All @@ -80,15 +80,15 @@ mod tests {
ir::{Item, Variable},
};

type ElemType = F32;
type ElemType = f32;
#[test]
fn cube_strategy_trait_add_test() {
let mut context = CubeContext::root();

let x = context.create_local(Item::new(ElemType::as_elem()));
let y = context.create_local(Item::new(ElemType::as_elem()));

with_strategy_trait::__expand::<AddStrategy, ElemType>(&mut context, x.into(), y.into());
with_strategy_trait::expand::<AddStrategy, ElemType>(&mut context, x.into(), y.into());
let scope = context.into_scope();

assert_eq!(
Expand All @@ -104,7 +104,7 @@ mod tests {
let x = context.create_local(Item::new(ElemType::as_elem()));
let y = context.create_local(Item::new(ElemType::as_elem()));

with_strategy_trait::__expand::<SubStrategy, ElemType>(&mut context, x.into(), y.into());
with_strategy_trait::expand::<SubStrategy, ElemType>(&mut context, x.into(), y.into());
let scope = context.into_scope();

assert_eq!(
Expand All @@ -120,7 +120,7 @@ mod tests {
let x = context.create_local(Item::new(ElemType::as_elem()));
let y = context.create_local(Item::new(ElemType::as_elem()));

two_strategy_traits::__expand::<SubStrategy, AddStrategy, ElemType>(
two_strategy_traits::expand::<SubStrategy, AddStrategy, ElemType>(
&mut context,
x.into(),
y.into(),
Expand All @@ -137,7 +137,7 @@ mod tests {
let x = context.create_local(Item::new(ElemType::as_elem()));
let y = context.create_local(Item::new(ElemType::as_elem()));

with_trait_generic_method::__expand::<AddStrategy, ElemType>(
with_trait_generic_method::expand::<AddStrategy, ElemType>(
&mut context,
x.into(),
y.into(),
Expand Down
22 changes: 11 additions & 11 deletions crates/cubecl-core/tests/frontend/tuple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,17 @@ use cubecl_core as cubecl;
use cubecl_core::prelude::*;

#[cube]
pub fn tuple_const() -> (UInt, UInt) {
let x = UInt::new(0);
let y = UInt::new(1);
pub fn tuple_const() -> (u32, u32) {
let x = 0u32;
let y = 1u32;
(x, y)
}

#[cube]
pub fn tuple_destructuring() -> (UInt, UInt) {
let x = (UInt::new(0), UInt::new(1));
pub fn tuple_destructuring() -> (u32, u32) {
let x = (0u32, 1u32);
let (a, b) = x;
(a + UInt::new(1), b)
(a + 1, b)
}

mod tests {
Expand All @@ -21,12 +21,14 @@ mod tests {
cpa,
ir::{Elem, Item, Operation, Variable},
};
use pretty_assertions::assert_eq;

#[test]
#[ignore = "Empty body because of constant collapsing"]
fn cube_tuple_const_test() {
let mut context = CubeContext::root();

tuple_const::__expand(&mut context);
tuple_const::expand(&mut context);
let scope = context.into_scope();

assert_eq!(scope.operations, inline_macro_ref_tuple_const());
Expand All @@ -52,7 +54,7 @@ mod tests {
fn cube_tuple_destructuring() {
let mut context = CubeContext::root();

tuple_destructuring::__expand(&mut context);
tuple_destructuring::expand(&mut context);
let scope = context.into_scope();

assert_eq!(scope.operations, inline_macro_ref_tuple_destructuring());
Expand All @@ -65,12 +67,10 @@ mod tests {
let a = scope.create_local(Item::new(Elem::UInt));
let b = scope.create_local(Item::new(Elem::UInt));

let zero: Variable = 0u32.into();
let one: Variable = 1u32.into();

cpa!(scope, a = zero);
cpa!(scope, a = one);
cpa!(scope, b = one);
cpa!(scope, a = a + 1u32);

scope.operations
}
Expand Down
Loading

0 comments on commit 02bc447

Please sign in to comment.