diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu index e4b09b00f030c..973ef8d304e2e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu @@ -51,11 +51,11 @@ half maybe2half(float x) { // Using only power of 2 numbers will lead to waste of compute for same size such as 768, which is a very common case // in BERT. Ideally we can step by wrap_size * num_unroll, but listing too many steps will cause long compile time. -constexpr int kSizes[] = {32, 64, 128, 384, 768, 1024, 2048}; +constexpr int kSizes[] = {128, 384, 768, 1024, 2048, 4096, 5120, 8192}; constexpr size_t kNumOfSizes = sizeof(kSizes) / sizeof(kSizes[0]); constexpr int kMaxSize = kSizes[kNumOfSizes - 1]; constexpr int kMinBlockSize = 32; -constexpr int kMaxBlockSize = 256; +constexpr int kMaxBlockSize = 1024; int NextSize(int x) { for (size_t i = 0; i < kNumOfSizes; ++i) { @@ -63,14 +63,13 @@ int NextSize(int x) { return kSizes[i]; } } - return kMaxSize; + return kMaxSize + 1; } -template -bool CanVectorized(T* output, T* sum_output, const T* input, const T* skip, const T* bias, - const T* gamma, const T* beta, const int ld, const int next_size) { - constexpr int alignment = std::alignment_of>::value; - return ld % NumUnroll == 0 && +bool CanVectorized(void* output, void* sum_output, const void* input, const void* skip, const void* bias, + const void* gamma, const void* beta, const int ld, const int next_size, int num_unroll, int element_size) { + int alignment = element_size * num_unroll; + return ld % num_unroll == 0 && reinterpret_cast(output) % alignment == 0 && reinterpret_cast(sum_output) % alignment == 0 && reinterpret_cast(input) % alignment == 0 && @@ -78,8 +77,8 @@ bool CanVectorized(T* output, T* sum_output, const T* input, const T* skip, cons reinterpret_cast(bias) % alignment == 0 && reinterpret_cast(gamma) % alignment == 0 && reinterpret_cast(beta) % alignment == 0 && - next_size / NumUnroll >= kMinBlockSize && - next_size / NumUnroll <= kMaxBlockSize; + next_size / num_unroll >= kMinBlockSize && + next_size / num_unroll <= kMaxBlockSize; } } // namespace @@ -187,8 +186,14 @@ void LaunchSkipLayerNormKernel( int ld, int row_count, int skip_size) { const int next_size = NextSize(ld); const int grid_size = row_count; - bool flag_vec2 = CanVectorized(output, sum_output, input, skip, bias, gamma, beta, ld, next_size); - bool flag_vec4 = CanVectorized(output, sum_output, input, skip, bias, gamma, beta, ld, next_size); + bool can_unroll_vec4 = CanVectorized(output, sum_output, input, + skip, bias, gamma, + beta, ld, next_size, + 4, sizeof(T)); + bool can_unroll_vec8 = CanVectorized(output, sum_output, input, + skip, bias, gamma, + beta, ld, next_size, + 8, sizeof(T)); #define LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(num_unroll) \ SkipLayerNormKernelSmall<<>>( \ @@ -198,39 +203,42 @@ void LaunchSkipLayerNormKernel( SkipLayerNormKernel<<>>( \ output, sum_output, input, skip, bias, gamma, beta, maybe2half(epsilon), ld, skip_size) -#define CASE_NEXT_SIZE(next_size_value) \ - case next_size_value: { \ - static_assert(next_size_value > kSizes[0] && next_size_value < kMaxSize); \ - if (flag_vec4) { \ - constexpr int block_size = next_size_value / 4; \ - LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(4); \ - } else if (flag_vec2) { \ - constexpr int block_size = next_size_value / 2; \ - LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(2); \ - } else { \ - if (next_size_value <= kMaxBlockSize) { \ - constexpr int block_size = next_size_value; \ - LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(1); \ - } else { \ - constexpr int block_size = 256; \ - LAUNCH_SKIP_LAYER_NORM_KERNEL(); \ - } \ - } \ +#define CASE_NEXT_SIZE(next_size_value) \ + case next_size_value: { \ + static_assert(next_size_value >= kSizes[0] && next_size_value <= kMaxSize); \ + if constexpr (next_size_value >= 8 * 256) { \ + if (can_unroll_vec8) { \ + constexpr int block_size = next_size_value / 8; \ + LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(8); \ + } else { \ + constexpr int block_size = 256; \ + LAUNCH_SKIP_LAYER_NORM_KERNEL(); \ + } \ + } else { \ + if (can_unroll_vec4) { \ + constexpr int block_size = next_size_value / 4; \ + LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(4); \ + } else { \ + if (next_size_value <= kMaxBlockSize) { \ + constexpr int block_size = next_size_value; \ + LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(1); \ + } else { \ + constexpr int block_size = 256; \ + LAUNCH_SKIP_LAYER_NORM_KERNEL(); \ + } \ + } \ + } \ } break switch (next_size) { - case kSizes[0]: { - constexpr int block_size = kSizes[0]; - // TODO: Add back the small TensorRT kernel for 32. No need to use vertorized kernel for such small size. - LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(1); - break; - } + CASE_NEXT_SIZE(kSizes[0]); CASE_NEXT_SIZE(kSizes[1]); CASE_NEXT_SIZE(kSizes[2]); CASE_NEXT_SIZE(kSizes[3]); CASE_NEXT_SIZE(kSizes[4]); CASE_NEXT_SIZE(kSizes[5]); - // kMaxSize shall not run vectorized kernel since ld might be larger than kMaxSize. + CASE_NEXT_SIZE(kSizes[6]); + CASE_NEXT_SIZE(kSizes[7]); default: { constexpr int block_size = 256; LAUNCH_SKIP_LAYER_NORM_KERNEL();