Skip to content

Commit

Permalink
Small fixes with unrolled training. (#99)
Browse files Browse the repository at this point in the history
* Add support for training from measured CelebA.

* Update example to compare with original.

* Update default unrolled config.

* Clean up global shutter capture.

* Fix nbits for global shutter.

* Long exposure comments.

* Fix path.

* Fix aperture.

* Update setup for Python 3.11

* Improve benchmarking.

* Use natural sorting.

* Save analysis.

* Save eval examples.

* Set seed.

* Fix typo.

* Add support to benchmark on DigiCamCelebA dataset.

* Better align simulated PSF.

* Add support to train adafruit mask.

* Fix data type of shape for new PyTorch.

* Add sensor.

* Add option to set number of channels.

* Add more options to analyzing measured dataset.

* Fix resizing.

* Update configs.

* Add option to train mask color filter.

* Add and improve hardware utilities.

* Add more features to unrolled training.

* Update configs.

* update changelog.

* Small fixes.

* Fix crop copying issue and update mask preparation.

---------

Co-authored-by: Eric Bezzam <[email protected]>
  • Loading branch information
ebezzam and Eric Bezzam authored Nov 20, 2023
1 parent 53ac235 commit a0c687c
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 5 deletions.
2 changes: 1 addition & 1 deletion lensless/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def __init__(

self.vertical_shift = vertical_shift
self.horizontal_shift = horizontal_shift
self.crop = crop
self.crop = crop.copy() if crop is not None else None
if downsample != 1:
if self.vertical_shift is not None:
self.vertical_shift = int(self.vertical_shift // downsample)
Expand Down
7 changes: 3 additions & 4 deletions scripts/recon/train_unrolled.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,7 @@ def simulate_dataset(config, generator=None):

else:
# training mask / PSF
# mask = prep_trainable_mask(config, psf, downsample=config.files.downsample)
mask = prep_trainable_mask(config, psf, downsample=config.simulation.downsample)
mask = prep_trainable_mask(config, psf)
psf = mask.get_psf().to(device)

# -- load dataset
Expand Down Expand Up @@ -163,8 +162,6 @@ def simulate_dataset(config, generator=None):
**config.simulation,
)

# import pudb; pudb.set_trace()

# create Pytorch dataset and dataloader
crop = config.files.crop.copy() if config.files.crop is not None else None
if mask is None:
Expand Down Expand Up @@ -270,6 +267,7 @@ def simulate_dataset(config, generator=None):
def prep_trainable_mask(config, psf=None, downsample=None):
mask = None
color_filter = None
downsample = config.files.downsample if downsample is None else downsample
if config.trainable_mask.mask_type is not None:
mask_class = getattr(lensless.hardware.trainable_mask, config.trainable_mask.mask_type)

Expand Down Expand Up @@ -475,6 +473,7 @@ def train_unrolled(config):
# lensless, lensed = dataset[_idx]
lensless, lensed = test_set[_idx]
recon = ADMM(psf)

recon.set_data(lensless.to(psf.device))
res = recon.apply(disp_iter=None, plot=False, n_iter=10)
res_np = res[0].cpu().numpy()
Expand Down

0 comments on commit a0c687c

Please sign in to comment.