Skip to content

Commit

Permalink
Start interface for trainable coded aperture.
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzam committed Dec 6, 2023
1 parent 8e3747b commit f2c3c29
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
1 change: 0 additions & 1 deletion lensless/hardware/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ def __init__(
assert np.all(feature_size > 0), "Feature size should be positive"
assert np.all(resolution * feature_size <= size)

self.phase_mask = None
self.resolution = resolution
self.size = size
if feature_size is None:
Expand Down
16 changes: 16 additions & 0 deletions lensless/hardware/trainable_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from lensless.hardware.slm import get_programmable_mask, get_intensity_psf
from lensless.hardware.sensor import VirtualSensor
from waveprop.devices import slm_dict
from lensless.hardware.mask import CodedAperture


class TrainableMask(torch.nn.Module, metaclass=abc.ABCMeta):
Expand Down Expand Up @@ -230,3 +231,18 @@ def project(self):
self.color_filter.data = self.color_filter / self.color_filter.sum(
dim=[1, 2]
).unsqueeze(-1).unsqueeze(-1)


class TrainableCodedAperture(CodedAperture):
def __init__(self, **kwargs):
super().__init__(**kwargs)

self.row = torch.nn.Parameter(self.row)
self.col = torch.nn.Parameter(self.col)

def get_psf(self):
return super().compute_psf()

def project(self):
self.row.data = torch.clamp(self.row, 0, 1)
self.col.data = torch.clamp(self.col, 0, 1)

0 comments on commit f2c3c29

Please sign in to comment.