diff --git a/lensless/hardware/mask.py b/lensless/hardware/mask.py index f9597bf5..f6a5d797 100644 --- a/lensless/hardware/mask.py +++ b/lensless/hardware/mask.py @@ -94,7 +94,6 @@ def __init__( assert np.all(feature_size > 0), "Feature size should be positive" assert np.all(resolution * feature_size <= size) - self.phase_mask = None self.resolution = resolution self.size = size if feature_size is None: diff --git a/lensless/hardware/trainable_mask.py b/lensless/hardware/trainable_mask.py index f0d258ba..8c15353a 100644 --- a/lensless/hardware/trainable_mask.py +++ b/lensless/hardware/trainable_mask.py @@ -12,6 +12,7 @@ from lensless.hardware.slm import get_programmable_mask, get_intensity_psf from lensless.hardware.sensor import VirtualSensor from waveprop.devices import slm_dict +from lensless.hardware.mask import CodedAperture class TrainableMask(torch.nn.Module, metaclass=abc.ABCMeta): @@ -230,3 +231,18 @@ def project(self): self.color_filter.data = self.color_filter / self.color_filter.sum( dim=[1, 2] ).unsqueeze(-1).unsqueeze(-1) + + +class TrainableCodedAperture(CodedAperture): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + self.row = torch.nn.Parameter(self.row) + self.col = torch.nn.Parameter(self.col) + + def get_psf(self): + return super().compute_psf() + + def project(self): + self.row.data = torch.clamp(self.row, 0, 1) + self.col.data = torch.clamp(self.col, 0, 1)