diff --git a/crates/cubecl-core/src/compute/kernel.rs b/crates/cubecl-core/src/compute/kernel.rs index 7b950ec03..8a5316f6f 100644 --- a/crates/cubecl-core/src/compute/kernel.rs +++ b/crates/cubecl-core/src/compute/kernel.rs @@ -1,4 +1,4 @@ -use std::marker::PhantomData; +use std::{fmt::Debug, marker::PhantomData}; use crate::{codegen::CompilerRepresentation, ir::CubeDim, Compiler, Kernel}; use alloc::sync::Arc; @@ -78,6 +78,15 @@ pub enum CubeCount { Dynamic(Binding), } +impl Debug for CubeCount { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + CubeCount::Static(x, y, z) => f.write_fmt(format_args!("({x}, {y}, {z})")), + CubeCount::Dynamic(_) => f.write_str("binding"), + } + } +} + impl Clone for CubeCount { fn clone(&self) -> Self { match self { diff --git a/crates/cubecl-core/src/runtime_tests/assign.rs b/crates/cubecl-core/src/runtime_tests/assign.rs new file mode 100644 index 000000000..08dfd77f2 --- /dev/null +++ b/crates/cubecl-core/src/runtime_tests/assign.rs @@ -0,0 +1,44 @@ +use crate as cubecl; + +use cubecl::prelude::*; + +#[cube(launch)] +pub fn kernel_assign(output: &mut Array, vectorization: Comptime) { + if UNIT_POS == UInt::new(0) { + let item = F32::vectorized(5.0, Comptime::get(vectorization)); + output[0] = item; + } +} + +pub fn test_kernel_assign_scalar(client: ComputeClient) { + let handle = client.create(f32::as_bytes(&[0.0, 1.0])); + + let vectorization = 2; + + kernel_assign::launch::( + &client, + CubeCount::Static(1, 1, 1), + CubeDim::default(), + ArrayArg::vectorized(vectorization, &handle, 2), + UInt::new(vectorization as u32), + ); + + let actual = client.read(handle.binding()); + let actual = f32::from_bytes(&actual); + + assert_eq!(actual[0], 5.0); +} + +#[allow(missing_docs)] +#[macro_export] +macro_rules! testgen_assign { + () => { + use super::*; + + #[test] + fn test_assign_scalar() { + let client = TestRuntime::client(&Default::default()); + cubecl_core::runtime_tests::assign::test_kernel_assign_scalar::(client); + } + }; +} diff --git a/crates/cubecl-core/src/runtime_tests/mod.rs b/crates/cubecl-core/src/runtime_tests/mod.rs index d275d5831..005c5bb3c 100644 --- a/crates/cubecl-core/src/runtime_tests/mod.rs +++ b/crates/cubecl-core/src/runtime_tests/mod.rs @@ -1,7 +1,9 @@ +pub mod assign; pub mod cmma; pub mod launch; pub mod slice; pub mod subcube; +pub mod topology; #[allow(missing_docs)] #[macro_export] @@ -13,5 +15,7 @@ macro_rules! testgen_all { cubecl_core::testgen_launch!(); cubecl_core::testgen_cmma!(); cubecl_core::testgen_slice!(); + cubecl_core::testgen_assign!(); + cubecl_core::testgen_topology!(); }; } diff --git a/crates/cubecl-core/src/runtime_tests/topology.rs b/crates/cubecl-core/src/runtime_tests/topology.rs new file mode 100644 index 000000000..1b2df08e3 --- /dev/null +++ b/crates/cubecl-core/src/runtime_tests/topology.rs @@ -0,0 +1,57 @@ +use crate as cubecl; + +use cubecl::prelude::*; + +#[cube(launch)] +pub fn kernel_absolute_pos(output1: &mut Array, output2: &mut Array) { + if ABSOLUTE_POS >= output1.len() { + return; + } + + output1[ABSOLUTE_POS] = ABSOLUTE_POS; + output2[ABSOLUTE_POS] = ABSOLUTE_POS; +} + +pub fn test_kernel_topology_absolute_pos(client: ComputeClient) { + let cube_count = (3, 5, 7); + let cube_dim = (16, 16, 1); + let extra: u32 = 3u32; + + let length = + (cube_count.0 * cube_count.1 * cube_count.2 * cube_dim.0 * cube_dim.1 * cube_dim.2) + extra; + let handle1 = client.empty(length as usize * core::mem::size_of::()); + let handle2 = client.empty(length as usize * core::mem::size_of::()); + + kernel_absolute_pos::launch::( + &client, + CubeCount::Static(cube_count.0, cube_count.1, cube_count.2), + CubeDim::new(cube_dim.0, cube_dim.1, cube_dim.2), + ArrayArg::new(&handle1, length as usize), + ArrayArg::new(&handle2, length as usize), + ); + + let actual = client.read(handle1.binding()); + let actual = u32::from_bytes(&actual); + let mut expect: Vec = (0..length - extra).collect(); + expect.push(0); + expect.push(0); + expect.push(0); + + assert_eq!(actual, &expect); +} + +#[allow(missing_docs)] +#[macro_export] +macro_rules! testgen_topology { + () => { + use super::*; + + #[test] + fn test_topology_scalar() { + let client = TestRuntime::client(&Default::default()); + cubecl_core::runtime_tests::topology::test_kernel_topology_absolute_pos::( + client, + ); + } + }; +} diff --git a/crates/cubecl-cuda/src/compiler/base.rs b/crates/cubecl-cuda/src/compiler/base.rs index ded8b3f02..c63cd8632 100644 --- a/crates/cubecl-cuda/src/compiler/base.rs +++ b/crates/cubecl-cuda/src/compiler/base.rs @@ -1,3 +1,5 @@ +use std::collections::HashSet; + use cubecl_core::{ ir::{self as gpu, ConstantScalarValue}, Compiler, @@ -22,6 +24,7 @@ pub struct CudaCompiler { stride: bool, num_inputs: usize, num_outputs: usize, + items: HashSet, } impl Compiler for CudaCompiler { @@ -86,6 +89,7 @@ impl CudaCompiler { wmma_activated: self.wmma, bf16: self.bf16, f16: self.f16, + items: self.items, } } @@ -548,13 +552,10 @@ impl CudaCompiler { } fn compile_item(&mut self, item: gpu::Item) -> super::Item { - match item.vectorization { - 4 => super::Item::Vec4(self.compile_elem(item.elem)), - 3 => super::Item::Vec3(self.compile_elem(item.elem)), - 2 => super::Item::Vec2(self.compile_elem(item.elem)), - 1 => super::Item::Scalar(self.compile_elem(item.elem)), - _ => panic!("Vectorization factor unsupported {:?}", item.vectorization), - } + let item = super::Item::new(self.compile_elem(item.elem), item.vectorization.into()); + self.items.insert(item); + self.items.insert(item.optimized()); + item } fn compile_elem(&mut self, value: gpu::Elem) -> super::Elem { diff --git a/crates/cubecl-cuda/src/compiler/binary.rs b/crates/cubecl-cuda/src/compiler/binary.rs index e665b3e7f..6b22437c1 100644 --- a/crates/cubecl-cuda/src/compiler/binary.rs +++ b/crates/cubecl-cuda/src/compiler/binary.rs @@ -1,4 +1,4 @@ -use super::{Component, Elem, InstructionSettings, Item, Variable}; +use super::{Component, Elem, Variable}; use std::fmt::Display; pub trait Binary { @@ -8,37 +8,8 @@ pub trait Binary { rhs: &Variable, out: &Variable, ) -> std::fmt::Result { - let item = out.item(); - let settings = Self::settings(*item.elem()); - - match item { - Item::Vec4(elem) => { - if settings.native_vec4 && lhs.item() == rhs.item() { - Self::format_native_vec4(f, lhs, rhs, out, elem) - } else { - Self::unroll_vec4(f, lhs, rhs, out, elem) - } - } - Item::Vec3(elem) => { - if settings.native_vec3 && lhs.item() == rhs.item() { - Self::format_native_vec3(f, lhs, rhs, out, elem) - } else { - Self::unroll_vec3(f, lhs, rhs, out, elem) - } - } - Item::Vec2(elem) => { - if settings.native_vec2 && lhs.item() == rhs.item() { - Self::format_native_vec2(f, lhs, rhs, out, elem) - } else { - Self::unroll_vec2(f, lhs, rhs, out, elem) - } - } - Item::Scalar(elem) => Self::format_scalar(f, *lhs, *rhs, *out, elem), - } - } - - fn settings(_elem: Elem) -> InstructionSettings { - InstructionSettings::default() + let item = out.item().de_optimized(); + Self::unroll_vec(f, lhs, rhs, out, item.elem, item.vectorization) } fn format_scalar( @@ -53,66 +24,6 @@ pub trait Binary { Rhs: Component, Out: Component; - fn format_native_vec4( - f: &mut std::fmt::Formatter<'_>, - lhs: &Variable, - rhs: &Variable, - out: &Variable, - elem: Elem, - ) -> std::fmt::Result { - Self::format_scalar(f, *lhs, *rhs, *out, elem) - } - - fn format_native_vec3( - f: &mut std::fmt::Formatter<'_>, - lhs: &Variable, - rhs: &Variable, - out: &Variable, - elem: Elem, - ) -> std::fmt::Result { - Self::format_scalar(f, *lhs, *rhs, *out, elem) - } - - fn format_native_vec2( - f: &mut std::fmt::Formatter<'_>, - lhs: &Variable, - rhs: &Variable, - out: &Variable, - elem: Elem, - ) -> std::fmt::Result { - Self::format_scalar(f, *lhs, *rhs, *out, elem) - } - - fn unroll_vec2( - f: &mut std::fmt::Formatter<'_>, - lhs: &Variable, - rhs: &Variable, - out: &Variable, - elem: Elem, - ) -> std::fmt::Result { - Self::unroll_vec(f, lhs, rhs, out, elem, 2) - } - - fn unroll_vec3( - f: &mut std::fmt::Formatter<'_>, - lhs: &Variable, - rhs: &Variable, - out: &Variable, - elem: Elem, - ) -> std::fmt::Result { - Self::unroll_vec(f, lhs, rhs, out, elem, 3) - } - - fn unroll_vec4( - f: &mut std::fmt::Formatter<'_>, - lhs: &Variable, - rhs: &Variable, - out: &Variable, - elem: Elem, - ) -> std::fmt::Result { - Self::unroll_vec(f, lhs, rhs, out, elem, 4) - } - fn unroll_vec( f: &mut std::fmt::Formatter<'_>, lhs: &Variable, @@ -121,10 +32,21 @@ pub trait Binary { elem: Elem, index: usize, ) -> core::fmt::Result { + if index == 1 { + return Self::format_scalar(f, *lhs, *rhs, *out, elem); + } + + let optimized = Variable::optimized_args([*lhs, *rhs, *out]); + let [lhs, rhs, out] = optimized.args; + let (is_optimized, index) = match optimized.optimization_factor { + Some(factor) => (true, index / factor), + None => (false, index), + }; + for i in 0..index { - let lhsi = lhs.index(i); - let rhsi = rhs.index(i); - let outi = out.index(i); + let lhsi = lhs.index(i, is_optimized); + let rhsi = rhs.index(i, is_optimized); + let outi = out.index(i, is_optimized); Self::format_scalar(f, lhsi, rhsi, outi, elem)?; } @@ -158,11 +80,6 @@ macro_rules! operator { ) -> std::fmt::Result { f.write_fmt(format_args!("{out} = {lhs} {} {rhs};\n", $op)) } - - #[allow(unused_variables)] - fn settings(elem: Elem) -> InstructionSettings { - $vectorization - } } }; } @@ -192,11 +109,6 @@ macro_rules! function { ) -> std::fmt::Result { f.write_fmt(format_args!("{out} = {}({lhs}, {rhs});\n", $op)) } - - #[allow(unused_variables)] - fn settings(elem: Elem) -> InstructionSettings { - $vectorization - } } }; } @@ -232,31 +144,14 @@ impl Binary for IndexAssign { lhs: Lhs, rhs: Rhs, out: Out, - elem: Elem, + _elem: Elem, ) -> std::fmt::Result where Lhs: Component, Rhs: Component, Out: Component, { - let elem_rhs = rhs.elem(); - // Cast only when necessary. - if elem != elem_rhs { - if let Elem::Bool = elem_rhs { - match rhs.item() { - Item::Vec4(_) => { - f.write_fmt(format_args!("{out}[{lhs}] = make_uint4({elem}({rhs}.x), {elem}({rhs}.y), {elem}({rhs}.z), {elem}({rhs}.w));\n")) - }, - Item::Vec3(_) => todo!(), - Item::Vec2(_) => todo!(), - Item::Scalar(_) => todo!(), - } - } else { - f.write_fmt(format_args!("{out}[{lhs}] = {elem}({rhs});\n")) - } - } else { - f.write_fmt(format_args!("{out}[{lhs}] = {rhs};\n")) - } + f.write_fmt(format_args!("{out}[{lhs}] = {rhs};\n")) } fn unroll_vec( @@ -267,9 +162,13 @@ impl Binary for IndexAssign { elem: Elem, index: usize, ) -> std::fmt::Result { + if index == 1 { + return Self::format_scalar(f, *lhs, *rhs, *out, elem); + } + for i in 0..index { - let lhsi = lhs.index(i); - let rhsi = rhs.index(i); + let lhsi = lhs.index(i, false); + let rhsi = rhs.index(i, false); Self::format_scalar(f, lhsi, rhsi, *out, elem)?; } @@ -292,13 +191,9 @@ impl Binary for IndexAssign { }; let elem = out.elem(); + let item = lhs.item(); - match lhs.item() { - Item::Vec4(_) => Self::unroll_vec4(f, lhs, rhs, out, elem), - Item::Vec3(_) => Self::unroll_vec3(f, lhs, rhs, out, elem), - Item::Vec2(_) => Self::unroll_vec2(f, lhs, rhs, out, elem), - Item::Scalar(_) => Self::format_scalar(f, *lhs, *rhs, *out, elem), - } + Self::unroll_vec(f, lhs, rhs, out, elem, item.vectorization) } } @@ -375,8 +270,8 @@ impl IndexVector { } }; - let out = out.index(index); - let lhs = lhs.index(index); + let out = out.index(index, false); + let lhs = lhs.index(index, false); f.write_fmt(format_args!("{out} = {lhs};\n")) } @@ -397,8 +292,8 @@ impl IndexAssignVector { } }; - let out = out.index(index); - let rhs = rhs.index(index); + let out = out.index(index, false); + let rhs = rhs.index(index, false); f.write_fmt(format_args!("{out} = {rhs};\n")) } diff --git a/crates/cubecl-cuda/src/compiler/element.rs b/crates/cubecl-cuda/src/compiler/element.rs index ff797a24b..22ce1ec16 100644 --- a/crates/cubecl-cuda/src/compiler/element.rs +++ b/crates/cubecl-cuda/src/compiler/element.rs @@ -4,30 +4,32 @@ use std::fmt::Display; use super::Fragment; -#[derive(Debug, Clone, PartialEq, Eq, Copy)] +#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)] pub enum Elem { F32, F16, + F162, BF16, + BF162, I32, U32, Bool, } -#[derive(Debug, Clone, PartialEq, Eq, Copy)] -pub enum Item { - Vec4(Elem), - Vec3(Elem), - Vec2(Elem), - Scalar(Elem), +#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)] +pub struct Item { + pub(crate) elem: Elem, + pub(crate) vectorization: usize, } impl Display for Elem { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Elem::F16 => f.write_str("__half"), + Elem::F162 => f.write_str("__half2"), Elem::F32 => f.write_str("float"), Elem::BF16 => f.write_str("__nv_bfloat16"), + Elem::BF162 => f.write_str("__nv_bfloat162"), Elem::I32 => f.write_str("int"), Elem::U32 => f.write_str("uint"), Elem::Bool => f.write_str("bool"), @@ -37,33 +39,16 @@ impl Display for Elem { impl Display for Item { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Item::Vec4(elem) => match elem { - Elem::F32 => f.write_str("float4"), - Elem::I32 => f.write_str("int4"), - Elem::U32 => f.write_str("uint4"), - Elem::Bool => f.write_str("bool4"), - Elem::BF16 => f.write_str("__nv_bfloat164"), - Elem::F16 => f.write_str("__half4"), - }, - Item::Vec3(elem) => match elem { - Elem::F32 => f.write_str("float3"), - Elem::I32 => f.write_str("int3"), - Elem::U32 => f.write_str("uint3"), - Elem::Bool => f.write_str("bool3"), - Elem::BF16 => f.write_str("__nv_bfloat164"), - Elem::F16 => f.write_str("__half3"), - }, - Item::Vec2(elem) => match elem { - Elem::F32 => f.write_str("float2"), - Elem::I32 => f.write_str("int2"), - Elem::U32 => f.write_str("uint2"), - Elem::Bool => f.write_str("bool2"), - Elem::BF16 => f.write_str("__nv_bfloat162"), - Elem::F16 => f.write_str("__half2"), - }, - Item::Scalar(elem) => f.write_fmt(format_args!("{elem}")), + if 1 == self.vectorization { + return f.write_fmt(format_args!("{}", self.elem)); } + + if self.is_vec_native() { + let elem = self.optimized().elem; + return f.write_fmt(format_args!("{elem}")); + } + + return f.write_fmt(format_args!("{}_{}", self.elem, self.vectorization)); } } @@ -95,38 +80,38 @@ impl Component for Variable { item, depth: _, } => *item, - Variable::ConstantScalar(_, e) => Item::Scalar(*e), - Variable::GlobalScalar(_, e, _) => Item::Scalar(*e), - Variable::IdxGlobal => Item::Scalar(Elem::U32), - Variable::ThreadIdxGlobal => Item::Scalar(Elem::U32), - Variable::ThreadIdxX => Item::Scalar(Elem::U32), - Variable::ThreadIdxY => Item::Scalar(Elem::U32), - Variable::ThreadIdxZ => Item::Scalar(Elem::U32), - Variable::Rank => Item::Scalar(Elem::U32), + Variable::ConstantScalar(_, e) => Item::scalar(*e), + Variable::GlobalScalar(_, e, _) => Item::scalar(*e), + Variable::IdxGlobal => Item::scalar(Elem::U32), + Variable::ThreadIdxGlobal => Item::scalar(Elem::U32), + Variable::ThreadIdxX => Item::scalar(Elem::U32), + Variable::ThreadIdxY => Item::scalar(Elem::U32), + Variable::ThreadIdxZ => Item::scalar(Elem::U32), + Variable::Rank => Item::scalar(Elem::U32), Variable::LocalScalar { id: _, elem, depth: _, - } => Item::Scalar(*elem), - Variable::BlockIdxX => Item::Scalar(Elem::U32), - Variable::BlockIdxY => Item::Scalar(Elem::U32), - Variable::BlockIdxZ => Item::Scalar(Elem::U32), - Variable::AbsoluteIdxX => Item::Scalar(Elem::U32), - Variable::AbsoluteIdxY => Item::Scalar(Elem::U32), - Variable::AbsoluteIdxZ => Item::Scalar(Elem::U32), - Variable::BlockDimX => Item::Scalar(Elem::U32), - Variable::BlockDimY => Item::Scalar(Elem::U32), - Variable::BlockDimZ => Item::Scalar(Elem::U32), - Variable::GridDimX => Item::Scalar(Elem::U32), - Variable::GridDimY => Item::Scalar(Elem::U32), - Variable::GridDimZ => Item::Scalar(Elem::U32), + } => Item::scalar(*elem), + Variable::BlockIdxX => Item::scalar(Elem::U32), + Variable::BlockIdxY => Item::scalar(Elem::U32), + Variable::BlockIdxZ => Item::scalar(Elem::U32), + Variable::AbsoluteIdxX => Item::scalar(Elem::U32), + Variable::AbsoluteIdxY => Item::scalar(Elem::U32), + Variable::AbsoluteIdxZ => Item::scalar(Elem::U32), + Variable::BlockDimX => Item::scalar(Elem::U32), + Variable::BlockDimY => Item::scalar(Elem::U32), + Variable::BlockDimZ => Item::scalar(Elem::U32), + Variable::GridDimX => Item::scalar(Elem::U32), + Variable::GridDimY => Item::scalar(Elem::U32), + Variable::GridDimZ => Item::scalar(Elem::U32), Variable::LocalArray(_, e, _, _) => *e, - Variable::WarpSize => Item::Scalar(Elem::U32), + Variable::WarpSize => Item::scalar(Elem::U32), Variable::WmmaFragment { id: _, frag, depth: _, - } => Item::Scalar(frag.elem), + } => Item::scalar(frag.elem), } } } @@ -241,7 +226,82 @@ impl Display for Variable { } } +#[derive(new)] +pub struct OptimizedArgs { + pub args: [Variable; N], + pub optimization_factor: Option, +} + impl Variable { + pub fn is_optimized(&self) -> bool { + self.item().is_optimized() + } + + pub fn optimized_args(args: [Self; N]) -> OptimizedArgs { + let args_after = args.map(|a| a.optimized()); + + let item_reference_after = args_after[0].item(); + + let is_optimized = args_after + .iter() + .all(|var| var.elem() == item_reference_after.elem && var.is_optimized()); + + if is_optimized { + let vectorization_before = args + .iter() + .map(|var| var.item().vectorization) + .max() + .unwrap(); + let vectorization_after = args_after + .iter() + .map(|var| var.item().vectorization) + .max() + .unwrap(); + + OptimizedArgs::new(args_after, Some(vectorization_before / vectorization_after)) + } else { + OptimizedArgs::new(args, None) + } + } + + pub fn optimized(&self) -> Self { + match self { + Variable::GlobalInputArray(id, item) => { + Variable::GlobalInputArray(*id, item.optimized()) + } + Variable::GlobalOutputArray(id, item) => { + Variable::GlobalOutputArray(*id, item.optimized()) + } + Variable::Local { id, item, depth } => Variable::Local { + id: *id, + item: item.optimized(), + depth: *depth, + }, + Variable::Slice { id, item, depth } => Variable::Slice { + id: *id, + item: item.optimized(), + depth: *depth, + }, + Variable::SharedMemory(id, item, size) => { + let before = item.vectorization; + let item = item.optimized(); + let after = item.vectorization; + let scaling = (before / after) as u32; + + Variable::SharedMemory(*id, item, size / scaling) + } + Variable::LocalArray(id, item, vec, size) => { + let before = item.vectorization; + let item = item.optimized(); + let after = item.vectorization; + let scaling = (before / after) as u32; + + Variable::LocalArray(*id, item.optimized(), *vec, size / scaling) + } + _ => *self, + } + } + pub fn is_always_scalar(&self) -> bool { match self { Variable::GlobalScalar(_, _, _) => true, @@ -292,54 +352,100 @@ impl Variable { } } - pub fn index(&self, index: usize) -> IndexedVariable { - IndexedVariable { var: *self, index } + pub fn index(&self, index: usize, optimized: bool) -> IndexedVariable { + IndexedVariable { + var: *self, + index, + optimized, + } } } #[derive(Debug, Clone)] pub struct IndexedVariable { var: Variable, + optimized: bool, index: usize, } impl Display for IndexedVariable { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let var = &self.var; - let item = self.var.item(); - - match item { - Item::Vec4(_) => match self.index { - 0 => f.write_fmt(format_args!("{var}.x"))?, - 1 => f.write_fmt(format_args!("{var}.y"))?, - 2 => f.write_fmt(format_args!("{var}.z"))?, - 3 => f.write_fmt(format_args!("{var}.w"))?, - _ => unreachable!(), - }, - Item::Vec3(_) => match self.index { - 0 => f.write_fmt(format_args!("{var}.x"))?, - 1 => f.write_fmt(format_args!("{var}.y"))?, - 2 => f.write_fmt(format_args!("{var}.z"))?, - _ => unreachable!(), - }, - Item::Vec2(_) => match self.index { - 0 => f.write_fmt(format_args!("{var}.x"))?, - 1 => f.write_fmt(format_args!("{var}.y"))?, - _ => unreachable!(), - }, - Item::Scalar(_) => f.write_fmt(format_args!("{var}"))?, - } - Ok(()) + if self.var.item().vectorization > 1 { + if self.optimized { + let item = self.var.item(); + f.write_fmt(format_args!( + "(reinterpret_cast<{item}*>(&{var}))->i_{}", + self.index + )) + } else { + f.write_fmt(format_args!("{var}.i_{}", self.index)) + } + } else { + f.write_fmt(format_args!("{var}")) + } } } impl Item { pub fn elem(&self) -> &Elem { - match self { - Item::Vec4(e) => e, - Item::Vec3(e) => e, - Item::Vec2(e) => e, - Item::Scalar(e) => e, + &self.elem + } + + pub fn de_optimized(&self) -> Self { + match self.elem { + Elem::F162 => Item::new(Elem::F16, self.vectorization * 2), + Elem::BF162 => Item::new(Elem::BF16, self.vectorization * 2), + _ => *self, + } + } + + pub fn new(elem: Elem, vectorization: usize) -> Self { + Self { + elem, + vectorization, + } + } + pub fn scalar(elem: Elem) -> Self { + Self { + elem, + vectorization: 1, + } + } + + pub fn is_optimized(&self) -> bool { + matches!(self.elem, Elem::F162 | Elem::BF162) + } + + pub fn is_vec_native(&self) -> bool { + match &self.elem { + Elem::F16 => self.vectorization == 2, + Elem::BF16 => self.vectorization == 2, + Elem::F162 => self.vectorization == 1, + Elem::BF162 => self.vectorization == 1, + _ => false, + } + } + + pub fn optimized(&self) -> Item { + if self.vectorization == 1 { + return *self; + } + + if self.vectorization % 2 != 0 { + return *self; + } + + match self.elem { + Elem::F16 => Item { + elem: Elem::F162, + vectorization: self.vectorization / 2, + }, + Elem::BF16 => Item { + elem: Elem::BF162, + vectorization: self.vectorization / 2, + }, + _ => *self, } } } @@ -347,9 +453,11 @@ impl Item { impl Elem { pub fn size(&self) -> usize { match self { - Self::F32 => core::mem::size_of::(), Self::F16 => core::mem::size_of::(), + Self::F162 => 2 * core::mem::size_of::(), + Self::BF162 => 2 * core::mem::size_of::(), Self::BF16 => core::mem::size_of::(), + Self::F32 => core::mem::size_of::(), Self::I32 => core::mem::size_of::(), Self::U32 => core::mem::size_of::(), Self::Bool => core::mem::size_of::(), diff --git a/crates/cubecl-cuda/src/compiler/instruction.rs b/crates/cubecl-cuda/src/compiler/instruction.rs index 8d6ae6e10..4e30af5c0 100644 --- a/crates/cubecl-cuda/src/compiler/instruction.rs +++ b/crates/cubecl-cuda/src/compiler/instruction.rs @@ -262,16 +262,14 @@ for (uint {i} = {start}; {i} < {end}; {i}++) {{ Variable::GlobalOutputArray(index, _) => *index as usize + num_inputs, _ => panic!("Can only know the len of a global array."), } + 1; - let factor = match input.item() { - super::Item::Vec4(_) => 4, - super::Item::Vec3(_) => 3, - super::Item::Vec2(_) => 2, - super::Item::Scalar(_) => { - return f.write_fmt(format_args!( - "{out} = info[({offset} * 2 * info[0]) + {index}];\n" - )) - } - }; + let factor = input.item().vectorization; + + if factor == 1 { + return f.write_fmt(format_args!( + "{out} = info[({offset} * 2 * info[0]) + {index}];\n" + )); + } + f.write_fmt(format_args!( "{out} = info[({offset} * 2 * info[0]) + {index}] / {factor};\n" )) @@ -293,18 +291,13 @@ impl Fma { c: &Variable, out: &Variable, ) -> core::fmt::Result { - let num = match out.item() { - super::Item::Vec4(_) => 4, - super::Item::Vec3(_) => 3, - super::Item::Vec2(_) => 2, - super::Item::Scalar(_) => 1, - }; + let num = out.item().vectorization; for i in 0..num { - let ai = a.index(i); - let bi = b.index(i); - let ci = c.index(i); - let outi = out.index(i); + let ai = a.index(i, false); + let bi = b.index(i, false); + let ci = c.index(i, false); + let outi = out.index(i, false); f.write_fmt(format_args!("{outi} = fma({ai}, {bi}, {ci});\n"))?; } diff --git a/crates/cubecl-cuda/src/compiler/kernel.rs b/crates/cubecl-cuda/src/compiler/kernel.rs index 0251b7945..c239d154f 100644 --- a/crates/cubecl-cuda/src/compiler/kernel.rs +++ b/crates/cubecl-cuda/src/compiler/kernel.rs @@ -1,6 +1,6 @@ use super::{Body, Item}; use cubecl_core::{ir::CubeDim, CompilerRepresentation}; -use std::fmt::Display; +use std::{collections::HashSet, fmt::Display}; #[derive(Debug, PartialEq, Eq, Clone)] pub struct Binding { @@ -50,6 +50,7 @@ pub struct ComputeKernel { pub wmma_activated: bool, pub bf16: bool, pub f16: bool, + pub items: HashSet, } impl CompilerRepresentation for ComputeKernel { @@ -57,13 +58,7 @@ impl CompilerRepresentation for ComputeKernel { let mut current = 0usize; for var in self.body.shared_memories.iter() { - let factor = match var.item { - Item::Vec4(_) => 4, - Item::Vec3(_) => 3, - Item::Vec2(_) => 2, - Item::Scalar(_) => 1, - }; - + let factor = var.item.vectorization; let elem_size_bytes = var.item.elem().size(); current += (var.size as usize) * factor * elem_size_bytes; } @@ -89,30 +84,39 @@ impl Display for ComputeKernel { f.write_str("using namespace nvcuda;\n")?; } - if self.f16 { - f.write_str( - " -extern \"C\" struct __half4 { - __half x; - __half y; - __half z; - __half w; -}; -", - )?; - } - - f.write_fmt(format_args!( + f.write_str( " typedef unsigned int uint; + ", + )?; -extern \"C\" struct bool4 {{ - bool x; - bool y; - bool z; - bool w; -}}; + for item in self.items.iter() { + if item.is_vec_native() { + continue; + } + + let elem = item.elem; + let size = item.vectorization; + let alignment = elem.size() * size; + if size > 1 { + f.write_fmt(format_args!( + " +struct __align__({alignment}) {item} {{" + ))?; + + for i in 0..size { + f.write_fmt(format_args!( + " + {elem} i_{i};" + ))?; + } + + f.write_str("\n};\n")?; + } + } + f.write_fmt(format_args!( + " extern \"C\" __global__ void kernel( ", diff --git a/crates/cubecl-cuda/src/compiler/unary.rs b/crates/cubecl-cuda/src/compiler/unary.rs index 641c4835d..e7237c87a 100644 --- a/crates/cubecl-cuda/src/compiler/unary.rs +++ b/crates/cubecl-cuda/src/compiler/unary.rs @@ -1,4 +1,4 @@ -use super::{Component, Elem, InstructionSettings, Item, Variable}; +use super::{Component, Elem, Variable}; use std::fmt::Display; pub trait Unary { @@ -8,36 +8,8 @@ pub trait Unary { out: &Variable, ) -> std::fmt::Result { let item = out.item(); - let settings = Self::settings(*item.elem()); - match item { - Item::Vec4(elem) => { - if settings.native_vec4 { - Self::format_native_vec4(f, input, out, elem) - } else { - Self::unroll_vec4(f, input, out, elem) - } - } - Item::Vec3(elem) => { - if settings.native_vec3 { - Self::format_native_vec3(f, input, out, elem) - } else { - Self::unroll_vec3(f, input, out, elem) - } - } - Item::Vec2(elem) => { - if settings.native_vec2 { - Self::format_native_vec2(f, input, out, elem) - } else { - Self::unroll_vec2(f, input, out, elem) - } - } - Item::Scalar(elem) => Self::format_scalar(f, *input, *out, elem), - } - } - - fn settings(_elem: Elem) -> InstructionSettings { - InstructionSettings::default() + Self::unroll_vec(f, input, out, item.elem, item.vectorization) } fn format_scalar( @@ -50,60 +22,6 @@ pub trait Unary { Input: Component, Out: Component; - fn format_native_vec4( - f: &mut std::fmt::Formatter<'_>, - input: &Variable, - out: &Variable, - elem: Elem, - ) -> std::fmt::Result { - Self::format_scalar(f, *input, *out, elem) - } - - fn format_native_vec3( - f: &mut std::fmt::Formatter<'_>, - input: &Variable, - out: &Variable, - elem: Elem, - ) -> std::fmt::Result { - Self::format_scalar(f, *input, *out, elem) - } - - fn format_native_vec2( - f: &mut std::fmt::Formatter<'_>, - input: &Variable, - out: &Variable, - elem: Elem, - ) -> std::fmt::Result { - Self::format_scalar(f, *input, *out, elem) - } - - fn unroll_vec2( - f: &mut std::fmt::Formatter<'_>, - input: &Variable, - out: &Variable, - elem: Elem, - ) -> std::fmt::Result { - Self::unroll_vec(f, input, out, elem, 2) - } - - fn unroll_vec3( - f: &mut std::fmt::Formatter<'_>, - input: &Variable, - out: &Variable, - elem: Elem, - ) -> std::fmt::Result { - Self::unroll_vec(f, input, out, elem, 3) - } - - fn unroll_vec4( - f: &mut std::fmt::Formatter<'_>, - input: &Variable, - out: &Variable, - elem: Elem, - ) -> std::fmt::Result { - Self::unroll_vec(f, input, out, elem, 4) - } - fn unroll_vec( f: &mut std::fmt::Formatter<'_>, input: &Variable, @@ -111,9 +29,16 @@ pub trait Unary { elem: Elem, index: usize, ) -> std::fmt::Result { + let optimized = Variable::optimized_args([*input, *out]); + let [input, out] = optimized.args; + let (is_optimized, index, elem) = match optimized.optimization_factor { + Some(factor) => (true, index / factor, out.elem()), + None => (false, index, elem), + }; + for i in 0..index { - let inputi = input.index(i); - let outi = out.index(i); + let inputi = input.index(i, is_optimized); + let outi = out.index(i, is_optimized); Self::format_scalar(f, inputi, outi, elem)?; } @@ -131,9 +56,16 @@ macro_rules! function { f: &mut std::fmt::Formatter<'_>, input: Input, out: Out, - _elem: Elem, + elem: Elem, ) -> std::fmt::Result { - f.write_fmt(format_args!("{out} = {}({input});\n", $func)) + match elem { + Elem::F16 => f.write_fmt(format_args!("{out} = h{}({input});\n", $func)), + Elem::F162 => f.write_fmt(format_args!("{out} = h2{}({input});\n", $func)), + Elem::BF16 => f.write_fmt(format_args!("{out} = h{}({input});\n", $func)), + Elem::BF162 => f.write_fmt(format_args!("{out} = h2{}({input});\n", $func)), + Elem::F32 => f.write_fmt(format_args!("{out} = __{}f({input});\n", $func)), + _ => f.write_fmt(format_args!("{out} = {}({input});\n", $func)), + } } } }; diff --git a/crates/cubecl-linalg/src/tensor/base.rs b/crates/cubecl-linalg/src/tensor/base.rs index 5716c711c..3cd6c93e8 100644 --- a/crates/cubecl-linalg/src/tensor/base.rs +++ b/crates/cubecl-linalg/src/tensor/base.rs @@ -109,7 +109,7 @@ where R: Runtime, E: Numeric, { - pub fn zeros(client: ComputeClient, shape: Vec) -> Self { + pub fn zeros(client: &ComputeClient, shape: Vec) -> Self { let num_elements: usize = shape.iter().product(); let size = E::as_elem().size(); @@ -125,7 +125,7 @@ where ); init::zeros_array::launch::( - &client, + client, cube_count, CubeDim::default(), ArrayArg::new(&handle, num_elements), @@ -141,8 +141,10 @@ pub(crate) mod init { #[cube(launch)] pub fn zeros_array(output: &mut Array) { - if ABSOLUTE_POS < output.len() { - output[ABSOLUTE_POS] = C::from_int(0); + if ABSOLUTE_POS >= output.len() { + return; } + + output[ABSOLUTE_POS] = C::from_int(0); } } diff --git a/crates/cubecl/Cargo.toml b/crates/cubecl/Cargo.toml index 83c0935ee..0d129d7bb 100644 --- a/crates/cubecl/Cargo.toml +++ b/crates/cubecl/Cargo.toml @@ -36,3 +36,7 @@ cubecl-linalg = { path = "../cubecl-linalg", version = "0.1.1", default-features [[bench]] name = "matmul" harness = false + +[[bench]] +name = "unary" +harness = false diff --git a/crates/cubecl/benches/matmul.rs b/crates/cubecl/benches/matmul.rs index ef81c4d8a..962df6169 100644 --- a/crates/cubecl/benches/matmul.rs +++ b/crates/cubecl/benches/matmul.rs @@ -13,9 +13,9 @@ impl Benchmark for MatmulBench { fn prepare(&self) -> Self::Args { let (b, m, k, n) = (self.b, self.m, self.k, self.n); let client = R::client(&self.device); - let lhs = TensorHandle::zeros(client.clone(), vec![b, m, k]); - let rhs = TensorHandle::zeros(client.clone(), vec![b, k, n]); - let out = TensorHandle::zeros(client.clone(), vec![b, m, n]); + let lhs = TensorHandle::zeros(&client, vec![b, m, k]); + let rhs = TensorHandle::zeros(&client, vec![b, k, n]); + let out = TensorHandle::zeros(&client, vec![b, m, n]); (lhs, rhs, out) } @@ -82,9 +82,13 @@ fn run(device: R::Device, kind: MatmulKind) { fn main() { #[cfg(feature = "wgpu")] run::(Default::default(), MatmulKind::Tiling2d); + #[cfg(feature = "cuda")] run::(Default::default(), MatmulKind::Tiling2d); - + #[cfg(feature = "cuda")] + run::(Default::default(), MatmulKind::Tiling2d); #[cfg(feature = "cuda")] run::(Default::default(), MatmulKind::Cmma); + #[cfg(feature = "cuda")] + run::(Default::default(), MatmulKind::Cmma); } diff --git a/crates/cubecl/benches/unary.rs b/crates/cubecl/benches/unary.rs new file mode 100644 index 000000000..03c53f3de --- /dev/null +++ b/crates/cubecl/benches/unary.rs @@ -0,0 +1,103 @@ +use cubecl::{calculate_cube_count_elemwise, prelude::*}; +use std::marker::PhantomData; + +use cubecl::benchmark::Benchmark; +use cubecl::client::SyncType; +use cubecl::frontend::Float; +use cubecl_linalg::tensor::TensorHandle; + +#[cube(launch)] +fn execute(lhs: &Tensor, rhs: &Tensor, out: &mut Tensor) { + if ABSOLUTE_POS < out.len() { + for _ in range(0, 256, Comptime::new(false)) { + out[ABSOLUTE_POS] += F::cos(lhs[ABSOLUTE_POS] * rhs[ABSOLUTE_POS]); + } + } +} + +impl Benchmark for UnaryBench { + type Args = (TensorHandle, TensorHandle, TensorHandle); + + fn prepare(&self) -> Self::Args { + let client = R::client(&self.device); + let lhs = TensorHandle::zeros(&client, self.shape.clone()); + let rhs = TensorHandle::zeros(&client, self.shape.clone()); + let out = TensorHandle::zeros(&client, self.shape.clone()); + + (lhs, rhs, out) + } + + fn execute(&self, (lhs, rhs, out): Self::Args) { + let num_elems: usize = out.shape.iter().product(); + + let cube_count = + calculate_cube_count_elemwise::(num_elems / self.vectorization as usize, 16); + + execute::launch::( + &self.client, + cube_count, + CubeDim::new(16, 16, 1), + TensorArg::vectorized(self.vectorization, &lhs.handle, &lhs.strides, &lhs.shape), + TensorArg::vectorized(self.vectorization, &rhs.handle, &rhs.strides, &rhs.shape), + TensorArg::vectorized(self.vectorization, &out.handle, &out.strides, &out.shape), + ) + } + + fn num_samples(&self) -> usize { + 100 + } + + fn name(&self) -> String { + format!( + "unary-{}-{}-{:?}", + R::name(), + E::as_elem(), + self.vectorization + ) + .to_lowercase() + } + + fn sync(&self) { + self.client.sync(SyncType::Wait); + } +} + +#[allow(dead_code)] +struct UnaryBench { + shape: Vec, + vectorization: u8, + device: R::Device, + client: ComputeClient, + _e: PhantomData, +} + +#[allow(dead_code)] +#[derive(Debug)] +enum MatmulKind { + Tiling2d, + Cmma, +} + +#[allow(dead_code)] +fn run(device: R::Device, vectorization: u8) { + let bench = UnaryBench:: { + shape: vec![32, 512, 2048], + vectorization, + client: R::client(&device), + device, + _e: PhantomData, + }; + println!("{}", bench.name()); + println!("{}", bench.run()); +} + +fn main() { + #[cfg(feature = "cuda")] + run::(Default::default(), 8); + #[cfg(feature = "cuda")] + run::(Default::default(), 4); + #[cfg(feature = "wgpu")] + run::(Default::default(), 1); + #[cfg(feature = "wgpu")] + run::(Default::default(), 4); +} diff --git a/profiling/matmul-example/Cargo.toml b/profiling/matmul-example/Cargo.toml index 260689939..bbfa1830e 100644 --- a/profiling/matmul-example/Cargo.toml +++ b/profiling/matmul-example/Cargo.toml @@ -10,9 +10,9 @@ cubecl = { version = "0.1.0", path = "../../crates/cubecl", features = [ "cuda", "linalg", ], optional = true } -burn = { version = "0.13.2", optional = true, features = ["tch"] } -burn-tensor = { version = "0.13.2", optional = true } +burn = { git = "https://github.com/tracel-ai/burn", optional = true, features = ["tch"] } +burn-tensor = { git = "https://github.com/tracel-ai/burn", optional = true } [features] -burn-tch-cuda = ["burn", "burn-tensor"] +burn-tch-cuda = ["burn"] cube-cuda = ["cubecl"]