From aefc253598799a3463f444eebb0fef1b3a6e6437 Mon Sep 17 00:00:00 2001 From: Eric Bezzam Date: Mon, 25 Sep 2023 12:17:23 +0000 Subject: [PATCH] Start adding support for multi GPU. --- configs/train_unrolledADMM.yaml | 3 +- docs/source/reconstruction.rst | 6 ++-- lensless/eval/benchmark.py | 2 +- lensless/recon/trainable_recon.py | 3 +- lensless/recon/unrolled_admm.py | 3 +- lensless/recon/utils.py | 50 ++++++++++++++++++++++++++++++- scripts/recon/train_unrolled.py | 42 ++++++++++++++++++-------- test/test_algos.py | 8 ++--- 8 files changed, 90 insertions(+), 27 deletions(-) diff --git a/configs/train_unrolledADMM.yaml b/configs/train_unrolledADMM.yaml index 3871be0d..2571bfd8 100644 --- a/configs/train_unrolledADMM.yaml +++ b/configs/train_unrolledADMM.yaml @@ -13,7 +13,8 @@ files: downsample: 2 # factor by which to downsample the PSF, note that for DiffuserCam the PSF has 4x the resolution torch: True -torch_device: 'cuda' +torch_device: cuda:0 +multi_gpu: True display: # How many iterations to wait for intermediate plot. diff --git a/docs/source/reconstruction.rst b/docs/source/reconstruction.rst index e5b927f4..4674f327 100644 --- a/docs/source/reconstruction.rst +++ b/docs/source/reconstruction.rst @@ -69,7 +69,7 @@ .. autoclass:: lensless.TrainableReconstructionAlgorithm - :members: batch_call, apply, reset, set_data + :members: forward, apply, reset, set_data :special-members: __init__ :show-inheritance: @@ -78,7 +78,7 @@ ~~~~~~~~~~~~~~ .. autoclass:: lensless.UnrolledFISTA - :members: batch_call + :members: forward :special-members: __init__ :show-inheritance: @@ -86,7 +86,7 @@ ~~~~~~~~~~~~~ .. autoclass:: lensless.UnrolledADMM - :members: batch_call + :members: forward :special-members: __init__ :show-inheritance: diff --git a/lensless/eval/benchmark.py b/lensless/eval/benchmark.py index 885766f3..3c5ad63e 100644 --- a/lensless/eval/benchmark.py +++ b/lensless/eval/benchmark.py @@ -75,7 +75,7 @@ def benchmark(model, dataset, batchsize=1, metrics=None, **kwargs): prediction = model.apply(plot=False, save=False, **kwargs) else: - prediction = model.batch_call(lensless, **kwargs) + prediction = model.forward(lensless, **kwargs) # Convert to [N*D, C, H, W] for torchmetrics prediction = prediction.reshape(-1, *prediction.shape[-3:]).movedim(-1, -3) diff --git a/lensless/recon/trainable_recon.py b/lensless/recon/trainable_recon.py index 82fd883d..e3436390 100644 --- a/lensless/recon/trainable_recon.py +++ b/lensless/recon/trainable_recon.py @@ -122,10 +122,9 @@ def _prepare_process_block(self, process): return process_function, process_model, process_param - def batch_call(self, batch): + def forward(self, batch): """ Method for performing iterative reconstruction on a batch of images. - This implementation is a properly vectorized implementation of FISTA. Parameters ---------- diff --git a/lensless/recon/unrolled_admm.py b/lensless/recon/unrolled_admm.py index 43b6b956..dbe26044 100644 --- a/lensless/recon/unrolled_admm.py +++ b/lensless/recon/unrolled_admm.py @@ -98,8 +98,7 @@ def __init__( self._PsiT = psi_adj self._PsiTPsi = psi_gram(self._padded_shape) - self._PsiTPsi = self._PsiTPsi.to(self._psf.device) - + self._PsiTPsi = torch.nn.Parameter(self._PsiTPsi.to(self._psf.device), requires_grad=False) self.reset() def _Psi(self, x): diff --git a/lensless/recon/utils.py b/lensless/recon/utils.py index 2ca758c6..792d4db9 100644 --- a/lensless/recon/utils.py +++ b/lensless/recon/utils.py @@ -237,6 +237,53 @@ def create_process_network(network, depth, device="cpu"): return (process, process_name) +def device_checks(device=None, multi_gpu=False, logger=None): + + if logger is None: + logger = print + + use_cuda = torch.cuda.is_available() + device_ids = None + if device is None: + if use_cuda: + logger("CUDA available, using GPU.") + device = "cuda:0" + n_gpus = torch.cuda.device_count() + if n_gpus > 1 and multi_gpu: + print(f"-- using {n_gpus} GPUs") + device_ids = np.arange(n_gpus).tolist() + else: + device = "cpu" + logger("CUDA not available, using CPU.") + else: + if device == "cpu": + use_cuda = False + else: + try: + gpu_id = int(device.split(":")[1]) + except ValueError: + raise ValueError( + "Bad device specification. Should be 'cpu' or something like 'cuda:1' to set the GPU ID." + ) + assert use_cuda, f"No GPU available but device set to {device}." + n_gpus = torch.cuda.device_count() + assert gpu_id < n_gpus, f"GPU {device} not available" + if n_gpus > 1 and multi_gpu: + device_ids = np.arange(n_gpus) + device_ids[[0, gpu_id]] = device_ids[[gpu_id, 0]] + device_ids = device_ids.tolist() + elif device is not None: + device_ids = [int(device.split(":")[1])] + + if device_ids is None or len(device_ids) == 1: + multi_gpu = False + logger(f"main device : {device}") + logger(f"multi GPU : {multi_gpu}") + logger(f"device ids : {device_ids}") + + return device, use_cuda, multi_gpu, device_ids + + class Trainer: def __init__( self, @@ -312,6 +359,7 @@ def __init__( """ + self.device = recon._psf.device self.logger = logger self.recon = recon @@ -471,7 +519,7 @@ def train_epoch(self, data_loader, disp=-1): self.recon._set_psf(self.mask.get_psf()) # forward pass - y_pred = self.recon.batch_call(X.to(self.device)) + y_pred = self.recon.forward(X.to(self.device)) # normalizing each output eps = 1e-12 y_pred_max = torch.amax(y_pred, dim=(-1, -2, -3), keepdim=True) + eps diff --git a/scripts/recon/train_unrolled.py b/scripts/recon/train_unrolled.py index c9be1ee4..b3cb0e9c 100644 --- a/scripts/recon/train_unrolled.py +++ b/scripts/recon/train_unrolled.py @@ -49,7 +49,7 @@ from lensless.recon.utils import create_process_network from lensless.utils.image import rgb2gray, is_grayscale from lensless.utils.simulation import FarFieldSimulator -from lensless.recon.utils import Trainer +from lensless.recon.utils import Trainer, device_checks import torch from torchvision import transforms, datasets from lensless.utils.io import load_psf @@ -61,12 +61,25 @@ log = logging.getLogger(__name__) +class MyDataParallel(torch.nn.DataParallel): + def __getattr__(self, name): + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self.module, name) + + def simulate_dataset(config): - if config.torch_device == "cuda" and torch.cuda.is_available(): - device = "cuda" - else: - device = "cpu" + # if config.torch_device == "cuda" and torch.cuda.is_available(): + # device = "cuda" + # else: + # device = "cpu" + device, use_cuda, multi_gpu, device_ids = device_checks(config.torch_device, config.multi_gpu) + + import pudb + + pudb.set_trace() # prepare PSF psf_fp = os.path.join(get_original_cwd(), config.files.psf) @@ -211,12 +224,9 @@ def train_unrolled(config): if save: save = os.getcwd() - if config.torch_device == "cuda" and torch.cuda.is_available(): - log.info("Using GPU for training.") - device = "cuda" - else: - log.info("Using CPU for training.") - device = "cpu" + device, use_cuda, multi_gpu, device_ids = device_checks( + config.torch_device, config.multi_gpu, logger=log.info + ) # load dataset and create dataloader train_set = None @@ -293,7 +303,7 @@ def train_unrolled(config): learn_tk=config.reconstruction.unrolled_fista.learn_tk, pre_process=pre_process, post_process=post_process, - ).to(device) + ) elif config.reconstruction.method == "unrolled_admm": recon = UnrolledADMM( psf, @@ -304,10 +314,16 @@ def train_unrolled(config): tau=config.reconstruction.unrolled_admm.tau, pre_process=pre_process, post_process=post_process, - ).to(device) + ) else: raise ValueError(f"{config.reconstruction.method} is not a supported algorithm") + if multi_gpu: + # recon = torch.nn.DataParallel(recon, device_ids=device_ids) + recon = MyDataParallel(recon, device_ids=device_ids) + if use_cuda: + recon.to(device) + # constructing algorithm name by appending pre and post process algorithm_name = config.reconstruction.method if config.reconstruction.pre_process.network is not None: diff --git a/test/test_algos.py b/test/test_algos.py index b63b5a42..c464f2b1 100644 --- a/test/test_algos.py +++ b/test/test_algos.py @@ -183,7 +183,7 @@ def post_process(x, noise): next(recon.parameters(), None) is not None ), f"{algorithm.__name__} has no trainable parameters" - res = recon.batch_call(data) + res = recon.forward(data) loss = torch.mean(res) loss.backward() @@ -197,7 +197,7 @@ def post_process(x, noise): @pytest.mark.parametrize("algorithm", trainable_algos) def test_trainable_batch(algorithm): - # test if batch_call and pally give the same result for any batch size + # test if forward and pally give the same result for any batch size if not torch_is_available: return for dtype, torch_type in [("float32", torch.float32), ("float64", torch.float64)]: @@ -215,8 +215,8 @@ def post_process(x, noise): recon = algorithm( psf, dtype=dtype, n_iter=_n_iter, pre_process=pre_process, post_process=post_process ) - res1 = recon.batch_call(data1) - res2 = recon.batch_call(data2) + res1 = recon.forward(data1) + res2 = recon.forward(data2) recon.set_data(data2[0]) res3 = recon.apply(disp_iter=None, plot=False)