Skip to content

Commit

Permalink
Even more TF test fixes (huggingface#28146)
Browse files Browse the repository at this point in the history
* Fix vision text dual encoder

* Small cleanup for wav2vec2 (not fixed yet)

* Small fix for vision_encoder_decoder

* Fix SAM builds

* Update TFBertTokenizer test with modern exporting + tokenizer

* Fix DeBERTa

* Fix DeBERTav2

* Try RAG fix but it's impossible to test locally

* Actually fix RAG now that I got FAISS working somehow

* Fix Wav2Vec2, add sermon

* Fix Hubert
  • Loading branch information
Rocketknight1 authored Dec 21, 2023
1 parent f9a98c4 commit 260b9d2
Show file tree
Hide file tree
Showing 11 changed files with 46 additions and 39 deletions.
2 changes: 1 addition & 1 deletion src/transformers/modeling_tf_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def convert_tf_weight_name_to_pt_weight_name(
transposed with regards to each other
"""
if name_scope is not None:
if not tf_name.startswith(name_scope):
if not tf_name.startswith(name_scope) and "final_logits_bias" not in tf_name:
raise ValueError(
f"Weight name {tf_name} does not start with name_scope {name_scope}. This is an internal error "
"in Transformers, so (unless you were doing something really evil) please open an issue to report it!"
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/deberta/modeling_tf_deberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,10 +638,10 @@ def build(self, input_shape=None):
self.pos_dropout.build(None)
if getattr(self, "pos_proj", None) is not None:
with tf.name_scope(self.pos_proj.name):
self.pos_proj.build(None)
self.pos_proj.build([self.config.hidden_size])
if getattr(self, "pos_q_proj", None) is not None:
with tf.name_scope(self.pos_q_proj.name):
self.pos_q_proj.build(None)
self.pos_q_proj.build([self.config.hidden_size])

def transpose_for_scores(self, tensor: tf.Tensor) -> tf.Tensor:
shape = shape_list(tensor)[:-1] + [self.num_attention_heads, -1]
Expand Down
11 changes: 5 additions & 6 deletions src/transformers/models/deberta_v2/modeling_tf_deberta_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,9 @@ def __init__(self, config: DebertaV2Config, **kwargs):
self.config = config

def build(self, input_shape=None):
if self.built:
return
self.built = True
with tf.name_scope("conv"):
self.conv_kernel = self.add_weight(
name="kernel",
Expand All @@ -371,13 +374,9 @@ def build(self, input_shape=None):
self.conv_bias = self.add_weight(
name="bias", shape=[self.config.hidden_size], initializer=tf.zeros_initializer()
)
return
if self.built:
return
self.built = True
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build(None)
self.LayerNorm.build([None, None, self.config.hidden_size])
if getattr(self, "dropout", None) is not None:
with tf.name_scope(self.dropout.name):
self.dropout.build(None)
Expand Down Expand Up @@ -453,7 +452,7 @@ def build(self, input_shape=None):
self.conv.build(None)
if getattr(self, "LayerNorm", None) is not None:
with tf.name_scope(self.LayerNorm.name):
self.LayerNorm.build([None, None, self.config.hidden_size])
self.LayerNorm.build([None, self.config.hidden_size])
if getattr(self, "layer", None) is not None:
for layer in self.layer:
with tf.name_scope(layer.name):
Expand Down
9 changes: 4 additions & 5 deletions src/transformers/models/hubert/modeling_tf_hubert.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,6 @@ def __init__(self, filters, kernel_size, groups, explicit_padding, **kwargs):
)
self.explicit_padding = explicit_padding
self.filter_axis = 2
self.initialized = False
self.kernel_norm_axes = tf.constant([0, 1])

def _init_norm(self):
Expand All @@ -428,13 +427,13 @@ def build(self, input_shape):
dtype=self.weight_v.dtype,
trainable=True,
)
self._init_norm()
self.bias = self.add_weight(name="bias", shape=(self.filters,), initializer="zeros", trainable=True)

def call(self, inputs):
if not self.initialized:
self._init_norm()
self.initialized = True

# TODO Matt: Assigning to attributes in call() is deeply sinful in TensorFlow, as it should be idempotent.
# This whole layer should be replaced by a layer that doesn't inherit from Conv1D, but instead calls
# a functional 1d convolution with normalized weights that it generates (but does not store!)
self._normalize_kernel()

padded_inputs = tf.pad(inputs, ((0, 0), (self.explicit_padding, self.explicit_padding), (0, 0)))
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/models/rag/modeling_tf_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,15 @@ def call(
generator_dec_attentions=gen_outputs.decoder_attentions,
)

def build(self, input_shape=None):
if self.built:
return
self.built = True
with tf.name_scope(self.generator.name):
self.generator.build(None)
with tf.name_scope(self.question_encoder.name):
self.question_encoder.build(None)


@add_start_docstrings_to_model_forward(
"""
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/sam/modeling_tf_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,9 @@ def build(self, input_shape=None):
if getattr(self, "iou_prediction_head", None) is not None:
with tf.name_scope(self.iou_prediction_head.name):
self.iou_prediction_head.build(None)
for mlp in self.output_hypernetworks_mlps:
with tf.name_scope(mlp.name):
mlp.build(None)

def call(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,16 +247,16 @@ def tf_to_pt_weight_rename(self, tf_weight):
# However, the name of that extra layer is the name of the MainLayer in the base model.
if "vision_model" in tf_weight:
if tf_weight.count("vision_model") == 1:
return re.sub(r"vision_model\..*?\.", "vision_model.", tf_weight)
return (re.sub(r"vision_model\..*?\.", "vision_model.", tf_weight),)
elif tf_weight.count("vision_model") == 2:
return re.sub(r"vision_model\..*?\.vision_model", "vision_model.vision_model", tf_weight)
return (re.sub(r"vision_model\..*?\.vision_model", "vision_model.vision_model", tf_weight),)
else:
raise ValueError(
f"Unexpected weight name {tf_weight}. Please file an issue on the"
" Transformers repo to let us know about this error!"
)
elif "text_model" in tf_weight:
return re.sub(r"text_model\..*?\.", "text_model.", tf_weight)
return (re.sub(r"text_model\..*?\.", "text_model.", tf_weight),)
else:
return (tf_weight,)

Expand Down Expand Up @@ -598,7 +598,7 @@ def from_vision_text_pretrained(
if text_model.name != "text_model":
raise ValueError("text model must be created with the name `text_model`.")

model.build() # Ensure model is fully built
model.build_in_name_scope() # Ensure model is fully built

return model

Expand Down
17 changes: 8 additions & 9 deletions src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,6 @@ def __init__(self, filters, kernel_size, groups, explicit_padding, **kwargs):
)
self.explicit_padding = explicit_padding
self.filter_axis = 2
self.initialized = False
self.kernel_norm_axes = tf.constant([0, 1])

def _init_norm(self):
Expand All @@ -462,13 +461,13 @@ def build(self, input_shape):
dtype=self.weight_v.dtype,
trainable=True,
)
self._init_norm()
self.bias = self.add_weight(name="bias", shape=(self.filters,), initializer="zeros", trainable=True)

def call(self, inputs):
if not self.initialized:
self._init_norm()
self.initialized = True

# TODO Matt: Assigning to attributes in call() is deeply sinful in TensorFlow, as it should be idempotent.
# This whole layer should be replaced by a layer that doesn't inherit from Conv1D, but instead calls
# a functional 1d convolution with normalized weights that it generates (but does not store!)
self._normalize_kernel()

padded_inputs = tf.pad(inputs, ((0, 0), (self.explicit_padding, self.explicit_padding), (0, 0)))
Expand Down Expand Up @@ -1208,13 +1207,13 @@ def __init__(self, config: Wav2Vec2Config, **kwargs):
self.encoder = TFWav2Vec2Encoder(config, name="encoder")

def build(self, input_shape=None):
self.masked_spec_embed = self.add_weight(
shape=(self.config.hidden_size,), initializer="uniform", trainable=True, name="masked_spec_embed"
)

if self.built:
return
self.built = True
if self.config.mask_time_prob > 0.0 or self.config.mask_feature_prob > 0.0:
self.masked_spec_embed = self.add_weight(
shape=(self.config.hidden_size,), initializer="uniform", trainable=True, name="masked_spec_embed"
)
if getattr(self, "feature_extractor", None) is not None:
with tf.name_scope(self.feature_extractor.name):
self.feature_extractor.build(None)
Expand Down
19 changes: 7 additions & 12 deletions tests/models/bert/test_tokenization_bert_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(self, tokenizer):

def call(self, inputs):
tokenized = self.tokenizer(inputs)
out = self.bert(**tokenized)
out = self.bert(tokenized)
return out["pooler_output"]


Expand All @@ -41,13 +41,8 @@ class BertTokenizationTest(unittest.TestCase):
def setUp(self):
super().setUp()

self.tokenizers = [
BertTokenizer.from_pretrained(checkpoint) for checkpoint in (TOKENIZER_CHECKPOINTS * 2)
] # repeat for when fast_bert_tokenizer=false
self.tf_tokenizers = [TFBertTokenizer.from_pretrained(checkpoint) for checkpoint in TOKENIZER_CHECKPOINTS] + [
TFBertTokenizer.from_pretrained(checkpoint, use_fast_bert_tokenizer=False)
for checkpoint in TOKENIZER_CHECKPOINTS
]
self.tokenizers = [BertTokenizer.from_pretrained(checkpoint) for checkpoint in TOKENIZER_CHECKPOINTS]
self.tf_tokenizers = [TFBertTokenizer.from_pretrained(checkpoint) for checkpoint in TOKENIZER_CHECKPOINTS]
assert len(self.tokenizers) == len(self.tf_tokenizers)

self.test_sentences = [
Expand Down Expand Up @@ -94,15 +89,15 @@ def test_graph_mode(self):
self.assertTrue(tf.reduce_all(eager_outputs[key] == compiled_outputs[key]))

@slow
def test_saved_model(self):
def test_export_for_inference(self):
for tf_tokenizer in self.tf_tokenizers:
model = ModelToSave(tokenizer=tf_tokenizer)
test_inputs = tf.convert_to_tensor(self.test_sentences)
out = model(test_inputs) # Build model with some sample inputs
with TemporaryDirectory() as tempdir:
save_path = Path(tempdir) / "saved.model"
model.save(save_path)
loaded_model = tf.keras.models.load_model(save_path)
loaded_output = loaded_model(test_inputs)
model.export(save_path)
loaded_model = tf.saved_model.load(save_path)
loaded_output = loaded_model.serve(test_inputs)
# We may see small differences because the loaded model is compiled, so we need an epsilon for the test
self.assertLessEqual(tf.reduce_max(tf.abs(out - loaded_output)), 1e-5)
2 changes: 2 additions & 0 deletions tests/models/rag/test_modeling_tf_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,6 +1005,7 @@ def test_rag_sequence_from_pretrained(self):
retriever=rag_retriever,
config=rag_config,
)
rag_sequence.build_in_name_scope()
# check that the from pretrained methods work
rag_sequence.save_pretrained(tmp_dirname)
rag_sequence.from_pretrained(tmp_dirname, retriever=rag_retriever)
Expand Down Expand Up @@ -1056,6 +1057,7 @@ def test_rag_token_from_pretrained(self):
retriever=rag_retriever,
config=rag_config,
)
rag_token.build_in_name_scope()
# check that the from pretrained methods work
rag_token.save_pretrained(tmp_dirname)
rag_token.from_pretrained(tmp_dirname, retriever=rag_retriever)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -858,6 +858,7 @@ def test_encoder_decoder_from_pretrained(self):
pretrained_encoder_dir,
pretrained_decoder_dir,
)
enc_dec_model.build_in_name_scope()
# check that the from pretrained methods work
enc_dec_model.save_pretrained(tmp_dirname)
enc_dec_model = TFVisionEncoderDecoderModel.from_pretrained(tmp_dirname)
Expand Down

0 comments on commit 260b9d2

Please sign in to comment.