Skip to content

Commit

Permalink
Merge branch 'main' into yuslepukhin/fix_cuda_unit_tests
Browse files Browse the repository at this point in the history
  • Loading branch information
yuslepukhin committed Mar 26, 2024
2 parents bb7239d + dfa891a commit 16d728d
Show file tree
Hide file tree
Showing 42 changed files with 3,974 additions and 586 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -164,17 +164,14 @@ export const createConv2DTransposeMatMulProgramInfo =
const outWidth = isChannelsLast ? outputShape[2] : outputShape[3];
const outHeight = isChannelsLast ? outputShape[1] : outputShape[2];
const outChannels = isChannelsLast ? outputShape[3] : outputShape[1];
const isVec4 =
isChannelsLast ? inChannels % 4 === 0 && outChannels % 4 === 0 : outWidth % 4 === 0 && outChannels % 4 === 0;
// TODO: enable vec4 for NCHW
const isVec4 = isChannelsLast && (inChannels % 4 === 0 && inChannels % 3) && outChannels % 4 === 0;

// TODO: fine tune size
const dispatchX = isChannelsLast ? outChannels : outWidth * outHeight;
const dispatchY = isChannelsLast ? outWidth * outHeight : outChannels;
const workGroupSize: [number, number, number] = isVec4 ?
[8, 8, 1] :
[(dispatchX <= 4 || dispatchY <= 4) ? 4 : 16, dispatchX > 4 && dispatchY <= 4 ? 4 : 16, 1];
const elementsPerThread =
isVec4 ? [4, 4, 1] : [dispatchX <= 4 ? 1 : 4, dispatchX > 4 && dispatchY <= 4 ? 1 : 4, 1];
const workGroupSize: [number, number, number] = [8, 8, 1];
const elementsPerThread = dimAOuter <= 8 ? [4, 1, 1] : [4, 4, 1];
const dispatch = [
Math.ceil(dispatchX / workGroupSize[0] / elementsPerThread[0]),
Math.ceil(dispatchY / workGroupSize[1] / elementsPerThread[1]),
Expand Down
262 changes: 262 additions & 0 deletions js/web/test/data/ops/conv-transpose.jsonc

Large diffs are not rendered by default.

23 changes: 20 additions & 3 deletions onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,6 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
!use_flash_attention &&
!disable_memory_efficient_attention_ &&
local_window_size_ == -1 &&
do_rotary_ == false &&
key != nullptr &&
(parameters.head_size & 7) == 0 &&
parameters.sequence_length <= parameters.seqlen_past_kv_cache + parameters.sequence_length &&
(sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) &&
Expand All @@ -172,18 +170,31 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
if (use_memory_efficient_attention && needs_buff) {
kv_buffer_bytes = (sizeof(T) * parameters.batch_size * parameters.num_heads * parameters.seqlen_present_kv_cache * parameters.head_size);
}
size_t rotary_buffer_bytes = 0;
if (use_memory_efficient_attention && do_rotary_) {
rotary_buffer_bytes = 2 * sizeof(T) * parameters.batch_size * parameters.num_heads * parameters.sequence_length * parameters.head_size;
rotary_buffer_bytes += sizeof(int64_t) * parameters.batch_size * parameters.sequence_length;
}
size_t fmha_buffer_bytes = 0;
if (use_memory_efficient_attention && MemoryEfficientAttentionParams::need_workspace(parameters.head_size, sizeof(T) == sizeof(float))) {
fmha_buffer_bytes = (parameters.batch_size * parameters.sequence_length * parameters.num_heads * parameters.head_size * sizeof(float));
}
size_t unpacked_qkv_bytes = 0;
if (use_memory_efficient_attention && parameters.is_packed_qkv) {
unpacked_qkv_bytes = (parameters.batch_size * parameters.sequence_length * (parameters.num_heads + 2 * parameters.kv_num_heads) * parameters.head_size * sizeof(T));
}
auto k_buffer = GetScratchBuffer<void>(kv_buffer_bytes, context->GetComputeStream());
auto v_buffer = GetScratchBuffer<void>(kv_buffer_bytes, context->GetComputeStream());
auto rotary_buffer = GetScratchBuffer<void>(rotary_buffer_bytes, context->GetComputeStream());
auto fmha_buffer = GetScratchBuffer<void>(fmha_buffer_bytes, context->GetComputeStream());
auto unpacked_qkv_buffer = GetScratchBuffer<void>(unpacked_qkv_bytes, context->GetComputeStream());
#else
constexpr bool use_memory_efficient_attention = false;
auto k_buffer = GetScratchBuffer<void>(0, context->GetComputeStream());
auto v_buffer = GetScratchBuffer<void>(0, context->GetComputeStream());
auto rotary_buffer = GetScratchBuffer<void>(0, context->GetComputeStream());
auto fmha_buffer = GetScratchBuffer<void>(0, context->GetComputeStream());
auto unpacked_qkv_buffer = GetScratchBuffer<void>(0, context->GetComputeStream());
#endif

// seqlens_k buffer
Expand Down Expand Up @@ -251,7 +262,13 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
if (fmha_buffer != nullptr) {
data.fmha_buffer = reinterpret_cast<CudaT*>(fmha_buffer.get());
}
// Rotary
if (unpacked_qkv_buffer != nullptr) {
data.unpacked_qkv_buffer = reinterpret_cast<CudaT*>(unpacked_qkv_buffer.get());
}
if (rotary_buffer != nullptr) {
data.rotary_buffer = reinterpret_cast<CudaT*>(rotary_buffer.get());
}
// Rotary Embedding
if (parameters.do_rotary) {
data.cos_cache = reinterpret_cast<const CudaT*>(cos_cache->Data<T>());
data.sin_cache = reinterpret_cast<const CudaT*>(sin_cache->Data<T>());
Expand Down
Loading

0 comments on commit 16d728d

Please sign in to comment.