Skip to content

Commit

Permalink
Update trainable mask interface.
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzam committed Dec 6, 2023
1 parent f2c3c29 commit cd76a3c
Show file tree
Hide file tree
Showing 6 changed files with 240 additions and 100 deletions.
38 changes: 38 additions & 0 deletions configs/train_coded_aperture.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# python scripts/recon/train_unrolled.py -cn train_coded_aperture
defaults:
- train_unrolledADMM
- _self_

# Train Dataset
files:
dataset: mnist # Simulated : "mnist", "fashion_mnist", "cifar10", "CelebA". Measure :"DiffuserCam"
celeba_root: /scratch/bezzam
downsample: 8

#Trainable Mask
trainable_mask:
mask_type: TrainableCodedAperture
optimizer: Adam
mask_lr: 1e-3
initial_value:
method: MLS
n_bits: 8
# MURA not working...
# method: MURA
# n_bits: 3

simulation:
grayscale: False
flip: False
scene2mask: 40e-2
mask2sensor: 2e-3
sensor: "rpi_hq"
downsample: 16
object_height: 0.30

training:
crop_preloss: False # crop region for computing loss
batch_size: 8
epoch: 25
eval_batch_size: 16
save_every: 5
1 change: 1 addition & 0 deletions configs/train_unrolledADMM.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ trainable_mask:
initial_value: psf
grayscale: False
mask_lr: 1e-3
optimizer: Adam
L1_strength: 1.0 #False or float

target: "object_plane" # "original" or "object_plane" or "label"
Expand Down
80 changes: 59 additions & 21 deletions lensless/hardware/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def __init__(
size=None,
feature_size=None,
psf_wavelength=[460e-9, 550e-9, 640e-9],
is_torch=False,
torch_device="cpu",
**kwargs
):
"""
Expand Down Expand Up @@ -95,13 +97,17 @@ def __init__(
assert np.all(resolution * feature_size <= size)

self.resolution = resolution
self.resolution = (int(self.resolution[0]), int(self.resolution[1]))
self.size = size
if feature_size is None:
self.feature_size = self.size / self.resolution
else:
self.feature_size = feature_size
self.distance_sensor = distance_sensor

self.is_torch = is_torch
self.torch_device = torch_device

# create mask
self.mask = None
self.create_mask()
Expand Down Expand Up @@ -155,19 +161,28 @@ def compute_psf(self):
Compute the intensity PSF with bandlimited angular spectrum (BLAS) for each wavelength.
Common to all types of masks.
"""
psf = np.zeros(tuple(self.resolution) + (len(self.psf_wavelength),), dtype=np.complex64)
if self.is_torch:
psf = torch.zeros(
tuple(self.resolution) + (len(self.psf_wavelength),), dtype=torch.complex64
)
else:
psf = np.zeros(tuple(self.resolution) + (len(self.psf_wavelength),), dtype=np.complex64)
for i, wv in enumerate(self.psf_wavelength):
psf[:, :, i] = angular_spectrum(
u_in=self.mask,
wv=wv,
d1=self.feature_size,
dz=self.distance_sensor,
dtype=np.float32,
dtype=np.float32 if not self.is_torch else torch.float32,
bandlimit=True,
device=self.torch_device if self.is_torch else None,
)[0]

# intensity PSF
self.psf = np.abs(psf) ** 2
if self.is_torch:
self.psf = torch.abs(psf) ** 2
else:
self.psf = np.abs(psf) ** 2


class CodedAperture(Mask):
Expand Down Expand Up @@ -196,33 +211,55 @@ def __init__(self, method="MLS", n_bits=8, **kwargs):
self.method = method
self.n_bits = n_bits

assert self.method.upper() in ["MURA", "MLS"], "Method should be either 'MLS' or 'MURA'"
# TODO? use: https://github.com/bpops/codedapertures

# initialize parameters
if self.method.upper() == "MURA":
mask = self.squarepattern(4 * self.n_bits + 1)[1:, 1:]
self.row = self.mask[0, :]
self.col = self.mask[:, 0]
outer = np.outer(self.row, self.col)
assert np.all(outer == mask)
else:
seq = max_len_seq(self.n_bits)[0]
self.row = seq
self.col = seq

if kwargs["is_torch"]:
self.row = torch.from_numpy(self.row).float()
self.col = torch.from_numpy(self.col).float()

super().__init__(**kwargs)

def create_mask(self):
"""
Creating coded aperture mask using either the MURA of MLS method.
Creating coded aperture mask.
"""
assert self.method.upper() in ["MURA", "MLS"], "Method should be either 'MLS' or 'MURA'"

# Generating pattern
if self.method.upper() == "MURA":
self.mask = self.squarepattern(4 * self.n_bits + 1)[1:, 1:]
self.row = 2 * self.mask[0, :] - 1
self.col = 2 * self.mask[:, 0] - 1
# outer product
if self.is_torch:
self.mask = torch.outer(self.row, self.col)
else:
seq = max_len_seq(self.n_bits)[0] * 2 - 1
h_r = np.r_[seq, seq]
self.row = h_r
self.col = h_r
self.mask = (np.outer(h_r, h_r) + 1) / 2
self.mask = np.outer(self.row, self.col)

# Upscaling
# resize to sensor shape
if np.any(self.resolution != self.mask.shape):
upscaled_mask = resize(
self.mask[:, :, np.newaxis], shape=tuple(self.resolution) + (1,)
).squeeze()
upscaled_mask = np.clip(upscaled_mask, 0, 1)
self.mask = np.round(upscaled_mask).astype(int)

if self.is_torch:
self.mask = self.mask.unsqueeze(0).unsqueeze(0)
self.mask = torch.nn.functional.interpolate(
self.mask, size=tuple(self.resolution), mode="nearest"
).squeeze()
else:
# self.mask = resize(self.mask[:, :, np.newaxis], shape=tuple(self.resolution) + (1,))
self.mask = resize(
self.mask[:, :, np.newaxis],
shape=tuple(self.resolution) + (1,),
interpolation=cv.INTER_NEAREST,
).squeeze()

# assert np.all(np.unique(self.mask) == np.array([0, 1]))

def is_prime(self, n):
"""
Expand All @@ -246,6 +283,7 @@ def squarepattern(self, p):
p: int
Number of bits.
"""

if not self.is_prime(p):
raise ValueError("p is not a valid length. It must be prime.")
A = np.zeros((p, p), dtype=int)
Expand Down
1 change: 1 addition & 0 deletions lensless/hardware/sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def from_name(cls, name, downsample=None):
Sensor.
"""

if name not in SensorOptions.values():
raise ValueError(f"Sensor {name} not supported.")
sensor_specs = sensor_dict[name].copy()
Expand Down
106 changes: 73 additions & 33 deletions lensless/hardware/trainable_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,24 +26,28 @@ class TrainableMask(torch.nn.Module, metaclass=abc.ABCMeta):
"""

def __init__(self, initial_mask, optimizer="Adam", lr=1e-3, **kwargs):
def __init__(self, optimizer="Adam", lr=1e-3, **kwargs):
"""
Base constructor. Derived constructor may define new state variables
Parameters
----------
initial_mask : :py:class:`~torch.Tensor`
Initial mask parameters.
optimizer : str, optional
Optimizer to use for updating the mask parameters, by default "Adam"
lr : float, optional
Learning rate for the mask parameters, by default 1e-3
"""
super().__init__()
self._mask = torch.nn.Parameter(initial_mask)
self._optimizer = getattr(torch.optim, optimizer)([self._mask], lr=lr)
self.train_mask_vals = True
self._counter = 0
# self._param = [torch.nn.Parameter(p, requires_grad=True) for p in initial_param]
# # self._param = initial_param
# self._optimizer = getattr(torch.optim, optimizer)(self._param, lr=lr)
# self._counter = 0
self.optimizer = optimizer
self.lr = lr

def _set_optimizer(self, param):
"""Set the optimizer for the mask parameters."""
self._optimizer = getattr(torch.optim, self.optimizer)(param, lr=self.lr)

@abc.abstractmethod
def get_psf(self):
Expand All @@ -66,7 +70,7 @@ def update_mask(self):

def get_vals(self):
"""Get the mask parameters."""
return self._mask
return self._param

@abc.abstractmethod
def project(self):
Expand All @@ -85,30 +89,38 @@ class TrainablePSF(TrainableMask):
Otherwise PSF will be returned as RGB. By default False.
"""

def __init__(self, initial_mask, optimizer="Adam", lr=1e-3, grayscale=False, **kwargs):
super().__init__(initial_mask, optimizer, lr, **kwargs)
assert (
len(initial_mask.shape) == 4
), "Mask must be of shape (depth, height, width, channels)"
def __init__(self, initial_psf, optimizer="Adam", lr=1e-3, grayscale=False, **kwargs):

super().__init__(optimizer, lr, **kwargs)

# cast as learnable parameters
self._psf = torch.nn.Parameter(initial_psf)

# set optimizer
initial_param = [self._psf]
self._set_optimizer(initial_param)

# checks
assert len(initial_psf.shape) == 4, "Mask must be of shape (depth, height, width, channels)"
self.grayscale = grayscale
self._is_grayscale = is_grayscale(initial_mask)
self._is_grayscale = is_grayscale(initial_psf)
if grayscale:
assert self._is_grayscale, "Mask must be grayscale"
assert self._is_grayscale, "PSF must be grayscale"

def get_psf(self):
if self._is_grayscale:
if self.grayscale:
# simulation in grayscale
return self._mask
return self._psf
else:
# replicate to 3 channels
return self._mask.expand(-1, -1, -1, 3)
return self._psf.expand(-1, -1, -1, 3)
else:
# assume RGB
return self._mask
return self._psf

def project(self):
self._mask.data = torch.clamp(self._mask, 0, 1)
self._psf.data = torch.clamp(self._psf, 0, 1)


class AdafruitLCD(TrainableMask):
Expand Down Expand Up @@ -147,23 +159,27 @@ def __init__(
Whether to flip the mask vertically, by default False
"""

super().__init__(initial_vals, **kwargs)
super().__init__(optimizer, lr, **kwargs)
self.train_mask_vals = train_mask_vals
if train_mask_vals:
self._mask = torch.nn.Parameter(initial_vals)
else:
self._mask = initial_vals

if color_filter is not None:
self.color_filter = torch.nn.Parameter(color_filter)
if train_mask_vals:
param = [self._mask, self.color_filter]
initial_param = [self._mask, self.color_filter]
else:
del self._mask
self._mask = initial_vals
param = [self.color_filter]
self._optimizer = getattr(torch.optim, optimizer)(param, lr=lr)
initial_param = [self.color_filter]
else:
self.color_filter = None
assert (
train_mask_vals
), "If color filter is not trainable, mask values must be trainable"

# set optimizer
self._set_optimizer(initial_param)

self.slm_param = slm_dict[slm]
self.device = slm
self.sensor = VirtualSensor.from_name(sensor, downsample=downsample)
Expand Down Expand Up @@ -233,16 +249,40 @@ def project(self):
).unsqueeze(-1).unsqueeze(-1)


class TrainableCodedAperture(CodedAperture):
def __init__(self, **kwargs):
super().__init__(**kwargs)
class TrainableCodedAperture(TrainableMask):
def __init__(
self, sensor_name, downsample=None, binary=True, optimizer="Adam", lr=1e-3, **kwargs
):
"""
TODO: Distinguish between separable and non-separable.
"""

super().__init__(optimizer, lr, **kwargs)

assert "distance_sensor" in kwargs, "Distance to sensor must be specified"
assert "method" in kwargs, "Method must be specified."
assert "n_bits" in kwargs, "Number of bits must be specified."

# initialize mask
self._mask = CodedAperture.from_sensor(sensor_name, downsample, is_torch=True, **kwargs)

# set learnable parameters (should be immediate attributes of the class)
self._row = torch.nn.Parameter(self._mask.row)
self._col = torch.nn.Parameter(self._mask.col)
initial_param = [self._row, self._col]
self.binary = binary

self.row = torch.nn.Parameter(self.row)
self.col = torch.nn.Parameter(self.col)
# set optimizer
self._set_optimizer(initial_param)

def get_psf(self):
return super().compute_psf()
self._mask.create_mask()
self._mask.compute_psf()
return self._mask.psf.unsqueeze(0)

def project(self):
self.row.data = torch.clamp(self.row, 0, 1)
self.col.data = torch.clamp(self.col, 0, 1)
self.col = torch.clamp(self.col, 0, 1)
if self.binary:
self.row.data = torch.round(self.row)
self.col.data = torch.round(self.col)
Loading

0 comments on commit cd76a3c

Please sign in to comment.