diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index 54ad4d5a504..155d9e3e6ab 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -569,10 +569,11 @@ def forward( batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) # Get the target length - target_seqlen = first_layer_past_key_value.shape[-1] + 1 + target_length = input_ids.shape[1] + past_length = first_layer_past_key_value.shape[-1] extended_attention_mask = torch.ones( - (attention_mask.shape[0], target_seqlen - attention_mask.shape[1]), + (attention_mask.shape[0], past_length), dtype=attention_mask.dtype, device=attention_mask.device, ) @@ -587,7 +588,7 @@ def forward( # Zero-out the places where we don't need to attend extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 - attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1) + attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 outputs = self.language_model( diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index dda9549a4f2..1b20353410c 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -441,10 +441,10 @@ def forward( if past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1: # Retrieve the first layer to inspect the logits and mask out the hidden states # that are set to 0 - first_layer_past_key_value = past_key_values[0][0][:, 0, :, :] + first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] # Sum all dimensions of head_dim (-1) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 - batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-1) == 0) + batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) target_length = input_ids.shape[1] past_length = first_layer_past_key_value.shape[-1] diff --git a/tests/models/llava_next/test_modeling_llava_next.py b/tests/models/llava_next/test_modeling_llava_next.py index 7e4469f306b..1c7e3200904 100644 --- a/tests/models/llava_next/test_modeling_llava_next.py +++ b/tests/models/llava_next/test_modeling_llava_next.py @@ -423,7 +423,7 @@ def test_small_model_integration_test(self): output = model(**inputs) expected_slice = torch.tensor( - [[-4.7695, -4.5664, -0.2786], [-10.6172, -10.8906, -2.5234], [-6.7344, -7.2422, -0.6758]], + [[-4.7695, -4.5664, -0.2786], [-10.6250, -10.8906, -2.5254], [-6.7383, -7.2461, -0.6787]], dtype=torch.float32, device=torch_device, )