Skip to content

Commit

Permalink
[CUDA] Fix SkipLayerNorm vectorized kernel out-of-bounds read (micros…
Browse files Browse the repository at this point in the history
…oft#17943)

Fix a bug in microsoft#11803:
When hidden size is not exactly same as next size (for example ld=320 in
stable diffusion) current vectorized kernel might read out-of-bounds,
and might cause CUDA failure.

Also resolved another issue: for the first and last size, current macro
will cause some dead code (some branch will never run). Here we change
it to avoid those branches in boundary sizes.

Performance tests with stable diffusion shows that the performance is
on-par before/after this fix.
  • Loading branch information
tianleiwu authored and kleiti committed Mar 22, 2024
1 parent 607a655 commit 9cffe91
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 49 deletions.
35 changes: 22 additions & 13 deletions onnxruntime/contrib_ops/cuda/bert/layer_norm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -147,14 +147,16 @@ __device__ inline void LayerNormSmall(const T* input_v, const cub::KeyValuePair<
__shared__ T rsigma; // 1 / std.dev.
T beta_v[ILP], gamma_v[ILP], output_v[ILP];

if (beta != nullptr) {
VecT* beta_val = reinterpret_cast<VecT*>(&beta_v);
*beta_val = *reinterpret_cast<const VecT*>(&beta[threadIdx.x * ILP]);
}
VecT* gamma_val = reinterpret_cast<VecT*>(&gamma_v);
*gamma_val = *reinterpret_cast<const VecT*>(&gamma[threadIdx.x * ILP]);
const bool is_valid = ILP * threadIdx.x < ld;
if (is_valid) {
if (beta != nullptr) {
VecT* beta_val = reinterpret_cast<VecT*>(&beta_v);
*beta_val = *reinterpret_cast<const VecT*>(&beta[threadIdx.x * ILP]);
}

VecT* output_val = reinterpret_cast<VecT*>(&output_v);
VecT* gamma_val = reinterpret_cast<VecT*>(&gamma_v);
*gamma_val = *reinterpret_cast<const VecT*>(&gamma[threadIdx.x * ILP]);
}

KeyValuePairSum pair_sum;
const cub::KeyValuePair<T, T> sum_kv = BlockReduce(temp_storage).Reduce(thread_data, pair_sum);
Expand All @@ -165,13 +167,15 @@ __device__ inline void LayerNormSmall(const T* input_v, const cub::KeyValuePair<
}
__syncthreads();

if (ILP * threadIdx.x < ld) {
if (is_valid) {
#pragma unroll
for (int i = 0; i < ILP; i++) {
output_v[i] = (beta != nullptr)
? gamma_v[i] * (input_v[i] - mu) * rsigma + beta_v[i]
: gamma_v[i] * (input_v[i] - mu) * rsigma;
}

VecT* output_val = reinterpret_cast<VecT*>(&output_v);
*(reinterpret_cast<VecT*>(&output[idx])) = *output_val;
}
}
Expand All @@ -186,12 +190,15 @@ __device__ inline void SimplifiedLayerNormSmall(const T* input_v, const T& threa
using BlockReduce = cub::BlockReduce<T, TPB>;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ T rsigma; // 1 / std.dev.
T gamma_v[ILP], output_v[ILP];

VecT* gamma_val = reinterpret_cast<VecT*>(&gamma_v);
*gamma_val = *reinterpret_cast<const VecT*>(&gamma[threadIdx.x * ILP]);
const bool is_valid = ILP * threadIdx.x < ld;

VecT* output_val = reinterpret_cast<VecT*>(&output_v);
T gamma_v[ILP], output_v[ILP];

if (is_valid) {
VecT* gamma_val = reinterpret_cast<VecT*>(&gamma_v);
*gamma_val = *reinterpret_cast<const VecT*>(&gamma[threadIdx.x * ILP]);
}

const T sum = BlockReduce(temp_storage).Sum(thread_data);

Expand All @@ -200,11 +207,13 @@ __device__ inline void SimplifiedLayerNormSmall(const T* input_v, const T& threa
}
__syncthreads();

if (ILP * threadIdx.x < ld) {
if (is_valid) {
#pragma unroll
for (int i = 0; i < ILP; i++) {
output_v[i] = gamma_v[i] * input_v[i] * rsigma;
}

VecT* output_val = reinterpret_cast<VecT*>(&output_v);
*(reinterpret_cast<VecT*>(&output[idx])) = *output_val;
}
}
Expand Down
86 changes: 50 additions & 36 deletions onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,18 @@ half maybe2half(float x) {
// Using only power of 2 numbers will lead to waste of compute for same size such as 768, which is a very common case
// in BERT. Ideally we can step by wrap_size * num_unroll, but listing too many steps will cause long compile time.
constexpr int kSizes[] = {32, 64, 128, 384, 768, 1024, 2048};
constexpr size_t kNumOfSizes = sizeof(kSizes) / sizeof(kSizes[0]);
constexpr int kMaxSize = kSizes[kNumOfSizes - 1];
constexpr int kMinBlockSize = 32;
constexpr int kMaxBlockSize = 256;

int NextSize(int x) {
size_t len = sizeof(kSizes) / sizeof(kSizes[0]);
for (size_t i = 0; i < len; ++i) {
for (size_t i = 0; i < kNumOfSizes; ++i) {
if (x <= kSizes[i]) {
return kSizes[i];
}
}
return kSizes[len - 1];
return kMaxSize;
}

template <typename T, int NumUnroll>
Expand Down Expand Up @@ -129,25 +130,26 @@ __global__ void SkipLayerNormKernelSmall(
const int idx = blockIdx.x * ld + threadIdx.x * ILP;

using VecT = aligned_vector<T, ILP>;
T sum_v[ILP];

T skip_v[ILP], bias_v[ILP], sum_v[ILP];
cub::KeyValuePair<T, T> thread_data(T(0.f), T(0.f));

// load input to sum_v
VecT* sum_val = reinterpret_cast<VecT*>(&sum_v);
*sum_val = *reinterpret_cast<const VecT*>(&input[idx]);
if (ILP * threadIdx.x < ld) { // load data under this guard to avoid reading out-of-bounds
T skip_v[ILP], bias_v[ILP];

VecT* skip_val = reinterpret_cast<VecT*>(&skip_v);
*skip_val = *reinterpret_cast<const VecT*>(&skip[idx % skip_size]);
// load input to sum_v
VecT* sum_val = reinterpret_cast<VecT*>(&sum_v);
*sum_val = *reinterpret_cast<const VecT*>(&input[idx]);

const bool has_bias = (bias != nullptr);
if (has_bias) {
VecT* bias_val = reinterpret_cast<VecT*>(&bias_v);
*bias_val = *reinterpret_cast<const VecT*>(&bias[threadIdx.x * ILP]);
}
VecT* skip_val = reinterpret_cast<VecT*>(&skip_v);
*skip_val = *reinterpret_cast<const VecT*>(&skip[idx % skip_size]);

cub::KeyValuePair<T, T> thread_data(T(0.f), T(0.f));
const bool has_bias = (bias != nullptr);
if (has_bias) {
VecT* bias_val = reinterpret_cast<VecT*>(&bias_v);
*bias_val = *reinterpret_cast<const VecT*>(&bias[threadIdx.x * ILP]);
}

if (ILP * threadIdx.x < ld) {
T rldval_sum = T(0.f);
T rldvalsq_sum = T(0.f);
const bool has_sum_output = (sum_output != nullptr);
Expand Down Expand Up @@ -192,36 +194,48 @@ void LaunchSkipLayerNormKernel(
SkipLayerNormKernelSmall<T, block_size, num_unroll, Simplified><<<grid_size, block_size, 0, stream>>>( \
output, sum_output, input, skip, bias, gamma, beta, maybe2half<T>(epsilon), ld, skip_size)

#define LAUNCH_SKIP_LAYER_NORM_KERNEL() \
SkipLayerNormKernel<T, kMaxBlockSize, Simplified><<<grid_size, kMaxBlockSize, 0, stream>>>( \
#define LAUNCH_SKIP_LAYER_NORM_KERNEL() \
SkipLayerNormKernel<T, block_size, Simplified><<<grid_size, block_size, 0, stream>>>( \
output, sum_output, input, skip, bias, gamma, beta, maybe2half<T>(epsilon), ld, skip_size)

#define CASE_NEXT_SIZE(next_size_value) \
case next_size_value: { \
if (flag_vec4) { \
constexpr int block_size = next_size_value / 4; \
LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(4); \
} else if (flag_vec2) { \
constexpr int block_size = next_size_value / 2; \
LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(2); \
} else { \
if (next_size_value <= kMaxBlockSize) { \
constexpr int block_size = next_size_value; \
LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(1); \
} else { \
LAUNCH_SKIP_LAYER_NORM_KERNEL(); \
} \
} \
#define CASE_NEXT_SIZE(next_size_value) \
case next_size_value: { \
static_assert(next_size_value > kSizes[0] && next_size_value < kMaxSize); \
if (flag_vec4) { \
constexpr int block_size = next_size_value / 4; \
LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(4); \
} else if (flag_vec2) { \
constexpr int block_size = next_size_value / 2; \
LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(2); \
} else { \
if (next_size_value <= kMaxBlockSize) { \
constexpr int block_size = next_size_value; \
LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(1); \
} else { \
constexpr int block_size = 256; \
LAUNCH_SKIP_LAYER_NORM_KERNEL(); \
} \
} \
} break

switch (next_size) {
CASE_NEXT_SIZE(kSizes[0]);
case kSizes[0]: {
constexpr int block_size = kSizes[0];
// TODO: Add back the small TensorRT kernel for 32. No need to use vertorized kernel for such small size.
LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(1);
break;
}
CASE_NEXT_SIZE(kSizes[1]);
CASE_NEXT_SIZE(kSizes[2]);
CASE_NEXT_SIZE(kSizes[3]);
CASE_NEXT_SIZE(kSizes[4]);
CASE_NEXT_SIZE(kSizes[5]);
CASE_NEXT_SIZE(kSizes[6]);
// kMaxSize shall not run vectorized kernel since ld might be larger than kMaxSize.
default: {
constexpr int block_size = 256;
LAUNCH_SKIP_LAYER_NORM_KERNEL();
break;
}
}

#undef CASE_NEXT_SIZE
Expand Down

0 comments on commit 9cffe91

Please sign in to comment.