From 66e1af8614b3bb68aa9388dfa5b1d1dd9de16532 Mon Sep 17 00:00:00 2001 From: "Chendi.Xue" Date: Sat, 14 Dec 2024 04:22:04 +0000 Subject: [PATCH] Fix shape error detected by benchmark Signed-off-by: Chendi.Xue --- vllm/sequence.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/sequence.py b/vllm/sequence.py index 669124319c4f4..53c8a4b73b4e3 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1182,7 +1182,8 @@ def update(self, second_last_token_hidden_states: Optional[torch.Tensor] = None): """Update hidden states from target model invocation. Only used for decode steps""" - assert len(seq_group_metadata_list) == len(hidden_states) + if len(seq_group_metadata_list) < len(hidden_states): + hidden_states = hidden_states[:len(seq_group_metadata_list)] self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list)) self.hidden_states = torch.cat([self.hidden_states, hidden_states])