Skip to content

Commit

Permalink
Fix/bool cast (#303)
Browse files Browse the repository at this point in the history
  • Loading branch information
wingertge authored Nov 26, 2024
1 parent 82ae5bc commit e794e7b
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 36 deletions.
1 change: 1 addition & 0 deletions crates/cubecl-core/src/runtime_tests/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::{self as cubecl, as_type};
use cubecl::prelude::*;
use cubecl_runtime::server::Handle;

#[track_caller]
pub(crate) fn assert_equals_approx<
R: Runtime,
F: Float + num_traits::Float + CubeElement + Display,
Expand Down
55 changes: 41 additions & 14 deletions crates/cubecl-core/src/runtime_tests/plane.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use crate::Feature;
use std::fmt::Display;

use crate::{self as cubecl};
use crate::{runtime_tests::binary::assert_equals_approx, Feature};
use cubecl::prelude::*;

#[cube(launch)]
Expand Down Expand Up @@ -74,7 +76,10 @@ pub fn kernel_broadcast<F: Float>(output: &mut Tensor<F>) {
}
}

pub fn test_plane_sum<TestRuntime: Runtime, F: Float + CubeElement + Sized>(
pub fn test_plane_sum<
TestRuntime: Runtime,
F: Float + num_traits::Float + CubeElement + Display,
>(
client: ComputeClient<TestRuntime::Server, TestRuntime::Channel>,
vectorization: u8,
) {
Expand Down Expand Up @@ -108,7 +113,10 @@ pub fn test_plane_sum<TestRuntime: Runtime, F: Float + CubeElement + Sized>(
);
}

pub fn test_plane_prod<TestRuntime: Runtime, F: Float + CubeElement>(
pub fn test_plane_prod<
TestRuntime: Runtime,
F: Float + num_traits::Float + CubeElement + Display,
>(
client: ComputeClient<TestRuntime::Server, TestRuntime::Channel>,
vectorization: u8,
) {
Expand Down Expand Up @@ -147,7 +155,10 @@ pub fn test_plane_prod<TestRuntime: Runtime, F: Float + CubeElement>(
);
}

pub fn test_plane_max<TestRuntime: Runtime, F: Float + CubeElement>(
pub fn test_plane_max<
TestRuntime: Runtime,
F: Float + num_traits::Float + CubeElement + Display,
>(
client: ComputeClient<TestRuntime::Server, TestRuntime::Channel>,
vectorization: u8,
) {
Expand Down Expand Up @@ -183,7 +194,10 @@ pub fn test_plane_max<TestRuntime: Runtime, F: Float + CubeElement>(
);
}

pub fn test_plane_min<TestRuntime: Runtime, F: Float + CubeElement>(
pub fn test_plane_min<
TestRuntime: Runtime,
F: Float + num_traits::Float + CubeElement + Display,
>(
client: ComputeClient<TestRuntime::Server, TestRuntime::Channel>,
vectorization: u8,
) {
Expand Down Expand Up @@ -219,7 +233,10 @@ pub fn test_plane_min<TestRuntime: Runtime, F: Float + CubeElement>(
);
}

pub fn test_plane_all<TestRuntime: Runtime, F: Float + CubeElement>(
pub fn test_plane_all<
TestRuntime: Runtime,
F: Float + num_traits::Float + CubeElement + Display,
>(
client: ComputeClient<TestRuntime::Server, TestRuntime::Channel>,
vectorization: u8,
) {
Expand Down Expand Up @@ -257,7 +274,10 @@ pub fn test_plane_all<TestRuntime: Runtime, F: Float + CubeElement>(
);
}

pub fn test_plane_any<TestRuntime: Runtime, F: Float + CubeElement>(
pub fn test_plane_any<
TestRuntime: Runtime,
F: Float + num_traits::Float + CubeElement + Display,
>(
client: ComputeClient<TestRuntime::Server, TestRuntime::Channel>,
vectorization: u8,
) {
Expand Down Expand Up @@ -295,7 +315,10 @@ pub fn test_plane_any<TestRuntime: Runtime, F: Float + CubeElement>(
);
}

pub fn test_plane_elect<TestRuntime: Runtime, F: Float + CubeElement>(
pub fn test_plane_elect<
TestRuntime: Runtime,
F: Float + num_traits::Float + CubeElement + Display,
>(
client: ComputeClient<TestRuntime::Server, TestRuntime::Channel>,
vectorization: u8,
) {
Expand Down Expand Up @@ -324,7 +347,10 @@ pub fn test_plane_elect<TestRuntime: Runtime, F: Float + CubeElement>(
);
}

pub fn test_plane_broadcast<TestRuntime: Runtime, F: Float + CubeElement>(
pub fn test_plane_broadcast<
TestRuntime: Runtime,
F: Float + num_traits::Float + CubeElement + Display,
>(
client: ComputeClient<TestRuntime::Server, TestRuntime::Channel>,
vectorization: u8,
) {
Expand Down Expand Up @@ -356,7 +382,11 @@ pub fn test_plane_broadcast<TestRuntime: Runtime, F: Float + CubeElement>(
);
}

fn test_plane_operation<TestRuntime: Runtime, F: Float + CubeElement, Launch>(
fn test_plane_operation<
TestRuntime: Runtime,
F: Float + num_traits::Float + CubeElement + Display,
Launch,
>(
input: &[F],
expected: &[F],
vectorization: u8,
Expand All @@ -380,10 +410,7 @@ fn test_plane_operation<TestRuntime: Runtime, F: Float + CubeElement, Launch>(
);
}

let actual = client.read_one(handle.binding());
let actual = F::from_bytes(&actual);

assert_eq!(actual, expected);
assert_equals_approx::<TestRuntime, F>(&client, handle, expected, 1e-5);
}

#[allow(missing_docs)]
Expand Down
2 changes: 1 addition & 1 deletion crates/cubecl-cpp/src/cuda/wmma/cuda_compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ impl WmmaCompiler<CudaDialect<Self>> for CudaWmmaCompiler {
gpu::Elem::Float(gpu::FloatKind::TF32),
gpu::Elem::Float(gpu::FloatKind::TF32),
gpu::Elem::Float(gpu::FloatKind::F32),
vec![(16, 8, 16)],
vec![(16, 16, 8)],
));
}
result
Expand Down
4 changes: 2 additions & 2 deletions crates/cubecl-cuda/tests/sequence_for_loop.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ typedef unsigned int uint;
typedef unsigned long long int uint64;
typedef long long int int64;

extern "C" __global__ void sequence_for_loop_kernel(float output_0[],
uint info[]) {
extern "C" __global__ void sequence_for_loop_kernel(float output_0[],
uint info[]) {

int threadIdxGlobal = threadIdx.x + threadIdx.y * blockDim.x +
threadIdx.z * (blockDim.x * blockDim.y);
Expand Down
22 changes: 11 additions & 11 deletions crates/cubecl-spirv/src/item.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,8 @@ impl Item {
b.select(ty, out_id, obj, one, zero).unwrap()
}
(Elem::Int(_, _), Elem::Bool) => {
let one = self.const_u32(b, 1);
b.i_equal(ty, out_id, obj, one).unwrap()
let zero = self.const_u32(b, 0);
b.i_not_equal(ty, out_id, obj, zero).unwrap()
}
(Elem::Int(width_self, signed_self), Elem::Int(width_other, signed_other)) => {
convert_int(
Expand All @@ -254,8 +254,8 @@ impl Item {
b.convert_s_to_f(ty, out_id, obj).unwrap()
}
(Elem::Float(_), Elem::Bool) | (Elem::Relaxed, Elem::Bool) => {
let one = self.const_u32(b, 1);
b.i_equal(ty, out_id, obj, one).unwrap()
let zero = self.const_u32(b, 0);
b.f_unord_not_equal(ty, out_id, obj, zero).unwrap()
}
(Elem::Float(_), Elem::Int(_, false)) | (Elem::Relaxed, Elem::Int(_, false)) => {
b.convert_f_to_u(ty, out_id, obj).unwrap()
Expand Down Expand Up @@ -325,7 +325,7 @@ impl Elem {
let ty = self.id(b);
match self {
Elem::Void => unreachable!(),
Elem::Bool if value.as_u64() == 1 => b.constant_true(ty),
Elem::Bool if value.as_u64() != 0 => b.constant_true(ty),
Elem::Bool => b.constant_false(ty),
_ => match value {
ConstVal::Bit32(val) => b.constant_bit32(ty, val),
Expand Down Expand Up @@ -418,7 +418,7 @@ impl<T: SpirvTarget> SpirvCompiler<T> {
let val = val.as_const().unwrap();

let value = match (val, item.elem()) {
(core::ConstantScalarValue::Int(val, _), Elem::Bool) => ConstVal::from_bool(val == 1),
(core::ConstantScalarValue::Int(val, _), Elem::Bool) => ConstVal::from_bool(val != 0),
(core::ConstantScalarValue::Int(val, _), Elem::Int(width, false)) => {
ConstVal::from_uint(val as u64, width)
}
Expand All @@ -432,7 +432,7 @@ impl<T: SpirvTarget> SpirvCompiler<T> {
ConstVal::from_float(val as f64, 32)
}
(core::ConstantScalarValue::Float(val, _), Elem::Bool) => {
ConstVal::from_bool(val == 1.0)
ConstVal::from_bool(val != 0.0)
}
(core::ConstantScalarValue::Float(val, _), Elem::Int(width, false)) => {
ConstVal::from_uint(val as u64, width)
Expand All @@ -446,7 +446,7 @@ impl<T: SpirvTarget> SpirvCompiler<T> {
(core::ConstantScalarValue::Float(val, _), Elem::Relaxed) => {
ConstVal::from_float(val, 32)
}
(core::ConstantScalarValue::UInt(val, _), Elem::Bool) => ConstVal::from_bool(val == 1),
(core::ConstantScalarValue::UInt(val, _), Elem::Bool) => ConstVal::from_bool(val != 0),
(core::ConstantScalarValue::UInt(val, _), Elem::Int(width, false)) => {
ConstVal::from_uint(val, width)
}
Expand Down Expand Up @@ -479,7 +479,7 @@ impl<T: SpirvTarget> SpirvCompiler<T> {
(Elem::Bool, Elem::Int(width, _)) => ConstVal::from_uint(val.as_u32() as u64, width),
(Elem::Bool, Elem::Float(width)) => ConstVal::from_float(val.as_u32() as f64, width),
(Elem::Bool, Elem::Relaxed) => ConstVal::from_float(val.as_u32() as f64, 32),
(Elem::Int(_, _), Elem::Bool) => ConstVal::from_bool(val.as_u64() == 1),
(Elem::Int(_, _), Elem::Bool) => ConstVal::from_bool(val.as_u64() != 0),
(Elem::Int(_, false), Elem::Int(width, _)) => ConstVal::from_uint(val.as_u64(), width),
(Elem::Int(w_in, true), Elem::Int(width, _)) => {
ConstVal::from_uint(val.as_int(w_in) as u64, width)
Expand All @@ -494,8 +494,8 @@ impl<T: SpirvTarget> SpirvCompiler<T> {
(Elem::Int(in_w, true), Elem::Relaxed) => {
ConstVal::from_float(val.as_int(in_w) as f64, 32)
}
(Elem::Float(in_w), Elem::Bool) => ConstVal::from_bool(val.as_float(in_w) == 1.0),
(Elem::Relaxed, Elem::Bool) => ConstVal::from_bool(val.as_float(32) == 1.0),
(Elem::Float(in_w), Elem::Bool) => ConstVal::from_bool(val.as_float(in_w) != 0.0),
(Elem::Relaxed, Elem::Bool) => ConstVal::from_bool(val.as_float(32) != 0.0),
(Elem::Float(in_w), Elem::Int(out_w, false)) => {
ConstVal::from_uint(val.as_float(in_w) as u64, out_w)
}
Expand Down
58 changes: 50 additions & 8 deletions crates/cubecl-spirv/src/subgroup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,59 @@ impl<T: SpirvTarget> SpirvCompiler<T> {
}
Plane::All(op) => {
self.capabilities.insert(Capability::GroupNonUniformVote);
self.compile_unary_op(op, out, |b, _, ty, input, out| {
b.group_non_uniform_all(ty, Some(out), subgroup, input)
.unwrap();
});
match out.vectorization_factor() {
1 => {
self.compile_unary_op(op, out, |b, _, ty, input, out| {
b.group_non_uniform_all(ty, Some(out), subgroup, input)
.unwrap();
});
}
vec => {
let elem_ty = self.compile_item(op.input.item).elem().id(self);
let bool_ty = self.type_bool();

self.compile_unary_op(op, out, |b, _, ty, input, out| {
let ids = (0..vec)
.map(|i| {
let elem_i = b
.composite_extract(elem_ty, None, input, vec![i as u32])
.unwrap();
b.group_non_uniform_all(bool_ty, None, subgroup, elem_i)
.unwrap()
})
.collect::<Vec<_>>();
b.composite_construct(ty, Some(out), ids).unwrap();
});
}
};
}
Plane::Any(op) => {
self.capabilities.insert(Capability::GroupNonUniformVote);
self.compile_unary_op(op, out, |b, _, ty, input, out| {
b.group_non_uniform_any(ty, Some(out), subgroup, input)
.unwrap();
});
match out.vectorization_factor() {
1 => {
self.compile_unary_op(op, out, |b, _, ty, input, out| {
b.group_non_uniform_any(ty, Some(out), subgroup, input)
.unwrap();
});
}
vec => {
let elem_ty = self.compile_item(op.input.item).elem().id(self);
let bool_ty = self.type_bool();

self.compile_unary_op(op, out, |b, _, ty, input, out| {
let ids = (0..vec)
.map(|i| {
let elem_i = b
.composite_extract(elem_ty, None, input, vec![i as u32])
.unwrap();
b.group_non_uniform_any(bool_ty, None, subgroup, elem_i)
.unwrap()
})
.collect::<Vec<_>>();
b.composite_construct(ty, Some(out), ids).unwrap();
});
}
};
}
Plane::Broadcast(op) => {
self.capabilities.insert(Capability::GroupNonUniformBallot);
Expand Down

0 comments on commit e794e7b

Please sign in to comment.