From 5c62cecfabe0aa3f9105812851d53e3e0cd32dc3 Mon Sep 17 00:00:00 2001 From: Eric Bezzam Date: Fri, 22 Sep 2023 14:36:18 +0200 Subject: [PATCH] Logging (#92) * Fix ADMM ordering * Write training output to log file. --------- Co-authored-by: Yohann PERRON --- lensless/recon/utils.py | 44 +++++++++++++++++++++++++++------ scripts/recon/train_unrolled.py | 16 +++++++----- 2 files changed, 46 insertions(+), 14 deletions(-) diff --git a/lensless/recon/utils.py b/lensless/recon/utils.py index 2409dd80..2ca758c6 100644 --- a/lensless/recon/utils.py +++ b/lensless/recon/utils.py @@ -246,6 +246,7 @@ def __init__( test_size=0.15, mask=None, batch_size=4, + eval_batch_size=10, loss="l2", lpips=None, l1_mask=None, @@ -257,6 +258,7 @@ def __init__( metric_for_best_model=None, save_every=None, gamma=None, + logger=None, ): """ Class to train a reconstruction algorithm. Inspired by Trainer from `HuggingFace `__. @@ -281,6 +283,8 @@ def __init__( Trainable mask to use for training. If none, training with fix psf, by default None. batch_size : int, optional Batch size to use for training, by default 4. + eval_batch_size : int, optional + Batch size to use for evaluation, by default 10. loss : str, optional Loss function to use for training "l1" or "l2", by default "l2". lpips : float, optional @@ -303,11 +307,13 @@ def __init__( Save model every ``save_every`` epochs. If None, just save best model. gamma : float, optional Gamma correction to apply to PSFs when plotting. If None, no gamma correction is applied. Default is None. + logger : :py:class:`logging.Logger`, optional + Logger to use for logging. If None, just print to terminal. Default is None. """ self.device = recon._psf.device - + self.logger = logger self.recon = recon assert train_dataset is not None @@ -319,7 +325,10 @@ def __init__( train_dataset, test_dataset = torch.utils.data.random_split( train_dataset, [train_size, test_size] ) - print(f"Train size : {train_size}, Test size : {test_size}") + if self.logger is not None: + self.logger.info(f"Train size : {train_size}, Test size : {test_size}") + else: + print(f"Train size : {train_size}, Test size : {test_size}") self.train_dataloader = torch.utils.data.DataLoader( dataset=train_dataset, @@ -330,6 +339,7 @@ def __init__( self.test_dataset = test_dataset self.lpips = lpips self.skip_NAN = skip_NAN + self.eval_batch_size = eval_batch_size if mask is not None: assert isinstance(mask, TrainableMask) @@ -413,10 +423,16 @@ def learning_rate_function(epoch): def detect_nan(grad): if torch.isnan(grad).any(): - print(grad, flush=True) + if self.logger: + self.logger.info(grad) + else: + print(grad, flush=True) for name, param in recon.named_parameters(): if param.requires_grad: - print(name, param) + if self.logger: + self.logger.info(name, param) + else: + print(name, param) raise ValueError("Gradient is NaN") return grad @@ -505,7 +521,10 @@ def train_epoch(self, data_loader, disp=-1): is_NAN = True break if is_NAN: - print("NAN detected in gradiant, skipping training step") + if self.logger is not None: + self.logger.info("NAN detected in gradiant, skipping training step") + else: + print("NAN detected in gradiant, skipping training step") i += 1 continue self.optimizer.step() @@ -518,6 +537,9 @@ def train_epoch(self, data_loader, disp=-1): pbar.set_description(f"loss : {mean_loss}") i += 1 + if self.logger is not None: + self.logger.info(f"loss : {mean_loss}") + return mean_loss def evaluate(self, mean_loss, save_pt): @@ -534,7 +556,7 @@ def evaluate(self, mean_loss, save_pt): if self.test_dataset is None: return # benchmarking - current_metrics = benchmark(self.recon, self.test_dataset, batchsize=10) + current_metrics = benchmark(self.recon, self.test_dataset, batchsize=self.eval_batch_size) # update metrics with current metrics self.metrics["LOSS"].append(mean_loss) @@ -615,13 +637,19 @@ def train(self, n_epoch=1, save_pt=None, disp=-1): self.evaluate(-1, save_pt) for epoch in range(n_epoch): - print(f"Epoch {epoch} with learning rate {self.scheduler.get_last_lr()}") + if self.logger is not None: + self.logger.info(f"Epoch {epoch} with learning rate {self.scheduler.get_last_lr()}") + else: + print(f"Epoch {epoch} with learning rate {self.scheduler.get_last_lr()}") mean_loss = self.train_epoch(self.train_dataloader, disp=disp) # offset because of evaluate before loop self.on_epoch_end(mean_loss, save_pt, epoch + 1) self.scheduler.step() - print(f"Train time : {time.time() - start_time} s") + if self.logger is not None: + self.logger.info(f"Train time : {time.time() - start_time} s") + else: + print(f"Train time : {time.time() - start_time} s") def save(self, epoch, path="recon", include_optimizer=False): # create directory if it does not exist diff --git a/scripts/recon/train_unrolled.py b/scripts/recon/train_unrolled.py index 5cbee7bf..c9be1ee4 100644 --- a/scripts/recon/train_unrolled.py +++ b/scripts/recon/train_unrolled.py @@ -212,10 +212,10 @@ def train_unrolled(config): save = os.getcwd() if config.torch_device == "cuda" and torch.cuda.is_available(): - print("Using GPU for training.") + log.info("Using GPU for training.") device = "cuda" else: - print("Using CPU for training.") + log.info("Using CPU for training.") device = "cpu" # load dataset and create dataloader @@ -265,8 +265,8 @@ def train_unrolled(config): assert train_set is not None assert psf is not None - print("Train test size : ", len(train_set)) - print("Test test size : ", len(test_set)) + log.info(f"Train test size : {len(train_set)}") + log.info(f"Test test size : {len(test_set)}") start_time = time.time() @@ -321,8 +321,9 @@ def train_unrolled(config): n_param += sum(p.numel() for p in mask.parameters()) log.info(f"Training model with {n_param} parameters") - print(f"Setup time : {time.time() - start_time} s") - print(f"PSF shape : {psf.shape}") + log.info(f"Setup time : {time.time() - start_time} s") + log.info(f"PSF shape : {psf.shape}") + log.info(f"Results saved in {save}") trainer = Trainer( recon=recon, train_dataset=train_set, @@ -340,10 +341,13 @@ def train_unrolled(config): metric_for_best_model=config.training.metric_for_best_model, save_every=config.training.save_every, gamma=config.display.gamma, + logger=log, ) trainer.train(n_epoch=config.training.epoch, save_pt=save, disp=disp) + log.info(f"Results saved in {save}") + if __name__ == "__main__": train_unrolled()