-
Notifications
You must be signed in to change notification settings - Fork 36
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Matmul: customizable cube dispatch (#273)
* wip swizzle dispatch * fix swizzle dispatch * fmt * minor
- Loading branch information
Showing
11 changed files
with
826 additions
and
225 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
135 changes: 135 additions & 0 deletions
135
crates/cubecl-linalg/src/matmul/components/batch/cube_dispatch.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
use cubecl_core::prelude::*; | ||
use cubecl_core::{self as cubecl}; | ||
use std::fmt::Debug; | ||
use std::hash::Hash; | ||
|
||
use crate::matmul::components::batch::shared::swizzle; | ||
|
||
#[cube] | ||
pub trait CubeDispatch: Clone + Copy + 'static + Send + Sync + Debug + Hash + Eq { | ||
fn x_y_indices() -> (u32, u32); | ||
fn batch_index() -> u32; | ||
fn max_x(#[comptime] cube_count: (u32, u32, u32)) -> u32; | ||
fn max_y(#[comptime] cube_count: (u32, u32, u32)) -> u32; | ||
fn max_batches(#[comptime] cube_count: (u32, u32, u32)) -> u32; | ||
} | ||
|
||
#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq)] | ||
/// Operates on data further along the m dimension as `cube_pos_x` increases, | ||
/// and further along the n dimension as `cube_pos_y` increases. | ||
pub struct NaturalDispatch; | ||
|
||
#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq)] | ||
/// Operates on data further along the m dimension as `cube_pos_x` increases, | ||
/// and further along the n dimension as `cube_pos_y` increases. | ||
pub struct TransposedDispatch; | ||
|
||
#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq)] | ||
/// Processes data in a swizzled pattern, prioritizing cubes along the x-axis first. | ||
/// | ||
/// # Generics | ||
/// - W: Width of a swizzle column | ||
pub struct SwizzleNaturalDispatch<const W: u32>; | ||
|
||
#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq)] | ||
/// Processes data in a swizzled pattern, prioritizing cubes along the y-axis first. | ||
/// | ||
/// # Generics | ||
/// - W: Width of a swizzle column | ||
pub struct SwizzleTransposedDispatch<const W: u32>; | ||
|
||
#[cube] | ||
impl CubeDispatch for NaturalDispatch { | ||
fn x_y_indices() -> (u32, u32) { | ||
(CUBE_POS_X, CUBE_POS_Y) | ||
} | ||
|
||
fn batch_index() -> u32 { | ||
CUBE_POS_Z | ||
} | ||
|
||
fn max_x(#[comptime] cube_count: (u32, u32, u32)) -> u32 { | ||
cube_count.0 | ||
} | ||
|
||
fn max_y(#[comptime] cube_count: (u32, u32, u32)) -> u32 { | ||
cube_count.1 | ||
} | ||
|
||
fn max_batches(#[comptime] cube_count: (u32, u32, u32)) -> u32 { | ||
cube_count.2 | ||
} | ||
} | ||
|
||
#[cube] | ||
impl CubeDispatch for TransposedDispatch { | ||
fn x_y_indices() -> (u32, u32) { | ||
(CUBE_POS_Y, CUBE_POS_X) | ||
} | ||
|
||
fn batch_index() -> u32 { | ||
CUBE_POS_Z | ||
} | ||
|
||
fn max_x(#[comptime] cube_count: (u32, u32, u32)) -> u32 { | ||
cube_count.1 | ||
} | ||
|
||
fn max_y(#[comptime] cube_count: (u32, u32, u32)) -> u32 { | ||
cube_count.0 | ||
} | ||
|
||
fn max_batches(#[comptime] cube_count: (u32, u32, u32)) -> u32 { | ||
cube_count.2 | ||
} | ||
} | ||
|
||
#[cube] | ||
impl<const W: u32> CubeDispatch for SwizzleNaturalDispatch<W> { | ||
fn x_y_indices() -> (u32, u32) { | ||
let height = CUBE_COUNT_X; | ||
let nth_cube = CUBE_POS_Y * height + CUBE_POS_X; | ||
swizzle(nth_cube, height, W) | ||
} | ||
|
||
fn batch_index() -> u32 { | ||
CUBE_POS_Z | ||
} | ||
|
||
fn max_x(#[comptime] cube_count: (u32, u32, u32)) -> u32 { | ||
cube_count.0 | ||
} | ||
|
||
fn max_y(#[comptime] cube_count: (u32, u32, u32)) -> u32 { | ||
cube_count.1 | ||
} | ||
|
||
fn max_batches(#[comptime] cube_count: (u32, u32, u32)) -> u32 { | ||
cube_count.2 | ||
} | ||
} | ||
|
||
#[cube] | ||
impl<const W: u32> CubeDispatch for SwizzleTransposedDispatch<W> { | ||
fn x_y_indices() -> (u32, u32) { | ||
let height = CUBE_COUNT_Y; | ||
let nth_cube = CUBE_POS_X * height + CUBE_POS_Y; | ||
swizzle(nth_cube, height, W) | ||
} | ||
|
||
fn batch_index() -> u32 { | ||
CUBE_POS_Z | ||
} | ||
|
||
fn max_x(#[comptime] cube_count: (u32, u32, u32)) -> u32 { | ||
cube_count.1 | ||
} | ||
|
||
fn max_y(#[comptime] cube_count: (u32, u32, u32)) -> u32 { | ||
cube_count.0 | ||
} | ||
|
||
fn max_batches(#[comptime] cube_count: (u32, u32, u32)) -> u32 { | ||
cube_count.2 | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.