Skip to content

Commit

Permalink
also log epoch
Browse files Browse the repository at this point in the history
  • Loading branch information
wiederm committed Nov 7, 2024
1 parent 667909f commit afb0c59
Showing 1 changed file with 24 additions and 16 deletions.
40 changes: 24 additions & 16 deletions modelforge/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1157,7 +1157,9 @@ def _log_figures_for_each_phase(
# Log outlier error counts for non-training phases
if phase != "train":
self._identify__and_log_top_k_errors(errors, gathered_indices, phase)
self.log_dict(self.outlier_errors_over_epochs, on_epoch=True)
self.log_dict(
self.outlier_errors_over_epochs, on_epoch=True, rank_zero_only=True
)

def _identify__and_log_top_k_errors(
self,
Expand Down Expand Up @@ -1197,7 +1199,9 @@ def _identify__and_log_top_k_errors(
if key not in self.outlier_errors_over_epochs:
self.outlier_errors_over_epochs[key] = 0
self.outlier_errors_over_epochs[key] += 1
log.info(f"{phase} : Outlier error {error} at index {idx}.")
log.info(
f"{self.current_epoch}: {phase} : Outlier error {error} at index {idx}."
)

def _clear_error_tracking(self, preds, targets, incides):
"""
Expand Down Expand Up @@ -1293,24 +1297,28 @@ def on_train_epoch_end(self):
self._log_time()
self._log_histograms()
# log the weights of the different loss components
for key, weight in self.loss.weights_scheduling.items():
self.log(
f"loss/{key}/weight",
weight[self.current_epoch],
)
if self.trainer.is_global_zero:
for key, weight in self.loss.weights_scheduling.items():
self.log(
f"loss/{key}/weight",
weight[self.current_epoch],
rank_zero_only=True,
)

def _log_learning_rate(self):
"""Logs the current learning rate."""
sch = self.lr_schedulers()
try:
self.log(
"lr",
sch.get_last_lr()[0],
on_epoch=True,
prog_bar=True,
)
except AttributeError:
pass
if self.trainer.is_global_zero:
try:
self.log(
"lr",
sch.get_last_lr()[0],
on_epoch=True,
prog_bar=True,
rank_zero_only=True,
)
except AttributeError:
pass

def _log_metrics(self, metrics: ModuleDict, phase: str):
"""
Expand Down

0 comments on commit afb0c59

Please sign in to comment.