diff --git a/lensless/eval/benchmark.py b/lensless/eval/benchmark.py index 73c03b09..45c3d04f 100644 --- a/lensless/eval/benchmark.py +++ b/lensless/eval/benchmark.py @@ -8,12 +8,14 @@ from lensless.utils.dataset import DiffuserCamTestDataset +from lensless.utils.dataset import HFDataset from lensless.utils.io import save_image from waveprop.noise import add_shot_noise from tqdm import tqdm import os import numpy as np import wandb +from lensless.eval.metric import clip_iqa try: import torch @@ -102,6 +104,7 @@ def benchmark( ), "SSIM": StructuralSimilarityIndexMeasure(reduction=None, data_range=(0, 1)).to(device), "ReconstructionError": None, + "CLIP-IQA": clip_iqa } metrics_values = {key: [] for key in metrics} @@ -241,6 +244,10 @@ def benchmark( metrics_values[metric].append( metrics[metric](prediction, lensed).cpu().item() ) + elif metric == "CLIP-IQA": + metrics_values[metric].append( + metrics[metric](prediction, lensed).cpu().item() + ) elif metric == "MSE": metrics_values[metric].append( metrics[metric](prediction, lensed).cpu().item() * len(batch[0]) @@ -350,7 +357,44 @@ def benchmark( device = "cpu" # prepare dataset - dataset = DiffuserCamTestDataset(n_files=n_files, downsample=downsample) + #dataset = DiffuserCamTestDataset(n_files=n_files, downsample=downsample) + dataset = HFDataset( + huggingface_repo='Lensless/TapeCam-Mirflickr-Ambient-100', + cache_dir=None, + psf='psf.png', + single_channel_psf=False, + split="test", + display_res=[600, 600], + rotate=False, + flipud=False, + flip_lensed=False, + downsample=1, + downsample_lensed=2, + alignment={'top_left': [85, 185], 'height': 178}, + save_psf=True, + n_files=None, + simulation_config={ + 'grayscale': False, + 'output_dim': None, + 'object_height': 0.04, + 'flip': True, + 'random_shift': False, + 'random_vflip': 0.5, + 'random_hflip': 0.5, + 'random_rotate': False, + 'scene2mask': 0.1, + 'mask2sensor': 0.009, + 'deadspace': True, + 'use_waveprop': False, + 'sensor': 'rgb', # Replace with the correct value if different + }, + per_pixel_color_shift=True, + per_pixel_color_shift_range=[0.8, 1.2], + bg_snr_range=None, + bg_fp=None, + force_rgb=False, + simulate_lensless=False, + ) # prepare model psf = dataset.psf.to(device) diff --git a/lensless/eval/metric.py b/lensless/eval/metric.py index bd1746bb..74a338b1 100644 --- a/lensless/eval/metric.py +++ b/lensless/eval/metric.py @@ -112,10 +112,19 @@ from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity import lpips as lpips_lib import torch +import torch.nn.functional as F +from torchmetrics.multimodal import CLIPImageQualityAssessment from scipy.ndimage import rotate from lensless.utils.image import resize +# Initialize CLIP-IQA model +clip_iqa_model = CLIPImageQualityAssessment( + model_name_or_path=("clip_iqa"), + prompts=("noisiness", ), # TODO change if different metric is required + ).to(torch.device('cuda' if torch.cuda.is_available() else 'cpu')) + + def mse(true, est, normalize=True): """ Compute the mean-squared error between two images. The closer to 0, the @@ -260,7 +269,6 @@ def lpips(true, est, normalize=True): ) return loss_fn.forward(true, est).squeeze().item() - def extract( estimate, original, vertical_crop=None, horizontal_crop=None, rotation=0, verbose=False ): @@ -329,3 +337,36 @@ def extract( print(img_resize.max()) return estimate, img_resize + +def clip_iqa(true, est, normalize=True): + """ + Computes the CLIP Image Quality Assessment (CLIP-IQA) score between the true and estimated images. + Args: + true (Tensor): The ground truth image tensor. + est (Tensor): The estimated image tensor. + normalize (bool, optional): If True, normalize the images before computing the CLIP-IQA score. Default is True. + Returns: + float: The CLIP-IQA score. + """ + # if normalize: + # true = np.array(true, dtype=np.float32) + # est = np.array(est, dtype=np.float32) + # true /= true.max() + # est /= est.max() + + # Compute CLIP-IQA + with torch.no_grad(): + # Resize images to 224x224 for CLIP-IQA + outputs_resized = F.interpolate( + est, size=(224, 224), mode="bilinear", align_corners=False + ) + + outputs_3d = outputs_resized + + #clip_iqa_scores = self.clip_iqa(outputs_3d) + + + return clip_iqa_model(outputs_3d) + + # Compute CLIP-IQA scores over the batch + clip_iqa = clip_iqa_scores.mean().item() \ No newline at end of file