Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUDA] Fix SkipLayerNorm strict mode when skip has broadcast #17896

Merged
merged 3 commits into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 18 additions & 13 deletions onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "core/providers/cuda/cuda_common.h"
#include "core/providers/cuda/nn/layer_norm_impl.h"
#include "core/common/narrow.h"
#include "skip_layer_norm.h"
#include "skip_layer_norm_impl.h"
#include "contrib_ops/cpu/skip_layer_norm_helper.h"
Expand Down Expand Up @@ -50,23 +51,25 @@ template <typename T, bool Simplified>
Status SkipLayerNorm<T, Simplified>::ComputeInternal(OpKernelContext* ctx) const {
const Tensor* input = ctx->Input<Tensor>(0);
const Tensor* skip = ctx->Input<Tensor>(1);
if (strict_ && skip->Shape() != input->Shape()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"'input' and 'skip' shall have same shape when enable_skip_layer_norm_strict_mode is True");
}

const Tensor* gamma = ctx->Input<Tensor>(2);

const Tensor* beta = Simplified ? nullptr : ctx->Input<Tensor>(3);
const Tensor* bias = Simplified ? ctx->Input<Tensor>(3) : ctx->Input<Tensor>(4);

Tensor* output = ctx->Output(0, input->Shape());

// For inferencing, we support one more optional output which is the sum
// of the input and skip tensors
Tensor* skip_input_bias_add_output = ctx->Output(3, input->Shape());
// Optional output for the sum of skip, input and bias tensors (It is also the input of Layer Normalization).
Tensor* sum_output = ctx->Output(3, input->Shape());

const auto& input_dims = input->Shape().GetDims();
size_t input_dims_size = input_dims.size();
const auto& skip_dims = skip->Shape().GetDims();
size_t skip_dims_size = skip_dims.size();

int hidden_size = static_cast<int>(input_dims[input_dims_size - 1]);
int hidden_size = onnxruntime::narrow<int>(input_dims[input_dims_size - 1]);

ORT_RETURN_IF_ERROR(onnxruntime::contrib::skip_layer_norm_helper::CheckInputs<Tensor>(input,
skip,
Expand All @@ -76,12 +79,15 @@ Status SkipLayerNorm<T, Simplified>::ComputeInternal(OpKernelContext* ctx) const
hidden_size,
input_dims_size));

const bool skip_broadcasted = (skip_dims[0] == 1 || skip_dims_size == 2) ? true : false;
const int skip_size = static_cast<int>(skip_dims[skip_dims_size - 1] * skip_dims[skip_dims_size - 2]);
int row_count = onnxruntime::narrow<int>(input->Shape().SizeToDimension(input_dims_size - 1));
if (row_count == 0) {
return Status::OK();
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved
}

int row_count = gsl::narrow<int>(input->Shape().SizeToDimension(input_dims_size - 1));
typedef typename ToCudaType<T>::MappedType CudaT;

const int skip_size = onnxruntime::narrow<int>(skip->Shape().Size());

if (strict_) {
HostApplyLayerNorm<CudaT, float, CudaT, Simplified>(
GetDeviceProp(),
Expand All @@ -97,21 +103,20 @@ Status SkipLayerNorm<T, Simplified>::ComputeInternal(OpKernelContext* ctx) const
(beta != nullptr) ? reinterpret_cast<const CudaT*>(beta->Data<T>()) : nullptr, // beta
reinterpret_cast<const CudaT*>(skip->Data<T>()), // skip or residual to add
(bias != nullptr) ? reinterpret_cast<const CudaT*>(bias->Data<T>()) : nullptr, // bias to add
skip_input_bias_add_output != nullptr ? reinterpret_cast<CudaT*>(skip_input_bias_add_output->MutableData<T>()) : nullptr);
sum_output != nullptr ? reinterpret_cast<CudaT*>(sum_output->MutableData<T>()) : nullptr);
} else {
LaunchSkipLayerNormKernel<CudaT, Simplified>(
Stream(ctx),
reinterpret_cast<CudaT*>(output->MutableData<T>()),
skip_input_bias_add_output != nullptr ? reinterpret_cast<CudaT*>(skip_input_bias_add_output->MutableData<T>()) : nullptr,
sum_output != nullptr ? reinterpret_cast<CudaT*>(sum_output->MutableData<T>()) : nullptr,
reinterpret_cast<const CudaT*>(input->Data<T>()),
reinterpret_cast<const CudaT*>(skip->Data<T>()),
(bias != nullptr) ? reinterpret_cast<const CudaT*>(bias->Data<T>()) : nullptr,
reinterpret_cast<const CudaT*>(gamma->Data<T>()),
(beta != nullptr) ? reinterpret_cast<const CudaT*>(beta->Data<T>()) : nullptr,
(bias != nullptr) ? reinterpret_cast<const CudaT*>(bias->Data<T>()) : nullptr,
epsilon_,
hidden_size,
row_count,
skip_broadcasted,
skip_size);
}

Expand Down
128 changes: 64 additions & 64 deletions onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -66,45 +66,53 @@ int NextSize(int x) {
}

template <typename T, int NumUnroll>
bool CanVectorized(T* output, T* skip_input_bias_add_output, const T* input, const T* skip, const T* gamma,
const T* beta, const T* bias, const int ld, const int next_size) {
bool CanVectorized(T* output, T* sum_output, const T* input, const T* skip, const T* bias,
const T* gamma, const T* beta, const int ld, const int next_size) {
constexpr int alignment = std::alignment_of<aligned_vector<T, NumUnroll>>::value;
return ld % NumUnroll == 0 && reinterpret_cast<uint64_t>(output) % alignment == 0 &&
reinterpret_cast<uint64_t>(skip_input_bias_add_output) % alignment == 0 &&
reinterpret_cast<uint64_t>(input) % alignment == 0 && reinterpret_cast<uint64_t>(skip) % alignment == 0 &&
reinterpret_cast<uint64_t>(gamma) % alignment == 0 && reinterpret_cast<uint64_t>(beta) % alignment == 0 &&
reinterpret_cast<uint64_t>(bias) % alignment == 0 && next_size / NumUnroll >= kMinBlockSize &&
return ld % NumUnroll == 0 &&
reinterpret_cast<uint64_t>(output) % alignment == 0 &&
reinterpret_cast<uint64_t>(sum_output) % alignment == 0 &&
reinterpret_cast<uint64_t>(input) % alignment == 0 &&
reinterpret_cast<uint64_t>(skip) % alignment == 0 &&
reinterpret_cast<uint64_t>(bias) % alignment == 0 &&
reinterpret_cast<uint64_t>(gamma) % alignment == 0 &&
reinterpret_cast<uint64_t>(beta) % alignment == 0 &&
next_size / NumUnroll >= kMinBlockSize &&
next_size / NumUnroll <= kMaxBlockSize;
}
} // namespace

template <typename T, unsigned TPB, bool Simplified>
__global__ void SkipLayerNormKernel(
const int ld, const T* input, const T* skip,
const T* beta, const T* gamma, const T* bias,
const T epsilon, T* output, T* skip_input_bias_add_output, const bool skip_broadcasted, int skip_size) {
T* output, T* sum_output, const T* input, const T* skip, const T* bias, const T* gamma, const T* beta, T epsilon,
const int ld, int skip_size) {
const T reverse_ld = T(1.f / ld);
const int offset = blockIdx.x * ld;
const bool has_bias = (bias != nullptr);

// Reduce sum of x and x^2, and the results are divided by ld.
KeyValuePairSum pair_sum;
// reduce x and x^2
cub::KeyValuePair<T, T> thread_data(0, 0);

for (int i = threadIdx.x; i < ld; i += TPB) {
const int idx = offset + i;

const T skip_data = skip_broadcasted ? skip[idx % skip_size] : skip[idx];
const T val = (bias == nullptr) ? input[idx] + skip_data : input[idx] + skip_data + bias[i];
T val = input[idx];
if (has_bias) {
val += bias[i];
}
val += skip[idx % skip_size];

const T rldval = reverse_ld * val;
thread_data = pair_sum(thread_data, cub::KeyValuePair<T, T>(rldval, rldval * val));

if (skip_input_bias_add_output != nullptr) {
skip_input_bias_add_output[idx] = val;
if (sum_output != nullptr) {
sum_output[idx] = val;
}

output[idx] = val;
}

if (Simplified) {
SimplifiedLayerNorm<T, TPB>(thread_data.value, ld, offset, gamma, epsilon, output);
return;
Expand All @@ -115,27 +123,24 @@ __global__ void SkipLayerNormKernel(
// Vectorized kernel
template <typename T, unsigned TPB, int ILP, bool Simplified>
__global__ void SkipLayerNormKernelSmall(
const int ld, const T* input, const T* skip, const T* beta, const T* gamma,
const T* bias, const T epsilon, T* output, T* skip_input_bias_add_output,
bool hasBias, bool hasSkipInputBiasAdditionOutput, const bool skip_broadcasted, const int skip_size) {
T* output, T* sum_output, const T* input, const T* skip, const T* bias, const T* gamma, const T* beta, T epsilon,
int ld, int skip_size) {
const T rld = T(1.f / ld);
const int idx = blockIdx.x * ld + threadIdx.x * ILP; // grid_size = n / ld
const int idx = blockIdx.x * ld + threadIdx.x * ILP;

using VecT = aligned_vector<T, ILP>;

T input_v[ILP], skip_v[ILP], bias_v[ILP], skip_input_bias_add_output_v[ILP];
T skip_v[ILP], bias_v[ILP], sum_v[ILP];

VecT* input_val = reinterpret_cast<VecT*>(&input_v);
*input_val = *reinterpret_cast<const VecT*>(&input[idx]);
// load input to sum_v
VecT* sum_val = reinterpret_cast<VecT*>(&sum_v);
*sum_val = *reinterpret_cast<const VecT*>(&input[idx]);

VecT* skip_val = reinterpret_cast<VecT*>(&skip_v);
if (skip_broadcasted) {
*skip_val = *reinterpret_cast<const VecT*>(&skip[idx % skip_size]);
} else {
*skip_val = *reinterpret_cast<const VecT*>(&skip[idx]);
}
*skip_val = *reinterpret_cast<const VecT*>(&skip[idx % skip_size]);

if (hasBias) {
const bool has_bias = (bias != nullptr);
if (has_bias) {
VecT* bias_val = reinterpret_cast<VecT*>(&bias_v);
*bias_val = *reinterpret_cast<const VecT*>(&bias[threadIdx.x * ILP]);
}
Expand All @@ -145,59 +150,52 @@ __global__ void SkipLayerNormKernelSmall(
if (ILP * threadIdx.x < ld) {
T rldval_sum = T(0.f);
T rldvalsq_sum = T(0.f);
const bool has_sum_output = (sum_output != nullptr);

#pragma unroll
for (int i = 0; i < ILP; i++) {
input_v[i] += hasBias ? skip_v[i] + bias_v[i] : skip_v[i];

if (hasSkipInputBiasAdditionOutput) {
skip_input_bias_add_output_v[i] = input_v[i];
if (has_bias) {
sum_v[i] += bias_v[i];
}
sum_v[i] += skip_v[i];

const T rldval = rld * input_v[i];
const T rldval = rld * sum_v[i];
rldval_sum += rldval;
rldvalsq_sum += rldval * input_v[i];
rldvalsq_sum += rldval * sum_v[i];
}

if (hasSkipInputBiasAdditionOutput) {
*(reinterpret_cast<VecT*>(&skip_input_bias_add_output[idx])) = *reinterpret_cast<VecT*>(&skip_input_bias_add_output_v);
if (has_sum_output) {
*(reinterpret_cast<VecT*>(&sum_output[idx])) = *reinterpret_cast<VecT*>(&sum_v);
}

thread_data = cub::KeyValuePair<T, T>(rldval_sum, rldvalsq_sum);
}

if (Simplified) {
SimplifiedLayerNormSmall<T, TPB, ILP>(input_v, thread_data.value, ld, idx, gamma, epsilon, output);
SimplifiedLayerNormSmall<T, TPB, ILP>(sum_v, thread_data.value, ld, idx, gamma, epsilon, output);
return;
}
LayerNormSmall<T, TPB, ILP>(input_v, thread_data, ld, idx, beta, gamma, epsilon, output);
LayerNormSmall<T, TPB, ILP>(sum_v, thread_data, ld, idx, beta, gamma, epsilon, output);
}

template <typename T, bool Simplified>
void LaunchSkipLayerNormKernel(
cudaStream_t stream, T* output, T* skip_input_bias_add_output, const T* input, const T* skip, const T* gamma,
const T* beta, const T* bias, float epsilon, int ld, int row_count, bool skip_broadcasted, int skip_size) {
if (row_count == 0) {
return;
}

bool hasBias = (bias == nullptr) ? false : true;
bool hasSkipInputBiasAdditionOutput = (skip_input_bias_add_output == nullptr) ? false : true;

cudaStream_t stream, T* output, T* sum_output,
const T* input, const T* skip, const T* bias, const T* gamma, const T* beta, float epsilon,
int ld, int row_count, int skip_size) {
const int next_size = NextSize(ld);
const int grid_size = row_count;
bool flag_vec2 =
CanVectorized<T, 2>(output, skip_input_bias_add_output, input, skip, gamma, beta, bias, ld, next_size);
bool flag_vec4 =
CanVectorized<T, 4>(output, skip_input_bias_add_output, input, skip, gamma, beta, bias, ld, next_size);
bool flag_vec2 = CanVectorized<T, 2>(output, sum_output, input, skip, bias, gamma, beta, ld, next_size);
bool flag_vec4 = CanVectorized<T, 4>(output, sum_output, input, skip, bias, gamma, beta, ld, next_size);

#define LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(num_unroll) \
SkipLayerNormKernelSmall<T, block_size, num_unroll, Simplified><<<grid_size, block_size, 0, stream>>>( \
output, sum_output, input, skip, bias, gamma, beta, maybe2half<T>(epsilon), ld, skip_size)

switch (next_size) {
#define LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(num_unroll) \
SkipLayerNormKernelSmall<T, block_size, num_unroll, Simplified> \
<<<grid_size, block_size, 0, stream>>>(ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, \
skip_input_bias_add_output, hasBias, hasSkipInputBiasAdditionOutput, skip_broadcasted, skip_size)
#define LAUNCH_SKIP_LAYER_NORM_KERNEL() \
SkipLayerNormKernel<T, kMaxBlockSize, Simplified><<<grid_size, kMaxBlockSize, 0, stream>>>( \
ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, skip_input_bias_add_output, skip_broadcasted, skip_size)
output, sum_output, input, skip, bias, gamma, beta, maybe2half<T>(epsilon), ld, skip_size)

#define CASE_NEXT_SIZE(next_size_value) \
case next_size_value: { \
if (flag_vec4) { \
Expand All @@ -215,25 +213,27 @@ void LaunchSkipLayerNormKernel(
} \
} \
} break

switch (next_size) {
CASE_NEXT_SIZE(kSizes[0]);
CASE_NEXT_SIZE(kSizes[1]);
CASE_NEXT_SIZE(kSizes[2]);
CASE_NEXT_SIZE(kSizes[3]);
CASE_NEXT_SIZE(kSizes[4]);
CASE_NEXT_SIZE(kSizes[5]);
CASE_NEXT_SIZE(kSizes[6]);
}

#undef CASE_NEXT_SIZE
#undef LAUNCH_SKIP_LAYER_NORM_KERNEL
#undef LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL
}
}

#define SKIPLAYERNORM_IMPL(T, Simplified) \
template void LaunchSkipLayerNormKernel<T, Simplified>(cudaStream_t stream, T * output, \
T * skip_input_bias_add_output, \
const T* input, const T* skip, const T* gamma, \
const T* beta, const T* bias, float epsilon, \
int ld, int row_count, bool skip_broadcasted, int skip_size);
#define SKIPLAYERNORM_IMPL(T, Simplified) \
template void LaunchSkipLayerNormKernel<T, Simplified>(cudaStream_t stream, T * output, T * sum_output, \
const T* input, const T* skip, const T* bias, \
const T* gamma, const T* beta, float epsilon, \
int ld, int row_count, int skip_size);
SKIPLAYERNORM_IMPL(float, true);
SKIPLAYERNORM_IMPL(float, false);
SKIPLAYERNORM_IMPL(half, true);
Expand Down
23 changes: 11 additions & 12 deletions onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,17 @@ namespace cuda {
template <typename T, bool Simplified>
void LaunchSkipLayerNormKernel(
cudaStream_t stream,
T* output, // normalized output tensor
T* skip_input_bias_add_output, // sum of the input and skip (and bias if it exists) tensors output
const T* input, // input tensor
const T* skip, // skip tensor
const T* gamma, // Layer normalization gamma tensor
const T* beta, // Layer normalization beta tensor
const T* bias, // Layer normalization beta tensor
float epsilon, // Layer normalization epsilon
int hidden_size, // hidden size, it is the leading dimension (ld)
int row_count, // number of rows. That is total number of elements divided by hidden size.
bool skip_broadcasted, // determines if broadcasting should be implemented
int skip_size); // determines size of the skip tensor
T* output, // normalized output tensor
T* sum_output, // sum of the input and skip (and bias if it exists) tensors output
const T* input, // input tensor
const T* skip, // skip tensor
const T* bias, // bias tensor
const T* gamma, // Layer normalization gamma tensor
const T* beta, // Layer normalization beta tensor
float epsilon, // Layer normalization epsilon
int hidden_size, // hidden size, it is the leading dimension (ld)
int row_count, // number of rows. That is total number of elements divided by hidden size.
int skip_size); // number of elements of the skip tensor

} // namespace cuda
} // namespace contrib
Expand Down
Loading
Loading