From f816850554909330692d5291ef3638b3c2142dd2 Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Fri, 22 Nov 2024 22:35:52 +0100 Subject: [PATCH 1/3] Fix bool cast to consider all non-zero as true --- .../cubecl-core/src/runtime_tests/binary.rs | 1 + crates/cubecl-core/src/runtime_tests/plane.rs | 55 ++++++++++++++----- crates/cubecl-cuda/tests/common.rs | 1 + crates/cubecl-spirv/src/item.rs | 22 ++++---- 4 files changed, 54 insertions(+), 25 deletions(-) diff --git a/crates/cubecl-core/src/runtime_tests/binary.rs b/crates/cubecl-core/src/runtime_tests/binary.rs index 60e32a5f2..5013ed2e2 100644 --- a/crates/cubecl-core/src/runtime_tests/binary.rs +++ b/crates/cubecl-core/src/runtime_tests/binary.rs @@ -5,6 +5,7 @@ use crate::{self as cubecl, as_type}; use cubecl::prelude::*; use cubecl_runtime::server::Handle; +#[track_caller] pub(crate) fn assert_equals_approx< R: Runtime, F: Float + num_traits::Float + CubeElement + Display, diff --git a/crates/cubecl-core/src/runtime_tests/plane.rs b/crates/cubecl-core/src/runtime_tests/plane.rs index 2849805fa..ae82ec3fa 100644 --- a/crates/cubecl-core/src/runtime_tests/plane.rs +++ b/crates/cubecl-core/src/runtime_tests/plane.rs @@ -1,5 +1,7 @@ -use crate::Feature; +use std::fmt::Display; + use crate::{self as cubecl}; +use crate::{runtime_tests::binary::assert_equals_approx, Feature}; use cubecl::prelude::*; #[cube(launch)] @@ -74,7 +76,10 @@ pub fn kernel_broadcast(output: &mut Tensor) { } } -pub fn test_plane_sum( +pub fn test_plane_sum< + TestRuntime: Runtime, + F: Float + num_traits::Float + CubeElement + Display, +>( client: ComputeClient, vectorization: u8, ) { @@ -108,7 +113,10 @@ pub fn test_plane_sum( ); } -pub fn test_plane_prod( +pub fn test_plane_prod< + TestRuntime: Runtime, + F: Float + num_traits::Float + CubeElement + Display, +>( client: ComputeClient, vectorization: u8, ) { @@ -147,7 +155,10 @@ pub fn test_plane_prod( ); } -pub fn test_plane_max( +pub fn test_plane_max< + TestRuntime: Runtime, + F: Float + num_traits::Float + CubeElement + Display, +>( client: ComputeClient, vectorization: u8, ) { @@ -183,7 +194,10 @@ pub fn test_plane_max( ); } -pub fn test_plane_min( +pub fn test_plane_min< + TestRuntime: Runtime, + F: Float + num_traits::Float + CubeElement + Display, +>( client: ComputeClient, vectorization: u8, ) { @@ -219,7 +233,10 @@ pub fn test_plane_min( ); } -pub fn test_plane_all( +pub fn test_plane_all< + TestRuntime: Runtime, + F: Float + num_traits::Float + CubeElement + Display, +>( client: ComputeClient, vectorization: u8, ) { @@ -257,7 +274,10 @@ pub fn test_plane_all( ); } -pub fn test_plane_any( +pub fn test_plane_any< + TestRuntime: Runtime, + F: Float + num_traits::Float + CubeElement + Display, +>( client: ComputeClient, vectorization: u8, ) { @@ -295,7 +315,10 @@ pub fn test_plane_any( ); } -pub fn test_plane_elect( +pub fn test_plane_elect< + TestRuntime: Runtime, + F: Float + num_traits::Float + CubeElement + Display, +>( client: ComputeClient, vectorization: u8, ) { @@ -324,7 +347,10 @@ pub fn test_plane_elect( ); } -pub fn test_plane_broadcast( +pub fn test_plane_broadcast< + TestRuntime: Runtime, + F: Float + num_traits::Float + CubeElement + Display, +>( client: ComputeClient, vectorization: u8, ) { @@ -356,7 +382,11 @@ pub fn test_plane_broadcast( ); } -fn test_plane_operation( +fn test_plane_operation< + TestRuntime: Runtime, + F: Float + num_traits::Float + CubeElement + Display, + Launch, +>( input: &[F], expected: &[F], vectorization: u8, @@ -380,10 +410,7 @@ fn test_plane_operation( ); } - let actual = client.read_one(handle.binding()); - let actual = F::from_bytes(&actual); - - assert_eq!(actual, expected); + assert_equals_approx::(&client, handle, expected, 1e-5); } #[allow(missing_docs)] diff --git a/crates/cubecl-cuda/tests/common.rs b/crates/cubecl-cuda/tests/common.rs index 813f91b81..537866bf5 100644 --- a/crates/cubecl-cuda/tests/common.rs +++ b/crates/cubecl-cuda/tests/common.rs @@ -37,6 +37,7 @@ pub fn array() -> ArrayCompilationArg { pub fn compile(kernel: impl Kernel) -> String { let kernel = <::Compiler as Compiler>::compile( kernel.define(), + &Default::default(), ExecutionMode::Checked, ) .to_string(); diff --git a/crates/cubecl-spirv/src/item.rs b/crates/cubecl-spirv/src/item.rs index a0dcb260a..65726ee83 100644 --- a/crates/cubecl-spirv/src/item.rs +++ b/crates/cubecl-spirv/src/item.rs @@ -235,8 +235,8 @@ impl Item { b.select(ty, out_id, obj, one, zero).unwrap() } (Elem::Int(_, _), Elem::Bool) => { - let one = self.const_u32(b, 1); - b.i_equal(ty, out_id, obj, one).unwrap() + let zero = self.const_u32(b, 0); + b.i_not_equal(ty, out_id, obj, zero).unwrap() } (Elem::Int(width_self, signed_self), Elem::Int(width_other, signed_other)) => { convert_int( @@ -254,8 +254,8 @@ impl Item { b.convert_s_to_f(ty, out_id, obj).unwrap() } (Elem::Float(_), Elem::Bool) | (Elem::Relaxed, Elem::Bool) => { - let one = self.const_u32(b, 1); - b.i_equal(ty, out_id, obj, one).unwrap() + let zero = self.const_u32(b, 0); + b.f_unord_not_equal(ty, out_id, obj, zero).unwrap() } (Elem::Float(_), Elem::Int(_, false)) | (Elem::Relaxed, Elem::Int(_, false)) => { b.convert_f_to_u(ty, out_id, obj).unwrap() @@ -325,7 +325,7 @@ impl Elem { let ty = self.id(b); match self { Elem::Void => unreachable!(), - Elem::Bool if value.as_u64() == 1 => b.constant_true(ty), + Elem::Bool if value.as_u64() != 0 => b.constant_true(ty), Elem::Bool => b.constant_false(ty), _ => match value { ConstVal::Bit32(val) => b.constant_bit32(ty, val), @@ -418,7 +418,7 @@ impl SpirvCompiler { let val = val.as_const().unwrap(); let value = match (val, item.elem()) { - (core::ConstantScalarValue::Int(val, _), Elem::Bool) => ConstVal::from_bool(val == 1), + (core::ConstantScalarValue::Int(val, _), Elem::Bool) => ConstVal::from_bool(val != 0), (core::ConstantScalarValue::Int(val, _), Elem::Int(width, false)) => { ConstVal::from_uint(val as u64, width) } @@ -432,7 +432,7 @@ impl SpirvCompiler { ConstVal::from_float(val as f64, 32) } (core::ConstantScalarValue::Float(val, _), Elem::Bool) => { - ConstVal::from_bool(val == 1.0) + ConstVal::from_bool(val != 0.0) } (core::ConstantScalarValue::Float(val, _), Elem::Int(width, false)) => { ConstVal::from_uint(val as u64, width) @@ -446,7 +446,7 @@ impl SpirvCompiler { (core::ConstantScalarValue::Float(val, _), Elem::Relaxed) => { ConstVal::from_float(val, 32) } - (core::ConstantScalarValue::UInt(val, _), Elem::Bool) => ConstVal::from_bool(val == 1), + (core::ConstantScalarValue::UInt(val, _), Elem::Bool) => ConstVal::from_bool(val != 0), (core::ConstantScalarValue::UInt(val, _), Elem::Int(width, false)) => { ConstVal::from_uint(val, width) } @@ -479,7 +479,7 @@ impl SpirvCompiler { (Elem::Bool, Elem::Int(width, _)) => ConstVal::from_uint(val.as_u32() as u64, width), (Elem::Bool, Elem::Float(width)) => ConstVal::from_float(val.as_u32() as f64, width), (Elem::Bool, Elem::Relaxed) => ConstVal::from_float(val.as_u32() as f64, 32), - (Elem::Int(_, _), Elem::Bool) => ConstVal::from_bool(val.as_u64() == 1), + (Elem::Int(_, _), Elem::Bool) => ConstVal::from_bool(val.as_u64() != 0), (Elem::Int(_, false), Elem::Int(width, _)) => ConstVal::from_uint(val.as_u64(), width), (Elem::Int(w_in, true), Elem::Int(width, _)) => { ConstVal::from_uint(val.as_int(w_in) as u64, width) @@ -494,8 +494,8 @@ impl SpirvCompiler { (Elem::Int(in_w, true), Elem::Relaxed) => { ConstVal::from_float(val.as_int(in_w) as f64, 32) } - (Elem::Float(in_w), Elem::Bool) => ConstVal::from_bool(val.as_float(in_w) == 1.0), - (Elem::Relaxed, Elem::Bool) => ConstVal::from_bool(val.as_float(32) == 1.0), + (Elem::Float(in_w), Elem::Bool) => ConstVal::from_bool(val.as_float(in_w) != 0.0), + (Elem::Relaxed, Elem::Bool) => ConstVal::from_bool(val.as_float(32) != 0.0), (Elem::Float(in_w), Elem::Int(out_w, false)) => { ConstVal::from_uint(val.as_float(in_w) as u64, out_w) } From 7398decb239ba78bb7fb8d8644792b818fa11b83 Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Fri, 22 Nov 2024 23:50:52 +0100 Subject: [PATCH 2/3] Fix tests and plane_all/any --- crates/cubecl-cuda/tests/constant_array.cu | 4 +- crates/cubecl-cuda/tests/plane_sum.cu | 4 +- crates/cubecl-cuda/tests/sequence_for_loop.cu | 4 +- crates/cubecl-cuda/tests/slice_assign.cu | 5 +- crates/cubecl-cuda/tests/unary_bench.cu | 7 ++- crates/cubecl-spirv/src/subgroup.rs | 58 ++++++++++++++++--- 6 files changed, 67 insertions(+), 15 deletions(-) diff --git a/crates/cubecl-cuda/tests/constant_array.cu b/crates/cubecl-cuda/tests/constant_array.cu index 5723d5b70..d7dbbd23d 100644 --- a/crates/cubecl-cuda/tests/constant_array.cu +++ b/crates/cubecl-cuda/tests/constant_array.cu @@ -1,10 +1,12 @@ +#include typedef unsigned char uint8; typedef unsigned short uint16; typedef unsigned int uint; typedef unsigned long long int uint64; typedef long long int int64; -extern "C" __global__ void kernel(float output_0[], uint info[]) { +extern "C" __global__ void constant_array_kernel(float output_0[], + uint info[]) { int3 absoluteIdx = make_int3(blockIdx.x * blockDim.x + threadIdx.x, blockIdx.y * blockDim.y + threadIdx.y, diff --git a/crates/cubecl-cuda/tests/plane_sum.cu b/crates/cubecl-cuda/tests/plane_sum.cu index f675e88b2..522159a41 100644 --- a/crates/cubecl-cuda/tests/plane_sum.cu +++ b/crates/cubecl-cuda/tests/plane_sum.cu @@ -1,10 +1,11 @@ +#include typedef unsigned char uint8; typedef unsigned short uint16; typedef unsigned int uint; typedef unsigned long long int uint64; typedef long long int int64; -extern "C" __global__ void kernel(float output_0[], uint info[]) { +extern "C" __global__ void kernel_sum(float output_0[], uint info[]) { int threadIdxGlobal = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * (blockDim.x * blockDim.y); @@ -18,6 +19,7 @@ extern "C" __global__ void kernel(float output_0[], uint info[]) { l_0_0 = (threadIdxGlobal < l_0_3) ? output_0[threadIdxGlobal] : float(0); l_0_1 = l_0_0; + { for (int offset = 1; offset < warpSizeChecked; offset *= 2) { l_0_1 += __shfl_xor_sync(-1, l_0_1, offset); diff --git a/crates/cubecl-cuda/tests/sequence_for_loop.cu b/crates/cubecl-cuda/tests/sequence_for_loop.cu index 4b87281a3..7baa227ce 100644 --- a/crates/cubecl-cuda/tests/sequence_for_loop.cu +++ b/crates/cubecl-cuda/tests/sequence_for_loop.cu @@ -1,10 +1,12 @@ +#include typedef unsigned char uint8; typedef unsigned short uint16; typedef unsigned int uint; typedef unsigned long long int uint64; typedef long long int int64; -extern "C" __global__ void kernel(float output_0[], uint info[]) { +extern "C" __global__ void sequence_for_loop_kernel(float output_0[], + uint info[]) { int threadIdxGlobal = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * (blockDim.x * blockDim.y); diff --git a/crates/cubecl-cuda/tests/slice_assign.cu b/crates/cubecl-cuda/tests/slice_assign.cu index b7648a2ad..ef992e947 100644 --- a/crates/cubecl-cuda/tests/slice_assign.cu +++ b/crates/cubecl-cuda/tests/slice_assign.cu @@ -1,11 +1,12 @@ +#include typedef unsigned char uint8; typedef unsigned short uint16; typedef unsigned int uint; typedef unsigned long long int uint64; typedef long long int int64; -extern "C" __global__ void kernel(float input_0[], float output_0[], - uint info[]) { +extern "C" __global__ void slice_assign_kernel(float input_0[], + float output_0[], uint info[]) { int threadIdxGlobal = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * (blockDim.x * blockDim.y); diff --git a/crates/cubecl-cuda/tests/unary_bench.cu b/crates/cubecl-cuda/tests/unary_bench.cu index 7042eadc8..653553cc8 100644 --- a/crates/cubecl-cuda/tests/unary_bench.cu +++ b/crates/cubecl-cuda/tests/unary_bench.cu @@ -1,3 +1,4 @@ +#include typedef unsigned char uint8; typedef unsigned short uint16; typedef unsigned int uint; @@ -11,8 +12,10 @@ struct __align__(16) float_4 { float i_3; }; -extern "C" __global__ void kernel(float_4 input_0[], float_4 input_1[], - float_4 output_0[], uint info[]) { +extern "C" __global__ void execute_unary_kernel(float_4 input_0[], + float_4 input_1[], + float_4 output_0[], + uint info[]) { int3 absoluteIdx = make_int3(blockIdx.x * blockDim.x + threadIdx.x, blockIdx.y * blockDim.y + threadIdx.y, diff --git a/crates/cubecl-spirv/src/subgroup.rs b/crates/cubecl-spirv/src/subgroup.rs index 517bcc533..ac7a4dd4b 100644 --- a/crates/cubecl-spirv/src/subgroup.rs +++ b/crates/cubecl-spirv/src/subgroup.rs @@ -20,17 +20,59 @@ impl SpirvCompiler { } Plane::All(op) => { self.capabilities.insert(Capability::GroupNonUniformVote); - self.compile_unary_op(op, out, |b, _, ty, input, out| { - b.group_non_uniform_all(ty, Some(out), subgroup, input) - .unwrap(); - }); + match out.vectorization_factor() { + 1 => { + self.compile_unary_op(op, out, |b, _, ty, input, out| { + b.group_non_uniform_all(ty, Some(out), subgroup, input) + .unwrap(); + }); + } + vec => { + let elem_ty = self.compile_item(op.input.item).elem().id(self); + let bool_ty = self.type_bool(); + + self.compile_unary_op(op, out, |b, _, ty, input, out| { + let ids = (0..vec) + .map(|i| { + let elem_i = b + .composite_extract(elem_ty, None, input, vec![i as u32]) + .unwrap(); + b.group_non_uniform_all(bool_ty, None, subgroup, elem_i) + .unwrap() + }) + .collect::>(); + b.composite_construct(ty, Some(out), ids).unwrap(); + }); + } + }; } Plane::Any(op) => { self.capabilities.insert(Capability::GroupNonUniformVote); - self.compile_unary_op(op, out, |b, _, ty, input, out| { - b.group_non_uniform_any(ty, Some(out), subgroup, input) - .unwrap(); - }); + match out.vectorization_factor() { + 1 => { + self.compile_unary_op(op, out, |b, _, ty, input, out| { + b.group_non_uniform_any(ty, Some(out), subgroup, input) + .unwrap(); + }); + } + vec => { + let elem_ty = self.compile_item(op.input.item).elem().id(self); + let bool_ty = self.type_bool(); + + self.compile_unary_op(op, out, |b, _, ty, input, out| { + let ids = (0..vec) + .map(|i| { + let elem_i = b + .composite_extract(elem_ty, None, input, vec![i as u32]) + .unwrap(); + b.group_non_uniform_any(bool_ty, None, subgroup, elem_i) + .unwrap() + }) + .collect::>(); + b.composite_construct(ty, Some(out), ids).unwrap(); + }); + } + }; } Plane::Broadcast(op) => { self.capabilities.insert(Capability::GroupNonUniformBallot); From 1474ce5f8b259f21cb7c5a52c8fba8e6a24b05ba Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Mon, 25 Nov 2024 12:13:13 +0100 Subject: [PATCH 3/3] Fix TF32 feature registration --- crates/cubecl-cpp/src/cuda/wmma/cuda_compiler.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/cubecl-cpp/src/cuda/wmma/cuda_compiler.rs b/crates/cubecl-cpp/src/cuda/wmma/cuda_compiler.rs index 17b5fdf7e..dc725b23b 100644 --- a/crates/cubecl-cpp/src/cuda/wmma/cuda_compiler.rs +++ b/crates/cubecl-cpp/src/cuda/wmma/cuda_compiler.rs @@ -91,7 +91,7 @@ impl WmmaCompiler> for CudaWmmaCompiler { gpu::Elem::Float(gpu::FloatKind::TF32), gpu::Elem::Float(gpu::FloatKind::TF32), gpu::Elem::Float(gpu::FloatKind::F32), - vec![(16, 8, 16)], + vec![(16, 16, 8)], )); } result