Skip to content

Commit

Permalink
Fix coded aperture training (fashion mnist).
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzam committed Dec 13, 2023
1 parent 5e1f8cc commit 61651cf
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 22 deletions.
13 changes: 9 additions & 4 deletions configs/train_coded_aperture.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,15 @@ defaults:

# Train Dataset
files:
dataset: mnist # Simulated : "mnist", "fashion_mnist", "cifar10", "CelebA". Measure :"DiffuserCam"
dataset: fashion_mnist # Simulated : "mnist", "fashion_mnist", "cifar10", "CelebA". Measure :"DiffuserCam"
celeba_root: /scratch/bezzam
downsample: 16 # TODO use simulation instead?
downsample: 16 # TODO use downsample simulation instead?
n_files: 100
crop:
vertical: [810, 2240]
horizontal: [1310, 2750]

torch_device: "cuda:1"

#Trainable Mask
trainable_mask:
Expand All @@ -28,11 +34,10 @@ simulation:
scene2mask: 40e-2
mask2sensor: 2e-3
sensor: "rpi_hq"
downsample: 16
object_height: 0.30

training:
crop_preloss: False # crop region for computing loss
crop_preloss: True # crop region for computing loss
batch_size: 4
epoch: 25
eval_batch_size: 16
Expand Down
16 changes: 12 additions & 4 deletions lensless/hardware/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,17 +232,25 @@ def __init__(self, method="MLS", n_bits=8, **kwargs):

super().__init__(**kwargs)

def create_mask(self):
def create_mask(self, row=None, col=None, mask=None):
"""
Creating coded aperture mask.
"""

if mask is not None:
raise NotImplementedError("Mask loading not implemented yet.")

# if row and col are provided, use them
if row is None and col is None:
row = self.row
col = self.col

# outer product
if self.row is not None and self.col is not None:
if row is not None and col is not None:
if self.is_torch:
self.mask = torch.outer(self.row, self.col)
self.mask = torch.outer(row, col)
else:
self.mask = np.outer(self.row, self.col)
self.mask = np.outer(row, col)
else:
assert self.mask is not None

Expand Down
28 changes: 21 additions & 7 deletions lensless/hardware/trainable_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def project(self):


class TrainablePSF(TrainableMask):
# class TrainablePSF(torch.nn.Module, TrainableMask):
"""
Class for defining an object that directly optimizes the PSF, without any constraints on what can be realized physically.
Expand All @@ -88,15 +89,18 @@ class TrainablePSF(TrainableMask):

def __init__(self, initial_psf, optimizer="Adam", lr=1e-3, grayscale=False, **kwargs):

# BEFORE
super().__init__(optimizer, lr, **kwargs)

# cast as learnable parameters
self._psf = torch.nn.Parameter(initial_psf)

# set optimizer
initial_param = [self._psf]
self._set_optimizer(initial_param)

# # cast as learnable parameters
# super().__init__()
# self._psf = torch.nn.Parameter(initial_psf)
# self._optimizer = getattr(torch.optim, optimizer)([self._psf], lr=lr)
# self._counter = 0

# checks
assert len(initial_psf.shape) == 4, "Mask must be of shape (depth, height, width, channels)"
self.grayscale = grayscale
Expand All @@ -121,6 +125,7 @@ def project(self):


class AdafruitLCD(TrainableMask):
# class AdafruitLCD(torch.nn.Module, TrainableMask):
def __init__(
self,
initial_vals,
Expand Down Expand Up @@ -156,7 +161,9 @@ def __init__(
Whether to flip the mask vertically, by default False
"""

super().__init__(optimizer, lr, **kwargs)
super().__init__(optimizer, lr, **kwargs) # when using TrainableMask init
# super().__init__() # when using torch.nn.Module

self.train_mask_vals = train_mask_vals
if train_mask_vals:
self._vals = torch.nn.Parameter(initial_vals)
Expand All @@ -175,6 +182,8 @@ def __init__(
), "If color filter is not trainable, mask values must be trainable"

# set optimizer
# self._optimizer = getattr(torch.optim, optimizer)(initial_param, lr=lr)
# self._counter = 0
self._set_optimizer(initial_param)

self.slm_param = slm_dict[slm]
Expand Down Expand Up @@ -282,9 +291,14 @@ def __init__(
self._set_optimizer(initial_param)

def get_psf(self):
self._mask_obj.create_mask()
self._mask_obj.create_mask(self._row, self._col)
self._mask_obj.compute_psf()
return self._mask_obj.psf.unsqueeze(0)
psf = self._mask_obj.psf.unsqueeze(0)

# # need normalize the PSF? would think so but NAN comes up if included
# psf = psf / psf.norm()

return psf

def project(self):
if self.separable:
Expand Down
8 changes: 7 additions & 1 deletion lensless/recon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,8 @@ def train_epoch(self, data_loader):

# update psf according to mask
if self.use_mask:
self.recon._set_psf(self.mask.get_psf().to(self.device))
new_psf = self.mask.get_psf().to(self.device)
self.recon._set_psf(new_psf)

# forward pass
y_pred = self.recon.batch_call(X.to(self.device))
Expand Down Expand Up @@ -584,6 +585,11 @@ def train_epoch(self, data_loader):
loss_v = loss_v + self.l1_mask * torch.mean(torch.abs(p))
loss_v.backward()

# check mask parameters are learning
if self.use_mask:
for p in self.mask.parameters():
assert p.grad is not None

if self.clip_grad_norm is not None:
torch.nn.utils.clip_grad_norm_(self.recon.parameters(), self.clip_grad_norm)

Expand Down
14 changes: 8 additions & 6 deletions scripts/recon/train_unrolled.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,6 @@ def simulate_dataset(config, generator=None):
train_ds = datasets.MNIST(root=data_path, train=True, download=True, transform=transform)
test_ds = datasets.MNIST(root=data_path, train=False, download=True, transform=transform)

if config.files.n_files is not None:
train_size = int((1 - config.files.test_size) * config.files.n_files)
test_size = config.files.n_files - train_size
train_ds = Subset(train_ds, np.arange(train_size))
test_ds = Subset(test_ds, np.arange(test_size))

elif config.files.dataset == "fashion_mnist":
transform = transforms.Compose(transforms_list)
train_ds = datasets.FashionMNIST(
Expand All @@ -127,6 +121,7 @@ def simulate_dataset(config, generator=None):
transform = transforms.Compose(transforms_list)
train_ds = datasets.CIFAR10(root=data_path, train=True, download=True, transform=transform)
test_ds = datasets.CIFAR10(root=data_path, train=False, download=True, transform=transform)

elif config.files.dataset == "CelebA":
root = config.files.celeba_root
data_path = os.path.join(root, "celeba")
Expand All @@ -152,6 +147,13 @@ def simulate_dataset(config, generator=None):
else:
raise NotImplementedError(f"Dataset {config.files.dataset} not implemented.")

if config.files.dataset != "CelebA":
if config.files.n_files is not None:
train_size = int((1 - config.files.test_size) * config.files.n_files)
test_size = config.files.n_files - train_size
train_ds = Subset(train_ds, np.arange(train_size))
test_ds = Subset(test_ds, np.arange(test_size))

# convert PSF
if config.simulation.grayscale and not is_grayscale(psf):
psf = rgb2gray(psf)
Expand Down

0 comments on commit 61651cf

Please sign in to comment.