diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index f5984caf96c5c..e724a087fbe83 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -29,6 +29,56 @@ logger.warning("Could not import HPU FusedSDPA kernel. " "vLLM will use native implementation.") +def prompt_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_bias: Optional[torch.Tensor] = None, + p: float = 0.0, + scale: Optional[float] = None, + matmul_qk_op=torch.matmul, + softmax_op=torch.softmax, + matmul_av_op=torch.matmul, + valid_seq_lengths: Optional[torch.Tensor] = None, + fsdpa_op = None, +) -> torch.Tensor: + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + query_heads = query.size(1) + kv_heads = key.size(1) + if attn_bias is not None: + attn_bias = attn_bias.expand(query.size(0), query_heads, query.size(2), key.size(2)) + if attn_bias is not None or fsdpa_op is None: + if query_heads != kv_heads: + query = query.unflatten(1, (kv_heads, -1)) + key = key.unflatten(1, (kv_heads, 1)) + value = value.unflatten(1, (kv_heads, 1)) + if attn_bias is not None: + attn_bias = attn_bias.unflatten(1, (kv_heads, -1)) + attn_weights = matmul_qk_op(query * scale, key.transpose(-1, -2)) + if attn_bias is not None: + attn_weights.add_(attn_bias) + attn_weights = softmax_op(attn_weights, dim=-1) + attn_weights = matmul_av_op(attn_weights, value) + if query_heads != kv_heads: + attn_weights = attn_weights.flatten(1, 2) + else: + VLLM_DO_NOT_REMOVE_REPEAT_KV_CACHE = os.environ.get('VLLM_REMOVE_REPEAT_KV_CACHE', '1') == '1' + # TODO: remove after fusedsdpa fix for query_heads != kv_heads + if query_heads != kv_heads: + if VLLM_DO_NOT_REMOVE_REPEAT_KV_CACHE: + key = ops.repeat_kv(key, int(query_heads // kv_heads)) + value = ops.repeat_kv(value, int(query_heads // kv_heads)) + if attn_bias is not None: + attn_bias = attn_bias.unflatten(1, (kv_heads, -1)) + softmax_mode = 'fast' + recompute_mode = True + attn_weights = fsdpa_op(query=query, key=key, value=value, attn_mask=attn_bias, dropout_p=0.0, is_causal=True, + scale=scale, softmax_mode=softmax_mode, recompute_mode=recompute_mode, + valid_sequence_lengths=valid_seq_lengths, padding_side='right') + attn_weights = attn_weights.transpose(1, 2) + return attn_weights class HPUAttentionBackend(AttentionBackend): @@ -228,8 +278,10 @@ def forward( if attn_metadata.is_prompt: padded_shape = attn_metadata.input_tokens_padded_tensor seq_lens_tensor_list = seq_lens_tensor.tolist() - padded_key_tensor = torch.zeros(padded_shape[0], padded_shape[1], self.num_kv_heads, self.head_size, device=key.device, dtype=key.dtype) - padded_value_tensor = torch.zeros(padded_shape[0], padded_shape[1], self.num_kv_heads, self.head_size, device=query.device, dtype=key.dtype) + padded_key_tensor = torch.zeros((padded_shape[0], padded_shape[1], self.num_kv_heads, self.head_size), + device=key.device, dtype=key.dtype) + padded_value_tensor = torch.zeros((padded_shape[0], padded_shape[1], self.num_kv_heads, self.head_size), + device=query.device, dtype=key.dtype) start = 0 # we need to copy the key and value tensors to the padded tensors # shape is [bacth_size, entire_seq_len, num_kv_heads, head_size] @@ -284,11 +336,16 @@ def forward( attn_bias = attn_bias.tile( (1, self.num_kv_heads, 1, 1)) attn_bias.add_(position_bias) + elif enable_merged_prefill: + pass else: attn_bias = None - #self.fused_scaled_dot_product_attention = None - out = ops.prompt_attention( + if enable_merged_prefill: + prompt_attn_func = prompt_attention + else: + prompt_attn_func = ops.prompt_attention + out = prompt_attn_func( query.view(query_shape), key.view(kv_shape), value.view(kv_shape), diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 6305e96a24919..27e0aa4e98ba1 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -212,6 +212,8 @@ def __init__(self, model, block_size, dtype, enforce_eager, layer_names): self.prefill_use_fusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA', '1').lower() in ['1', 'true'] \ and not is_fake_hpu() + self.merged_prefill_attn_mask_compute = os.getenv('VLLM_MERGED_PREFILL_ATTN_MASK_COMPUTE', + '1').lower() in ['1', 'true'] self.block_size = block_size self.dtype = dtype self.layer_names = layer_names @@ -287,6 +289,28 @@ def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device, attn_metadata = prefill_metadata._replace(attn_bias=attn_bias) return attn_metadata + def _set_merged_attn_bias(self, attn_metadata, batch_size, max_seq_len, device,):# create a 2D causal attn mask to ensure I can only attend to the past + if attn_metadata is None or not attn_metadata.is_prompt: + return attn_metadata + if not self.merged_prefill_attn_mask_compute: + return attn_metadata + #TODO: Support batch_size > 1 + seq_lens = attn_metadata.seq_lens_tensor.tolist() + causal_attn_mask_tensor = torch.ones((batch_size, max_seq_len, max_seq_len), dtype=torch.bool, device=device) + start = 0 + for i in range(batch_size): + for seq_len in seq_lens: + # create triangular mask for each sequence + causal_mask = torch.triu(torch.ones((seq_len, seq_len), device=device, dtype=torch.bool), diagonal=1) + causal_attn_mask_tensor[i][start:start+seq_len, start:start+seq_len].copy_(causal_mask) + start += seq_len + causal_attn_mask_tensor = (torch.zeros_like(causal_attn_mask_tensor, device=device, dtype=self.dtype).masked_fill_( + causal_attn_mask_tensor, -10000)) # should be math(-inf) but -10000 is used for numerical stability + causal_attn_mask_tensor = causal_attn_mask_tensor.view(causal_attn_mask_tensor.shape[0], 1, causal_attn_mask_tensor.shape[1], causal_attn_mask_tensor.shape[2]) + + attn_metadata = attn_metadata._replace(attn_bias=causal_attn_mask_tensor) + return attn_metadata + def _set_block_mapping(self, metadata, batch_size, device, dtype): mask = torch.arange(0, self.block_size, @@ -340,7 +364,9 @@ def _set_indices_and_offsets(self, metadata, block_size, is_prompt): def _update_metadata(self, attn_metadata, batch_size, seq_len, device, dtype): - if attn_metadata.is_prompt: + if attn_metadata.is_prompt and attn_metadata.enable_merged_prefill: + attn_metadata = self._set_merged_attn_bias(attn_metadata, batch_size, seq_len, device) + elif attn_metadata.is_prompt: attn_metadata = self._set_attn_bias(attn_metadata, batch_size, seq_len, device, dtype) else: @@ -886,8 +912,8 @@ def _prepare_prompt( block_number = block_table[i // self.block_size] block_offset = i % self.block_size slot = block_number * self.block_size + block_offset - slot_mapping[-1].append(slot) - + slot_mapping[-1].append(slot) + max_query_len = max(query_lens) real_num_seqs = len(query_lens) assert max_query_len > 0