Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into refactor/wgpu-v23
Browse files Browse the repository at this point in the history
  • Loading branch information
AsherJingkongChen committed Nov 25, 2024
2 parents ba78b17 + 7133da7 commit e04d19e
Show file tree
Hide file tree
Showing 194 changed files with 7,894 additions and 2,995 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ cfg-if = "1.0.0"

### For xtask crate ###
strum = { version = "0.26.3", features = ["derive"] }
tracel-xtask = { version = "~1.1" }
tracel-xtask = { version = "~1.1.6" }

portable-atomic = { version = "1.9" }
pretty_assertions = "1.4"
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ Therefore, each kind of variable also has its own axis-independent variable, whi
| UNIT_POS_X | threadIdx.x | local_invocation_id.x |
| UNIT_POS_Y | threadIdx.y | local_invocation_id.y |
| UNIT_POS_Z | threadIdx.z | local_invocation_id.z |
| SUBCUBE_DIM | warpSize | subgroup_size |
| PLANE_DIM | warpSize | subgroup_size |
| ABSOLUTE_POS | N/A | N/A |
| ABSOLUTE_POS_X | N/A | global_id.x |
| ABSOLUTE_POS_Y | N/A | global_id.y |
Expand Down
7 changes: 6 additions & 1 deletion crates/cubecl-core/src/codegen/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,14 @@ pub trait CompilerRepresentation: Display {
pub trait Compiler: Sync + Send + 'static + Clone + Default + core::fmt::Debug {
/// The representation for the compiled code.
type Representation: CompilerRepresentation;
type CompilationOptions: Send + Default + core::fmt::Debug;

/// Compiles the [kernel definition](KernelDefinition) into the compiler's representation.
fn compile(kernel: KernelDefinition, mode: ExecutionMode) -> Self::Representation;
fn compile(
kernel: KernelDefinition,
compilation_options: &Self::CompilationOptions,
mode: ExecutionMode,
) -> Self::Representation;
/// The size of the given element in bytes.
fn elem_size(elem: Elem) -> usize;
fn local_allocator() -> impl LocalAllocator;
Expand Down
10 changes: 10 additions & 0 deletions crates/cubecl-core/src/codegen/integrator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ pub struct KernelExpansion {
pub inputs: Vec<InputInfo>,
pub outputs: Vec<OutputInfo>,
pub scope: Scope,
pub kernel_name: String,
}

/// Simply indicate the output that can be replaced by the input.
Expand Down Expand Up @@ -55,6 +56,7 @@ pub struct KernelSettings {
vectorization_partial: Vec<VectorizationPartial>,
pub cube_dim: CubeDim,
pub reading_strategy: Vec<(u16, ReadingStrategy)>,
pub kernel_name: String,
}

impl core::fmt::Display for KernelSettings {
Expand Down Expand Up @@ -193,6 +195,13 @@ impl KernelSettings {
self.cube_dim = cube_dim;
self
}

/// Set kernel name.
#[allow(dead_code)]
pub fn kernel_name<S: AsRef<str>>(mut self, name: S) -> Self {
self.kernel_name = name.as_ref().to_string();
self
}
}

#[allow(dead_code)]
Expand Down Expand Up @@ -331,6 +340,7 @@ impl KernelIntegrator {
named,
cube_dim: settings.cube_dim,
body: self.expansion.scope,
kernel_name: self.expansion.kernel_name,
}
}

Expand Down
1 change: 1 addition & 0 deletions crates/cubecl-core/src/compute/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ impl KernelBuilder {
scope: self.context.into_scope(),
inputs: self.inputs,
outputs: self.outputs,
kernel_name: settings.kernel_name.clone(),
})
.integrate(settings)
}
Expand Down
68 changes: 58 additions & 10 deletions crates/cubecl-core/src/compute/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,37 @@ use cubecl_runtime::ExecutionMode;

/// A kernel, compiled in the target language
pub struct CompiledKernel<C: Compiler> {
pub name: Option<&'static str>,
/// The name of the kernel entrypoint.

/// For example
///
/// ```text
/// #[cube(launch)]
/// fn gelu_array<F: Float, R: Runtime>() {}
/// ```
///
/// would have the entrypoint name "gelu_array".
pub entrypoint_name: String,

/// A fully qualified debug name of the kernel.
///
/// For example
///
/// ```text
/// #[cube(launch)]
/// fn gelu_array<F: Float, R: Runtime>() {}
/// ```
///
/// would have a debug name such as
///
/// ```text
/// gelu::gelu_array::GeluArray<
/// cubecl_core::frontend::element::float::F32,
/// cubecl_cuda::runtime::CudaRuntime,
/// >
/// ```
pub debug_name: Option<&'static str>,

/// Source code of the kernel
pub source: String,
/// In-memory representation of the kernel
Expand Down Expand Up @@ -48,7 +78,7 @@ impl<C: Compiler> Display for CompiledKernel<C> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("\n[START_KERNEL_COMPILATION]")?;

if let Some(name) = self.name {
if let Some(name) = self.debug_name {
if name.len() <= 32 {
f.write_fmt(format_args!("\nname: {name}"))?;
} else {
Expand Down Expand Up @@ -172,7 +202,11 @@ pub trait CubeTask<C: Compiler>: Send + Sync {
/// Identifier for the kernel, used for caching kernel compilation.
fn id(&self) -> KernelId;
/// Compile the kernel into source
fn compile(&self, mode: ExecutionMode) -> CompiledKernel<C>;
fn compile(
&self,
compilation_options: &C::CompilationOptions,
mode: ExecutionMode,
) -> CompiledKernel<C>;
fn name(&self) -> &'static str {
core::any::type_name::<Self>()
}
Expand All @@ -186,14 +220,20 @@ pub struct KernelTask<C: Compiler, K: Kernel> {
}

impl<C: Compiler, K: Kernel> CubeTask<C> for KernelTask<C, K> {
fn compile(&self, mode: ExecutionMode) -> CompiledKernel<C> {
fn compile(
&self,
compilation_options: &C::CompilationOptions,
mode: ExecutionMode,
) -> CompiledKernel<C> {
let gpu_ir = self.kernel_definition.define();
let entrypoint_name = gpu_ir.kernel_name.clone();
let cube_dim = gpu_ir.cube_dim;
let lower_level_ir = C::compile(gpu_ir, mode);
let lower_level_ir = C::compile(gpu_ir, compilation_options, mode);
let shared_mem_bytes = lower_level_ir.shared_memory_size();

CompiledKernel {
name: Some(core::any::type_name::<K>()),
entrypoint_name,
debug_name: Some(core::any::type_name::<K>()),
source: lower_level_ir.to_string(),
repr: Some(lower_level_ir),
cube_dim,
Expand All @@ -212,8 +252,12 @@ impl<C: Compiler, K: Kernel> CubeTask<C> for KernelTask<C, K> {
}

impl<C: Compiler> CubeTask<C> for Arc<dyn CubeTask<C>> {
fn compile(&self, mode: ExecutionMode) -> CompiledKernel<C> {
self.as_ref().compile(mode)
fn compile(
&self,
compilation_options: &C::CompilationOptions,
mode: ExecutionMode,
) -> CompiledKernel<C> {
self.as_ref().compile(compilation_options, mode)
}

fn id(&self) -> KernelId {
Expand All @@ -225,8 +269,12 @@ impl<C: Compiler> CubeTask<C> for Arc<dyn CubeTask<C>> {
}

impl<C: Compiler> CubeTask<C> for Box<dyn CubeTask<C>> {
fn compile(&self, mode: ExecutionMode) -> CompiledKernel<C> {
self.as_ref().compile(mode)
fn compile(
&self,
compilation_options: &C::CompilationOptions,
mode: ExecutionMode,
) -> CompiledKernel<C> {
self.as_ref().compile(compilation_options, mode)
}

fn id(&self) -> KernelId {
Expand Down
20 changes: 8 additions & 12 deletions crates/cubecl-core/src/frontend/cmma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ impl<C: CubePrimitive> Matrix<C> {
n: u32,
k: u32,
layout: MatrixLayout,
value: &Slice<'_, C>,
value: &Slice<C>,
stride: u32,
) -> Self {
Matrix { _c: PhantomData }
Expand Down Expand Up @@ -227,7 +227,7 @@ impl<C: CubePrimitive> Matrix<C> {
n: ExpandElementTyped<u32>,
k: ExpandElementTyped<u32>,
layout: MatrixLayout,
value: ExpandElementTyped<Slice<'static, C>>,
value: ExpandElementTyped<Slice<C>>,
stride: ExpandElementTyped<u32>,
) -> MatrixExpand<C> {
let mat = Self::__expand_uninitialized(context, ident, m, n, k, layout);
Expand Down Expand Up @@ -262,11 +262,7 @@ pub mod fill {

/// Load the matrix with the provided array using the stride.
#[allow(unused_variables)]
pub fn load<C: CubePrimitive, V: CubePrimitive>(
mat: &Matrix<C>,
value: &Slice<'_, V>,
stride: u32,
) {
pub fn load<C: CubePrimitive, V: CubePrimitive>(mat: &Matrix<C>, value: &Slice<V>, stride: u32) {
unexpanded!()
}

Expand All @@ -279,7 +275,7 @@ pub mod load {
pub fn expand<C: CubePrimitive, V: CubePrimitive>(
context: &mut CubeContext,
mat: MatrixExpand<C>,
value: ExpandElementTyped<Slice<'static, V>>,
value: ExpandElementTyped<Slice<V>>,
stride: ExpandElementTyped<u32>,
) {
let stride: ExpandElement = stride.into();
Expand All @@ -305,7 +301,7 @@ pub mod load {
#[allow(unused_variables)]
pub fn load_with_layout<C: CubeType>(
mat: &Matrix<C>,
value: &Slice<'_, C>,
value: &Slice<C>,
stride: u32,
layout: MatrixLayout,
) {
Expand All @@ -321,7 +317,7 @@ pub mod load_with_layout {
pub fn expand<C: CubeType>(
context: &mut CubeContext,
mat: MatrixExpand<C>,
value: ExpandElementTyped<Slice<'static, C>>,
value: ExpandElementTyped<Slice<C>>,
stride: ExpandElementTyped<u32>,
layout: MatrixLayout,
) {
Expand All @@ -341,7 +337,7 @@ pub mod load_with_layout {
/// Store the matrix in the given array following the given stride and layout.
#[allow(unused_variables)]
pub fn store<C: CubePrimitive, O: CubePrimitive>(
output: &mut SliceMut<'_, O>,
output: &mut SliceMut<O>,
mat: &Matrix<C>,
stride: u32,
layout: MatrixLayout,
Expand All @@ -357,7 +353,7 @@ pub mod store {
#[allow(unused_variables)]
pub fn expand<C: CubePrimitive, O: CubePrimitive>(
context: &mut CubeContext,
output: ExpandElementTyped<SliceMut<'static, O>>,
output: ExpandElementTyped<SliceMut<O>>,
mat: MatrixExpand<C>,
stride: ExpandElementTyped<u32>,
layout: MatrixLayout,
Expand Down
122 changes: 122 additions & 0 deletions crates/cubecl-core/src/frontend/container/line/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,128 @@ where
}
}

impl<P> core::ops::BitAnd<Self> for Line<P>
where
P: CubePrimitive,
P: core::ops::BitAnd<P, Output = P>,
{
type Output = Self;

fn bitand(self, rhs: Self) -> Self::Output {
Self::new(self.val & rhs.val)
}
}

impl<P> core::ops::BitOr<Self> for Line<P>
where
P: CubePrimitive,
P: core::ops::BitOr<P, Output = P>,
{
type Output = Self;

fn bitor(self, rhs: Self) -> Self::Output {
Self::new(self.val | rhs.val)
}
}

impl<P> core::ops::BitXor<Self> for Line<P>
where
P: CubePrimitive,
P: core::ops::BitXor<P, Output = P>,
{
type Output = Self;

fn bitxor(self, rhs: Self) -> Self::Output {
Self::new(self.val ^ rhs.val)
}
}

impl<P> core::ops::Not for Line<P>
where
P: CubePrimitive,
P: core::ops::Not<Output = P>,
{
type Output = Self;

fn not(self) -> Self::Output {
Self::new(!self.val)
}
}

impl<P> core::ops::Shl<Self> for Line<P>
where
P: CubePrimitive,
P: core::ops::Shl<P, Output = P>,
{
type Output = Self;

fn shl(self, rhs: Self) -> Self::Output {
Self::new(self.val << rhs.val)
}
}

impl<P> core::ops::Shr<Self> for Line<P>
where
P: CubePrimitive,
P: core::ops::Shr<P, Output = P>,
{
type Output = Self;

fn shr(self, rhs: Self) -> Self::Output {
Self::new(self.val >> rhs.val)
}
}

impl<P> core::ops::BitAndAssign<Self> for Line<P>
where
P: CubePrimitive,
P: core::ops::BitAndAssign,
{
fn bitand_assign(&mut self, rhs: Self) {
self.val &= rhs.val;
}
}

impl<P> core::ops::BitOrAssign<Self> for Line<P>
where
P: CubePrimitive,
P: core::ops::BitOrAssign,
{
fn bitor_assign(&mut self, rhs: Self) {
self.val |= rhs.val;
}
}

impl<P> core::ops::BitXorAssign<Self> for Line<P>
where
P: CubePrimitive,
P: core::ops::BitXorAssign,
{
fn bitxor_assign(&mut self, rhs: Self) {
self.val ^= rhs.val;
}
}

impl<P> core::ops::ShlAssign<Self> for Line<P>
where
P: CubePrimitive,
P: core::ops::ShlAssign,
{
fn shl_assign(&mut self, rhs: Self) {
self.val <<= rhs.val;
}
}

impl<P> core::ops::ShrAssign<Self> for Line<P>
where
P: CubePrimitive,
P: core::ops::ShrAssign,
{
fn shr_assign(&mut self, rhs: Self) {
self.val >>= rhs.val;
}
}

impl<P: CubePrimitive + Abs> Abs for Line<P> {}
impl<P: CubePrimitive + Max> Max for Line<P> {}
impl<P: CubePrimitive + Min> Min for Line<P> {}
Expand Down
Loading

0 comments on commit e04d19e

Please sign in to comment.