From 954d782ae45dc61873c21e067273cf851a223e78 Mon Sep 17 00:00:00 2001 From: calpt Date: Thu, 2 Mar 2023 22:25:20 +0100 Subject: [PATCH] Only compute fusion reg loss if fusion layer is trained (#505) --- src/transformers/adapters/model_mixin.py | 10 ++- src/transformers/adapters/trainer.py | 3 +- tests_adapters/test_adapter_trainer.py | 92 ++++++++++++++++++++---- 3 files changed, 86 insertions(+), 19 deletions(-) diff --git a/src/transformers/adapters/model_mixin.py b/src/transformers/adapters/model_mixin.py index a87abb4fe4..54bae8d654 100644 --- a/src/transformers/adapters/model_mixin.py +++ b/src/transformers/adapters/model_mixin.py @@ -858,15 +858,19 @@ def forward_context(self, context: ForwardContext, *args, **kwargs): context.adapter_fusion_attentions = defaultdict(dict) def get_fusion_regularization_loss(self): - reg_loss = 0.0 + reg_loss = None target = torch.zeros((self.config.hidden_size, self.config.hidden_size)).fill_diagonal_(1.0).to(self.device) for i, layer in self.iter_layers(): for module in layer.modules(): if isinstance(module, AdapterLayer): for _, layer_fusion in module.adapter_fusion_layer.items(): - if hasattr(layer_fusion, "value"): - reg_loss += 0.01 * (target - layer_fusion.value.weight).pow(2).sum() + if hasattr(layer_fusion, "value") and layer_fusion.value.weight.requires_grad: + layer_reg_loss = 0.01 * (target - layer_fusion.value.weight).pow(2).sum() + if reg_loss is None: + reg_loss = layer_reg_loss + else: + reg_loss += layer_reg_loss return reg_loss diff --git a/src/transformers/adapters/trainer.py b/src/transformers/adapters/trainer.py index 81f9ae74f7..5746294d89 100644 --- a/src/transformers/adapters/trainer.py +++ b/src/transformers/adapters/trainer.py @@ -259,7 +259,8 @@ def on_step_end(self, args: TrainingArguments, state: TrainerState, control: Tra model = kwargs.pop("model") if self.trainer.train_adapter_fusion: fusion_reg_loss = model.base_model.get_fusion_regularization_loss() - fusion_reg_loss.backward() + if fusion_reg_loss is not None: + fusion_reg_loss.backward() class Seq2SeqAdapterTrainer(AdapterTrainer, Seq2SeqTrainer): diff --git a/tests_adapters/test_adapter_trainer.py b/tests_adapters/test_adapter_trainer.py index efcb9fe65a..eb8ffd4a3e 100644 --- a/tests_adapters/test_adapter_trainer.py +++ b/tests_adapters/test_adapter_trainer.py @@ -21,6 +21,14 @@ class TestAdapterTrainer(unittest.TestCase): + def get_model_config(self): + return BertConfig( + hidden_size=32, + num_hidden_layers=4, + num_attention_heads=4, + intermediate_size=37, + ) + def test_resume_training(self): tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") @@ -29,7 +37,7 @@ def test_resume_training(self): ) train_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="train") - model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased") + model = AutoModelForSequenceClassification.from_config(self.get_model_config()) model.add_adapter("adapter") model.add_adapter("additional_adapter") model.set_active_adapters("adapter") @@ -52,7 +60,7 @@ def test_resume_training(self): trainer.train() # create second model that should resume the training of the first - model_resume = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased") + model_resume = AutoModelForSequenceClassification.from_config(self.get_model_config()) model_resume.add_adapter("adapter") model_resume.add_adapter("additional_adapter") model_resume.set_active_adapters("adapter") @@ -78,7 +86,7 @@ def test_resume_training_with_fusion(self): ) train_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="train") - model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased") + model = AutoModelForSequenceClassification.from_config(self.get_model_config()) model.add_adapter("adapter") model.add_adapter("additional_adapter") model.add_adapter_fusion(Fuse("adapter", "additional_adapter")) @@ -101,7 +109,7 @@ def test_resume_training_with_fusion(self): ) trainer.train() - model_resume = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased") + model_resume = AutoModelForSequenceClassification.from_config(self.get_model_config()) model_resume.add_adapter("adapter") model_resume.add_adapter("additional_adapter") model_resume.add_adapter_fusion(Fuse("adapter", "additional_adapter")) @@ -155,7 +163,7 @@ def test_training_load_best_model_at_end_full_model(self): train_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="train") eval_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev") - model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased") + model = AutoModelForSequenceClassification.from_config(self.get_model_config()) model.add_adapter("adapter") model.train_adapter("adapter") @@ -189,7 +197,7 @@ def test_training_load_best_model_at_end_adapter(self): train_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="train") eval_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev") - model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased") + model = AutoModelForSequenceClassification.from_config(self.get_model_config()) model.add_adapter("adapter") model.train_adapter("adapter") @@ -221,7 +229,7 @@ def test_training_load_best_model_at_end_fusion(self): train_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="train") eval_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev") - model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased") + model = AutoModelForSequenceClassification.from_config(self.get_model_config()) model.add_adapter("fuse_adapter_1") model.add_adapter("fuse_adapter_2") model.add_adapter_fusion(Fuse("fuse_adapter_1", "fuse_adapter_2")) @@ -254,7 +262,7 @@ def test_reloading_prediction_head(self): ) train_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="train") - model = AutoAdapterModel.from_pretrained("bert-base-uncased") + model = AutoAdapterModel.from_config(self.get_model_config()) model.add_classification_head("adapter", num_labels=3) model.add_classification_head("dummy", num_labels=2) @@ -288,7 +296,7 @@ def test_reloading_prediction_head(self): trainer.train() # create second model that should resume the training of the first - model_resume = AutoAdapterModel.from_pretrained("bert-base-uncased") + model_resume = AutoAdapterModel.from_config(self.get_model_config()) model_resume.add_classification_head("adapter", num_labels=3) model_resume.add_classification_head("dummy", num_labels=2) @@ -323,7 +331,7 @@ def test_general(self): ) train_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="train") - model = AutoAdapterModel.from_pretrained("bert-base-uncased") + model = AutoAdapterModel.from_config(self.get_model_config()) model.add_classification_head("task", num_labels=3) @@ -364,6 +372,61 @@ def test_general(self): self.assertEqual("task", model.active_head) self.assertEqual(Stack("task"), model.active_adapters) + def test_train_with_frozen_adapter_fusion(self): + tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") + data_args = GlueDataTrainingArguments( + task_name="mrpc", data_dir="./tests/fixtures/tests_samples/MRPC", overwrite_cache=True + ) + train_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="train") + + model = AutoAdapterModel.from_config(self.get_model_config()) + + model.add_adapter("a") + model.add_adapter("b") + + adapter_setup = Fuse("a", "b") + + model.add_adapter_fusion(adapter_setup, set_active=True) + + model.add_adapter("c") + model.add_classification_head("c") + + model.train_adapter("c") + + model.active_adapters = Stack(Fuse("a", "b"), "c") + + # Since our config has a value matrix, make sure it is regularized. + # We do this by patching the fusion regularization function. + regularization_called = False + orig_fusion_regularization_loss = model.base_model.get_fusion_regularization_loss + + def patched_fusion_reg_loss(): + nonlocal regularization_called + regularization_called = True + return orig_fusion_regularization_loss() + + model.base_model.get_fusion_regularization_loss = patched_fusion_reg_loss + + with TemporaryDirectory() as tempdir: + training_args = TrainingArguments( + output_dir=tempdir, + do_train=True, + learning_rate=0.1, + logging_steps=1, + max_steps=1, + save_steps=1, + remove_unused_columns=False, + ) + trainer = AdapterTrainer( + model=model, + args=training_args, + train_dataset=train_dataset, + ) + + trainer.train() + + self.assertTrue(regularization_called) + @require_ray def test_hyperparameter_search_works_with_AdapterTrainer(self): tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") @@ -375,12 +438,13 @@ def test_hyperparameter_search_works_with_AdapterTrainer(self): def hp_space(params): from ray import tune + return { "learning_rate": tune.choice([0.1, 0.2]), } def model_init(trail=None): - model = AutoAdapterModel.from_pretrained("bert-base-uncased") + model = AutoAdapterModel.from_config(self.get_model_config()) model.add_classification_head("task", num_labels=3) @@ -406,12 +470,10 @@ def model_init(trail=None): model_init=model_init, args=training_args, train_dataset=train_dataset, - eval_dataset=eval_dataset + eval_dataset=eval_dataset, ) - trainer.hyperparameter_search( - direction="minimize", hp_space=hp_space, backend="ray", n_trials=2 - ) + trainer.hyperparameter_search(direction="minimize", hp_space=hp_space, backend="ray", n_trials=2) if __name__ == "__main__":