Skip to content

Commit

Permalink
optimize SLN with large dimension (#18138)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->
Optimize SkipLayerNorm for large dimension (>=2048) by handling 8
elements in one thread. It avoid the re-writing and re-loading sum of
input, skip and bias to main memory. It reduces the latency of dimension
4096 with small batch size from ~18us to ~3.8us on A100.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
yufenglee authored Oct 30, 2023
1 parent 348a963 commit 90d1f53
Showing 1 changed file with 45 additions and 37 deletions.
82 changes: 45 additions & 37 deletions onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -51,35 +51,34 @@ half maybe2half(float x) {

// Using only power of 2 numbers will lead to waste of compute for same size such as 768, which is a very common case
// in BERT. Ideally we can step by wrap_size * num_unroll, but listing too many steps will cause long compile time.
constexpr int kSizes[] = {32, 64, 128, 384, 768, 1024, 2048};
constexpr int kSizes[] = {128, 384, 768, 1024, 2048, 4096, 5120, 8192};
constexpr size_t kNumOfSizes = sizeof(kSizes) / sizeof(kSizes[0]);
constexpr int kMaxSize = kSizes[kNumOfSizes - 1];
constexpr int kMinBlockSize = 32;
constexpr int kMaxBlockSize = 256;
constexpr int kMaxBlockSize = 1024;

int NextSize(int x) {
for (size_t i = 0; i < kNumOfSizes; ++i) {
if (x <= kSizes[i]) {
return kSizes[i];
}
}
return kMaxSize;
return kMaxSize + 1;
}

template <typename T, int NumUnroll>
bool CanVectorized(T* output, T* sum_output, const T* input, const T* skip, const T* bias,
const T* gamma, const T* beta, const int ld, const int next_size) {
constexpr int alignment = std::alignment_of<aligned_vector<T, NumUnroll>>::value;
return ld % NumUnroll == 0 &&
bool CanVectorized(void* output, void* sum_output, const void* input, const void* skip, const void* bias,
const void* gamma, const void* beta, const int ld, const int next_size, int num_unroll, int element_size) {
int alignment = element_size * num_unroll;
return ld % num_unroll == 0 &&
reinterpret_cast<uint64_t>(output) % alignment == 0 &&
reinterpret_cast<uint64_t>(sum_output) % alignment == 0 &&
reinterpret_cast<uint64_t>(input) % alignment == 0 &&
reinterpret_cast<uint64_t>(skip) % alignment == 0 &&
reinterpret_cast<uint64_t>(bias) % alignment == 0 &&
reinterpret_cast<uint64_t>(gamma) % alignment == 0 &&
reinterpret_cast<uint64_t>(beta) % alignment == 0 &&
next_size / NumUnroll >= kMinBlockSize &&
next_size / NumUnroll <= kMaxBlockSize;
next_size / num_unroll >= kMinBlockSize &&
next_size / num_unroll <= kMaxBlockSize;
}
} // namespace

Expand Down Expand Up @@ -187,8 +186,14 @@ void LaunchSkipLayerNormKernel(
int ld, int row_count, int skip_size) {
const int next_size = NextSize(ld);
const int grid_size = row_count;
bool flag_vec2 = CanVectorized<T, 2>(output, sum_output, input, skip, bias, gamma, beta, ld, next_size);
bool flag_vec4 = CanVectorized<T, 4>(output, sum_output, input, skip, bias, gamma, beta, ld, next_size);
bool can_unroll_vec4 = CanVectorized(output, sum_output, input,
skip, bias, gamma,
beta, ld, next_size,
4, sizeof(T));
bool can_unroll_vec8 = CanVectorized(output, sum_output, input,
skip, bias, gamma,
beta, ld, next_size,
8, sizeof(T));

#define LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(num_unroll) \
SkipLayerNormKernelSmall<T, block_size, num_unroll, Simplified><<<grid_size, block_size, 0, stream>>>( \
Expand All @@ -198,39 +203,42 @@ void LaunchSkipLayerNormKernel(
SkipLayerNormKernel<T, block_size, Simplified><<<grid_size, block_size, 0, stream>>>( \
output, sum_output, input, skip, bias, gamma, beta, maybe2half<T>(epsilon), ld, skip_size)

#define CASE_NEXT_SIZE(next_size_value) \
case next_size_value: { \
static_assert(next_size_value > kSizes[0] && next_size_value < kMaxSize); \
if (flag_vec4) { \
constexpr int block_size = next_size_value / 4; \
LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(4); \
} else if (flag_vec2) { \
constexpr int block_size = next_size_value / 2; \
LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(2); \
} else { \
if (next_size_value <= kMaxBlockSize) { \
constexpr int block_size = next_size_value; \
LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(1); \
} else { \
constexpr int block_size = 256; \
LAUNCH_SKIP_LAYER_NORM_KERNEL(); \
} \
} \
#define CASE_NEXT_SIZE(next_size_value) \
case next_size_value: { \
static_assert(next_size_value >= kSizes[0] && next_size_value <= kMaxSize); \
if constexpr (next_size_value >= 8 * 256) { \
if (can_unroll_vec8) { \
constexpr int block_size = next_size_value / 8; \
LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(8); \
} else { \
constexpr int block_size = 256; \
LAUNCH_SKIP_LAYER_NORM_KERNEL(); \
} \
} else { \
if (can_unroll_vec4) { \
constexpr int block_size = next_size_value / 4; \
LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(4); \
} else { \
if (next_size_value <= kMaxBlockSize) { \
constexpr int block_size = next_size_value; \
LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(1); \
} else { \
constexpr int block_size = 256; \
LAUNCH_SKIP_LAYER_NORM_KERNEL(); \
} \
} \
} \
} break

switch (next_size) {
case kSizes[0]: {
constexpr int block_size = kSizes[0];
// TODO: Add back the small TensorRT kernel for 32. No need to use vertorized kernel for such small size.
LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(1);
break;
}
CASE_NEXT_SIZE(kSizes[0]);
CASE_NEXT_SIZE(kSizes[1]);
CASE_NEXT_SIZE(kSizes[2]);
CASE_NEXT_SIZE(kSizes[3]);
CASE_NEXT_SIZE(kSizes[4]);
CASE_NEXT_SIZE(kSizes[5]);
// kMaxSize shall not run vectorized kernel since ld might be larger than kMaxSize.
CASE_NEXT_SIZE(kSizes[6]);
CASE_NEXT_SIZE(kSizes[7]);
default: {
constexpr int block_size = 256;
LAUNCH_SKIP_LAYER_NORM_KERNEL();
Expand Down

0 comments on commit 90d1f53

Please sign in to comment.