Skip to content

Commit

Permalink
Merge pull request #30 from tracel-ai/feat/contiguous
Browse files Browse the repository at this point in the history
Into Contiguous Kernel
  • Loading branch information
nathanielsimard authored Jul 22, 2024
2 parents 89af40f + 3c53a87 commit b3c1799
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 17 deletions.
58 changes: 43 additions & 15 deletions crates/cubecl-linalg/src/matmul/cmma/launch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::{
base::cmma_kernel,
config::{cmma_cube_count, cmma_cube_dim, CmmaConfig, CmmaLaunchConfig},
},
tensor::{matrix_layout, MatrixLayout, TensorHandle},
tensor::{into_contiguous, matrix_layout, MatrixLayout, TensorHandle},
};

/// Matrix multiplication using [cooperative matrix-multiply and accumulate operations](cubecl_core::cmma).
Expand All @@ -28,8 +28,7 @@ pub fn matmul_cmma<R: Runtime, F: Float>(

#[derive(Debug)]
pub enum UnavailabilityReason {
TransposedInput, // TODO: Support that case.
NotMultipleOf4, // TODO: Support that case.
NotMultipleOf4, // TODO: Support that case.
HiglyPermutatedInput,
ShapeMemoryLimitBusted,
InvalidConfig(String),
Expand All @@ -43,15 +42,6 @@ pub fn check_cmma_availability<R: Runtime>(
rhs: &TensorHandleRef<'_, R>,
config: Option<&CmmaLaunchConfig>,
) -> Result<(), UnavailabilityReason> {
let check_layout = |tensor: &TensorHandleRef<'_, R>| match matrix_layout(tensor.strides) {
MatrixLayout::Contiguous => Ok(()),
MatrixLayout::MildlyPermuted {
transposed: _,
batch_swap: _,
} => Err(UnavailabilityReason::TransposedInput),
MatrixLayout::HighlyPermuted => Err(UnavailabilityReason::HiglyPermutatedInput),
};

if !client.features().enabled(Feature::Cmma {
a: Elem::Float(FloatKind::F16),
b: Elem::Float(FloatKind::F16),
Expand All @@ -63,9 +53,6 @@ pub fn check_cmma_availability<R: Runtime>(
return Err(UnavailabilityReason::CmmaInstructionsUnsupported);
}

check_layout(lhs)?;
check_layout(rhs)?;

let rank = lhs.shape.len();
let m = lhs.shape[rank - 2];
let k = lhs.shape[rank - 1];
Expand Down Expand Up @@ -105,6 +92,47 @@ pub fn matmul_cmma_ref<R: Runtime, F: Float>(
lhs: TensorHandleRef<'_, R>,
rhs: TensorHandleRef<'_, R>,
out: TensorHandleRef<'_, R>,
) {
let check_layout = |tensor: &TensorHandleRef<'_, R>| match matrix_layout(tensor.strides) {
MatrixLayout::Contiguous => true,
MatrixLayout::MildlyPermuted {
transposed: _,
batch_swap: _,
} => false,
MatrixLayout::HighlyPermuted => false,
};

let lhs_correct_layout = check_layout(&lhs);
let rhs_correct_layout = check_layout(&rhs);

match (lhs_correct_layout, rhs_correct_layout) {
(true, true) => matmul_cmma_ref_no_check::<R, F>(client, lhs, rhs, out),
(true, false) => matmul_cmma_ref_no_check::<R, F>(
client,
lhs,
into_contiguous::<R, F>(client, rhs).as_ref(),
out,
),
(false, true) => matmul_cmma_ref_no_check::<R, F>(
client,
into_contiguous::<R, F>(client, lhs).as_ref(),
rhs,
out,
),
(false, false) => matmul_cmma_ref_no_check::<R, F>(
client,
into_contiguous::<R, F>(client, lhs).as_ref(),
into_contiguous::<R, F>(client, rhs).as_ref(),
out,
),
}
}

fn matmul_cmma_ref_no_check<R: Runtime, F: Float>(
client: &ComputeClient<R::Server, R::Channel>,
lhs: TensorHandleRef<'_, R>,
rhs: TensorHandleRef<'_, R>,
out: TensorHandleRef<'_, R>,
) {
let rank = lhs.strides.len();

Expand Down
50 changes: 48 additions & 2 deletions crates/cubecl-linalg/src/matmul/tiling2d/launch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{
base::tiling2d_cube_kernel,
config::{tiling2d_cube_count, tiling2d_cube_dim, CubeTiling2dConfig},
},
tensor::{matrix_layout, MatrixLayout, TensorHandle},
tensor::{into_contiguous, matrix_layout, MatrixLayout, TensorHandle},
};

use super::config::Tiling2dConfig;
Expand Down Expand Up @@ -38,6 +38,51 @@ pub fn matmul_tiling_2d_ref<R: Runtime, F: Float>(
<= <R::Compiler as Compiler>::max_shared_memory_size(),
"Shared memory limit will be busted. "
);
let check_layout = |tensor: &TensorHandleRef<'_, R>| match matrix_layout(tensor.strides) {
MatrixLayout::Contiguous => true,
MatrixLayout::MildlyPermuted {
transposed: _,
batch_swap: _,
} => true,
MatrixLayout::HighlyPermuted => false,
};
let lhs_correct_layout = check_layout(&lhs);
let rhs_correct_layout = check_layout(&rhs);

match (lhs_correct_layout, rhs_correct_layout) {
(true, true) => matmul_tiling_2d_ref_no_check::<R, F>(client, lhs, rhs, out, config),
(true, false) => matmul_tiling_2d_ref_no_check::<R, F>(
client,
lhs,
into_contiguous::<R, F>(client, rhs).as_ref(),
out,
config,
),
(false, true) => matmul_tiling_2d_ref_no_check::<R, F>(
client,
into_contiguous::<R, F>(client, lhs).as_ref(),
rhs,
out,
config,
),
(false, false) => matmul_tiling_2d_ref_no_check::<R, F>(
client,
into_contiguous::<R, F>(client, lhs).as_ref(),
into_contiguous::<R, F>(client, rhs).as_ref(),
out,
config,
),
}
}

/// Matrix multiplication using tiling 2d algorithm.
fn matmul_tiling_2d_ref_no_check<R: Runtime, F: Float>(
client: &ComputeClient<R::Server, R::Channel>,
lhs: TensorHandleRef<'_, R>,
rhs: TensorHandleRef<'_, R>,
out: TensorHandleRef<'_, R>,
config: Tiling2dConfig,
) {
let rank = lhs.strides.len();

let m = lhs.shape[rank - 2];
Expand All @@ -58,7 +103,8 @@ pub fn matmul_tiling_2d_ref<R: Runtime, F: Float>(
let rhs_transposed = check_layout(rhs.strides);

let vectorization = |shape: usize| {
[].into_iter()
[4, 2]
.into_iter()
.filter(|v| shape % v == 0)
.map(|v| v as u8)
.next()
Expand Down
95 changes: 95 additions & 0 deletions crates/cubecl-linalg/src/tensor/contiguous.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
use cubecl_core::{
self as cubecl, calculate_cube_count_elemwise, tensor_vectorization_factor, SUBCUBE_DIM_APPROX,
};

use cubecl::prelude::*;

use super::TensorHandle;

/// Returns the offset of the tensor corresponding to the layout tensor.
#[cube]
pub fn index_offset_with_layout<N: CubePrimitive, L: CubePrimitive>(
tensor: &Tensor<N>,
layout: &Tensor<L>,
offset_layout: UInt,
dim_start: UInt,
dim_end: UInt,
unroll: Comptime<bool>,
) -> UInt {
let vectorization_factor = Comptime::vectorization(tensor);
let vectorization_factor_runtime = Comptime::runtime(vectorization_factor);

let offset_ref = offset_layout * vectorization_factor_runtime;
let mut offset = UInt::new(0);

for i in range(dim_start, dim_end, unroll) {
let ogwl = offset_ref / layout.stride(i);
offset += ogwl % tensor.shape(i) * tensor.stride(i);
}

offset / vectorization_factor_runtime
}

#[cube(launch)]
fn into_contiguous_kernel<N: CubePrimitive>(
input: &Tensor<N>,
output: &mut Tensor<N>,
rank: Comptime<Option<UInt>>,
) {
let offset_output = ABSOLUTE_POS;

if offset_output >= output.len() {
return;
}

let offset_input = index_offset_with_layout::<N, N>(
input,
output,
offset_output,
UInt::new(0),
Comptime::unwrap_or_else(rank, || output.rank()),
Comptime::is_some(rank),
);

output[offset_output] = input[offset_input];
}

/// Make a jit tensor contiguous.
pub fn into_contiguous<R: Runtime, E: CubePrimitive>(
client: &ComputeClient<R::Server, R::Channel>,
input: TensorHandleRef<'_, R>,
) -> TensorHandle<R, E> {
// Vectorization is only enabled when the last dimension is contiguous.
let rank = input.strides.len();
let vectorization_factor =
tensor_vectorization_factor(&[4, 2], &input.shape, &input.strides, rank - 1);

let num_elems: usize = input.shape.iter().product();
let cube_count = calculate_cube_count_elemwise(
num_elems / vectorization_factor as usize,
SUBCUBE_DIM_APPROX,
);
let handle = client.empty(num_elems * E::as_elem().size());
let output = TensorHandle::new_contiguous(input.shape.to_vec(), handle);

into_contiguous_kernel::launch::<E, R>(
&client,
cube_count,
CubeDim::default(),
TensorArg::vectorized(
vectorization_factor,
&input.handle,
&input.strides,
&input.shape,
),
TensorArg::vectorized(
vectorization_factor,
&output.handle,
&output.strides,
&output.shape,
),
Some(UInt::new(rank as u32)),
);

output
}
3 changes: 3 additions & 0 deletions crates/cubecl-linalg/src/tensor/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
mod base;
mod contiguous;
mod layout;

pub use base::*;
pub use contiguous::*;
pub use layout::*;

0 comments on commit b3c1799

Please sign in to comment.