From eb29b1603999c91fa0d4d157a1a139302143901e Mon Sep 17 00:00:00 2001 From: Lars Kuehmichel Date: Fri, 27 Oct 2023 15:34:43 +0200 Subject: [PATCH] Move hparam logging to on_train_start with test batch from train_dataloader --- src/lightning_trainable/trainable/trainable.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/lightning_trainable/trainable/trainable.py b/src/lightning_trainable/trainable/trainable.py index bff3e36..fc094e4 100644 --- a/src/lightning_trainable/trainable/trainable.py +++ b/src/lightning_trainable/trainable/trainable.py @@ -238,6 +238,14 @@ def on_before_optimizer_step(self, optimizer): case other: raise NotImplementedError(f"Unrecognized grad norm: {other}") + def on_train_start(self) -> None: + # get hparams metrics with a test batch + test_batch = next(iter(self.trainer.train_dataloader)) + metrics = self.compute_metrics(test_batch, 0) + + # add hparams to tensorboard + self.logger.log_hyperparams(self.hparams, metrics) + @torch.enable_grad() def fit(self, logger_kwargs: dict = None, trainer_kwargs: dict = None, fit_kwargs: dict = None) -> dict: """ @@ -256,15 +264,6 @@ def fit(self, logger_kwargs: dict = None, trainer_kwargs: dict = None, fit_kwarg trainer = self.configure_trainer(logger_kwargs, trainer_kwargs) - # TODO: move this to on_train_start(), see https://lightning.ai/docs/pytorch/stable/extensions/logging.html#logging-hyperparameters - metrics_list = trainer.validate(self) - if metrics_list is not None and len(metrics_list) > 0: - metrics = metrics_list[0] - else: - metrics = {} - trainer.logger.log_hyperparams(self.hparams, metrics) - - trainer.fit(self, **fit_kwargs) return {