diff --git a/crates/cubecl-core/src/frontend/element/array.rs b/crates/cubecl-core/src/frontend/element/array.rs index b0520d6d1..e9ca88d56 100644 --- a/crates/cubecl-core/src/frontend/element/array.rs +++ b/crates/cubecl-core/src/frontend/element/array.rs @@ -3,7 +3,8 @@ use std::{marker::PhantomData, num::NonZero}; use crate::{ compute::{KernelBuilder, KernelLauncher}, frontend::CubeType, - ir::{Item, Vectorization}, + ir::{Branch, Item, RangeLoop, Vectorization}, + prelude::{CubeIndex, Iterable}, unexpanded, KernelSettings, Runtime, }; use crate::{ @@ -253,3 +254,56 @@ impl<'a, R: Runtime> ArrayHandleRef<'a, R> { } } } + +pub trait SizedContainer: + CubeIndex, Output = Self::Item> + CubeType +{ + type Item: CubeType>; +} + +impl>> SizedContainer for Array { + type Item = T; +} + +impl Iterator for &Array { + type Item = T; + + fn next(&mut self) -> Option { + unexpanded!() + } +} + +impl Iterable for ExpandElementTyped { + fn expand( + self, + context: &mut CubeContext, + mut body: impl FnMut(&mut CubeContext, ::ExpandType), + ) { + let index_ty = Item::new(u32::as_elem()); + let len: ExpandElement = self.clone().__expand_len_method(context).into(); + + let mut child = context.child(); + let i = child.scope.borrow_mut().create_local_undeclared(index_ty); + let i = ExpandElement::Plain(i); + + let item = index::expand(&mut child, self, i.clone().into()); + body(&mut child, item); + + context.register(Branch::RangeLoop(RangeLoop { + i: *i, + start: 0u32.into(), + end: *len, + step: None, + inclusive: false, + scope: child.into_scope(), + })); + } + + fn expand_unroll( + self, + _context: &mut CubeContext, + _body: impl FnMut(&mut CubeContext, ::ExpandType), + ) { + unimplemented!("Can't unroll array iterator") + } +} diff --git a/crates/cubecl-core/src/frontend/element/slice.rs b/crates/cubecl-core/src/frontend/element/slice.rs index 0ed569651..68c9afcc3 100644 --- a/crates/cubecl-core/src/frontend/element/slice.rs +++ b/crates/cubecl-core/src/frontend/element/slice.rs @@ -1,7 +1,8 @@ use std::marker::PhantomData; use super::{ - Array, CubePrimitive, CubeType, ExpandElement, ExpandElementTyped, Init, SharedMemory, Tensor, + Array, CubePrimitive, CubeType, ExpandElement, ExpandElementTyped, Init, SharedMemory, + SizedContainer, Tensor, }; use crate::{ frontend::indexation::Index, @@ -60,6 +61,25 @@ impl<'a, C: CubeType> Init for ExpandElementTyped> { } } +impl<'a, C: CubeType>> SizedContainer for Slice<'a, C> { + type Item = C; +} + +impl<'a, T: CubeType> Iterator for Slice<'a, T> { + type Item = T; + + fn next(&mut self) -> Option { + unexpanded!() + } +} +impl<'a, T: CubeType> Iterator for &Slice<'a, T> { + type Item = T; + + fn next(&mut self) -> Option { + unexpanded!() + } +} + pub trait SliceOperator: CubeType { type Expand: SliceOperatorExpand; diff --git a/crates/cubecl-core/src/frontend/element/tensor.rs b/crates/cubecl-core/src/frontend/element/tensor.rs index 94ba711e5..86fe20c6e 100644 --- a/crates/cubecl-core/src/frontend/element/tensor.rs +++ b/crates/cubecl-core/src/frontend/element/tensor.rs @@ -1,4 +1,4 @@ -use super::{ExpandElementBaseInit, ExpandElementTyped, LaunchArgExpand}; +use super::{ExpandElementBaseInit, ExpandElementTyped, LaunchArgExpand, SizedContainer}; use crate::{ frontend::{ indexation::Index, ArgSettings, CubeContext, CubePrimitive, CubeType, ExpandElement, @@ -243,3 +243,15 @@ impl ExpandElementTyped { ExpandElement::Plain(Variable::Rank).into() } } + +impl>> SizedContainer for Tensor { + type Item = T; +} + +impl Iterator for &Tensor { + type Item = T; + + fn next(&mut self) -> Option { + unexpanded!() + } +} diff --git a/crates/cubecl-core/src/frontend/operation/assignation.rs b/crates/cubecl-core/src/frontend/operation/assignation.rs index 6c782870f..87414635e 100644 --- a/crates/cubecl-core/src/frontend/operation/assignation.rs +++ b/crates/cubecl-core/src/frontend/operation/assignation.rs @@ -91,7 +91,7 @@ pub mod index { use super::*; - pub fn expand>( + pub fn expand>>( context: &mut CubeContext, array: ExpandElementTyped, index: ExpandElementTyped, diff --git a/crates/cubecl-core/src/runtime_tests/slice.rs b/crates/cubecl-core/src/runtime_tests/slice.rs index 87a4b1bcc..769d1c364 100644 --- a/crates/cubecl-core/src/runtime_tests/slice.rs +++ b/crates/cubecl-core/src/runtime_tests/slice.rs @@ -26,6 +26,19 @@ pub fn slice_len(input: &Array, output: &mut Array) { } } +#[cube(launch)] +pub fn slice_for(input: &Array, output: &mut Array) { + if UNIT_POS == 0 { + let mut sum = 0f32; + + for item in input.slice(2, 4) { + sum += item; + } + + output[0] = sum; + } +} + pub fn test_slice_select(client: ComputeClient) { let input = client.create(f32::as_bytes(&[0.0, 1.0, 2.0, 3.0, 4.0])); let output = client.empty(core::mem::size_of::()); @@ -86,6 +99,26 @@ pub fn test_slice_assign(client: ComputeClient(client: ComputeClient) { + let input = client.create(f32::as_bytes(&[0.0, 1.0, 2.0, 3.0, 4.0])); + let output = client.create(f32::as_bytes(&[0.0])); + + unsafe { + slice_for::launch::( + &client, + CubeCount::Static(1, 1, 1), + CubeDim::new(1, 1, 1), + ArrayArg::from_raw_parts(&input, 5, 1), + ArrayArg::from_raw_parts(&output, 1, 1), + ) + }; + + let actual = client.read(output.binding()); + let actual = f32::from_bytes(&actual); + + assert_eq!(actual[0], 5.0); +} + #[allow(missing_docs)] #[macro_export] macro_rules! testgen_slice { @@ -109,5 +142,11 @@ macro_rules! testgen_slice { let client = TestRuntime::client(&Default::default()); cubecl_core::runtime_tests::slice::test_slice_len::(client); } + + #[test] + fn test_slice_for() { + let client = TestRuntime::client(&Default::default()); + cubecl_core::runtime_tests::slice::test_slice_for::(client); + } }; } diff --git a/crates/cubecl-core/tests/frontend/for_loop.rs b/crates/cubecl-core/tests/frontend/for_loop.rs index 18cc9681b..159fc88f8 100644 --- a/crates/cubecl-core/tests/frontend/for_loop.rs +++ b/crates/cubecl-core/tests/frontend/for_loop.rs @@ -17,9 +17,23 @@ pub fn for_loop(mut lhs: Array, rhs: F, end: u32, #[comptime] unrol } } +#[cube] +pub fn for_in_loop(input: &Array) -> F { + let mut sum = F::new(0.0); + + for item in input { + sum += item; + } + sum +} + mod tests { use cubecl::frontend::ExpandElement; - use cubecl_core::{cpa, ir::Item}; + use cubecl_core::{ + cpa, + ir::{Item, Variable}, + }; + use pretty_assertions::assert_eq; use super::*; @@ -35,7 +49,7 @@ mod tests { for_loop::expand::(&mut context, lhs.into(), rhs.into(), end.into(), unroll); let scope = context.into_scope(); - assert_eq!(format!("{:?}", scope.operations), inline_macro_ref(unroll)); + assert_eq!(format!("{:#?}", scope.operations), inline_macro_ref(unroll)); } #[test] @@ -50,7 +64,22 @@ mod tests { for_loop::expand::(&mut context, lhs.into(), rhs.into(), end.into(), unroll); let scope = context.into_scope(); - assert_eq!(format!("{:?}", scope.operations), inline_macro_ref(unroll)); + assert_eq!(format!("{:#?}", scope.operations), inline_macro_ref(unroll)); + } + + #[test] + fn test_for_in_loop() { + let mut context = CubeContext::root(); + + let input = context.create_local_array(Item::new(ElemType::as_elem()), 4u32); + + for_in_loop::expand::(&mut context, input.into()); + let scope = context.into_scope(); + + assert_eq!( + format!("{:#?}", scope.operations), + inline_macro_ref_for_in() + ); } fn inline_macro_ref(unroll: bool) -> String { @@ -76,6 +105,32 @@ mod tests { }) ); - format!("{:?}", scope.operations) + format!("{:#?}", scope.operations) + } + + fn inline_macro_ref_for_in() -> String { + let context = CubeContext::root(); + let item = Item::new(ElemType::as_elem()); + + let mut scope = context.into_scope(); + let input = scope.create_local_array(item, 4u32); + let sum = scope.create_local(item); + let end = scope.create_local(Item::new(u32::as_elem())); + let zero: Variable = ElemType::new(0.0).into(); + + // Kernel + let tmp1 = scope.create_local(item); + cpa!(scope, sum = zero); + cpa!(scope, end = len(input)); + + cpa!( + &mut scope, + range(0u32, end).for_each(|i, scope| { + cpa!(scope, tmp1 = input[i]); + cpa!(scope, sum = sum + tmp1); + }) + ); + + format!("{:#?}", scope.operations) } }