Skip to content

Commit

Permalink
Merge pull request IntelLabs#187 from melo-gonzo/sam-multidata-and-dd…
Browse files Browse the repository at this point in the history
…p-fix

SAM Callback DDP and Multi Data Fixes.
  • Loading branch information
melo-gonzo authored Apr 19, 2024
2 parents 7375112 + 6b2718f commit 46c1737
Showing 1 changed file with 14 additions and 8 deletions.
22 changes: 14 additions & 8 deletions matsciml/lightning/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,16 +754,20 @@ def on_train_batch_start(
self.batch = batch
self.batch_idx = batch_idx

def extract_optimizer_specific_loss(self, trainer, optimizer, loss):
optimizer_names = copy(trainer.model.optimizer_names)
opt_idx = [opt == optimizer for opt in trainer.optimizers].index(True)
def extract_optimizer_specific_loss(self, task, optimizer, loss):
optimizer_names = copy(task.optimizer_names)
opt_idx = [opt.optimizer == optimizer for opt in task.optimizers()].index(True)
loss_keys = optimizer_names[opt_idx]
if loss_keys == ("Global", "Encoder"):
optimizer_names.pop(opt_idx)
global_loss = 0
for dataset, task in optimizer_names:
global_loss += loss[dataset][task]["loss"]
if loss.get(dataset, None) is not None:
global_loss += loss[dataset][task]["loss"]
return {"loss": global_loss}
# When some datasets have less samples than others, they wont have a loss value
if loss_keys[0] not in loss:
loss = {"loss": None}
else:
for key in loss_keys:
loss = loss[key]
Expand All @@ -779,11 +783,13 @@ def on_before_optimizer_step(
org_weights = self._first_step(optimizer)
with torch.enable_grad():
loss = task._compute_losses(self.batch)
if len(trainer.optimizers) > 1:
loss = self.extract_optimizer_specific_loss(trainer, optimizer, loss)
# this is for the multitask case where there is more than on optimizer
if len(task.optimizers()) > 1:
loss = self.extract_optimizer_specific_loss(task, optimizer, loss)
loss = self._get_loss(loss)
if torch.isfinite(loss):
trainer.strategy.backward(loss, optimizer=optimizer)
if loss is not None:
if torch.isfinite(loss):
trainer.strategy.backward(loss, optimizer=optimizer)
with torch.no_grad():
self._second_step(optimizer, org_weights)

Expand Down

0 comments on commit 46c1737

Please sign in to comment.