From 62d71f4083acccca1c1c9b0eea68db69d9ef759a Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 16 Jun 2023 14:43:43 +0100 Subject: [PATCH] Fix functional TF Whisper and modernize tests (#24301) * Revert whisper change and modify the test_compile_tf_model test * make fixup * Tweak test slightly * Add functional model saving to test * Ensure TF can infer shapes for data2vec * Add override for efficientformer * Mark test as slow --- .../data2vec/modeling_tf_data2vec_vision.py | 4 +- .../models/whisper/modeling_tf_whisper.py | 11 +- .../test_modeling_tf_data2vec_vision.py | 4 - .../test_modeling_tf_efficientformer.py | 18 +++ .../models/funnel/test_modeling_tf_funnel.py | 4 - .../models/lxmert/test_modeling_tf_lxmert.py | 49 -------- .../models/marian/test_modeling_tf_marian.py | 32 ------ tests/models/mbart/test_modeling_tf_mbart.py | 27 ----- .../mobilevit/test_modeling_tf_mobilevit.py | 4 - .../pegasus/test_modeling_tf_pegasus.py | 32 ------ .../segformer/test_modeling_tf_segformer.py | 4 - .../vit_mae/test_modeling_tf_vit_mae.py | 46 -------- tests/test_modeling_tf_common.py | 105 +++--------------- 13 files changed, 40 insertions(+), 300 deletions(-) diff --git a/src/transformers/models/data2vec/modeling_tf_data2vec_vision.py b/src/transformers/models/data2vec/modeling_tf_data2vec_vision.py index ee8bec20a019ab..a5953467cdd28e 100644 --- a/src/transformers/models/data2vec/modeling_tf_data2vec_vision.py +++ b/src/transformers/models/data2vec/modeling_tf_data2vec_vision.py @@ -1383,11 +1383,11 @@ def call( # only keep certain features, and reshape # note that we do +1 as the encoder_hidden_states also includes the initial embeddings features = [feature for idx, feature in enumerate(encoder_hidden_states) if idx + 1 in self.config.out_indices] - batch_size = shape_list(pixel_values)[0] patch_resolution = self.config.image_size // self.config.patch_size def reshape_features(x): - x = tf.reshape(x, (batch_size, patch_resolution, patch_resolution, -1)) + # We do it this way so TF can always infer the non-batch dims at compile time + x = tf.reshape(x, (-1, patch_resolution, patch_resolution, self.config.hidden_size)) return x features = [reshape_features(x[:, 1:, :]) for x in features] diff --git a/src/transformers/models/whisper/modeling_tf_whisper.py b/src/transformers/models/whisper/modeling_tf_whisper.py index f33340e1c06a68..4d6ecb85b59930 100644 --- a/src/transformers/models/whisper/modeling_tf_whisper.py +++ b/src/transformers/models/whisper/modeling_tf_whisper.py @@ -766,12 +766,11 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, past_key_ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] batch_size, seq_len = input_shape[0], input_shape[1] - if seq_len > 1: - combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length) - else: - combined_attention_mask = _expand_mask( - tf.ones((batch_size, seq_len + past_key_values_length)), tgt_len=seq_len - ) + combined_attention_mask = tf.cond( + tf.math.greater(seq_len, 1), + lambda: _make_causal_mask(input_shape, past_key_values_length=past_key_values_length), + lambda: _expand_mask(tf.ones((batch_size, seq_len + past_key_values_length)), tgt_len=seq_len), + ) if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] diff --git a/tests/models/data2vec/test_modeling_tf_data2vec_vision.py b/tests/models/data2vec/test_modeling_tf_data2vec_vision.py index 6a30c83ebaf941..320b5ede5c10b2 100644 --- a/tests/models/data2vec/test_modeling_tf_data2vec_vision.py +++ b/tests/models/data2vec/test_modeling_tf_data2vec_vision.py @@ -240,10 +240,6 @@ def test_for_image_segmentation(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_image_segmentation(*config_and_inputs) - @unittest.skip("Test was written for TF 1.x and isn't really relevant here") - def test_compile_tf_model(self): - pass - def test_attention_outputs(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config.return_dict = True diff --git a/tests/models/efficientformer/test_modeling_tf_efficientformer.py b/tests/models/efficientformer/test_modeling_tf_efficientformer.py index 5301aee561b0f0..059ff1ac129513 100644 --- a/tests/models/efficientformer/test_modeling_tf_efficientformer.py +++ b/tests/models/efficientformer/test_modeling_tf_efficientformer.py @@ -344,6 +344,24 @@ def test_attention_outputs(self): [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], ) + def test_compile_tf_model(self): + # We use a simplified version of this test for EfficientFormer because it requires training=False + # and Keras refuses to let us force that during functional construction + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + # Prepare our model + model = model_class(config) + # These are maximally general inputs for the model, with multiple None dimensions + # Hopefully this will catch any conditionals that fail for flexible shapes + functional_inputs = { + key: tf.keras.Input(shape=val.shape[1:], dtype=val.dtype, name=key) + for key, val in model.input_signature.items() + if key in model.dummy_inputs + } + outputs_dict = model(functional_inputs) + self.assertTrue(outputs_dict is not None) + # We will verify our results on an image of cute cats def prepare_img(): diff --git a/tests/models/funnel/test_modeling_tf_funnel.py b/tests/models/funnel/test_modeling_tf_funnel.py index 5aea7e4309b51e..051da46fadafce 100644 --- a/tests/models/funnel/test_modeling_tf_funnel.py +++ b/tests/models/funnel/test_modeling_tf_funnel.py @@ -390,10 +390,6 @@ def test_for_question_answering(self): def test_saved_model_creation(self): pass - def test_compile_tf_model(self): - # This test fails the CI. TODO Lysandre re-enable it - pass - @require_tf class TFFunnelBaseModelTest(TFModelTesterMixin, unittest.TestCase): diff --git a/tests/models/lxmert/test_modeling_tf_lxmert.py b/tests/models/lxmert/test_modeling_tf_lxmert.py index 9d97ddb462ccc9..a99495d008a1a6 100644 --- a/tests/models/lxmert/test_modeling_tf_lxmert.py +++ b/tests/models/lxmert/test_modeling_tf_lxmert.py @@ -532,55 +532,6 @@ def test_save_load(self): self.assert_outputs_same(after_outputs, outputs) - def test_compile_tf_model(self): - optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0) - loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) - metric = tf.keras.metrics.SparseCategoricalAccuracy("accuracy") - - for model_class in self.all_model_classes: - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common( - return_obj_labels="PreTraining" in model_class.__name__ - ) - - input_ids = tf.keras.Input( - batch_shape=(self.model_tester.batch_size, self.model_tester.seq_length), - name="input_ids", - dtype="int32", - ) - visual_feats = tf.keras.Input( - batch_shape=( - self.model_tester.batch_size, - self.model_tester.num_visual_features, - self.model_tester.visual_feat_dim, - ), - name="visual_feats", - dtype="int32", - ) - visual_pos = tf.keras.Input( - batch_shape=(self.model_tester.batch_size, self.model_tester.num_visual_features, 4), - name="visual_pos", - dtype="int32", - ) - - # Prepare our model - model = model_class(config) - - # Let's load it from the disk to be sure we can use pretrained weights - with tempfile.TemporaryDirectory() as tmpdirname: - outputs = model(self._prepare_for_class(inputs_dict, model_class)) # build the model - model.save_pretrained(tmpdirname) - model = model_class.from_pretrained(tmpdirname) - - outputs_dict = model(input_ids, visual_feats, visual_pos) - hidden_states = outputs_dict[0] - - # Add a dense layer on top to test integration with other keras modules - outputs = tf.keras.layers.Dense(2, activation="softmax", name="outputs")(hidden_states) - - # Compile extended model - extended_model = tf.keras.Model(inputs=[input_ids, visual_feats, visual_pos], outputs=[outputs]) - extended_model.compile(optimizer=optimizer, loss=loss, metrics=[metric]) - @tooslow def test_saved_model_creation(self): pass diff --git a/tests/models/marian/test_modeling_tf_marian.py b/tests/models/marian/test_modeling_tf_marian.py index 8833938f3b2e34..1a87f4e984a17b 100644 --- a/tests/models/marian/test_modeling_tf_marian.py +++ b/tests/models/marian/test_modeling_tf_marian.py @@ -16,7 +16,6 @@ from __future__ import annotations -import tempfile import unittest import warnings @@ -209,37 +208,6 @@ def test_decoder_model_past_large_inputs(self): config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common() self.model_tester.check_decoder_model_past_large_inputs(*config_and_inputs) - def test_compile_tf_model(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0) - loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) - metric = tf.keras.metrics.SparseCategoricalAccuracy("accuracy") - - model_class = self.all_generative_model_classes[0] - input_ids = { - "decoder_input_ids": tf.keras.Input(batch_shape=(2, 2000), name="decoder_input_ids", dtype="int32"), - "input_ids": tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32"), - } - - # Prepare our model - model = model_class(config) - model(self._prepare_for_class(inputs_dict, model_class)) # Model must be called before saving. - # Let's load it from the disk to be sure we can use pre-trained weights - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - model = model_class.from_pretrained(tmpdirname) - - outputs_dict = model(input_ids) - hidden_states = outputs_dict[0] - - # Add a dense layer on top to test integration with other keras modules - outputs = tf.keras.layers.Dense(2, activation="softmax", name="outputs")(hidden_states) - - # Compile extended model - extended_model = tf.keras.Model(inputs=[input_ids], outputs=[outputs]) - extended_model.compile(optimizer=optimizer, loss=loss, metrics=[metric]) - @tooslow def test_saved_model_creation(self): pass diff --git a/tests/models/mbart/test_modeling_tf_mbart.py b/tests/models/mbart/test_modeling_tf_mbart.py index 70a93acf8175a2..753f961d1f1a5a 100644 --- a/tests/models/mbart/test_modeling_tf_mbart.py +++ b/tests/models/mbart/test_modeling_tf_mbart.py @@ -15,7 +15,6 @@ from __future__ import annotations -import tempfile import unittest from transformers import AutoTokenizer, MBartConfig, is_tf_available @@ -118,32 +117,6 @@ def check_decoder_model_past_large_inputs(self, config, inputs_dict): output, past_key_values = outputs.to_tuple() past_key_values = past_key_values[1] - def test_compile_tf_model(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0) - loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) - metric = tf.keras.metrics.SparseCategoricalAccuracy("accuracy") - model_class = self.all_generative_model_classes[0] - input_ids = { - "decoder_input_ids": tf.keras.Input(batch_shape=(2, 2000), name="decoder_input_ids", dtype="int32"), - "input_ids": tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32"), - } - # Prepare our model - model = model_class(config) - model(self._prepare_for_class(inputs_dict, model_class)) # Model must be called before saving. - # Let's load it from the disk to be sure we can use pretrained weights - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - model = model_class.from_pretrained(tmpdirname) - outputs_dict = model(input_ids) - hidden_states = outputs_dict[0] - # Add a dense layer on top to test integration with other keras modules - outputs = tf.keras.layers.Dense(2, activation="softmax", name="outputs")(hidden_states) - # Compile extended model - extended_model = tf.keras.Model(inputs=[input_ids], outputs=[outputs]) - extended_model.compile(optimizer=optimizer, loss=loss, metrics=[metric]) - def prepare_mbart_inputs_dict( config, diff --git a/tests/models/mobilevit/test_modeling_tf_mobilevit.py b/tests/models/mobilevit/test_modeling_tf_mobilevit.py index 37d7db39e68d3e..c635feae8f4212 100644 --- a/tests/models/mobilevit/test_modeling_tf_mobilevit.py +++ b/tests/models/mobilevit/test_modeling_tf_mobilevit.py @@ -199,10 +199,6 @@ def test_model_common_attributes(self): def test_attention_outputs(self): pass - @unittest.skip("Test was written for TF 1.x and isn't really relevant here") - def test_compile_tf_model(self): - pass - def test_forward_signature(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/pegasus/test_modeling_tf_pegasus.py b/tests/models/pegasus/test_modeling_tf_pegasus.py index 0bd1ed25e7bb06..dcd0479e2cbf35 100644 --- a/tests/models/pegasus/test_modeling_tf_pegasus.py +++ b/tests/models/pegasus/test_modeling_tf_pegasus.py @@ -15,7 +15,6 @@ from __future__ import annotations -import tempfile import unittest from transformers import AutoTokenizer, PegasusConfig, is_tf_available @@ -207,37 +206,6 @@ def test_decoder_model_past_large_inputs(self): config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common() self.model_tester.check_decoder_model_past_large_inputs(*config_and_inputs) - def test_compile_tf_model(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0) - loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) - metric = tf.keras.metrics.SparseCategoricalAccuracy("accuracy") - - model_class = self.all_generative_model_classes[0] - input_ids = { - "decoder_input_ids": tf.keras.Input(batch_shape=(2, 2000), name="decoder_input_ids", dtype="int32"), - "input_ids": tf.keras.Input(batch_shape=(2, 2000), name="input_ids", dtype="int32"), - } - - # Prepare our model - model = model_class(config) - model(self._prepare_for_class(inputs_dict, model_class)) # Model must be called before saving. - # Let's load it from the disk to be sure we can use pretrained weights - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - model = model_class.from_pretrained(tmpdirname) - - outputs_dict = model(input_ids) - hidden_states = outputs_dict[0] - - # Add a dense layer on top to test integration with other keras modules - outputs = tf.keras.layers.Dense(2, activation="softmax", name="outputs")(hidden_states) - - # Compile extended model - extended_model = tf.keras.Model(inputs=[input_ids], outputs=[outputs]) - extended_model.compile(optimizer=optimizer, loss=loss, metrics=[metric]) - @tooslow def test_saved_model_creation(self): pass diff --git a/tests/models/segformer/test_modeling_tf_segformer.py b/tests/models/segformer/test_modeling_tf_segformer.py index b831e8ddbc2b14..d3317b2079be5b 100644 --- a/tests/models/segformer/test_modeling_tf_segformer.py +++ b/tests/models/segformer/test_modeling_tf_segformer.py @@ -186,10 +186,6 @@ def test_inputs_embeds(self): def test_model_common_attributes(self): pass - @unittest.skip("Test was written for TF 1.x and isn't really relevant here") - def test_compile_tf_model(self): - pass - def test_forward_signature(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/vit_mae/test_modeling_tf_vit_mae.py b/tests/models/vit_mae/test_modeling_tf_vit_mae.py index d5e16e96385068..34e5a636d82533 100644 --- a/tests/models/vit_mae/test_modeling_tf_vit_mae.py +++ b/tests/models/vit_mae/test_modeling_tf_vit_mae.py @@ -283,52 +283,6 @@ def check_pt_tf_models(self, tf_model, pt_model, tf_inputs_dict): super().check_pt_tf_models(tf_model, pt_model, tf_inputs_dict) - # overwrite from common since TFViTMAEForPretraining outputs loss along with - # logits and mask indices. loss and mask indices are not suitable for integration - # with other keras modules. - def test_compile_tf_model(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0) - loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) - metric = tf.keras.metrics.SparseCategoricalAccuracy("accuracy") - - for model_class in self.all_model_classes: - # `pixel_values` implies that the input is an image - inputs = tf.keras.Input( - batch_shape=( - 3, - self.model_tester.num_channels, - self.model_tester.image_size, - self.model_tester.image_size, - ), - name="pixel_values", - dtype="float32", - ) - - # Prepare our model - model = model_class(config) - model(self._prepare_for_class(inputs_dict, model_class)) # Model must be called before saving. - # Let's load it from the disk to be sure we can use pretrained weights - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname, saved_model=False) - model = model_class.from_pretrained(tmpdirname) - - outputs_dict = model(inputs) - hidden_states = outputs_dict[0] - - # `TFViTMAEForPreTraining` outputs are not recommended to be used for - # downstream application. This is just to check if the outputs of - # `TFViTMAEForPreTraining` can be integrated with other keras modules. - if model_class.__name__ == "TFViTMAEForPreTraining": - hidden_states = outputs_dict["logits"] - - # Add a dense layer on top to test integration with other keras modules - outputs = tf.keras.layers.Dense(2, activation="softmax", name="outputs")(hidden_states) - - # Compile extended model - extended_model = tf.keras.Model(inputs=[inputs], outputs=[outputs]) - extended_model.compile(optimizer=optimizer, loss=loss, metrics=[metric]) - # overwrite from common since TFViTMAEForPretraining has random masking, we need to fix the noise # to generate masks during test def test_keras_save_load(self): diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index a4de6e8f471e3e..3c3ca75ff6bb15 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -685,105 +685,30 @@ def test_pt_tf_model_equivalence(self, allow_missing_keys=False): if tf_inputs_dict_with_labels: self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict_with_labels) + @slow def test_compile_tf_model(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - max_input = getattr(self.model_tester, "max_position_embeddings", 512) - optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0) - loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) - metric = tf.keras.metrics.SparseCategoricalAccuracy("accuracy") + config, _ = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: - if model_class.__name__ in ["TFSpeech2TextModel", "TFSpeech2TextForConditionalGeneration"]: - inputs = { - "decoder_input_ids": tf.keras.Input( - batch_shape=(2, max_input), - name="decoder_input_ids", - dtype="int32", - ), - "input_features": tf.keras.Input( - batch_shape=( - 2, - max_input, - self.model_tester.input_feat_per_channel * self.model_tester.input_channels, - ), - name="input_features", - dtype="float32", - ), - } - elif model_class.__name__ in ["TFWhisperModel", "TFWhisperForConditionalGeneration"]: - inputs = { - "decoder_input_ids": tf.keras.Input( - batch_shape=(2, max_input), - name="decoder_input_ids", - dtype="int32", - ), - "input_features": tf.keras.Input( - batch_shape=( - 2, - self.model_tester.num_mel_bins, - self.model_tester.seq_length, - ), - name="input_features", - dtype="float32", - ), - } - elif self.is_encoder_decoder: - inputs = { - "decoder_input_ids": tf.keras.Input( - batch_shape=(2, max_input), - name="decoder_input_ids", - dtype="int32", - ), - "input_ids": tf.keras.Input(batch_shape=(2, max_input), name="input_ids", dtype="int32"), - } - # `pixel_values` implies that the input is an image - elif model_class.main_input_name == "pixel_values": - inputs = tf.keras.Input( - batch_shape=( - 3, - self.model_tester.num_channels, - self.model_tester.image_size, - self.model_tester.image_size, - ), - name="pixel_values", - dtype="float32", - ) - elif model_class.__name__ in ["TFCLIPModel", "TFGroupViTModel", "TFBlipModel"]: - inputs = { - "input_ids": tf.keras.Input(batch_shape=(3, max_input), name="input_ids", dtype="int32"), - "pixel_values": tf.keras.Input( - batch_shape=( - 3, - self.model_tester.vision_model_tester.num_channels, - self.model_tester.vision_model_tester.image_size, - self.model_tester.vision_model_tester.image_size, - ), - name="pixel_values", - dtype="float32", - ), - } - elif model_class in get_values(TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING): - inputs = tf.keras.Input(batch_shape=(4, 2, max_input), name="input_ids", dtype="int32") - else: - inputs = tf.keras.Input(batch_shape=(2, max_input), name="input_ids", dtype="int32") - # Prepare our model model = model_class(config) - model(self._prepare_for_class(inputs_dict, model_class)) # Model must be called before saving. - # Let's load it from the disk to be sure we can use pretrained weights - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname, saved_model=False) - model = model_class.from_pretrained(tmpdirname) + # These are maximally general inputs for the model, with multiple None dimensions + # Hopefully this will catch any conditionals that fail for flexible shapes + functional_inputs = { + key: tf.keras.Input(shape=val.shape[1:], dtype=val.dtype, name=key) + for key, val in model.input_signature.items() + if key in model.dummy_inputs + } + outputs_dict = model(functional_inputs) - outputs_dict = model(inputs) hidden_states = outputs_dict[0] - # Add a dense layer on top to test integration with other keras modules - outputs = tf.keras.layers.Dense(2, activation="softmax", name="outputs")(hidden_states) - # Compile extended model - extended_model = tf.keras.Model(inputs=[inputs], outputs=[outputs]) - extended_model.compile(optimizer=optimizer, loss=loss, metrics=[metric]) + functional_model = tf.keras.Model(inputs=functional_inputs, outputs=hidden_states) + model_out = functional_model.predict(model.dummy_inputs) # Check we can pass inputs with the Keras API + self.assertTrue(model_out is not None) + with tempfile.TemporaryDirectory() as tmpdirname: + functional_model.save(tmpdirname) # Ensure we can save/export the whole functional model def test_keyword_and_dict_args(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()