Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zhijxu/fix softmax fp16 #20059

Merged
merged 3 commits into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 20 additions & 20 deletions onnxruntime/core/providers/cuda/math/softmax_warpwise_impl.cuh
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
/**
* Copyright (c) 2016-present, Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
* Copyright (c) 2016-present, Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

// The code below is mostly copied from Pytorch PersistentSoftmax.cuh

Expand Down Expand Up @@ -55,7 +55,6 @@
}
}


// The softmax_warp_* methods perform softmax forward and backward propagation on samples spanning the fast dimension.
// Each sample contains element_count scalar elements. element_count can be any integer value <= 1024.
// The template arguments have the following meaning:
Expand Down Expand Up @@ -163,7 +162,6 @@
}
}


// softmax_warp_forward uses register to store data in fp32 even when data is fp16, which will cause register resource oversubscription when data is large,
// and will lead to low CUDA warp occupancy and thus a poor kernel performance.
// softmax_warp_forward_resource_efficient is implemented to solve the issue, it caches data in original data type, and casts it into fp32 when needed,
Expand All @@ -176,17 +174,19 @@
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;

int local_idx = threadIdx.x;
src += blockIdx.x * stride + local_idx;
dst += blockIdx.x * stride + local_idx;
src += blockIdx.x * stride + local_idx;
dst += blockIdx.x * stride + local_idx;
extern __shared__ unsigned char smem[];
input_t (&elements)[WARP_ITERATIONS][WARP_SIZE] = *reinterpret_cast<input_t (*)[WARP_ITERATIONS][WARP_SIZE]>(smem);
input_t(&elements)[WARP_ITERATIONS][WARP_SIZE] = *reinterpret_cast<input_t(*)[WARP_ITERATIONS][WARP_SIZE]>(smem);
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
int element_index = local_idx + it * WARP_SIZE;
if (element_index < element_count) {
elements[it][local_idx] = src[it * WARP_SIZE];
} else {
elements[it][local_idx] = -std::numeric_limits<input_t>::infinity();
static_assert(std::numeric_limits<acc_t>::has_infinity,
"type of acc_t should have infinity to avoid infinity function return 0");
elements[it][local_idx] = static_cast<input_t>(-std::numeric_limits<acc_t>::infinity());

Check warning on line 189 in onnxruntime/core/providers/cuda/math/softmax_warpwise_impl.cuh

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <limits> for numeric_limits<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/cuda/math/softmax_warpwise_impl.cuh:189: Add #include <limits> for numeric_limits<> [build/include_what_you_use] [4]
}
}
// compute max_value
Expand Down
17 changes: 17 additions & 0 deletions orttraining/orttraining/test/training_ops/cuda/softmax_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,23 @@ TEST(CudaKernelTest, Softmax_LargeTensor_LastAxis_Float16_NoPowerOfTwo) {
TestSoftmax<MLFloat16>(X_dims, Y_dims, 2, false, 1e-3, 1e-3);
}

TEST(CudaKernelTest, Softmax_LargeTensor_LastAxis_Float16_NoPowerOfTwo2) {
// at fp16 case, when input is all -65504, the output can't be inf
std::vector<int64_t> X_dims{8192, 1, 1050};
std::vector<int64_t> Y_dims{8192, 1, 1050};
TestSoftmax<MLFloat16>(X_dims, Y_dims, 2, false, 1e-3, 1e-3);
CompareOpTester test("Softmax");
test.AddAttribute<int64_t>("axis", 1);

std::vector<MLFloat16> X_data(detail::SizeFromDims(X_dims), (MLFloat16)-65504.0f);
test.AddInput<MLFloat16>("X", X_dims, X_data);

std::vector<MLFloat16> Y_data = FillZeros<MLFloat16>(Y_dims);
test.AddOutput<MLFloat16>("Y", Y_dims, Y_data);

test.CompareWithCPU(kGpuExecutionProvider, 1e-4, 1e-4);
}

TEST(CudaKernelTest, Softmax_LargeTensor_AllAxis) {
std::vector<int64_t> X_dims{8, 16, 512};
std::vector<int64_t> Y_dims{8, 16, 512};
Expand Down
Loading