diff --git a/onnxruntime/core/providers/cuda/math/softmax_warpwise_impl.cuh b/onnxruntime/core/providers/cuda/math/softmax_warpwise_impl.cuh index 054f9abc8c071..5e2cec464a86b 100644 --- a/onnxruntime/core/providers/cuda/math/softmax_warpwise_impl.cuh +++ b/onnxruntime/core/providers/cuda/math/softmax_warpwise_impl.cuh @@ -184,8 +184,9 @@ __global__ void softmax_warp_forward_resource_efficient(output_t* dst, const inp if (element_index < element_count) { elements[it][local_idx] = src[it * WARP_SIZE]; } else { - static_assert(!std::is_same::value, "acc_t can no be half, as the infinity function will return 0 instead of inf"); - elements[it][local_idx] = (input_t)-std::numeric_limits::infinity(); + static_assert(std::numeric_limits::has_infinity, + "type of acc_t should have infinity to avoid infinity function return 0"); + elements[it][local_idx] = static_cast(-std::numeric_limits::infinity()); } } // compute max_value