diff --git a/vllm/sequence.py b/vllm/sequence.py index 669124319c4f4..53c8a4b73b4e3 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1182,7 +1182,8 @@ def update(self, second_last_token_hidden_states: Optional[torch.Tensor] = None): """Update hidden states from target model invocation. Only used for decode steps""" - assert len(seq_group_metadata_list) == len(hidden_states) + if len(seq_group_metadata_list) < len(hidden_states): + hidden_states = hidden_states[:len(seq_group_metadata_list)] self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list)) self.hidden_states = torch.cat([self.hidden_states, hidden_states])