diff --git a/configs/train_coded_aperture.yaml b/configs/train_coded_aperture.yaml index b1e5e30a..e560da2d 100644 --- a/configs/train_coded_aperture.yaml +++ b/configs/train_coded_aperture.yaml @@ -5,9 +5,15 @@ defaults: # Train Dataset files: - dataset: mnist # Simulated : "mnist", "fashion_mnist", "cifar10", "CelebA". Measure :"DiffuserCam" + dataset: fashion_mnist # Simulated : "mnist", "fashion_mnist", "cifar10", "CelebA". Measure :"DiffuserCam" celeba_root: /scratch/bezzam - downsample: 16 # TODO use simulation instead? + downsample: 16 # TODO use downsample simulation instead? + n_files: 100 + crop: + vertical: [810, 2240] + horizontal: [1310, 2750] + +torch_device: "cuda:1" #Trainable Mask trainable_mask: @@ -28,11 +34,10 @@ simulation: scene2mask: 40e-2 mask2sensor: 2e-3 sensor: "rpi_hq" - downsample: 16 object_height: 0.30 training: - crop_preloss: False # crop region for computing loss + crop_preloss: True # crop region for computing loss batch_size: 4 epoch: 25 eval_batch_size: 16 diff --git a/lensless/hardware/mask.py b/lensless/hardware/mask.py index 7c34d627..17393cce 100644 --- a/lensless/hardware/mask.py +++ b/lensless/hardware/mask.py @@ -232,17 +232,25 @@ def __init__(self, method="MLS", n_bits=8, **kwargs): super().__init__(**kwargs) - def create_mask(self): + def create_mask(self, row=None, col=None, mask=None): """ Creating coded aperture mask. """ + if mask is not None: + raise NotImplementedError("Mask loading not implemented yet.") + + # if row and col are provided, use them + if row is None and col is None: + row = self.row + col = self.col + # outer product - if self.row is not None and self.col is not None: + if row is not None and col is not None: if self.is_torch: - self.mask = torch.outer(self.row, self.col) + self.mask = torch.outer(row, col) else: - self.mask = np.outer(self.row, self.col) + self.mask = np.outer(row, col) else: assert self.mask is not None diff --git a/lensless/hardware/trainable_mask.py b/lensless/hardware/trainable_mask.py index eef57933..5c0c16b3 100644 --- a/lensless/hardware/trainable_mask.py +++ b/lensless/hardware/trainable_mask.py @@ -76,6 +76,7 @@ def project(self): class TrainablePSF(TrainableMask): + # class TrainablePSF(torch.nn.Module, TrainableMask): """ Class for defining an object that directly optimizes the PSF, without any constraints on what can be realized physically. @@ -88,15 +89,18 @@ class TrainablePSF(TrainableMask): def __init__(self, initial_psf, optimizer="Adam", lr=1e-3, grayscale=False, **kwargs): + # BEFORE super().__init__(optimizer, lr, **kwargs) - - # cast as learnable parameters self._psf = torch.nn.Parameter(initial_psf) - - # set optimizer initial_param = [self._psf] self._set_optimizer(initial_param) + # # cast as learnable parameters + # super().__init__() + # self._psf = torch.nn.Parameter(initial_psf) + # self._optimizer = getattr(torch.optim, optimizer)([self._psf], lr=lr) + # self._counter = 0 + # checks assert len(initial_psf.shape) == 4, "Mask must be of shape (depth, height, width, channels)" self.grayscale = grayscale @@ -121,6 +125,7 @@ def project(self): class AdafruitLCD(TrainableMask): + # class AdafruitLCD(torch.nn.Module, TrainableMask): def __init__( self, initial_vals, @@ -156,7 +161,9 @@ def __init__( Whether to flip the mask vertically, by default False """ - super().__init__(optimizer, lr, **kwargs) + super().__init__(optimizer, lr, **kwargs) # when using TrainableMask init + # super().__init__() # when using torch.nn.Module + self.train_mask_vals = train_mask_vals if train_mask_vals: self._vals = torch.nn.Parameter(initial_vals) @@ -175,6 +182,8 @@ def __init__( ), "If color filter is not trainable, mask values must be trainable" # set optimizer + # self._optimizer = getattr(torch.optim, optimizer)(initial_param, lr=lr) + # self._counter = 0 self._set_optimizer(initial_param) self.slm_param = slm_dict[slm] @@ -282,9 +291,14 @@ def __init__( self._set_optimizer(initial_param) def get_psf(self): - self._mask_obj.create_mask() + self._mask_obj.create_mask(self._row, self._col) self._mask_obj.compute_psf() - return self._mask_obj.psf.unsqueeze(0) + psf = self._mask_obj.psf.unsqueeze(0) + + # # need normalize the PSF? would think so but NAN comes up if included + # psf = psf / psf.norm() + + return psf def project(self): if self.separable: diff --git a/lensless/recon/utils.py b/lensless/recon/utils.py index f249f2b1..388871a2 100644 --- a/lensless/recon/utils.py +++ b/lensless/recon/utils.py @@ -535,7 +535,8 @@ def train_epoch(self, data_loader): # update psf according to mask if self.use_mask: - self.recon._set_psf(self.mask.get_psf().to(self.device)) + new_psf = self.mask.get_psf().to(self.device) + self.recon._set_psf(new_psf) # forward pass y_pred = self.recon.batch_call(X.to(self.device)) @@ -584,6 +585,11 @@ def train_epoch(self, data_loader): loss_v = loss_v + self.l1_mask * torch.mean(torch.abs(p)) loss_v.backward() + # check mask parameters are learning + if self.use_mask: + for p in self.mask.parameters(): + assert p.grad is not None + if self.clip_grad_norm is not None: torch.nn.utils.clip_grad_norm_(self.recon.parameters(), self.clip_grad_norm) diff --git a/scripts/recon/train_unrolled.py b/scripts/recon/train_unrolled.py index 132464f0..a855de96 100644 --- a/scripts/recon/train_unrolled.py +++ b/scripts/recon/train_unrolled.py @@ -109,12 +109,6 @@ def simulate_dataset(config, generator=None): train_ds = datasets.MNIST(root=data_path, train=True, download=True, transform=transform) test_ds = datasets.MNIST(root=data_path, train=False, download=True, transform=transform) - if config.files.n_files is not None: - train_size = int((1 - config.files.test_size) * config.files.n_files) - test_size = config.files.n_files - train_size - train_ds = Subset(train_ds, np.arange(train_size)) - test_ds = Subset(test_ds, np.arange(test_size)) - elif config.files.dataset == "fashion_mnist": transform = transforms.Compose(transforms_list) train_ds = datasets.FashionMNIST( @@ -127,6 +121,7 @@ def simulate_dataset(config, generator=None): transform = transforms.Compose(transforms_list) train_ds = datasets.CIFAR10(root=data_path, train=True, download=True, transform=transform) test_ds = datasets.CIFAR10(root=data_path, train=False, download=True, transform=transform) + elif config.files.dataset == "CelebA": root = config.files.celeba_root data_path = os.path.join(root, "celeba") @@ -152,6 +147,13 @@ def simulate_dataset(config, generator=None): else: raise NotImplementedError(f"Dataset {config.files.dataset} not implemented.") + if config.files.dataset != "CelebA": + if config.files.n_files is not None: + train_size = int((1 - config.files.test_size) * config.files.n_files) + test_size = config.files.n_files - train_size + train_ds = Subset(train_ds, np.arange(train_size)) + test_ds = Subset(test_ds, np.arange(test_size)) + # convert PSF if config.simulation.grayscale and not is_grayscale(psf): psf = rgb2gray(psf)