From 0c494b6d6ec3cc28635e0c6d7a4935008c3c8b8b Mon Sep 17 00:00:00 2001 From: Nathaniel Simard Date: Mon, 16 Sep 2024 12:13:49 -0400 Subject: [PATCH] Fix: cuda support different ranks (#124) --- crates/cubecl-core/src/compute/launcher.rs | 61 ++++------- .../src/runtime_tests/different_rank.rs | 101 ++++++++++++++++++ crates/cubecl-core/src/runtime_tests/mod.rs | 2 + 3 files changed, 122 insertions(+), 42 deletions(-) create mode 100644 crates/cubecl-core/src/runtime_tests/different_rank.rs diff --git a/crates/cubecl-core/src/compute/launcher.rs b/crates/cubecl-core/src/compute/launcher.rs index 40750c0f..1356802a 100644 --- a/crates/cubecl-core/src/compute/launcher.rs +++ b/crates/cubecl-core/src/compute/launcher.rs @@ -205,8 +205,8 @@ impl TensorState { Self::adjust_rank(metadata, bindings.len() - 1, rank); } - Self::register_strides(tensor.strides, tensor.shape, None, metadata); - Self::register_shape(tensor.shape, None, metadata); + Self::register_strides(tensor.strides, tensor.shape, metadata); + Self::register_shape(tensor.shape, metadata); if R::require_array_lengths() { let len = calculate_num_elems_dyn_rank(tensor.shape); @@ -227,62 +227,39 @@ impl TensorState { let strides_old = &metadata[stride_index..stride_index + old_rank]; let shape_old = &metadata[shape_index..shape_index + old_rank]; - Self::register_strides( - strides_old, - shape_old, - Some(old_rank as u32), - &mut updated_metadata, - ); - Self::register_shape(shape_old, Some(old_rank as u32), &mut updated_metadata); + Self::register_strides(strides_old, shape_old, &mut updated_metadata); + Self::register_shape(shape_old, &mut updated_metadata); } core::mem::swap(&mut updated_metadata, metadata); } - fn register_strides( - strides: &[T], - shape: &[T], - old_rank: Option, - output: &mut Vec, - ) { - let old_rank = if let Some(old_rank) = old_rank { - let rank = output[0]; - let rank_diff = i32::abs(old_rank as i32 - rank as i32) as usize; + fn register_strides(strides: &[T], shape: &[T], output: &mut Vec) { + let old_rank = output[0] as usize; + let rank_diff = i32::abs(old_rank as i32 - shape.len() as i32) as usize; - if rank_diff > 0 { - let padded_strides = shape.iter().map(|a| a.to_u32().unwrap()).sum::(); + if rank_diff > 0 { + let padded_strides = shape.iter().map(|a| a.to_u32().unwrap()).sum::(); - for _ in 0..rank_diff { - output.push(padded_strides); - } + for _ in 0..rank_diff { + output.push(padded_strides); } - - old_rank as usize - } else { - output[0] as usize // same as current. - }; + } for stride in strides.iter().take(old_rank) { output.push(stride.to_u32().unwrap()); } } - fn register_shape(shape: &[T], old_rank: Option, output: &mut Vec) { - let old_rank = if let Some(old_rank) = old_rank { - let rank = output[0]; - let rank_diff = i32::abs(old_rank as i32 - rank as i32) as usize; + fn register_shape(shape: &[T], output: &mut Vec) { + let old_rank = output[0] as usize; + let rank_diff = i32::abs(old_rank as i32 - shape.len() as i32) as usize; - if rank_diff > 0 { - for _ in 0..rank_diff { - output.push(1); - } + if rank_diff > 0 { + for _ in 0..rank_diff { + output.push(1); } - - old_rank as usize - } else { - output[0] as usize // same as current - }; - + } for elem in shape.iter().take(old_rank) { output.push(elem.to_u32().unwrap()); } diff --git a/crates/cubecl-core/src/runtime_tests/different_rank.rs b/crates/cubecl-core/src/runtime_tests/different_rank.rs new file mode 100644 index 00000000..1a54ba4a --- /dev/null +++ b/crates/cubecl-core/src/runtime_tests/different_rank.rs @@ -0,0 +1,101 @@ +use crate as cubecl; + +use cubecl::prelude::*; + +#[cube(launch)] +pub fn kernel_different_rank(lhs: &Tensor, rhs: &Tensor, output: &mut Tensor) { + output[ABSOLUTE_POS] = lhs[ABSOLUTE_POS] + rhs[ABSOLUTE_POS]; +} + +pub fn test_kernel_different_rank_first_biggest( + client: ComputeClient, +) { + let shape_lhs = vec![2, 2, 2]; + let shape_rhs = vec![8]; + let shape_out = vec![2, 4]; + + let strides_lhs = vec![8, 4, 1]; + let strides_rhs = vec![1]; + let strides_out = vec![4, 1]; + + test_kernel_different_rank::( + client, + (shape_lhs, shape_rhs, shape_out), + (strides_lhs, strides_rhs, strides_out), + ); +} + +pub fn test_kernel_different_rank_last_biggest( + client: ComputeClient, +) { + let shape_lhs = vec![2, 4]; + let shape_rhs = vec![8]; + let shape_out = vec![2, 2, 2]; + + let strides_lhs = vec![4, 1]; + let strides_rhs = vec![1]; + let strides_out = vec![8, 4, 1]; + + test_kernel_different_rank::( + client, + (shape_lhs, shape_rhs, shape_out), + (strides_lhs, strides_rhs, strides_out), + ); +} + +fn test_kernel_different_rank( + client: ComputeClient, + (shape_lhs, shape_rhs, shape_out): (Vec, Vec, Vec), + (strides_lhs, strides_rhs, strides_out): (Vec, Vec, Vec), +) { + let vectorisation = 2; + + let handle_lhs = client.create(f32::as_bytes(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0])); + let handle_rhs = client.create(f32::as_bytes(&[3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0])); + let handle_out = client.create(f32::as_bytes(&[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])); + + let lhs = + unsafe { TensorArg::from_raw_parts(&handle_lhs, &strides_lhs, &shape_lhs, vectorisation) }; + let rhs = + unsafe { TensorArg::from_raw_parts(&handle_rhs, &strides_rhs, &shape_rhs, vectorisation) }; + let out = + unsafe { TensorArg::from_raw_parts(&handle_out, &strides_out, &shape_out, vectorisation) }; + + kernel_different_rank::launch::( + &client, + CubeCount::Static(1, 1, 1), + CubeDim::default(), + lhs, + rhs, + out, + ); + + let actual = client.read(handle_out.binding()); + let actual = f32::from_bytes(&actual); + + assert_eq!(actual, &[3.0, 5.0, 7.0, 9.0, 11.0, 13.0, 15.0, 17.0]); +} + +#[allow(missing_docs)] +#[macro_export] +macro_rules! testgen_different_rank { + () => { + use super::*; + + #[test] + fn test_kernel_different_rank_first_biggest() { + let client = TestRuntime::client(&Default::default()); + cubecl_core::runtime_tests::different_rank::test_kernel_different_rank_first_biggest::< + TestRuntime, + >(client); + } + + #[test] + fn test_kernel_different_rank_last_biggest() { + let client = TestRuntime::client(&Default::default()); + cubecl_core::runtime_tests::different_rank::test_kernel_different_rank_last_biggest::< + TestRuntime, + >(client); + } + }; +} diff --git a/crates/cubecl-core/src/runtime_tests/mod.rs b/crates/cubecl-core/src/runtime_tests/mod.rs index 31879f38..08ba577c 100644 --- a/crates/cubecl-core/src/runtime_tests/mod.rs +++ b/crates/cubecl-core/src/runtime_tests/mod.rs @@ -1,5 +1,6 @@ pub mod assign; pub mod cmma; +pub mod different_rank; pub mod launch; pub mod sequence; pub mod slice; @@ -21,5 +22,6 @@ macro_rules! testgen_all { cubecl_core::testgen_topology!(); cubecl_core::testgen_sequence!(); cubecl_core::testgen_unary!(); + cubecl_core::testgen_different_rank!(); }; }