diff --git a/vllm/attention/backends/habana_attn.py b/vllm/attention/backends/habana_attn.py index 56b71a431aca7..869e7f45153ce 100644 --- a/vllm/attention/backends/habana_attn.py +++ b/vllm/attention/backends/habana_attn.py @@ -206,6 +206,7 @@ def forward( matmul_qk_op=self.matmul_qk, softmax_op=self.softmax, matmul_av_op=self.matmul_av, + valid_seq_lengths=attn_metadata.seq_lens_tensor, ) output = out.reshape(batch_size, seq_len, hidden_size) else: diff --git a/vllm/hpu/ops.py b/vllm/hpu/ops.py index 939d195a12b08..9e901d2ad0b7b 100644 --- a/vllm/hpu/ops.py +++ b/vllm/hpu/ops.py @@ -96,22 +96,6 @@ def silu_and_mul(x: torch.Tensor) -> torch.Tensor: d = x.shape[-1] // 2 return F.silu(x[..., :d]) * x[..., d:] - -#TODO: remove after fusedsdpa fix for query_head != kv_head -def repeat_kv(kv: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). - The kv go from (batch, num_key_value_heads, seqlen, head_dim) to - (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = kv.shape - if n_rep == 1: - return kv - kv = kv[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, - head_dim) - return kv.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - def prompt_attention( query: torch.Tensor, key: torch.Tensor, @@ -144,10 +128,6 @@ def prompt_attention( if query_heads != kv_heads: attn_weights = attn_weights.flatten(1, 2) else: - #TODO: remove after fusedsdpa fix for query_heads != kv_heads - if query_heads != kv_heads: - key = repeat_kv(key, int(query_heads // kv_heads)) - value = repeat_kv(value, int(query_heads // kv_heads)) softmax_mode = 'fast' recompute_mode = True attn_weights = FusedSDPA.apply(query, key, value, None, 0.0, True,