Skip to content

Commit

Permalink
Added missed valid_seq_lengths from FusedSdpa prompt_attention.
Browse files Browse the repository at this point in the history
Removed workaround for fusedsdpa when query_heads and kv_heads are not the same.
  • Loading branch information
libinta committed Sep 23, 2024
1 parent 84b2490 commit b92ff50
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 20 deletions.
1 change: 1 addition & 0 deletions vllm/attention/backends/habana_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def forward(
matmul_qk_op=self.matmul_qk,
softmax_op=self.softmax,
matmul_av_op=self.matmul_av,
valid_seq_lengths=attn_metadata.seq_lens_tensor,
)
output = out.reshape(batch_size, seq_len, hidden_size)
else:
Expand Down
20 changes: 0 additions & 20 deletions vllm/hpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,22 +96,6 @@ def silu_and_mul(x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:]


#TODO: remove after fusedsdpa fix for query_head != kv_head
def repeat_kv(kv: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
The kv go from (batch, num_key_value_heads, seqlen, head_dim) to
(batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = kv.shape
if n_rep == 1:
return kv
kv = kv[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen,
head_dim)
return kv.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


def prompt_attention(
query: torch.Tensor,
key: torch.Tensor,
Expand Down Expand Up @@ -144,10 +128,6 @@ def prompt_attention(
if query_heads != kv_heads:
attn_weights = attn_weights.flatten(1, 2)
else:
#TODO: remove after fusedsdpa fix for query_heads != kv_heads
if query_heads != kv_heads:
key = repeat_kv(key, int(query_heads // kv_heads))
value = repeat_kv(value, int(query_heads // kv_heads))
softmax_mode = 'fast'
recompute_mode = True
attn_weights = FusedSDPA.apply(query, key, value, None, 0.0, True,
Expand Down

0 comments on commit b92ff50

Please sign in to comment.