Skip to content

Commit

Permalink
Fix crop copying issue and update mask preparation.
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzam committed Nov 20, 2023
1 parent 27d27dd commit eb1d33d
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 5 deletions.
2 changes: 1 addition & 1 deletion lensless/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def __init__(

self.vertical_shift = vertical_shift
self.horizontal_shift = horizontal_shift
self.crop = crop
self.crop = crop.copy() if crop is not None else None
if downsample != 1:
if self.vertical_shift is not None:
self.vertical_shift = int(self.vertical_shift // downsample)
Expand Down
7 changes: 3 additions & 4 deletions scripts/recon/train_unrolled.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,7 @@ def simulate_dataset(config, generator=None):

else:
# training mask / PSF
# mask = prep_trainable_mask(config, psf, downsample=config.files.downsample)
mask = prep_trainable_mask(config, psf, downsample=config.simulation.downsample)
mask = prep_trainable_mask(config, psf)
psf = mask.get_psf().to(device)

# -- load dataset
Expand Down Expand Up @@ -163,8 +162,6 @@ def simulate_dataset(config, generator=None):
**config.simulation,
)

# import pudb; pudb.set_trace()

# create Pytorch dataset and dataloader
crop = config.files.crop.copy() if config.files.crop is not None else None
if mask is None:
Expand Down Expand Up @@ -270,6 +267,7 @@ def simulate_dataset(config, generator=None):
def prep_trainable_mask(config, psf=None, downsample=None):
mask = None
color_filter = None
downsample = config.files.downsample if downsample is None else downsample
if config.trainable_mask.mask_type is not None:
mask_class = getattr(lensless.hardware.trainable_mask, config.trainable_mask.mask_type)

Expand Down Expand Up @@ -475,6 +473,7 @@ def train_unrolled(config):
# lensless, lensed = dataset[_idx]
lensless, lensed = test_set[_idx]
recon = ADMM(psf)

recon.set_data(lensless.to(psf.device))
res = recon.apply(disp_iter=None, plot=False, n_iter=10)
res_np = res[0].cpu().numpy()
Expand Down

0 comments on commit eb1d33d

Please sign in to comment.