From 1bc1ad5b671e27ab54991f1f2f8908abf5f246f1 Mon Sep 17 00:00:00 2001 From: Eric Bezzam Date: Sat, 26 Oct 2024 23:00:31 +0000 Subject: [PATCH] Add script for psf err vs metric. --- configs/authen.yaml | 2 + .../benchmark_digicam_mirflickr_multi.yaml | 4 +- configs/recon_digicam_mirflickr.yaml | 1 + configs/recon_digicam_mirflickr_err.yaml | 22 ++ configs/train_unrolledADMM.yaml | 6 +- lensless/eval/benchmark.py | 4 +- lensless/recon/trainable_recon.py | 30 ++ lensless/utils/dataset.py | 70 +++-- scripts/data/authenticate.py | 2 +- scripts/recon/digicam_mirflickr.py | 5 +- scripts/recon/digicam_mirflickr_psf_err.py | 259 ++++++++++++++++++ scripts/recon/train_learning_based.py | 21 ++ 12 files changed, 399 insertions(+), 27 deletions(-) create mode 100644 configs/recon_digicam_mirflickr_err.yaml create mode 100644 scripts/recon/digicam_mirflickr_psf_err.py diff --git a/configs/authen.yaml b/configs/authen.yaml index 06e56da8..6ad25bd3 100644 --- a/configs/authen.yaml +++ b/configs/authen.yaml @@ -14,6 +14,8 @@ save_idx: [1, 2, 4, 5, 9] font_scale: 1.5 # for plotting confusion matrix +metric: "recon" # "recon", "mse", "lpips" + # Dataset parameters huggingface: repo: "bezzam/DigiCam-Mirflickr-MultiMask-25K" diff --git a/configs/benchmark_digicam_mirflickr_multi.yaml b/configs/benchmark_digicam_mirflickr_multi.yaml index 17ff3508..1523a3c4 100644 --- a/configs/benchmark_digicam_mirflickr_multi.yaml +++ b/configs/benchmark_digicam_mirflickr_multi.yaml @@ -25,9 +25,9 @@ algorithms: [ ## -- reconstructions trained on measured data "hf:digicam:mirflickr_multi_25k:Unet4M+U5+Unet4M_wave", - "hf:digicam:mirflickr_multi_25k:Unet4M+U10+Unet4M_wave", + # "hf:digicam:mirflickr_multi_25k:Unet4M+U10+Unet4M_wave", "hf:digicam:mirflickr_multi_25k:Unet4M+U5+Unet4M_wave_aux1", - "hf:digicam:mirflickr_multi_25k:Unet4M+U5+Unet4M_wave_flips", + # "hf:digicam:mirflickr_multi_25k:Unet4M+U5+Unet4M_wave_flips", # ## -- reconstructions trained on other datasets/systems # "hf:diffusercam:mirflickr:Unet4M+U10+Unet4M", diff --git a/configs/recon_digicam_mirflickr.yaml b/configs/recon_digicam_mirflickr.yaml index 38143f0c..cc21c339 100644 --- a/configs/recon_digicam_mirflickr.yaml +++ b/configs/recon_digicam_mirflickr.yaml @@ -3,6 +3,7 @@ defaults: - defaults_recon - _self_ +dataset: mirflickr_single_25k # for loading model, "mirflickr_single_25k" or "mirflickr_multi_25k" cache_dir: /dev/shm # fn: null # if not null, download this file from https://huggingface.co/datasets/bezzam/DigiCam-Mirflickr-SingleMask-25K/tree/main diff --git a/configs/recon_digicam_mirflickr_err.yaml b/configs/recon_digicam_mirflickr_err.yaml new file mode 100644 index 00000000..2e08bbc7 --- /dev/null +++ b/configs/recon_digicam_mirflickr_err.yaml @@ -0,0 +1,22 @@ +# python scripts/recon/digicam_mirflickr.py +defaults: + - defaults_recon + - _self_ + +cache_dir: null +metrics_fp : null +hf_repo: null # by default use one in model config + +# set model +# -- for learning-based methods (comment if using ADMM) +model: Unet4M+U5+Unet4M_wave + +# # -- for ADMM with fixed parameters +# model: admm +# n_iter: 10 + +device: cuda:1 +save_idx: [1, 2, 4, 5, 9] +n_files: null +percent_pixels_wrong: [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100] +flip: True # whether to flip mask values (True) or reset them (False) \ No newline at end of file diff --git a/configs/train_unrolledADMM.yaml b/configs/train_unrolledADMM.yaml index 054d9e9b..ce4d851d 100644 --- a/configs/train_unrolledADMM.yaml +++ b/configs/train_unrolledADMM.yaml @@ -90,6 +90,10 @@ reconstruction: init_pre: True # if `init_processors`, set pre-procesor is available init_post: True # if `init_processors`, set post-procesor is available + # processing PSF + psf_network: False # False, or set number of channels for UnetRes, e.g. [8,16,32,64], with skip connection + psf_residual: True # if psf_network used, whether to use residual connection for original PSF estimate + # background subtraction (if dataset has corresponding background images) direct_background_subtraction: False # True or False learned_background_subtraction: False # False, or set number of channels for UnetRes, e.g. [8,16,32,64] @@ -193,10 +197,10 @@ optimizer: type: AdamW # Adam, SGD... (Pytorch class) lr: 1e-4 lr_step_epoch: True # True -> update LR at end of each epoch, False at the end of each mini-batch + cosine_decay_warmup: True # if set, cosine decay with warmup of 5% final_lr: False # if set, exponentially decay *to* this value exp_decay: False # if set, exponentially decay *with* this value slow_start: False #float how much to reduce lr for first epoch - cosine_decay_warmup: True # if set, cosine decay with warmup of 5% # Decay LR in step fashion: https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.StepLR.html step: False # int, period of learning rate decay. False to not apply gamma: 0.1 # float, factor for learning rate decay diff --git a/lensless/eval/benchmark.py b/lensless/eval/benchmark.py index 73c03b09..57a052f6 100644 --- a/lensless/eval/benchmark.py +++ b/lensless/eval/benchmark.py @@ -157,7 +157,9 @@ def benchmark( ) else: - prediction = model.forward(lensless, psfs, background=background, **kwargs) + prediction = model.forward( + batch=lensless, psfs=psfs, background=background, **kwargs + ) if unrolled_output_factor or pre_process_aux: pre_process_out = prediction[2] diff --git a/lensless/recon/trainable_recon.py b/lensless/recon/trainable_recon.py index d689adeb..8903f589 100644 --- a/lensless/recon/trainable_recon.py +++ b/lensless/recon/trainable_recon.py @@ -58,6 +58,8 @@ def __init__( legacy_denoiser=False, compensation=None, compensation_residual=True, + psf_network=None, + psf_residual=True, # background subtraction direct_background_subtraction=False, background_network=None, @@ -95,6 +97,10 @@ def __init__( Post-processor must be defined if compensation provided. compensation_residual : bool, optional Whether to use residual connection in compensation layer. + psf_network : :py:class:`function` or :py:class:`~torch.nn.Module`, optional + Function or model to apply to PSF prior to camera inversion. + psf_residual : bool, optional + Whether to use residual connection in PSF network. """ assert isinstance(psf, torch.Tensor), "PSF must be a torch.Tensor" @@ -141,6 +147,12 @@ def __init__( ), "Cannot use direct_background_subtraction and background_network at the same time." self.set_background_network(background_network) + # PSF network + self.psf_network = None + self.psf_residual = psf_residual + if psf_network is not None: + self.set_psf_network(psf_network) + # compensation branch self.return_intermediate = return_intermediate self.compensation_branch = compensation @@ -227,6 +239,13 @@ def set_background_network(self, background_network): self.background_network_param, ) = self._prepare_process_block(background_network) + def set_psf_network(self, psf_network): + ( + self.psf_network, + self.psf_network_model, + self.psf_network_param, + ) = self._prepare_process_block(psf_network) + def freeze_pre_process(self): """ Method for freezing the pre process block. @@ -307,6 +326,15 @@ def forward(self, batch, psfs=None, background=None): ).to(self._data.device) self._data = torch.clamp(self._data, 0, 1) + # set / transform PSFs if need be + if self.psf_network is not None: + if psfs is None: + psfs = self._psf + if self.psf_residual: + psfs = self.psf_network(psfs, self.psf_network_param).to(psfs.device) + psfs + else: + psfs = self.psf_network(psfs, self.psf_network_param).to(psfs.device) + if psfs is not None: # assert same shape assert psfs.shape == batch.shape, "psfs must have the same shape as batch" @@ -381,6 +409,8 @@ def apply( algorithm, the number of iteration isn't required. Note that `set_data` must be called beforehand. + # TODO apply PSF network + Parameters ---------- disp_iter : int diff --git a/lensless/utils/dataset.py b/lensless/utils/dataset.py index 725f418b..7ea29328 100644 --- a/lensless/utils/dataset.py +++ b/lensless/utils/dataset.py @@ -32,7 +32,6 @@ import warnings from waveprop.noise import add_shot_noise from lensless.utils.image import shift_with_pad -from PIL import Image def convert(text): @@ -1351,6 +1350,9 @@ def __init__( self.display_res = display_res self.return_mask_label = return_mask_label self.force_rgb = force_rgb # if some data is not 3D + self.sensor = sensor + self.slm = slm + self.simulation_config = simulation_config # augmentation self.random_flip = random_flip @@ -1372,6 +1374,7 @@ def __init__( downsample_fact = min(sensor_res / lensless.shape[:2]) else: downsample_fact = 1 + self.downsample_fact = downsample_fact # deduce recon shape from original image self.alignment = None @@ -1410,6 +1413,7 @@ def __init__( # download all masks # TODO: reshape directly with lensless image shape self.multimask = False + self.huggingface_repo = huggingface_repo if psf is not None: # download PSF from huggingface psf_fp = hf_hub_download(repo_id=huggingface_repo, filename=psf, repo_type="dataset") @@ -1435,28 +1439,32 @@ def __init__( for i in range(len(self.dataset)): mask_labels.append(self.dataset[i]["mask_label"]) mask_labels = list(set(mask_labels)) + self.mask_labels = mask_labels # simulate all PSFs self.psf = dict() for label in mask_labels: - mask_fp = hf_hub_download( - repo_id=huggingface_repo, - filename=f"masks/mask_{label}.npy", - repo_type="dataset", - ) - mask_vals = np.load(mask_fp) - mask = AdafruitLCD( - initial_vals=torch.from_numpy(mask_vals.astype(np.float32)), - sensor=sensor, - slm=slm, - downsample=downsample_fact, - flipud=self.rotate or flipud, # TODO separate commands? - use_waveprop=simulation_config.get("use_waveprop", False), - scene2mask=simulation_config.get("scene2mask", None), - mask2sensor=simulation_config.get("mask2sensor", None), - deadspace=simulation_config.get("deadspace", True), - ) - self.psf[label] = mask.get_psf().detach() + + mask_vals = self.get_mask_vals(label) + self.psf[label] = self.simulate_psf(mask_vals) + # mask_fp = hf_hub_download( + # repo_id=huggingface_repo, + # filename=f"masks/mask_{label}.npy", + # repo_type="dataset", + # ) + # mask_vals = np.load(mask_fp) + # mask = AdafruitLCD( + # initial_vals=torch.from_numpy(mask_vals.astype(np.float32)), + # sensor=sensor, + # slm=slm, + # downsample=downsample_fact, + # flipud=self.rotate or flipud, # TODO separate commands? + # use_waveprop=simulation_config.get("use_waveprop", False), + # scene2mask=simulation_config.get("scene2mask", None), + # mask2sensor=simulation_config.get("mask2sensor", None), + # deadspace=simulation_config.get("deadspace", True), + # ) + # self.psf[label] = mask.get_psf().detach() assert ( self.psf[label].shape[-3:-1] == lensless.shape[:2] @@ -1541,6 +1549,30 @@ def __init__( def __len__(self): return len(self.dataset) + def get_mask_vals(self, idx): + assert self.multimask + assert idx in self.mask_labels + mask_fp = hf_hub_download( + repo_id=self.huggingface_repo, + filename=f"masks/mask_{idx}.npy", + repo_type="dataset", + ) + return np.load(mask_fp) + + def simulate_psf(self, mask_vals): + mask = AdafruitLCD( + initial_vals=torch.from_numpy(mask_vals.astype(np.float32)), + sensor=self.sensor, + slm=self.slm, + downsample=self.downsample_fact, + flipud=self.rotate or self.flipud, # TODO separate commands? + use_waveprop=self.simulation_config.get("use_waveprop", False), + scene2mask=self.simulation_config.get("scene2mask", None), + mask2sensor=self.simulation_config.get("mask2sensor", None), + deadspace=self.simulation_config.get("deadspace", True), + ) + return mask.get_psf().detach() + def _get_images_pair(self, idx): # load images diff --git a/scripts/data/authenticate.py b/scripts/data/authenticate.py index 5599c696..70e4a681 100644 --- a/scripts/data/authenticate.py +++ b/scripts/data/authenticate.py @@ -246,7 +246,7 @@ def authen(config): if i in save_idx: res_np = res[0].cpu().numpy() res_np = res_np / res_np.max() - fp = os.path.join(save_dir, f"{psf_idx}.png") + fp = os.path.join(save_dir, f"psf{psf_idx}.png") save_image(res_np, fp) scores[str(mask_label)].append(np.array(scores_i).tolist()) diff --git a/scripts/recon/digicam_mirflickr.py b/scripts/recon/digicam_mirflickr.py index c74de25a..927811a0 100644 --- a/scripts/recon/digicam_mirflickr.py +++ b/scripts/recon/digicam_mirflickr.py @@ -29,14 +29,13 @@ def apply_pretrained(config): model_config = yaml.safe_load(stream) else: - model_path = download_model( - camera="digicam", dataset="mirflickr_single_25k", model=model_name - ) + model_path = download_model(camera="digicam", dataset=config.dataset, model=model_name) config_path = os.path.join(model_path, ".hydra", "config.yaml") with open(config_path, "r") as stream: model_config = yaml.safe_load(stream) # load data + # TODO try with multi-mask, should load single mask dataset... test_set = HFDataset( huggingface_repo=model_config["files"]["dataset"], psf=( diff --git a/scripts/recon/digicam_mirflickr_psf_err.py b/scripts/recon/digicam_mirflickr_psf_err.py new file mode 100644 index 00000000..5b55740c --- /dev/null +++ b/scripts/recon/digicam_mirflickr_psf_err.py @@ -0,0 +1,259 @@ +import hydra +import yaml +import torch +from lensless import ADMM +from lensless.utils.dataset import HFDataset +import os +from lensless.utils.io import save_image +from tqdm import tqdm +from lensless.recon.model_dict import download_model, load_model +import numpy as np +from torchmetrics import StructuralSimilarityIndexMeasure +from torchmetrics.image import lpip, psnr +import json +from matplotlib import pyplot as plt + + +@hydra.main( + version_base=None, config_path="../../configs", config_name="recon_digicam_mirflickr_err" +) +def apply_pretrained(config): + device = config.device + model_name = config.model + percent_pixels_wrong = config.percent_pixels_wrong + + if config.metrics_fp is not None: + + # load metrics from file + with open(config.metrics_fp, "r") as f: + metrics_values = json.load(f) + + # # if not normalized... all PSFs have roughtly same norm + # metrics_values["psf_err"] = np.array(metrics_values["psf_err"]) / 1.7302357e-06 + # metrics_values["psf_err"] = metrics_values["psf_err"].tolist() + + # # resave metrics dict to JSON + # with open(config.metrics_fp, "w") as f: + # json.dump(metrics_values, f, indent=4) + + else: + + # load config + if model_name == "admm": + # take config from unrolled ADMM for dataset + model_path = download_model( + camera="digicam", dataset="mirflickr_multi_25k", model="Unet4M+U5+Unet4M_wave" + ) + config_path = os.path.join(model_path, ".hydra", "config.yaml") + with open(config_path, "r") as stream: + model_config = yaml.safe_load(stream) + + else: + model_path = download_model( + camera="digicam", dataset="mirflickr_multi_25k", model=model_name + ) + config_path = os.path.join(model_path, ".hydra", "config.yaml") + with open(config_path, "r") as stream: + model_config = yaml.safe_load(stream) + + metrics = { + "PSNR": psnr.PeakSignalNoiseRatio(reduction=None, dim=(1, 2, 3), data_range=(0, 1)).to( + device + ), + "SSIM": StructuralSimilarityIndexMeasure(reduction=None, data_range=(0, 1)).to(device), + "LPIPS_Vgg": lpip.LearnedPerceptualImagePatchSimilarity( + net_type="vgg", normalize=True, reduction="sum" + ).to(device), + "psf_err": torch.nn.functional.mse_loss, + } + + # load data + # TODO missing simulation parameters??? + test_set = HFDataset( + huggingface_repo=model_config["files"]["dataset"] + if config.hf_repo is None + else config.hf_repo, + psf=( + model_config["files"]["huggingface_psf"] + if "huggingface_psf" in model_config["files"] + else None + ), + split="test", + display_res=model_config["files"]["image_res"], + rotate=model_config["files"]["rotate"], + flipud=model_config["files"]["flipud"], + flip_lensed=model_config["files"]["flip_lensed"], + downsample=model_config["files"]["downsample"], + alignment=model_config["alignment"], + simulation_config=model_config["simulation"], + force_rgb=model_config["files"].get("force_rgb", False), + cache_dir=config.cache_dir, + save_psf=False, + return_mask_label=True, + ) + + # # create Dataset loader + # batch_size = 4 + # dataloader = torch.utils.data.DataLoader( + # dataset=test_set, + # batch_size=batch_size, + # shuffle=False, + # pin_memory=(device != "cpu"), + # ) + + psf_norms = [] + for mask_label in test_set.psf.keys(): + psf_norms.append(np.mean(test_set.psf[mask_label].cpu().numpy().flatten() ** 2)) + psf_norms = np.array(psf_norms) + + n_files = config.n_files + if n_files is None: + n_files = len(test_set) + percent_pixels_wrong = config.percent_pixels_wrong + + # initialize metrics dict + metrics_values = {k: np.zeros((len(percent_pixels_wrong), n_files)) for k in metrics.keys()} + + for i in config.save_idx: + # make folder + save_dir = str(i) + os.makedirs(save_dir, exist_ok=True) + + for idx in tqdm(range(n_files)): + + # get data + lensless, lensed, mask_label = test_set[idx] + lensless = lensless.to(device) + + if idx in config.save_idx: + if lensed is not None: + lensed_np = lensed[0].cpu().numpy() + save_image(lensed_np, os.path.join(str(idx), f"original_idx{idx}.png")) + save_image( + lensless[0].cpu().numpy(), os.path.join(str(idx), f"lensless_idx{idx}.png") + ) + + # -- reshape for torchmetrics + lensed = lensed.reshape(-1, *lensed.shape[-3:]).movedim(-1, -3) + lensed_max = torch.amax(lensed, dim=(1, 2, 3), keepdim=True) + lensed = lensed / lensed_max + lensed = lensed.to(device) + + _metrics_idx = {k: [] for k in metrics.keys()} + + for percent_wrong in percent_pixels_wrong: + + # perturb mask + mask_vals = test_set.get_mask_vals(mask_label) + + noisy_mask_vals = mask_vals.copy() + if percent_wrong > 0: + + n_pixels = mask_vals.size + n_wrong_pixels = int(n_pixels * percent_wrong / 100) + wrong_pixels = np.random.choice(n_pixels, n_wrong_pixels, replace=False) + noisy_mask_vals = noisy_mask_vals.flatten() + + if config.flip: + noisy_mask_vals[wrong_pixels] = ( + 1 - noisy_mask_vals[wrong_pixels] + ) # flip pixel value + else: + # reset values randomly + noisy_mask_vals[wrong_pixels] = np.random.uniform(size=n_wrong_pixels) + noisy_mask_vals = noisy_mask_vals.reshape(mask_vals.shape) + + # noise = np.random.uniform(size=mask_vals.shape) + # # -- rescale noise to desired SNR + # mask_var = ndimage.variance(mask_vals) + # noise_var = ndimage.variance(noise) + # fact = np.sqrt(mask_var / noise_var / (10 ** (mask_snr_db / 10))) + # noisy_mask_vals = mask_vals + fact * noise + # # -- clip to [0, 1] + # noisy_mask_vals = np.clip(noisy_mask_vals, 0, 1) + + # simulate PSF + psf = test_set.simulate_psf(noisy_mask_vals) + psf = psf.to(device) + + # compute L2 error with normal PSF + _metrics_idx["psf_err"].append( + metrics["psf_err"](psf, test_set.psf[mask_label].to(device)).item() + / psf_norms[mask_label] + ) + + # load model + if model_name == "admm": + recon = ADMM(psf, n_iter=config.n_iter) + else: + # load best model + recon = load_model(model_path, psf, device, verbose=False) + + # reconstruct + with torch.no_grad(): + recon.set_data(lensless) + res = recon.apply( + disp_iter=-1, + save=False, + gamma=None, + plot=False, + ) + recon = res[0] + + # prepare for metrics + # -- convert to [N*D, C, H, W] for torchmetrics + prediction = recon.reshape(-1, *recon.shape[-3:]).movedim(-1, -3) + # - extract ROI + prediction = test_set.extract_roi(prediction, axis=(-2, -1)) + # -- normalize + prediction_max = torch.amax(prediction, dim=(1, 2, 3), keepdim=True) + prediction = prediction / prediction_max + + for k, metric in metrics.items(): + if k == "psf_err": + continue + _metrics_idx[k].append(metric(prediction, lensed).item()) + + # save + if idx in config.save_idx: + img = recon.cpu().numpy().squeeze() + alignment = test_set.alignment + top_left = alignment["top_left"] + height = alignment["height"] + width = alignment["width"] + res_np = img[ + top_left[0] : top_left[0] + height, top_left[1] : top_left[1] + width + ] + fp = os.path.join(str(idx), f"{model_name}_percentwrong{percent_wrong}.png") + save_image(res_np, fp) + + # save metrics + for k, v in _metrics_idx.items(): + metrics_values[k][:, idx] = v + + # save metric dict to JSON + # -- make sure to convert numpy arrays to lists + for k, v in metrics_values.items(): + metrics_values[k] = v.tolist() + with open(f"{model_name}_metrics.json", "w") as f: + json.dump(metrics_values, f, indent=4) + + # plot each metrics vs percent_wrong + for k, v in metrics_values.items(): + plt.figure() + plt.xlabel("Percent pixels wrong [%]") + if k == "psf_err": + plt.plot(percent_pixels_wrong, np.mean(v, axis=1) * 100) + plt.ylabel("Relative PSF error [%]") + else: + plt.plot(percent_pixels_wrong, np.mean(v, axis=1)) + plt.ylabel(k) + + # save plot + # - tight + plt.tight_layout() + plt.savefig(f"{k}_{model_name}.png") + + +if __name__ == "__main__": + apply_pretrained() diff --git a/scripts/recon/train_learning_based.py b/scripts/recon/train_learning_based.py index 591e3a78..53ddbacb 100644 --- a/scripts/recon/train_learning_based.py +++ b/scripts/recon/train_learning_based.py @@ -503,6 +503,18 @@ def train_learned(config): if name1 in dict_params2_post: dict_params2_post[name1].data.copy_(param1.data) + # network for PSF + psf_network = None + if config.reconstruction.psf_network: + # create UnetRes for PSF + psf_network, _ = create_process_network( + network="UnetRes", + depth=len(config.reconstruction.psf_network), + nc=config.reconstruction.psf_network, + device=device, + device_ids=device_ids, + ) + # check/prepare background subtraction background_network = None if config.reconstruction.direct_background_subtraction: @@ -565,6 +577,8 @@ def train_learned(config): pre_process=pre_process if pre_proc_delay is None else None, post_process=post_process if post_proc_delay is None else None, background_network=background_network, + psf_network=psf_network, + psf_residual=config.reconstruction.psf_residual, skip_unrolled=config.reconstruction.skip_unrolled, return_intermediate=( True if config.unrolled_output_factor > 0 or config.pre_proc_aux > 0 else False @@ -585,6 +599,8 @@ def train_learned(config): pre_process=pre_process if pre_proc_delay is None else None, post_process=post_process if post_proc_delay is None else None, background_network=background_network, + psf_network=psf_network, + psf_residual=config.reconstruction.psf_residual, skip_unrolled=config.reconstruction.skip_unrolled, return_intermediate=( True if config.unrolled_output_factor > 0 or config.pre_proc_aux > 0 else False @@ -596,6 +612,7 @@ def train_learned(config): ) elif config.reconstruction.method == "trainable_inv": assert config.trainable_mask.mask_type == "TrainablePSF" + assert psf_network is None recon = TrainableInversion( psf, K=config.reconstruction.trainable_inv.K, @@ -619,6 +636,7 @@ def train_learned(config): assert config.reconstruction.direct_background_subtraction is False, "Not supported" assert config.reconstruction.learned_background_subtraction is None, "Not supported" assert config.reconstruction.integrated_background_subtraction is None, "Not supported" + assert psf_network is None, "Not supported" recon = MultiWiener( in_channels=3, @@ -658,6 +676,9 @@ def train_learned(config): if background_network is not None: n_param = sum(p.numel() for p in background_network.parameters() if p.requires_grad) log.info(f"-- Background subtraction model with {n_param} parameters") + if psf_network is not None: + n_param = sum(p.numel() for p in psf_network.parameters() if p.requires_grad) + log.info(f"-- PSF network model with {n_param} parameters") log.info(f"Setup time : {time.time() - start_time} s") log.info(f"PSF shape : {psf.shape}")