Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/bool cast #303

Merged
merged 5 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/cubecl-core/src/runtime_tests/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::{self as cubecl, as_type};
use cubecl::prelude::*;
use cubecl_runtime::server::Handle;

#[track_caller]
pub(crate) fn assert_equals_approx<
R: Runtime,
F: Float + num_traits::Float + CubeElement + Display,
Expand Down
55 changes: 41 additions & 14 deletions crates/cubecl-core/src/runtime_tests/plane.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use crate::Feature;
use std::fmt::Display;

use crate::{self as cubecl};
use crate::{runtime_tests::binary::assert_equals_approx, Feature};
use cubecl::prelude::*;

#[cube(launch)]
Expand Down Expand Up @@ -74,7 +76,10 @@ pub fn kernel_broadcast<F: Float>(output: &mut Tensor<F>) {
}
}

pub fn test_plane_sum<TestRuntime: Runtime, F: Float + CubeElement + Sized>(
pub fn test_plane_sum<
TestRuntime: Runtime,
F: Float + num_traits::Float + CubeElement + Display,
>(
client: ComputeClient<TestRuntime::Server, TestRuntime::Channel>,
vectorization: u8,
) {
Expand Down Expand Up @@ -108,7 +113,10 @@ pub fn test_plane_sum<TestRuntime: Runtime, F: Float + CubeElement + Sized>(
);
}

pub fn test_plane_prod<TestRuntime: Runtime, F: Float + CubeElement>(
pub fn test_plane_prod<
TestRuntime: Runtime,
F: Float + num_traits::Float + CubeElement + Display,
>(
client: ComputeClient<TestRuntime::Server, TestRuntime::Channel>,
vectorization: u8,
) {
Expand Down Expand Up @@ -147,7 +155,10 @@ pub fn test_plane_prod<TestRuntime: Runtime, F: Float + CubeElement>(
);
}

pub fn test_plane_max<TestRuntime: Runtime, F: Float + CubeElement>(
pub fn test_plane_max<
TestRuntime: Runtime,
F: Float + num_traits::Float + CubeElement + Display,
>(
client: ComputeClient<TestRuntime::Server, TestRuntime::Channel>,
vectorization: u8,
) {
Expand Down Expand Up @@ -183,7 +194,10 @@ pub fn test_plane_max<TestRuntime: Runtime, F: Float + CubeElement>(
);
}

pub fn test_plane_min<TestRuntime: Runtime, F: Float + CubeElement>(
pub fn test_plane_min<
TestRuntime: Runtime,
F: Float + num_traits::Float + CubeElement + Display,
>(
client: ComputeClient<TestRuntime::Server, TestRuntime::Channel>,
vectorization: u8,
) {
Expand Down Expand Up @@ -219,7 +233,10 @@ pub fn test_plane_min<TestRuntime: Runtime, F: Float + CubeElement>(
);
}

pub fn test_plane_all<TestRuntime: Runtime, F: Float + CubeElement>(
pub fn test_plane_all<
TestRuntime: Runtime,
F: Float + num_traits::Float + CubeElement + Display,
>(
client: ComputeClient<TestRuntime::Server, TestRuntime::Channel>,
vectorization: u8,
) {
Expand Down Expand Up @@ -257,7 +274,10 @@ pub fn test_plane_all<TestRuntime: Runtime, F: Float + CubeElement>(
);
}

pub fn test_plane_any<TestRuntime: Runtime, F: Float + CubeElement>(
pub fn test_plane_any<
TestRuntime: Runtime,
F: Float + num_traits::Float + CubeElement + Display,
>(
client: ComputeClient<TestRuntime::Server, TestRuntime::Channel>,
vectorization: u8,
) {
Expand Down Expand Up @@ -295,7 +315,10 @@ pub fn test_plane_any<TestRuntime: Runtime, F: Float + CubeElement>(
);
}

pub fn test_plane_elect<TestRuntime: Runtime, F: Float + CubeElement>(
pub fn test_plane_elect<
TestRuntime: Runtime,
F: Float + num_traits::Float + CubeElement + Display,
>(
client: ComputeClient<TestRuntime::Server, TestRuntime::Channel>,
vectorization: u8,
) {
Expand Down Expand Up @@ -324,7 +347,10 @@ pub fn test_plane_elect<TestRuntime: Runtime, F: Float + CubeElement>(
);
}

pub fn test_plane_broadcast<TestRuntime: Runtime, F: Float + CubeElement>(
pub fn test_plane_broadcast<
TestRuntime: Runtime,
F: Float + num_traits::Float + CubeElement + Display,
>(
client: ComputeClient<TestRuntime::Server, TestRuntime::Channel>,
vectorization: u8,
) {
Expand Down Expand Up @@ -356,7 +382,11 @@ pub fn test_plane_broadcast<TestRuntime: Runtime, F: Float + CubeElement>(
);
}

fn test_plane_operation<TestRuntime: Runtime, F: Float + CubeElement, Launch>(
fn test_plane_operation<
TestRuntime: Runtime,
F: Float + num_traits::Float + CubeElement + Display,
Launch,
>(
input: &[F],
expected: &[F],
vectorization: u8,
Expand All @@ -380,10 +410,7 @@ fn test_plane_operation<TestRuntime: Runtime, F: Float + CubeElement, Launch>(
);
}

let actual = client.read_one(handle.binding());
let actual = F::from_bytes(&actual);

assert_eq!(actual, expected);
assert_equals_approx::<TestRuntime, F>(&client, handle, expected, 1e-5);
}

#[allow(missing_docs)]
Expand Down
2 changes: 1 addition & 1 deletion crates/cubecl-cpp/src/cuda/wmma/cuda_compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ impl WmmaCompiler<CudaDialect<Self>> for CudaWmmaCompiler {
gpu::Elem::Float(gpu::FloatKind::TF32),
gpu::Elem::Float(gpu::FloatKind::TF32),
gpu::Elem::Float(gpu::FloatKind::F32),
vec![(16, 8, 16)],
vec![(16, 16, 8)],
));
}
result
Expand Down
4 changes: 3 additions & 1 deletion crates/cubecl-cuda/tests/constant_array.cu
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
#include <mma.h>
typedef unsigned char uint8;
typedef unsigned short uint16;
typedef unsigned int uint;
typedef unsigned long long int uint64;
typedef long long int int64;

extern "C" __global__ void kernel(float output_0[], uint info[]) {
extern "C" __global__ void constant_array_kernel(float output_0[],
uint info[]) {

int3 absoluteIdx = make_int3(blockIdx.x * blockDim.x + threadIdx.x,
blockIdx.y * blockDim.y + threadIdx.y,
Expand Down
4 changes: 3 additions & 1 deletion crates/cubecl-cuda/tests/plane_sum.cu
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
#include <mma.h>
typedef unsigned char uint8;
typedef unsigned short uint16;
typedef unsigned int uint;
typedef unsigned long long int uint64;
typedef long long int int64;

extern "C" __global__ void kernel(float output_0[], uint info[]) {
extern "C" __global__ void kernel_sum(float output_0[], uint info[]) {

int threadIdxGlobal = threadIdx.x + threadIdx.y * blockDim.x +
threadIdx.z * (blockDim.x * blockDim.y);
Expand All @@ -18,6 +19,7 @@ extern "C" __global__ void kernel(float output_0[], uint info[]) {
l_0_0 = (threadIdxGlobal < l_0_3) ? output_0[threadIdxGlobal] : float(0);

l_0_1 = l_0_0;

{
for (int offset = 1; offset < warpSizeChecked; offset *= 2) {
l_0_1 += __shfl_xor_sync(-1, l_0_1, offset);
Expand Down
4 changes: 3 additions & 1 deletion crates/cubecl-cuda/tests/sequence_for_loop.cu
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
#include <mma.h>
typedef unsigned char uint8;
typedef unsigned short uint16;
typedef unsigned int uint;
typedef unsigned long long int uint64;
typedef long long int int64;

extern "C" __global__ void kernel(float output_0[], uint info[]) {
extern "C" __global__ void sequence_for_loop_kernel(float output_0[],
uint info[]) {

int threadIdxGlobal = threadIdx.x + threadIdx.y * blockDim.x +
threadIdx.z * (blockDim.x * blockDim.y);
Expand Down
5 changes: 3 additions & 2 deletions crates/cubecl-cuda/tests/slice_assign.cu
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
#include <mma.h>
typedef unsigned char uint8;
typedef unsigned short uint16;
typedef unsigned int uint;
typedef unsigned long long int uint64;
typedef long long int int64;

extern "C" __global__ void kernel(float input_0[], float output_0[],
uint info[]) {
extern "C" __global__ void slice_assign_kernel(float input_0[],
float output_0[], uint info[]) {

int threadIdxGlobal = threadIdx.x + threadIdx.y * blockDim.x +
threadIdx.z * (blockDim.x * blockDim.y);
Expand Down
7 changes: 5 additions & 2 deletions crates/cubecl-cuda/tests/unary_bench.cu
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <mma.h>
typedef unsigned char uint8;
typedef unsigned short uint16;
typedef unsigned int uint;
Expand All @@ -11,8 +12,10 @@ struct __align__(16) float_4 {
float i_3;
};

extern "C" __global__ void kernel(float_4 input_0[], float_4 input_1[],
float_4 output_0[], uint info[]) {
extern "C" __global__ void execute_unary_kernel(float_4 input_0[],
float_4 input_1[],
float_4 output_0[],
uint info[]) {

int3 absoluteIdx = make_int3(blockIdx.x * blockDim.x + threadIdx.x,
blockIdx.y * blockDim.y + threadIdx.y,
Expand Down
22 changes: 11 additions & 11 deletions crates/cubecl-spirv/src/item.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,8 @@ impl Item {
b.select(ty, out_id, obj, one, zero).unwrap()
}
(Elem::Int(_, _), Elem::Bool) => {
let one = self.const_u32(b, 1);
b.i_equal(ty, out_id, obj, one).unwrap()
let zero = self.const_u32(b, 0);
b.i_not_equal(ty, out_id, obj, zero).unwrap()
}
(Elem::Int(width_self, signed_self), Elem::Int(width_other, signed_other)) => {
convert_int(
Expand All @@ -254,8 +254,8 @@ impl Item {
b.convert_s_to_f(ty, out_id, obj).unwrap()
}
(Elem::Float(_), Elem::Bool) | (Elem::Relaxed, Elem::Bool) => {
let one = self.const_u32(b, 1);
b.i_equal(ty, out_id, obj, one).unwrap()
let zero = self.const_u32(b, 0);
b.f_unord_not_equal(ty, out_id, obj, zero).unwrap()
}
(Elem::Float(_), Elem::Int(_, false)) | (Elem::Relaxed, Elem::Int(_, false)) => {
b.convert_f_to_u(ty, out_id, obj).unwrap()
Expand Down Expand Up @@ -325,7 +325,7 @@ impl Elem {
let ty = self.id(b);
match self {
Elem::Void => unreachable!(),
Elem::Bool if value.as_u64() == 1 => b.constant_true(ty),
Elem::Bool if value.as_u64() != 0 => b.constant_true(ty),
Elem::Bool => b.constant_false(ty),
_ => match value {
ConstVal::Bit32(val) => b.constant_bit32(ty, val),
Expand Down Expand Up @@ -418,7 +418,7 @@ impl<T: SpirvTarget> SpirvCompiler<T> {
let val = val.as_const().unwrap();

let value = match (val, item.elem()) {
(core::ConstantScalarValue::Int(val, _), Elem::Bool) => ConstVal::from_bool(val == 1),
(core::ConstantScalarValue::Int(val, _), Elem::Bool) => ConstVal::from_bool(val != 0),
(core::ConstantScalarValue::Int(val, _), Elem::Int(width, false)) => {
ConstVal::from_uint(val as u64, width)
}
Expand All @@ -432,7 +432,7 @@ impl<T: SpirvTarget> SpirvCompiler<T> {
ConstVal::from_float(val as f64, 32)
}
(core::ConstantScalarValue::Float(val, _), Elem::Bool) => {
ConstVal::from_bool(val == 1.0)
ConstVal::from_bool(val != 0.0)
}
(core::ConstantScalarValue::Float(val, _), Elem::Int(width, false)) => {
ConstVal::from_uint(val as u64, width)
Expand All @@ -446,7 +446,7 @@ impl<T: SpirvTarget> SpirvCompiler<T> {
(core::ConstantScalarValue::Float(val, _), Elem::Relaxed) => {
ConstVal::from_float(val, 32)
}
(core::ConstantScalarValue::UInt(val, _), Elem::Bool) => ConstVal::from_bool(val == 1),
(core::ConstantScalarValue::UInt(val, _), Elem::Bool) => ConstVal::from_bool(val != 0),
(core::ConstantScalarValue::UInt(val, _), Elem::Int(width, false)) => {
ConstVal::from_uint(val, width)
}
Expand Down Expand Up @@ -479,7 +479,7 @@ impl<T: SpirvTarget> SpirvCompiler<T> {
(Elem::Bool, Elem::Int(width, _)) => ConstVal::from_uint(val.as_u32() as u64, width),
(Elem::Bool, Elem::Float(width)) => ConstVal::from_float(val.as_u32() as f64, width),
(Elem::Bool, Elem::Relaxed) => ConstVal::from_float(val.as_u32() as f64, 32),
(Elem::Int(_, _), Elem::Bool) => ConstVal::from_bool(val.as_u64() == 1),
(Elem::Int(_, _), Elem::Bool) => ConstVal::from_bool(val.as_u64() != 0),
(Elem::Int(_, false), Elem::Int(width, _)) => ConstVal::from_uint(val.as_u64(), width),
(Elem::Int(w_in, true), Elem::Int(width, _)) => {
ConstVal::from_uint(val.as_int(w_in) as u64, width)
Expand All @@ -494,8 +494,8 @@ impl<T: SpirvTarget> SpirvCompiler<T> {
(Elem::Int(in_w, true), Elem::Relaxed) => {
ConstVal::from_float(val.as_int(in_w) as f64, 32)
}
(Elem::Float(in_w), Elem::Bool) => ConstVal::from_bool(val.as_float(in_w) == 1.0),
(Elem::Relaxed, Elem::Bool) => ConstVal::from_bool(val.as_float(32) == 1.0),
(Elem::Float(in_w), Elem::Bool) => ConstVal::from_bool(val.as_float(in_w) != 0.0),
(Elem::Relaxed, Elem::Bool) => ConstVal::from_bool(val.as_float(32) != 0.0),
(Elem::Float(in_w), Elem::Int(out_w, false)) => {
ConstVal::from_uint(val.as_float(in_w) as u64, out_w)
}
Expand Down
58 changes: 50 additions & 8 deletions crates/cubecl-spirv/src/subgroup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,59 @@ impl<T: SpirvTarget> SpirvCompiler<T> {
}
Plane::All(op) => {
self.capabilities.insert(Capability::GroupNonUniformVote);
self.compile_unary_op(op, out, |b, _, ty, input, out| {
b.group_non_uniform_all(ty, Some(out), subgroup, input)
.unwrap();
});
match out.vectorization_factor() {
1 => {
self.compile_unary_op(op, out, |b, _, ty, input, out| {
b.group_non_uniform_all(ty, Some(out), subgroup, input)
.unwrap();
});
}
vec => {
let elem_ty = self.compile_item(op.input.item).elem().id(self);
let bool_ty = self.type_bool();

self.compile_unary_op(op, out, |b, _, ty, input, out| {
let ids = (0..vec)
.map(|i| {
let elem_i = b
.composite_extract(elem_ty, None, input, vec![i as u32])
.unwrap();
b.group_non_uniform_all(bool_ty, None, subgroup, elem_i)
.unwrap()
})
.collect::<Vec<_>>();
b.composite_construct(ty, Some(out), ids).unwrap();
});
}
};
}
Plane::Any(op) => {
self.capabilities.insert(Capability::GroupNonUniformVote);
self.compile_unary_op(op, out, |b, _, ty, input, out| {
b.group_non_uniform_any(ty, Some(out), subgroup, input)
.unwrap();
});
match out.vectorization_factor() {
1 => {
self.compile_unary_op(op, out, |b, _, ty, input, out| {
b.group_non_uniform_any(ty, Some(out), subgroup, input)
.unwrap();
});
}
vec => {
let elem_ty = self.compile_item(op.input.item).elem().id(self);
let bool_ty = self.type_bool();

self.compile_unary_op(op, out, |b, _, ty, input, out| {
let ids = (0..vec)
.map(|i| {
let elem_i = b
.composite_extract(elem_ty, None, input, vec![i as u32])
.unwrap();
b.group_non_uniform_any(bool_ty, None, subgroup, elem_i)
.unwrap()
})
.collect::<Vec<_>>();
b.composite_construct(ty, Some(out), ids).unwrap();
});
}
};
}
Plane::Broadcast(op) => {
self.capabilities.insert(Capability::GroupNonUniformBallot);
Expand Down