Skip to content

Commit

Permalink
Implemented dot product (#140)
Browse files Browse the repository at this point in the history
  • Loading branch information
RianGoossens authored Sep 23, 2024
1 parent 447968e commit e5699a1
Show file tree
Hide file tree
Showing 13 changed files with 283 additions and 7 deletions.
1 change: 1 addition & 0 deletions crates/cubecl-core/src/frontend/element/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ pub trait Float:
+ Recip
+ Magnitude
+ Normalize
+ Dot
+ Into<Self::ExpandType>
+ core::ops::Add<Output = Self>
+ core::ops::Sub<Output = Self>
Expand Down
42 changes: 41 additions & 1 deletion crates/cubecl-core/src/frontend/operation/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,46 @@ where
out
}

pub(crate) fn binary_expand_fixed_output<F>(
context: &mut CubeContext,
lhs: ExpandElement,
rhs: ExpandElement,
out_item: Item,
func: F,
) -> ExpandElement
where
F: Fn(BinaryOperator) -> Operator,
{
let lhs_var: Variable = *lhs;
let rhs_var: Variable = *rhs;

let item_lhs = lhs.item();
let item_rhs = rhs.item();

let _ = find_vectorization(item_lhs.vectorization, item_rhs.vectorization);

// We can only reuse rhs.
let out = if lhs.can_mut() && item_lhs == out_item {
lhs
} else if rhs.can_mut() && item_rhs == out_item {
rhs
} else {
context.create_local(out_item)
};

let out_var = *out;

let op = func(BinaryOperator {
lhs: lhs_var,
rhs: rhs_var,
out: out_var,
});

context.register(op);

out
}

pub(crate) fn binary_expand_no_vec<F>(
context: &mut CubeContext,
lhs: ExpandElement,
Expand Down Expand Up @@ -170,7 +210,7 @@ where
out
}

pub fn fixed_output_unary_expand<F>(
pub fn unary_expand_fixed_output<F>(
context: &mut CubeContext,
input: ExpandElement,
out_item: Item,
Expand Down
53 changes: 51 additions & 2 deletions crates/cubecl-core/src/frontend/operation/binary.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use crate::frontend::CubeType;
use crate::frontend::{CubeContext, CubePrimitive, ExpandElementTyped};
use crate::frontend::{CubeContext, CubePrimitive, ExpandElement, ExpandElementTyped};
use crate::ir::Operator;
use crate::{frontend::operation::base::binary_expand, unexpanded};
use crate::{
frontend::operation::base::{binary_expand, binary_expand_fixed_output},
unexpanded,
};
use half::{bf16, f16};

pub mod add {
Expand Down Expand Up @@ -210,6 +213,37 @@ macro_rules! impl_binary_func {
}
}

macro_rules! impl_binary_func_fixed_output_vectorization {
($trait_name:ident, $method_name:ident, $func_name_expand:ident, $method_name_expand:ident, $operator:expr, $out_vectorization: expr, $($type:ty),*) => {
pub trait $trait_name: CubeType + Sized {
fn $method_name(self, _rhs: Self) -> Self {
unexpanded!()
}

fn $func_name_expand(
context: &mut CubeContext,
lhs: ExpandElementTyped<Self>,
rhs: ExpandElementTyped<Self>,
) -> ExpandElementTyped<Self> {
let lhs: ExpandElement = lhs.into();
let mut item = lhs.item();
item.vectorization = $out_vectorization;
binary_expand_fixed_output(context, lhs, rhs.into(), item, $operator).into()
}
}

$(impl $trait_name for $type {})*
$(impl ExpandElementTyped<$type> {
pub fn $method_name_expand(self, context: &mut CubeContext, rhs: ExpandElementTyped<$type>) -> ExpandElementTyped<$type> {
let lhs: ExpandElement = self.into();
let mut item = lhs.item();
item.vectorization = $out_vectorization;
binary_expand_fixed_output(context, lhs, rhs.into(), item, $operator).into()
}
})*
}
}

impl_binary_func!(
Powf,
powf,
Expand Down Expand Up @@ -263,3 +297,18 @@ impl_binary_func!(
i64,
u32
);
impl_binary_func_fixed_output_vectorization!(
Dot,
dot,
__expand_dot,
__expand_dot_method,
Operator::Dot,
None,
f16,
bf16,
f32,
f64,
i32,
i64,
u32
);
8 changes: 4 additions & 4 deletions crates/cubecl-core/src/frontend/operation/unary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{
unexpanded,
};

use super::base::{fixed_output_unary_expand, unary_expand};
use super::base::{unary_expand, unary_expand_fixed_output};

pub mod not {
use super::*;
Expand Down Expand Up @@ -50,7 +50,7 @@ macro_rules! impl_unary_func {
}
}

macro_rules! impl_fixed_out_vectorization_unary_func {
macro_rules! impl_unary_func_fixed_out_vectorization {
($trait_name:ident, $method_name:ident, $method_name_expand:ident, $operator:expr, $out_vectorization: expr, $($type:ty),*) => {
pub trait $trait_name: CubePrimitive + Sized {
#[allow(unused_variables)]
Expand All @@ -62,7 +62,7 @@ macro_rules! impl_fixed_out_vectorization_unary_func {
let expand_element: ExpandElement = x.into();
let mut item = expand_element.item();
item.vectorization = $out_vectorization;
fixed_output_unary_expand(context, expand_element, item, $operator).into()
unary_expand_fixed_output(context, expand_element, item, $operator).into()
}
}

Expand Down Expand Up @@ -158,7 +158,7 @@ impl_unary_func!(
f32,
f64
);
impl_fixed_out_vectorization_unary_func!(
impl_unary_func_fixed_out_vectorization!(
Magnitude,
magnitude,
__expand_magnitude,
Expand Down
1 change: 1 addition & 0 deletions crates/cubecl-core/src/ir/operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ pub enum Operator {
AtomicCompareAndSwap(CompareAndSwapOperator),
Magnitude(UnaryOperator),
Normalize(UnaryOperator),
Dot(BinaryOperator),
}

/// All metadata that can be access in a shader.
Expand Down
4 changes: 4 additions & 0 deletions crates/cubecl-core/src/ir/processing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,10 @@ impl ScopeProcessing {
Operator::Normalize(op) => {
sanitize_constant_scalar_ref_var(&mut op.input, &op.out);
}
Operator::Dot(op) => {
sanitize_constant_scalar_ref_var(&mut op.lhs, &op.out);
sanitize_constant_scalar_ref_var(&mut op.rhs, &op.out);
}
},
Operation::Metadata(op) => match op {
Metadata::Stride { dim, .. } => {
Expand Down
1 change: 1 addition & 0 deletions crates/cubecl-core/src/ir/vectorization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ impl Operator {
Operator::AtomicXor(op) => Operator::AtomicXor(op.vectorize(vectorization)),
Operator::Magnitude(op) => Operator::Magnitude(op.vectorize(vectorization)),
Operator::Normalize(op) => Operator::Normalize(op.vectorize(vectorization)),
Operator::Dot(op) => Operator::Dot(op.vectorize(vectorization)),
}
}
}
Expand Down
136 changes: 136 additions & 0 deletions crates/cubecl-core/src/runtime_tests/binary.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
use crate as cubecl;

use cubecl::prelude::*;
use cubecl_runtime::server::Handle;

pub(crate) fn assert_equals_approx<R: Runtime>(
client: &ComputeClient<R::Server, R::Channel>,
output: Handle<<R as Runtime>::Server>,
expected: &[f32],
epsilon: f32,
) {
let actual = client.read(output.binding());
let actual = f32::from_bytes(&actual);

for (i, (a, e)) in actual.iter().zip(expected.iter()).enumerate() {
assert!(
(a - e).abs() < epsilon || (a.is_nan() && e.is_nan()),
"Values differ more than epsilon: actual={}, expected={}, difference={}, epsilon={}
index: {}
actual: {:?}
expected: {:?}",
a,
e,
(a - e).abs(),
epsilon,
i,
actual,
expected
);
}
}

macro_rules! test_binary_impl {
(
$test_name:ident,
$float_type:ident,
$binary_func:expr,
[$({
input_vectorization: $input_vectorization:expr,
out_vectorization: $out_vectorization:expr,
lhs: $lhs:expr,
rhs: $rhs:expr,
expected: $expected:expr
}),*]) => {
pub fn $test_name<R: Runtime>(client: ComputeClient<R::Server, R::Channel>) {
#[cube(launch_unchecked)]
fn test_function<$float_type: Float>(lhs: &Array<$float_type>, rhs: &Array<$float_type>, output: &mut Array<$float_type>) {
if ABSOLUTE_POS < rhs.len() {
output[ABSOLUTE_POS] = $binary_func(lhs[ABSOLUTE_POS], rhs[ABSOLUTE_POS]);
}
}

$(
{
let lhs = &$lhs;
let rhs = &$rhs;
let output_handle = client.empty($expected.len() * core::mem::size_of::<f32>());
let lhs_handle = client.create(f32::as_bytes(lhs));
let rhs_handle = client.create(f32::as_bytes(rhs));

unsafe {
test_function::launch_unchecked::<f32, R>(
&client,
CubeCount::Static(1, 1, 1),
CubeDim::new((lhs.len() / $input_vectorization as usize) as u32, 1, 1),
ArrayArg::from_raw_parts(&lhs_handle, lhs.len(), $input_vectorization),
ArrayArg::from_raw_parts(&rhs_handle, rhs.len(), $input_vectorization),
ArrayArg::from_raw_parts(&output_handle, $expected.len(), $out_vectorization),
)
};

assert_equals_approx::<R>(&client, output_handle, &$expected, 0.001);
}
)*
}
};
}

test_binary_impl!(
test_dot,
F,
F::dot,
[
{
input_vectorization: 1,
out_vectorization: 1,
lhs: [1., -3.1, -2.4, 15.1],
rhs: [-1., 23.1, -1.4, 5.1],
expected: [-1.0, -71.61, 3.36, 77.01]
},
{
input_vectorization: 2,
out_vectorization: 1,
lhs: [1., -3.1, -2.4, 15.1],
rhs: [-1., 23.1, -1.4, 5.1],
expected: [-72.61, 80.37]
},
{
input_vectorization: 4,
out_vectorization: 1,
lhs: [1., -3.1, -2.4, 15.1],
rhs: [-1., 23.1, -1.4, 5.1],
expected: [7.76]
},
{
input_vectorization: 4,
out_vectorization: 1,
lhs: [1., -3.1, -2.4, 15.1, -1., 23.1, -1.4, 5.1],
rhs: [-1., 23.1, -1.4, 5.1, 1., -3.1, -2.4, 15.1],
expected: [7.76, 7.76]
}

]
);

#[allow(missing_docs)]
#[macro_export]
macro_rules! testgen_binary {
() => {
mod binary {
use super::*;

macro_rules! add_test {
($test_name:ident) => {
#[test]
fn $test_name() {
let client = TestRuntime::client(&Default::default());
cubecl_core::runtime_tests::binary::$test_name::<TestRuntime>(client);
}
};
}

add_test!(test_dot);
}
};
}
2 changes: 2 additions & 0 deletions crates/cubecl-core/src/runtime_tests/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pub mod assign;
pub mod binary;
pub mod cmma;
pub mod different_rank;
pub mod launch;
Expand All @@ -22,6 +23,7 @@ macro_rules! testgen_all {
cubecl_core::testgen_topology!();
cubecl_core::testgen_sequence!();
cubecl_core::testgen_unary!();
cubecl_core::testgen_binary!();
cubecl_core::testgen_different_rank!();
};
}
1 change: 1 addition & 0 deletions crates/cubecl-cuda/src/compiler/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,7 @@ impl CudaCompiler {
gpu::Operator::Magnitude(op) => {
instructions.push(Instruction::Magnitude(self.compile_unary(op)))
}
gpu::Operator::Dot(op) => instructions.push(Instruction::Dot(self.compile_binary(op))),
};
}

Expand Down
24 changes: 24 additions & 0 deletions crates/cubecl-cuda/src/compiler/instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ pub enum Instruction {
Negate(UnaryInstruction),
Magnitude(UnaryInstruction),
Normalize(UnaryInstruction),
Dot(BinaryInstruction),
}

impl Display for Instruction {
Expand Down Expand Up @@ -392,6 +393,7 @@ for ({i_ty} {i} = {start}; {i} {cmp} {end}; {increment}) {{
}
Instruction::Normalize(inst) => Normalize::format(f, &inst.input, &inst.out),
Instruction::Magnitude(inst) => Magnitude::format(f, &inst.input, &inst.out),
Instruction::Dot(inst) => Dot::format(f, &inst.lhs, &inst.rhs, &inst.out),
}
}
}
Expand Down Expand Up @@ -534,3 +536,25 @@ impl Normalize {
f.write_fmt(format_args!("}}\n"))
}
}

struct Dot;

impl Dot {
fn format(
f: &mut core::fmt::Formatter<'_>,
lhs: &Variable,
rhs: &Variable,
out: &Variable,
) -> core::fmt::Result {
let num = lhs.item().vectorization;

f.write_fmt(format_args!("{out} = 0.0;\n"))?;

for i in 0..num {
let lhs_i = lhs.index(i);
let rhs_i = rhs.index(i);
f.write_fmt(format_args!("{out} += {lhs_i} * {rhs_i};\n"))?;
}
Ok(())
}
}
Loading

0 comments on commit e5699a1

Please sign in to comment.