Skip to content

Commit

Permalink
Merge branch 'main' into matmul/pipelined
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard committed Nov 28, 2024
2 parents 6b4409e + 553e6a2 commit 4f4da29
Show file tree
Hide file tree
Showing 14 changed files with 236 additions and 22 deletions.
2 changes: 1 addition & 1 deletion crates/cubecl-core/src/codegen/execution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ where
}
}

impl<'h, 'a, 'b, 'c, K, R, E1, E2, E3> Execution<'h, K, R, (&'a [E1], &'b [E2], &'c [E3])>
impl<K, R, E1, E2, E3> Execution<'_, K, R, (&[E1], &[E2], &[E3])>
where
K: Kernel + 'static,
R: Runtime,
Expand Down
1 change: 0 additions & 1 deletion crates/cubecl-core/src/compute/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ use cubecl_runtime::ExecutionMode;
/// A kernel, compiled in the target language
pub struct CompiledKernel<C: Compiler> {
/// The name of the kernel entrypoint.
/// For example
///
/// ```text
Expand Down
2 changes: 1 addition & 1 deletion crates/cubecl-core/src/frontend/container/array/launch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ pub enum ArrayArg<'a, R: Runtime> {
},
}

impl<'a, R: Runtime> ArgSettings<R> for ArrayArg<'a, R> {
impl<R: Runtime> ArgSettings<R> for ArrayArg<'_, R> {
fn register(&self, launcher: &mut KernelLauncher<R>) {
launcher.register_array(self)
}
Expand Down
4 changes: 2 additions & 2 deletions crates/cubecl-core/src/frontend/container/sequence/launch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pub struct SequenceArg<'a, R: Runtime, T: LaunchArg> {
pub values: Vec<T::RuntimeArg<'a, R>>,
}

impl<'a, R: Runtime, T: LaunchArg> Default for SequenceArg<'a, R, T> {
impl<R: Runtime, T: LaunchArg> Default for SequenceArg<'_, R, T> {
fn default() -> Self {
Self::new()
}
Expand Down Expand Up @@ -72,7 +72,7 @@ impl<C: LaunchArg> LaunchArg for Sequence<C> {
}
}

impl<'a, R: Runtime, T: LaunchArg> ArgSettings<R> for SequenceArg<'a, R, T> {
impl<R: Runtime, T: LaunchArg> ArgSettings<R> for SequenceArg<'_, R, T> {
fn register(&self, launcher: &mut crate::prelude::KernelLauncher<R>) {
self.values.iter().for_each(|arg| arg.register(launcher));
}
Expand Down
57 changes: 56 additions & 1 deletion crates/cubecl-core/src/frontend/container/tensor/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ pub struct Tensor<T: CubeType> {
/// Module that contains the implementation details of the metadata functions.
mod metadata {
use super::*;
use crate::{ir::Instruction, prelude::Array};
use crate::{
ir::{BinaryOperator, Instruction, Operator},
prelude::Array,
};

impl<T: CubeType> Tensor<T> {
/// Obtain the stride of input at dimension dim
Expand All @@ -31,6 +34,14 @@ mod metadata {
unexpanded!()
}

/// Obtain the coordinate corresponding to the given `index` of the tensor at dimension `dim`.
///
/// A coordinate is a list of indices corresponding to the multi-dimensional position of an element in the tensor.
/// The `dim` element in a coordinate is the position along the `dim` dimension of the tensor.
pub fn coordinate<I: Index, D: Index>(&self, _index: I, _dim: D) -> u32 {
unexpanded!()
}

/// The number of vectorized elements in the tensor.
///
/// # Warning
Expand Down Expand Up @@ -76,6 +87,16 @@ mod metadata {
expand.__expand_shape_method(context, dim)
}

// Expand function of [coordinate](Tensor::coordinate).
pub fn __expand_coordinate<I: Index, D: Index>(
context: &mut CubeContext,
expand: ExpandElementTyped<Tensor<T>>,
index: ExpandElementTyped<u32>,
dim: ExpandElementTyped<u32>,
) -> ExpandElementTyped<u32> {
expand.__expand_coordinate_method(context, index, dim)
}

// Expand function of [len](Tensor::len).
pub fn __expand_len<C: Index>(
context: &mut CubeContext,
Expand Down Expand Up @@ -138,6 +159,40 @@ mod metadata {
out.into()
}

// Expand method of [coordinate](Tensor::coordinate).
pub fn __expand_coordinate_method(
self,
context: &mut CubeContext,
index: ExpandElementTyped<u32>,
dim: ExpandElementTyped<u32>,
) -> ExpandElementTyped<u32> {
let index: ExpandElement = index.into();
let stride = self.clone().__expand_stride_method(context, dim.clone());
let shape = self.clone().__expand_shape_method(context, dim.clone());

// Compute `num_strides = index / stride`.
let num_strides = context.create_local_binding(Item::new(u32::as_elem()));
context.register(Instruction::new(
Operator::Div(BinaryOperator {
lhs: *index,
rhs: stride.expand.into(),
}),
num_strides.clone().into(),
));

// Compute `coordinate = num_strides % shape `.
let coordinate = context.create_local_binding(Item::new(u32::as_elem()));
context.register(Instruction::new(
Operator::Modulo(BinaryOperator {
lhs: *num_strides,
rhs: shape.expand.into(),
}),
coordinate.clone().into(),
));

coordinate.into()
}

// Expand method of [len](Tensor::len).
pub fn __expand_len_method(self, context: &mut CubeContext) -> ExpandElementTyped<u32> {
let elem: ExpandElementTyped<Array<u32>> = self.expand.into();
Expand Down
4 changes: 2 additions & 2 deletions crates/cubecl-core/src/frontend/container/tensor/launch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ pub struct TensorHandleRef<'a, R: Runtime> {
pub runtime: PhantomData<R>,
}

impl<'a, R: Runtime> core::fmt::Debug for TensorHandleRef<'a, R> {
impl<R: Runtime> core::fmt::Debug for TensorHandleRef<'_, R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(
f,
Expand Down Expand Up @@ -153,7 +153,7 @@ impl<'a, R: Runtime> TensorArg<'a, R> {
}
}

impl<'a, R: Runtime> ArgSettings<R> for TensorArg<'a, R> {
impl<R: Runtime> ArgSettings<R> for TensorArg<'_, R> {
fn register(&self, launcher: &mut KernelLauncher<R>) {
launcher.register_tensor(self)
}
Expand Down
3 changes: 3 additions & 0 deletions crates/cubecl-core/src/frontend/element/float/tensor_float.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
#![allow(clippy::transmute_int_to_float)] // Not yet stable in previous version. To be removed when
#![allow(clippy::transmute_float_to_int)] // prev=1.83.

use bytemuck::{Pod, Zeroable};
use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign};
use half::f16;
Expand Down
82 changes: 76 additions & 6 deletions crates/cubecl-core/src/runtime_tests/cmma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use cubecl::{
ir::{Elem, FloatKind},
prelude::*,
};
use half::f16;
use half::{bf16, f16};

#[cube(launch)]
/// Executes Out = Lhs @ Rhs.T
Expand Down Expand Up @@ -88,7 +88,7 @@ pub fn kernel_simple_tf32(lhs: &Array<tf32>, rhs: &Array<tf32>, out: &mut Array<
}

#[cube(launch)]
pub fn cast_matrix(input: &Array<f32>, out: &mut Array<f16>) {
pub fn cast_matrix_f16(input: &Array<f32>, out: &mut Array<f16>) {
let acc = unsafe {
cmma::Matrix::<f32>::uninitialized(
cmma::MatrixIdent::Accumulator,
Expand All @@ -110,6 +110,29 @@ pub fn cast_matrix(input: &Array<f32>, out: &mut Array<f16>) {
);
}

#[cube(launch)]
pub fn cast_matrix_bf16(input: &Array<f32>, out: &mut Array<bf16>) {
let acc = unsafe {
cmma::Matrix::<f32>::uninitialized(
cmma::MatrixIdent::Accumulator,
16,
16,
16,
cmma::MatrixLayout::Undefined,
)
};
cmma::load_with_layout(&acc, &input.to_slice(), 16, cmma::MatrixLayout::RowMajor);

let output = cmma::cast::<f32, bf16>(&acc);

cmma::store(
&mut out.to_slice_mut(),
&output,
16,
cmma::MatrixLayout::RowMajor,
);
}

pub fn test_simple_1<R: Runtime>(
client: ComputeClient<R::Server, R::Channel>,
cube_dimensions: CubeDim,
Expand Down Expand Up @@ -174,7 +197,7 @@ pub fn test_simple_1<R: Runtime>(
assert_eq!(expected, actual);
}

pub fn test_cmma_cast_acc<R: Runtime>(
pub fn test_cmma_cast_f16<R: Runtime>(
client: ComputeClient<R::Server, R::Channel>,
cube_dimensions: CubeDim,
) {
Expand All @@ -195,7 +218,7 @@ pub fn test_cmma_cast_acc<R: Runtime>(
let out = client.empty(core::mem::size_of::<f16>() * 256);

unsafe {
cast_matrix::launch::<R>(
cast_matrix_f16::launch::<R>(
&client,
CubeCount::Static(1, 1, 1),
cube_dimensions,
Expand All @@ -211,6 +234,43 @@ pub fn test_cmma_cast_acc<R: Runtime>(
assert_eq!(actual, expected);
}

pub fn test_cmma_cast_bf16<R: Runtime>(
client: ComputeClient<R::Server, R::Channel>,
cube_dimensions: CubeDim,
) {
if !client.properties().feature_enabled(Feature::Cmma {
a: Elem::Float(FloatKind::BF16),
b: Elem::Float(FloatKind::BF16),
c: Elem::Float(FloatKind::F32),
m: 16,
k: 16,
n: 16,
}) {
// We can't execute the test, skip.
return;
}

let input: Vec<f32> = (0..256).map(|i| i as f32).collect();
let input = client.create(f32::as_bytes(&input));
let out = client.empty(core::mem::size_of::<f16>() * 256);

unsafe {
cast_matrix_bf16::launch::<R>(
&client,
CubeCount::Static(1, 1, 1),
cube_dimensions,
ArrayArg::from_raw_parts::<f32>(&input, 256, 1),
ArrayArg::from_raw_parts::<f16>(&out, 256, 1),
)
};

let actual = client.read_one(out.binding());
let actual = bf16::from_bytes(&actual);
let expected: Vec<bf16> = (0..256).map(|i| bf16::from_f32(i as f32)).collect();

assert_eq!(actual, expected);
}

pub fn test_simple_tf32<R: Runtime>(
client: ComputeClient<R::Server, R::Channel>,
cube_dimensions: CubeDim,
Expand Down Expand Up @@ -305,10 +365,20 @@ macro_rules! testgen_cmma {
}

#[test]
fn test_cmma_cast_acc() {
fn test_cmma_cast_f16() {
let client = TestRuntime::client(&Default::default());
let cube_dimensions = CubeDim::new(32, 1, 1);
cubecl_core::runtime_tests::cmma::test_cmma_cast_f16::<TestRuntime>(
client,
cube_dimensions,
);
}

#[test]
fn test_cmma_cast_bf16() {
let client = TestRuntime::client(&Default::default());
let cube_dimensions = CubeDim::new(32, 1, 1);
cubecl_core::runtime_tests::cmma::test_cmma_cast_acc::<TestRuntime>(
cubecl_core::runtime_tests::cmma::test_cmma_cast_bf16::<TestRuntime>(
client,
cube_dimensions,
);
Expand Down
2 changes: 2 additions & 0 deletions crates/cubecl-core/src/runtime_tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pub mod metadata;
pub mod plane;
pub mod sequence;
pub mod slice;
pub mod tensor;
pub mod topology;
pub mod unary;

Expand Down Expand Up @@ -103,6 +104,7 @@ macro_rules! testgen_untyped {
cubecl_core::testgen_topology!();

cubecl_core::testgen_constants!();
cubecl_core::testgen_tensor_indexing!();
};
}

Expand Down
61 changes: 61 additions & 0 deletions crates/cubecl-core/src/runtime_tests/tensor.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
use crate as cubecl;
use cubecl::prelude::*;

#[cube(launch)]
pub fn tensor_coordinate(input: &Tensor<f32>, output: &mut Array<u32>) {
let index = UNIT_POS_X;
let dim = UNIT_POS_Y;
output[UNIT_POS] = input.coordinate(index, dim);
}

pub fn test_tensor_coordinate<R: Runtime>(client: ComputeClient<R::Server, R::Channel>) {
let stride = [2, 1, 4];
let shape = [2, 2, 3];

let input_size = shape.iter().product::<usize>();
let input = client.empty(core::mem::size_of::<f32>() * input_size);

// Each column corresponds to a complete coordinate.
// That is, when increasing the index, the coordinates are
// [0,0,0], [0,1,0] ... [1,1,2].
let expected = vec![
0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, //
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, //
0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, //
];

let output_size = shape.len() * input_size;

// The result is independant of the line size
for &line_size in R::supported_line_sizes() {
let output = client.empty(core::mem::size_of::<u32>() * output_size);
unsafe {
tensor_coordinate::launch::<R>(
&client,
CubeCount::Static(1, 1, 1),
CubeDim::new(input_size as u32, shape.len() as u32, 1),
TensorArg::from_raw_parts::<f32>(&input, &stride, &shape, line_size),
ArrayArg::from_raw_parts::<u32>(&output, output_size, 1),
)
};

let actual = client.read_one(output.binding());
let actual = u32::from_bytes(&actual);

assert_eq!(actual, expected);
}
}

#[allow(missing_docs)]
#[macro_export]
macro_rules! testgen_tensor_indexing {
() => {
use super::*;

#[test]
fn test_tensor_coordinate() {
let client = TestRuntime::client(&Default::default());
cubecl_core::runtime_tests::tensor::test_tensor_coordinate::<TestRuntime>(client);
}
};
}
2 changes: 1 addition & 1 deletion crates/cubecl-cpp/src/shared/instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -777,7 +777,7 @@ struct EnsureBoolArg<'a, V: Display, D: Dialect> {
elem: &'a Elem<D>,
}

impl<'a, V: Display, D: Dialect> Display for EnsureBoolArg<'a, V, D> {
impl<V: Display, D: Dialect> Display for EnsureBoolArg<'_, V, D> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if self.elem != &Elem::Bool {
write!(f, "bool({})", self.var)
Expand Down
Loading

0 comments on commit 4f4da29

Please sign in to comment.