diff --git a/src/lightning_trainable/trainable/trainable.py b/src/lightning_trainable/trainable/trainable.py index bff3e36..fc094e4 100644 --- a/src/lightning_trainable/trainable/trainable.py +++ b/src/lightning_trainable/trainable/trainable.py @@ -238,6 +238,14 @@ def on_before_optimizer_step(self, optimizer): case other: raise NotImplementedError(f"Unrecognized grad norm: {other}") + def on_train_start(self) -> None: + # get hparams metrics with a test batch + test_batch = next(iter(self.trainer.train_dataloader)) + metrics = self.compute_metrics(test_batch, 0) + + # add hparams to tensorboard + self.logger.log_hyperparams(self.hparams, metrics) + @torch.enable_grad() def fit(self, logger_kwargs: dict = None, trainer_kwargs: dict = None, fit_kwargs: dict = None) -> dict: """ @@ -256,15 +264,6 @@ def fit(self, logger_kwargs: dict = None, trainer_kwargs: dict = None, fit_kwarg trainer = self.configure_trainer(logger_kwargs, trainer_kwargs) - # TODO: move this to on_train_start(), see https://lightning.ai/docs/pytorch/stable/extensions/logging.html#logging-hyperparameters - metrics_list = trainer.validate(self) - if metrics_list is not None and len(metrics_list) > 0: - metrics = metrics_list[0] - else: - metrics = {} - trainer.logger.log_hyperparams(self.hparams, metrics) - - trainer.fit(self, **fit_kwargs) return {