Skip to content

Commit

Permalink
undo unrelated; static_cast; comments
Browse files Browse the repository at this point in the history
  • Loading branch information
tianleiwu committed Aug 19, 2024
1 parent 9c78a6d commit 388dabe
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 17 deletions.
6 changes: 1 addition & 5 deletions onnxruntime/contrib_ops/cpu/utils/console_dumper.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,7 @@ void PrintTensorByDims(const TConsoleDumper* dumper,
const char* name,
const T* tensor,
gsl::span<const int64_t>& dims) {
if (!dumper->IsEnabled()) {
return;
}

if ((tensor == nullptr || dims.size() == 0)) {
if (dumper->IsEnabled() && (tensor == nullptr || dims.size() == 0)) {
std::cout << std::string(name) << " is None" << std::endl;
return;
}
Expand Down
13 changes: 7 additions & 6 deletions onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,8 @@ Status CudnnFlashAttention(
ORT_UNUSED_PARAMETER(parameters);
ORT_UNUSED_PARAMETER(data);
ORT_UNUSED_PARAMETER(scale);
return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, "cudnn flash attention does not support float tensor");
return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED,
"cudnn flash attention does not support float tensor");
}

#if USE_MEMORY_EFFICIENT_ATTENTION
Expand Down Expand Up @@ -580,11 +581,11 @@ Status QkvToContext(
void* fused_runner = data.fused_runner;

// At most one fused kernel is enabled.
assert((int(data.use_flash_attention) +
int(data.use_memory_efficient_attention) +
int(fused_runner != nullptr) +
int(data.fused_cross_attention_kernel != nullptr) +
int(data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention)) <= 1);
assert((static_cast<int>(data.use_flash_attention) +
static_cast<int>(data.use_memory_efficient_attention) +
static_cast<int>(fused_runner != nullptr) +
static_cast<int>(data.fused_cross_attention_kernel != nullptr) +
static_cast<int>(data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention)) <= 1);

ORT_RETURN_IF_ERROR(PrepareQkv<T>(parameters, data, stream, max_threads_per_block));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
// Licensed under the MIT License.

#include "contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.h"
#include <memory>
#include <vector>
#include <unordered_map>
#include <cudnn.h>

Check warning on line 8 in onnxruntime/contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.cu

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Found C system header after C++ system header. Should be: cudnn_flash_attention.h, c system, c++ system, other. [build/include_order] [4] Raw Output: onnxruntime/contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.cu:8: Found C system header after C++ system header. Should be: cudnn_flash_attention.h, c system, c++ system, other. [build/include_order] [4]

#if CUDNN_MAJOR < 9
Expand Down Expand Up @@ -306,13 +309,14 @@ struct BytesHash {
value ^= ptr[i];
value *= 0x01000193;
}
return (size_t)value;
return static_cast<size_t>(value);
}
};

// Use thread local caches because cuDNN execution plans are not guaranteed to be thread safe.
// TODO: since we the key includes sequence lengths, we may want to limit the cache size.
thread_local std::unordered_map<GraphParams, std::shared_ptr<fe::graph::Graph>, BytesHash<GraphParams> > mha_graph_cache;
// TODO(tianleiwu): since we the key includes sequence lengths, we may want to limit the cache size.
thread_local
std::unordered_map<GraphParams, std::shared_ptr<fe::graph::Graph>, BytesHash<GraphParams> > mha_graph_cache;

void run(
void* output,
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/contrib_ops/cuda/bert/packed_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ Status PackedAttention<T>::CheckInputs(const TensorShape& input_shape,
// Abbreviation and Meanings:
// T: token_count
// B: batch_size
// S: sequence_length
// S: sequence_length (input sequence length of query)
// N: num_heads
// H: head size for Q and K, aka q_head_size or v_head_size or qk_head_size
// H_v: v_head_size
Expand All @@ -125,7 +125,7 @@ Status PackedAttention<T>::CheckInputs(const TensorShape& input_shape,
// bias (Q/K/V) : (D + D + D_v)
// token_offset : (B, S)
// cu_seq_len_shape : (B + 1)
// attention_bias : (B or 1, N or 1, S, S) or NULL
// attention_bias : (B, N, S, S), (1, N, S, S) or NULL
const auto& input_dims = input_shape.GetDims();
if (input_dims.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ Status PackedMultiHeadAttention<T>::CheckInputs(const TensorShape& query_shape,
// Input 'value': None
// Input 'token_offset': (batch_size, sequence_length)
// Input 'cumulative_sequence_length': (batch_size + 1)
// Input 'attention_bias': (batch_size or 1, num_heads or 1, sequence_length, sequence_length) or None
// Input 'attention_bias': (batch_size or 1, num_heads, sequence_length, sequence_length) or None
// Output 'output': (token_count, v_hidden_size)

const auto& query_dims = query_shape.GetDims();
Expand Down

0 comments on commit 388dabe

Please sign in to comment.