Skip to content

Commit

Permalink
Fix CMMA stride on SPIR-V (#319)
Browse files Browse the repository at this point in the history
  • Loading branch information
wingertge authored Nov 28, 2024
1 parent c096a40 commit 0df1761
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 5 deletions.
2 changes: 1 addition & 1 deletion crates/cubecl-core/src/runtime_tests/cmma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ pub fn test_simple_tf32<R: Runtime>(
b: Elem::Float(FloatKind::TF32),
c: Elem::Float(FloatKind::F32),
m: 16,
k: 16,
k: 8,
n: 16,
}) {
// We can't execute the test, skip.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ pub trait CmmaValid<I: Numeric, O: Numeric> {}
impl CmmaValid<f16, f16> for (f16, f16) {}
impl CmmaValid<f16, f32> for (f16, f32) {}
impl CmmaValid<bf16, f32> for (bf16, f32) {}
impl CmmaValid<tf32, f32> for (tf32, f32) {}

macro_rules! instruction {
($name:ident, $m:expr, $n:expr, $k:expr) => {
Expand Down
9 changes: 7 additions & 2 deletions crates/cubecl-reduce/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -295,10 +295,10 @@ impl TestCase {
naive_reduce_dim_kernel::launch_unchecked::<I, O, K, R>(
&client,
self.cube_count.clone(),
self.cube_dim.clone(),
self.cube_dim,
input_tensor,
output_tensor,
ScalarArg::new(self.reduce_dim.clone()),
ScalarArg::new(self.reduce_dim),
);
}

Expand All @@ -311,6 +311,7 @@ impl TestCase {

fn cpu_sum_dim<F: Float>(&self, values: &[F]) -> Vec<F> {
let mut expected = vec![F::new(0.0); self.num_output_values()];
#[allow(clippy::needless_range_loop)]
for input_index in 0..values.len() {
let output_index = self.to_output_index(input_index);
expected[output_index] += values[input_index];
Expand All @@ -320,6 +321,7 @@ impl TestCase {

fn cpu_prod_dim<F: Float>(&self, values: &[F]) -> Vec<F> {
let mut expected = vec![F::new(1.0); self.num_output_values()];
#[allow(clippy::needless_range_loop)]
for value_index in 0..values.len() {
let output_index = self.to_output_index(value_index);
expected[output_index] *= values[value_index];
Expand All @@ -336,6 +338,7 @@ impl TestCase {

fn cpu_argmax_dim<F: Float>(&self, values: &[F]) -> Vec<u32> {
let mut expected = vec![(F::MIN, 0_u32); self.num_output_values()];
#[allow(clippy::needless_range_loop)]
for input_index in 0..values.len() {
let output_index = self.to_output_index(input_index);
let (best, _) = expected[output_index];
Expand All @@ -350,6 +353,7 @@ impl TestCase {

fn cpu_argmin_dim<F: Float>(&self, values: &[F]) -> Vec<u32> {
let mut expected = vec![(F::MAX, 0_u32); self.num_output_values()];
#[allow(clippy::needless_range_loop)]
for input_index in 0..values.len() {
let output_index = self.to_output_index(input_index);
let (best, _) = expected[output_index];
Expand Down Expand Up @@ -382,6 +386,7 @@ impl TestCase {
.collect()
}

#[allow(clippy::wrong_self_convention)]
fn from_output_coordinate(&self, coordinate: Vec<usize>) -> usize {
coordinate
.into_iter()
Expand Down
23 changes: 21 additions & 2 deletions crates/cubecl-spirv/src/cmma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,17 @@ impl<T: SpirvTarget> SpirvCompiler<T> {

let value = self.compile_variable(value);
let stride = self.compile_variable(stride);
let stride = self.read(&stride);
let stride_item = stride.item();
let mut stride = self.read(&stride);

if let Item::Vector(_, line_size) = value.item() {
let shift = stride_item.const_u32(self, line_size.trailing_zeros());
let stride_ty = stride_item.id(self);
stride = self
.shift_right_logical(stride_ty, None, stride, shift)
.unwrap();
}

let layout = layout
.and_then(compile_layout)
.or(mat.layout)
Expand Down Expand Up @@ -101,11 +111,20 @@ impl<T: SpirvTarget> SpirvCompiler<T> {

let out = self.compile_variable(out);
let stride = self.compile_variable(stride);
let stride = self.read(&stride);
let stride_item = stride.item();
let mut stride = self.read(&stride);
let layout = compile_layout(layout).unwrap_or(CooperativeMatrixLayout::RowMajorKHR);
let memory_layout = self.const_u32(layout as u32);
let ptr = self.deref_slice(&out);

if let Item::Vector(_, line_size) = out.item() {
let shift = stride_item.const_u32(self, line_size.trailing_zeros());
let stride_ty = stride_item.id(self);
stride = self
.shift_right_logical(stride_ty, None, stride, shift)
.unwrap();
}

self.cooperative_matrix_store_khr(ptr, mat_obj, memory_layout, Some(stride), None, vec![])
.unwrap();
}
Expand Down

0 comments on commit 0df1761

Please sign in to comment.