Skip to content

Commit

Permalink
Merge pull request IntelLabs#193 from melo-gonzo/sam-callback-multida…
Browse files Browse the repository at this point in the history
…ta-optimizer-fix

SAM Callback Update - Check If Optimizer Is Utilized
  • Loading branch information
laserkelvin authored Apr 19, 2024
2 parents ae76f44 + d868391 commit f83248c
Showing 1 changed file with 30 additions and 13 deletions.
43 changes: 30 additions & 13 deletions matsciml/lightning/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,25 +773,42 @@ def extract_optimizer_specific_loss(self, task, optimizer, loss):
loss = loss[key]
return loss

def is_optimizer_used(self, task, optimizer):
# Check if only one optimizer is used (single task)
if isinstance(task.optimizers(), Optimizer):
return True
# Otherwise, see if the specific optimizer we are looking at is used in the current batch.
# If it is not present, this means there will be no loss value and all of the parameters
# gradients will be None.
optimizer_names = copy(task.optimizer_names)
opt_idx = [opt.optimizer == optimizer for opt in task.optimizers()].index(True)
used_optimizer_names = self.batch.keys()
if optimizer_names[opt_idx][0] in list(used_optimizer_names):
return True
else:
return False

def on_before_optimizer_step(
self,
trainer: Trainer,
task: BaseTaskModule,
optimizer: Optimizer,
) -> None:
with torch.no_grad():
org_weights = self._first_step(optimizer)
with torch.enable_grad():
loss = task._compute_losses(self.batch)
# this is for the multitask case where there is more than on optimizer
if len(task.optimizers()) > 1:
loss = self.extract_optimizer_specific_loss(task, optimizer, loss)
loss = self._get_loss(loss)
if loss is not None:
if torch.isfinite(loss):
trainer.strategy.backward(loss, optimizer=optimizer)
with torch.no_grad():
self._second_step(optimizer, org_weights)
optimizer_is_used = self.is_optimizer_used(task, optimizer)
if optimizer_is_used:
with torch.no_grad():
org_weights = self._first_step(optimizer)
with torch.enable_grad():
loss = task._compute_losses(self.batch)
# this is for the multitask case where there is more than on optimizer
if not isinstance(task.optimizers(), Optimizer):
loss = self.extract_optimizer_specific_loss(task, optimizer, loss)
loss = self._get_loss(loss)
if loss is not None:
if torch.isfinite(loss):
trainer.strategy.backward(loss, optimizer=optimizer)
with torch.no_grad():
self._second_step(optimizer, org_weights)

def _norm_weights(self, p: torch.Tensor) -> torch.Tensor:
return torch.abs(p) if self.adaptive else torch.ones_like(p)
Expand Down

0 comments on commit f83248c

Please sign in to comment.