From 5ee52ae0bc01424a038b5bc7fcaef166dde2e314 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Wed, 9 Oct 2024 14:02:56 +0200 Subject: [PATCH] Mllama: fix tests (#34000) * fix tests * don't need this * style --- src/transformers/cache_utils.py | 3 +- .../generation/candidate_generator.py | 10 ++ .../models/mllama/modeling_mllama.py | 4 +- tests/models/mllama/test_modeling_mllama.py | 128 ++++++------------ 4 files changed, 53 insertions(+), 92 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index b7e4ec914baebe..4e4a1ee26c12d7 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -813,8 +813,7 @@ def _quantize(self, tensor, axis): if is_optimum_quanto_available(): from optimum.quanto import quantize_weight - scale, zeropoint = self.optimizer(tensor, self.qtype, axis, self.q_group_size) - qtensor = quantize_weight(tensor, self.qtype, axis, scale, zeropoint, self.q_group_size) + qtensor = quantize_weight(tensor, self.qtype, axis, self.q_group_size) return qtensor elif is_quanto_available(): logger.warning_once( diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 54b3f709a9d899..fb7ed2f0b2f553 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -425,6 +425,16 @@ def _prepare_attention_mask(model_kwargs: Dict[str, Any], new_length: int, is_en model_kwargs[mask_key] = mask[:, :mask_length_diff] elif mask_length_diff > 0: model_kwargs[mask_key] = torch.cat([mask, mask.new_ones((mask.shape[0], mask_length_diff))], dim=-1) + + if "cross_attention_mask" in model_kwargs: + # Mllama case is special and has another mask for cross attention model + cross_mask = model_kwargs["cross_attention_mask"] + if mask_length_diff < 0: + model_kwargs["cross_attention_mask"] = cross_mask[:, :mask_length_diff] + elif mask_length_diff > 0: + new_mask = cross_mask[:, -1:, :, :].repeat(1, mask_length_diff, 1, 1) + model_kwargs["cross_attention_mask"] = torch.cat([cross_mask, new_mask], dim=1) + return model_kwargs diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index 34624d6ef8feb1..0bc77eaeec3324 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -1042,7 +1042,7 @@ class MllamaPreTrainedModel(PreTrainedModel): "MllamaSelfAttentionDecoderLayer", ] _supports_cache_class = True - _supports_static_cache = False + _supports_static_cache = False # static cache cannot have different shapes for each layer _supports_sdpa = True _supports_quantized_cache = True @@ -1980,6 +1980,8 @@ def forward( MLLAMA_START_DOCSTRING, ) class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin): + _supports_quantized_cache = False # quant cache not supported in encoder-decoder setting + def __init__(self, config: MllamaConfig): super().__init__(config) self.vocab_size = config.text_config.vocab_size diff --git a/tests/models/mllama/test_modeling_mllama.py b/tests/models/mllama/test_modeling_mllama.py index 85e54f707d7d2e..5c5ca3985ee08f 100644 --- a/tests/models/mllama/test_modeling_mllama.py +++ b/tests/models/mllama/test_modeling_mllama.py @@ -138,14 +138,6 @@ def setUp(self): def test_eager_matches_sdpa_generate(self): super().test_eager_matches_sdpa_generate() - @unittest.skip(reason="The outputs don't match, no idea why") - def test_beam_search_low_memory(self): - pass - - @unittest.skip(reason="Quanto test is borken") - def test_generate_with_quant_cache(self): - pass - class MllamaVisionText2TextModelTester: def __init__( @@ -208,6 +200,7 @@ def __init__( self.image_size = 224 self.max_num_images = 1 self.max_image_tiles = 4 + self.image_length = 904 def get_config(self): return MllamaConfig( @@ -329,6 +322,43 @@ def test_inputs_embeds_matches_input_ids(self): out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0] self.assertTrue(torch.allclose(out_embeds, out_ids)) + def _check_attentions_for_generate( + self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1 + ): + # Mllama has cross attention layers and those have a different shape than normal attention layers + self.assertIsInstance(attentions, tuple) + self.assertListEqual( + [isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions) + ) + self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups) + + cross_attention_layers = self.model_tester.text_config["cross_attention_layers"] + + for idx, iter_attentions in enumerate(attentions): + tgt_len = min_length + idx if not use_cache else 1 + src_len = min_length + idx + + expected_shape = ( + batch_size * num_beam_groups, + config.num_attention_heads, + tgt_len, + src_len, + ) + + expected_shape_cross = ( + batch_size * num_beam_groups, + config.num_attention_heads, + tgt_len, + self.model_tester.image_length, + ) + + expected_shapes = [ + expected_shape if layer_idx not in cross_attention_layers else expected_shape_cross + for layer_idx in range(len(iter_attentions)) + ] + + self.assertListEqual([layer_attention.shape for layer_attention in iter_attentions], expected_shapes) + @require_torch_sdpa @slow @is_flaky() @@ -342,94 +372,14 @@ def test_eager_matches_sdpa_inference_1_bfloat16(self): # A workaround to override parametrized test with flaky decorator super().test_eager_matches_sdpa_inference_1_bfloat16() - @unittest.skip(reason="Static cache not supported") - def test_static_cache_matches_dynamic(self): - # TypeError: list indices must be integers or slices, not tuple - # TODO: @raushan, please look into this for new cache format - pass - - @unittest.skip(reason="Mllama has dynamic control flow which is not yet supported by compile") - def test_generate_compile_fullgraph(self): - pass - - @unittest.skip(reason="The outputs don't match, no idea why") - def test_beam_search_low_memory(self): - pass - - @unittest.skip(reason="Mllama is not yet supported by compile") + @unittest.skip("For some unknown reasons the tests fails in CrossAttention layer when doing torch.sdpa(). ") def test_sdpa_can_compile_dynamic(self): - # TODO: look into this, AttributeError("'tensor' object has no attribute '__pow__'") - # relevant issue: https://github.com/pytorch/pytorch/issues/133166 - pass - - @unittest.skip(reason="The test itself is broken") # TODO @zucchini-nlp - def test_generate_with_quant_cache(self): pass @unittest.skip(reason="AssertionError: Items in the second set but not the first: might be a setting issue") def test_model_parallelism(self): pass - @unittest.skip(reason="Failing test, need to fix") - def test_compile_cuda_graph_time(self): - pass - - @unittest.skip(reason="Failing test, need to fix") - def test_torch_compile_fullgraph(self): - pass - - @unittest.skip(reason="Device side assert triggered") - def test_assisted_decoding_with_num_logits_to_keep(self): - pass - - @unittest.skip(reason="Failing test, need to fix") - def test_beam_sample_generate_dict_output(self): - pass - - @unittest.skip(reason="Failing test, need to fix") - def test_beam_search_generate_dict_output(self): - pass - - @unittest.skip(reason="Failing test, need to fix") - def test_constrained_beam_search_generate_dict_output(self): - pass - - @unittest.skip(reason="Failing test, need to fix") - def test_dola_decoding_sample(self): - pass - - @unittest.skip(reason="Failing test, need to fix") - def test_generate_methods_with_num_logits_to_keep(self): - pass - - @unittest.skip(reason="Failing test, need to fix") - def test_greedy_generate_dict_outputs(self): - pass - - @unittest.skip(reason="Failing test, need to fix") - def test_group_beam_search_generate_dict_output(self): - pass - - @unittest.skip(reason="Failing test, need to fix") - def test_model_parallel_beam_search(self): - pass - - @is_flaky() # TODO (joao, raushan) - investigate why this test is flaky (probably depends on the model initialization) - def test_new_cache_format_0(self): - super().test_new_cache_format_0() - - @is_flaky() # TODO (joao, raushan) - investigate why this test is flaky (probably depends on the model initialization) - def test_new_cache_format_1(self): - super().test_new_cache_format_1() - - @is_flaky() # TODO (joao, raushan) - investigate why this test is flaky (probably depends on the model initialization) - def test_new_cache_format_2(self): - super().test_new_cache_format_2() - - @unittest.skip(reason="Failing test, need to fix") - def test_sample_generate_dict_output(self): - pass - def test_generate_text_only_with_cache(self): """ Tests that our cached generation with text-only inputs works. When mllama was introduced, this feature