Skip to content

Commit

Permalink
only use v1
Browse files Browse the repository at this point in the history
  • Loading branch information
kailums committed Jan 2, 2024
1 parent 7827660 commit 5b25de3
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions onnxruntime/contrib_ops/rocm/bert/paged_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -582,10 +582,10 @@ Status PagedAttention<T>::ComputeInternal(OpKernelContext* context) const {
constexpr int PARTITION_SIZE = 512;
int max_num_partitions = ((input_metadata->max_context_len + PARTITION_SIZE - 1) / PARTITION_SIZE);
//TODO : Tune this heuristic.
bool use_v1 = max_num_partitions == 1 || (query_shape[0] * query_shape[1]) > PARTITION_SIZE ||
(kv_quant_param != nullptr && kv_quant_param->Shape().Size() > 0);
// bool use_v1 = max_num_partitions == 1 || (query_shape[0] * query_shape[1]) > PARTITION_SIZE ||
// (kv_quant_param != nullptr && kv_quant_param->Shape().Size() > 0);
int64_t generation_qeury_shape[3] = {num_valid_tokens - num_prompt_tokens, num_heads_, head_size_};
if (use_v1){
if (true){
paged_attention_v1(Stream(context),
output->MutableData<MLFloat16>() + num_prompt_tokens * num_heads_ * head_size_,
query_data + num_prompt_tokens * num_heads_ * head_size_,
Expand Down

0 comments on commit 5b25de3

Please sign in to comment.