Skip to content

Commit

Permalink
Feat/compilation arg (#141)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Sep 21, 2024
1 parent 16a79fc commit 447968e
Show file tree
Hide file tree
Showing 21 changed files with 505 additions and 258 deletions.
2 changes: 1 addition & 1 deletion crates/cubecl-core/src/codegen/integrator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ pub struct KernelSettings {
pub mappings: Vec<InplaceMapping>,
vectorization_global: Option<Vectorization>,
vectorization_partial: Vec<VectorizationPartial>,
cube_dim: CubeDim,
pub cube_dim: CubeDim,
pub reading_strategy: Vec<(u16, ReadingStrategy)>,
}

Expand Down
16 changes: 16 additions & 0 deletions crates/cubecl-core/src/compute/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,22 @@ impl KernelBuilder {
variable
}

/// Register an output that uses the same resource as the input as the given position.
pub fn inplace_output(&mut self, position: u16) -> ExpandElement {
let input = self
.inputs
.get_mut(position as usize)
.expect("Position valid");

if let InputInfo::Array { visibility, item } = input {
*visibility = Visibility::ReadWrite;
let variable = self.context.input(position, *item);
return variable;
}

panic!("No input found at position {position}");
}

/// Register an input array and return the [element](ExpandElement) to be used for kernel expansion.
pub fn input_array(&mut self, item: Item) -> ExpandElement {
self.inputs.push(InputInfo::Array {
Expand Down
3 changes: 3 additions & 0 deletions crates/cubecl-core/src/frontend/container/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
mod sequence;

pub use sequence::*;
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{
use crate::frontend::{
branch::Iterable, indexation::Index, CubeContext, CubeType, ExpandElementTyped, Init,
IntoRuntime,
};
Expand Down Expand Up @@ -34,6 +34,12 @@ impl<T: CubeType> Sequence<T> {
self.values.push(value);
}

/// Obtain the sequence length.
#[allow(clippy::len_without_is_empty)]
pub fn len(&self) -> u32 {
unexpanded!()
}

/// Get the variable at the given position in the sequence.
#[allow(unused_variables, clippy::should_implement_trait)]
pub fn index<I: Index>(&self, index: I) -> &T {
Expand Down Expand Up @@ -70,7 +76,7 @@ impl<T: CubeType> Sequence<T> {
pub struct SequenceExpand<T: CubeType> {
// We clone the expand type during the compilation phase, but for register reuse, not for
// copying data. To achieve the intended behavior, we have to share the same underlying values.
values: Rc<RefCell<Vec<T::ExpandType>>>,
pub(super) values: Rc<RefCell<Vec<T::ExpandType>>>,
}

impl<T: CubeType> Iterable<T> for SequenceExpand<T> {
Expand Down Expand Up @@ -149,6 +155,11 @@ impl<T: CubeType> SequenceExpand<T> {
.as_usize();
self.values.borrow()[index].clone()
}

pub fn __expand_len_method(&self, _context: &mut CubeContext) -> u32 {
let values = self.values.borrow();
values.len() as u32
}
}

impl<T: CubeType> IntoRuntime for Sequence<T> {
Expand Down
107 changes: 107 additions & 0 deletions crates/cubecl-core/src/frontend/container/sequence/launch.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
use std::{cell::RefCell, rc::Rc};

use crate::{
compute::KernelBuilder,
prelude::{ArgSettings, LaunchArg, LaunchArgExpand},
Runtime,
};

use super::{Sequence, SequenceExpand};

pub struct SequenceArg<'a, R: Runtime, T: LaunchArg> {
values: Vec<T::RuntimeArg<'a, R>>,
}

impl<'a, R: Runtime, T: LaunchArg> Default for SequenceArg<'a, R, T> {
fn default() -> Self {
Self::new()
}
}

impl<'a, R: Runtime, T: LaunchArg> SequenceArg<'a, R, T> {
pub fn new() -> Self {
Self { values: Vec::new() }
}
pub fn push(&mut self, arg: T::RuntimeArg<'a, R>) {
self.values.push(arg);
}
}

pub struct SequenceCompilationArg<C: LaunchArg> {
values: Vec<C::CompilationArg>,
}

impl<C: LaunchArg> Clone for SequenceCompilationArg<C> {
fn clone(&self) -> Self {
Self {
values: self.values.clone(),
}
}
}

impl<C: LaunchArg> core::hash::Hash for SequenceCompilationArg<C> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.values.hash(state)
}
}

impl<C: LaunchArg> core::cmp::PartialEq for SequenceCompilationArg<C> {
fn eq(&self, other: &Self) -> bool {
self.values.eq(&other.values)
}
}

impl<C: LaunchArg> core::fmt::Debug for SequenceCompilationArg<C> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_fmt(format_args!("SequenceCompilationArg {:?}", self.values))
}
}
impl<C: LaunchArg> core::cmp::Eq for SequenceCompilationArg<C> {}

impl<C: LaunchArg> LaunchArg for Sequence<C> {
type RuntimeArg<'a, R: Runtime> = SequenceArg<'a, R, C>;

fn compilation_arg<R: Runtime>(runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg {
SequenceCompilationArg {
values: runtime_arg
.values
.iter()
.map(|value| C::compilation_arg(value))
.collect(),
}
}
}

impl<'a, R: Runtime, T: LaunchArg> ArgSettings<R> for SequenceArg<'a, R, T> {
fn register(&self, launcher: &mut crate::prelude::KernelLauncher<R>) {
self.values.iter().for_each(|arg| arg.register(launcher));
}
}

impl<C: LaunchArg> LaunchArgExpand for Sequence<C> {
type CompilationArg = SequenceCompilationArg<C>;

fn expand(arg: &Self::CompilationArg, builder: &mut KernelBuilder) -> SequenceExpand<C> {
let values = arg
.values
.iter()
.map(|value| C::expand(value, builder))
.collect::<Vec<_>>();

SequenceExpand {
values: Rc::new(RefCell::new(values)),
}
}

fn expand_output(arg: &Self::CompilationArg, builder: &mut KernelBuilder) -> SequenceExpand<C> {
let values = arg
.values
.iter()
.map(|value| C::expand_output(value, builder))
.collect::<Vec<_>>();

SequenceExpand {
values: Rc::new(RefCell::new(values)),
}
}
}
5 changes: 5 additions & 0 deletions crates/cubecl-core/src/frontend/container/sequence/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
mod base;
mod launch;

pub use base::*;
pub use launch::*;
77 changes: 41 additions & 36 deletions crates/cubecl-core/src/frontend/element/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::{
frontend::CubeType,
ir::{Branch, Item, RangeLoop, Vectorization},
prelude::{CubeIndex, Iterable},
unexpanded, KernelSettings, Runtime,
unexpanded, Runtime,
};
use crate::{
frontend::{indexation::Index, CubeContext},
Expand All @@ -14,14 +14,20 @@ use crate::{

use super::{
ArgSettings, CubePrimitive, ExpandElement, ExpandElementBaseInit, ExpandElementTyped,
LaunchArg, LaunchArgExpand, TensorHandleRef,
IntoRuntime, LaunchArg, LaunchArgExpand, TensorHandleRef,
};

/// A contiguous array of elements.
pub struct Array<E> {
_val: PhantomData<E>,
}

impl<E: CubePrimitive> IntoRuntime for Array<E> {
fn __expand_runtime_method(self, _context: &mut CubeContext) -> Self::ExpandType {
unimplemented!("Array can't exist at compile time")
}
}

impl<C: CubeType> CubeType for Array<C> {
type ExpandType = ExpandElementTyped<Array<C>>;
}
Expand Down Expand Up @@ -127,24 +133,51 @@ impl<E: CubeType> Array<E> {

impl<C: CubePrimitive> LaunchArg for Array<C> {
type RuntimeArg<'a, R: Runtime> = ArrayArg<'a, R>;

fn compilation_arg<R: Runtime>(runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg {
match runtime_arg {
ArrayArg::Handle {
handle: _,
vectorization_factor,
} => ArrayCompilationArg {
inplace: None,
vectorisation: Vectorization::Some(NonZero::new(*vectorization_factor).unwrap()),
},
ArrayArg::Alias { input_pos } => ArrayCompilationArg {
inplace: Some(*input_pos as u16),
vectorisation: Vectorization::None,
},
}
}
}

#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct ArrayCompilationArg {
inplace: Option<u16>,
vectorisation: Vectorization,
}

impl<C: CubePrimitive> LaunchArgExpand for Array<C> {
type CompilationArg = ArrayCompilationArg;

fn expand(
arg: &Self::CompilationArg,
builder: &mut KernelBuilder,
vectorization: Vectorization,
) -> ExpandElementTyped<Array<C>> {
builder
.input_array(Item::vectorized(C::as_elem(), vectorization))
.input_array(Item::vectorized(C::as_elem(), arg.vectorisation))
.into()
}
fn expand_output(
arg: &Self::CompilationArg,
builder: &mut KernelBuilder,
vectorization: Vectorization,
) -> ExpandElementTyped<Array<C>> {
builder
.output_array(Item::vectorized(C::as_elem(), vectorization))
.into()
match arg.inplace {
Some(id) => builder.inplace_output(id).into(),
None => builder
.output_array(Item::vectorized(C::as_elem(), arg.vectorisation))
.into(),
}
}
}

Expand Down Expand Up @@ -179,34 +212,6 @@ impl<'a, R: Runtime> ArgSettings<R> for ArrayArg<'a, R> {
launcher.register_array(handle)
}
}

fn configure_input(&self, position: usize, settings: KernelSettings) -> KernelSettings {
match self {
Self::Handle {
handle: _,
vectorization_factor,
} => settings.vectorize_input(position, NonZero::new(*vectorization_factor)),
Self::Alias { input_pos: _ } => {
panic!("Not yet supported, only output can be aliased for now.");
}
}
}

fn configure_output(&self, position: usize, mut settings: KernelSettings) -> KernelSettings {
match self {
Self::Handle {
handle: _,
vectorization_factor,
} => settings.vectorize_output(position, NonZero::new(*vectorization_factor)),
Self::Alias { input_pos } => {
settings.mappings.push(crate::InplaceMapping {
pos_input: *input_pos,
pos_output: position,
});
settings
}
}
}
}

impl<'a, R: Runtime> ArrayArg<'a, R> {
Expand Down
18 changes: 7 additions & 11 deletions crates/cubecl-core/src/frontend/element/atomic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,7 @@ use super::{
};
use crate::{
frontend::{CubeContext, CubePrimitive, CubeType, ExpandElement},
ir::{
BinaryOperator, CompareAndSwapOperator, Elem, IntKind, Item, Operator, UnaryOperator,
Vectorization,
},
ir::{BinaryOperator, CompareAndSwapOperator, Elem, IntKind, Item, Operator, UnaryOperator},
prelude::KernelBuilder,
unexpanded,
};
Expand Down Expand Up @@ -306,11 +303,12 @@ macro_rules! impl_atomic_int {
}

impl LaunchArgExpand for $type {
type CompilationArg = ();

fn expand(
_: &Self::CompilationArg,
builder: &mut KernelBuilder,
vectorization: Vectorization,
) -> ExpandElementTyped<Self> {
assert_eq!(vectorization, None, "Attempted to vectorize a scalar");
builder.scalar(Elem::AtomicInt(IntKind::$inner_type)).into()
}
}
Expand Down Expand Up @@ -357,11 +355,9 @@ impl ExpandElementBaseInit for AtomicU32 {
}

impl LaunchArgExpand for AtomicU32 {
fn expand(
builder: &mut KernelBuilder,
vectorization: Vectorization,
) -> ExpandElementTyped<Self> {
assert_eq!(vectorization, None, "Attempted to vectorize a scalar");
type CompilationArg = ();

fn expand(_: &Self::CompilationArg, builder: &mut KernelBuilder) -> ExpandElementTyped<Self> {
builder.scalar(Elem::AtomicUInt).into()
}
}
Expand Down
Loading

0 comments on commit 447968e

Please sign in to comment.