From f6c0c84a012597fe8ff595842d6eba45dfe5d769 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Fri, 20 Dec 2024 04:42:22 +0200 Subject: [PATCH] update Signed-off-by: Chendi Xue --- vllm/attention/backends/hpu_attn.py | 26 ++++---------------------- 1 file changed, 4 insertions(+), 22 deletions(-) diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index 69ec99f7f2ffe..95fb115228c00 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -48,8 +48,8 @@ def prompt_attention( value = value.transpose(1, 2) query_heads = query.size(1) kv_heads = key.size(1) - if attn_bias is not None or fsdpa_op is None: - #if fsdpa_op is None: + #if attn_bias is not None or fsdpa_op is None: + if fsdpa_op is None: if query_heads != kv_heads: query = query.unflatten(1, (kv_heads, -1)) key = key.unflatten(1, (kv_heads, 1)) @@ -80,24 +80,6 @@ def prompt_attention( attn_weights = attn_weights.transpose(1, 2) return attn_weights - -class VLLMKVCache_dev(torch.nn.Module): - - def __init__(self): - super(VLLMKVCache_dev, self).__init__() - self.use_contiguous_pa = os.environ.get('VLLM_CONTIGUOUS_PA', - 'true').lower() == 'true' - - def forward(self, input, cache, block_indices, block_offset): - insert_or_update_cache(input, cache, block_indices, block_offset) - return cache - - def fetch_from_cache(self, cache, blocks): - if self.use_contiguous_pa: - return cache[:blocks.size(0)] - else: - return cache.index_select(0, blocks) - class HPUAttentionBackend(AttentionBackend): @staticmethod @@ -206,8 +188,8 @@ def __init__( self.matmul_av = Matmul() self.batch2block_matmul = Matmul() self.block2batch_matmul = Matmul() - self.k_cache = VLLMKVCache_dev() - self.v_cache = VLLMKVCache_dev() + self.k_cache = VLLMKVCache() + self.v_cache = VLLMKVCache() self.fused_scaled_dot_product_attention = None if HPUFusedSDPA is None \ else ModuleFusedSDPA(HPUFusedSDPA) self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads