Skip to content

Commit

Permalink
Move hparam logging to on_train_start with test batch from train_data…
Browse files Browse the repository at this point in the history
…loader
  • Loading branch information
Lars Kuehmichel committed Oct 27, 2023
1 parent a17f51e commit eb29b16
Showing 1 changed file with 8 additions and 9 deletions.
17 changes: 8 additions & 9 deletions src/lightning_trainable/trainable/trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,14 @@ def on_before_optimizer_step(self, optimizer):
case other:
raise NotImplementedError(f"Unrecognized grad norm: {other}")

def on_train_start(self) -> None:
# get hparams metrics with a test batch
test_batch = next(iter(self.trainer.train_dataloader))
metrics = self.compute_metrics(test_batch, 0)

# add hparams to tensorboard
self.logger.log_hyperparams(self.hparams, metrics)

@torch.enable_grad()
def fit(self, logger_kwargs: dict = None, trainer_kwargs: dict = None, fit_kwargs: dict = None) -> dict:
"""
Expand All @@ -256,15 +264,6 @@ def fit(self, logger_kwargs: dict = None, trainer_kwargs: dict = None, fit_kwarg

trainer = self.configure_trainer(logger_kwargs, trainer_kwargs)

# TODO: move this to on_train_start(), see https://lightning.ai/docs/pytorch/stable/extensions/logging.html#logging-hyperparameters
metrics_list = trainer.validate(self)
if metrics_list is not None and len(metrics_list) > 0:
metrics = metrics_list[0]
else:
metrics = {}
trainer.logger.log_hyperparams(self.hparams, metrics)


trainer.fit(self, **fit_kwargs)

return {
Expand Down

0 comments on commit eb29b16

Please sign in to comment.