Skip to content

Commit

Permalink
Fix resize. (#97)
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzam authored Nov 17, 2023
1 parent 4dbeaf7 commit 8878bc7
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 135 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ Bugfix

- Support for unrolled reconstruction with grayscale, needed to copy to three channels for LPIPS.
- Fix bad train/test split for DiffuserCamMirflickr in unrolled training.
- Resize utility.


1.0.5 - (2023-09-05)
Expand Down
160 changes: 25 additions & 135 deletions lensless/utils/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@


import cv2
import scipy.signal
import numpy as np
from lensless.hardware.constants import RPI_HQ_CAMERA_CCM_MATRIX, RPI_HQ_CAMERA_BLACK_LEVEL

Expand Down Expand Up @@ -49,23 +50,23 @@ def resize(img, factor=None, shape=None, interpolation=cv2.INTER_CUBIC):
img_shape = np.array(img.shape)[-3:-1]

assert not ((factor is None) and (shape is None)), "Must specify either factor or shape"
new_shape = tuple((img_shape * factor).astype(int)) if (shape is None) else shape[-3:-1]
new_shape = tuple(img_shape * factor) if shape is None else shape[-3:-1]
new_shape = [int(i) for i in new_shape]

if np.array_equal(img_shape, new_shape):
return img

if torch_available:
# torch resize expects an input of form [color, depth, width, height]
tmp = np.moveaxis(img, -1, 0)
resized = tf.Resize(size=new_shape, interpolation=interpolation)(
torch.from_numpy(tmp)
).numpy()
tmp = torch.from_numpy(tmp.copy())
resized = tf.Resize(size=new_shape, antialias=True)(tmp).numpy()
resized = np.moveaxis(resized, 0, -1)

else:
resized = np.array(
[
cv2.resize(img[i], dsize=new_shape[::-1], interpolation=interpolation)
cv2.resize(img[i], dsize=tuple(new_shape[::-1]), interpolation=interpolation)
for i in range(img.shape[-4])
]
)
Expand Down Expand Up @@ -327,6 +328,25 @@ def autocorr2d(vals, pad_mode="reflect"):
return autocorr[shape[0] // 2 : -shape[0] // 2, shape[1] // 2 : -shape[1] // 2]


def corr2d(im1, im2):
"""
Source: https://stackoverflow.com/a/24769222
"""

# get rid of the color channels by performing a grayscale transform
# the type cast into 'float' is to avoid overflows
im1_gray = np.sum(im1.astype("float"), axis=2)
im2_gray = np.sum(im2.astype("float"), axis=2)

# get rid of the averages, otherwise the results are not good
im1_gray -= np.mean(im1_gray)
im2_gray -= np.mean(im2_gray)

# calculate the correlation image; note the flipping of onw of the images
return scipy.signal.fftconvolve(im1_gray, im2_gray[::-1, ::-1], mode="same")


def rgb2bayer(img, pattern):
"""
Converting RGB image to separated Bayer channels.
Expand Down Expand Up @@ -442,133 +462,3 @@ def bayer2rgb(X_bayer, pattern):
X_rgb[:, :, 2] = X_bayer[:, :, b_channel]

return X_rgb


def load_drunet(model_path, n_channels=3, requires_grad=False):
"""
Load a pre-trained Drunet model.
Parameters
----------
model_path : str
Path to pre-trained model.
n_channels : int
Number of channels in input image.
requires_grad : bool
Whether to require gradients for model parameters.
Returns
-------
model : :py:class:`~torch.nn.Module`
Loaded model.
"""
from lensless.recon.drunet.network_unet import UNetRes

model = UNetRes(
in_nc=n_channels + 1,
out_nc=n_channels,
nc=[64, 128, 256, 512],
nb=4,
act_mode="R",
downsample_mode="strideconv",
upsample_mode="convtranspose",
)
model.load_state_dict(torch.load(model_path), strict=True)
model.eval()
for k, v in model.named_parameters():
v.requires_grad = requires_grad

return model


def apply_denoiser(model, image, noise_level=10, device="cpu", mode="inference"):
"""
Apply a pre-trained denoising model with input in the format Channel, Height, Width.
An additionnal channel is added for the noise level as done in Drunet.
Parameters
----------
model : :py:class:`~torch.nn.Module`
Drunet compatible model. Its input must concist of 4 channels ( RGB + noise level) and outbut an RGB image both in CHW format.
image : :py:class:`~torch.Tensor`
Input image.
noise_level : float or :py:class:`~torch.Tensor`
Noise level in the image.
device : str
Device to use for computation. Can be "cpu" or "cuda".
mode : str
Mode to use for model. Can be "inference" or "train".
Returns
-------
image : :py:class:`~torch.Tensor`
Reconstructed image.
"""
# convert from NDHWC to NCHW
depth = image.shape[-4]
image = image.movedim(-1, -3)
image = image.reshape(-1, *image.shape[-3:])
# pad image H and W to next multiple of 8
top = (8 - image.shape[-2] % 8) // 2
bottom = (8 - image.shape[-2] % 8) - top
left = (8 - image.shape[-1] % 8) // 2
right = (8 - image.shape[-1] % 8) - left
image = torch.nn.functional.pad(image, (left, right, top, bottom), mode="constant", value=0)
# add noise level as extra channel
image = image.to(device)
if isinstance(noise_level, torch.Tensor):
noise_level = noise_level / 255.0
else:
noise_level = torch.tensor([noise_level / 255.0]).to(device)
image = torch.cat(
(
image,
noise_level.repeat(image.shape[0], 1, image.shape[2], image.shape[3]),
),
dim=1,
)

# apply model
if mode == "inference":
with torch.no_grad():
image = model(image)
elif mode == "train":
image = model(image)
else:
raise ValueError("mode must be 'inference' or 'train'")

# remove padding
image = image[:, :, top:-bottom, left:-right]
# convert back to NDHWC
image = image.movedim(-3, -1)
image = image.reshape(-1, depth, *image.shape[-3:])
return image


def process_with_DruNet(model, device="cpu", mode="inference"):
"""
Return a porcessing function that applies the DruNet model to an image.
Parameters
----------
model : torch.nn.Module
DruNet like denoiser model
device : str
Device to use for computation. Can be "cpu" or "cuda".
mode : str
Mode to use for model. Can be "inference" or "train".
"""

def process(image, noise_level):
x_max = torch.amax(image, dim=(-2, -3), keepdim=True) + 1e-6
image = apply_denoiser(
model,
image,
noise_level=noise_level,
device=device,
mode="train",
)
image = torch.clip(image, min=0.0) * x_max
return image

return process

0 comments on commit 8878bc7

Please sign in to comment.