From 62c4cd7bda75460ac44418357d3308a31d84edff Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 13 Oct 2023 07:51:37 -0700 Subject: [PATCH] [CUDA] Fix SkipLayerNorm strict mode when skip has broadcast (#17896) In SLN strict mode, current code (#16510) does not handle skip broadcast nicely . There are two issues: (1) skip related parameters is not passed to cuda kernel in strict mode (2) Strict mode kernel also has bug in handling skip broadcasting (like cuWelfordMuSigma2 does not handle skip broadcasting). Here we remove the support of skip broadcasting in strict mode, and operator will return error message that strict mode only support same shape of input and skip. Other changes: * skip_size is misleading when there is no broadcasting. Change to correct value. * Refactor the code to be more efficient: (1) no need to check whether there is broadcasting in kernel. (2) remove one local buffer (load input to sum_v directly to save a local buffer copy). * compute input + bias + skip instead of input + skip + bias. The order is followed common pattern in transformers model (Here assume graph fusion will distinguish input and skip correctly, need double check fusion code later). * update unit test so that strict mode is triggered in each test case (unless skip broadcasting) to have higher test coverage. ### Motivation and Context SLN strict mode does not support skip broadcast but current code will silently run (kernel might fail) --- .../contrib_ops/cuda/bert/skip_layer_norm.cc | 31 +++-- .../cuda/bert/skip_layer_norm_impl.cu | 128 +++++++++--------- .../cuda/bert/skip_layer_norm_impl.h | 23 ++-- .../core/providers/cuda/nn/layer_norm_impl.cu | 52 +++---- .../core/providers/cuda/nn/layer_norm_impl.h | 4 +- .../test/contrib_ops/skiplayernorm_op_test.cc | 81 +++++++---- 6 files changed, 171 insertions(+), 148 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc index 78174181acdc8..3299bc2cb11de 100644 --- a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc @@ -3,6 +3,7 @@ #include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/nn/layer_norm_impl.h" +#include "core/common/narrow.h" #include "skip_layer_norm.h" #include "skip_layer_norm_impl.h" #include "contrib_ops/cpu/skip_layer_norm_helper.h" @@ -50,6 +51,11 @@ template Status SkipLayerNorm::ComputeInternal(OpKernelContext* ctx) const { const Tensor* input = ctx->Input(0); const Tensor* skip = ctx->Input(1); + if (strict_ && skip->Shape() != input->Shape()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "'input' and 'skip' shall have same shape when enable_skip_layer_norm_strict_mode is True"); + } + const Tensor* gamma = ctx->Input(2); const Tensor* beta = Simplified ? nullptr : ctx->Input(3); @@ -57,16 +63,13 @@ Status SkipLayerNorm::ComputeInternal(OpKernelContext* ctx) const Tensor* output = ctx->Output(0, input->Shape()); - // For inferencing, we support one more optional output which is the sum - // of the input and skip tensors - Tensor* skip_input_bias_add_output = ctx->Output(3, input->Shape()); + // Optional output for the sum of skip, input and bias tensors (It is also the input of Layer Normalization). + Tensor* sum_output = ctx->Output(3, input->Shape()); const auto& input_dims = input->Shape().GetDims(); size_t input_dims_size = input_dims.size(); - const auto& skip_dims = skip->Shape().GetDims(); - size_t skip_dims_size = skip_dims.size(); - int hidden_size = static_cast(input_dims[input_dims_size - 1]); + int hidden_size = onnxruntime::narrow(input_dims[input_dims_size - 1]); ORT_RETURN_IF_ERROR(onnxruntime::contrib::skip_layer_norm_helper::CheckInputs(input, skip, @@ -76,12 +79,15 @@ Status SkipLayerNorm::ComputeInternal(OpKernelContext* ctx) const hidden_size, input_dims_size)); - const bool skip_broadcasted = (skip_dims[0] == 1 || skip_dims_size == 2) ? true : false; - const int skip_size = static_cast(skip_dims[skip_dims_size - 1] * skip_dims[skip_dims_size - 2]); + int row_count = onnxruntime::narrow(input->Shape().SizeToDimension(input_dims_size - 1)); + if (row_count == 0) { + return Status::OK(); + } - int row_count = gsl::narrow(input->Shape().SizeToDimension(input_dims_size - 1)); typedef typename ToCudaType::MappedType CudaT; + const int skip_size = onnxruntime::narrow(skip->Shape().Size()); + if (strict_) { HostApplyLayerNorm( GetDeviceProp(), @@ -97,21 +103,20 @@ Status SkipLayerNorm::ComputeInternal(OpKernelContext* ctx) const (beta != nullptr) ? reinterpret_cast(beta->Data()) : nullptr, // beta reinterpret_cast(skip->Data()), // skip or residual to add (bias != nullptr) ? reinterpret_cast(bias->Data()) : nullptr, // bias to add - skip_input_bias_add_output != nullptr ? reinterpret_cast(skip_input_bias_add_output->MutableData()) : nullptr); + sum_output != nullptr ? reinterpret_cast(sum_output->MutableData()) : nullptr); } else { LaunchSkipLayerNormKernel( Stream(ctx), reinterpret_cast(output->MutableData()), - skip_input_bias_add_output != nullptr ? reinterpret_cast(skip_input_bias_add_output->MutableData()) : nullptr, + sum_output != nullptr ? reinterpret_cast(sum_output->MutableData()) : nullptr, reinterpret_cast(input->Data()), reinterpret_cast(skip->Data()), + (bias != nullptr) ? reinterpret_cast(bias->Data()) : nullptr, reinterpret_cast(gamma->Data()), (beta != nullptr) ? reinterpret_cast(beta->Data()) : nullptr, - (bias != nullptr) ? reinterpret_cast(bias->Data()) : nullptr, epsilon_, hidden_size, row_count, - skip_broadcasted, skip_size); } diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu index bfecacf4fb717..224c2fa38e8c1 100644 --- a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu @@ -66,45 +66,53 @@ int NextSize(int x) { } template -bool CanVectorized(T* output, T* skip_input_bias_add_output, const T* input, const T* skip, const T* gamma, - const T* beta, const T* bias, const int ld, const int next_size) { +bool CanVectorized(T* output, T* sum_output, const T* input, const T* skip, const T* bias, + const T* gamma, const T* beta, const int ld, const int next_size) { constexpr int alignment = std::alignment_of>::value; - return ld % NumUnroll == 0 && reinterpret_cast(output) % alignment == 0 && - reinterpret_cast(skip_input_bias_add_output) % alignment == 0 && - reinterpret_cast(input) % alignment == 0 && reinterpret_cast(skip) % alignment == 0 && - reinterpret_cast(gamma) % alignment == 0 && reinterpret_cast(beta) % alignment == 0 && - reinterpret_cast(bias) % alignment == 0 && next_size / NumUnroll >= kMinBlockSize && + return ld % NumUnroll == 0 && + reinterpret_cast(output) % alignment == 0 && + reinterpret_cast(sum_output) % alignment == 0 && + reinterpret_cast(input) % alignment == 0 && + reinterpret_cast(skip) % alignment == 0 && + reinterpret_cast(bias) % alignment == 0 && + reinterpret_cast(gamma) % alignment == 0 && + reinterpret_cast(beta) % alignment == 0 && + next_size / NumUnroll >= kMinBlockSize && next_size / NumUnroll <= kMaxBlockSize; } } // namespace template __global__ void SkipLayerNormKernel( - const int ld, const T* input, const T* skip, - const T* beta, const T* gamma, const T* bias, - const T epsilon, T* output, T* skip_input_bias_add_output, const bool skip_broadcasted, int skip_size) { + T* output, T* sum_output, const T* input, const T* skip, const T* bias, const T* gamma, const T* beta, T epsilon, + const int ld, int skip_size) { const T reverse_ld = T(1.f / ld); const int offset = blockIdx.x * ld; + const bool has_bias = (bias != nullptr); + // Reduce sum of x and x^2, and the results are divided by ld. KeyValuePairSum pair_sum; - // reduce x and x^2 cub::KeyValuePair thread_data(0, 0); for (int i = threadIdx.x; i < ld; i += TPB) { const int idx = offset + i; - const T skip_data = skip_broadcasted ? skip[idx % skip_size] : skip[idx]; - const T val = (bias == nullptr) ? input[idx] + skip_data : input[idx] + skip_data + bias[i]; + T val = input[idx]; + if (has_bias) { + val += bias[i]; + } + val += skip[idx % skip_size]; const T rldval = reverse_ld * val; thread_data = pair_sum(thread_data, cub::KeyValuePair(rldval, rldval * val)); - if (skip_input_bias_add_output != nullptr) { - skip_input_bias_add_output[idx] = val; + if (sum_output != nullptr) { + sum_output[idx] = val; } output[idx] = val; } + if (Simplified) { SimplifiedLayerNorm(thread_data.value, ld, offset, gamma, epsilon, output); return; @@ -115,27 +123,24 @@ __global__ void SkipLayerNormKernel( // Vectorized kernel template __global__ void SkipLayerNormKernelSmall( - const int ld, const T* input, const T* skip, const T* beta, const T* gamma, - const T* bias, const T epsilon, T* output, T* skip_input_bias_add_output, - bool hasBias, bool hasSkipInputBiasAdditionOutput, const bool skip_broadcasted, const int skip_size) { + T* output, T* sum_output, const T* input, const T* skip, const T* bias, const T* gamma, const T* beta, T epsilon, + int ld, int skip_size) { const T rld = T(1.f / ld); - const int idx = blockIdx.x * ld + threadIdx.x * ILP; // grid_size = n / ld + const int idx = blockIdx.x * ld + threadIdx.x * ILP; using VecT = aligned_vector; - T input_v[ILP], skip_v[ILP], bias_v[ILP], skip_input_bias_add_output_v[ILP]; + T skip_v[ILP], bias_v[ILP], sum_v[ILP]; - VecT* input_val = reinterpret_cast(&input_v); - *input_val = *reinterpret_cast(&input[idx]); + // load input to sum_v + VecT* sum_val = reinterpret_cast(&sum_v); + *sum_val = *reinterpret_cast(&input[idx]); VecT* skip_val = reinterpret_cast(&skip_v); - if (skip_broadcasted) { - *skip_val = *reinterpret_cast(&skip[idx % skip_size]); - } else { - *skip_val = *reinterpret_cast(&skip[idx]); - } + *skip_val = *reinterpret_cast(&skip[idx % skip_size]); - if (hasBias) { + const bool has_bias = (bias != nullptr); + if (has_bias) { VecT* bias_val = reinterpret_cast(&bias_v); *bias_val = *reinterpret_cast(&bias[threadIdx.x * ILP]); } @@ -145,59 +150,52 @@ __global__ void SkipLayerNormKernelSmall( if (ILP * threadIdx.x < ld) { T rldval_sum = T(0.f); T rldvalsq_sum = T(0.f); + const bool has_sum_output = (sum_output != nullptr); + #pragma unroll for (int i = 0; i < ILP; i++) { - input_v[i] += hasBias ? skip_v[i] + bias_v[i] : skip_v[i]; - - if (hasSkipInputBiasAdditionOutput) { - skip_input_bias_add_output_v[i] = input_v[i]; + if (has_bias) { + sum_v[i] += bias_v[i]; } + sum_v[i] += skip_v[i]; - const T rldval = rld * input_v[i]; + const T rldval = rld * sum_v[i]; rldval_sum += rldval; - rldvalsq_sum += rldval * input_v[i]; + rldvalsq_sum += rldval * sum_v[i]; } - if (hasSkipInputBiasAdditionOutput) { - *(reinterpret_cast(&skip_input_bias_add_output[idx])) = *reinterpret_cast(&skip_input_bias_add_output_v); + if (has_sum_output) { + *(reinterpret_cast(&sum_output[idx])) = *reinterpret_cast(&sum_v); } thread_data = cub::KeyValuePair(rldval_sum, rldvalsq_sum); } if (Simplified) { - SimplifiedLayerNormSmall(input_v, thread_data.value, ld, idx, gamma, epsilon, output); + SimplifiedLayerNormSmall(sum_v, thread_data.value, ld, idx, gamma, epsilon, output); return; } - LayerNormSmall(input_v, thread_data, ld, idx, beta, gamma, epsilon, output); + LayerNormSmall(sum_v, thread_data, ld, idx, beta, gamma, epsilon, output); } template void LaunchSkipLayerNormKernel( - cudaStream_t stream, T* output, T* skip_input_bias_add_output, const T* input, const T* skip, const T* gamma, - const T* beta, const T* bias, float epsilon, int ld, int row_count, bool skip_broadcasted, int skip_size) { - if (row_count == 0) { - return; - } - - bool hasBias = (bias == nullptr) ? false : true; - bool hasSkipInputBiasAdditionOutput = (skip_input_bias_add_output == nullptr) ? false : true; - + cudaStream_t stream, T* output, T* sum_output, + const T* input, const T* skip, const T* bias, const T* gamma, const T* beta, float epsilon, + int ld, int row_count, int skip_size) { const int next_size = NextSize(ld); const int grid_size = row_count; - bool flag_vec2 = - CanVectorized(output, skip_input_bias_add_output, input, skip, gamma, beta, bias, ld, next_size); - bool flag_vec4 = - CanVectorized(output, skip_input_bias_add_output, input, skip, gamma, beta, bias, ld, next_size); + bool flag_vec2 = CanVectorized(output, sum_output, input, skip, bias, gamma, beta, ld, next_size); + bool flag_vec4 = CanVectorized(output, sum_output, input, skip, bias, gamma, beta, ld, next_size); + +#define LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(num_unroll) \ + SkipLayerNormKernelSmall<<>>( \ + output, sum_output, input, skip, bias, gamma, beta, maybe2half(epsilon), ld, skip_size) - switch (next_size) { -#define LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(num_unroll) \ - SkipLayerNormKernelSmall \ - <<>>(ld, input, skip, beta, gamma, bias, maybe2half(epsilon), output, \ - skip_input_bias_add_output, hasBias, hasSkipInputBiasAdditionOutput, skip_broadcasted, skip_size) #define LAUNCH_SKIP_LAYER_NORM_KERNEL() \ SkipLayerNormKernel<<>>( \ - ld, input, skip, beta, gamma, bias, maybe2half(epsilon), output, skip_input_bias_add_output, skip_broadcasted, skip_size) + output, sum_output, input, skip, bias, gamma, beta, maybe2half(epsilon), ld, skip_size) + #define CASE_NEXT_SIZE(next_size_value) \ case next_size_value: { \ if (flag_vec4) { \ @@ -215,6 +213,8 @@ void LaunchSkipLayerNormKernel( } \ } \ } break + + switch (next_size) { CASE_NEXT_SIZE(kSizes[0]); CASE_NEXT_SIZE(kSizes[1]); CASE_NEXT_SIZE(kSizes[2]); @@ -222,18 +222,18 @@ void LaunchSkipLayerNormKernel( CASE_NEXT_SIZE(kSizes[4]); CASE_NEXT_SIZE(kSizes[5]); CASE_NEXT_SIZE(kSizes[6]); + } + #undef CASE_NEXT_SIZE #undef LAUNCH_SKIP_LAYER_NORM_KERNEL #undef LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL - } } -#define SKIPLAYERNORM_IMPL(T, Simplified) \ - template void LaunchSkipLayerNormKernel(cudaStream_t stream, T * output, \ - T * skip_input_bias_add_output, \ - const T* input, const T* skip, const T* gamma, \ - const T* beta, const T* bias, float epsilon, \ - int ld, int row_count, bool skip_broadcasted, int skip_size); +#define SKIPLAYERNORM_IMPL(T, Simplified) \ + template void LaunchSkipLayerNormKernel(cudaStream_t stream, T * output, T * sum_output, \ + const T* input, const T* skip, const T* bias, \ + const T* gamma, const T* beta, float epsilon, \ + int ld, int row_count, int skip_size); SKIPLAYERNORM_IMPL(float, true); SKIPLAYERNORM_IMPL(float, false); SKIPLAYERNORM_IMPL(half, true); diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.h b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.h index ffb5850c827fe..9727dd6236ec8 100644 --- a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.h @@ -11,18 +11,17 @@ namespace cuda { template void LaunchSkipLayerNormKernel( cudaStream_t stream, - T* output, // normalized output tensor - T* skip_input_bias_add_output, // sum of the input and skip (and bias if it exists) tensors output - const T* input, // input tensor - const T* skip, // skip tensor - const T* gamma, // Layer normalization gamma tensor - const T* beta, // Layer normalization beta tensor - const T* bias, // Layer normalization beta tensor - float epsilon, // Layer normalization epsilon - int hidden_size, // hidden size, it is the leading dimension (ld) - int row_count, // number of rows. That is total number of elements divided by hidden size. - bool skip_broadcasted, // determines if broadcasting should be implemented - int skip_size); // determines size of the skip tensor + T* output, // normalized output tensor + T* sum_output, // sum of the input and skip (and bias if it exists) tensors output + const T* input, // input tensor + const T* skip, // skip tensor + const T* bias, // bias tensor + const T* gamma, // Layer normalization gamma tensor + const T* beta, // Layer normalization beta tensor + float epsilon, // Layer normalization epsilon + int hidden_size, // hidden size, it is the leading dimension (ld) + int row_count, // number of rows. That is total number of elements divided by hidden size. + int skip_size); // number of elements of the skip tensor } // namespace cuda } // namespace contrib diff --git a/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu b/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu index 4cc560a1178ef..679b8b6b78886 100644 --- a/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu +++ b/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu @@ -104,17 +104,17 @@ __device__ void cuWelfordMuSigma2( const int numx = blockDim.x * blockDim.y; const int thrx = threadIdx.x + threadIdx.y * blockDim.x; const T* lvals = vals + i1 * n2; - const T* skip_vals = (skip != NULL) ? skip + i1 * n2 : NULL; + const T* skip_vals = (skip != nullptr) ? skip + i1 * n2 : nullptr; int l = 4 * thrx; for (; l + 3 < n2; l += 4 * numx) { for (int k = 0; k < 4; ++k) { U curr = static_cast(lvals[l + k]); - if (bias != NULL) { + if (bias != nullptr) { curr += static_cast(bias[l + k]); } - if (skip_vals != NULL) { + if (skip_vals != nullptr) { curr += static_cast(skip_vals[l + k]); } @@ -124,11 +124,11 @@ __device__ void cuWelfordMuSigma2( for (; l < n2; ++l) { U curr = static_cast(lvals[l]); - if (bias != NULL) { + if (bias != nullptr) { curr += static_cast(bias[l]); } - if (skip_vals != NULL) { + if (skip_vals != nullptr) { curr += static_cast(skip_vals[l]); } @@ -301,7 +301,7 @@ namespace { // { // extern __device__ void error(void); // error(); -// return NULL; +// return nullptr; // } // }; // https://github.com/NVIDIA/apex/issues/246 @@ -338,9 +338,7 @@ __global__ void cuApplyLayerNorm( const V* __restrict__ beta, const T* __restrict__ skip, const T* __restrict__ bias, - T* __restrict__ skip_input_bias_add_output, - const bool skip_broadcasted, - const int skip_size) { + T* __restrict__ skip_input_bias_add_output) { // Assumptions: // 1) blockDim.x == GPU_WARP_SIZE // 2) Tensors are contiguous @@ -350,38 +348,35 @@ __global__ void cuApplyLayerNorm( U* buf = shared.getPointer(); U mu, sigma2; cuWelfordMuSigma2(vals, n1, n2, i1, mu, sigma2, buf, skip, bias); - const T* lvals = vals + i1 * n2; - const T* skip_vals = (skip != NULL) ? skip + i1 * n2 : NULL; - V* ovals = output_vals + i1 * n2; - T* skip_input_bias_add_ovals = (skip_input_bias_add_output != NULL) ? skip_input_bias_add_output + i1 * n2 : NULL; + const int offset = i1 * n2; + const T* lvals = vals + offset; + const T* skip_vals = (skip != nullptr) ? skip + offset : nullptr; + + V* ovals = output_vals + offset; + T* skip_input_bias_add_ovals = (skip_input_bias_add_output != nullptr) ? skip_input_bias_add_output + offset : nullptr; U c_inv_std_dev = rsqrt(sigma2 + epsilon); const int numx = blockDim.x * blockDim.y; const int thrx = threadIdx.x + threadIdx.y * blockDim.x; for (int i = thrx; i < n2; i += numx) { U curr = static_cast(lvals[i]); - - - if (bias != NULL) { + if (bias != nullptr) { curr += static_cast(bias[i]); } - if (skip_vals != NULL && skip_broadcasted) { - int skip_i = i % skip_size; - curr += static_cast(skip_vals[skip_i]); //Calculates index for the second dimension of the skip tensor - }else if (skip_vals != NULL){ + if (skip_vals != nullptr) { curr += static_cast(skip_vals[i]); } - U gamma_i = (gamma != NULL) ? (U)gamma[i] : (U)1; - U beta_i = (beta != NULL) ? (U)beta[i] : (U)0; + U gamma_i = (gamma != nullptr) ? (U)gamma[i] : (U)1; + U beta_i = (beta != nullptr) ? (U)beta[i] : (U)0; if (simplified) { ovals[i] = static_cast(gamma_i * c_inv_std_dev * curr); } else { ovals[i] = static_cast(gamma_i * c_inv_std_dev * (curr - mu) + beta_i); } - if (skip_input_bias_add_ovals != NULL) { + if (skip_input_bias_add_ovals != nullptr) { skip_input_bias_add_ovals[i] = static_cast(curr); } } @@ -418,9 +413,7 @@ void HostApplyLayerNorm( const V* beta, const T* skip, const T* bias, - T* skip_input_bias_add_output, - const bool skip_broadcasted, - const int skip_size) { + T* skip_input_bias_add_output) { const int maxGridY = prop.maxGridSize[1]; const int warp_size = prop.warpSize; ORT_ENFORCE(warp_size == GPU_WARP_SIZE_HOST); @@ -452,17 +445,14 @@ void HostApplyLayerNorm( n1, n2, U(epsilon), gamma, beta, - skip, bias, skip_input_bias_add_output, - skip_broadcasted, - skip_size); + skip, bias, skip_input_bias_add_output); } #define LAYERNORM_LINEAR_IMPL(T, U, V, simplified) \ template void HostApplyLayerNorm(const cudaDeviceProp& prop, cudaStream_t stream, V* output, \ U* mean, U* inv_std_dev, const T* input, int n1, int n2, \ double epsilon, const V* gamma, const V* beta, const T* skip, \ - const T* bias, T* skip_input_bias_add_output, const bool skip_broadcasted, \ - const int skip_size); + const T* bias, T* skip_input_bias_add_output); LAYERNORM_LINEAR_IMPL(float, float, float, true) LAYERNORM_LINEAR_IMPL(half, float, half, true) diff --git a/onnxruntime/core/providers/cuda/nn/layer_norm_impl.h b/onnxruntime/core/providers/cuda/nn/layer_norm_impl.h index d0d5db8ba3587..e3952eefae35d 100644 --- a/onnxruntime/core/providers/cuda/nn/layer_norm_impl.h +++ b/onnxruntime/core/providers/cuda/nn/layer_norm_impl.h @@ -43,9 +43,7 @@ void HostApplyLayerNorm( const V* beta, const T* skip = nullptr, const T* bias = nullptr, - T* skip_input_bias_add_output = nullptr, - const bool skip_broadcasted = false, - const int skip_size = 0); + T* skip_input_bias_add_output = nullptr); } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc b/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc index 2395532198805..bb56a5aba7f65 100644 --- a/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc +++ b/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc @@ -11,14 +11,15 @@ namespace onnxruntime { namespace test { constexpr float epsilon_ = 1e-12f; -static void RunTest( +static void RunOneTest( + bool strict, const std::vector& input_data, const std::vector& skip_data, const std::vector& gamma_data, const std::vector& beta_data, const std::vector& bias_data, const std::vector& output_data, - const std::vector& skip_input_bias_add_output_data, + const std::vector& sum_output_data, float epsilon, int batch_size, int sequence_length, @@ -27,7 +28,6 @@ static void RunTest( bool no_beta = false, bool simplified = false, bool use_token_count = false, - bool strict = false, bool broadcast_skip = false, bool no_batch_size = false) { // Input and output shapes @@ -82,14 +82,14 @@ static void RunTest( test.AddOutput("output", output_dims, output_data); - if (skip_input_bias_add_output_data.size() != 0) { + if (sum_output_data.size() != 0) { // The second and third outputs are reserved for something else test.AddOptionalOutputEdge(); test.AddOptionalOutputEdge(); test.AddOutput("skip_input_bias_add_output", output_dims, - skip_input_bias_add_output_data); + sum_output_data); } if (cpu_ep != nullptr) { @@ -117,14 +117,19 @@ static void RunTest( test.AddOutput("output", output_dims, ToFloat16(output_data)); - if (skip_input_bias_add_output_data.size() != 0) { + // Use larger threshold for fp16 + if (use_float16) { + test.SetOutputAbsErr("output", 0.01f); + } + + if (sum_output_data.size() != 0) { // The second and third outputs are reserved for something else test.AddOptionalOutputEdge(); test.AddOptionalOutputEdge(); test.AddOutput("skip_input_bias_add_output", output_dims, - ToFloat16(skip_input_bias_add_output_data)); + ToFloat16(sum_output_data)); } if (dml_ep != nullptr) { @@ -151,6 +156,36 @@ static void RunTest( } } +static void RunTest( + const std::vector& input_data, + const std::vector& skip_data, + const std::vector& gamma_data, + const std::vector& beta_data, + const std::vector& bias_data, + const std::vector& output_data, + const std::vector& sum_output_data, + float epsilon, + int batch_size, + int sequence_length, + int hidden_size, + bool use_float16 = false, + bool no_beta = false, + bool simplified = false, + bool use_token_count = false, + bool broadcast_skip = false, + bool no_batch_size = false) { + RunOneTest(false, input_data, skip_data, gamma_data, beta_data, bias_data, output_data, sum_output_data, + epsilon, batch_size, sequence_length, hidden_size, use_float16, no_beta, simplified, + use_token_count, broadcast_skip, no_batch_size); + + // strict mode does not support skip broadcasting. + if (!broadcast_skip) { + RunOneTest(true, input_data, skip_data, gamma_data, beta_data, bias_data, output_data, sum_output_data, + epsilon, batch_size, sequence_length, hidden_size, use_float16, no_beta, simplified, + use_token_count, broadcast_skip, no_batch_size); + } +} + TEST(SkipLayerNormTest, SkipLayerNormNullInput) { int batch_size = 1; int sequence_length = 0; @@ -359,8 +394,7 @@ TEST(SkipLayerNormTest, SkipLayerNormBatch1_Float16_vec) { true /*use_float16*/, false /*no_beta*/, false /*simplified*/, - false /*use_token_count*/, - true /*strict*/); + false /*use_token_count*/); } TEST(SkipLayerNormTest, SkipLayerNormBatch1_NoBeta) { @@ -648,8 +682,7 @@ TEST(SkipLayerNormTest, SkipLayerNormBatch1_Float16_vec_token_count) { true /*use_float16*/, false /*no_beta*/, false /*simplified*/, - true /*use_token_count*/, - true /*strict*/); + true /*use_token_count*/); } TEST(SkipLayerNormTest, SkipLayerNormBatch2_TokenCount) { @@ -776,13 +809,12 @@ TEST(SkipLayerNormTest, SkipLayerNormBatch2_Skip_Broadcast_No_Batch_Size) { batch_size, sequence_length, hidden_size, - false, - false, - false, - false, - false, - false, - true); + false, // use_float16 + false, // no_beta + false, // simplified + false, // use_token_count + true, // broadcast_skip + true); // no_batch_size } TEST(SkipLayerNormTest, SkipLayerNormBatch2_Skip_Broadcast_Batch_Size_1) { @@ -823,13 +855,12 @@ TEST(SkipLayerNormTest, SkipLayerNormBatch2_Skip_Broadcast_Batch_Size_1) { batch_size, sequence_length, hidden_size, - false, - false, - false, - false, - false, - true, - false); + false, // use_float16 + false, // no_beta + false, // simplified + false, // use_token_count + true, // broadcast_skip + false); // no_batch_size } #endif