diff --git a/configs/train_unrolled_multimask.yaml b/configs/train_unrolled_multimask.yaml index ff639b5c..80eee4ed 100644 --- a/configs/train_unrolled_multimask.yaml +++ b/configs/train_unrolled_multimask.yaml @@ -3,11 +3,15 @@ defaults: - train_unrolledADMM - _self_ + +torch_device: 'cuda:0' +device_ids: [0, 1, 2, 3] + # Dataset files: - dataset: bezzam/DigiCam-Mirflickr-MultiMask-1K + dataset: bezzam/DigiCam-Mirflickr-MultiMask-10K huggingface_dataset: True - downsample: 1.6 + downsample: 1 image_res: [900, 1200] # used during measurement rotate: True # if measurement is upside-down @@ -19,5 +23,5 @@ alignment: training: batch_size: 4 epoch: 25 - eval_batch_size: 10 + eval_batch_size: 16 diff --git a/docs/source/reconstruction.rst b/docs/source/reconstruction.rst index e5b927f4..4674f327 100644 --- a/docs/source/reconstruction.rst +++ b/docs/source/reconstruction.rst @@ -69,7 +69,7 @@ .. autoclass:: lensless.TrainableReconstructionAlgorithm - :members: batch_call, apply, reset, set_data + :members: forward, apply, reset, set_data :special-members: __init__ :show-inheritance: @@ -78,7 +78,7 @@ ~~~~~~~~~~~~~~ .. autoclass:: lensless.UnrolledFISTA - :members: batch_call + :members: forward :special-members: __init__ :show-inheritance: @@ -86,7 +86,7 @@ ~~~~~~~~~~~~~ .. autoclass:: lensless.UnrolledADMM - :members: batch_call + :members: forward :special-members: __init__ :show-inheritance: diff --git a/lensless/eval/benchmark.py b/lensless/eval/benchmark.py index 13c3912b..1aa84cc2 100644 --- a/lensless/eval/benchmark.py +++ b/lensless/eval/benchmark.py @@ -141,11 +141,12 @@ def benchmark( ) else: - prediction = model.batch_call(lensless, psfs, **kwargs) + prediction = model.forward(lensless, psfs, **kwargs) if unrolled_output_factor: unrolled_out = prediction[-1] prediction = prediction[0] + prediction_original = prediction.clone() # Convert to [N*D, C, H, W] for torchmetrics prediction = prediction.reshape(-1, *prediction.shape[-3:]).movedim(-1, -3) @@ -192,7 +193,13 @@ def benchmark( # compute metrics for metric in metrics: if metric == "ReconstructionError": - metrics_values[metric].append(model.reconstruction_error().cpu().item()) + metrics_values[metric].append( + model.reconstruction_error( + prediction=prediction_original, lensless=lensless + ) + .cpu() + .item() + ) else: if "LPIPS" in metric: if prediction.shape[1] == 1: diff --git a/lensless/recon/trainable_recon.py b/lensless/recon/trainable_recon.py index 7cd48735..5f7b6d52 100644 --- a/lensless/recon/trainable_recon.py +++ b/lensless/recon/trainable_recon.py @@ -3,6 +3,7 @@ # ================== # Authors : # Yohann PERRON [yohann.perron@gmail.com] +# Eric BEZZAM [ebezzam@gmail.com] # ############################################################################# import pathlib as plib @@ -192,7 +193,7 @@ def unfreeze_post_process(self): for param in self.post_process_model.parameters(): param.requires_grad = True - def batch_call(self, batch, psfs=None): + def forward(self, batch, psfs=None): """ Method for performing iterative reconstruction on a batch of images. This implementation is a properly vectorized implementation of FISTA. @@ -216,7 +217,7 @@ def batch_call(self, batch, psfs=None): # assert same shape assert psfs.shape == batch.shape, "psfs must have the same shape as batch" # -- update convolver - self._convolver = RealFFTConvolve2D(psfs.to(self._psf.device), **self._convolver_param) + self._convolver = RealFFTConvolve2D(psfs.to(self._data.device), **self._convolver_param) # pre process data if self.pre_process is not None: diff --git a/lensless/recon/unrolled_admm.py b/lensless/recon/unrolled_admm.py index 8c923ddb..b20eaa07 100644 --- a/lensless/recon/unrolled_admm.py +++ b/lensless/recon/unrolled_admm.py @@ -3,6 +3,7 @@ # ================= # Authors : # Yohann PERRON [yohann.perron@gmail.com] +# Eric BEZZAM [ebezzam@gmail.com] # ############################################################################# from lensless.recon.trainable_recon import TrainableReconstructionAlgorithm @@ -130,19 +131,23 @@ def _PsiT(self, U): return finite_diff_adj(U) def reset(self, batch_size=1): + + if self._data is not None: + device = self._data.device + else: + device = self._convolver._H.device + # ensure that mu1, mu2, mu3, tau are positive - self._mu1 = torch.abs(self._mu1_p) - self._mu2 = torch.abs(self._mu2_p) - self._mu3 = torch.abs(self._mu3_p) - self._tau = torch.abs(self._tau_p) + self._mu1 = torch.abs(self._mu1_p).to(device) + self._mu2 = torch.abs(self._mu2_p).to(device) + self._mu3 = torch.abs(self._mu3_p).to(device) + self._tau = torch.abs(self._tau_p).to(device) # TODO initialize without padding if self._initial_est is not None: - self._image_est = self._initial_est + self._image_est = self._initial_est.to(device) else: - self._image_est = torch.zeros([1] + self._padded_shape, dtype=self._dtype).to( - self._psf.device - ) + self._image_est = torch.zeros([1] + self._padded_shape, dtype=self._dtype).to(device) self._X = torch.zeros_like(self._image_est) self._U = torch.zeros_like(self._Psi(self._image_est)) @@ -163,7 +168,7 @@ def reset(self, batch_size=1): self._R_divmat = 1.0 / ( self._mu1[:, None, None, None, None, None] * (torch.abs(self._convolver._Hadj * self._convolver._H))[None, ...] - + self._mu2[:, None, None, None, None, None] * torch.abs(self._PsiTPsi) + + self._mu2[:, None, None, None, None, None] * torch.abs(self._PsiTPsi).to(device) + self._mu3[:, None, None, None, None, None] ).type(self._complex_dtype) diff --git a/lensless/recon/utils.py b/lensless/recon/utils.py index e09cd233..47f28fe7 100644 --- a/lensless/recon/utils.py +++ b/lensless/recon/utils.py @@ -615,7 +615,7 @@ def train_epoch(self, data_loader): self.recon._set_psf(self.mask.get_psf().to(self.device)) # forward pass - y_pred = self.recon.batch_call(X, psfs=psfs) + y_pred = self.recon.forward(X, psfs=psfs) if self.unrolled_output_factor: unrolled_out = y_pred[1] y_pred = y_pred[0] diff --git a/scripts/recon/train_unrolled.py b/scripts/recon/train_unrolled.py index aba35bf9..97612d2e 100644 --- a/scripts/recon/train_unrolled.py +++ b/scripts/recon/train_unrolled.py @@ -59,6 +59,14 @@ log = logging.getLogger(__name__) +class MyDataParallel(torch.nn.DataParallel): + def __getattr__(self, name): + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self.module, name) + + @hydra.main(version_base=None, config_path="../../configs", config_name="train_unrolledADMM") def train_unrolled(config): @@ -80,13 +88,19 @@ def train_unrolled(config): if save: save = os.getcwd() + use_cuda = False if "cuda" in config.torch_device and torch.cuda.is_available(): # if config.torch_device == "cuda" and torch.cuda.is_available(): log.info("Using GPU for training.") device = config.torch_device + use_cuda = True else: log.info("Using CPU for training.") device = "cpu" + # device, use_cuda, multi_gpu, device_ids = device_checks( + # config.torch_device, config.multi_gpu, logger=log.info, + # ) + device_ids = config.device_ids # load dataset and create dataloader train_set = None @@ -355,7 +369,7 @@ def train_unrolled(config): post_process=post_process if post_proc_delay is None else None, skip_unrolled=config.reconstruction.skip_unrolled, return_unrolled_output=True if config.unrolled_output_factor > 0 else False, - ).to(device) + ) elif config.reconstruction.method == "unrolled_admm": recon = UnrolledADMM( psf, @@ -368,7 +382,7 @@ def train_unrolled(config): post_process=post_process if post_proc_delay is None else None, skip_unrolled=config.reconstruction.skip_unrolled, return_unrolled_output=True if config.unrolled_output_factor > 0 else False, - ).to(device) + ) elif config.reconstruction.method == "trainable_inv": recon = TrainableInversion( psf, @@ -376,10 +390,15 @@ def train_unrolled(config): pre_process=pre_process if pre_proc_delay is None else None, post_process=post_process if post_proc_delay is None else None, return_unrolled_output=True if config.unrolled_output_factor > 0 else False, - ).to(device) + ) else: raise ValueError(f"{config.reconstruction.method} is not a supported algorithm") + if device_ids is not None: + recon = MyDataParallel(recon, device_ids=device_ids) + if use_cuda: + recon.to(device) + # constructing algorithm name by appending pre and post process algorithm_name = config.reconstruction.method if config.reconstruction.pre_process.network is not None: diff --git a/test/test_algos.py b/test/test_algos.py index b63b5a42..b5e3d94c 100644 --- a/test/test_algos.py +++ b/test/test_algos.py @@ -183,7 +183,7 @@ def post_process(x, noise): next(recon.parameters(), None) is not None ), f"{algorithm.__name__} has no trainable parameters" - res = recon.batch_call(data) + res = recon.forward(data) loss = torch.mean(res) loss.backward() @@ -215,8 +215,8 @@ def post_process(x, noise): recon = algorithm( psf, dtype=dtype, n_iter=_n_iter, pre_process=pre_process, post_process=post_process ) - res1 = recon.batch_call(data1) - res2 = recon.batch_call(data2) + res1 = recon.forward(data1) + res2 = recon.forward(data2) recon.set_data(data2[0]) res3 = recon.apply(disp_iter=None, plot=False)