diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index b035ef7b10d3e3..276f94aebdbb9e 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -1184,6 +1184,11 @@ def get_encoder(self): def get_decoder(self): return self.decoder + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.get_input_embeddings()) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.get_input_embeddings()) + @add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, diff --git a/tests/models/mbart/test_modeling_mbart.py b/tests/models/mbart/test_modeling_mbart.py index db5b554e82d43a..deaa8b5dafe676 100644 --- a/tests/models/mbart/test_modeling_mbart.py +++ b/tests/models/mbart/test_modeling_mbart.py @@ -327,6 +327,43 @@ def test_generate_fp16(self): model.generate(input_ids, attention_mask=attention_mask) model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3) + def test_ensure_weights_are_shared(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs() + + config.tie_word_embeddings = True + model = MBartForConditionalGeneration(config) + + # MBart shares four weights. + # Not an issue to not have these correctly tied for torch.load, but it is an issue for safetensors. + self.assertEqual( + len( + { + model.get_output_embeddings().weight.data_ptr(), + model.get_input_embeddings().weight.data_ptr(), + model.base_model.decoder.embed_tokens.weight.data_ptr(), + model.base_model.encoder.embed_tokens.weight.data_ptr(), + } + ), + 1, + ) + + config.tie_word_embeddings = False + model = MBartForConditionalGeneration(config) + + # MBart shares four weights. + # Not an issue to not have these correctly tied for torch.load, but it is an issue for safetensors. + self.assertEqual( + len( + { + model.get_output_embeddings().weight.data_ptr(), + model.get_input_embeddings().weight.data_ptr(), + model.base_model.decoder.embed_tokens.weight.data_ptr(), + model.base_model.encoder.embed_tokens.weight.data_ptr(), + } + ), + 2, + ) + def assert_tensors_close(a, b, atol=1e-12, prefix=""): """If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error."""