Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: Add new CLIP-IQA metric #157

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion lensless/eval/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@


from lensless.utils.dataset import DiffuserCamTestDataset
from lensless.utils.dataset import HFDataset
from lensless.utils.io import save_image
from waveprop.noise import add_shot_noise
from tqdm import tqdm
import os
import numpy as np
import wandb
from lensless.eval.metric import clip_iqa

try:
import torch
Expand Down Expand Up @@ -102,6 +104,7 @@ def benchmark(
),
"SSIM": StructuralSimilarityIndexMeasure(reduction=None, data_range=(0, 1)).to(device),
"ReconstructionError": None,
"CLIP-IQA": clip_iqa
}

metrics_values = {key: [] for key in metrics}
Expand Down Expand Up @@ -241,6 +244,10 @@ def benchmark(
metrics_values[metric].append(
metrics[metric](prediction, lensed).cpu().item()
)
elif metric == "CLIP-IQA":
metrics_values[metric].append(
metrics[metric](prediction, lensed).cpu().item()
)
elif metric == "MSE":
metrics_values[metric].append(
metrics[metric](prediction, lensed).cpu().item() * len(batch[0])
Expand Down Expand Up @@ -350,7 +357,44 @@ def benchmark(
device = "cpu"

# prepare dataset
dataset = DiffuserCamTestDataset(n_files=n_files, downsample=downsample)
#dataset = DiffuserCamTestDataset(n_files=n_files, downsample=downsample)
dataset = HFDataset(
huggingface_repo='Lensless/TapeCam-Mirflickr-Ambient-100',
cache_dir=None,
psf='psf.png',
single_channel_psf=False,
split="test",
display_res=[600, 600],
rotate=False,
flipud=False,
flip_lensed=False,
downsample=1,
downsample_lensed=2,
alignment={'top_left': [85, 185], 'height': 178},
save_psf=True,
n_files=None,
simulation_config={
'grayscale': False,
'output_dim': None,
'object_height': 0.04,
'flip': True,
'random_shift': False,
'random_vflip': 0.5,
'random_hflip': 0.5,
'random_rotate': False,
'scene2mask': 0.1,
'mask2sensor': 0.009,
'deadspace': True,
'use_waveprop': False,
'sensor': 'rgb', # Replace with the correct value if different
},
per_pixel_color_shift=True,
per_pixel_color_shift_range=[0.8, 1.2],
bg_snr_range=None,
bg_fp=None,
force_rgb=False,
simulate_lensless=False,
)

# prepare model
psf = dataset.psf.to(device)
Expand Down
43 changes: 42 additions & 1 deletion lensless/eval/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,19 @@
from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity
import lpips as lpips_lib
import torch
import torch.nn.functional as F
from torchmetrics.multimodal import CLIPImageQualityAssessment
from scipy.ndimage import rotate
from lensless.utils.image import resize


# Initialize CLIP-IQA model
clip_iqa_model = CLIPImageQualityAssessment(
model_name_or_path=("clip_iqa"),
prompts=("noisiness", ), # TODO change if different metric is required
).to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))


def mse(true, est, normalize=True):
"""
Compute the mean-squared error between two images. The closer to 0, the
Expand Down Expand Up @@ -260,7 +269,6 @@ def lpips(true, est, normalize=True):
)
return loss_fn.forward(true, est).squeeze().item()


def extract(
estimate, original, vertical_crop=None, horizontal_crop=None, rotation=0, verbose=False
):
Expand Down Expand Up @@ -329,3 +337,36 @@ def extract(
print(img_resize.max())

return estimate, img_resize

def clip_iqa(true, est, normalize=True):
"""
Computes the CLIP Image Quality Assessment (CLIP-IQA) score between the true and estimated images.
Args:
true (Tensor): The ground truth image tensor.
est (Tensor): The estimated image tensor.
normalize (bool, optional): If True, normalize the images before computing the CLIP-IQA score. Default is True.
Returns:
float: The CLIP-IQA score.
"""
# if normalize:
# true = np.array(true, dtype=np.float32)
# est = np.array(est, dtype=np.float32)
# true /= true.max()
# est /= est.max()

# Compute CLIP-IQA
with torch.no_grad():
# Resize images to 224x224 for CLIP-IQA
outputs_resized = F.interpolate(
est, size=(224, 224), mode="bilinear", align_corners=False
)

outputs_3d = outputs_resized

#clip_iqa_scores = self.clip_iqa(outputs_3d)


return clip_iqa_model(outputs_3d)

# Compute CLIP-IQA scores over the batch
clip_iqa = clip_iqa_scores.mean().item()
Loading