Skip to content

Commit

Permalink
Mllama: fix tests (huggingface#34000)
Browse files Browse the repository at this point in the history
* fix tests

* don't need this

* style
  • Loading branch information
zucchini-nlp authored Oct 9, 2024
1 parent 295a90c commit 5ee52ae
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 92 deletions.
3 changes: 1 addition & 2 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,8 +813,7 @@ def _quantize(self, tensor, axis):
if is_optimum_quanto_available():
from optimum.quanto import quantize_weight

scale, zeropoint = self.optimizer(tensor, self.qtype, axis, self.q_group_size)
qtensor = quantize_weight(tensor, self.qtype, axis, scale, zeropoint, self.q_group_size)
qtensor = quantize_weight(tensor, self.qtype, axis, self.q_group_size)
return qtensor
elif is_quanto_available():
logger.warning_once(
Expand Down
10 changes: 10 additions & 0 deletions src/transformers/generation/candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,16 @@ def _prepare_attention_mask(model_kwargs: Dict[str, Any], new_length: int, is_en
model_kwargs[mask_key] = mask[:, :mask_length_diff]
elif mask_length_diff > 0:
model_kwargs[mask_key] = torch.cat([mask, mask.new_ones((mask.shape[0], mask_length_diff))], dim=-1)

if "cross_attention_mask" in model_kwargs:
# Mllama case is special and has another mask for cross attention model
cross_mask = model_kwargs["cross_attention_mask"]
if mask_length_diff < 0:
model_kwargs["cross_attention_mask"] = cross_mask[:, :mask_length_diff]
elif mask_length_diff > 0:
new_mask = cross_mask[:, -1:, :, :].repeat(1, mask_length_diff, 1, 1)
model_kwargs["cross_attention_mask"] = torch.cat([cross_mask, new_mask], dim=1)

return model_kwargs


Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/mllama/modeling_mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1042,7 +1042,7 @@ class MllamaPreTrainedModel(PreTrainedModel):
"MllamaSelfAttentionDecoderLayer",
]
_supports_cache_class = True
_supports_static_cache = False
_supports_static_cache = False # static cache cannot have different shapes for each layer
_supports_sdpa = True
_supports_quantized_cache = True

Expand Down Expand Up @@ -1980,6 +1980,8 @@ def forward(
MLLAMA_START_DOCSTRING,
)
class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin):
_supports_quantized_cache = False # quant cache not supported in encoder-decoder setting

def __init__(self, config: MllamaConfig):
super().__init__(config)
self.vocab_size = config.text_config.vocab_size
Expand Down
128 changes: 39 additions & 89 deletions tests/models/mllama/test_modeling_mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,14 +138,6 @@ def setUp(self):
def test_eager_matches_sdpa_generate(self):
super().test_eager_matches_sdpa_generate()

@unittest.skip(reason="The outputs don't match, no idea why")
def test_beam_search_low_memory(self):
pass

@unittest.skip(reason="Quanto test is borken")
def test_generate_with_quant_cache(self):
pass


class MllamaVisionText2TextModelTester:
def __init__(
Expand Down Expand Up @@ -208,6 +200,7 @@ def __init__(
self.image_size = 224
self.max_num_images = 1
self.max_image_tiles = 4
self.image_length = 904

def get_config(self):
return MllamaConfig(
Expand Down Expand Up @@ -329,6 +322,43 @@ def test_inputs_embeds_matches_input_ids(self):
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
self.assertTrue(torch.allclose(out_embeds, out_ids))

def _check_attentions_for_generate(
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
):
# Mllama has cross attention layers and those have a different shape than normal attention layers
self.assertIsInstance(attentions, tuple)
self.assertListEqual(
[isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions)
)
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups)

cross_attention_layers = self.model_tester.text_config["cross_attention_layers"]

for idx, iter_attentions in enumerate(attentions):
tgt_len = min_length + idx if not use_cache else 1
src_len = min_length + idx

expected_shape = (
batch_size * num_beam_groups,
config.num_attention_heads,
tgt_len,
src_len,
)

expected_shape_cross = (
batch_size * num_beam_groups,
config.num_attention_heads,
tgt_len,
self.model_tester.image_length,
)

expected_shapes = [
expected_shape if layer_idx not in cross_attention_layers else expected_shape_cross
for layer_idx in range(len(iter_attentions))
]

self.assertListEqual([layer_attention.shape for layer_attention in iter_attentions], expected_shapes)

@require_torch_sdpa
@slow
@is_flaky()
Expand All @@ -342,94 +372,14 @@ def test_eager_matches_sdpa_inference_1_bfloat16(self):
# A workaround to override parametrized test with flaky decorator
super().test_eager_matches_sdpa_inference_1_bfloat16()

@unittest.skip(reason="Static cache not supported")
def test_static_cache_matches_dynamic(self):
# TypeError: list indices must be integers or slices, not tuple
# TODO: @raushan, please look into this for new cache format
pass

@unittest.skip(reason="Mllama has dynamic control flow which is not yet supported by compile")
def test_generate_compile_fullgraph(self):
pass

@unittest.skip(reason="The outputs don't match, no idea why")
def test_beam_search_low_memory(self):
pass

@unittest.skip(reason="Mllama is not yet supported by compile")
@unittest.skip("For some unknown reasons the tests fails in CrossAttention layer when doing torch.sdpa(). ")
def test_sdpa_can_compile_dynamic(self):
# TODO: look into this, AttributeError("'tensor' object has no attribute '__pow__'")
# relevant issue: https://github.com/pytorch/pytorch/issues/133166
pass

@unittest.skip(reason="The test itself is broken") # TODO @zucchini-nlp
def test_generate_with_quant_cache(self):
pass

@unittest.skip(reason="AssertionError: Items in the second set but not the first: might be a setting issue")
def test_model_parallelism(self):
pass

@unittest.skip(reason="Failing test, need to fix")
def test_compile_cuda_graph_time(self):
pass

@unittest.skip(reason="Failing test, need to fix")
def test_torch_compile_fullgraph(self):
pass

@unittest.skip(reason="Device side assert triggered")
def test_assisted_decoding_with_num_logits_to_keep(self):
pass

@unittest.skip(reason="Failing test, need to fix")
def test_beam_sample_generate_dict_output(self):
pass

@unittest.skip(reason="Failing test, need to fix")
def test_beam_search_generate_dict_output(self):
pass

@unittest.skip(reason="Failing test, need to fix")
def test_constrained_beam_search_generate_dict_output(self):
pass

@unittest.skip(reason="Failing test, need to fix")
def test_dola_decoding_sample(self):
pass

@unittest.skip(reason="Failing test, need to fix")
def test_generate_methods_with_num_logits_to_keep(self):
pass

@unittest.skip(reason="Failing test, need to fix")
def test_greedy_generate_dict_outputs(self):
pass

@unittest.skip(reason="Failing test, need to fix")
def test_group_beam_search_generate_dict_output(self):
pass

@unittest.skip(reason="Failing test, need to fix")
def test_model_parallel_beam_search(self):
pass

@is_flaky() # TODO (joao, raushan) - investigate why this test is flaky (probably depends on the model initialization)
def test_new_cache_format_0(self):
super().test_new_cache_format_0()

@is_flaky() # TODO (joao, raushan) - investigate why this test is flaky (probably depends on the model initialization)
def test_new_cache_format_1(self):
super().test_new_cache_format_1()

@is_flaky() # TODO (joao, raushan) - investigate why this test is flaky (probably depends on the model initialization)
def test_new_cache_format_2(self):
super().test_new_cache_format_2()

@unittest.skip(reason="Failing test, need to fix")
def test_sample_generate_dict_output(self):
pass

def test_generate_text_only_with_cache(self):
"""
Tests that our cached generation with text-only inputs works. When mllama was introduced, this feature
Expand Down

0 comments on commit 5ee52ae

Please sign in to comment.