Skip to content

Commit

Permalink
[BeamSearch]optimize key cache reordering (microsoft#17771)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. --> 

Replace
onnxruntime::cuda::Transpose4DKernelParallelizeMultipleElementsPerThreadInInnermostDim()
with custom transpose kernel in ReorderPastState(). The original
implementation doesn't benefit from vectorized loading and coalesced
accessing(write). and not fully utilize threads in the block.

benchmarked with TNLGv4 model(batch=4, seq_len=4K)
transpose kernel speed up: ~1.9X (392 μs -> 206 μs)
overall reordering speedup: ~1.48X

Latency:
before:

![image](https://github.com/microsoft/onnxruntime/assets/52801275/34c7ab73-3da1-4c41-a036-e9fb6a966891)
after:

![image](https://github.com/microsoft/onnxruntime/assets/52801275/337818ec-9598-4d8a-9e9b-7215b6862498)

GPU matrix:
before:

![image](https://github.com/microsoft/onnxruntime/assets/52801275/4962248f-703c-49bd-8586-deaeccd9bce0)
after:

![image](https://github.com/microsoft/onnxruntime/assets/52801275/a795a892-4c5d-432d-8375-0bb67385d2bc)


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

---------

Co-authored-by: Your Name <[email protected]>
  • Loading branch information
2 people authored and kleiti committed Mar 22, 2024
1 parent 5cccf57 commit f4d37a2
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 23 deletions.
61 changes: 61 additions & 0 deletions onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1315,6 +1315,67 @@ template void BufferExpansionKernelLauncher(const int32_t* input,
int chunk_size,
cudaStream_t stream);

// Support head_size up to 128
constexpr unsigned int kTileSize = 32;
constexpr unsigned int kSeqTileSize = 16;

__global__ void ReorderPastStatesKernel(float4* out_buffer,
const float4* in_buffer,
int batch_size,
int num_heads,
int max_length,
int chunked_head_size) {
__shared__ float4 tile[kSeqTileSize][kTileSize + 1];

const int b = blockIdx.z;
const int n = blockIdx.y;
const int s_base = blockIdx.x * kSeqTileSize;
const int s = s_base + threadIdx.y;
const int base_offset = (b * num_heads + n) * max_length * chunked_head_size;

if (s < max_length) {
const int in_offset = base_offset + s * chunked_head_size + threadIdx.x;
tile[threadIdx.y][threadIdx.x] = in_buffer[in_offset];
}

__syncthreads();

const int tidx = threadIdx.x + threadIdx.y * chunked_head_size;
const int tidx_x = tidx % kSeqTileSize;
const int tidx_y = tidx / kSeqTileSize;

const int s2 = s_base + tidx_x;

if (s2 < max_length) {
const int out_offset = base_offset + tidx_y * max_length + s2;
out_buffer[out_offset] = tile[tidx_x][tidx_y];
}
}

void ReorderPastStatesKernelLauncher(void* out_buffer,
const void* in_buffer,
int batch_size,
int num_heads,
int max_length,
int head_size,
int chunk_size,
cudaStream_t stream) {
//[B, N, max_length, H2(head_size/chunk_size), equv_chunk_size] -> [B, N, H2(head_size/chunk_size), max_length, equv_chunk_size]
const int chunked_head_size = head_size / chunk_size;
const dim3 block(chunked_head_size, kSeqTileSize);
const dim3 grid((max_length + kSeqTileSize - 1) / kSeqTileSize, num_heads, batch_size);
if (chunk_size == 4 || chunk_size == 8) {
ReorderPastStatesKernel<<<grid, block, 0, stream>>>(reinterpret_cast<float4*>(out_buffer),
reinterpret_cast<const float4*>(in_buffer),
batch_size,
num_heads,
max_length,
chunked_head_size);
} else {
ORT_THROW("ReorderPastStatesKernelLauncher only support float or half");
}
}

} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,14 @@ void BufferExpansionKernelLauncher(const T* input,
int chunk_size,
cudaStream_t stream);

void ReorderPastStatesKernelLauncher(void* out_buffer,
const void* in_buffer,
int batch_size,
int num_heads,
int max_length,
int head_size,
int chunk_size,
cudaStream_t stream);
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -56,19 +56,23 @@ namespace GenerationCudaDeviceHelper {
// It might be better to forcefully require the same type since cast node generates
// extra overhead.
Status ReorderPastState(
const void* cuda_device_prop,
const void*,
Tensor& past_state,
Tensor& past_state_staging,
Stream* stream) {
ORT_ENFORCE(stream);
cudaStream_t cuda_stream = reinterpret_cast<cudaStream_t>(stream->GetHandle());
cublasHandle_t cublas_handle = static_cast<CudaStream*>(stream)->cublas_handle_;

const auto& past_state_shape = past_state.Shape();

const auto& past_state_dims = past_state_shape.GetDims();
const bool packed_past = past_state_dims.size() == 5;

size_t batch_size = packed_past ? past_state_dims[1] : past_state_dims[0];
size_t num_heads = packed_past ? past_state_dims[2] : past_state_dims[1];
size_t max_length = packed_past ? past_state_dims[3] : past_state_dims[2];
size_t head_size = packed_past ? past_state_dims[4] : past_state_dims[3];

// Copy the 'K' values into the temp staging buffer
size_t past_state_size = packed_past ? past_state.SizeInBytes() / 2 : past_state.SizeInBytes();
void* past_state_staging_buffer = past_state_staging.MutableDataRaw();
Expand All @@ -79,27 +83,16 @@ Status ReorderPastState(
// [B, N, head_size / x, max_length, x], where x = 16 / sizeof(T)
int64_t chunk_size = static_cast<int64_t>(16 / past_state.DataType()->Size());

std::vector<size_t> permutation_vector = {0, 1, 3, 2, 4};
gsl::span<size_t> permutation(permutation_vector.data(), 5);

// "Fake" the shapes of the input and output tensors of the Transpose operation to suit our need
size_t offset = packed_past ? 1 : 0;
TensorShape transpose_input_shape_override = {past_state_shape[offset],
past_state_shape[offset + 1],
past_state_shape[offset + 2],
past_state_shape[offset + 3] / chunk_size,
chunk_size};

TensorShape transpose_output_shape_override = {past_state_shape[offset], past_state_shape[offset + 1],
past_state_shape[offset + 3] / chunk_size, past_state_shape[offset + 2],
chunk_size};

// TODO(hasesh): Explore perf tuning for this Transpose operation
return onnxruntime::cuda::Transpose::DoTranspose(*static_cast<const cudaDeviceProp*>(cuda_device_prop), cuda_stream,
cublas_handle, permutation,
past_state_staging, past_state,
&transpose_input_shape_override,
&transpose_output_shape_override);
cuda::ReorderPastStatesKernelLauncher(past_state.MutableDataRaw(),
past_state_staging_buffer,
static_cast<int>(batch_size),
static_cast<int>(num_heads),
static_cast<int>(max_length),
static_cast<int>(head_size),
static_cast<int>(chunk_size),
cuda_stream);

return Status::OK();
}

Status InitCacheIndir(Tensor& cache_indir, Stream* stream) {
Expand Down

0 comments on commit f4d37a2

Please sign in to comment.