Skip to content

Commit

Permalink
a
Browse files Browse the repository at this point in the history
  • Loading branch information
wangyems committed Oct 2, 2023
1 parent ac4e726 commit a4ee6ea
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 24 deletions.
64 changes: 64 additions & 0 deletions onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1315,6 +1315,70 @@ template void BufferExpansionKernelLauncher(const int32_t* input,
int chunk_size,
cudaStream_t stream);

template <typename T>
__global__ void ReorderPastStatesKernel(T* out_buffer,
const T* in_buffer,
int batch_size,
int num_heads,
int max_length,
int chunked_head_size,
int equv_chunk_size) {
//[B, N, max_length, H2(head_size/chunk_size), chunk_size] -> [B, N, H2(head_size/chunk_size), max_length, chunk_size]
const int b = blockIdx.x;
const int n = blockIdx.y;
const int s = blockIdx.z;
const int h2 = threadIdx.x;
const int c = threadIdx.y;

const int in_offset = b * num_heads * max_length * chunked_head_size * equv_chunk_size +
n * max_length * chunked_head_size * equv_chunk_size +
s * chunked_head_size * equv_chunk_size +
h2 * equv_chunk_size +
c;

const int out_offset = b * num_heads * max_length * chunked_head_size * equv_chunk_size +
n * max_length * chunked_head_size * equv_chunk_size +
h2 * max_length * equv_chunk_size +
s * equv_chunk_size +
c;

out_buffer[out_offset] = in_buffer[in_offset];
}

void ReorderPastStatesKernelLauncher(void* out_buffer,
const void* in_buffer,
int batch_size,
int num_heads,
int max_length,
int head_size,
int chunk_size,
cudaStream_t stream) {
const dim3 grid(batch_size, num_heads, max_length);
const dim3 block(head_size / chunk_size, chunk_size / 4);
if (chunk_size == 4) {
// float
ReorderPastStatesKernel<<<grid, block, 0, stream>>>(reinterpret_cast<float4*>(out_buffer),
reinterpret_cast<const float4*>(in_buffer),
batch_size,
num_heads,
max_length,
head_size / chunk_size,
chunk_size / 4);
} else if (chunk_size == 8) {
// half
ReorderPastStatesKernel<<<grid, block, 0, stream>>>(reinterpret_cast<Half4*>(out_buffer),
reinterpret_cast<const Half4*>(in_buffer),
batch_size,
num_heads,
max_length,
head_size / chunk_size,
chunk_size / 4);
} else {
// throw not support error
ORT_THROW("ReorderPastStatesKernelLauncher only support float or half");
}
}

} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,15 @@ void BufferExpansionKernelLauncher(const T* input,
int chunk_size,
cudaStream_t stream);


void ReorderPastStatesKernelLauncher(void* out_buffer,
const void* in_buffer,
int batch_size,
int num_heads,
int max_length,
int head_size,
int chunk_size,
cudaStream_t stream);
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -56,50 +56,45 @@ namespace GenerationCudaDeviceHelper {
// It might be better to forcefully require the same type since cast node generates
// extra overhead.
Status ReorderPastState(
const void* cuda_device_prop,
const void*,
Tensor& past_state,
Tensor& past_state_staging,
Stream* stream) {
ORT_ENFORCE(stream);
cudaStream_t cuda_stream = reinterpret_cast<cudaStream_t>(stream->GetHandle());
cublasHandle_t cublas_handle = static_cast<CudaStream*>(stream)->cublas_handle_;

const auto& past_state_shape = past_state.Shape();

const auto& past_state_dims = past_state_shape.GetDims();
const bool packed_past = past_state_dims.size() == 5;

size_t batch_size = packed_past ? past_state_dims[1] : past_state_dims[0];
size_t num_heads = packed_past ? past_state_dims[2] : past_state_dims[1];
size_t max_length = packed_past ? past_state_dims[3] : past_state_dims[2];
size_t head_size = packed_past ? past_state_dims[4] : past_state_dims[3];
std::cout << "75: " << std::endl;
// Copy the 'K' values into the temp staging buffer
size_t past_state_size = packed_past ? past_state.SizeInBytes() / 2 : past_state.SizeInBytes();
void* past_state_staging_buffer = past_state_staging.MutableDataRaw();
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(past_state_staging_buffer, past_state.DataRaw(), past_state_size,
cudaMemcpyDeviceToDevice, cuda_stream));

std::cout << "81: " << std::endl;
// Now consider the original 'K' values to be of shape [B, N, max_length, head_size / x, x] and transpose it into
// [B, N, head_size / x, max_length, x], where x = 16 / sizeof(T)
int64_t chunk_size = static_cast<int64_t>(16 / past_state.DataType()->Size());
std::cout << "chunk_size: " << chunk_size << std::endl;
cuda::ReorderPastStatesKernelLauncher(past_state.MutableDataRaw(),
past_state_staging_buffer,
static_cast<int>(batch_size),
static_cast<int>(num_heads),
static_cast<int>(max_length),
static_cast<int>(head_size),
static_cast<int>(chunk_size),
cuda_stream);
cudaDeviceSynchronize();
std::cout << "reordered past state kernel done" << std::endl;

std::vector<size_t> permutation_vector = {0, 1, 3, 2, 4};
gsl::span<size_t> permutation(permutation_vector.data(), 5);

// "Fake" the shapes of the input and output tensors of the Transpose operation to suit our need
size_t offset = packed_past ? 1 : 0;
TensorShape transpose_input_shape_override = {past_state_shape[offset],
past_state_shape[offset + 1],
past_state_shape[offset + 2],
past_state_shape[offset + 3] / chunk_size,
chunk_size};

TensorShape transpose_output_shape_override = {past_state_shape[offset], past_state_shape[offset + 1],
past_state_shape[offset + 3] / chunk_size, past_state_shape[offset + 2],
chunk_size};

// TODO(hasesh): Explore perf tuning for this Transpose operation
return onnxruntime::cuda::Transpose::DoTranspose(*static_cast<const cudaDeviceProp*>(cuda_device_prop), cuda_stream,
cublas_handle, permutation,
past_state_staging, past_state,
&transpose_input_shape_override,
&transpose_output_shape_override);
return Status::OK();
}

Status InitCacheIndir(Tensor& cache_indir, Stream* stream) {
Expand Down

0 comments on commit a4ee6ea

Please sign in to comment.