diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 813b4b25..5a646e3f 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -14,11 +14,17 @@ Added ~~~~~ - Script to upload measured datasets to Hugging Face: ``scripts/data/upload_dataset_huggingface.py`` +- Pytorch support for simulating PSFs of masks. +- ``lensless.hardware.mask.MultiLensArray`` class for simulating multi-lens arrays. +- ``lensless.hardware.trainable_mask.TrainableCodedAperture`` class for training a coded aperture mask pattern. +- Support for other optimizers in ``lensless.utils.Trainer.set_optimizer``. +- ``lensless.utils.dataset.simulate_dataset`` for simulating a dataset given a mask/PSF. Changed ~~~~~ - Dataset reconstruction script uses datasets from Hugging Face: ``scripts/recon/dataset.py`` +- For trainable masks, set trainable parameters inside the child class. Bugfix ~~~~~ diff --git a/configs/train_celeba_digicam_mask.yaml b/configs/train_celeba_digicam_mask.yaml index 9657e248..8dfd7f73 100644 --- a/configs/train_celeba_digicam_mask.yaml +++ b/configs/train_celeba_digicam_mask.yaml @@ -1,3 +1,4 @@ +# fine-tune mask for PSF, but don't re-simulate # python scripts/recon/train_unrolled.py -cn train_celeba_digicam_mask defaults: - train_celeba_digicam @@ -78,7 +79,7 @@ trainable_mask: # horizontal_shift: -100 # [px] - initial_value: adafruit_random_pattern_20231004_174047.npy + initial_value: /home/bezzam/LenslessPiCam/adafruit_random_pattern_20231004_174047.npy ap_center: [58, 76] ap_shape: [19, 25] rotate: 0 # rotation in degrees diff --git a/configs/train_coded_aperture.yaml b/configs/train_coded_aperture.yaml new file mode 100644 index 00000000..ea39b6ab --- /dev/null +++ b/configs/train_coded_aperture.yaml @@ -0,0 +1,55 @@ +# python scripts/recon/train_unrolled.py -cn train_coded_aperture +defaults: + - train_unrolledADMM + - _self_ + +# Train Dataset +files: + dataset: mnist # Simulated : "mnist", "fashion_mnist", "cifar10", "CelebA". Measure :"DiffuserCam" + celeba_root: /scratch/bezzam + downsample: 16 # TODO use downsample simulation instead? + n_files: 100 + crop: + vertical: [810, 2240] + horizontal: [1310, 2750] + +torch_device: "cuda:1" + +optimizer: + # type: Adam # Adam, SGD... + # lr: 1e-4 + type: SGD + lr: 0.01 + +#Trainable Mask +trainable_mask: + mask_type: TrainableCodedAperture + # optimizer: Adam + # mask_lr: 1e-3 + optimizer: SGD + mask_lr: 0.01 + L1_strength: False + binary: False + initial_value: + psf_wavelength: [550e-9] + method: MLS + n_bits: 8 # (2**n_bits-1, 2**n_bits-1) + # method: MURA + # n_bits: 25 # (4*nbits*1, 4*nbits*1) + # # -- applicable for phase masks + # design_wv: 550e-9 + +simulation: + grayscale: True + flip: False + scene2mask: 40e-2 + mask2sensor: 2e-3 + sensor: "rpi_hq" + object_height: 0.30 + +training: + crop_preloss: True # crop region for computing loss + batch_size: 4 + epoch: 25 + eval_batch_size: 16 + save_every: 1 diff --git a/configs/train_unrolledADMM.yaml b/configs/train_unrolledADMM.yaml index 5965efb7..0c7e8e47 100644 --- a/configs/train_unrolledADMM.yaml +++ b/configs/train_unrolledADMM.yaml @@ -90,6 +90,7 @@ trainable_mask: initial_value: psf grayscale: False mask_lr: 1e-3 + optimizer: Adam # Adam, SGD... (Pytorch class) L1_strength: 1.0 #False or float target: "object_plane" # "original" or "object_plane" or "label" @@ -135,7 +136,7 @@ training: crop_preloss: False # crop region for computing loss, files.crop should be set optimizer: - type: Adam + type: Adam # Adam, SGD... (Pytorch class) lr: 1e-4 slow_start: False #float how much to reduce lr for first epoch # Decay LR in step fashion: https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.StepLR.html diff --git a/lensless/eval/benchmark.py b/lensless/eval/benchmark.py index c8322916..c0dffd6e 100644 --- a/lensless/eval/benchmark.py +++ b/lensless/eval/benchmark.py @@ -106,17 +106,17 @@ def benchmark( dataloader = DataLoader(dataset, batch_size=batchsize, pin_memory=(device != "cpu")) model.reset() idx = 0 - for lensless, lensed in tqdm(dataloader): - lensless = lensless.to(device) - lensed = lensed.to(device) + with torch.no_grad(): + for lensless, lensed in tqdm(dataloader): + lensless = lensless.to(device) + lensed = lensed.to(device) - # add shot noise - if snr is not None: - for i in range(lensless.shape[0]): - lensless[i] = add_shot_noise(lensless[i], float(snr)) + # add shot noise + if snr is not None: + for i in range(lensless.shape[0]): + lensless[i] = add_shot_noise(lensless[i], float(snr)) - # compute predictions - with torch.no_grad(): + # compute predictions if batchsize == 1: model.set_data(lensless) prediction = model.apply( @@ -126,113 +126,115 @@ def benchmark( else: prediction = model.batch_call(lensless, **kwargs) - if unrolled_output_factor: - unrolled_out = prediction[-1] - prediction = prediction[0] - - # Convert to [N*D, C, H, W] for torchmetrics - prediction = prediction.reshape(-1, *prediction.shape[-3:]).movedim(-1, -3) - lensed = lensed.reshape(-1, *lensed.shape[-3:]).movedim(-1, -3) - - if crop is not None: - prediction = prediction[ - ..., - crop["vertical"][0] : crop["vertical"][1], - crop["horizontal"][0] : crop["horizontal"][1], - ] - lensed = lensed[ - ..., - crop["vertical"][0] : crop["vertical"][1], - crop["horizontal"][0] : crop["horizontal"][1], - ] - - if save_idx is not None: - batch_idx = np.arange(idx, idx + batchsize) - - for i, idx in enumerate(batch_idx): - if idx in save_idx: - prediction_np = prediction.cpu().numpy()[i].squeeze() - # switch to [H, W, C] - prediction_np = np.moveaxis(prediction_np, 0, -1) - save_image(prediction_np, fp=os.path.join(output_dir, f"{idx}.png")) - - # normalization - prediction_max = torch.amax(prediction, dim=(-1, -2, -3), keepdim=True) - if torch.all(prediction_max != 0): - prediction = prediction / prediction_max - else: - print("Warning: prediction is zero") - lensed_max = torch.amax(lensed, dim=(1, 2, 3), keepdim=True) - lensed = lensed / lensed_max - - # compute metrics - for metric in metrics: - if metric == "ReconstructionError": - metrics_values[metric].append(model.reconstruction_error().cpu().item()) - else: - if "LPIPS" in metric: - if prediction.shape[1] == 1: - # LPIPS needs 3 channels - metrics_values[metric].append( - metrics[metric]( - prediction.repeat(1, 3, 1, 1), lensed.repeat(1, 3, 1, 1) - ) - .cpu() - .item() - ) - else: - metrics_values[metric].append( - metrics[metric](prediction, lensed).cpu().item() - ) - else: - metrics_values[metric].append(metrics[metric](prediction, lensed).cpu().item()) + if unrolled_output_factor: + unrolled_out = prediction[-1] + prediction = prediction[0] - # compute metrics for unrolled output - if unrolled_output_factor: + # Convert to [N*D, C, H, W] for torchmetrics + prediction = prediction.reshape(-1, *prediction.shape[-3:]).movedim(-1, -3) + lensed = lensed.reshape(-1, *lensed.shape[-3:]).movedim(-1, -3) - # -- convert to CHW and remove depth - unrolled_out = unrolled_out.reshape(-1, *unrolled_out.shape[-3:]).movedim(-1, -3) - - # -- extraction region of interest if crop is not None: - unrolled_out = unrolled_out[ + prediction = prediction[ + ..., + crop["vertical"][0] : crop["vertical"][1], + crop["horizontal"][0] : crop["horizontal"][1], + ] + lensed = lensed[ ..., crop["vertical"][0] : crop["vertical"][1], crop["horizontal"][0] : crop["horizontal"][1], ] - # -- normalization - unrolled_out_max = torch.amax(unrolled_out, dim=(-1, -2, -3), keepdim=True) - if torch.all(unrolled_out_max != 0): - unrolled_out = unrolled_out / unrolled_out_max + if save_idx is not None: + batch_idx = np.arange(idx, idx + batchsize) + + for i, idx in enumerate(batch_idx): + if idx in save_idx: + prediction_np = prediction.cpu().numpy()[i] + # switch to [H, W, C] for saving + prediction_np = np.moveaxis(prediction_np, 0, -1) + save_image(prediction_np, fp=os.path.join(output_dir, f"{idx}.png")) - # -- compute metrics + # normalization + prediction_max = torch.amax(prediction, dim=(-1, -2, -3), keepdim=True) + if torch.all(prediction_max != 0): + prediction = prediction / prediction_max + else: + print("Warning: prediction is zero") + lensed_max = torch.amax(lensed, dim=(1, 2, 3), keepdim=True) + lensed = lensed / lensed_max + + # compute metrics for metric in metrics: if metric == "ReconstructionError": - # only have this for final output - continue + metrics_values[metric].append(model.reconstruction_error().cpu().item()) else: if "LPIPS" in metric: - if unrolled_out.shape[1] == 1: + if prediction.shape[1] == 1: # LPIPS needs 3 channels metrics_values[metric].append( metrics[metric]( - unrolled_out.repeat(1, 3, 1, 1), lensed.repeat(1, 3, 1, 1) + prediction.repeat(1, 3, 1, 1), lensed.repeat(1, 3, 1, 1) ) .cpu() .item() ) else: - metrics_values[metric + "_unrolled"].append( - metrics[metric](unrolled_out, lensed).cpu().item() + metrics_values[metric].append( + metrics[metric](prediction, lensed).cpu().item() ) else: - metrics_values[metric + "_unrolled"].append( - metrics[metric](unrolled_out, lensed).cpu().item() + metrics_values[metric].append( + metrics[metric](prediction, lensed).cpu().item() ) - model.reset() - idx += batchsize + # compute metrics for unrolled output + if unrolled_output_factor: + + # -- convert to CHW and remove depth + unrolled_out = unrolled_out.reshape(-1, *unrolled_out.shape[-3:]).movedim(-1, -3) + + # -- extraction region of interest + if crop is not None: + unrolled_out = unrolled_out[ + ..., + crop["vertical"][0] : crop["vertical"][1], + crop["horizontal"][0] : crop["horizontal"][1], + ] + + # -- normalization + unrolled_out_max = torch.amax(unrolled_out, dim=(-1, -2, -3), keepdim=True) + if torch.all(unrolled_out_max != 0): + unrolled_out = unrolled_out / unrolled_out_max + + # -- compute metrics + for metric in metrics: + if metric == "ReconstructionError": + # only have this for final output + continue + else: + if "LPIPS" in metric: + if unrolled_out.shape[1] == 1: + # LPIPS needs 3 channels + metrics_values[metric].append( + metrics[metric]( + unrolled_out.repeat(1, 3, 1, 1), lensed.repeat(1, 3, 1, 1) + ) + .cpu() + .item() + ) + else: + metrics_values[metric + "_unrolled"].append( + metrics[metric](unrolled_out, lensed).cpu().item() + ) + else: + metrics_values[metric + "_unrolled"].append( + metrics[metric](unrolled_out, lensed).cpu().item() + ) + + model.reset() + idx += batchsize # average metrics if return_average: diff --git a/lensless/hardware/mask.py b/lensless/hardware/mask.py index f9597bf5..40c7cbcc 100644 --- a/lensless/hardware/mask.py +++ b/lensless/hardware/mask.py @@ -53,7 +53,9 @@ def __init__( size=None, feature_size=None, psf_wavelength=[460e-9, 550e-9, 640e-9], - **kwargs + is_torch=False, + torch_device="cpu", + **kwargs, ): """ Constructor from parameters of the user's choice. @@ -94,8 +96,8 @@ def __init__( assert np.all(feature_size > 0), "Feature size should be positive" assert np.all(resolution * feature_size <= size) - self.phase_mask = None self.resolution = resolution + self.resolution = (int(self.resolution[0]), int(self.resolution[1])) self.size = size if feature_size is None: self.feature_size = self.size / self.resolution @@ -103,12 +105,17 @@ def __init__( self.feature_size = feature_size self.distance_sensor = distance_sensor + if is_torch: + assert torch_available, "PyTorch is not available" + self.is_torch = is_torch + self.torch_device = torch_device + # create mask - self.mask = None self.create_mask() self.shape = self.mask.shape # PSF + assert hasattr(psf_wavelength, "__len__"), "psf_wavelength should be a list" self.psf_wavelength = psf_wavelength self.psf = None self.compute_psf() @@ -141,7 +148,7 @@ def from_sensor(cls, sensor_name, downsample=None, **kwargs): resolution=tuple(sensor.resolution.copy()), size=tuple(sensor.size.copy()), feature_size=sensor.pixel_size.copy(), - **kwargs + **kwargs, ) @abc.abstractmethod @@ -156,19 +163,30 @@ def compute_psf(self): Compute the intensity PSF with bandlimited angular spectrum (BLAS) for each wavelength. Common to all types of masks. """ - psf = np.zeros(tuple(self.resolution) + (len(self.psf_wavelength),), dtype=np.complex64) + if self.is_torch: + psf = torch.zeros( + tuple(self.resolution) + (len(self.psf_wavelength),), + dtype=torch.complex64, + device=self.torch_device, + ) + else: + psf = np.zeros(tuple(self.resolution) + (len(self.psf_wavelength),), dtype=np.complex64) for i, wv in enumerate(self.psf_wavelength): psf[:, :, i] = angular_spectrum( u_in=self.mask, wv=wv, d1=self.feature_size, dz=self.distance_sensor, - dtype=np.float32, + dtype=np.float32 if not self.is_torch else torch.float32, bandlimit=True, + device=self.torch_device if self.is_torch else None, )[0] # intensity PSF - self.psf = np.abs(psf) ** 2 + if self.is_torch: + self.psf = torch.abs(psf) ** 2 + else: + self.psf = np.abs(psf) ** 2 class CodedAperture(Mask): @@ -197,33 +215,67 @@ def __init__(self, method="MLS", n_bits=8, **kwargs): self.method = method self.n_bits = n_bits + assert self.method.upper() in ["MURA", "MLS"], "Method should be either 'MLS' or 'MURA'" + # TODO? use: https://github.com/bpops/codedapertures + + # initialize parameters + if self.method.upper() == "MURA": + self.mask = self.generate_mura(4 * self.n_bits + 1) + self.row = None + self.col = None + else: + seq = max_len_seq(self.n_bits)[0] + self.row = seq + self.col = seq + + if "is_torch" in kwargs and kwargs["is_torch"]: + torch_device = kwargs["torch_device"] if "torch_device" in kwargs else "cpu" + if self.row is not None and self.col is not None: + self.row = torch.from_numpy(self.row).float().to(torch_device) + self.col = torch.from_numpy(self.col).float().to(torch_device) + else: + self.mask = torch.from_numpy(self.mask).float().to(torch_device) + + # needs to be done at the end as it calls create_mask super().__init__(**kwargs) - def create_mask(self): + def create_mask(self, row=None, col=None, mask=None): """ - Creating coded aperture mask using either the MURA of MLS method. + Creating coded aperture mask. """ - assert self.method.upper() in ["MURA", "MLS"], "Method should be either 'MLS' or 'MURA'" - # Generating pattern - if self.method.upper() == "MURA": - self.mask = self.squarepattern(4 * self.n_bits + 1)[1:, 1:] - self.row = 2 * self.mask[0, :] - 1 - self.col = 2 * self.mask[:, 0] - 1 - else: - seq = max_len_seq(self.n_bits)[0] * 2 - 1 - h_r = np.r_[seq, seq] - self.row = h_r - self.col = h_r - self.mask = (np.outer(h_r, h_r) + 1) / 2 + if mask is not None: + self.mask = mask + assert row is None and col is None, "Row and col should not be specified" + + elif row is not None: + assert col is not None, "Both row and col should be specified" + self.row = row + self.col = col + + # output product if necessary + if self.row is not None: + if self.is_torch: + self.mask = torch.outer(self.row, self.col) + else: + self.mask = np.outer(self.row, self.col) + assert self.mask is not None, "Mask should be specified" - # Upscaling + # resize to sensor shape if np.any(self.resolution != self.mask.shape): - upscaled_mask = resize( - self.mask[:, :, np.newaxis], shape=tuple(self.resolution) + (1,) - ).squeeze() - upscaled_mask = np.clip(upscaled_mask, 0, 1) - self.mask = np.round(upscaled_mask).astype(int) + + if self.is_torch: + self.mask = self.mask.unsqueeze(0).unsqueeze(0) + self.mask = torch.nn.functional.interpolate( + self.mask, size=tuple(self.resolution), mode="nearest" + ).squeeze() + else: + # self.mask = resize(self.mask[:, :, np.newaxis], shape=tuple(self.resolution) + (1,)) + self.mask = resize( + self.mask[:, :, np.newaxis], + shape=tuple(self.resolution) + (1,), + interpolation=cv.INTER_NEAREST, + ).squeeze() def is_prime(self, n): """ @@ -238,7 +290,7 @@ def is_prime(self, n): return False return all(n % i for i in range(3, int(sqrt(n)) + 1, 2)) - def squarepattern(self, p): + def generate_mura(self, p): """ Generate MURA square pattern. @@ -247,6 +299,7 @@ def squarepattern(self, p): p: int Number of bits. """ + if not self.is_prime(p): raise ValueError("p is not a valid length. It must be prime.") A = np.zeros((p, p), dtype=int) @@ -322,6 +375,219 @@ def simulate(self, obj, snr_db=20): return meas +class MultiLensArray(Mask): + """ + Multi-lens array mask. + """ + + def __init__( + self, + N=None, + radius=None, + loc=None, + refractive_index=1.2, + design_wv=532e-9, + seed=0, + min_height=1e-5, + radius_range=(1e-4, 4e-4), + min_separation=1e-4, + verbose=False, + **kwargs, + ): + """ + Multi-lens array mask constructor. + + Parameters + ---------- + N: int + Number of micro-lenses. + radius: array_like + Radius of the lenses (m). + loc: array_like of tuples + Location of the lenses (m). + refractive_index: float + Refractive index of the mask substrate. Default is 1.2. + design_wv: float + Wavelength used to design the mask (m). Default is 532e-9. + seed: int + Seed for the random number generator. Default is 0. + min_height: float + Minimum height of the lenses (m). Default is 1e-3. + radius_range: array_like + Range of the radius of the lenses (m). Default is (1e-4, 4e-4) m. + min_separation: float + Minimum separation between lenses (m). Default is 1e-4. + verbose: bool + If True, print lens placement information. Default is False. + """ + self.N = N + self.radius = radius + self.loc = loc + self.refractive_index = refractive_index + self.wavelength = design_wv + self.seed = seed + self.min_height = min_height + self.radius_range = radius_range + self.min_separation = min_separation + self.verbose = verbose + + super().__init__(**kwargs) + + def check_asserts(self): + """ + Check the validity of the parameters. + + Generate the locations and radii of the lenses if not specified. + """ + assert ( + self.radius_range[0] < self.radius_range[1] + ), "Minimum radius should be smaller than maximum radius" + if self.radius is not None: + if self.is_torch: + assert torch.all(self.radius >= 0) + else: + assert np.all(self.radius >= 0) + assert ( + self.loc is not None + ), "Location of the lenses should be specified if their radius is specified" + assert len(self.radius) == len( + self.loc + ), "Number of radius should be equal to the number of locations" + self.N = len(self.radius) + circles = ( + np.array([(self.loc[i][0], self.loc[i][1], self.radius[i]) for i in range(self.N)]) + if not self.is_torch + else torch.tensor( + [(self.loc[i][0], self.loc[i][1], self.radius[i]) for i in range(self.N)] + ).to(self.torch_device) + ) + assert self.no_circle_overlap(circles), "lenses should not overlap" + else: + # generate random locations and radii + assert ( + self.N is not None + ), "If positions are not specified, the number of lenses should be specified" + + np.random.seed(self.seed) + self.radius = np.random.uniform(self.radius_range[0], self.radius_range[1], self.N) + # radius get sorted in descending order + self.loc, self.radius = self.place_spheres_on_plane(self.radius) + if self.is_torch: + self.radius = torch.tensor(self.radius).to(self.torch_device) + self.loc = torch.tensor(self.loc).to(self.torch_device) + + def no_circle_overlap(self, circles): + """ + Check if any circle in the list overlaps with another. + + Parameters + ---------- + circles: array_like + List of circles, each represented by a tuple (x, y, r) with location (x, y) and radius r. + """ + for i in range(len(circles)): + if self.does_circle_overlap( + circles[i + 1 :], circles[i][0], circles[i][1], circles[i][2] + ): + return False + return True + + def does_circle_overlap(self, circles, x, y, r): + """Check if a circle overlaps with any in the list.""" + for (cx, cy, cr) in circles: + if sqrt((x - cx) ** 2 + (y - cy) ** 2) <= (r + cr + self.min_separation): + return True, (cx, cy, cr) + return False + + def place_spheres_on_plane(self, radius, max_attempts=1000): + """Try to place circles of given radius on a 2D plane.""" + placed_circles = [] + rad_sorted = sorted(radius, reverse=True) # sort the radius in descending order + loc = [] + r_placed = [] + for r in rad_sorted: + placed = False + for _ in range(max_attempts): + x = np.random.uniform(r, self.size[1] - r) + y = np.random.uniform(r, self.size[0] - r) + if not self.does_circle_overlap(placed_circles, x, y, r): + placed_circles.append((x, y, r)) + loc.append([x, y]) + r_placed.append(r) + placed = True + if self.verbose: + print(f"Placed circle with rad {r}, and center ({x}, {y})") + break + if not placed: + if self.verbose: + print(f"Failed to place circle with rad {r}") + continue + if len(r_placed) < self.N: + warnings.warn(f"Could not place {self.N - len(r_placed)} lenses") + return np.array(loc, dtype=np.float32), np.array(r_placed, dtype=np.float32) + + def create_mask(self, loc=None, radius=None): + """ + Creating multi-lens array mask. + + Parameters + ---------- + loc: array_like of tuples, optional + Location of the lenses (m). + radius: array_like, optional + Radius of the lenses (m). + """ + if radius is not None: + self.radius = radius + if loc is not None: + self.loc = loc + self.check_asserts() + + # convert to pixels (assume same size for x and y) + locs_pix = self.loc * (1 / self.feature_size[0]) + radius_pix = self.radius * (1 / self.feature_size[0]) + height = self.create_height_map(radius_pix, locs_pix) + self.phi = height * (self.refractive_index - 1) * 2 * np.pi / self.wavelength + self.mask = np.exp(1j * self.phi) if not self.is_torch else torch.exp(1j * self.phi) + + def create_height_map(self, radius, locs): + height = ( + np.full((self.resolution[0], self.resolution[1]), self.min_height).astype(np.float32) + if not self.is_torch + else torch.full((self.resolution[0], self.resolution[1]), self.min_height).to( + self.torch_device, dtype=torch.float32 + ) + ) + x = ( + np.arange(self.resolution[0]).astype(np.float32) + if not self.is_torch + else torch.arange(self.resolution[0]).to(self.torch_device) + ) + y = ( + np.arange(self.resolution[1]).astype(np.float32) + if not self.is_torch + else torch.arange(self.resolution[1]).to(self.torch_device) + ) + X, Y = ( + np.meshgrid(x, y, indexing="ij") + if not self.is_torch + else torch.meshgrid(x, y, indexing="ij") + ) + for idx, rad in enumerate(radius): + contribution = self.lens_contribution(X, Y, rad, locs[idx]) * self.feature_size[0] + contribution[(X - locs[idx][1]) ** 2 + (Y - locs[idx][0]) ** 2 > rad**2] = 0 + height = height + contribution + height[height < self.min_height] = self.min_height + return height + + def lens_contribution(self, x, y, radius, loc): + return ( + np.sqrt(radius**2 - (x - loc[1]) ** 2 - (y - loc[0]) ** 2) + if not self.is_torch + else torch.sqrt(radius**2 - (x - loc[1]) ** 2 - (y - loc[0]) ** 2) + ) + + class PhaseContour(Mask): """ Phase contour mask as in `PhlatCam `_. diff --git a/lensless/hardware/sensor.py b/lensless/hardware/sensor.py index 0785204e..08d8fa46 100644 --- a/lensless/hardware/sensor.py +++ b/lensless/hardware/sensor.py @@ -213,6 +213,7 @@ def from_name(cls, name, downsample=None): Sensor. """ + if name not in SensorOptions.values(): raise ValueError(f"Sensor {name} not supported.") sensor_specs = sensor_dict[name].copy() diff --git a/lensless/hardware/trainable_mask.py b/lensless/hardware/trainable_mask.py index cf791a52..254bd5b1 100644 --- a/lensless/hardware/trainable_mask.py +++ b/lensless/hardware/trainable_mask.py @@ -7,14 +7,14 @@ # ############################################################################# import abc -import torch +import omegaconf import numpy as np -from lensless.utils.image import is_grayscale -from lensless.hardware.slm import get_programmable_mask, get_intensity_psf +import torch +from lensless.utils.image import is_grayscale, rgb2gray +from lensless.hardware.slm import full2subpattern, get_programmable_mask, get_intensity_psf from lensless.hardware.sensor import VirtualSensor from waveprop.devices import slm_dict -from lensless.hardware.slm import full2subpattern -from lensless.utils.image import rgb2gray +from lensless.hardware.mask import CodedAperture class TrainableMask(torch.nn.Module, metaclass=abc.ABCMeta): @@ -28,25 +28,26 @@ class TrainableMask(torch.nn.Module, metaclass=abc.ABCMeta): """ - def __init__(self, initial_mask, optimizer="Adam", lr=1e-3, **kwargs): + def __init__(self, optimizer="Adam", lr=1e-3, **kwargs): """ Base constructor. Derived constructor may define new state variables Parameters ---------- - initial_mask : :py:class:`~torch.Tensor` - Initial mask parameters. optimizer : str, optional Optimizer to use for updating the mask parameters, by default "Adam" lr : float, optional Learning rate for the mask parameters, by default 1e-3 """ super().__init__() - self._mask = torch.nn.Parameter(initial_mask) - self._optimizer = getattr(torch.optim, optimizer)([self._mask], lr=lr) - self.train_mask_vals = True + self._optimizer = optimizer + self._lr = lr self._counter = 0 + def _set_optimizer(self, param): + """Set the optimizer for the mask parameters.""" + self._optimizer = getattr(torch.optim, self._optimizer)(param, lr=self._lr) + @abc.abstractmethod def get_psf(self): """ @@ -66,10 +67,6 @@ def update_mask(self): self.project() self._counter += 1 - def get_vals(self): - """Get the mask parameters.""" - return self._mask - @abc.abstractmethod def project(self): """Abstract method for projecting the mask parameters to a valid space (should be a subspace of [0,1]).""" @@ -77,6 +74,7 @@ def project(self): class TrainablePSF(TrainableMask): + # class TrainablePSF(torch.nn.Module, TrainableMask): """ Class for defining an object that directly optimizes the PSF, without any constraints on what can be realized physically. @@ -87,40 +85,43 @@ class TrainablePSF(TrainableMask): Otherwise PSF will be returned as RGB. By default False. """ - def __init__(self, initial_mask, optimizer="Adam", lr=1e-3, grayscale=False, **kwargs): - super().__init__(initial_mask, optimizer, lr, **kwargs) - assert ( - len(initial_mask.shape) == 4 - ), "Mask must be of shape (depth, height, width, channels)" + def __init__(self, initial_psf, grayscale=False, **kwargs): + + super().__init__(**kwargs) + self._psf = torch.nn.Parameter(initial_psf) + initial_param = [self._psf] + self._set_optimizer(initial_param) + + # checks + assert len(initial_psf.shape) == 4, "Mask must be of shape (depth, height, width, channels)" self.grayscale = grayscale - self._is_grayscale = is_grayscale(initial_mask) + self._is_grayscale = is_grayscale(initial_psf) if grayscale: - assert self._is_grayscale, "Mask must be grayscale" + assert self._is_grayscale, "PSF must be grayscale" def get_psf(self): if self._is_grayscale: if self.grayscale: # simulation in grayscale - return self._mask + return self._psf else: # replicate to 3 channels - return self._mask.expand(-1, -1, -1, 3) + return self._psf.expand(-1, -1, -1, 3) else: # assume RGB - return self._mask + return self._psf def project(self): - self._mask.data = torch.clamp(self._mask, 0, 1) + self._psf.data = torch.clamp(self._psf, 0, 1) class AdafruitLCD(TrainableMask): + # class AdafruitLCD(torch.nn.Module, TrainableMask): def __init__( self, initial_vals, sensor, slm, - optimizer="Adam", - lr=1e-3, train_mask_vals=True, color_filter=None, rotate=None, @@ -149,23 +150,28 @@ def __init__( Whether to flip the mask vertically, by default False """ - super().__init__(initial_vals, **kwargs) + super().__init__(**kwargs) + self.train_mask_vals = train_mask_vals + if train_mask_vals: + self._vals = torch.nn.Parameter(initial_vals) + else: + self._vals = initial_vals + if color_filter is not None: - self.color_filter = torch.nn.Parameter(color_filter) + self._color_filter = torch.nn.Parameter(color_filter) if train_mask_vals: - param = [self._mask, self.color_filter] + initial_param = [self._vals, self._color_filter] else: - del self._mask - self._mask = initial_vals - param = [self.color_filter] - self._optimizer = getattr(torch.optim, optimizer)(param, lr=lr) + initial_param = [self._color_filter] else: - self.color_filter = None assert ( train_mask_vals ), "If color filter is not trainable, mask values must be trainable" + # set optimizer + self._set_optimizer(initial_param) + self.slm_param = slm_dict[slm] self.device = slm self.sensor = VirtualSensor.from_name(sensor, downsample=downsample) @@ -188,12 +194,12 @@ def __init__( def get_psf(self): mask = get_programmable_mask( - vals=self._mask, + vals=self._vals, sensor=self.sensor, slm_param=self.slm_param, rotate=self.rotate, flipud=self.flipud, - color_filter=self.color_filter, + color_filter=self._color_filter, ) if self.vertical_shift is not None: @@ -226,22 +232,100 @@ def get_psf(self): def project(self): if self.train_mask_vals: - self._mask.data = torch.clamp(self._mask, self.min_val, 1) - if self.color_filter is not None: - self.color_filter.data = torch.clamp(self.color_filter, 0, 1) + self._vals.data = torch.clamp(self._vals, self.min_val, 1) + if self._color_filter is not None: + self._color_filter.data = torch.clamp(self._color_filter, 0, 1) # normalize each row to 1 - self.color_filter.data = self.color_filter / self.color_filter.sum( + self._color_filter.data = self._color_filter / self._color_filter.sum( dim=[1, 2] ).unsqueeze(-1).unsqueeze(-1) +class TrainableCodedAperture(TrainableMask): + def __init__( + self, + sensor_name, + downsample=None, + binary=True, + torch_device="cuda", + **kwargs, + ): + """ + TODO: Distinguish between separable and non-separable. + """ + + # 1) call base constructor so parameters can be set + super().__init__(**kwargs) + + # 2) initialize mask + assert "distance_sensor" in kwargs, "Distance to sensor must be specified" + assert "method" in kwargs, "Method must be specified." + assert "n_bits" in kwargs, "Number of bits must be specified." + self._mask_obj = CodedAperture.from_sensor( + sensor_name, + downsample, + is_torch=True, + torch_device=torch_device, + **kwargs, + ) + + # 3) set learnable parameters (should be immediate attributes of the class) + self._row = None + self._col = None + self._mask = None + if self._mask_obj.row is not None: + # separable + self.separable = True + self._row = torch.nn.Parameter(self._mask_obj.row) + self._col = torch.nn.Parameter(self._mask_obj.col) + initial_param = [self._row, self._col] + else: + # non-separable + self.separable = False + self._mask = torch.nn.Parameter(self._mask_obj.mask) + initial_param = [self._mask] + self.binary = binary + + # 4) set optimizer + self._set_optimizer(initial_param) + + # 5) compute PSF + self._psf = None + self.project() + + def get_psf(self): + return self._psf + + def project(self): + with torch.no_grad(): + if self.separable: + self._row.data = torch.clamp(self._row, 0, 1) + self._col.data = torch.clamp(self._col, 0, 1) + if self.binary: + self._row.data = torch.round(self._row) + self._col.data = torch.round(self._col) + else: + self._mask.data = torch.clamp(self._mask, 0, 1) + if self.binary: + self._mask.data = torch.round(self._mask) + + # recompute PSF + self._mask_obj.create_mask(self._row, self._col, mask=self._mask) + self._mask_obj.compute_psf() + self._psf = self._mask_obj.psf.unsqueeze(0) + self._psf = self._psf / self._psf.norm() + + """ Utility functions to help prepare trainable masks. """ mask_type_to_class = { - "TrainablePSF": TrainablePSF, "AdafruitLCD": AdafruitLCD, + "TrainablePSF": TrainablePSF, + "TrainableCodedAperture": TrainableCodedAperture, + "TrainableHeightVarying": None, + "TrainableMultiLensArray": None, } @@ -250,65 +334,87 @@ def prep_trainable_mask(config, psf=None, downsample=None): mask = None color_filter = None downsample = config["files"]["downsample"] if downsample is None else downsample - if config["trainable_mask"]["mask_type"] is not None: - mask_class = mask_type_to_class[config["trainable_mask"]["mask_type"]] - - if config["trainable_mask"]["initial_value"] == "random": - if psf is not None: - initial_mask = torch.rand_like(psf) - else: - sensor = VirtualSensor.from_name( - config["simulation"]["sensor"], downsample=downsample - ) - resolution = sensor.resolution - initial_mask = torch.rand((1, *resolution, 3)) + mask_type = config["trainable_mask"]["mask_type"] + if mask_type is not None: + assert mask_type in mask_type_to_class.keys(), ( + f"Trainable mask type {mask_type} not supported. " + f"Supported types are {mask_type_to_class.keys()}" + ) + mask_class = mask_type_to_class[mask_type] + + # -- trainable mask object + if isinstance(config["trainable_mask"]["initial_value"], omegaconf.dictconfig.DictConfig): + + # from mask config + mask = mask_class( + # mask = TrainableCodedAperture( + sensor_name=config.simulation.sensor, + downsample=downsample, + distance_sensor=config.simulation.mask2sensor, + optimizer=config.trainable_mask.optimizer, + lr=config.trainable_mask.mask_lr, + binary=config.trainable_mask.binary, + torch_device=config.torch_device, + **config.trainable_mask.initial_value, + ) - elif config["trainable_mask"]["initial_value"] == "psf": - initial_mask = psf.clone() + else: - # if file ending with "npy" - elif config["trainable_mask"]["initial_value"].endswith("npy"): + if config["trainable_mask"]["initial_value"] == "random": + if psf is not None: + initial_mask = torch.rand_like(psf) + else: + sensor = VirtualSensor.from_name( + config["simulation"]["sensor"], downsample=downsample + ) + resolution = sensor.resolution + initial_mask = torch.rand((1, *resolution, 3)) + + elif config["trainable_mask"]["initial_value"] == "psf": + initial_mask = psf.clone() + + # if file ending with "npy" + elif config["trainable_mask"]["initial_value"].endswith("npy"): + + pattern = np.load(config["trainable_mask"]["initial_value"]) + + initial_mask = full2subpattern( + pattern=pattern, + shape=config["trainable_mask"]["ap_shape"], + center=config["trainable_mask"]["ap_center"], + slm=config["trainable_mask"]["slm"], + ) + initial_mask = torch.from_numpy(initial_mask.astype(np.float32)) - pattern = np.load(config["trainable_mask"]["initial_value"]) + # prepare color filter if needed + from waveprop.devices import slm_dict + from waveprop.devices import SLMParam as SLMParam_wp - initial_mask = full2subpattern( - pattern=pattern, - shape=config["trainable_mask"]["ap_shape"], - center=config["trainable_mask"]["ap_center"], - slm=config["trainable_mask"]["slm"], - ) - initial_mask = torch.from_numpy(initial_mask.astype(np.float32)) + slm_param = slm_dict[config["trainable_mask"]["slm"]] + if ( + config["trainable_mask"]["train_color_filter"] + and SLMParam_wp.COLOR_FILTER in slm_param.keys() + ): + color_filter = slm_param[SLMParam_wp.COLOR_FILTER] + color_filter = torch.from_numpy(color_filter.copy()).to(dtype=torch.float32) - # prepare color filter if needed - from waveprop.devices import slm_dict - from waveprop.devices import SLMParam as SLMParam_wp + # TODO: add small random values? + color_filter = color_filter + 0.1 * torch.rand_like(color_filter) - slm_param = slm_dict[config["trainable_mask"]["slm"]] - if ( - config["trainable_mask"]["train_color_filter"] - and SLMParam_wp.COLOR_FILTER in slm_param.keys() - ): - color_filter = slm_param[SLMParam_wp.COLOR_FILTER] - color_filter = torch.from_numpy(color_filter.copy()).to(dtype=torch.float32) + else: + raise ValueError( + f"Initial PSF value {config['trainable_mask']['initial_value']} not supported" + ) - # add small random values - color_filter = color_filter + 0.1 * torch.rand_like(color_filter) + # convert to grayscale if needed + if config["trainable_mask"]["grayscale"] and not is_grayscale(initial_mask): + initial_mask = rgb2gray(initial_mask) - else: - raise ValueError( - f"Initial PSF value {config['trainable_mask']['initial_value']} not supported" + mask = mask_class( + initial_mask, + downsample=downsample, + color_filter=color_filter, + **config["trainable_mask"], ) - # convert to grayscale if needed - if config["trainable_mask"]["grayscale"] and not is_grayscale(initial_mask): - initial_mask = rgb2gray(initial_mask) - - mask = mask_class( - initial_mask, - optimizer="Adam", - downsample=downsample, - color_filter=color_filter, - **config["trainable_mask"], - ) - return mask diff --git a/lensless/recon/utils.py b/lensless/recon/utils.py index 21eb57e9..897328aa 100644 --- a/lensless/recon/utils.py +++ b/lensless/recon/utils.py @@ -21,6 +21,7 @@ from lensless.recon.drunet.network_unet import UNetRes from lensless.utils.io import save_image from lensless.utils.plot import plot_image +from lensless.utils.dataset import SimulatedDatasetTrainableMask def load_drunet(model_path=None, n_channels=3, requires_grad=False): @@ -427,12 +428,24 @@ def __init__( self.skip_NAN = skip_NAN self.eval_batch_size = eval_batch_size + # check if Subset and if simulating dataset + self.simulated_dataset_trainable_mask = False + if isinstance(self.test_dataset, SimulatedDatasetTrainableMask): + # assuming the case for both training and testing + self.simulated_dataset_trainable_mask = True + self.mask = mask if mask is not None: assert isinstance(mask, TrainableMask) self.use_mask = True else: self.use_mask = False + if self.use_mask: + # save original PSF + psf_np = self.mask.get_psf().detach().cpu().numpy()[0, ...] + psf_np = psf_np.squeeze() # remove (potential) singleton color channel + np.save(os.path.join("psf_original.npy"), psf_np) + save_image(psf_np, os.path.join("psf_original.png")) self.l1_mask = l1_mask self.gamma = gamma @@ -521,11 +534,10 @@ def detect_nan(grad): def set_optimizer(self, last_epoch=-1): - if self.optimizer_config.type == "Adam": - parameters = [{"params": self.recon.parameters()}] - self.optimizer = torch.optim.Adam(parameters, lr=self.optimizer_config.lr) - else: - raise ValueError(f"Unsupported optimizer : {self.optimizer_config.type}") + parameters = [{"params": self.recon.parameters()}] + self.optimizer = getattr(torch.optim, self.optimizer_config.type)( + parameters, lr=self.optimizer_config.lr + ) # Scheduler if self.optimizer_config.slow_start: @@ -602,7 +614,6 @@ def train_epoch(self, data_loader): y_max = torch.amax(y, dim=(-1, -2, -3), keepdim=True) + eps y = y / y_max - self.optimizer.zero_grad(set_to_none=True) # convert to CHW for loss and remove depth y_pred = y_pred.reshape(-1, *y_pred.shape[-3:]).movedim(-1, -3) y = y.reshape(-1, *y.shape[-3:]).movedim(-1, -3) @@ -635,7 +646,9 @@ def train_epoch(self, data_loader): self.Loss_lpips(2 * y_pred - 1, 2 * y - 1) ) if self.use_mask and self.l1_mask: - loss_v = loss_v + self.l1_mask * torch.mean(torch.abs(self.mask._mask)) + for p in self.mask.parameters(): + if p.requires_grad: + loss_v = loss_v + self.l1_mask * torch.mean(torch.abs(p)) if self.unrolled_output_factor: # -- normalize @@ -673,25 +686,47 @@ def train_epoch(self, data_loader): # backward pass loss_v.backward() + # check mask parameters are learning + if self.use_mask: + for p in self.mask.parameters(): + assert p.grad is not None + if self.clip_grad_norm is not None: + if self.use_mask: + torch.nn.utils.clip_grad_norm_(self.mask.parameters(), self.clip_grad_norm) torch.nn.utils.clip_grad_norm_(self.recon.parameters(), self.clip_grad_norm) # if any gradient is NaN, skip training step if self.skip_NAN: - is_NAN = False + recon_is_NAN = False + mask_is_NAN = False for param in self.recon.parameters(): if param.grad is not None and torch.isnan(param.grad).any(): - is_NAN = True + recon_is_NAN = True break - if is_NAN: - self.print("NAN detected in gradiant, skipping training step") + if self.use_mask: + for param in self.mask.parameters(): + if param.grad is not None and torch.isnan(param.grad).any(): + mask_is_NAN = True + break + if recon_is_NAN or mask_is_NAN: + if recon_is_NAN: + self.print( + "NAN detected in reconstruction gradient, skipping training step" + ) + if mask_is_NAN: + self.print("NAN detected in mask gradient, skipping training step") i += 1 continue + self.optimizer.step() + self.optimizer.zero_grad(set_to_none=True) # update mask if self.use_mask: self.mask.update_mask() + if self.simulated_dataset_trainable_mask: + self.train_dataloader.dataset.set_psf() mean_loss += (loss_v.item() - mean_loss) * (1 / i) pbar.set_description(f"loss : {mean_loss}") @@ -717,6 +752,10 @@ def evaluate(self, mean_loss, save_pt, epoch, disp=None): if self.test_dataset is None: return + if self.use_mask and self.simulated_dataset_trainable_mask: + with torch.no_grad(): + self.test_dataset.set_psf() + output_dir = None if disp is not None: output_dir = os.path.join("eval_recon") @@ -751,7 +790,10 @@ def evaluate(self, mean_loss, save_pt, epoch, disp=None): if self.lpips is not None: eval_loss += self.lpips * current_metrics["LPIPS_Vgg"] if self.use_mask and self.l1_mask: - eval_loss += self.l1_mask * np.mean(np.abs(self.mask._mask.cpu().detach().numpy())) + with torch.no_grad(): + for p in self.mask.parameters(): + if p.requires_grad: + eval_loss += self.l1_mask * np.mean(np.abs(p.cpu().detach().numpy())) if self.unrolled_output_factor: unrolled_loss = current_metrics["MSE_unrolled"] if self.lpips is not None: @@ -783,7 +825,6 @@ def on_epoch_end(self, mean_loss, save_pt, epoch, disp=None): save_pt = os.getcwd() # save model - # self.save(path=save_pt, include_optimizer=False) epoch_eval_metric = self.evaluate(mean_loss, save_pt, epoch, disp=disp) new_best = False if ( @@ -870,24 +911,17 @@ def save(self, epoch, path="recon", include_optimizer=False): # create directory if it does not exist if not os.path.exists(path): os.makedirs(path) - # save mask - if self.use_mask: - # torch.save(self.mask._mask, os.path.join(path, f"mask_epoch{epoch}.pt")) - # save mask as numpy array - if self.mask.train_mask_vals: - np.save( - os.path.join(path, f"mask_epoch{epoch}.npy"), - self.mask._mask.cpu().detach().numpy(), - ) + # save mask parameters + if self.use_mask: - # if color_filter is an attribute - if hasattr(self.mask, "color_filter") and self.mask.color_filter is not None: - # save save numpy array - np.save( - os.path.join(path, f"mask_color_filter_epoch{epoch}.npy"), - self.mask.color_filter.cpu().detach().numpy(), - ) + for name, param in self.mask.named_parameters(): + # save as numpy array + if param.requires_grad: + np.save( + os.path.join(path, f"mask{name}_epoch{epoch}.npy"), + param.cpu().detach().numpy(), + ) torch.save( self.mask._optimizer.state_dict(), os.path.join(path, f"mask_optim_epoch{epoch}.pt") @@ -899,9 +933,22 @@ def save(self, epoch, path="recon", include_optimizer=False): save_image(psf_np, os.path.join(path, f"psf_epoch{epoch}.png")) plot_image(psf_np, gamma=self.gamma) plt.savefig(os.path.join(path, f"psf_epoch{epoch}_plot.png")) + if epoch == "BEST": + # save difference with original PSF + psf_original = np.load("psf_original.npy") + diff = psf_np - psf_original + np.save(os.path.join(path, "psf_epochBEST_diff.npy"), diff) + diff_abs = np.abs(diff) + save_image(diff_abs, os.path.join(path, "psf_epochBEST_diffabs.png")) + _, ax = plt.subplots() + im = ax.imshow(diff_abs, cmap="gray" if diff_abs.ndim == 2 else None) + plt.colorbar(im, ax=ax) + ax.set_title("Absolute difference with original PSF") + plt.savefig(os.path.join(path, "psf_epochBEST_diffabs_plot.png")) # save optimizer if include_optimizer: torch.save(self.optimizer.state_dict(), os.path.join(path, f"optim_epoch{epoch}.pt")) + # save recon torch.save(self.recon.state_dict(), os.path.join(path, f"recon_epoch{epoch}")) diff --git a/lensless/utils/dataset.py b/lensless/utils/dataset.py index 851a060c..772f718b 100644 --- a/lensless/utils/dataset.py +++ b/lensless/utils/dataset.py @@ -6,16 +6,18 @@ # Eric BEZZAM [ebezzam@gmail.com] # ############################################################################# +from hydra.utils import get_original_cwd import numpy as np import glob import os import torch from abc import abstractmethod -from torch.utils.data import Dataset -from torchvision import transforms +from torch.utils.data import Dataset, Subset +from torchvision import datasets, transforms +from lensless.hardware.trainable_mask import prep_trainable_mask from lensless.utils.simulation import FarFieldSimulator from lensless.utils.io import load_image, load_psf -from lensless.utils.image import resize +from lensless.utils.image import is_grayscale, resize, rgb2gray import re from lensless.hardware.utils import capture from lensless.hardware.utils import display @@ -849,13 +851,18 @@ def __init__( super(SimulatedDatasetTrainableMask, self).__init__(dataset, simulator, **kwargs) - def _get_images_pair(self, index): - # update psf - psf = self._mask.get_psf() - self.sim.set_point_spread_function(psf) + def set_psf(self, psf=None): + """ + Set the PSF of the simulator. - # return simulated images - return super()._get_images_pair(index) + Parameters + ---------- + psf : :py:class:`torch.Tensor`, optional + PSF to use for the simulation. If ``None``, the PSF of the mask is used. + """ + if psf is None: + psf = self._mask.get_psf() + self.sim.set_point_spread_function(psf) class HITLDatasetTrainableMask(SimulatedDatasetTrainableMask): @@ -946,3 +953,220 @@ def __getitem__(self, index): # return simulated images (replace simulated with measured) return img, lensed + + +def simulate_dataset(config, generator=None): + """ + Prepare datasets for training and testing. + + Parameters + ---------- + config : omegaconf.DictConfig + Configuration, e.g. from Hydra. See ``scripts/recon/train_unrolled.py`` for an example that uses this function. + generator : torch.Generator, optional + Random number generator, by default ``None``. + """ + + if "cuda" in config.torch_device and torch.cuda.is_available(): + device = config.torch_device + else: + device = "cpu" + + # -- prepare PSF + psf = None + if config.trainable_mask.mask_type is None or config.trainable_mask.initial_value == "psf": + psf_fp = os.path.join(get_original_cwd(), config.files.psf) + psf, _ = load_psf( + psf_fp, + downsample=config.files.downsample, + return_float=True, + return_bg=True, + bg_pix=(0, 15), + ) + if config.files.diffusercam_psf: + transform_BRG2RGB = transforms.Lambda(lambda x: x[..., [2, 1, 0]]) + psf = transform_BRG2RGB(torch.from_numpy(psf)) + + # drop depth dimension + psf = psf.to(device) + + else: + # training mask / PSF + mask = prep_trainable_mask(config, psf) + psf = mask.get_psf().to(device) + + # -- load dataset + pre_transform = None + transforms_list = [transforms.ToTensor()] + data_path = os.path.join(get_original_cwd(), "data") + if config.simulation.grayscale: + transforms_list.append(transforms.Grayscale()) + + if config.files.dataset == "mnist": + transform = transforms.Compose(transforms_list) + train_ds = datasets.MNIST(root=data_path, train=True, download=True, transform=transform) + test_ds = datasets.MNIST(root=data_path, train=False, download=True, transform=transform) + + elif config.files.dataset == "fashion_mnist": + transform = transforms.Compose(transforms_list) + train_ds = datasets.FashionMNIST( + root=data_path, train=True, download=True, transform=transform + ) + test_ds = datasets.FashionMNIST( + root=data_path, train=False, download=True, transform=transform + ) + elif config.files.dataset == "cifar10": + transform = transforms.Compose(transforms_list) + train_ds = datasets.CIFAR10(root=data_path, train=True, download=True, transform=transform) + test_ds = datasets.CIFAR10(root=data_path, train=False, download=True, transform=transform) + + elif config.files.dataset == "CelebA": + root = config.files.celeba_root + data_path = os.path.join(root, "celeba") + assert os.path.isdir( + data_path + ), f"Data path {data_path} does not exist. Make sure you download the CelebA dataset and provide the parent directory as 'config.files.celeba_root'. Download link: https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html" + transform = transforms.Compose(transforms_list) + if config.files.n_files is None: + train_ds = datasets.CelebA( + root=root, split="train", download=False, transform=transform + ) + test_ds = datasets.CelebA(root=root, split="test", download=False, transform=transform) + else: + ds = datasets.CelebA(root=root, split="all", download=False, transform=transform) + + ds = Subset(ds, np.arange(config.files.n_files)) + + train_size = int((1 - config.files.test_size) * len(ds)) + test_size = len(ds) - train_size + train_ds, test_ds = torch.utils.data.random_split( + ds, [train_size, test_size], generator=generator + ) + else: + raise NotImplementedError(f"Dataset {config.files.dataset} not implemented.") + + if config.files.dataset != "CelebA": + if config.files.n_files is not None: + train_size = int((1 - config.files.test_size) * config.files.n_files) + test_size = config.files.n_files - train_size + train_ds = Subset(train_ds, np.arange(train_size)) + test_ds = Subset(test_ds, np.arange(test_size)) + + # convert PSF + if config.simulation.grayscale and not is_grayscale(psf): + psf = rgb2gray(psf) + + # check if gpu is available + device_conv = config.torch_device + if device_conv == "cuda" and torch.cuda.is_available(): + device_conv = "cuda" + else: + device_conv = "cpu" + + # create simulator + simulator = FarFieldSimulator( + psf=psf, + is_torch=True, + **config.simulation, + ) + + # create Pytorch dataset and dataloader + crop = config.files.crop.copy() if config.files.crop is not None else None + if mask is None: + train_ds_prop = SimulatedFarFieldDataset( + dataset=train_ds, + simulator=simulator, + dataset_is_CHW=True, + device_conv=device_conv, + flip=config.simulation.flip, + vertical_shift=config.files.vertical_shift, + horizontal_shift=config.files.horizontal_shift, + crop=crop, + downsample=config.files.downsample, + pre_transform=pre_transform, + ) + test_ds_prop = SimulatedFarFieldDataset( + dataset=test_ds, + simulator=simulator, + dataset_is_CHW=True, + device_conv=device_conv, + flip=config.simulation.flip, + vertical_shift=config.files.vertical_shift, + horizontal_shift=config.files.horizontal_shift, + crop=crop, + downsample=config.files.downsample, + pre_transform=pre_transform, + ) + else: + if config.measure is not None: + + train_ds_prop = HITLDatasetTrainableMask( + rpi_username=config.measure.rpi_username, + rpi_hostname=config.measure.rpi_hostname, + celeba_root=config.files.celeba_root, + display_config=config.measure.display, + capture_config=config.measure.capture, + mask_center=config.trainable_mask.ap_center, + dataset=train_ds, + mask=mask, + simulator=simulator, + dataset_is_CHW=True, + device_conv=device_conv, + flip=config.simulation.flip, + vertical_shift=config.files.vertical_shift, + horizontal_shift=config.files.horizontal_shift, + crop=crop, + downsample=config.files.downsample, + pre_transform=pre_transform, + ) + + test_ds_prop = HITLDatasetTrainableMask( + rpi_username=config.measure.rpi_username, + rpi_hostname=config.measure.rpi_hostname, + celeba_root=config.files.celeba_root, + display_config=config.measure.display, + capture_config=config.measure.capture, + mask_center=config.trainable_mask.ap_center, + dataset=test_ds, + mask=mask, + simulator=simulator, + dataset_is_CHW=True, + device_conv=device_conv, + flip=config.simulation.flip, + vertical_shift=config.files.vertical_shift, + horizontal_shift=config.files.horizontal_shift, + crop=crop, + downsample=config.files.downsample, + pre_transform=pre_transform, + ) + + else: + + train_ds_prop = SimulatedDatasetTrainableMask( + dataset=train_ds, + mask=mask, + simulator=simulator, + dataset_is_CHW=True, + device_conv=device_conv, + flip=config.simulation.flip, + vertical_shift=config.files.vertical_shift, + horizontal_shift=config.files.horizontal_shift, + crop=crop, + downsample=config.files.downsample, + pre_transform=pre_transform, + ) + test_ds_prop = SimulatedDatasetTrainableMask( + dataset=test_ds, + mask=mask, + simulator=simulator, + dataset_is_CHW=True, + device_conv=device_conv, + flip=config.simulation.flip, + vertical_shift=config.files.vertical_shift, + horizontal_shift=config.files.horizontal_shift, + crop=crop, + downsample=config.files.downsample, + pre_transform=pre_transform, + ) + + return train_ds_prop, test_ds_prop, mask diff --git a/lensless/utils/io.py b/lensless/utils/io.py index 2ed26675..0badcb43 100644 --- a/lensless/utils/io.py +++ b/lensless/utils/io.py @@ -552,7 +552,11 @@ def save_image(img, fp, max_val=255, normalize=True): img_tmp *= max_val img_tmp = img_tmp.astype(np.uint8) - img_tmp = Image.fromarray(img_tmp) + # RGB + if len(img_tmp.shape) == 3 and img_tmp.shape[2] == 3: + img_tmp = Image.fromarray(img_tmp) + else: + img_tmp = Image.fromarray(img_tmp.squeeze()) img_tmp.save(fp) diff --git a/scripts/eval/benchmark_recon.py b/scripts/eval/benchmark_recon.py index a0a6581d..e9cd86be 100644 --- a/scripts/eval/benchmark_recon.py +++ b/scripts/eval/benchmark_recon.py @@ -74,14 +74,16 @@ def benchmark_recon(config): psf = dataset.psf crop = dataset.crop + if config.n_files is not None: + dataset = Subset(dataset, np.arange(config.n_files)) + dataset.psf = dataset.dataset.psf + # train-test split train_size = int((1 - config.files.test_size) * len(dataset)) test_size = len(dataset) - train_size _, benchmark_dataset = torch.utils.data.random_split( dataset, [train_size, test_size], generator=generator ) - if config.n_files is not None: - benchmark_dataset = Subset(benchmark_dataset, np.arange(config.n_files)) else: raise ValueError(f"Dataset {dataset} not supported") diff --git a/scripts/recon/train_unrolled.py b/scripts/recon/train_unrolled.py index 687b8936..4ad8493e 100644 --- a/scripts/recon/train_unrolled.py +++ b/scripts/recon/train_unrolled.py @@ -38,25 +38,17 @@ import os import numpy as np import time +from lensless.hardware.trainable_mask import prep_trainable_mask from lensless import UnrolledFISTA, UnrolledADMM, TrainableInversion from lensless.utils.dataset import ( DiffuserCamMirflickr, - SimulatedFarFieldDataset, - SimulatedDatasetTrainableMask, DigiCamCelebA, - HITLDatasetTrainableMask, ) from torch.utils.data import Subset -import lensless.hardware.trainable_mask -from lensless.hardware.slm import full2subpattern -from lensless.hardware.sensor import VirtualSensor from lensless.recon.utils import create_process_network -from lensless.utils.image import rgb2gray, is_grayscale -from lensless.utils.simulation import FarFieldSimulator +from lensless.utils.dataset import simulate_dataset from lensless.recon.utils import Trainer import torch -from torchvision import transforms, datasets -from lensless.utils.io import load_psf from lensless.utils.io import save_image from lensless.utils.plot import plot_image from lensless import ADMM @@ -66,265 +58,6 @@ log = logging.getLogger(__name__) -def simulate_dataset(config, generator=None): - - if config.torch_device == "cuda" and torch.cuda.is_available(): - device = "cuda" - else: - device = "cpu" - - # -- prepare PSF - psf = None - if config.trainable_mask.mask_type is None or config.trainable_mask.initial_value == "psf": - psf_fp = os.path.join(get_original_cwd(), config.files.psf) - psf, _ = load_psf( - psf_fp, - downsample=config.files.downsample, - return_float=True, - return_bg=True, - bg_pix=(0, 15), - ) - if config.files.diffusercam_psf: - transform_BRG2RGB = transforms.Lambda(lambda x: x[..., [2, 1, 0]]) - psf = transform_BRG2RGB(torch.from_numpy(psf)) - - # drop depth dimension - psf = psf.to(device) - - else: - # training mask / PSF - mask = prep_trainable_mask(config, psf) - psf = mask.get_psf().to(device) - - # -- load dataset - pre_transform = None - transforms_list = [transforms.ToTensor()] - data_path = os.path.join(get_original_cwd(), "data") - if config.simulation.grayscale: - transforms_list.append(transforms.Grayscale()) - - if config.files.dataset == "mnist": - transform = transforms.Compose(transforms_list) - train_ds = datasets.MNIST(root=data_path, train=True, download=True, transform=transform) - test_ds = datasets.MNIST(root=data_path, train=False, download=True, transform=transform) - elif config.files.dataset == "fashion_mnist": - transform = transforms.Compose(transforms_list) - train_ds = datasets.FashionMNIST( - root=data_path, train=True, download=True, transform=transform - ) - test_ds = datasets.FashionMNIST( - root=data_path, train=False, download=True, transform=transform - ) - elif config.files.dataset == "cifar10": - transform = transforms.Compose(transforms_list) - train_ds = datasets.CIFAR10(root=data_path, train=True, download=True, transform=transform) - test_ds = datasets.CIFAR10(root=data_path, train=False, download=True, transform=transform) - elif config.files.dataset == "CelebA": - root = config.files.celeba_root - data_path = os.path.join(root, "celeba") - assert os.path.isdir( - data_path - ), f"Data path {data_path} does not exist. Make sure you download the CelebA dataset and provide the parent directory as 'config.files.celeba_root'. Download link: https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html" - transform = transforms.Compose(transforms_list) - if config.files.n_files is None: - train_ds = datasets.CelebA( - root=root, split="train", download=False, transform=transform - ) - test_ds = datasets.CelebA(root=root, split="test", download=False, transform=transform) - else: - ds = datasets.CelebA(root=root, split="all", download=False, transform=transform) - - ds = Subset(ds, np.arange(config.files.n_files)) - - train_size = int((1 - config.files.test_size) * len(ds)) - test_size = len(ds) - train_size - train_ds, test_ds = torch.utils.data.random_split( - ds, [train_size, test_size], generator=generator - ) - else: - raise NotImplementedError(f"Dataset {config.files.dataset} not implemented.") - - # convert PSF - if config.simulation.grayscale and not is_grayscale(psf): - psf = rgb2gray(psf) - - # check if gpu is available - device_conv = config.torch_device - if device_conv == "cuda" and torch.cuda.is_available(): - device_conv = "cuda" - else: - device_conv = "cpu" - - # create simulator - simulator = FarFieldSimulator( - psf=psf, - is_torch=True, - **config.simulation, - ) - - # create Pytorch dataset and dataloader - crop = config.files.crop.copy() if config.files.crop is not None else None - if mask is None: - train_ds_prop = SimulatedFarFieldDataset( - dataset=train_ds, - simulator=simulator, - dataset_is_CHW=True, - device_conv=device_conv, - flip=config.simulation.flip, - vertical_shift=config.files.vertical_shift, - horizontal_shift=config.files.horizontal_shift, - crop=crop, - downsample=config.files.downsample, - pre_transform=pre_transform, - ) - test_ds_prop = SimulatedFarFieldDataset( - dataset=test_ds, - simulator=simulator, - dataset_is_CHW=True, - device_conv=device_conv, - flip=config.simulation.flip, - vertical_shift=config.files.vertical_shift, - horizontal_shift=config.files.horizontal_shift, - crop=crop, - downsample=config.files.downsample, - pre_transform=pre_transform, - ) - else: - if config.measure is not None: - - train_ds_prop = HITLDatasetTrainableMask( - rpi_username=config.measure.rpi_username, - rpi_hostname=config.measure.rpi_hostname, - celeba_root=config.files.celeba_root, - display_config=config.measure.display, - capture_config=config.measure.capture, - mask_center=config.trainable_mask.ap_center, - dataset=train_ds, - mask=mask, - simulator=simulator, - dataset_is_CHW=True, - device_conv=device_conv, - flip=config.simulation.flip, - vertical_shift=config.files.vertical_shift, - horizontal_shift=config.files.horizontal_shift, - crop=crop, - downsample=config.files.downsample, - pre_transform=pre_transform, - ) - - test_ds_prop = HITLDatasetTrainableMask( - rpi_username=config.measure.rpi_username, - rpi_hostname=config.measure.rpi_hostname, - celeba_root=config.files.celeba_root, - display_config=config.measure.display, - capture_config=config.measure.capture, - mask_center=config.trainable_mask.ap_center, - dataset=test_ds, - mask=mask, - simulator=simulator, - dataset_is_CHW=True, - device_conv=device_conv, - flip=config.simulation.flip, - vertical_shift=config.files.vertical_shift, - horizontal_shift=config.files.horizontal_shift, - crop=crop, - downsample=config.files.downsample, - pre_transform=pre_transform, - ) - - else: - - train_ds_prop = SimulatedDatasetTrainableMask( - dataset=train_ds, - mask=mask, - simulator=simulator, - dataset_is_CHW=True, - device_conv=device_conv, - flip=config.simulation.flip, - vertical_shift=config.files.vertical_shift, - horizontal_shift=config.files.horizontal_shift, - crop=crop, - downsample=config.files.downsample, - pre_transform=pre_transform, - ) - test_ds_prop = SimulatedDatasetTrainableMask( - dataset=test_ds, - mask=mask, - simulator=simulator, - dataset_is_CHW=True, - device_conv=device_conv, - flip=config.simulation.flip, - vertical_shift=config.files.vertical_shift, - horizontal_shift=config.files.horizontal_shift, - crop=crop, - downsample=config.files.downsample, - pre_transform=pre_transform, - ) - - return train_ds_prop, test_ds_prop, mask - - -def prep_trainable_mask(config, psf=None, downsample=None): - mask = None - color_filter = None - downsample = config.files.downsample if downsample is None else downsample - if config.trainable_mask.mask_type is not None: - mask_class = getattr(lensless.hardware.trainable_mask, config.trainable_mask.mask_type) - - if config.trainable_mask.initial_value == "random": - if psf is not None: - initial_mask = torch.rand_like(psf) - else: - sensor = VirtualSensor.from_name(config.simulation.sensor, downsample=downsample) - resolution = sensor.resolution - initial_mask = torch.rand((1, *resolution, 3)) - elif config.trainable_mask.initial_value == "psf": - initial_mask = psf.clone() - # if file ending with "npy" - elif config.trainable_mask.initial_value.endswith("npy"): - pattern = np.load(os.path.join(get_original_cwd(), config.trainable_mask.initial_value)) - - initial_mask = full2subpattern( - pattern=pattern, - shape=config.trainable_mask.ap_shape, - center=config.trainable_mask.ap_center, - slm=config.trainable_mask.slm, - ) - initial_mask = torch.from_numpy(initial_mask.astype(np.float32)) - - # prepare color filter if needed - from waveprop.devices import slm_dict - from waveprop.devices import SLMParam as SLMParam_wp - - slm_param = slm_dict[config.trainable_mask.slm] - if ( - config.trainable_mask.train_color_filter - and SLMParam_wp.COLOR_FILTER in slm_param.keys() - ): - color_filter = slm_param[SLMParam_wp.COLOR_FILTER] - color_filter = torch.from_numpy(color_filter.copy()).to(dtype=torch.float32) - - # add small random values - color_filter = color_filter + 0.1 * torch.rand_like(color_filter) - else: - raise ValueError( - f"Initial PSF value {config.trainable_mask.initial_value} not supported" - ) - - if config.trainable_mask.grayscale and not is_grayscale(initial_mask): - initial_mask = rgb2gray(initial_mask) - - mask = mask_class( - initial_mask, - optimizer="Adam", - downsample=downsample, - color_filter=color_filter, - **config.trainable_mask, - ) - - return mask - - @hydra.main(version_base=None, config_path="../../configs", config_name="train_unrolledADMM") def train_unrolled(config): @@ -346,9 +79,10 @@ def train_unrolled(config): if save: save = os.getcwd() - if config.torch_device == "cuda" and torch.cuda.is_available(): + if "cuda" in config.torch_device and torch.cuda.is_available(): + # if config.torch_device == "cuda" and torch.cuda.is_available(): log.info("Using GPU for training.") - device = "cuda" + device = config.torch_device else: log.info("Using CPU for training.") device = "cpu"