From 6b2718fcdf4f40663830b4c41694be07ae5a1f6b Mon Sep 17 00:00:00 2001 From: Carmelo Gonzales <43048528+melo-gonzo@users.noreply.github.com> Date: Fri, 19 Apr 2024 08:52:07 -0700 Subject: [PATCH] update sam callback with comment about number of optimizers in multitask --- matsciml/lightning/callbacks.py | 1 + 1 file changed, 1 insertion(+) diff --git a/matsciml/lightning/callbacks.py b/matsciml/lightning/callbacks.py index 5347b4b9..49645e33 100644 --- a/matsciml/lightning/callbacks.py +++ b/matsciml/lightning/callbacks.py @@ -783,6 +783,7 @@ def on_before_optimizer_step( org_weights = self._first_step(optimizer) with torch.enable_grad(): loss = task._compute_losses(self.batch) + # this is for the multitask case where there is more than on optimizer if len(task.optimizers()) > 1: loss = self.extract_optimizer_specific_loss(task, optimizer, loss) loss = self._get_loss(loss)