Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/bool cast #303

Merged
merged 5 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/cubecl-core/src/runtime_tests/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::{self as cubecl, as_type};
use cubecl::prelude::*;
use cubecl_runtime::server::Handle;

#[track_caller]
pub(crate) fn assert_equals_approx<
R: Runtime,
F: Float + num_traits::Float + CubeElement + Display,
Expand Down
55 changes: 41 additions & 14 deletions crates/cubecl-core/src/runtime_tests/plane.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use crate::Feature;
use std::fmt::Display;

use crate::{self as cubecl};
use crate::{runtime_tests::binary::assert_equals_approx, Feature};
use cubecl::prelude::*;

#[cube(launch)]
Expand Down Expand Up @@ -74,7 +76,10 @@ pub fn kernel_broadcast<F: Float>(output: &mut Tensor<F>) {
}
}

pub fn test_plane_sum<TestRuntime: Runtime, F: Float + CubeElement + Sized>(
pub fn test_plane_sum<
TestRuntime: Runtime,
F: Float + num_traits::Float + CubeElement + Display,
>(
client: ComputeClient<TestRuntime::Server, TestRuntime::Channel>,
vectorization: u8,
) {
Expand Down Expand Up @@ -108,7 +113,10 @@ pub fn test_plane_sum<TestRuntime: Runtime, F: Float + CubeElement + Sized>(
);
}

pub fn test_plane_prod<TestRuntime: Runtime, F: Float + CubeElement>(
pub fn test_plane_prod<
TestRuntime: Runtime,
F: Float + num_traits::Float + CubeElement + Display,
>(
client: ComputeClient<TestRuntime::Server, TestRuntime::Channel>,
vectorization: u8,
) {
Expand Down Expand Up @@ -147,7 +155,10 @@ pub fn test_plane_prod<TestRuntime: Runtime, F: Float + CubeElement>(
);
}

pub fn test_plane_max<TestRuntime: Runtime, F: Float + CubeElement>(
pub fn test_plane_max<
TestRuntime: Runtime,
F: Float + num_traits::Float + CubeElement + Display,
>(
client: ComputeClient<TestRuntime::Server, TestRuntime::Channel>,
vectorization: u8,
) {
Expand Down Expand Up @@ -183,7 +194,10 @@ pub fn test_plane_max<TestRuntime: Runtime, F: Float + CubeElement>(
);
}

pub fn test_plane_min<TestRuntime: Runtime, F: Float + CubeElement>(
pub fn test_plane_min<
TestRuntime: Runtime,
F: Float + num_traits::Float + CubeElement + Display,
>(
client: ComputeClient<TestRuntime::Server, TestRuntime::Channel>,
vectorization: u8,
) {
Expand Down Expand Up @@ -219,7 +233,10 @@ pub fn test_plane_min<TestRuntime: Runtime, F: Float + CubeElement>(
);
}

pub fn test_plane_all<TestRuntime: Runtime, F: Float + CubeElement>(
pub fn test_plane_all<
TestRuntime: Runtime,
F: Float + num_traits::Float + CubeElement + Display,
>(
client: ComputeClient<TestRuntime::Server, TestRuntime::Channel>,
vectorization: u8,
) {
Expand Down Expand Up @@ -257,7 +274,10 @@ pub fn test_plane_all<TestRuntime: Runtime, F: Float + CubeElement>(
);
}

pub fn test_plane_any<TestRuntime: Runtime, F: Float + CubeElement>(
pub fn test_plane_any<
TestRuntime: Runtime,
F: Float + num_traits::Float + CubeElement + Display,
>(
client: ComputeClient<TestRuntime::Server, TestRuntime::Channel>,
vectorization: u8,
) {
Expand Down Expand Up @@ -295,7 +315,10 @@ pub fn test_plane_any<TestRuntime: Runtime, F: Float + CubeElement>(
);
}

pub fn test_plane_elect<TestRuntime: Runtime, F: Float + CubeElement>(
pub fn test_plane_elect<
TestRuntime: Runtime,
F: Float + num_traits::Float + CubeElement + Display,
>(
client: ComputeClient<TestRuntime::Server, TestRuntime::Channel>,
vectorization: u8,
) {
Expand Down Expand Up @@ -324,7 +347,10 @@ pub fn test_plane_elect<TestRuntime: Runtime, F: Float + CubeElement>(
);
}

pub fn test_plane_broadcast<TestRuntime: Runtime, F: Float + CubeElement>(
pub fn test_plane_broadcast<
TestRuntime: Runtime,
F: Float + num_traits::Float + CubeElement + Display,
>(
client: ComputeClient<TestRuntime::Server, TestRuntime::Channel>,
vectorization: u8,
) {
Expand Down Expand Up @@ -356,7 +382,11 @@ pub fn test_plane_broadcast<TestRuntime: Runtime, F: Float + CubeElement>(
);
}

fn test_plane_operation<TestRuntime: Runtime, F: Float + CubeElement, Launch>(
fn test_plane_operation<
TestRuntime: Runtime,
F: Float + num_traits::Float + CubeElement + Display,
Launch,
>(
input: &[F],
expected: &[F],
vectorization: u8,
Expand All @@ -380,10 +410,7 @@ fn test_plane_operation<TestRuntime: Runtime, F: Float + CubeElement, Launch>(
);
}

let actual = client.read_one(handle.binding());
let actual = F::from_bytes(&actual);

assert_eq!(actual, expected);
assert_equals_approx::<TestRuntime, F>(&client, handle, expected, 1e-5);
}

#[allow(missing_docs)]
Expand Down
2 changes: 1 addition & 1 deletion crates/cubecl-cpp/src/cuda/wmma/cuda_compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ impl WmmaCompiler<CudaDialect<Self>> for CudaWmmaCompiler {
gpu::Elem::Float(gpu::FloatKind::TF32),
gpu::Elem::Float(gpu::FloatKind::TF32),
gpu::Elem::Float(gpu::FloatKind::F32),
vec![(16, 8, 16)],
vec![(16, 16, 8)],
));
}
result
Expand Down
4 changes: 2 additions & 2 deletions crates/cubecl-cuda/tests/sequence_for_loop.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ typedef unsigned int uint;
typedef unsigned long long int uint64;
typedef long long int int64;

extern "C" __global__ void sequence_for_loop_kernel(float output_0[],
uint info[]) {
extern "C" __global__ void sequence_for_loop_kernel(float output_0[],
uint info[]) {

int threadIdxGlobal = threadIdx.x + threadIdx.y * blockDim.x +
threadIdx.z * (blockDim.x * blockDim.y);
Expand Down
22 changes: 11 additions & 11 deletions crates/cubecl-spirv/src/item.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,8 @@ impl Item {
b.select(ty, out_id, obj, one, zero).unwrap()
}
(Elem::Int(_, _), Elem::Bool) => {
let one = self.const_u32(b, 1);
b.i_equal(ty, out_id, obj, one).unwrap()
let zero = self.const_u32(b, 0);
b.i_not_equal(ty, out_id, obj, zero).unwrap()
}
(Elem::Int(width_self, signed_self), Elem::Int(width_other, signed_other)) => {
convert_int(
Expand All @@ -254,8 +254,8 @@ impl Item {
b.convert_s_to_f(ty, out_id, obj).unwrap()
}
(Elem::Float(_), Elem::Bool) | (Elem::Relaxed, Elem::Bool) => {
let one = self.const_u32(b, 1);
b.i_equal(ty, out_id, obj, one).unwrap()
let zero = self.const_u32(b, 0);
b.f_unord_not_equal(ty, out_id, obj, zero).unwrap()
}
(Elem::Float(_), Elem::Int(_, false)) | (Elem::Relaxed, Elem::Int(_, false)) => {
b.convert_f_to_u(ty, out_id, obj).unwrap()
Expand Down Expand Up @@ -325,7 +325,7 @@ impl Elem {
let ty = self.id(b);
match self {
Elem::Void => unreachable!(),
Elem::Bool if value.as_u64() == 1 => b.constant_true(ty),
Elem::Bool if value.as_u64() != 0 => b.constant_true(ty),
Elem::Bool => b.constant_false(ty),
_ => match value {
ConstVal::Bit32(val) => b.constant_bit32(ty, val),
Expand Down Expand Up @@ -418,7 +418,7 @@ impl<T: SpirvTarget> SpirvCompiler<T> {
let val = val.as_const().unwrap();

let value = match (val, item.elem()) {
(core::ConstantScalarValue::Int(val, _), Elem::Bool) => ConstVal::from_bool(val == 1),
(core::ConstantScalarValue::Int(val, _), Elem::Bool) => ConstVal::from_bool(val != 0),
(core::ConstantScalarValue::Int(val, _), Elem::Int(width, false)) => {
ConstVal::from_uint(val as u64, width)
}
Expand All @@ -432,7 +432,7 @@ impl<T: SpirvTarget> SpirvCompiler<T> {
ConstVal::from_float(val as f64, 32)
}
(core::ConstantScalarValue::Float(val, _), Elem::Bool) => {
ConstVal::from_bool(val == 1.0)
ConstVal::from_bool(val != 0.0)
}
(core::ConstantScalarValue::Float(val, _), Elem::Int(width, false)) => {
ConstVal::from_uint(val as u64, width)
Expand All @@ -446,7 +446,7 @@ impl<T: SpirvTarget> SpirvCompiler<T> {
(core::ConstantScalarValue::Float(val, _), Elem::Relaxed) => {
ConstVal::from_float(val, 32)
}
(core::ConstantScalarValue::UInt(val, _), Elem::Bool) => ConstVal::from_bool(val == 1),
(core::ConstantScalarValue::UInt(val, _), Elem::Bool) => ConstVal::from_bool(val != 0),
(core::ConstantScalarValue::UInt(val, _), Elem::Int(width, false)) => {
ConstVal::from_uint(val, width)
}
Expand Down Expand Up @@ -479,7 +479,7 @@ impl<T: SpirvTarget> SpirvCompiler<T> {
(Elem::Bool, Elem::Int(width, _)) => ConstVal::from_uint(val.as_u32() as u64, width),
(Elem::Bool, Elem::Float(width)) => ConstVal::from_float(val.as_u32() as f64, width),
(Elem::Bool, Elem::Relaxed) => ConstVal::from_float(val.as_u32() as f64, 32),
(Elem::Int(_, _), Elem::Bool) => ConstVal::from_bool(val.as_u64() == 1),
(Elem::Int(_, _), Elem::Bool) => ConstVal::from_bool(val.as_u64() != 0),
(Elem::Int(_, false), Elem::Int(width, _)) => ConstVal::from_uint(val.as_u64(), width),
(Elem::Int(w_in, true), Elem::Int(width, _)) => {
ConstVal::from_uint(val.as_int(w_in) as u64, width)
Expand All @@ -494,8 +494,8 @@ impl<T: SpirvTarget> SpirvCompiler<T> {
(Elem::Int(in_w, true), Elem::Relaxed) => {
ConstVal::from_float(val.as_int(in_w) as f64, 32)
}
(Elem::Float(in_w), Elem::Bool) => ConstVal::from_bool(val.as_float(in_w) == 1.0),
(Elem::Relaxed, Elem::Bool) => ConstVal::from_bool(val.as_float(32) == 1.0),
(Elem::Float(in_w), Elem::Bool) => ConstVal::from_bool(val.as_float(in_w) != 0.0),
(Elem::Relaxed, Elem::Bool) => ConstVal::from_bool(val.as_float(32) != 0.0),
(Elem::Float(in_w), Elem::Int(out_w, false)) => {
ConstVal::from_uint(val.as_float(in_w) as u64, out_w)
}
Expand Down
58 changes: 50 additions & 8 deletions crates/cubecl-spirv/src/subgroup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,59 @@ impl<T: SpirvTarget> SpirvCompiler<T> {
}
Plane::All(op) => {
self.capabilities.insert(Capability::GroupNonUniformVote);
self.compile_unary_op(op, out, |b, _, ty, input, out| {
b.group_non_uniform_all(ty, Some(out), subgroup, input)
.unwrap();
});
match out.vectorization_factor() {
1 => {
self.compile_unary_op(op, out, |b, _, ty, input, out| {
b.group_non_uniform_all(ty, Some(out), subgroup, input)
.unwrap();
});
}
vec => {
let elem_ty = self.compile_item(op.input.item).elem().id(self);
let bool_ty = self.type_bool();

self.compile_unary_op(op, out, |b, _, ty, input, out| {
let ids = (0..vec)
.map(|i| {
let elem_i = b
.composite_extract(elem_ty, None, input, vec![i as u32])
.unwrap();
b.group_non_uniform_all(bool_ty, None, subgroup, elem_i)
.unwrap()
})
.collect::<Vec<_>>();
b.composite_construct(ty, Some(out), ids).unwrap();
});
}
};
}
Plane::Any(op) => {
self.capabilities.insert(Capability::GroupNonUniformVote);
self.compile_unary_op(op, out, |b, _, ty, input, out| {
b.group_non_uniform_any(ty, Some(out), subgroup, input)
.unwrap();
});
match out.vectorization_factor() {
1 => {
self.compile_unary_op(op, out, |b, _, ty, input, out| {
b.group_non_uniform_any(ty, Some(out), subgroup, input)
.unwrap();
});
}
vec => {
let elem_ty = self.compile_item(op.input.item).elem().id(self);
let bool_ty = self.type_bool();

self.compile_unary_op(op, out, |b, _, ty, input, out| {
let ids = (0..vec)
.map(|i| {
let elem_i = b
.composite_extract(elem_ty, None, input, vec![i as u32])
.unwrap();
b.group_non_uniform_any(bool_ty, None, subgroup, elem_i)
.unwrap()
})
.collect::<Vec<_>>();
b.composite_construct(ty, Some(out), ids).unwrap();
});
}
};
}
Plane::Broadcast(op) => {
self.capabilities.insert(Capability::GroupNonUniformBallot);
Expand Down