diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py index ac67ef4b37e49c..5aad7b23a8a672 100644 --- a/src/transformers/models/bark/modeling_bark.py +++ b/src/transformers/models/bark/modeling_bark.py @@ -1248,6 +1248,17 @@ def resize_token_embeddings( return model_embeds + def _tie_weights(self): + if getattr(self.config, "tie_word_embeddings", True): + self._tied_weights_keys = [] + output_embeddings = self.get_output_embeddings() + input_embeddings = self.get_input_embeddings() + + for i in range(self.config.n_codes_total - self.config.n_codes_given): + # self.input_embeds_layers[i + 1].weight = self.lm_heads[i].weight + self._tie_or_clone_weights(output_embeddings[i], input_embeddings[i + 1]) + self._tied_weights_keys.append(f"lm_heads.{i}.weight") + def tie_weights(self): """ Tie the weights between the input embeddings list and the output embeddings list.