diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 5a646e3f..758a4fea 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -29,7 +29,7 @@ Changed Bugfix ~~~~~ -- Nothing +- ``lensless.hardware.trainable_mask.AdafruitLCD`` input handling. 1.0.6 - (2024-02-21) -------------------- diff --git a/lensless/hardware/trainable_mask.py b/lensless/hardware/trainable_mask.py index 254bd5b1..9d937b8d 100644 --- a/lensless/hardware/trainable_mask.py +++ b/lensless/hardware/trainable_mask.py @@ -116,7 +116,6 @@ def project(self): class AdafruitLCD(TrainableMask): - # class AdafruitLCD(torch.nn.Module, TrainableMask): def __init__( self, initial_vals, @@ -126,7 +125,7 @@ def __init__( color_filter=None, rotate=None, flipud=False, - use_waveprop=None, + use_waveprop=False, vertical_shift=None, horizontal_shift=None, scene2mask=None, @@ -145,9 +144,23 @@ def __init__( slm_param : :py:class:`~lensless.hardware.slm.SLMParam` SLM parameters. rotate : float, optional - Rotation angle in degrees, by default None + Rotation angle in degrees, by default None. flipud : bool, optional - Whether to flip the mask vertically, by default False + Whether to flip the mask vertically, by default False. + use_waveprop : bool, optional + Whether to use wave propagation for simulating PSF. If False, PSF will simply be intensity of mask pattern, by default False. + vertical_shift : int, optional + Vertical shift of the mask, by default None. + horizontal_shift : int, optional + Horizontal shift of the mask, by default None. + scene2mask : :py:class:`~torch.Tensor`, optional + Distance from scene to mask. Used for wave propagation, by default None. + mask2sensor : :py:class:`~torch.Tensor`, optional + Distance from mask to sensor. Used for wave propagation, by default None. + downsample : int, optional + Downsample factor, by default None. + min_val : float, optional + Minimum value for mask weights, by default 0. """ super().__init__(**kwargs) @@ -158,6 +171,7 @@ def __init__( else: self._vals = initial_vals + initial_param = None if color_filter is not None: self._color_filter = torch.nn.Parameter(color_filter) if train_mask_vals: @@ -168,6 +182,9 @@ def __init__( assert ( train_mask_vals ), "If color filter is not trainable, mask values must be trainable" + initial_param = [self._vals] + self._color_filter = None + assert initial_param is not None, "Initial parameters must be set" # set optimizer self._set_optimizer(initial_param)