Skip to content

Commit

Permalink
update sam callback with comment about number of optimizers in multitask
Browse files Browse the repository at this point in the history
  • Loading branch information
melo-gonzo authored Apr 19, 2024
1 parent cc84b90 commit 6b2718f
Showing 1 changed file with 1 addition and 0 deletions.
1 change: 1 addition & 0 deletions matsciml/lightning/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,7 @@ def on_before_optimizer_step(
org_weights = self._first_step(optimizer)
with torch.enable_grad():
loss = task._compute_losses(self.batch)
# this is for the multitask case where there is more than on optimizer
if len(task.optimizers()) > 1:
loss = self.extract_optimizer_specific_loss(task, optimizer, loss)
loss = self._get_loss(loss)
Expand Down

0 comments on commit 6b2718f

Please sign in to comment.