Skip to content

Commit

Permalink
Fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
wingertge committed Nov 28, 2024
1 parent 29e5557 commit a483aa3
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 5 deletions.
2 changes: 1 addition & 1 deletion crates/cubecl-core/src/runtime_tests/cmma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ pub fn test_simple_tf32<R: Runtime>(
b: Elem::Float(FloatKind::TF32),
c: Elem::Float(FloatKind::F32),
m: 16,
k: 16,
k: 8,
n: 16,
}) {
// We can't execute the test, skip.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ pub trait CmmaValid<I: Numeric, O: Numeric> {}
impl CmmaValid<f16, f16> for (f16, f16) {}
impl CmmaValid<f16, f32> for (f16, f32) {}
impl CmmaValid<bf16, f32> for (bf16, f32) {}
impl CmmaValid<tf32, f32> for (tf32, f32) {}

macro_rules! instruction {
($name:ident, $m:expr, $n:expr, $k:expr) => {
Expand Down
9 changes: 7 additions & 2 deletions crates/cubecl-reduce/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -295,10 +295,10 @@ impl TestCase {
naive_reduce_dim_kernel::launch_unchecked::<I, O, K, R>(
&client,
self.cube_count.clone(),
self.cube_dim.clone(),
self.cube_dim,
input_tensor,
output_tensor,
ScalarArg::new(self.reduce_dim.clone()),
ScalarArg::new(self.reduce_dim),
);
}

Expand All @@ -311,6 +311,7 @@ impl TestCase {

fn cpu_sum_dim<F: Float>(&self, values: &[F]) -> Vec<F> {
let mut expected = vec![F::new(0.0); self.num_output_values()];
#[allow(clippy::needless_range_loop)]
for input_index in 0..values.len() {
let output_index = self.to_output_index(input_index);
expected[output_index] += values[input_index];
Expand All @@ -320,6 +321,7 @@ impl TestCase {

fn cpu_prod_dim<F: Float>(&self, values: &[F]) -> Vec<F> {
let mut expected = vec![F::new(1.0); self.num_output_values()];
#[allow(clippy::needless_range_loop)]
for value_index in 0..values.len() {
let output_index = self.to_output_index(value_index);
expected[output_index] *= values[value_index];
Expand All @@ -336,6 +338,7 @@ impl TestCase {

fn cpu_argmax_dim<F: Float>(&self, values: &[F]) -> Vec<u32> {
let mut expected = vec![(F::MIN, 0_u32); self.num_output_values()];
#[allow(clippy::needless_range_loop)]
for input_index in 0..values.len() {
let output_index = self.to_output_index(input_index);
let (best, _) = expected[output_index];
Expand All @@ -350,6 +353,7 @@ impl TestCase {

fn cpu_argmin_dim<F: Float>(&self, values: &[F]) -> Vec<u32> {
let mut expected = vec![(F::MAX, 0_u32); self.num_output_values()];
#[allow(clippy::needless_range_loop)]
for input_index in 0..values.len() {
let output_index = self.to_output_index(input_index);
let (best, _) = expected[output_index];
Expand Down Expand Up @@ -382,6 +386,7 @@ impl TestCase {
.collect()
}

#[allow(clippy::wrong_self_convention)]
fn from_output_coordinate(&self, coordinate: Vec<usize>) -> usize {
coordinate
.into_iter()
Expand Down
4 changes: 2 additions & 2 deletions crates/cubecl-spirv/src/cmma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ impl<T: SpirvTarget> SpirvCompiler<T> {
let shift = stride_item.const_u32(self, line_size.trailing_zeros());
let stride_ty = stride_item.id(self);
stride = self
.shift_left_logical(stride_ty, None, stride, shift)
.shift_right_logical(stride_ty, None, stride, shift)
.unwrap();
}

Expand Down Expand Up @@ -121,7 +121,7 @@ impl<T: SpirvTarget> SpirvCompiler<T> {
let shift = stride_item.const_u32(self, line_size.trailing_zeros());
let stride_ty = stride_item.id(self);
stride = self
.shift_left_logical(stride_ty, None, stride, shift)
.shift_right_logical(stride_ty, None, stride, shift)
.unwrap();
}

Expand Down

0 comments on commit a483aa3

Please sign in to comment.