Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into quant
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaoyu-work committed Mar 25, 2024
2 parents d95e8cf + 7d976cf commit 23fd040
Show file tree
Hide file tree
Showing 31 changed files with 3,668 additions and 544 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -164,17 +164,14 @@ export const createConv2DTransposeMatMulProgramInfo =
const outWidth = isChannelsLast ? outputShape[2] : outputShape[3];
const outHeight = isChannelsLast ? outputShape[1] : outputShape[2];
const outChannels = isChannelsLast ? outputShape[3] : outputShape[1];
const isVec4 =
isChannelsLast ? inChannels % 4 === 0 && outChannels % 4 === 0 : outWidth % 4 === 0 && outChannels % 4 === 0;
// TODO: enable vec4 for NCHW
const isVec4 = isChannelsLast && (inChannels % 4 === 0 && inChannels % 3) && outChannels % 4 === 0;

// TODO: fine tune size
const dispatchX = isChannelsLast ? outChannels : outWidth * outHeight;
const dispatchY = isChannelsLast ? outWidth * outHeight : outChannels;
const workGroupSize: [number, number, number] = isVec4 ?
[8, 8, 1] :
[(dispatchX <= 4 || dispatchY <= 4) ? 4 : 16, dispatchX > 4 && dispatchY <= 4 ? 4 : 16, 1];
const elementsPerThread =
isVec4 ? [4, 4, 1] : [dispatchX <= 4 ? 1 : 4, dispatchX > 4 && dispatchY <= 4 ? 1 : 4, 1];
const workGroupSize: [number, number, number] = [8, 8, 1];
const elementsPerThread = dimAOuter <= 8 ? [4, 1, 1] : [4, 4, 1];
const dispatch = [
Math.ceil(dispatchX / workGroupSize[0] / elementsPerThread[0]),
Math.ceil(dispatchY / workGroupSize[1] / elementsPerThread[1]),
Expand Down
262 changes: 262 additions & 0 deletions js/web/test/data/ops/conv-transpose.jsonc

Large diffs are not rendered by default.

23 changes: 20 additions & 3 deletions onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,6 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
!use_flash_attention &&
!disable_memory_efficient_attention_ &&
local_window_size_ == -1 &&
do_rotary_ == false &&
key != nullptr &&
(parameters.head_size & 7) == 0 &&
parameters.sequence_length <= parameters.seqlen_past_kv_cache + parameters.sequence_length &&
(sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) &&
Expand All @@ -172,18 +170,31 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
if (use_memory_efficient_attention && needs_buff) {
kv_buffer_bytes = (sizeof(T) * parameters.batch_size * parameters.num_heads * parameters.seqlen_present_kv_cache * parameters.head_size);
}
size_t rotary_buffer_bytes = 0;
if (use_memory_efficient_attention && do_rotary_) {
rotary_buffer_bytes = 2 * sizeof(T) * parameters.batch_size * parameters.num_heads * parameters.sequence_length * parameters.head_size;
rotary_buffer_bytes += sizeof(int64_t) * parameters.batch_size * parameters.sequence_length;
}
size_t fmha_buffer_bytes = 0;
if (use_memory_efficient_attention && MemoryEfficientAttentionParams::need_workspace(parameters.head_size, sizeof(T) == sizeof(float))) {
fmha_buffer_bytes = (parameters.batch_size * parameters.sequence_length * parameters.num_heads * parameters.head_size * sizeof(float));
}
size_t unpacked_qkv_bytes = 0;
if (use_memory_efficient_attention && parameters.is_packed_qkv) {
unpacked_qkv_bytes = (parameters.batch_size * parameters.sequence_length * (parameters.num_heads + 2 * parameters.kv_num_heads) * parameters.head_size * sizeof(T));
}
auto k_buffer = GetScratchBuffer<void>(kv_buffer_bytes, context->GetComputeStream());
auto v_buffer = GetScratchBuffer<void>(kv_buffer_bytes, context->GetComputeStream());
auto rotary_buffer = GetScratchBuffer<void>(rotary_buffer_bytes, context->GetComputeStream());
auto fmha_buffer = GetScratchBuffer<void>(fmha_buffer_bytes, context->GetComputeStream());
auto unpacked_qkv_buffer = GetScratchBuffer<void>(unpacked_qkv_bytes, context->GetComputeStream());
#else
constexpr bool use_memory_efficient_attention = false;
auto k_buffer = GetScratchBuffer<void>(0, context->GetComputeStream());
auto v_buffer = GetScratchBuffer<void>(0, context->GetComputeStream());
auto rotary_buffer = GetScratchBuffer<void>(0, context->GetComputeStream());
auto fmha_buffer = GetScratchBuffer<void>(0, context->GetComputeStream());
auto unpacked_qkv_buffer = GetScratchBuffer<void>(0, context->GetComputeStream());
#endif

// seqlens_k buffer
Expand Down Expand Up @@ -251,7 +262,13 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
if (fmha_buffer != nullptr) {
data.fmha_buffer = reinterpret_cast<CudaT*>(fmha_buffer.get());
}
// Rotary
if (unpacked_qkv_buffer != nullptr) {
data.unpacked_qkv_buffer = reinterpret_cast<CudaT*>(unpacked_qkv_buffer.get());
}
if (rotary_buffer != nullptr) {
data.rotary_buffer = reinterpret_cast<CudaT*>(rotary_buffer.get());
}
// Rotary Embedding
if (parameters.do_rotary) {
data.cos_cache = reinterpret_cast<const CudaT*>(cos_cache->Data<T>());
data.sin_cache = reinterpret_cast<const CudaT*>(sin_cache->Data<T>());
Expand Down
Loading

0 comments on commit 23fd040

Please sign in to comment.