Skip to content

Commit

Permalink
Fix: cuda support different ranks (#124)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Sep 16, 2024
1 parent 443b5c8 commit 0c494b6
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 42 deletions.
61 changes: 19 additions & 42 deletions crates/cubecl-core/src/compute/launcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,8 @@ impl<R: Runtime> TensorState<R> {
Self::adjust_rank(metadata, bindings.len() - 1, rank);
}

Self::register_strides(tensor.strides, tensor.shape, None, metadata);
Self::register_shape(tensor.shape, None, metadata);
Self::register_strides(tensor.strides, tensor.shape, metadata);
Self::register_shape(tensor.shape, metadata);

if R::require_array_lengths() {
let len = calculate_num_elems_dyn_rank(tensor.shape);
Expand All @@ -227,62 +227,39 @@ impl<R: Runtime> TensorState<R> {
let strides_old = &metadata[stride_index..stride_index + old_rank];
let shape_old = &metadata[shape_index..shape_index + old_rank];

Self::register_strides(
strides_old,
shape_old,
Some(old_rank as u32),
&mut updated_metadata,
);
Self::register_shape(shape_old, Some(old_rank as u32), &mut updated_metadata);
Self::register_strides(strides_old, shape_old, &mut updated_metadata);
Self::register_shape(shape_old, &mut updated_metadata);
}

core::mem::swap(&mut updated_metadata, metadata);
}

fn register_strides<T: ToPrimitive>(
strides: &[T],
shape: &[T],
old_rank: Option<u32>,
output: &mut Vec<u32>,
) {
let old_rank = if let Some(old_rank) = old_rank {
let rank = output[0];
let rank_diff = i32::abs(old_rank as i32 - rank as i32) as usize;
fn register_strides<T: ToPrimitive>(strides: &[T], shape: &[T], output: &mut Vec<u32>) {
let old_rank = output[0] as usize;
let rank_diff = i32::abs(old_rank as i32 - shape.len() as i32) as usize;

if rank_diff > 0 {
let padded_strides = shape.iter().map(|a| a.to_u32().unwrap()).sum::<u32>();
if rank_diff > 0 {
let padded_strides = shape.iter().map(|a| a.to_u32().unwrap()).sum::<u32>();

for _ in 0..rank_diff {
output.push(padded_strides);
}
for _ in 0..rank_diff {
output.push(padded_strides);
}

old_rank as usize
} else {
output[0] as usize // same as current.
};
}

for stride in strides.iter().take(old_rank) {
output.push(stride.to_u32().unwrap());
}
}

fn register_shape<T: ToPrimitive>(shape: &[T], old_rank: Option<u32>, output: &mut Vec<u32>) {
let old_rank = if let Some(old_rank) = old_rank {
let rank = output[0];
let rank_diff = i32::abs(old_rank as i32 - rank as i32) as usize;
fn register_shape<T: ToPrimitive>(shape: &[T], output: &mut Vec<u32>) {
let old_rank = output[0] as usize;
let rank_diff = i32::abs(old_rank as i32 - shape.len() as i32) as usize;

if rank_diff > 0 {
for _ in 0..rank_diff {
output.push(1);
}
if rank_diff > 0 {
for _ in 0..rank_diff {
output.push(1);
}

old_rank as usize
} else {
output[0] as usize // same as current
};

}
for elem in shape.iter().take(old_rank) {
output.push(elem.to_u32().unwrap());
}
Expand Down
101 changes: 101 additions & 0 deletions crates/cubecl-core/src/runtime_tests/different_rank.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
use crate as cubecl;

use cubecl::prelude::*;

#[cube(launch)]
pub fn kernel_different_rank(lhs: &Tensor<f32>, rhs: &Tensor<f32>, output: &mut Tensor<f32>) {
output[ABSOLUTE_POS] = lhs[ABSOLUTE_POS] + rhs[ABSOLUTE_POS];
}

pub fn test_kernel_different_rank_first_biggest<R: Runtime>(
client: ComputeClient<R::Server, R::Channel>,
) {
let shape_lhs = vec![2, 2, 2];
let shape_rhs = vec![8];
let shape_out = vec![2, 4];

let strides_lhs = vec![8, 4, 1];
let strides_rhs = vec![1];
let strides_out = vec![4, 1];

test_kernel_different_rank::<R>(
client,
(shape_lhs, shape_rhs, shape_out),
(strides_lhs, strides_rhs, strides_out),
);
}

pub fn test_kernel_different_rank_last_biggest<R: Runtime>(
client: ComputeClient<R::Server, R::Channel>,
) {
let shape_lhs = vec![2, 4];
let shape_rhs = vec![8];
let shape_out = vec![2, 2, 2];

let strides_lhs = vec![4, 1];
let strides_rhs = vec![1];
let strides_out = vec![8, 4, 1];

test_kernel_different_rank::<R>(
client,
(shape_lhs, shape_rhs, shape_out),
(strides_lhs, strides_rhs, strides_out),
);
}

fn test_kernel_different_rank<R: Runtime>(
client: ComputeClient<R::Server, R::Channel>,
(shape_lhs, shape_rhs, shape_out): (Vec<usize>, Vec<usize>, Vec<usize>),
(strides_lhs, strides_rhs, strides_out): (Vec<usize>, Vec<usize>, Vec<usize>),
) {
let vectorisation = 2;

let handle_lhs = client.create(f32::as_bytes(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]));
let handle_rhs = client.create(f32::as_bytes(&[3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]));
let handle_out = client.create(f32::as_bytes(&[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]));

let lhs =
unsafe { TensorArg::from_raw_parts(&handle_lhs, &strides_lhs, &shape_lhs, vectorisation) };
let rhs =
unsafe { TensorArg::from_raw_parts(&handle_rhs, &strides_rhs, &shape_rhs, vectorisation) };
let out =
unsafe { TensorArg::from_raw_parts(&handle_out, &strides_out, &shape_out, vectorisation) };

kernel_different_rank::launch::<R>(
&client,
CubeCount::Static(1, 1, 1),
CubeDim::default(),
lhs,
rhs,
out,
);

let actual = client.read(handle_out.binding());
let actual = f32::from_bytes(&actual);

assert_eq!(actual, &[3.0, 5.0, 7.0, 9.0, 11.0, 13.0, 15.0, 17.0]);
}

#[allow(missing_docs)]
#[macro_export]
macro_rules! testgen_different_rank {
() => {
use super::*;

#[test]
fn test_kernel_different_rank_first_biggest() {
let client = TestRuntime::client(&Default::default());
cubecl_core::runtime_tests::different_rank::test_kernel_different_rank_first_biggest::<
TestRuntime,
>(client);
}

#[test]
fn test_kernel_different_rank_last_biggest() {
let client = TestRuntime::client(&Default::default());
cubecl_core::runtime_tests::different_rank::test_kernel_different_rank_last_biggest::<
TestRuntime,
>(client);
}
};
}
2 changes: 2 additions & 0 deletions crates/cubecl-core/src/runtime_tests/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
pub mod assign;
pub mod cmma;
pub mod different_rank;
pub mod launch;
pub mod sequence;
pub mod slice;
Expand All @@ -21,5 +22,6 @@ macro_rules! testgen_all {
cubecl_core::testgen_topology!();
cubecl_core::testgen_sequence!();
cubecl_core::testgen_unary!();
cubecl_core::testgen_different_rank!();
};
}

0 comments on commit 0c494b6

Please sign in to comment.