Skip to content

Commit

Permalink
[CUDA] Fix SkipLayerNorm strict mode when skip has broadcast (#17896)
Browse files Browse the repository at this point in the history
In SLN strict mode, current code (#16510) does not handle skip broadcast
nicely . There are two issues:
(1) skip related parameters is not passed to cuda kernel in strict mode
(2) Strict mode kernel also has bug in handling skip broadcasting (like
cuWelfordMuSigma2 does not handle skip broadcasting).

Here we remove the support of skip broadcasting in strict mode, and
operator will return error message that strict mode only support same
shape of input and skip.

Other changes:
* skip_size is misleading when there is no broadcasting. Change to
correct value.
* Refactor the code to be more efficient: (1) no need to check whether
there is broadcasting in kernel. (2) remove one local buffer (load input
to sum_v directly to save a local buffer copy).
* compute input + bias + skip instead of input + skip + bias. The order
is followed common pattern in transformers model (Here assume graph
fusion will distinguish input and skip correctly, need double check
fusion code later).
* update unit test so that strict mode is triggered in each test case
(unless skip broadcasting) to have higher test coverage.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

SLN strict mode does not support skip broadcast but current code will
silently run (kernel might fail)
  • Loading branch information
tianleiwu authored and jchen351 committed Oct 18, 2023
1 parent 78984d2 commit 62c4cd7
Show file tree
Hide file tree
Showing 6 changed files with 171 additions and 148 deletions.
31 changes: 18 additions & 13 deletions onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "core/providers/cuda/cuda_common.h"
#include "core/providers/cuda/nn/layer_norm_impl.h"
#include "core/common/narrow.h"
#include "skip_layer_norm.h"
#include "skip_layer_norm_impl.h"
#include "contrib_ops/cpu/skip_layer_norm_helper.h"
Expand Down Expand Up @@ -50,23 +51,25 @@ template <typename T, bool Simplified>
Status SkipLayerNorm<T, Simplified>::ComputeInternal(OpKernelContext* ctx) const {
const Tensor* input = ctx->Input<Tensor>(0);
const Tensor* skip = ctx->Input<Tensor>(1);
if (strict_ && skip->Shape() != input->Shape()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"'input' and 'skip' shall have same shape when enable_skip_layer_norm_strict_mode is True");
}

const Tensor* gamma = ctx->Input<Tensor>(2);

const Tensor* beta = Simplified ? nullptr : ctx->Input<Tensor>(3);
const Tensor* bias = Simplified ? ctx->Input<Tensor>(3) : ctx->Input<Tensor>(4);

Tensor* output = ctx->Output(0, input->Shape());

// For inferencing, we support one more optional output which is the sum
// of the input and skip tensors
Tensor* skip_input_bias_add_output = ctx->Output(3, input->Shape());
// Optional output for the sum of skip, input and bias tensors (It is also the input of Layer Normalization).
Tensor* sum_output = ctx->Output(3, input->Shape());

const auto& input_dims = input->Shape().GetDims();
size_t input_dims_size = input_dims.size();
const auto& skip_dims = skip->Shape().GetDims();
size_t skip_dims_size = skip_dims.size();

int hidden_size = static_cast<int>(input_dims[input_dims_size - 1]);
int hidden_size = onnxruntime::narrow<int>(input_dims[input_dims_size - 1]);

ORT_RETURN_IF_ERROR(onnxruntime::contrib::skip_layer_norm_helper::CheckInputs<Tensor>(input,
skip,
Expand All @@ -76,12 +79,15 @@ Status SkipLayerNorm<T, Simplified>::ComputeInternal(OpKernelContext* ctx) const
hidden_size,
input_dims_size));

const bool skip_broadcasted = (skip_dims[0] == 1 || skip_dims_size == 2) ? true : false;
const int skip_size = static_cast<int>(skip_dims[skip_dims_size - 1] * skip_dims[skip_dims_size - 2]);
int row_count = onnxruntime::narrow<int>(input->Shape().SizeToDimension(input_dims_size - 1));
if (row_count == 0) {
return Status::OK();
}

int row_count = gsl::narrow<int>(input->Shape().SizeToDimension(input_dims_size - 1));
typedef typename ToCudaType<T>::MappedType CudaT;

const int skip_size = onnxruntime::narrow<int>(skip->Shape().Size());

if (strict_) {
HostApplyLayerNorm<CudaT, float, CudaT, Simplified>(
GetDeviceProp(),
Expand All @@ -97,21 +103,20 @@ Status SkipLayerNorm<T, Simplified>::ComputeInternal(OpKernelContext* ctx) const
(beta != nullptr) ? reinterpret_cast<const CudaT*>(beta->Data<T>()) : nullptr, // beta
reinterpret_cast<const CudaT*>(skip->Data<T>()), // skip or residual to add
(bias != nullptr) ? reinterpret_cast<const CudaT*>(bias->Data<T>()) : nullptr, // bias to add
skip_input_bias_add_output != nullptr ? reinterpret_cast<CudaT*>(skip_input_bias_add_output->MutableData<T>()) : nullptr);
sum_output != nullptr ? reinterpret_cast<CudaT*>(sum_output->MutableData<T>()) : nullptr);
} else {
LaunchSkipLayerNormKernel<CudaT, Simplified>(
Stream(ctx),
reinterpret_cast<CudaT*>(output->MutableData<T>()),
skip_input_bias_add_output != nullptr ? reinterpret_cast<CudaT*>(skip_input_bias_add_output->MutableData<T>()) : nullptr,
sum_output != nullptr ? reinterpret_cast<CudaT*>(sum_output->MutableData<T>()) : nullptr,
reinterpret_cast<const CudaT*>(input->Data<T>()),
reinterpret_cast<const CudaT*>(skip->Data<T>()),
(bias != nullptr) ? reinterpret_cast<const CudaT*>(bias->Data<T>()) : nullptr,
reinterpret_cast<const CudaT*>(gamma->Data<T>()),
(beta != nullptr) ? reinterpret_cast<const CudaT*>(beta->Data<T>()) : nullptr,
(bias != nullptr) ? reinterpret_cast<const CudaT*>(bias->Data<T>()) : nullptr,
epsilon_,
hidden_size,
row_count,
skip_broadcasted,
skip_size);
}

Expand Down
128 changes: 64 additions & 64 deletions onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -66,45 +66,53 @@ int NextSize(int x) {
}

template <typename T, int NumUnroll>
bool CanVectorized(T* output, T* skip_input_bias_add_output, const T* input, const T* skip, const T* gamma,
const T* beta, const T* bias, const int ld, const int next_size) {
bool CanVectorized(T* output, T* sum_output, const T* input, const T* skip, const T* bias,
const T* gamma, const T* beta, const int ld, const int next_size) {
constexpr int alignment = std::alignment_of<aligned_vector<T, NumUnroll>>::value;
return ld % NumUnroll == 0 && reinterpret_cast<uint64_t>(output) % alignment == 0 &&
reinterpret_cast<uint64_t>(skip_input_bias_add_output) % alignment == 0 &&
reinterpret_cast<uint64_t>(input) % alignment == 0 && reinterpret_cast<uint64_t>(skip) % alignment == 0 &&
reinterpret_cast<uint64_t>(gamma) % alignment == 0 && reinterpret_cast<uint64_t>(beta) % alignment == 0 &&
reinterpret_cast<uint64_t>(bias) % alignment == 0 && next_size / NumUnroll >= kMinBlockSize &&
return ld % NumUnroll == 0 &&
reinterpret_cast<uint64_t>(output) % alignment == 0 &&
reinterpret_cast<uint64_t>(sum_output) % alignment == 0 &&
reinterpret_cast<uint64_t>(input) % alignment == 0 &&
reinterpret_cast<uint64_t>(skip) % alignment == 0 &&
reinterpret_cast<uint64_t>(bias) % alignment == 0 &&
reinterpret_cast<uint64_t>(gamma) % alignment == 0 &&
reinterpret_cast<uint64_t>(beta) % alignment == 0 &&
next_size / NumUnroll >= kMinBlockSize &&
next_size / NumUnroll <= kMaxBlockSize;
}
} // namespace

template <typename T, unsigned TPB, bool Simplified>
__global__ void SkipLayerNormKernel(
const int ld, const T* input, const T* skip,
const T* beta, const T* gamma, const T* bias,
const T epsilon, T* output, T* skip_input_bias_add_output, const bool skip_broadcasted, int skip_size) {
T* output, T* sum_output, const T* input, const T* skip, const T* bias, const T* gamma, const T* beta, T epsilon,
const int ld, int skip_size) {
const T reverse_ld = T(1.f / ld);
const int offset = blockIdx.x * ld;
const bool has_bias = (bias != nullptr);

// Reduce sum of x and x^2, and the results are divided by ld.
KeyValuePairSum pair_sum;
// reduce x and x^2
cub::KeyValuePair<T, T> thread_data(0, 0);

for (int i = threadIdx.x; i < ld; i += TPB) {
const int idx = offset + i;

const T skip_data = skip_broadcasted ? skip[idx % skip_size] : skip[idx];
const T val = (bias == nullptr) ? input[idx] + skip_data : input[idx] + skip_data + bias[i];
T val = input[idx];
if (has_bias) {
val += bias[i];
}
val += skip[idx % skip_size];

const T rldval = reverse_ld * val;
thread_data = pair_sum(thread_data, cub::KeyValuePair<T, T>(rldval, rldval * val));

if (skip_input_bias_add_output != nullptr) {
skip_input_bias_add_output[idx] = val;
if (sum_output != nullptr) {
sum_output[idx] = val;
}

output[idx] = val;
}

if (Simplified) {
SimplifiedLayerNorm<T, TPB>(thread_data.value, ld, offset, gamma, epsilon, output);
return;
Expand All @@ -115,27 +123,24 @@ __global__ void SkipLayerNormKernel(
// Vectorized kernel
template <typename T, unsigned TPB, int ILP, bool Simplified>
__global__ void SkipLayerNormKernelSmall(
const int ld, const T* input, const T* skip, const T* beta, const T* gamma,
const T* bias, const T epsilon, T* output, T* skip_input_bias_add_output,
bool hasBias, bool hasSkipInputBiasAdditionOutput, const bool skip_broadcasted, const int skip_size) {
T* output, T* sum_output, const T* input, const T* skip, const T* bias, const T* gamma, const T* beta, T epsilon,
int ld, int skip_size) {
const T rld = T(1.f / ld);
const int idx = blockIdx.x * ld + threadIdx.x * ILP; // grid_size = n / ld
const int idx = blockIdx.x * ld + threadIdx.x * ILP;

using VecT = aligned_vector<T, ILP>;

T input_v[ILP], skip_v[ILP], bias_v[ILP], skip_input_bias_add_output_v[ILP];
T skip_v[ILP], bias_v[ILP], sum_v[ILP];

VecT* input_val = reinterpret_cast<VecT*>(&input_v);
*input_val = *reinterpret_cast<const VecT*>(&input[idx]);
// load input to sum_v
VecT* sum_val = reinterpret_cast<VecT*>(&sum_v);
*sum_val = *reinterpret_cast<const VecT*>(&input[idx]);

VecT* skip_val = reinterpret_cast<VecT*>(&skip_v);
if (skip_broadcasted) {
*skip_val = *reinterpret_cast<const VecT*>(&skip[idx % skip_size]);
} else {
*skip_val = *reinterpret_cast<const VecT*>(&skip[idx]);
}
*skip_val = *reinterpret_cast<const VecT*>(&skip[idx % skip_size]);

if (hasBias) {
const bool has_bias = (bias != nullptr);
if (has_bias) {
VecT* bias_val = reinterpret_cast<VecT*>(&bias_v);
*bias_val = *reinterpret_cast<const VecT*>(&bias[threadIdx.x * ILP]);
}
Expand All @@ -145,59 +150,52 @@ __global__ void SkipLayerNormKernelSmall(
if (ILP * threadIdx.x < ld) {
T rldval_sum = T(0.f);
T rldvalsq_sum = T(0.f);
const bool has_sum_output = (sum_output != nullptr);

#pragma unroll
for (int i = 0; i < ILP; i++) {
input_v[i] += hasBias ? skip_v[i] + bias_v[i] : skip_v[i];

if (hasSkipInputBiasAdditionOutput) {
skip_input_bias_add_output_v[i] = input_v[i];
if (has_bias) {
sum_v[i] += bias_v[i];
}
sum_v[i] += skip_v[i];

const T rldval = rld * input_v[i];
const T rldval = rld * sum_v[i];
rldval_sum += rldval;
rldvalsq_sum += rldval * input_v[i];
rldvalsq_sum += rldval * sum_v[i];
}

if (hasSkipInputBiasAdditionOutput) {
*(reinterpret_cast<VecT*>(&skip_input_bias_add_output[idx])) = *reinterpret_cast<VecT*>(&skip_input_bias_add_output_v);
if (has_sum_output) {
*(reinterpret_cast<VecT*>(&sum_output[idx])) = *reinterpret_cast<VecT*>(&sum_v);
}

thread_data = cub::KeyValuePair<T, T>(rldval_sum, rldvalsq_sum);
}

if (Simplified) {
SimplifiedLayerNormSmall<T, TPB, ILP>(input_v, thread_data.value, ld, idx, gamma, epsilon, output);
SimplifiedLayerNormSmall<T, TPB, ILP>(sum_v, thread_data.value, ld, idx, gamma, epsilon, output);
return;
}
LayerNormSmall<T, TPB, ILP>(input_v, thread_data, ld, idx, beta, gamma, epsilon, output);
LayerNormSmall<T, TPB, ILP>(sum_v, thread_data, ld, idx, beta, gamma, epsilon, output);
}

template <typename T, bool Simplified>
void LaunchSkipLayerNormKernel(
cudaStream_t stream, T* output, T* skip_input_bias_add_output, const T* input, const T* skip, const T* gamma,
const T* beta, const T* bias, float epsilon, int ld, int row_count, bool skip_broadcasted, int skip_size) {
if (row_count == 0) {
return;
}

bool hasBias = (bias == nullptr) ? false : true;
bool hasSkipInputBiasAdditionOutput = (skip_input_bias_add_output == nullptr) ? false : true;

cudaStream_t stream, T* output, T* sum_output,
const T* input, const T* skip, const T* bias, const T* gamma, const T* beta, float epsilon,
int ld, int row_count, int skip_size) {
const int next_size = NextSize(ld);
const int grid_size = row_count;
bool flag_vec2 =
CanVectorized<T, 2>(output, skip_input_bias_add_output, input, skip, gamma, beta, bias, ld, next_size);
bool flag_vec4 =
CanVectorized<T, 4>(output, skip_input_bias_add_output, input, skip, gamma, beta, bias, ld, next_size);
bool flag_vec2 = CanVectorized<T, 2>(output, sum_output, input, skip, bias, gamma, beta, ld, next_size);
bool flag_vec4 = CanVectorized<T, 4>(output, sum_output, input, skip, bias, gamma, beta, ld, next_size);

#define LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(num_unroll) \
SkipLayerNormKernelSmall<T, block_size, num_unroll, Simplified><<<grid_size, block_size, 0, stream>>>( \
output, sum_output, input, skip, bias, gamma, beta, maybe2half<T>(epsilon), ld, skip_size)

switch (next_size) {
#define LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(num_unroll) \
SkipLayerNormKernelSmall<T, block_size, num_unroll, Simplified> \
<<<grid_size, block_size, 0, stream>>>(ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, \
skip_input_bias_add_output, hasBias, hasSkipInputBiasAdditionOutput, skip_broadcasted, skip_size)
#define LAUNCH_SKIP_LAYER_NORM_KERNEL() \
SkipLayerNormKernel<T, kMaxBlockSize, Simplified><<<grid_size, kMaxBlockSize, 0, stream>>>( \
ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, skip_input_bias_add_output, skip_broadcasted, skip_size)
output, sum_output, input, skip, bias, gamma, beta, maybe2half<T>(epsilon), ld, skip_size)

#define CASE_NEXT_SIZE(next_size_value) \
case next_size_value: { \
if (flag_vec4) { \
Expand All @@ -215,25 +213,27 @@ void LaunchSkipLayerNormKernel(
} \
} \
} break

switch (next_size) {
CASE_NEXT_SIZE(kSizes[0]);
CASE_NEXT_SIZE(kSizes[1]);
CASE_NEXT_SIZE(kSizes[2]);
CASE_NEXT_SIZE(kSizes[3]);
CASE_NEXT_SIZE(kSizes[4]);
CASE_NEXT_SIZE(kSizes[5]);
CASE_NEXT_SIZE(kSizes[6]);
}

#undef CASE_NEXT_SIZE
#undef LAUNCH_SKIP_LAYER_NORM_KERNEL
#undef LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL
}
}

#define SKIPLAYERNORM_IMPL(T, Simplified) \
template void LaunchSkipLayerNormKernel<T, Simplified>(cudaStream_t stream, T * output, \
T * skip_input_bias_add_output, \
const T* input, const T* skip, const T* gamma, \
const T* beta, const T* bias, float epsilon, \
int ld, int row_count, bool skip_broadcasted, int skip_size);
#define SKIPLAYERNORM_IMPL(T, Simplified) \
template void LaunchSkipLayerNormKernel<T, Simplified>(cudaStream_t stream, T * output, T * sum_output, \
const T* input, const T* skip, const T* bias, \
const T* gamma, const T* beta, float epsilon, \
int ld, int row_count, int skip_size);
SKIPLAYERNORM_IMPL(float, true);
SKIPLAYERNORM_IMPL(float, false);
SKIPLAYERNORM_IMPL(half, true);
Expand Down
23 changes: 11 additions & 12 deletions onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,17 @@ namespace cuda {
template <typename T, bool Simplified>
void LaunchSkipLayerNormKernel(
cudaStream_t stream,
T* output, // normalized output tensor
T* skip_input_bias_add_output, // sum of the input and skip (and bias if it exists) tensors output
const T* input, // input tensor
const T* skip, // skip tensor
const T* gamma, // Layer normalization gamma tensor
const T* beta, // Layer normalization beta tensor
const T* bias, // Layer normalization beta tensor
float epsilon, // Layer normalization epsilon
int hidden_size, // hidden size, it is the leading dimension (ld)
int row_count, // number of rows. That is total number of elements divided by hidden size.
bool skip_broadcasted, // determines if broadcasting should be implemented
int skip_size); // determines size of the skip tensor
T* output, // normalized output tensor
T* sum_output, // sum of the input and skip (and bias if it exists) tensors output
const T* input, // input tensor
const T* skip, // skip tensor
const T* bias, // bias tensor
const T* gamma, // Layer normalization gamma tensor
const T* beta, // Layer normalization beta tensor
float epsilon, // Layer normalization epsilon
int hidden_size, // hidden size, it is the leading dimension (ld)
int row_count, // number of rows. That is total number of elements divided by hidden size.
int skip_size); // number of elements of the skip tensor

} // namespace cuda
} // namespace contrib
Expand Down
Loading

0 comments on commit 62c4cd7

Please sign in to comment.