diff --git a/onnxruntime/contrib_ops/cuda/bert/layer_norm.cuh b/onnxruntime/contrib_ops/cuda/bert/layer_norm.cuh index 5c083d64ee542..ff3178b56c2a6 100644 --- a/onnxruntime/contrib_ops/cuda/bert/layer_norm.cuh +++ b/onnxruntime/contrib_ops/cuda/bert/layer_norm.cuh @@ -147,14 +147,16 @@ __device__ inline void LayerNormSmall(const T* input_v, const cub::KeyValuePair< __shared__ T rsigma; // 1 / std.dev. T beta_v[ILP], gamma_v[ILP], output_v[ILP]; - if (beta != nullptr) { - VecT* beta_val = reinterpret_cast(&beta_v); - *beta_val = *reinterpret_cast(&beta[threadIdx.x * ILP]); - } - VecT* gamma_val = reinterpret_cast(&gamma_v); - *gamma_val = *reinterpret_cast(&gamma[threadIdx.x * ILP]); + const bool is_valid = ILP * threadIdx.x < ld; + if (is_valid) { + if (beta != nullptr) { + VecT* beta_val = reinterpret_cast(&beta_v); + *beta_val = *reinterpret_cast(&beta[threadIdx.x * ILP]); + } - VecT* output_val = reinterpret_cast(&output_v); + VecT* gamma_val = reinterpret_cast(&gamma_v); + *gamma_val = *reinterpret_cast(&gamma[threadIdx.x * ILP]); + } KeyValuePairSum pair_sum; const cub::KeyValuePair sum_kv = BlockReduce(temp_storage).Reduce(thread_data, pair_sum); @@ -165,13 +167,15 @@ __device__ inline void LayerNormSmall(const T* input_v, const cub::KeyValuePair< } __syncthreads(); - if (ILP * threadIdx.x < ld) { + if (is_valid) { #pragma unroll for (int i = 0; i < ILP; i++) { output_v[i] = (beta != nullptr) ? gamma_v[i] * (input_v[i] - mu) * rsigma + beta_v[i] : gamma_v[i] * (input_v[i] - mu) * rsigma; } + + VecT* output_val = reinterpret_cast(&output_v); *(reinterpret_cast(&output[idx])) = *output_val; } } @@ -186,12 +190,15 @@ __device__ inline void SimplifiedLayerNormSmall(const T* input_v, const T& threa using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; __shared__ T rsigma; // 1 / std.dev. - T gamma_v[ILP], output_v[ILP]; - VecT* gamma_val = reinterpret_cast(&gamma_v); - *gamma_val = *reinterpret_cast(&gamma[threadIdx.x * ILP]); + const bool is_valid = ILP * threadIdx.x < ld; - VecT* output_val = reinterpret_cast(&output_v); + T gamma_v[ILP], output_v[ILP]; + + if (is_valid) { + VecT* gamma_val = reinterpret_cast(&gamma_v); + *gamma_val = *reinterpret_cast(&gamma[threadIdx.x * ILP]); + } const T sum = BlockReduce(temp_storage).Sum(thread_data); @@ -200,11 +207,13 @@ __device__ inline void SimplifiedLayerNormSmall(const T* input_v, const T& threa } __syncthreads(); - if (ILP * threadIdx.x < ld) { + if (is_valid) { #pragma unroll for (int i = 0; i < ILP; i++) { output_v[i] = gamma_v[i] * input_v[i] * rsigma; } + + VecT* output_val = reinterpret_cast(&output_v); *(reinterpret_cast(&output[idx])) = *output_val; } } diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu index 224c2fa38e8c1..e4b09b00f030c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu @@ -52,17 +52,18 @@ half maybe2half(float x) { // Using only power of 2 numbers will lead to waste of compute for same size such as 768, which is a very common case // in BERT. Ideally we can step by wrap_size * num_unroll, but listing too many steps will cause long compile time. constexpr int kSizes[] = {32, 64, 128, 384, 768, 1024, 2048}; +constexpr size_t kNumOfSizes = sizeof(kSizes) / sizeof(kSizes[0]); +constexpr int kMaxSize = kSizes[kNumOfSizes - 1]; constexpr int kMinBlockSize = 32; constexpr int kMaxBlockSize = 256; int NextSize(int x) { - size_t len = sizeof(kSizes) / sizeof(kSizes[0]); - for (size_t i = 0; i < len; ++i) { + for (size_t i = 0; i < kNumOfSizes; ++i) { if (x <= kSizes[i]) { return kSizes[i]; } } - return kSizes[len - 1]; + return kMaxSize; } template @@ -129,25 +130,26 @@ __global__ void SkipLayerNormKernelSmall( const int idx = blockIdx.x * ld + threadIdx.x * ILP; using VecT = aligned_vector; + T sum_v[ILP]; - T skip_v[ILP], bias_v[ILP], sum_v[ILP]; + cub::KeyValuePair thread_data(T(0.f), T(0.f)); - // load input to sum_v - VecT* sum_val = reinterpret_cast(&sum_v); - *sum_val = *reinterpret_cast(&input[idx]); + if (ILP * threadIdx.x < ld) { // load data under this guard to avoid reading out-of-bounds + T skip_v[ILP], bias_v[ILP]; - VecT* skip_val = reinterpret_cast(&skip_v); - *skip_val = *reinterpret_cast(&skip[idx % skip_size]); + // load input to sum_v + VecT* sum_val = reinterpret_cast(&sum_v); + *sum_val = *reinterpret_cast(&input[idx]); - const bool has_bias = (bias != nullptr); - if (has_bias) { - VecT* bias_val = reinterpret_cast(&bias_v); - *bias_val = *reinterpret_cast(&bias[threadIdx.x * ILP]); - } + VecT* skip_val = reinterpret_cast(&skip_v); + *skip_val = *reinterpret_cast(&skip[idx % skip_size]); - cub::KeyValuePair thread_data(T(0.f), T(0.f)); + const bool has_bias = (bias != nullptr); + if (has_bias) { + VecT* bias_val = reinterpret_cast(&bias_v); + *bias_val = *reinterpret_cast(&bias[threadIdx.x * ILP]); + } - if (ILP * threadIdx.x < ld) { T rldval_sum = T(0.f); T rldvalsq_sum = T(0.f); const bool has_sum_output = (sum_output != nullptr); @@ -192,36 +194,48 @@ void LaunchSkipLayerNormKernel( SkipLayerNormKernelSmall<<>>( \ output, sum_output, input, skip, bias, gamma, beta, maybe2half(epsilon), ld, skip_size) -#define LAUNCH_SKIP_LAYER_NORM_KERNEL() \ - SkipLayerNormKernel<<>>( \ +#define LAUNCH_SKIP_LAYER_NORM_KERNEL() \ + SkipLayerNormKernel<<>>( \ output, sum_output, input, skip, bias, gamma, beta, maybe2half(epsilon), ld, skip_size) -#define CASE_NEXT_SIZE(next_size_value) \ - case next_size_value: { \ - if (flag_vec4) { \ - constexpr int block_size = next_size_value / 4; \ - LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(4); \ - } else if (flag_vec2) { \ - constexpr int block_size = next_size_value / 2; \ - LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(2); \ - } else { \ - if (next_size_value <= kMaxBlockSize) { \ - constexpr int block_size = next_size_value; \ - LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(1); \ - } else { \ - LAUNCH_SKIP_LAYER_NORM_KERNEL(); \ - } \ - } \ +#define CASE_NEXT_SIZE(next_size_value) \ + case next_size_value: { \ + static_assert(next_size_value > kSizes[0] && next_size_value < kMaxSize); \ + if (flag_vec4) { \ + constexpr int block_size = next_size_value / 4; \ + LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(4); \ + } else if (flag_vec2) { \ + constexpr int block_size = next_size_value / 2; \ + LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(2); \ + } else { \ + if (next_size_value <= kMaxBlockSize) { \ + constexpr int block_size = next_size_value; \ + LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(1); \ + } else { \ + constexpr int block_size = 256; \ + LAUNCH_SKIP_LAYER_NORM_KERNEL(); \ + } \ + } \ } break switch (next_size) { - CASE_NEXT_SIZE(kSizes[0]); + case kSizes[0]: { + constexpr int block_size = kSizes[0]; + // TODO: Add back the small TensorRT kernel for 32. No need to use vertorized kernel for such small size. + LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(1); + break; + } CASE_NEXT_SIZE(kSizes[1]); CASE_NEXT_SIZE(kSizes[2]); CASE_NEXT_SIZE(kSizes[3]); CASE_NEXT_SIZE(kSizes[4]); CASE_NEXT_SIZE(kSizes[5]); - CASE_NEXT_SIZE(kSizes[6]); + // kMaxSize shall not run vectorized kernel since ld might be larger than kMaxSize. + default: { + constexpr int block_size = 256; + LAUNCH_SKIP_LAYER_NORM_KERNEL(); + break; + } } #undef CASE_NEXT_SIZE