From 816e22689e73a6ac893392cbf885ac9dc1ac1ab9 Mon Sep 17 00:00:00 2001 From: zhijiang <43435212+zhijxu-MS@users.noreply.github.com> Date: Wed, 27 Mar 2024 11:37:10 +0800 Subject: [PATCH] Zhijxu/fix softmax fp16 (#20059) in fp16 input, the softmax will return nan in some case, the reason is because in float16 dtype, std::numeric_limits::infinity() will return 0 instead of inf --- .../cuda/math/softmax_warpwise_impl.cuh | 40 +++++++++---------- .../test/training_ops/cuda/softmax_test.cc | 17 ++++++++ 2 files changed, 37 insertions(+), 20 deletions(-) diff --git a/onnxruntime/core/providers/cuda/math/softmax_warpwise_impl.cuh b/onnxruntime/core/providers/cuda/math/softmax_warpwise_impl.cuh index c1b3d6ada8b77..5e2cec464a86b 100644 --- a/onnxruntime/core/providers/cuda/math/softmax_warpwise_impl.cuh +++ b/onnxruntime/core/providers/cuda/math/softmax_warpwise_impl.cuh @@ -1,18 +1,18 @@ /** -* Copyright (c) 2016-present, Facebook, Inc. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Copyright (c) 2016-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ // The code below is mostly copied from Pytorch PersistentSoftmax.cuh @@ -55,7 +55,6 @@ __device__ __forceinline__ void warp_reduce(acc_t* sum) { } } - // The softmax_warp_* methods perform softmax forward and backward propagation on samples spanning the fast dimension. // Each sample contains element_count scalar elements. element_count can be any integer value <= 1024. // The template arguments have the following meaning: @@ -163,7 +162,6 @@ __global__ void softmax_warp_forward(output_t* dst, const input_t* src, int batc } } - // softmax_warp_forward uses register to store data in fp32 even when data is fp16, which will cause register resource oversubscription when data is large, // and will lead to low CUDA warp occupancy and thus a poor kernel performance. // softmax_warp_forward_resource_efficient is implemented to solve the issue, it caches data in original data type, and casts it into fp32 when needed, @@ -176,17 +174,19 @@ __global__ void softmax_warp_forward_resource_efficient(output_t* dst, const inp constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; int local_idx = threadIdx.x; - src += blockIdx.x * stride + local_idx; - dst += blockIdx.x * stride + local_idx; + src += blockIdx.x * stride + local_idx; + dst += blockIdx.x * stride + local_idx; extern __shared__ unsigned char smem[]; - input_t (&elements)[WARP_ITERATIONS][WARP_SIZE] = *reinterpret_cast(smem); + input_t(&elements)[WARP_ITERATIONS][WARP_SIZE] = *reinterpret_cast(smem); #pragma unroll for (int it = 0; it < WARP_ITERATIONS; ++it) { int element_index = local_idx + it * WARP_SIZE; if (element_index < element_count) { elements[it][local_idx] = src[it * WARP_SIZE]; } else { - elements[it][local_idx] = -std::numeric_limits::infinity(); + static_assert(std::numeric_limits::has_infinity, + "type of acc_t should have infinity to avoid infinity function return 0"); + elements[it][local_idx] = static_cast(-std::numeric_limits::infinity()); } } // compute max_value diff --git a/orttraining/orttraining/test/training_ops/cuda/softmax_test.cc b/orttraining/orttraining/test/training_ops/cuda/softmax_test.cc index 45edac3df2806..ad6ee1e0950e9 100644 --- a/orttraining/orttraining/test/training_ops/cuda/softmax_test.cc +++ b/orttraining/orttraining/test/training_ops/cuda/softmax_test.cc @@ -70,6 +70,23 @@ TEST(CudaKernelTest, Softmax_LargeTensor_LastAxis_Float16_NoPowerOfTwo) { TestSoftmax(X_dims, Y_dims, 2, false, 1e-3, 1e-3); } +TEST(CudaKernelTest, Softmax_LargeTensor_LastAxis_Float16_NoPowerOfTwo2) { + // at fp16 case, when input is all -65504, the output can't be inf + std::vector X_dims{8192, 1, 1050}; + std::vector Y_dims{8192, 1, 1050}; + TestSoftmax(X_dims, Y_dims, 2, false, 1e-3, 1e-3); + CompareOpTester test("Softmax"); + test.AddAttribute("axis", 1); + + std::vector X_data(detail::SizeFromDims(X_dims), (MLFloat16)-65504.0f); + test.AddInput("X", X_dims, X_data); + + std::vector Y_data = FillZeros(Y_dims); + test.AddOutput("Y", Y_dims, Y_data); + + test.CompareWithCPU(kGpuExecutionProvider, 1e-4, 1e-4); +} + TEST(CudaKernelTest, Softmax_LargeTensor_AllAxis) { std::vector X_dims{8, 16, 512}; std::vector Y_dims{8, 16, 512};