diff --git a/lensless/utils/dataset.py b/lensless/utils/dataset.py index 093cc298..6cd20cd0 100644 --- a/lensless/utils/dataset.py +++ b/lensless/utils/dataset.py @@ -207,7 +207,7 @@ def __init__( self.vertical_shift = vertical_shift self.horizontal_shift = horizontal_shift - self.crop = crop + self.crop = crop.copy() if crop is not None else None if downsample != 1: if self.vertical_shift is not None: self.vertical_shift = int(self.vertical_shift // downsample) diff --git a/scripts/recon/train_unrolled.py b/scripts/recon/train_unrolled.py index 9365a262..feef375a 100644 --- a/scripts/recon/train_unrolled.py +++ b/scripts/recon/train_unrolled.py @@ -93,8 +93,7 @@ def simulate_dataset(config, generator=None): else: # training mask / PSF - # mask = prep_trainable_mask(config, psf, downsample=config.files.downsample) - mask = prep_trainable_mask(config, psf, downsample=config.simulation.downsample) + mask = prep_trainable_mask(config, psf) psf = mask.get_psf().to(device) # -- load dataset @@ -163,8 +162,6 @@ def simulate_dataset(config, generator=None): **config.simulation, ) - # import pudb; pudb.set_trace() - # create Pytorch dataset and dataloader crop = config.files.crop.copy() if config.files.crop is not None else None if mask is None: @@ -270,6 +267,7 @@ def simulate_dataset(config, generator=None): def prep_trainable_mask(config, psf=None, downsample=None): mask = None color_filter = None + downsample = config.files.downsample if downsample is None else downsample if config.trainable_mask.mask_type is not None: mask_class = getattr(lensless.hardware.trainable_mask, config.trainable_mask.mask_type) @@ -475,6 +473,7 @@ def train_unrolled(config): # lensless, lensed = dataset[_idx] lensless, lensed = test_set[_idx] recon = ADMM(psf) + recon.set_data(lensless.to(psf.device)) res = recon.apply(disp_iter=None, plot=False, n_iter=10) res_np = res[0].cpu().numpy()