Skip to content

Commit

Permalink
update cost
Browse files Browse the repository at this point in the history
  • Loading branch information
tianleiwu committed Jul 2, 2024
1 parent 2684e97 commit e4ac550
Showing 1 changed file with 6 additions and 18 deletions.
24 changes: 6 additions & 18 deletions onnxruntime/contrib_ops/cpu/sparse/sparse_attention_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,6 @@ class SparseAttentionBase {
const size_t past_buff_chunk_length = static_cast<size_t>(past_buffer_sequence_length) * head_size;
const size_t present_buff_chunk_length = static_cast<size_t>(present_buffer_sequence_length) * head_size;

// if (!past_present_share_buffer) {
// memset(present_key, 0, batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T));
// }

const int loop_len = batch_size * num_heads_;
const float alpha = scale_ == 0.0f ? 1.0f / sqrt(static_cast<float>(head_size)) : scale_;

Expand All @@ -148,7 +144,7 @@ class SparseAttentionBase {
unit_cost.bytes_loaded += static_cast<double>(probs_matrix_bytes);
unit_cost.bytes_stored += static_cast<double>(probs_matrix_bytes);

// cost to concatenate current key to cache
// Cost to concatenate current key to cache (assume past and present share buffer).
double bytes_to_copy_key = static_cast<double>(sizeof(T) * sequence_length * head_size);
unit_cost.bytes_loaded += bytes_to_copy_key;
unit_cost.bytes_stored += bytes_to_copy_key;
Expand Down Expand Up @@ -329,29 +325,21 @@ class SparseAttentionBase {
const size_t past_buff_chunk_length = static_cast<size_t>(past_buffer_sequence_length) * head_size;
const size_t present_buff_chunk_length = static_cast<size_t>(present_buffer_sequence_length) * head_size;

// if (!past_present_share_buffer) {
// memset(present_value, 0, batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T));
// }

// The cost of Gemm
// The cost of Gemm.
TensorOpCost unit_cost;
// Here we use total_sequence_length to estimate total_key_lengths[batch_index] used in GEMM.
unit_cost.compute_cycles =
static_cast<double>(SafeInt<ptrdiff_t>(2) * sequence_length * head_size * present_buffer_sequence_length);
static_cast<double>(SafeInt<ptrdiff_t>(2) * sequence_length * head_size * total_sequence_length);
unit_cost.bytes_loaded = static_cast<double>(SafeInt<ptrdiff_t>(sequence_length + head_size) *
present_buffer_sequence_length * sizeof(T));
total_sequence_length * sizeof(T));
unit_cost.bytes_stored = static_cast<double>(sequence_length * head_size * sizeof(T));

if (present_value) {
double bytes_to_copy_value = static_cast<double>(present_buff_chunk_length * sizeof(T));
double bytes_to_copy_value = static_cast<double>(sizeof(T) * sequence_length * head_size);
unit_cost.bytes_loaded += bytes_to_copy_value;
unit_cost.bytes_stored += bytes_to_copy_value;
}

const size_t bytes_to_copy_trans = SafeInt<size_t>(head_size) * sizeof(T);
double bytes_to_copy_trans_all = static_cast<double>(sequence_length * bytes_to_copy_trans);
unit_cost.bytes_loaded += bytes_to_copy_trans_all;
unit_cost.bytes_stored += bytes_to_copy_trans_all;

DUMP_CPU_TENSOR_INIT();

ThreadPool::TryParallelFor(
Expand Down

0 comments on commit e4ac550

Please sign in to comment.