Skip to content

Commit

Permalink
Merge pull request IntelLabs#185 from melo-gonzo/sam-multi-pipeline-fix
Browse files Browse the repository at this point in the history
SAM Multi Task/Data Pipeline Fix
  • Loading branch information
melo-gonzo authored Apr 18, 2024
2 parents 96ae607 + f19566c commit 7375112
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 1 deletion.
18 changes: 18 additions & 0 deletions matsciml/lightning/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from logging import DEBUG, getLogger
from pathlib import Path
from time import time
from copy import copy
from typing import Any, Callable, Dict, Iterator, Optional

import numpy as np
Expand Down Expand Up @@ -753,6 +754,21 @@ def on_train_batch_start(
self.batch = batch
self.batch_idx = batch_idx

def extract_optimizer_specific_loss(self, trainer, optimizer, loss):
optimizer_names = copy(trainer.model.optimizer_names)
opt_idx = [opt == optimizer for opt in trainer.optimizers].index(True)
loss_keys = optimizer_names[opt_idx]
if loss_keys == ("Global", "Encoder"):
optimizer_names.pop(opt_idx)
global_loss = 0
for dataset, task in optimizer_names:
global_loss += loss[dataset][task]["loss"]
return {"loss": global_loss}
else:
for key in loss_keys:
loss = loss[key]
return loss

def on_before_optimizer_step(
self,
trainer: Trainer,
Expand All @@ -763,6 +779,8 @@ def on_before_optimizer_step(
org_weights = self._first_step(optimizer)
with torch.enable_grad():
loss = task._compute_losses(self.batch)
if len(trainer.optimizers) > 1:
loss = self.extract_optimizer_specific_loss(trainer, optimizer, loss)
loss = self._get_loss(loss)
if torch.isfinite(loss):
trainer.strategy.backward(loss, optimizer=optimizer)
Expand Down
2 changes: 1 addition & 1 deletion matsciml/lightning/tests/test_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,5 +365,5 @@ def test_multitask_sam():
("IS2REDataset", is2re),
("S2EFDataset", s2ef),
)
trainer = pl.Trainer(fast_dev_run=10)
trainer = pl.Trainer(fast_dev_run=10, callbacks=SAM())
trainer.fit(task, datamodule=dm)

0 comments on commit 7375112

Please sign in to comment.