Skip to content

Commit

Permalink
feat: input and output multiple and different channels
Browse files Browse the repository at this point in the history
  • Loading branch information
beniz committed Sep 12, 2024
1 parent 92ad57d commit b39edda
Show file tree
Hide file tree
Showing 11 changed files with 163 additions and 33 deletions.
20 changes: 16 additions & 4 deletions data/aligned_dataset.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import sys
import os.path
from data.base_dataset import BaseDataset, get_params, get_transform
from data.utils import load_image
from data.image_folder import make_dataset
from PIL import Image
import tifffile


class AlignedDataset(BaseDataset):
Expand Down Expand Up @@ -32,6 +34,11 @@ def __init__(self, opt, phase, name=""):
"aligned dataset: domain A and domain B should have the same number of images"
)

if opt.data_image_bits > 8 and opt.model_input_nc > 1:
self.use_tiff = True # multi-channel images > 8bit
else:
self.use_tiff = False

def __getitem__(self, index):
"""Return a data point and its metadata information.
Expand All @@ -48,8 +55,13 @@ def __getitem__(self, index):
A_path = self.A_paths[index]
B_path = self.B_paths[index]

A = Image.open(A_path)
B = Image.open(B_path)
if self.use_tiff:
A = tifffile.imread(A_path)
B = tifffile.imread(B_path)
else:
A = Image.open(A_path)
B = Image.open(B_path)

if self.opt.data_image_bits == 8:
A = A.convert("RGB")
B = B.convert("RGB")
Expand All @@ -58,13 +70,13 @@ def __getitem__(self, index):
grayscale = False

# apply the same transform to both A and B
transform_params = get_params(self.opt, A.size)
transform_params = get_params(self.opt, A.shape[:2])

A_transform = get_transform(self.opt, transform_params, grayscale=grayscale)
B_transform = get_transform(self.opt, transform_params, grayscale=grayscale)

# resize B to A's size with PIL
if A.size != B.size:
if not self.use_tiff and A.size != B.size:
B = B.resize(A.size, Image.NEAREST)

A = A_transform(A)
Expand Down
8 changes: 4 additions & 4 deletions data/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,10 +395,10 @@ def get_transform(
)
)

if opt.data_preprocess == "none":
transform_list.append(
transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method))
)
# if opt.data_preprocess == "none":
# transform_list.append(
# transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method))
# )

if opt.dataaug_flip != "none":
if params is None:
Expand Down
6 changes: 6 additions & 0 deletions models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1639,6 +1639,12 @@ def compute_metrics_test(
real_tensor = real_tensor[:, 1]
fake_tensor = fake_tensor[:, 1]
lpips_test = self.lpips_metric(real_tensor, fake_tensor).mean()
elif real_tensor.shape[1] > 3: # 3+ channels
real_tensor_3c = real_tensor[:, :-1, :, :]
fake_tensor_3c = fake_tensor[:, :-1, :, :]
lpips_test = self.lpips_metric(
real_tensor_3c, fake_tensor_3c
).mean() ##TODO: per channel and sum
else:
lpips_test = self.lpips_metric(real_tensor, fake_tensor).mean()
setattr(self, "lpips_test_" + test_name, lpips_test)
Expand Down
42 changes: 37 additions & 5 deletions models/cut_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,19 @@ def compute_G_loss_cut(self):
"""Calculate NCE loss for the generator"""

# Fake losses
feat_q_pool, feat_k_pool = self.calculate_feats(self.real_A, self.fake_B)
if self.real_A.size(1) != self.fake_B.size(1):
# hack: fake_B and real_A do not have the same number of channels
diffc = self.fake_B.size(1) - self.real_A.size(1)
assert diffc > 0
add1 = torch.zeros(
self.real_A.size(0), 1, self.real_A.size(2), self.real_A.size(3)
).to(self.device)
fake_B_nc = self.fake_B
for c in range(diffc):
fake_B_nc = torch.cat((fake_B_nc, add1), dim=1)
feat_q_pool, feat_k_pool = self.calculate_feats(self.real_A, fake_B_nc)
else:
feat_q_pool, feat_k_pool = self.calculate_feats(self.real_A, self.fake_B)

if self.opt.alg_cut_lambda_SRC > 0.0 or self.opt.alg_cut_nce_loss == "SRC_hDCE":
self.loss_G_SRC, weight = self.calculate_R_loss(feat_q_pool, feat_k_pool)
Expand Down Expand Up @@ -741,13 +753,33 @@ def compute_G_loss_cut(self):
else:
self.loss_G_supervised_norm = 0
if "LPIPS" in self.opt.alg_cut_supervised_loss:
self.loss_G_supervised_lpips = torch.mean(
self.criterionLPIPS(self.real_B, self.fake_B)
)
if self.real_B.size(1) > 3: # more than 3 channels
self.loss_g_supervised_lpips = 0.0
for c in range(4): # per channel loss and sum
real_Bc = self.real_B[:, c, :, :].unsqueeze(1)
fake_Bc = self.fake_B[:, c, :, :].unsqueeze(1)
self.loss_G_supervised_lpips += self.criterionLPIPS(
real_B_c, fake_B_c
)
else:
self.loss_G_supervised_lpips = torch.mean(
self.criterionLPIPS(self.real_B, self.fake_B)
)
else:
self.loss_G_supervised_lpips = 0
if "DISTS" in self.opt.alg_cut_supervised_loss:
self.loss_G_supervised_dists = self.criterionDISTS(self.real_B, self.fake_B)
if self.real_B.size(1) > 3: # more than 3 channels
self.loss_G_supervised_dists = 0.0
for c in range(4): # per channel loss and sum
real_Bc = self.real_B[:, c, :, :].unsqueeze(1)
fake_Bc = self.fake_B[:, c, :, :].unsqueeze(1)
self.loss_G_supervised_dists += self.criterionDISTS(
real_Bc, fake_Bc
)
else:
self.loss_G_supervised_dists = torch.mean(
self.criterionDISTS(self.real_B, self.fake_B)
)
else:
self.loss_G_supervised_dists = 0

Expand Down
2 changes: 1 addition & 1 deletion models/diffusion_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def define_G(
if (
alg_diffusion_cond_embed != "" and alg_diffusion_cond_embed != "y_t"
) or alg_diffusion_task == "pix2pix":
in_channel *= 2
in_channel = model_input_nc + model_output_nc

if "mask" in alg_diffusion_cond_embed:
in_channel += alg_diffusion_cond_embed_dim
Expand Down
1 change: 0 additions & 1 deletion models/modules/cm_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,6 @@ def forward(
)

def restoration(self, y, y_cond, sigmas, mask, clip_denoised=True):

if mask is not None:
mask = torch.clamp(
mask, min=0.0, max=1.0
Expand Down
2 changes: 1 addition & 1 deletion models/modules/unet_generator_attn/unet_generator_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ def __init__(

if num_heads_upsample == -1:
num_heads_upsample = num_heads

self.image_size = image_size
self.in_channel = in_channel
self.inner_channel = inner_channel
Expand Down
5 changes: 2 additions & 3 deletions options/common_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,13 @@ def initialize(self, parser):
"--model_input_nc",
type=int,
default=3,
choices=[1, 3],
help="# of input image channels: 3 for RGB and 1 for grayscale",
help="# of input image channels: 3 for RGB and 1 for grayscale, more supported",
)
parser.add_argument(
"--model_output_nc",
type=int,
default=3,
choices=[1, 3],
# choices=[1, 3],
help="# of output image channels: 3 for RGB and 1 for grayscale",
)
parser.add_argument(
Expand Down
5 changes: 3 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
torch==2.3.0
torchvision==0.18.0
torch==2.4.0
torchvision==0.19.0
xformers
numpy==1.23.1
pluggy==1.3.0
Expand Down Expand Up @@ -34,3 +34,4 @@ transformers
diffusers==0.25.1
peft
bitsandbytes
tifffile
39 changes: 39 additions & 0 deletions util/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,3 +340,42 @@ def pairs_of_floats(arg):

def pairs_of_ints(arg):
return [int(x) for x in arg.split(",")]


def rgbn_float_img_to_8bits_display(rgbn_img, gamma: float = 1.0):
"""
rgbn float (0.,1.) to rgb 8bits + ngb 8bits
"""
rgb_img = rgbn_img[:, :, [0, 1, 2]]
if gamma != 1:
rgb_img = rgb_img**gamma

rgb_img = (rgb_img.clip(0.0, 1.0) * 255.0).astype(np.uint8)

nrg_img = rgbn_img[:, :, [3, 0, 1]]
if gamma != 1:
nrg_img = nrg_img**gamma

nrg_img = (nrg_img.clip(0.0, 1.0) * 255.0).astype(np.uint8)

return rgb_img, nrg_img


def img_12bits_to_float(img: np.ndarray) -> np.ndarray:
"""
convert img to np.float32 [0,1.]
"""
img = img.astype(np.float32) / 4095.0

return img


def pan_float_img_to_8bits_display(img, gamma=0.7):
"""
float (0.,1.) to 8bits
optionnal gamma transform
"""
if gamma != 1:
img = img**gamma

return (img.clip(0.0, 1.0) * 255.0).astype(np.uint8)
66 changes: 54 additions & 12 deletions util/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
import json
from torchinfo import summary
import math
from .util import (
rgbn_float_img_to_8bits_display,
img_12bits_to_float,
pan_float_img_to_8bits_display,
)

if sys.version_info[0] == 2:
VisdomExceptionBase = Exception
Expand Down Expand Up @@ -179,8 +184,8 @@ def display_current_results_visdom(

if ncols == 0:
ncols = max_ncol
else:
ncols = min(ncols, max_ncol)
# else:
# ncols = min(ncols, max_ncol)

h, w = next(iter(visuals[0].values())).shape[:2]
table_css = """<style>
Expand Down Expand Up @@ -212,12 +217,21 @@ def display_current_results_visdom(
else:
imtype = np.float32
image_numpy = util.tensor2im(image, imtype=imtype)
label_html_row += "<td>%s</td>" % label
if image_numpy.shape[2] == 5:
npos = 3
elif image_numpy.shape[2] == 4:
npos = 2
else:
npos = 1
images.append(image_numpy.transpose([2, 0, 1]))
idx += 1
if idx % ncols == 0:
label_html += "<tr>%s</tr>" % label_html_row
label_html_row = ""
pos = 0
while pos < npos:
label_html_row += "<td>%s</td>" % label_html_row
idx += 1
if idx % ncols == 0:
label_html += "<tr>%s</tr>" % label_html_row
label_html_row = ""
pos += 1
white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255
while idx % ncols != 0:
images.append(white_image)
Expand All @@ -240,11 +254,39 @@ def display_current_results_visdom(
if im.shape[0] == 3:
mapped_images.append(im)
continue
mapped_im = np.squeeze(gray_cm(im, bytes=True))
mapped_im = mapped_im.transpose([2, 0, 1])
# remove the alpha channel
mapped_im = mapped_im[:3, :, :]
mapped_images.append(mapped_im)
elif im.shape[0] == 1:
mapped_im = np.squeeze(gray_cm(im, bytes=True))
mapped_im = mapped_im.transpose([2, 0, 1])
# remove the alpha channel
mapped_im = mapped_im[:3, :, :]
mapped_images.append(mapped_im)
elif im.shape[0] == 5:
im = im.transpose([1, 2, 0])
c_im = im[:, :, [1, 2, 3, 4]]
rgb_im, nrg_im = rgbn_float_img_to_8bits_display(
c_im, gamma=0.7
)
rgb_im = rgb_im.transpose([2, 1, 0])
nrg_im = nrg_im.transpose([2, 1, 0])

pan_c1_im = im[:, :, 0]
pan_c1_im = np.squeeze(gray_cm(pan_c1_im, bytes=True))
pan_c1_im = pan_c1_im.transpose([2, 1, 0])
pan_c1_im = pan_c1_im[:3, :, :]
mapped_images.append(pan_c1_im)

mapped_images.append(rgb_im)
mapped_images.append(nrg_im)
elif im.shape[0] == 4:
im = im.transpose([1, 2, 0])
rgb, ngr = rgbn_float_img_to_8bits_display(
im, gamma=0.7
)
rgb = rgb.transpose([2, 1, 0])
ngr = ngr.transpose([2, 1, 0])
mapped_images.append(rgb)
mapped_images.append(ngr)

images = mapped_images

self.vis.images(
Expand Down

0 comments on commit b39edda

Please sign in to comment.