From 388dabe5f1a682769c3bdeb9054c738525963cf5 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 19 Aug 2024 08:32:20 -0700 Subject: [PATCH] undo unrelated; static_cast; comments --- onnxruntime/contrib_ops/cpu/utils/console_dumper.h | 6 +----- onnxruntime/contrib_ops/cuda/bert/attention_impl.cu | 13 +++++++------ .../cuda/bert/cudnn_fmha/cudnn_flash_attention.cu | 10 +++++++--- .../contrib_ops/cuda/bert/packed_attention.cc | 4 ++-- .../cuda/bert/packed_multihead_attention.cc | 2 +- 5 files changed, 18 insertions(+), 17 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/utils/console_dumper.h b/onnxruntime/contrib_ops/cpu/utils/console_dumper.h index 9ebc44f4411eb..12cbc5049a02a 100644 --- a/onnxruntime/contrib_ops/cpu/utils/console_dumper.h +++ b/onnxruntime/contrib_ops/cpu/utils/console_dumper.h @@ -53,11 +53,7 @@ void PrintTensorByDims(const TConsoleDumper* dumper, const char* name, const T* tensor, gsl::span& dims) { - if (!dumper->IsEnabled()) { - return; - } - - if ((tensor == nullptr || dims.size() == 0)) { + if (dumper->IsEnabled() && (tensor == nullptr || dims.size() == 0)) { std::cout << std::string(name) << " is None" << std::endl; return; } diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index f271942ca2215..a02f5c7329b9a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -384,7 +384,8 @@ Status CudnnFlashAttention( ORT_UNUSED_PARAMETER(parameters); ORT_UNUSED_PARAMETER(data); ORT_UNUSED_PARAMETER(scale); - return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, "cudnn flash attention does not support float tensor"); + return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, + "cudnn flash attention does not support float tensor"); } #if USE_MEMORY_EFFICIENT_ATTENTION @@ -580,11 +581,11 @@ Status QkvToContext( void* fused_runner = data.fused_runner; // At most one fused kernel is enabled. - assert((int(data.use_flash_attention) + - int(data.use_memory_efficient_attention) + - int(fused_runner != nullptr) + - int(data.fused_cross_attention_kernel != nullptr) + - int(data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention)) <= 1); + assert((static_cast(data.use_flash_attention) + + static_cast(data.use_memory_efficient_attention) + + static_cast(fused_runner != nullptr) + + static_cast(data.fused_cross_attention_kernel != nullptr) + + static_cast(data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention)) <= 1); ORT_RETURN_IF_ERROR(PrepareQkv(parameters, data, stream, max_threads_per_block)); diff --git a/onnxruntime/contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.cu b/onnxruntime/contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.cu index 8d0c6bd2bbb81..426b105dff8db 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.cu +++ b/onnxruntime/contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.cu @@ -2,6 +2,9 @@ // Licensed under the MIT License. #include "contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.h" +#include +#include +#include #include #if CUDNN_MAJOR < 9 @@ -306,13 +309,14 @@ struct BytesHash { value ^= ptr[i]; value *= 0x01000193; } - return (size_t)value; + return static_cast(value); } }; // Use thread local caches because cuDNN execution plans are not guaranteed to be thread safe. -// TODO: since we the key includes sequence lengths, we may want to limit the cache size. -thread_local std::unordered_map, BytesHash > mha_graph_cache; +// TODO(tianleiwu): since we the key includes sequence lengths, we may want to limit the cache size. +thread_local +std::unordered_map, BytesHash > mha_graph_cache; void run( void* output, diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc b/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc index f486d08244547..0e5300f32da3c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc @@ -111,7 +111,7 @@ Status PackedAttention::CheckInputs(const TensorShape& input_shape, // Abbreviation and Meanings: // T: token_count // B: batch_size - // S: sequence_length + // S: sequence_length (input sequence length of query) // N: num_heads // H: head size for Q and K, aka q_head_size or v_head_size or qk_head_size // H_v: v_head_size @@ -125,7 +125,7 @@ Status PackedAttention::CheckInputs(const TensorShape& input_shape, // bias (Q/K/V) : (D + D + D_v) // token_offset : (B, S) // cu_seq_len_shape : (B + 1) - // attention_bias : (B or 1, N or 1, S, S) or NULL + // attention_bias : (B, N, S, S), (1, N, S, S) or NULL const auto& input_dims = input_shape.GetDims(); if (input_dims.size() != 2) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc index b0c3a28df2336..72a4c776d4fce 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc @@ -68,7 +68,7 @@ Status PackedMultiHeadAttention::CheckInputs(const TensorShape& query_shape, // Input 'value': None // Input 'token_offset': (batch_size, sequence_length) // Input 'cumulative_sequence_length': (batch_size + 1) - // Input 'attention_bias': (batch_size or 1, num_heads or 1, sequence_length, sequence_length) or None + // Input 'attention_bias': (batch_size or 1, num_heads, sequence_length, sequence_length) or None // Output 'output': (token_count, v_hidden_size) const auto& query_dims = query_shape.GetDims();