Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
Signed-off-by: Chendi Xue <[email protected]>
  • Loading branch information
xuechendi committed Dec 20, 2024
1 parent 510722e commit f6c0c84
Showing 1 changed file with 4 additions and 22 deletions.
26 changes: 4 additions & 22 deletions vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def prompt_attention(
value = value.transpose(1, 2)
query_heads = query.size(1)
kv_heads = key.size(1)
if attn_bias is not None or fsdpa_op is None:
#if fsdpa_op is None:
#if attn_bias is not None or fsdpa_op is None:
if fsdpa_op is None:
if query_heads != kv_heads:
query = query.unflatten(1, (kv_heads, -1))
key = key.unflatten(1, (kv_heads, 1))
Expand Down Expand Up @@ -80,24 +80,6 @@ def prompt_attention(
attn_weights = attn_weights.transpose(1, 2)
return attn_weights


class VLLMKVCache_dev(torch.nn.Module):

def __init__(self):
super(VLLMKVCache_dev, self).__init__()
self.use_contiguous_pa = os.environ.get('VLLM_CONTIGUOUS_PA',
'true').lower() == 'true'

def forward(self, input, cache, block_indices, block_offset):
insert_or_update_cache(input, cache, block_indices, block_offset)
return cache

def fetch_from_cache(self, cache, blocks):
if self.use_contiguous_pa:
return cache[:blocks.size(0)]
else:
return cache.index_select(0, blocks)

class HPUAttentionBackend(AttentionBackend):

@staticmethod
Expand Down Expand Up @@ -206,8 +188,8 @@ def __init__(
self.matmul_av = Matmul()
self.batch2block_matmul = Matmul()
self.block2batch_matmul = Matmul()
self.k_cache = VLLMKVCache_dev()
self.v_cache = VLLMKVCache_dev()
self.k_cache = VLLMKVCache()
self.v_cache = VLLMKVCache()
self.fused_scaled_dot_product_attention = None if HPUFusedSDPA is None \
else ModuleFusedSDPA(HPUFusedSDPA)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
Expand Down

0 comments on commit f6c0c84

Please sign in to comment.