Skip to content

Commit

Permalink
erge branch 'main' of github.com:LCAV/LenslessPiCam into main
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzam committed Sep 22, 2023
2 parents 6b16b86 + 5c62cec commit 816f050
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 14 deletions.
44 changes: 36 additions & 8 deletions lensless/recon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ def __init__(
test_size=0.15,
mask=None,
batch_size=4,
eval_batch_size=10,
loss="l2",
lpips=None,
l1_mask=None,
Expand All @@ -257,6 +258,7 @@ def __init__(
metric_for_best_model=None,
save_every=None,
gamma=None,
logger=None,
):
"""
Class to train a reconstruction algorithm. Inspired by Trainer from `HuggingFace <https://huggingface.co/docs/transformers/main_classes/trainer>`__.
Expand All @@ -281,6 +283,8 @@ def __init__(
Trainable mask to use for training. If none, training with fix psf, by default None.
batch_size : int, optional
Batch size to use for training, by default 4.
eval_batch_size : int, optional
Batch size to use for evaluation, by default 10.
loss : str, optional
Loss function to use for training "l1" or "l2", by default "l2".
lpips : float, optional
Expand All @@ -303,11 +307,13 @@ def __init__(
Save model every ``save_every`` epochs. If None, just save best model.
gamma : float, optional
Gamma correction to apply to PSFs when plotting. If None, no gamma correction is applied. Default is None.
logger : :py:class:`logging.Logger`, optional
Logger to use for logging. If None, just print to terminal. Default is None.
"""
self.device = recon._psf.device

self.logger = logger
self.recon = recon

assert train_dataset is not None
Expand All @@ -319,7 +325,10 @@ def __init__(
train_dataset, test_dataset = torch.utils.data.random_split(
train_dataset, [train_size, test_size]
)
print(f"Train size : {train_size}, Test size : {test_size}")
if self.logger is not None:
self.logger.info(f"Train size : {train_size}, Test size : {test_size}")
else:
print(f"Train size : {train_size}, Test size : {test_size}")

self.train_dataloader = torch.utils.data.DataLoader(
dataset=train_dataset,
Expand All @@ -330,6 +339,7 @@ def __init__(
self.test_dataset = test_dataset
self.lpips = lpips
self.skip_NAN = skip_NAN
self.eval_batch_size = eval_batch_size

if mask is not None:
assert isinstance(mask, TrainableMask)
Expand Down Expand Up @@ -413,10 +423,16 @@ def learning_rate_function(epoch):

def detect_nan(grad):
if torch.isnan(grad).any():
print(grad, flush=True)
if self.logger:
self.logger.info(grad)
else:
print(grad, flush=True)
for name, param in recon.named_parameters():
if param.requires_grad:
print(name, param)
if self.logger:
self.logger.info(name, param)
else:
print(name, param)
raise ValueError("Gradient is NaN")
return grad

Expand Down Expand Up @@ -505,7 +521,10 @@ def train_epoch(self, data_loader, disp=-1):
is_NAN = True
break
if is_NAN:
print("NAN detected in gradiant, skipping training step")
if self.logger is not None:
self.logger.info("NAN detected in gradiant, skipping training step")
else:
print("NAN detected in gradiant, skipping training step")
i += 1
continue
self.optimizer.step()
Expand All @@ -518,6 +537,9 @@ def train_epoch(self, data_loader, disp=-1):
pbar.set_description(f"loss : {mean_loss}")
i += 1

if self.logger is not None:
self.logger.info(f"loss : {mean_loss}")

return mean_loss

def evaluate(self, mean_loss, save_pt):
Expand All @@ -534,7 +556,7 @@ def evaluate(self, mean_loss, save_pt):
if self.test_dataset is None:
return
# benchmarking
current_metrics = benchmark(self.recon, self.test_dataset, batchsize=10)
current_metrics = benchmark(self.recon, self.test_dataset, batchsize=self.eval_batch_size)

# update metrics with current metrics
self.metrics["LOSS"].append(mean_loss)
Expand Down Expand Up @@ -615,13 +637,19 @@ def train(self, n_epoch=1, save_pt=None, disp=-1):

self.evaluate(-1, save_pt)
for epoch in range(n_epoch):
print(f"Epoch {epoch} with learning rate {self.scheduler.get_last_lr()}")
if self.logger is not None:
self.logger.info(f"Epoch {epoch} with learning rate {self.scheduler.get_last_lr()}")
else:
print(f"Epoch {epoch} with learning rate {self.scheduler.get_last_lr()}")
mean_loss = self.train_epoch(self.train_dataloader, disp=disp)
# offset because of evaluate before loop
self.on_epoch_end(mean_loss, save_pt, epoch + 1)
self.scheduler.step()

print(f"Train time : {time.time() - start_time} s")
if self.logger is not None:
self.logger.info(f"Train time : {time.time() - start_time} s")
else:
print(f"Train time : {time.time() - start_time} s")

def save(self, epoch, path="recon", include_optimizer=False):
# create directory if it does not exist
Expand Down
16 changes: 10 additions & 6 deletions scripts/recon/train_unrolled.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,10 +212,10 @@ def train_unrolled(config):
save = os.getcwd()

if config.torch_device == "cuda" and torch.cuda.is_available():
print("Using GPU for training.")
log.info("Using GPU for training.")
device = "cuda"
else:
print("Using CPU for training.")
log.info("Using CPU for training.")
device = "cpu"

# load dataset and create dataloader
Expand Down Expand Up @@ -265,8 +265,8 @@ def train_unrolled(config):
assert train_set is not None
assert psf is not None

print("Train test size : ", len(train_set))
print("Test test size : ", len(test_set))
log.info(f"Train test size : {len(train_set)}")
log.info(f"Test test size : {len(test_set)}")

start_time = time.time()

Expand Down Expand Up @@ -321,8 +321,9 @@ def train_unrolled(config):
n_param += sum(p.numel() for p in mask.parameters())
log.info(f"Training model with {n_param} parameters")

print(f"Setup time : {time.time() - start_time} s")
print(f"PSF shape : {psf.shape}")
log.info(f"Setup time : {time.time() - start_time} s")
log.info(f"PSF shape : {psf.shape}")
log.info(f"Results saved in {save}")
trainer = Trainer(
recon=recon,
train_dataset=train_set,
Expand All @@ -340,10 +341,13 @@ def train_unrolled(config):
metric_for_best_model=config.training.metric_for_best_model,
save_every=config.training.save_every,
gamma=config.display.gamma,
logger=log,
)

trainer.train(n_epoch=config.training.epoch, save_pt=save, disp=disp)

log.info(f"Results saved in {save}")


if __name__ == "__main__":
train_unrolled()

0 comments on commit 816f050

Please sign in to comment.