Skip to content

Commit

Permalink
Fix AdafruitLCD input handling. (#115)
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzam authored Feb 26, 2024
1 parent 9e1e8f1 commit 81c54ca
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 5 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ Changed
Bugfix
~~~~~

- Nothing
- ``lensless.hardware.trainable_mask.AdafruitLCD`` input handling.

1.0.6 - (2024-02-21)
--------------------
Expand Down
25 changes: 21 additions & 4 deletions lensless/hardware/trainable_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ def project(self):


class AdafruitLCD(TrainableMask):
# class AdafruitLCD(torch.nn.Module, TrainableMask):
def __init__(
self,
initial_vals,
Expand All @@ -126,7 +125,7 @@ def __init__(
color_filter=None,
rotate=None,
flipud=False,
use_waveprop=None,
use_waveprop=False,
vertical_shift=None,
horizontal_shift=None,
scene2mask=None,
Expand All @@ -145,9 +144,23 @@ def __init__(
slm_param : :py:class:`~lensless.hardware.slm.SLMParam`
SLM parameters.
rotate : float, optional
Rotation angle in degrees, by default None
Rotation angle in degrees, by default None.
flipud : bool, optional
Whether to flip the mask vertically, by default False
Whether to flip the mask vertically, by default False.
use_waveprop : bool, optional
Whether to use wave propagation for simulating PSF. If False, PSF will simply be intensity of mask pattern, by default False.
vertical_shift : int, optional
Vertical shift of the mask, by default None.
horizontal_shift : int, optional
Horizontal shift of the mask, by default None.
scene2mask : :py:class:`~torch.Tensor`, optional
Distance from scene to mask. Used for wave propagation, by default None.
mask2sensor : :py:class:`~torch.Tensor`, optional
Distance from mask to sensor. Used for wave propagation, by default None.
downsample : int, optional
Downsample factor, by default None.
min_val : float, optional
Minimum value for mask weights, by default 0.
"""

super().__init__(**kwargs)
Expand All @@ -158,6 +171,7 @@ def __init__(
else:
self._vals = initial_vals

initial_param = None
if color_filter is not None:
self._color_filter = torch.nn.Parameter(color_filter)
if train_mask_vals:
Expand All @@ -168,6 +182,9 @@ def __init__(
assert (
train_mask_vals
), "If color filter is not trainable, mask values must be trainable"
initial_param = [self._vals]
self._color_filter = None
assert initial_param is not None, "Initial parameters must be set"

# set optimizer
self._set_optimizer(initial_param)
Expand Down

0 comments on commit 81c54ca

Please sign in to comment.