From 03cece09bbca6f6b8c0b204c8b4114e506515db7 Mon Sep 17 00:00:00 2001 From: zhijxu Date: Mon, 25 Mar 2024 15:48:41 +0800 Subject: [PATCH 1/3] fix float16 softmax in float16 dtype, std::numeric_limits::infinity() will return 0 instead of inf --- .../cuda/math/softmax_warpwise_impl.cuh | 39 +++++++++---------- 1 file changed, 19 insertions(+), 20 deletions(-) diff --git a/onnxruntime/core/providers/cuda/math/softmax_warpwise_impl.cuh b/onnxruntime/core/providers/cuda/math/softmax_warpwise_impl.cuh index c1b3d6ada8b77..054f9abc8c071 100644 --- a/onnxruntime/core/providers/cuda/math/softmax_warpwise_impl.cuh +++ b/onnxruntime/core/providers/cuda/math/softmax_warpwise_impl.cuh @@ -1,18 +1,18 @@ /** -* Copyright (c) 2016-present, Facebook, Inc. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Copyright (c) 2016-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ // The code below is mostly copied from Pytorch PersistentSoftmax.cuh @@ -55,7 +55,6 @@ __device__ __forceinline__ void warp_reduce(acc_t* sum) { } } - // The softmax_warp_* methods perform softmax forward and backward propagation on samples spanning the fast dimension. // Each sample contains element_count scalar elements. element_count can be any integer value <= 1024. // The template arguments have the following meaning: @@ -163,7 +162,6 @@ __global__ void softmax_warp_forward(output_t* dst, const input_t* src, int batc } } - // softmax_warp_forward uses register to store data in fp32 even when data is fp16, which will cause register resource oversubscription when data is large, // and will lead to low CUDA warp occupancy and thus a poor kernel performance. // softmax_warp_forward_resource_efficient is implemented to solve the issue, it caches data in original data type, and casts it into fp32 when needed, @@ -176,17 +174,18 @@ __global__ void softmax_warp_forward_resource_efficient(output_t* dst, const inp constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; int local_idx = threadIdx.x; - src += blockIdx.x * stride + local_idx; - dst += blockIdx.x * stride + local_idx; + src += blockIdx.x * stride + local_idx; + dst += blockIdx.x * stride + local_idx; extern __shared__ unsigned char smem[]; - input_t (&elements)[WARP_ITERATIONS][WARP_SIZE] = *reinterpret_cast(smem); + input_t(&elements)[WARP_ITERATIONS][WARP_SIZE] = *reinterpret_cast(smem); #pragma unroll for (int it = 0; it < WARP_ITERATIONS; ++it) { int element_index = local_idx + it * WARP_SIZE; if (element_index < element_count) { elements[it][local_idx] = src[it * WARP_SIZE]; } else { - elements[it][local_idx] = -std::numeric_limits::infinity(); + static_assert(!std::is_same::value, "acc_t can no be half, as the infinity function will return 0 instead of inf"); + elements[it][local_idx] = (input_t)-std::numeric_limits::infinity(); } } // compute max_value From 7d776740c90575a575a324ba0f03e3fb35097767 Mon Sep 17 00:00:00 2001 From: zhijxu Date: Mon, 25 Mar 2024 15:49:37 +0800 Subject: [PATCH 2/3] add float16 softmax test --- .../test/training_ops/cuda/softmax_test.cc | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/orttraining/orttraining/test/training_ops/cuda/softmax_test.cc b/orttraining/orttraining/test/training_ops/cuda/softmax_test.cc index 45edac3df2806..ad6ee1e0950e9 100644 --- a/orttraining/orttraining/test/training_ops/cuda/softmax_test.cc +++ b/orttraining/orttraining/test/training_ops/cuda/softmax_test.cc @@ -70,6 +70,23 @@ TEST(CudaKernelTest, Softmax_LargeTensor_LastAxis_Float16_NoPowerOfTwo) { TestSoftmax(X_dims, Y_dims, 2, false, 1e-3, 1e-3); } +TEST(CudaKernelTest, Softmax_LargeTensor_LastAxis_Float16_NoPowerOfTwo2) { + // at fp16 case, when input is all -65504, the output can't be inf + std::vector X_dims{8192, 1, 1050}; + std::vector Y_dims{8192, 1, 1050}; + TestSoftmax(X_dims, Y_dims, 2, false, 1e-3, 1e-3); + CompareOpTester test("Softmax"); + test.AddAttribute("axis", 1); + + std::vector X_data(detail::SizeFromDims(X_dims), (MLFloat16)-65504.0f); + test.AddInput("X", X_dims, X_data); + + std::vector Y_data = FillZeros(Y_dims); + test.AddOutput("Y", Y_dims, Y_data); + + test.CompareWithCPU(kGpuExecutionProvider, 1e-4, 1e-4); +} + TEST(CudaKernelTest, Softmax_LargeTensor_AllAxis) { std::vector X_dims{8, 16, 512}; std::vector Y_dims{8, 16, 512}; From 8f04cbb179a4e38a66dcdfc79f35d90fe1f077b1 Mon Sep 17 00:00:00 2001 From: zhijxu Date: Tue, 26 Mar 2024 08:43:37 +0800 Subject: [PATCH 3/3] resolve review comments --- .../core/providers/cuda/math/softmax_warpwise_impl.cuh | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/cuda/math/softmax_warpwise_impl.cuh b/onnxruntime/core/providers/cuda/math/softmax_warpwise_impl.cuh index 054f9abc8c071..5e2cec464a86b 100644 --- a/onnxruntime/core/providers/cuda/math/softmax_warpwise_impl.cuh +++ b/onnxruntime/core/providers/cuda/math/softmax_warpwise_impl.cuh @@ -184,8 +184,9 @@ __global__ void softmax_warp_forward_resource_efficient(output_t* dst, const inp if (element_index < element_count) { elements[it][local_idx] = src[it * WARP_SIZE]; } else { - static_assert(!std::is_same::value, "acc_t can no be half, as the infinity function will return 0 instead of inf"); - elements[it][local_idx] = (input_t)-std::numeric_limits::infinity(); + static_assert(std::numeric_limits::has_infinity, + "type of acc_t should have infinity to avoid infinity function return 0"); + elements[it][local_idx] = static_cast(-std::numeric_limits::infinity()); } } // compute max_value