diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc index 78174181acdc8..3299bc2cb11de 100644 --- a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc @@ -3,6 +3,7 @@ #include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/nn/layer_norm_impl.h" +#include "core/common/narrow.h" #include "skip_layer_norm.h" #include "skip_layer_norm_impl.h" #include "contrib_ops/cpu/skip_layer_norm_helper.h" @@ -50,6 +51,11 @@ template Status SkipLayerNorm::ComputeInternal(OpKernelContext* ctx) const { const Tensor* input = ctx->Input(0); const Tensor* skip = ctx->Input(1); + if (strict_ && skip->Shape() != input->Shape()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "'input' and 'skip' shall have same shape when enable_skip_layer_norm_strict_mode is True"); + } + const Tensor* gamma = ctx->Input(2); const Tensor* beta = Simplified ? nullptr : ctx->Input(3); @@ -57,16 +63,13 @@ Status SkipLayerNorm::ComputeInternal(OpKernelContext* ctx) const Tensor* output = ctx->Output(0, input->Shape()); - // For inferencing, we support one more optional output which is the sum - // of the input and skip tensors - Tensor* skip_input_bias_add_output = ctx->Output(3, input->Shape()); + // Optional output for the sum of skip, input and bias tensors (It is also the input of Layer Normalization). + Tensor* sum_output = ctx->Output(3, input->Shape()); const auto& input_dims = input->Shape().GetDims(); size_t input_dims_size = input_dims.size(); - const auto& skip_dims = skip->Shape().GetDims(); - size_t skip_dims_size = skip_dims.size(); - int hidden_size = static_cast(input_dims[input_dims_size - 1]); + int hidden_size = onnxruntime::narrow(input_dims[input_dims_size - 1]); ORT_RETURN_IF_ERROR(onnxruntime::contrib::skip_layer_norm_helper::CheckInputs(input, skip, @@ -76,12 +79,15 @@ Status SkipLayerNorm::ComputeInternal(OpKernelContext* ctx) const hidden_size, input_dims_size)); - const bool skip_broadcasted = (skip_dims[0] == 1 || skip_dims_size == 2) ? true : false; - const int skip_size = static_cast(skip_dims[skip_dims_size - 1] * skip_dims[skip_dims_size - 2]); + int row_count = onnxruntime::narrow(input->Shape().SizeToDimension(input_dims_size - 1)); + if (row_count == 0) { + return Status::OK(); + } - int row_count = gsl::narrow(input->Shape().SizeToDimension(input_dims_size - 1)); typedef typename ToCudaType::MappedType CudaT; + const int skip_size = onnxruntime::narrow(skip->Shape().Size()); + if (strict_) { HostApplyLayerNorm( GetDeviceProp(), @@ -97,21 +103,20 @@ Status SkipLayerNorm::ComputeInternal(OpKernelContext* ctx) const (beta != nullptr) ? reinterpret_cast(beta->Data()) : nullptr, // beta reinterpret_cast(skip->Data()), // skip or residual to add (bias != nullptr) ? reinterpret_cast(bias->Data()) : nullptr, // bias to add - skip_input_bias_add_output != nullptr ? reinterpret_cast(skip_input_bias_add_output->MutableData()) : nullptr); + sum_output != nullptr ? reinterpret_cast(sum_output->MutableData()) : nullptr); } else { LaunchSkipLayerNormKernel( Stream(ctx), reinterpret_cast(output->MutableData()), - skip_input_bias_add_output != nullptr ? reinterpret_cast(skip_input_bias_add_output->MutableData()) : nullptr, + sum_output != nullptr ? reinterpret_cast(sum_output->MutableData()) : nullptr, reinterpret_cast(input->Data()), reinterpret_cast(skip->Data()), + (bias != nullptr) ? reinterpret_cast(bias->Data()) : nullptr, reinterpret_cast(gamma->Data()), (beta != nullptr) ? reinterpret_cast(beta->Data()) : nullptr, - (bias != nullptr) ? reinterpret_cast(bias->Data()) : nullptr, epsilon_, hidden_size, row_count, - skip_broadcasted, skip_size); } diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu index bfecacf4fb717..224c2fa38e8c1 100644 --- a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu @@ -66,45 +66,53 @@ int NextSize(int x) { } template -bool CanVectorized(T* output, T* skip_input_bias_add_output, const T* input, const T* skip, const T* gamma, - const T* beta, const T* bias, const int ld, const int next_size) { +bool CanVectorized(T* output, T* sum_output, const T* input, const T* skip, const T* bias, + const T* gamma, const T* beta, const int ld, const int next_size) { constexpr int alignment = std::alignment_of>::value; - return ld % NumUnroll == 0 && reinterpret_cast(output) % alignment == 0 && - reinterpret_cast(skip_input_bias_add_output) % alignment == 0 && - reinterpret_cast(input) % alignment == 0 && reinterpret_cast(skip) % alignment == 0 && - reinterpret_cast(gamma) % alignment == 0 && reinterpret_cast(beta) % alignment == 0 && - reinterpret_cast(bias) % alignment == 0 && next_size / NumUnroll >= kMinBlockSize && + return ld % NumUnroll == 0 && + reinterpret_cast(output) % alignment == 0 && + reinterpret_cast(sum_output) % alignment == 0 && + reinterpret_cast(input) % alignment == 0 && + reinterpret_cast(skip) % alignment == 0 && + reinterpret_cast(bias) % alignment == 0 && + reinterpret_cast(gamma) % alignment == 0 && + reinterpret_cast(beta) % alignment == 0 && + next_size / NumUnroll >= kMinBlockSize && next_size / NumUnroll <= kMaxBlockSize; } } // namespace template __global__ void SkipLayerNormKernel( - const int ld, const T* input, const T* skip, - const T* beta, const T* gamma, const T* bias, - const T epsilon, T* output, T* skip_input_bias_add_output, const bool skip_broadcasted, int skip_size) { + T* output, T* sum_output, const T* input, const T* skip, const T* bias, const T* gamma, const T* beta, T epsilon, + const int ld, int skip_size) { const T reverse_ld = T(1.f / ld); const int offset = blockIdx.x * ld; + const bool has_bias = (bias != nullptr); + // Reduce sum of x and x^2, and the results are divided by ld. KeyValuePairSum pair_sum; - // reduce x and x^2 cub::KeyValuePair thread_data(0, 0); for (int i = threadIdx.x; i < ld; i += TPB) { const int idx = offset + i; - const T skip_data = skip_broadcasted ? skip[idx % skip_size] : skip[idx]; - const T val = (bias == nullptr) ? input[idx] + skip_data : input[idx] + skip_data + bias[i]; + T val = input[idx]; + if (has_bias) { + val += bias[i]; + } + val += skip[idx % skip_size]; const T rldval = reverse_ld * val; thread_data = pair_sum(thread_data, cub::KeyValuePair(rldval, rldval * val)); - if (skip_input_bias_add_output != nullptr) { - skip_input_bias_add_output[idx] = val; + if (sum_output != nullptr) { + sum_output[idx] = val; } output[idx] = val; } + if (Simplified) { SimplifiedLayerNorm(thread_data.value, ld, offset, gamma, epsilon, output); return; @@ -115,27 +123,24 @@ __global__ void SkipLayerNormKernel( // Vectorized kernel template __global__ void SkipLayerNormKernelSmall( - const int ld, const T* input, const T* skip, const T* beta, const T* gamma, - const T* bias, const T epsilon, T* output, T* skip_input_bias_add_output, - bool hasBias, bool hasSkipInputBiasAdditionOutput, const bool skip_broadcasted, const int skip_size) { + T* output, T* sum_output, const T* input, const T* skip, const T* bias, const T* gamma, const T* beta, T epsilon, + int ld, int skip_size) { const T rld = T(1.f / ld); - const int idx = blockIdx.x * ld + threadIdx.x * ILP; // grid_size = n / ld + const int idx = blockIdx.x * ld + threadIdx.x * ILP; using VecT = aligned_vector; - T input_v[ILP], skip_v[ILP], bias_v[ILP], skip_input_bias_add_output_v[ILP]; + T skip_v[ILP], bias_v[ILP], sum_v[ILP]; - VecT* input_val = reinterpret_cast(&input_v); - *input_val = *reinterpret_cast(&input[idx]); + // load input to sum_v + VecT* sum_val = reinterpret_cast(&sum_v); + *sum_val = *reinterpret_cast(&input[idx]); VecT* skip_val = reinterpret_cast(&skip_v); - if (skip_broadcasted) { - *skip_val = *reinterpret_cast(&skip[idx % skip_size]); - } else { - *skip_val = *reinterpret_cast(&skip[idx]); - } + *skip_val = *reinterpret_cast(&skip[idx % skip_size]); - if (hasBias) { + const bool has_bias = (bias != nullptr); + if (has_bias) { VecT* bias_val = reinterpret_cast(&bias_v); *bias_val = *reinterpret_cast(&bias[threadIdx.x * ILP]); } @@ -145,59 +150,52 @@ __global__ void SkipLayerNormKernelSmall( if (ILP * threadIdx.x < ld) { T rldval_sum = T(0.f); T rldvalsq_sum = T(0.f); + const bool has_sum_output = (sum_output != nullptr); + #pragma unroll for (int i = 0; i < ILP; i++) { - input_v[i] += hasBias ? skip_v[i] + bias_v[i] : skip_v[i]; - - if (hasSkipInputBiasAdditionOutput) { - skip_input_bias_add_output_v[i] = input_v[i]; + if (has_bias) { + sum_v[i] += bias_v[i]; } + sum_v[i] += skip_v[i]; - const T rldval = rld * input_v[i]; + const T rldval = rld * sum_v[i]; rldval_sum += rldval; - rldvalsq_sum += rldval * input_v[i]; + rldvalsq_sum += rldval * sum_v[i]; } - if (hasSkipInputBiasAdditionOutput) { - *(reinterpret_cast(&skip_input_bias_add_output[idx])) = *reinterpret_cast(&skip_input_bias_add_output_v); + if (has_sum_output) { + *(reinterpret_cast(&sum_output[idx])) = *reinterpret_cast(&sum_v); } thread_data = cub::KeyValuePair(rldval_sum, rldvalsq_sum); } if (Simplified) { - SimplifiedLayerNormSmall(input_v, thread_data.value, ld, idx, gamma, epsilon, output); + SimplifiedLayerNormSmall(sum_v, thread_data.value, ld, idx, gamma, epsilon, output); return; } - LayerNormSmall(input_v, thread_data, ld, idx, beta, gamma, epsilon, output); + LayerNormSmall(sum_v, thread_data, ld, idx, beta, gamma, epsilon, output); } template void LaunchSkipLayerNormKernel( - cudaStream_t stream, T* output, T* skip_input_bias_add_output, const T* input, const T* skip, const T* gamma, - const T* beta, const T* bias, float epsilon, int ld, int row_count, bool skip_broadcasted, int skip_size) { - if (row_count == 0) { - return; - } - - bool hasBias = (bias == nullptr) ? false : true; - bool hasSkipInputBiasAdditionOutput = (skip_input_bias_add_output == nullptr) ? false : true; - + cudaStream_t stream, T* output, T* sum_output, + const T* input, const T* skip, const T* bias, const T* gamma, const T* beta, float epsilon, + int ld, int row_count, int skip_size) { const int next_size = NextSize(ld); const int grid_size = row_count; - bool flag_vec2 = - CanVectorized(output, skip_input_bias_add_output, input, skip, gamma, beta, bias, ld, next_size); - bool flag_vec4 = - CanVectorized(output, skip_input_bias_add_output, input, skip, gamma, beta, bias, ld, next_size); + bool flag_vec2 = CanVectorized(output, sum_output, input, skip, bias, gamma, beta, ld, next_size); + bool flag_vec4 = CanVectorized(output, sum_output, input, skip, bias, gamma, beta, ld, next_size); + +#define LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(num_unroll) \ + SkipLayerNormKernelSmall<<>>( \ + output, sum_output, input, skip, bias, gamma, beta, maybe2half(epsilon), ld, skip_size) - switch (next_size) { -#define LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(num_unroll) \ - SkipLayerNormKernelSmall \ - <<>>(ld, input, skip, beta, gamma, bias, maybe2half(epsilon), output, \ - skip_input_bias_add_output, hasBias, hasSkipInputBiasAdditionOutput, skip_broadcasted, skip_size) #define LAUNCH_SKIP_LAYER_NORM_KERNEL() \ SkipLayerNormKernel<<>>( \ - ld, input, skip, beta, gamma, bias, maybe2half(epsilon), output, skip_input_bias_add_output, skip_broadcasted, skip_size) + output, sum_output, input, skip, bias, gamma, beta, maybe2half(epsilon), ld, skip_size) + #define CASE_NEXT_SIZE(next_size_value) \ case next_size_value: { \ if (flag_vec4) { \ @@ -215,6 +213,8 @@ void LaunchSkipLayerNormKernel( } \ } \ } break + + switch (next_size) { CASE_NEXT_SIZE(kSizes[0]); CASE_NEXT_SIZE(kSizes[1]); CASE_NEXT_SIZE(kSizes[2]); @@ -222,18 +222,18 @@ void LaunchSkipLayerNormKernel( CASE_NEXT_SIZE(kSizes[4]); CASE_NEXT_SIZE(kSizes[5]); CASE_NEXT_SIZE(kSizes[6]); + } + #undef CASE_NEXT_SIZE #undef LAUNCH_SKIP_LAYER_NORM_KERNEL #undef LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL - } } -#define SKIPLAYERNORM_IMPL(T, Simplified) \ - template void LaunchSkipLayerNormKernel(cudaStream_t stream, T * output, \ - T * skip_input_bias_add_output, \ - const T* input, const T* skip, const T* gamma, \ - const T* beta, const T* bias, float epsilon, \ - int ld, int row_count, bool skip_broadcasted, int skip_size); +#define SKIPLAYERNORM_IMPL(T, Simplified) \ + template void LaunchSkipLayerNormKernel(cudaStream_t stream, T * output, T * sum_output, \ + const T* input, const T* skip, const T* bias, \ + const T* gamma, const T* beta, float epsilon, \ + int ld, int row_count, int skip_size); SKIPLAYERNORM_IMPL(float, true); SKIPLAYERNORM_IMPL(float, false); SKIPLAYERNORM_IMPL(half, true); diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.h b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.h index ffb5850c827fe..9727dd6236ec8 100644 --- a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.h @@ -11,18 +11,17 @@ namespace cuda { template void LaunchSkipLayerNormKernel( cudaStream_t stream, - T* output, // normalized output tensor - T* skip_input_bias_add_output, // sum of the input and skip (and bias if it exists) tensors output - const T* input, // input tensor - const T* skip, // skip tensor - const T* gamma, // Layer normalization gamma tensor - const T* beta, // Layer normalization beta tensor - const T* bias, // Layer normalization beta tensor - float epsilon, // Layer normalization epsilon - int hidden_size, // hidden size, it is the leading dimension (ld) - int row_count, // number of rows. That is total number of elements divided by hidden size. - bool skip_broadcasted, // determines if broadcasting should be implemented - int skip_size); // determines size of the skip tensor + T* output, // normalized output tensor + T* sum_output, // sum of the input and skip (and bias if it exists) tensors output + const T* input, // input tensor + const T* skip, // skip tensor + const T* bias, // bias tensor + const T* gamma, // Layer normalization gamma tensor + const T* beta, // Layer normalization beta tensor + float epsilon, // Layer normalization epsilon + int hidden_size, // hidden size, it is the leading dimension (ld) + int row_count, // number of rows. That is total number of elements divided by hidden size. + int skip_size); // number of elements of the skip tensor } // namespace cuda } // namespace contrib diff --git a/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu b/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu index 4cc560a1178ef..679b8b6b78886 100644 --- a/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu +++ b/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu @@ -104,17 +104,17 @@ __device__ void cuWelfordMuSigma2( const int numx = blockDim.x * blockDim.y; const int thrx = threadIdx.x + threadIdx.y * blockDim.x; const T* lvals = vals + i1 * n2; - const T* skip_vals = (skip != NULL) ? skip + i1 * n2 : NULL; + const T* skip_vals = (skip != nullptr) ? skip + i1 * n2 : nullptr; int l = 4 * thrx; for (; l + 3 < n2; l += 4 * numx) { for (int k = 0; k < 4; ++k) { U curr = static_cast(lvals[l + k]); - if (bias != NULL) { + if (bias != nullptr) { curr += static_cast(bias[l + k]); } - if (skip_vals != NULL) { + if (skip_vals != nullptr) { curr += static_cast(skip_vals[l + k]); } @@ -124,11 +124,11 @@ __device__ void cuWelfordMuSigma2( for (; l < n2; ++l) { U curr = static_cast(lvals[l]); - if (bias != NULL) { + if (bias != nullptr) { curr += static_cast(bias[l]); } - if (skip_vals != NULL) { + if (skip_vals != nullptr) { curr += static_cast(skip_vals[l]); } @@ -301,7 +301,7 @@ namespace { // { // extern __device__ void error(void); // error(); -// return NULL; +// return nullptr; // } // }; // https://github.com/NVIDIA/apex/issues/246 @@ -338,9 +338,7 @@ __global__ void cuApplyLayerNorm( const V* __restrict__ beta, const T* __restrict__ skip, const T* __restrict__ bias, - T* __restrict__ skip_input_bias_add_output, - const bool skip_broadcasted, - const int skip_size) { + T* __restrict__ skip_input_bias_add_output) { // Assumptions: // 1) blockDim.x == GPU_WARP_SIZE // 2) Tensors are contiguous @@ -350,38 +348,35 @@ __global__ void cuApplyLayerNorm( U* buf = shared.getPointer(); U mu, sigma2; cuWelfordMuSigma2(vals, n1, n2, i1, mu, sigma2, buf, skip, bias); - const T* lvals = vals + i1 * n2; - const T* skip_vals = (skip != NULL) ? skip + i1 * n2 : NULL; - V* ovals = output_vals + i1 * n2; - T* skip_input_bias_add_ovals = (skip_input_bias_add_output != NULL) ? skip_input_bias_add_output + i1 * n2 : NULL; + const int offset = i1 * n2; + const T* lvals = vals + offset; + const T* skip_vals = (skip != nullptr) ? skip + offset : nullptr; + + V* ovals = output_vals + offset; + T* skip_input_bias_add_ovals = (skip_input_bias_add_output != nullptr) ? skip_input_bias_add_output + offset : nullptr; U c_inv_std_dev = rsqrt(sigma2 + epsilon); const int numx = blockDim.x * blockDim.y; const int thrx = threadIdx.x + threadIdx.y * blockDim.x; for (int i = thrx; i < n2; i += numx) { U curr = static_cast(lvals[i]); - - - if (bias != NULL) { + if (bias != nullptr) { curr += static_cast(bias[i]); } - if (skip_vals != NULL && skip_broadcasted) { - int skip_i = i % skip_size; - curr += static_cast(skip_vals[skip_i]); //Calculates index for the second dimension of the skip tensor - }else if (skip_vals != NULL){ + if (skip_vals != nullptr) { curr += static_cast(skip_vals[i]); } - U gamma_i = (gamma != NULL) ? (U)gamma[i] : (U)1; - U beta_i = (beta != NULL) ? (U)beta[i] : (U)0; + U gamma_i = (gamma != nullptr) ? (U)gamma[i] : (U)1; + U beta_i = (beta != nullptr) ? (U)beta[i] : (U)0; if (simplified) { ovals[i] = static_cast(gamma_i * c_inv_std_dev * curr); } else { ovals[i] = static_cast(gamma_i * c_inv_std_dev * (curr - mu) + beta_i); } - if (skip_input_bias_add_ovals != NULL) { + if (skip_input_bias_add_ovals != nullptr) { skip_input_bias_add_ovals[i] = static_cast(curr); } } @@ -418,9 +413,7 @@ void HostApplyLayerNorm( const V* beta, const T* skip, const T* bias, - T* skip_input_bias_add_output, - const bool skip_broadcasted, - const int skip_size) { + T* skip_input_bias_add_output) { const int maxGridY = prop.maxGridSize[1]; const int warp_size = prop.warpSize; ORT_ENFORCE(warp_size == GPU_WARP_SIZE_HOST); @@ -452,17 +445,14 @@ void HostApplyLayerNorm( n1, n2, U(epsilon), gamma, beta, - skip, bias, skip_input_bias_add_output, - skip_broadcasted, - skip_size); + skip, bias, skip_input_bias_add_output); } #define LAYERNORM_LINEAR_IMPL(T, U, V, simplified) \ template void HostApplyLayerNorm(const cudaDeviceProp& prop, cudaStream_t stream, V* output, \ U* mean, U* inv_std_dev, const T* input, int n1, int n2, \ double epsilon, const V* gamma, const V* beta, const T* skip, \ - const T* bias, T* skip_input_bias_add_output, const bool skip_broadcasted, \ - const int skip_size); + const T* bias, T* skip_input_bias_add_output); LAYERNORM_LINEAR_IMPL(float, float, float, true) LAYERNORM_LINEAR_IMPL(half, float, half, true) diff --git a/onnxruntime/core/providers/cuda/nn/layer_norm_impl.h b/onnxruntime/core/providers/cuda/nn/layer_norm_impl.h index d0d5db8ba3587..e3952eefae35d 100644 --- a/onnxruntime/core/providers/cuda/nn/layer_norm_impl.h +++ b/onnxruntime/core/providers/cuda/nn/layer_norm_impl.h @@ -43,9 +43,7 @@ void HostApplyLayerNorm( const V* beta, const T* skip = nullptr, const T* bias = nullptr, - T* skip_input_bias_add_output = nullptr, - const bool skip_broadcasted = false, - const int skip_size = 0); + T* skip_input_bias_add_output = nullptr); } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc b/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc index 2395532198805..bb56a5aba7f65 100644 --- a/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc +++ b/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc @@ -11,14 +11,15 @@ namespace onnxruntime { namespace test { constexpr float epsilon_ = 1e-12f; -static void RunTest( +static void RunOneTest( + bool strict, const std::vector& input_data, const std::vector& skip_data, const std::vector& gamma_data, const std::vector& beta_data, const std::vector& bias_data, const std::vector& output_data, - const std::vector& skip_input_bias_add_output_data, + const std::vector& sum_output_data, float epsilon, int batch_size, int sequence_length, @@ -27,7 +28,6 @@ static void RunTest( bool no_beta = false, bool simplified = false, bool use_token_count = false, - bool strict = false, bool broadcast_skip = false, bool no_batch_size = false) { // Input and output shapes @@ -82,14 +82,14 @@ static void RunTest( test.AddOutput("output", output_dims, output_data); - if (skip_input_bias_add_output_data.size() != 0) { + if (sum_output_data.size() != 0) { // The second and third outputs are reserved for something else test.AddOptionalOutputEdge(); test.AddOptionalOutputEdge(); test.AddOutput("skip_input_bias_add_output", output_dims, - skip_input_bias_add_output_data); + sum_output_data); } if (cpu_ep != nullptr) { @@ -117,14 +117,19 @@ static void RunTest( test.AddOutput("output", output_dims, ToFloat16(output_data)); - if (skip_input_bias_add_output_data.size() != 0) { + // Use larger threshold for fp16 + if (use_float16) { + test.SetOutputAbsErr("output", 0.01f); + } + + if (sum_output_data.size() != 0) { // The second and third outputs are reserved for something else test.AddOptionalOutputEdge(); test.AddOptionalOutputEdge(); test.AddOutput("skip_input_bias_add_output", output_dims, - ToFloat16(skip_input_bias_add_output_data)); + ToFloat16(sum_output_data)); } if (dml_ep != nullptr) { @@ -151,6 +156,36 @@ static void RunTest( } } +static void RunTest( + const std::vector& input_data, + const std::vector& skip_data, + const std::vector& gamma_data, + const std::vector& beta_data, + const std::vector& bias_data, + const std::vector& output_data, + const std::vector& sum_output_data, + float epsilon, + int batch_size, + int sequence_length, + int hidden_size, + bool use_float16 = false, + bool no_beta = false, + bool simplified = false, + bool use_token_count = false, + bool broadcast_skip = false, + bool no_batch_size = false) { + RunOneTest(false, input_data, skip_data, gamma_data, beta_data, bias_data, output_data, sum_output_data, + epsilon, batch_size, sequence_length, hidden_size, use_float16, no_beta, simplified, + use_token_count, broadcast_skip, no_batch_size); + + // strict mode does not support skip broadcasting. + if (!broadcast_skip) { + RunOneTest(true, input_data, skip_data, gamma_data, beta_data, bias_data, output_data, sum_output_data, + epsilon, batch_size, sequence_length, hidden_size, use_float16, no_beta, simplified, + use_token_count, broadcast_skip, no_batch_size); + } +} + TEST(SkipLayerNormTest, SkipLayerNormNullInput) { int batch_size = 1; int sequence_length = 0; @@ -359,8 +394,7 @@ TEST(SkipLayerNormTest, SkipLayerNormBatch1_Float16_vec) { true /*use_float16*/, false /*no_beta*/, false /*simplified*/, - false /*use_token_count*/, - true /*strict*/); + false /*use_token_count*/); } TEST(SkipLayerNormTest, SkipLayerNormBatch1_NoBeta) { @@ -648,8 +682,7 @@ TEST(SkipLayerNormTest, SkipLayerNormBatch1_Float16_vec_token_count) { true /*use_float16*/, false /*no_beta*/, false /*simplified*/, - true /*use_token_count*/, - true /*strict*/); + true /*use_token_count*/); } TEST(SkipLayerNormTest, SkipLayerNormBatch2_TokenCount) { @@ -776,13 +809,12 @@ TEST(SkipLayerNormTest, SkipLayerNormBatch2_Skip_Broadcast_No_Batch_Size) { batch_size, sequence_length, hidden_size, - false, - false, - false, - false, - false, - false, - true); + false, // use_float16 + false, // no_beta + false, // simplified + false, // use_token_count + true, // broadcast_skip + true); // no_batch_size } TEST(SkipLayerNormTest, SkipLayerNormBatch2_Skip_Broadcast_Batch_Size_1) { @@ -823,13 +855,12 @@ TEST(SkipLayerNormTest, SkipLayerNormBatch2_Skip_Broadcast_Batch_Size_1) { batch_size, sequence_length, hidden_size, - false, - false, - false, - false, - false, - true, - false); + false, // use_float16 + false, // no_beta + false, // simplified + false, // use_token_count + true, // broadcast_skip + false); // no_batch_size } #endif