diff --git a/configs/train_coded_aperture.yaml b/configs/train_coded_aperture.yaml new file mode 100644 index 00000000..4219910d --- /dev/null +++ b/configs/train_coded_aperture.yaml @@ -0,0 +1,38 @@ +# python scripts/recon/train_unrolled.py -cn train_coded_aperture +defaults: + - train_unrolledADMM + - _self_ + +# Train Dataset +files: + dataset: mnist # Simulated : "mnist", "fashion_mnist", "cifar10", "CelebA". Measure :"DiffuserCam" + celeba_root: /scratch/bezzam + downsample: 8 + +#Trainable Mask +trainable_mask: + mask_type: TrainableCodedAperture + optimizer: Adam + mask_lr: 1e-3 + initial_value: + method: MLS + n_bits: 8 + # MURA not working... + # method: MURA + # n_bits: 3 + +simulation: + grayscale: False + flip: False + scene2mask: 40e-2 + mask2sensor: 2e-3 + sensor: "rpi_hq" + downsample: 16 + object_height: 0.30 + +training: + crop_preloss: False # crop region for computing loss + batch_size: 8 + epoch: 25 + eval_batch_size: 16 + save_every: 5 diff --git a/configs/train_unrolledADMM.yaml b/configs/train_unrolledADMM.yaml index f7602f01..16c6040d 100644 --- a/configs/train_unrolledADMM.yaml +++ b/configs/train_unrolledADMM.yaml @@ -84,6 +84,7 @@ trainable_mask: initial_value: psf grayscale: False mask_lr: 1e-3 + optimizer: Adam L1_strength: 1.0 #False or float target: "object_plane" # "original" or "object_plane" or "label" diff --git a/lensless/hardware/mask.py b/lensless/hardware/mask.py index f6a5d797..4ffa7b36 100644 --- a/lensless/hardware/mask.py +++ b/lensless/hardware/mask.py @@ -53,6 +53,8 @@ def __init__( size=None, feature_size=None, psf_wavelength=[460e-9, 550e-9, 640e-9], + is_torch=False, + torch_device="cpu", **kwargs ): """ @@ -95,6 +97,7 @@ def __init__( assert np.all(resolution * feature_size <= size) self.resolution = resolution + self.resolution = (int(self.resolution[0]), int(self.resolution[1])) self.size = size if feature_size is None: self.feature_size = self.size / self.resolution @@ -102,6 +105,9 @@ def __init__( self.feature_size = feature_size self.distance_sensor = distance_sensor + self.is_torch = is_torch + self.torch_device = torch_device + # create mask self.mask = None self.create_mask() @@ -155,19 +161,28 @@ def compute_psf(self): Compute the intensity PSF with bandlimited angular spectrum (BLAS) for each wavelength. Common to all types of masks. """ - psf = np.zeros(tuple(self.resolution) + (len(self.psf_wavelength),), dtype=np.complex64) + if self.is_torch: + psf = torch.zeros( + tuple(self.resolution) + (len(self.psf_wavelength),), dtype=torch.complex64 + ) + else: + psf = np.zeros(tuple(self.resolution) + (len(self.psf_wavelength),), dtype=np.complex64) for i, wv in enumerate(self.psf_wavelength): psf[:, :, i] = angular_spectrum( u_in=self.mask, wv=wv, d1=self.feature_size, dz=self.distance_sensor, - dtype=np.float32, + dtype=np.float32 if not self.is_torch else torch.float32, bandlimit=True, + device=self.torch_device if self.is_torch else None, )[0] # intensity PSF - self.psf = np.abs(psf) ** 2 + if self.is_torch: + self.psf = torch.abs(psf) ** 2 + else: + self.psf = np.abs(psf) ** 2 class CodedAperture(Mask): @@ -196,33 +211,55 @@ def __init__(self, method="MLS", n_bits=8, **kwargs): self.method = method self.n_bits = n_bits + assert self.method.upper() in ["MURA", "MLS"], "Method should be either 'MLS' or 'MURA'" + # TODO? use: https://github.com/bpops/codedapertures + + # initialize parameters + if self.method.upper() == "MURA": + mask = self.squarepattern(4 * self.n_bits + 1)[1:, 1:] + self.row = self.mask[0, :] + self.col = self.mask[:, 0] + outer = np.outer(self.row, self.col) + assert np.all(outer == mask) + else: + seq = max_len_seq(self.n_bits)[0] + self.row = seq + self.col = seq + + if kwargs["is_torch"]: + self.row = torch.from_numpy(self.row).float() + self.col = torch.from_numpy(self.col).float() + super().__init__(**kwargs) def create_mask(self): """ - Creating coded aperture mask using either the MURA of MLS method. + Creating coded aperture mask. """ - assert self.method.upper() in ["MURA", "MLS"], "Method should be either 'MLS' or 'MURA'" - # Generating pattern - if self.method.upper() == "MURA": - self.mask = self.squarepattern(4 * self.n_bits + 1)[1:, 1:] - self.row = 2 * self.mask[0, :] - 1 - self.col = 2 * self.mask[:, 0] - 1 + # outer product + if self.is_torch: + self.mask = torch.outer(self.row, self.col) else: - seq = max_len_seq(self.n_bits)[0] * 2 - 1 - h_r = np.r_[seq, seq] - self.row = h_r - self.col = h_r - self.mask = (np.outer(h_r, h_r) + 1) / 2 + self.mask = np.outer(self.row, self.col) - # Upscaling + # resize to sensor shape if np.any(self.resolution != self.mask.shape): - upscaled_mask = resize( - self.mask[:, :, np.newaxis], shape=tuple(self.resolution) + (1,) - ).squeeze() - upscaled_mask = np.clip(upscaled_mask, 0, 1) - self.mask = np.round(upscaled_mask).astype(int) + + if self.is_torch: + self.mask = self.mask.unsqueeze(0).unsqueeze(0) + self.mask = torch.nn.functional.interpolate( + self.mask, size=tuple(self.resolution), mode="nearest" + ).squeeze() + else: + # self.mask = resize(self.mask[:, :, np.newaxis], shape=tuple(self.resolution) + (1,)) + self.mask = resize( + self.mask[:, :, np.newaxis], + shape=tuple(self.resolution) + (1,), + interpolation=cv.INTER_NEAREST, + ).squeeze() + + # assert np.all(np.unique(self.mask) == np.array([0, 1])) def is_prime(self, n): """ @@ -246,6 +283,7 @@ def squarepattern(self, p): p: int Number of bits. """ + if not self.is_prime(p): raise ValueError("p is not a valid length. It must be prime.") A = np.zeros((p, p), dtype=int) diff --git a/lensless/hardware/sensor.py b/lensless/hardware/sensor.py index 0785204e..08d8fa46 100644 --- a/lensless/hardware/sensor.py +++ b/lensless/hardware/sensor.py @@ -213,6 +213,7 @@ def from_name(cls, name, downsample=None): Sensor. """ + if name not in SensorOptions.values(): raise ValueError(f"Sensor {name} not supported.") sensor_specs = sensor_dict[name].copy() diff --git a/lensless/hardware/trainable_mask.py b/lensless/hardware/trainable_mask.py index 8c15353a..fd6c28cd 100644 --- a/lensless/hardware/trainable_mask.py +++ b/lensless/hardware/trainable_mask.py @@ -26,24 +26,28 @@ class TrainableMask(torch.nn.Module, metaclass=abc.ABCMeta): """ - def __init__(self, initial_mask, optimizer="Adam", lr=1e-3, **kwargs): + def __init__(self, optimizer="Adam", lr=1e-3, **kwargs): """ Base constructor. Derived constructor may define new state variables Parameters ---------- - initial_mask : :py:class:`~torch.Tensor` - Initial mask parameters. optimizer : str, optional Optimizer to use for updating the mask parameters, by default "Adam" lr : float, optional Learning rate for the mask parameters, by default 1e-3 """ super().__init__() - self._mask = torch.nn.Parameter(initial_mask) - self._optimizer = getattr(torch.optim, optimizer)([self._mask], lr=lr) - self.train_mask_vals = True - self._counter = 0 + # self._param = [torch.nn.Parameter(p, requires_grad=True) for p in initial_param] + # # self._param = initial_param + # self._optimizer = getattr(torch.optim, optimizer)(self._param, lr=lr) + # self._counter = 0 + self.optimizer = optimizer + self.lr = lr + + def _set_optimizer(self, param): + """Set the optimizer for the mask parameters.""" + self._optimizer = getattr(torch.optim, self.optimizer)(param, lr=self.lr) @abc.abstractmethod def get_psf(self): @@ -66,7 +70,7 @@ def update_mask(self): def get_vals(self): """Get the mask parameters.""" - return self._mask + return self._param @abc.abstractmethod def project(self): @@ -85,30 +89,38 @@ class TrainablePSF(TrainableMask): Otherwise PSF will be returned as RGB. By default False. """ - def __init__(self, initial_mask, optimizer="Adam", lr=1e-3, grayscale=False, **kwargs): - super().__init__(initial_mask, optimizer, lr, **kwargs) - assert ( - len(initial_mask.shape) == 4 - ), "Mask must be of shape (depth, height, width, channels)" + def __init__(self, initial_psf, optimizer="Adam", lr=1e-3, grayscale=False, **kwargs): + + super().__init__(optimizer, lr, **kwargs) + + # cast as learnable parameters + self._psf = torch.nn.Parameter(initial_psf) + + # set optimizer + initial_param = [self._psf] + self._set_optimizer(initial_param) + + # checks + assert len(initial_psf.shape) == 4, "Mask must be of shape (depth, height, width, channels)" self.grayscale = grayscale - self._is_grayscale = is_grayscale(initial_mask) + self._is_grayscale = is_grayscale(initial_psf) if grayscale: - assert self._is_grayscale, "Mask must be grayscale" + assert self._is_grayscale, "PSF must be grayscale" def get_psf(self): if self._is_grayscale: if self.grayscale: # simulation in grayscale - return self._mask + return self._psf else: # replicate to 3 channels - return self._mask.expand(-1, -1, -1, 3) + return self._psf.expand(-1, -1, -1, 3) else: # assume RGB - return self._mask + return self._psf def project(self): - self._mask.data = torch.clamp(self._mask, 0, 1) + self._psf.data = torch.clamp(self._psf, 0, 1) class AdafruitLCD(TrainableMask): @@ -147,23 +159,27 @@ def __init__( Whether to flip the mask vertically, by default False """ - super().__init__(initial_vals, **kwargs) + super().__init__(optimizer, lr, **kwargs) self.train_mask_vals = train_mask_vals + if train_mask_vals: + self._mask = torch.nn.Parameter(initial_vals) + else: + self._mask = initial_vals + if color_filter is not None: self.color_filter = torch.nn.Parameter(color_filter) if train_mask_vals: - param = [self._mask, self.color_filter] + initial_param = [self._mask, self.color_filter] else: - del self._mask - self._mask = initial_vals - param = [self.color_filter] - self._optimizer = getattr(torch.optim, optimizer)(param, lr=lr) + initial_param = [self.color_filter] else: - self.color_filter = None assert ( train_mask_vals ), "If color filter is not trainable, mask values must be trainable" + # set optimizer + self._set_optimizer(initial_param) + self.slm_param = slm_dict[slm] self.device = slm self.sensor = VirtualSensor.from_name(sensor, downsample=downsample) @@ -233,16 +249,40 @@ def project(self): ).unsqueeze(-1).unsqueeze(-1) -class TrainableCodedAperture(CodedAperture): - def __init__(self, **kwargs): - super().__init__(**kwargs) +class TrainableCodedAperture(TrainableMask): + def __init__( + self, sensor_name, downsample=None, binary=True, optimizer="Adam", lr=1e-3, **kwargs + ): + """ + TODO: Distinguish between separable and non-separable. + """ + + super().__init__(optimizer, lr, **kwargs) + + assert "distance_sensor" in kwargs, "Distance to sensor must be specified" + assert "method" in kwargs, "Method must be specified." + assert "n_bits" in kwargs, "Number of bits must be specified." + + # initialize mask + self._mask = CodedAperture.from_sensor(sensor_name, downsample, is_torch=True, **kwargs) + + # set learnable parameters (should be immediate attributes of the class) + self._row = torch.nn.Parameter(self._mask.row) + self._col = torch.nn.Parameter(self._mask.col) + initial_param = [self._row, self._col] + self.binary = binary - self.row = torch.nn.Parameter(self.row) - self.col = torch.nn.Parameter(self.col) + # set optimizer + self._set_optimizer(initial_param) def get_psf(self): - return super().compute_psf() + self._mask.create_mask() + self._mask.compute_psf() + return self._mask.psf.unsqueeze(0) def project(self): self.row.data = torch.clamp(self.row, 0, 1) - self.col.data = torch.clamp(self.col, 0, 1) + self.col = torch.clamp(self.col, 0, 1) + if self.binary: + self.row.data = torch.round(self.row) + self.col.data = torch.round(self.col) diff --git a/scripts/recon/train_unrolled.py b/scripts/recon/train_unrolled.py index eaace9a8..8a326a7a 100644 --- a/scripts/recon/train_unrolled.py +++ b/scripts/recon/train_unrolled.py @@ -33,6 +33,7 @@ """ import logging +import omegaconf import hydra from hydra.utils import get_original_cwd import os @@ -271,56 +272,77 @@ def prep_trainable_mask(config, psf=None, downsample=None): if config.trainable_mask.mask_type is not None: mask_class = getattr(lensless.hardware.trainable_mask, config.trainable_mask.mask_type) - if config.trainable_mask.initial_value == "random": - if psf is not None: - initial_mask = torch.rand_like(psf) - else: - sensor = VirtualSensor.from_name(config.simulation.sensor, downsample=downsample) - resolution = sensor.resolution - initial_mask = torch.rand((1, *resolution, 3)) - elif config.trainable_mask.initial_value == "psf": - initial_mask = psf.clone() - # if file ending with "npy" - elif config.trainable_mask.initial_value.endswith("npy"): - pattern = np.load(os.path.join(get_original_cwd(), config.trainable_mask.initial_value)) - - initial_mask = full2subpattern( - pattern=pattern, - shape=config.trainable_mask.ap_shape, - center=config.trainable_mask.ap_center, - slm=config.trainable_mask.slm, + if isinstance(config.trainable_mask.initial_value, omegaconf.dictconfig.DictConfig): + + from lensless.hardware.trainable_mask import TrainableCodedAperture + + # from mask config + mask = mask_class( + # mask = TrainableCodedAperture( + sensor_name=config.simulation.sensor, + downsample=downsample, + distance_sensor=config.simulation.mask2sensor, + optimizer=config.trainable_mask.optimizer, + lr=config.trainable_mask.mask_lr, + **config.trainable_mask.initial_value, ) - initial_mask = torch.from_numpy(initial_mask.astype(np.float32)) - - # prepare color filter if needed - from waveprop.devices import slm_dict - from waveprop.devices import SLMParam as SLMParam_wp - - slm_param = slm_dict[config.trainable_mask.slm] - if ( - config.trainable_mask.train_color_filter - and SLMParam_wp.COLOR_FILTER in slm_param.keys() - ): - color_filter = slm_param[SLMParam_wp.COLOR_FILTER] - color_filter = torch.from_numpy(color_filter.copy()).to(dtype=torch.float32) - - # add small random values - color_filter = color_filter + 0.1 * torch.rand_like(color_filter) + else: - raise ValueError( - f"Initial PSF value {config.trainable_mask.initial_value} not supported" - ) - if config.trainable_mask.grayscale and not is_grayscale(initial_mask): - initial_mask = rgb2gray(initial_mask) + if config.trainable_mask.initial_value == "random": + if psf is not None: + initial_mask = torch.rand_like(psf) + else: + sensor = VirtualSensor.from_name( + config.simulation.sensor, downsample=downsample + ) + resolution = sensor.resolution + initial_mask = torch.rand((1, *resolution, 3)) + elif config.trainable_mask.initial_value == "psf": + initial_mask = psf.clone() + # if file ending with "npy" + elif config.trainable_mask.initial_value.endswith("npy"): + pattern = np.load( + os.path.join(get_original_cwd(), config.trainable_mask.initial_value) + ) + + initial_mask = full2subpattern( + pattern=pattern, + shape=config.trainable_mask.ap_shape, + center=config.trainable_mask.ap_center, + slm=config.trainable_mask.slm, + ) + initial_mask = torch.from_numpy(initial_mask.astype(np.float32)) + + # prepare color filter if needed + from waveprop.devices import slm_dict + from waveprop.devices import SLMParam as SLMParam_wp + + slm_param = slm_dict[config.trainable_mask.slm] + if ( + config.trainable_mask.train_color_filter + and SLMParam_wp.COLOR_FILTER in slm_param.keys() + ): + color_filter = slm_param[SLMParam_wp.COLOR_FILTER] + color_filter = torch.from_numpy(color_filter.copy()).to(dtype=torch.float32) + + # add small random values + color_filter = color_filter + 0.1 * torch.rand_like(color_filter) - mask = mask_class( - initial_mask, - optimizer="Adam", - downsample=downsample, - color_filter=color_filter, - **config.trainable_mask, - ) + else: + raise ValueError( + f"Initial PSF value {config.trainable_mask.initial_value} not supported" + ) + + if config.trainable_mask.grayscale and not is_grayscale(initial_mask): + initial_mask = rgb2gray(initial_mask) + + mask = mask_class( + initial_mask, + downsample=downsample, + color_filter=color_filter, + **config.trainable_mask, + ) return mask