Skip to content

Commit

Permalink
comments
Browse files Browse the repository at this point in the history
  • Loading branch information
tianleiwu committed Aug 17, 2024
1 parent fad5de2 commit c54f79c
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 12 deletions.
6 changes: 5 additions & 1 deletion onnxruntime/contrib_ops/cpu/utils/console_dumper.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,11 @@ void PrintTensorByDims(const TConsoleDumper* dumper,
const char* name,
const T* tensor,
gsl::span<const int64_t>& dims) {
if (dumper->IsEnabled() && (tensor == nullptr || dims.size() == 0)) {
if (!dumper->IsEnabled()) {
return;
}

if ((tensor == nullptr || dims.size() == 0)) {
std::cout << std::string(name) << " is None" << std::endl;
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ void run(
void* q,
void* k,
void* v,
void* bias, // (optional) attention bias with shape [b or 1, h_q or 1, s_q, s_kv].
int* mask_sequence_lengths_q, // (optional) sequence lengths of q for padding mask. Shape: [batch_size]
int* mask_sequence_lengths_kv, // (optional) sequence lengths of k or v for padding mask. Shape: [batch_size]
void* bias, // (optional) attention bias with shape [b or 1, h_q or 1, s_q, s_kv].
int* mask_sequence_lengths_q, // (optional) sequence lengths of q for padding mask. Shape: [batch_size]
int* mask_sequence_lengths_kv, // (optional) sequence lengths of k or v for padding mask. Shape: [batch_size]
int batch_size,
int num_heads_q,
int num_heads_kv,
Expand All @@ -38,11 +38,11 @@ void run(
int sequence_length_kv,
float scale,
bool is_causal,
bool is_bf16, // True if bfloat16, otherwise float16
bool broadcast_attn_bias_dim_0, // broadcast attention bias dimension 0
bool broadcast_attn_bias_dim_1, // broadcast attention bias dimension 1
int sliding_window, // sliding window length. 0 means no sliding window.
AttentionQkvFormat qkv_format, // Q_K_V_BNSH, Q_K_V_BSNH, Q_K_V_BSNH_BNSH_BNSH are supported
bool is_bf16, // True if bfloat16, otherwise float16
bool broadcast_attn_bias_dim_0, // broadcast attention bias dimension 0
bool broadcast_attn_bias_dim_1, // broadcast attention bias dimension 1
int sliding_window, // sliding window length. 0 means no sliding window.
AttentionQkvFormat qkv_format, // Q_K_V_BNSH, Q_K_V_BSNH, Q_K_V_BSNH_BNSH_BNSH are supported
cudnnHandle_t handle,
Stream* stream,
AllocatorPtr allocator);
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/contrib_ops/cuda/bert/packed_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ Status PackedAttention<T>::CheckInputs(const TensorShape& input_shape,
// Abbreviation and Meanings:
// T: token_count
// B: batch_size
// S: sequence_length (input sequence length of query)
// S: sequence_length
// N: num_heads
// H: head size for Q and K, aka q_head_size or v_head_size or qk_head_size
// H_v: v_head_size
Expand All @@ -125,7 +125,7 @@ Status PackedAttention<T>::CheckInputs(const TensorShape& input_shape,
// bias (Q/K/V) : (D + D + D_v)
// token_offset : (B, S)
// cu_seq_len_shape : (B + 1)
// attention_bias : (B, N, S, S), (1, N, S, S) or NULL
// attention_bias : (B or 1, N or 1, S, S) or NULL
const auto& input_dims = input_shape.GetDims();
if (input_dims.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ Status PackedMultiHeadAttention<T>::CheckInputs(const TensorShape& query_shape,
// Input 'value': None
// Input 'token_offset': (batch_size, sequence_length)
// Input 'cumulative_sequence_length': (batch_size + 1)
// Input 'attention_bias': (batch_size or 1, num_heads, sequence_length, sequence_length) or None
// Input 'attention_bias': (batch_size or 1, num_heads or 1, sequence_length, sequence_length) or None
// Output 'output': (token_count, v_hidden_size)

const auto& query_dims = query_shape.GetDims();
Expand Down

0 comments on commit c54f79c

Please sign in to comment.