diff --git a/orttraining/orttraining/test/training_ops/cuda/softmax_test.cc b/orttraining/orttraining/test/training_ops/cuda/softmax_test.cc index 45edac3df2806..ad6ee1e0950e9 100644 --- a/orttraining/orttraining/test/training_ops/cuda/softmax_test.cc +++ b/orttraining/orttraining/test/training_ops/cuda/softmax_test.cc @@ -70,6 +70,23 @@ TEST(CudaKernelTest, Softmax_LargeTensor_LastAxis_Float16_NoPowerOfTwo) { TestSoftmax(X_dims, Y_dims, 2, false, 1e-3, 1e-3); } +TEST(CudaKernelTest, Softmax_LargeTensor_LastAxis_Float16_NoPowerOfTwo2) { + // at fp16 case, when input is all -65504, the output can't be inf + std::vector X_dims{8192, 1, 1050}; + std::vector Y_dims{8192, 1, 1050}; + TestSoftmax(X_dims, Y_dims, 2, false, 1e-3, 1e-3); + CompareOpTester test("Softmax"); + test.AddAttribute("axis", 1); + + std::vector X_data(detail::SizeFromDims(X_dims), (MLFloat16)-65504.0f); + test.AddInput("X", X_dims, X_data); + + std::vector Y_data = FillZeros(Y_dims); + test.AddOutput("Y", Y_dims, Y_data); + + test.CompareWithCPU(kGpuExecutionProvider, 1e-4, 1e-4); +} + TEST(CudaKernelTest, Softmax_LargeTensor_AllAxis) { std::vector X_dims{8, 16, 512}; std::vector Y_dims{8, 16, 512};