diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index be8abe92d2a0d..abc4222715380 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -19,6 +19,7 @@ HPUPagedAttentionMetadata) from vllm.logger import init_logger from vllm.utils import is_fake_hpu +from vllm.model_executor.models.utils import split_and_pad_to_length logger = init_logger(__name__) @@ -30,28 +31,6 @@ logger.warning("Could not import HPU FusedSDPA kernel. " "vLLM will use native implementation.") -def split_and_pad_to_length(input, target_length, seq_lens_tensor_list): - # we need to copy the key and value tensors to the padded tensors - # shape is [bacth_size, entire_seq_len, num_kv_heads, head_size] - padded_list = torch.split_with_sizes(input[:sum(seq_lens_tensor_list)], seq_lens_tensor_list, dim=0) - - padded_tensor = torch.nn.utils.rnn.pad_sequence(padded_list, batch_first=True) - p3d = (0, 0, 0, 0, 0, target_length - padded_tensor.size(1)) - padded_tensor = torch.nn.functional.pad(padded_tensor, p3d, value=0) - return padded_tensor - -def split_and_pad_to_length_2(input, target_length, seq_lens_tensor_list): - # we need to copy the key and value tensors to the padded tensors - # shape is [bacth_size, entire_seq_len, num_kv_heads, head_size] - padded_tensor = torch.zeros((len(seq_lens_tensor_list), target_length, input.size(1), input.size(2)), device=input.device, dtype=input.dtype) - - start = 0 - for i in range(len(seq_lens_tensor_list)): - padded_tensor[i, :seq_lens_tensor_list[i], :, :] = input[start: start + seq_lens_tensor_list[i], :, :] - start = start + seq_lens_tensor_list[i] - - return padded_tensor - def prompt_attention( query: torch.Tensor, key: torch.Tensor, @@ -306,8 +285,6 @@ def forward( padded_key_tensor = padded_key_tensor.flatten(0, 1).unflatten(0, (block_indices.size(0), -1)) padded_value_tensor = padded_value_tensor.flatten(0, 1).unflatten(0, (block_indices.size(0), -1)) - #seq_lens_tensor_merged = torch.tensor(sum(seq_lens_tensor_list), device=seq_lens_tensor.device, dtype=seq_lens_tensor.dtype).unsqueeze(0) - seq_lens_tensor_merged = seq_lens_tensor if kv_cache is not None: key_cache, value_cache = HPUPagedAttention.split_kv_cache( kv_cache, self.num_kv_heads, self.head_size) @@ -320,7 +297,6 @@ def forward( if attn_metadata.is_prompt: key = key.unflatten(0, (block_indices.size(0), -1)) value = value.unflatten(0, (block_indices.size(0), -1)) - seq_lens_tensor_merged = seq_lens_tensor if kv_cache is not None: key_cache, value_cache = HPUPagedAttention.split_kv_cache( kv_cache, self.num_kv_heads, self.head_size) @@ -369,7 +345,7 @@ def forward( matmul_qk_op=self.matmul_qk, softmax_op=self.softmax, matmul_av_op=self.matmul_av, - valid_seq_lengths=seq_lens_tensor_merged, + valid_seq_lengths=seq_lens_tensor, fsdpa_op=self.fused_scaled_dot_product_attention, ) else: diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 44d34a4e3f20a..c4e1f0778ae2b 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -56,7 +56,7 @@ from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) + maybe_prefix, split_and_pad_to_length) is_hpu = current_platform.is_hpu() @@ -490,6 +490,12 @@ def forward( "residual": residual }) + # we need to split result before do RMSNorm + if attn_metadata.enable_merged_prefill and attn_metadata.is_prompt: + max_len=attn_metadata.slot_mapping.size(1) + seq_lens_tensor_list = attn_metadata.seq_lens_tensor.tolist() + hidden_states = split_and_pad_to_length(hidden_states.view(-1, hidden_states.size(2)), max_len, seq_lens_tensor_list) + residual = split_and_pad_to_length(residual.view(-1, hidden_states.size(2)), max_len, seq_lens_tensor_list) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 269b66806adf4..72d01e1ca4906 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -666,3 +666,26 @@ def extract_layer_index(layer_name: str) -> int: assert len(int_vals) == 1, (f"layer name {layer_name} should" " only contain one integer") return int_vals[0] + +def split_and_pad_to_length(input, target_length, seq_lens_tensor_list): + # we need to copy the key and value tensors to the padded tensors + # shape is [bacth_size, entire_seq_len, num_kv_heads, head_size] + padded_list = torch.split_with_sizes(input[:sum(seq_lens_tensor_list)], seq_lens_tensor_list, dim=0) + + padded_tensor = torch.nn.utils.rnn.pad_sequence(padded_list, batch_first=True) + pad_shape = [0] * (input.dim() - 1) * 2 + pad_shape += [0, target_length - padded_tensor.size(1)] + padded_tensor = torch.nn.functional.pad(padded_tensor, pad_shape, value=0) + return padded_tensor + +def split_and_pad_to_length_2(input, target_length, seq_lens_tensor_list): + # we need to copy the key and value tensors to the padded tensors + # shape is [bacth_size, entire_seq_len, num_kv_heads, head_size] + padded_tensor = torch.zeros((len(seq_lens_tensor_list), target_length, input.size(1), input.size(2)), device=input.device, dtype=input.dtype) + + start = 0 + for i in range(len(seq_lens_tensor_list)): + padded_tensor[i, :seq_lens_tensor_list[i], :, :] = input[start: start + seq_lens_tensor_list[i], :, :] + start = start + seq_lens_tensor_list[i] + + return padded_tensor \ No newline at end of file