Skip to content

Commit

Permalink
Add multi-GPU support to unrolled ADMM.
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzam committed Feb 29, 2024
1 parent fc5fbf6 commit fc38d30
Show file tree
Hide file tree
Showing 8 changed files with 62 additions and 26 deletions.
10 changes: 7 additions & 3 deletions configs/train_unrolled_multimask.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@ defaults:
- train_unrolledADMM
- _self_


torch_device: 'cuda:0'
device_ids: [0, 1, 2, 3]

# Dataset
files:
dataset: bezzam/DigiCam-Mirflickr-MultiMask-1K
dataset: bezzam/DigiCam-Mirflickr-MultiMask-10K
huggingface_dataset: True
downsample: 1.6
downsample: 1
image_res: [900, 1200] # used during measurement
rotate: True # if measurement is upside-down

Expand All @@ -19,5 +23,5 @@ alignment:
training:
batch_size: 4
epoch: 25
eval_batch_size: 10
eval_batch_size: 16

6 changes: 3 additions & 3 deletions docs/source/reconstruction.rst
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@


.. autoclass:: lensless.TrainableReconstructionAlgorithm
:members: batch_call, apply, reset, set_data
:members: forward, apply, reset, set_data
:special-members: __init__
:show-inheritance:

Expand All @@ -78,15 +78,15 @@
~~~~~~~~~~~~~~

.. autoclass:: lensless.UnrolledFISTA
:members: batch_call
:members: forward
:special-members: __init__
:show-inheritance:

Unrolled ADMM
~~~~~~~~~~~~~

.. autoclass:: lensless.UnrolledADMM
:members: batch_call
:members: forward
:special-members: __init__
:show-inheritance:

Expand Down
11 changes: 9 additions & 2 deletions lensless/eval/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,12 @@ def benchmark(
)

else:
prediction = model.batch_call(lensless, psfs, **kwargs)
prediction = model.forward(lensless, psfs, **kwargs)

if unrolled_output_factor:
unrolled_out = prediction[-1]
prediction = prediction[0]
prediction_original = prediction.clone()

# Convert to [N*D, C, H, W] for torchmetrics
prediction = prediction.reshape(-1, *prediction.shape[-3:]).movedim(-1, -3)
Expand Down Expand Up @@ -192,7 +193,13 @@ def benchmark(
# compute metrics
for metric in metrics:
if metric == "ReconstructionError":
metrics_values[metric].append(model.reconstruction_error().cpu().item())
metrics_values[metric].append(
model.reconstruction_error(
prediction=prediction_original, lensless=lensless
)
.cpu()
.item()
)
else:
if "LPIPS" in metric:
if prediction.shape[1] == 1:
Expand Down
5 changes: 3 additions & 2 deletions lensless/recon/trainable_recon.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# ==================
# Authors :
# Yohann PERRON [[email protected]]
# Eric BEZZAM [[email protected]]
# #############################################################################

import pathlib as plib
Expand Down Expand Up @@ -192,7 +193,7 @@ def unfreeze_post_process(self):
for param in self.post_process_model.parameters():
param.requires_grad = True

def batch_call(self, batch, psfs=None):
def forward(self, batch, psfs=None):
"""
Method for performing iterative reconstruction on a batch of images.
This implementation is a properly vectorized implementation of FISTA.
Expand All @@ -216,7 +217,7 @@ def batch_call(self, batch, psfs=None):
# assert same shape
assert psfs.shape == batch.shape, "psfs must have the same shape as batch"
# -- update convolver
self._convolver = RealFFTConvolve2D(psfs.to(self._psf.device), **self._convolver_param)
self._convolver = RealFFTConvolve2D(psfs.to(self._data.device), **self._convolver_param)

# pre process data
if self.pre_process is not None:
Expand Down
23 changes: 14 additions & 9 deletions lensless/recon/unrolled_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# =================
# Authors :
# Yohann PERRON [[email protected]]
# Eric BEZZAM [[email protected]]
# #############################################################################

from lensless.recon.trainable_recon import TrainableReconstructionAlgorithm
Expand Down Expand Up @@ -130,19 +131,23 @@ def _PsiT(self, U):
return finite_diff_adj(U)

def reset(self, batch_size=1):

if self._data is not None:
device = self._data.device
else:
device = self._convolver._H.device

# ensure that mu1, mu2, mu3, tau are positive
self._mu1 = torch.abs(self._mu1_p)
self._mu2 = torch.abs(self._mu2_p)
self._mu3 = torch.abs(self._mu3_p)
self._tau = torch.abs(self._tau_p)
self._mu1 = torch.abs(self._mu1_p).to(device)
self._mu2 = torch.abs(self._mu2_p).to(device)
self._mu3 = torch.abs(self._mu3_p).to(device)
self._tau = torch.abs(self._tau_p).to(device)

# TODO initialize without padding
if self._initial_est is not None:
self._image_est = self._initial_est
self._image_est = self._initial_est.to(device)
else:
self._image_est = torch.zeros([1] + self._padded_shape, dtype=self._dtype).to(
self._psf.device
)
self._image_est = torch.zeros([1] + self._padded_shape, dtype=self._dtype).to(device)

self._X = torch.zeros_like(self._image_est)
self._U = torch.zeros_like(self._Psi(self._image_est))
Expand All @@ -163,7 +168,7 @@ def reset(self, batch_size=1):
self._R_divmat = 1.0 / (
self._mu1[:, None, None, None, None, None]
* (torch.abs(self._convolver._Hadj * self._convolver._H))[None, ...]
+ self._mu2[:, None, None, None, None, None] * torch.abs(self._PsiTPsi)
+ self._mu2[:, None, None, None, None, None] * torch.abs(self._PsiTPsi).to(device)
+ self._mu3[:, None, None, None, None, None]
).type(self._complex_dtype)

Expand Down
2 changes: 1 addition & 1 deletion lensless/recon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ def train_epoch(self, data_loader):
self.recon._set_psf(self.mask.get_psf().to(self.device))

# forward pass
y_pred = self.recon.batch_call(X, psfs=psfs)
y_pred = self.recon.forward(X, psfs=psfs)
if self.unrolled_output_factor:
unrolled_out = y_pred[1]
y_pred = y_pred[0]
Expand Down
25 changes: 22 additions & 3 deletions scripts/recon/train_unrolled.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,14 @@
log = logging.getLogger(__name__)


class MyDataParallel(torch.nn.DataParallel):
def __getattr__(self, name):
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.module, name)


@hydra.main(version_base=None, config_path="../../configs", config_name="train_unrolledADMM")
def train_unrolled(config):

Expand All @@ -80,13 +88,19 @@ def train_unrolled(config):
if save:
save = os.getcwd()

use_cuda = False
if "cuda" in config.torch_device and torch.cuda.is_available():
# if config.torch_device == "cuda" and torch.cuda.is_available():
log.info("Using GPU for training.")
device = config.torch_device
use_cuda = True
else:
log.info("Using CPU for training.")
device = "cpu"
# device, use_cuda, multi_gpu, device_ids = device_checks(
# config.torch_device, config.multi_gpu, logger=log.info,
# )
device_ids = config.device_ids

# load dataset and create dataloader
train_set = None
Expand Down Expand Up @@ -355,7 +369,7 @@ def train_unrolled(config):
post_process=post_process if post_proc_delay is None else None,
skip_unrolled=config.reconstruction.skip_unrolled,
return_unrolled_output=True if config.unrolled_output_factor > 0 else False,
).to(device)
)
elif config.reconstruction.method == "unrolled_admm":
recon = UnrolledADMM(
psf,
Expand All @@ -368,18 +382,23 @@ def train_unrolled(config):
post_process=post_process if post_proc_delay is None else None,
skip_unrolled=config.reconstruction.skip_unrolled,
return_unrolled_output=True if config.unrolled_output_factor > 0 else False,
).to(device)
)
elif config.reconstruction.method == "trainable_inv":
recon = TrainableInversion(
psf,
K=config.reconstruction.trainable_inv.K,
pre_process=pre_process if pre_proc_delay is None else None,
post_process=post_process if post_proc_delay is None else None,
return_unrolled_output=True if config.unrolled_output_factor > 0 else False,
).to(device)
)
else:
raise ValueError(f"{config.reconstruction.method} is not a supported algorithm")

if device_ids is not None:
recon = MyDataParallel(recon, device_ids=device_ids)
if use_cuda:
recon.to(device)

# constructing algorithm name by appending pre and post process
algorithm_name = config.reconstruction.method
if config.reconstruction.pre_process.network is not None:
Expand Down
6 changes: 3 additions & 3 deletions test/test_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def post_process(x, noise):
next(recon.parameters(), None) is not None
), f"{algorithm.__name__} has no trainable parameters"

res = recon.batch_call(data)
res = recon.forward(data)
loss = torch.mean(res)
loss.backward()

Expand Down Expand Up @@ -215,8 +215,8 @@ def post_process(x, noise):
recon = algorithm(
psf, dtype=dtype, n_iter=_n_iter, pre_process=pre_process, post_process=post_process
)
res1 = recon.batch_call(data1)
res2 = recon.batch_call(data2)
res1 = recon.forward(data1)
res2 = recon.forward(data2)
recon.set_data(data2[0])
res3 = recon.apply(disp_iter=None, plot=False)

Expand Down

0 comments on commit fc38d30

Please sign in to comment.