From a483aa33fb75fa9bad46f243da839fe422320da6 Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Thu, 28 Nov 2024 15:41:44 +0100 Subject: [PATCH] Fixup --- crates/cubecl-core/src/runtime_tests/cmma.rs | 2 +- .../src/matmul/components/tile/accelerated.rs | 1 + crates/cubecl-reduce/src/test.rs | 9 +++++++-- crates/cubecl-spirv/src/cmma.rs | 4 ++-- 4 files changed, 11 insertions(+), 5 deletions(-) diff --git a/crates/cubecl-core/src/runtime_tests/cmma.rs b/crates/cubecl-core/src/runtime_tests/cmma.rs index e53475fe..8737c736 100644 --- a/crates/cubecl-core/src/runtime_tests/cmma.rs +++ b/crates/cubecl-core/src/runtime_tests/cmma.rs @@ -220,7 +220,7 @@ pub fn test_simple_tf32( b: Elem::Float(FloatKind::TF32), c: Elem::Float(FloatKind::F32), m: 16, - k: 16, + k: 8, n: 16, }) { // We can't execute the test, skip. diff --git a/crates/cubecl-linalg/src/matmul/components/tile/accelerated.rs b/crates/cubecl-linalg/src/matmul/components/tile/accelerated.rs index e750f809..919641f7 100644 --- a/crates/cubecl-linalg/src/matmul/components/tile/accelerated.rs +++ b/crates/cubecl-linalg/src/matmul/components/tile/accelerated.rs @@ -16,6 +16,7 @@ pub trait CmmaValid {} impl CmmaValid for (f16, f16) {} impl CmmaValid for (f16, f32) {} impl CmmaValid for (bf16, f32) {} +impl CmmaValid for (tf32, f32) {} macro_rules! instruction { ($name:ident, $m:expr, $n:expr, $k:expr) => { diff --git a/crates/cubecl-reduce/src/test.rs b/crates/cubecl-reduce/src/test.rs index 1a5f75ac..c2268cc1 100644 --- a/crates/cubecl-reduce/src/test.rs +++ b/crates/cubecl-reduce/src/test.rs @@ -295,10 +295,10 @@ impl TestCase { naive_reduce_dim_kernel::launch_unchecked::( &client, self.cube_count.clone(), - self.cube_dim.clone(), + self.cube_dim, input_tensor, output_tensor, - ScalarArg::new(self.reduce_dim.clone()), + ScalarArg::new(self.reduce_dim), ); } @@ -311,6 +311,7 @@ impl TestCase { fn cpu_sum_dim(&self, values: &[F]) -> Vec { let mut expected = vec![F::new(0.0); self.num_output_values()]; + #[allow(clippy::needless_range_loop)] for input_index in 0..values.len() { let output_index = self.to_output_index(input_index); expected[output_index] += values[input_index]; @@ -320,6 +321,7 @@ impl TestCase { fn cpu_prod_dim(&self, values: &[F]) -> Vec { let mut expected = vec![F::new(1.0); self.num_output_values()]; + #[allow(clippy::needless_range_loop)] for value_index in 0..values.len() { let output_index = self.to_output_index(value_index); expected[output_index] *= values[value_index]; @@ -336,6 +338,7 @@ impl TestCase { fn cpu_argmax_dim(&self, values: &[F]) -> Vec { let mut expected = vec![(F::MIN, 0_u32); self.num_output_values()]; + #[allow(clippy::needless_range_loop)] for input_index in 0..values.len() { let output_index = self.to_output_index(input_index); let (best, _) = expected[output_index]; @@ -350,6 +353,7 @@ impl TestCase { fn cpu_argmin_dim(&self, values: &[F]) -> Vec { let mut expected = vec![(F::MAX, 0_u32); self.num_output_values()]; + #[allow(clippy::needless_range_loop)] for input_index in 0..values.len() { let output_index = self.to_output_index(input_index); let (best, _) = expected[output_index]; @@ -382,6 +386,7 @@ impl TestCase { .collect() } + #[allow(clippy::wrong_self_convention)] fn from_output_coordinate(&self, coordinate: Vec) -> usize { coordinate .into_iter() diff --git a/crates/cubecl-spirv/src/cmma.rs b/crates/cubecl-spirv/src/cmma.rs index 6b519338..177b7281 100644 --- a/crates/cubecl-spirv/src/cmma.rs +++ b/crates/cubecl-spirv/src/cmma.rs @@ -55,7 +55,7 @@ impl SpirvCompiler { let shift = stride_item.const_u32(self, line_size.trailing_zeros()); let stride_ty = stride_item.id(self); stride = self - .shift_left_logical(stride_ty, None, stride, shift) + .shift_right_logical(stride_ty, None, stride, shift) .unwrap(); } @@ -121,7 +121,7 @@ impl SpirvCompiler { let shift = stride_item.const_u32(self, line_size.trailing_zeros()); let stride_ty = stride_item.id(self); stride = self - .shift_left_logical(stride_ty, None, stride, shift) + .shift_right_logical(stride_ty, None, stride, shift) .unwrap(); }