Skip to content

Commit

Permalink
Start adding support for multi GPU.
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzam committed Sep 25, 2023
1 parent 816f050 commit aefc253
Show file tree
Hide file tree
Showing 8 changed files with 90 additions and 27 deletions.
3 changes: 2 additions & 1 deletion configs/train_unrolledADMM.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ files:
downsample: 2 # factor by which to downsample the PSF, note that for DiffuserCam the PSF has 4x the resolution

torch: True
torch_device: 'cuda'
torch_device: cuda:0
multi_gpu: True

display:
# How many iterations to wait for intermediate plot.
Expand Down
6 changes: 3 additions & 3 deletions docs/source/reconstruction.rst
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@


.. autoclass:: lensless.TrainableReconstructionAlgorithm
:members: batch_call, apply, reset, set_data
:members: forward, apply, reset, set_data
:special-members: __init__
:show-inheritance:

Expand All @@ -78,15 +78,15 @@
~~~~~~~~~~~~~~

.. autoclass:: lensless.UnrolledFISTA
:members: batch_call
:members: forward
:special-members: __init__
:show-inheritance:

Unrolled ADMM
~~~~~~~~~~~~~

.. autoclass:: lensless.UnrolledADMM
:members: batch_call
:members: forward
:special-members: __init__
:show-inheritance:

Expand Down
2 changes: 1 addition & 1 deletion lensless/eval/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def benchmark(model, dataset, batchsize=1, metrics=None, **kwargs):
prediction = model.apply(plot=False, save=False, **kwargs)

else:
prediction = model.batch_call(lensless, **kwargs)
prediction = model.forward(lensless, **kwargs)

# Convert to [N*D, C, H, W] for torchmetrics
prediction = prediction.reshape(-1, *prediction.shape[-3:]).movedim(-1, -3)
Expand Down
3 changes: 1 addition & 2 deletions lensless/recon/trainable_recon.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,9 @@ def _prepare_process_block(self, process):

return process_function, process_model, process_param

def batch_call(self, batch):
def forward(self, batch):
"""
Method for performing iterative reconstruction on a batch of images.
This implementation is a properly vectorized implementation of FISTA.
Parameters
----------
Expand Down
3 changes: 1 addition & 2 deletions lensless/recon/unrolled_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,7 @@ def __init__(
self._PsiT = psi_adj
self._PsiTPsi = psi_gram(self._padded_shape)

self._PsiTPsi = self._PsiTPsi.to(self._psf.device)

self._PsiTPsi = torch.nn.Parameter(self._PsiTPsi.to(self._psf.device), requires_grad=False)
self.reset()

def _Psi(self, x):
Expand Down
50 changes: 49 additions & 1 deletion lensless/recon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,53 @@ def create_process_network(network, depth, device="cpu"):
return (process, process_name)


def device_checks(device=None, multi_gpu=False, logger=None):

if logger is None:
logger = print

use_cuda = torch.cuda.is_available()
device_ids = None
if device is None:
if use_cuda:
logger("CUDA available, using GPU.")
device = "cuda:0"
n_gpus = torch.cuda.device_count()
if n_gpus > 1 and multi_gpu:
print(f"-- using {n_gpus} GPUs")
device_ids = np.arange(n_gpus).tolist()
else:
device = "cpu"
logger("CUDA not available, using CPU.")
else:
if device == "cpu":
use_cuda = False
else:
try:
gpu_id = int(device.split(":")[1])
except ValueError:
raise ValueError(
"Bad device specification. Should be 'cpu' or something like 'cuda:1' to set the GPU ID."
)
assert use_cuda, f"No GPU available but device set to {device}."
n_gpus = torch.cuda.device_count()
assert gpu_id < n_gpus, f"GPU {device} not available"
if n_gpus > 1 and multi_gpu:
device_ids = np.arange(n_gpus)
device_ids[[0, gpu_id]] = device_ids[[gpu_id, 0]]
device_ids = device_ids.tolist()
elif device is not None:
device_ids = [int(device.split(":")[1])]

if device_ids is None or len(device_ids) == 1:
multi_gpu = False
logger(f"main device : {device}")
logger(f"multi GPU : {multi_gpu}")
logger(f"device ids : {device_ids}")

return device, use_cuda, multi_gpu, device_ids


class Trainer:
def __init__(
self,
Expand Down Expand Up @@ -312,6 +359,7 @@ def __init__(
"""

self.device = recon._psf.device
self.logger = logger
self.recon = recon
Expand Down Expand Up @@ -471,7 +519,7 @@ def train_epoch(self, data_loader, disp=-1):
self.recon._set_psf(self.mask.get_psf())

# forward pass
y_pred = self.recon.batch_call(X.to(self.device))
y_pred = self.recon.forward(X.to(self.device))
# normalizing each output
eps = 1e-12
y_pred_max = torch.amax(y_pred, dim=(-1, -2, -3), keepdim=True) + eps
Expand Down
42 changes: 29 additions & 13 deletions scripts/recon/train_unrolled.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from lensless.recon.utils import create_process_network
from lensless.utils.image import rgb2gray, is_grayscale
from lensless.utils.simulation import FarFieldSimulator
from lensless.recon.utils import Trainer
from lensless.recon.utils import Trainer, device_checks
import torch
from torchvision import transforms, datasets
from lensless.utils.io import load_psf
Expand All @@ -61,12 +61,25 @@
log = logging.getLogger(__name__)


class MyDataParallel(torch.nn.DataParallel):
def __getattr__(self, name):
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.module, name)


def simulate_dataset(config):

if config.torch_device == "cuda" and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
# if config.torch_device == "cuda" and torch.cuda.is_available():
# device = "cuda"
# else:
# device = "cpu"
device, use_cuda, multi_gpu, device_ids = device_checks(config.torch_device, config.multi_gpu)

import pudb

pudb.set_trace()

# prepare PSF
psf_fp = os.path.join(get_original_cwd(), config.files.psf)
Expand Down Expand Up @@ -211,12 +224,9 @@ def train_unrolled(config):
if save:
save = os.getcwd()

if config.torch_device == "cuda" and torch.cuda.is_available():
log.info("Using GPU for training.")
device = "cuda"
else:
log.info("Using CPU for training.")
device = "cpu"
device, use_cuda, multi_gpu, device_ids = device_checks(
config.torch_device, config.multi_gpu, logger=log.info
)

# load dataset and create dataloader
train_set = None
Expand Down Expand Up @@ -293,7 +303,7 @@ def train_unrolled(config):
learn_tk=config.reconstruction.unrolled_fista.learn_tk,
pre_process=pre_process,
post_process=post_process,
).to(device)
)
elif config.reconstruction.method == "unrolled_admm":
recon = UnrolledADMM(
psf,
Expand All @@ -304,10 +314,16 @@ def train_unrolled(config):
tau=config.reconstruction.unrolled_admm.tau,
pre_process=pre_process,
post_process=post_process,
).to(device)
)
else:
raise ValueError(f"{config.reconstruction.method} is not a supported algorithm")

if multi_gpu:
# recon = torch.nn.DataParallel(recon, device_ids=device_ids)
recon = MyDataParallel(recon, device_ids=device_ids)
if use_cuda:
recon.to(device)

# constructing algorithm name by appending pre and post process
algorithm_name = config.reconstruction.method
if config.reconstruction.pre_process.network is not None:
Expand Down
8 changes: 4 additions & 4 deletions test/test_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def post_process(x, noise):
next(recon.parameters(), None) is not None
), f"{algorithm.__name__} has no trainable parameters"

res = recon.batch_call(data)
res = recon.forward(data)
loss = torch.mean(res)
loss.backward()

Expand All @@ -197,7 +197,7 @@ def post_process(x, noise):

@pytest.mark.parametrize("algorithm", trainable_algos)
def test_trainable_batch(algorithm):
# test if batch_call and pally give the same result for any batch size
# test if forward and pally give the same result for any batch size
if not torch_is_available:
return
for dtype, torch_type in [("float32", torch.float32), ("float64", torch.float64)]:
Expand All @@ -215,8 +215,8 @@ def post_process(x, noise):
recon = algorithm(
psf, dtype=dtype, n_iter=_n_iter, pre_process=pre_process, post_process=post_process
)
res1 = recon.batch_call(data1)
res2 = recon.batch_call(data2)
res1 = recon.forward(data1)
res2 = recon.forward(data2)
recon.set_data(data2[0])
res3 = recon.apply(disp_iter=None, plot=False)

Expand Down

0 comments on commit aefc253

Please sign in to comment.