Skip to content

Commit

Permalink
feat: reference image conditioning
Browse files Browse the repository at this point in the history
  • Loading branch information
pnsuau committed Jul 18, 2023
1 parent 7819fdf commit 6c92cab
Show file tree
Hide file tree
Showing 13 changed files with 1,982 additions and 19 deletions.
35 changes: 35 additions & 0 deletions data/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ def __init__(self, opt, phase):
self.use_domain_B = not "self_supervised" in self.opt.data_dataset_mode

self.root = opt.dataroot

if not self.root.endswith("/"):
self.root += "/"

self.sv_dir = os.path.join(opt.checkpoints_dir, opt.name)
self.warning_mode = self.opt.warning_mode
self.set_dataset_dirs_and_dims()
Expand Down Expand Up @@ -342,6 +346,37 @@ def get_transform(
return transforms.Compose(transform_list)


def get_transform_ref(
opt,
params=None,
grayscale=False,
method=InterpolationMode.BICUBIC,
convert=True,
crop=True,
):

transform_list = []

if grayscale:
transform_list.append(transforms.Grayscale(1))

osize = [opt.data_crop_size, opt.data_crop_size]
transform_list.append(transforms.Resize(osize, interpolation=method))

if convert:
transform_list += [transforms.ToTensor()]
"""if grayscale:
transform_list += [transforms.Normalize((0.5,), (0.5,))]
else:
transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]"""

transforms.Normalize(
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711),
)
return transforms.Compose(transform_list)


def __make_power_2(img, base, method=InterpolationMode.BICUBIC):
ow, oh = img.size
h = int(round(oh / base) * base)
Expand Down
36 changes: 29 additions & 7 deletions data/image_folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,18 +89,24 @@ def make_labeled_path_dataset(dir, paths, max_dataset_size=float("inf")):
for line in paths_list:
line_split = line.split(" ")

if len(line_split) == 2:
if (
len(line_split) == 1 and len(line_split[0]) > 0
): # we allow B not having a label
images.append(line_split[0])
labels.append(line_split[1])
if len(line_split) == 3:

elif len(line_split) == 2:
images.append(line_split[0])
labels.append(line_split[1] + " " + line_split[2])
labels.append(line_split[1])

elif (
len(line_split) == 1 and len(line_split[0]) > 0
): # we allow B not having a label
elif len(line_split) > 2:
images.append(line_split[0])

label_line = line_split[1]
for i in range(2, len(line_split)):
label_line += " " + line_split[i]

labels.append(label_line)

return (
images[: min(max_dataset_size, len(images))],
labels[: min(max_dataset_size, len(images))],
Expand All @@ -123,6 +129,22 @@ def make_dataset_path(dir, paths, max_dataset_size=float("inf")):
return images[: min(max_dataset_size, len(images))]


def make_ref_path(dir, paths, max_dataset_size=float("inf")):
ref = {}
assert os.path.isdir(dir), "%s is not a valid directory" % dir

with open(dir + paths, "r") as f:
paths_list = f.read().split("\n")

for line in paths_list:
line_split = line.split(" ")

if len(line_split) == 2:
ref[line_split[0]] = line_split[1]

return ref


def default_loader(path):
return Image.open(path).convert("RGB")

Expand Down
67 changes: 67 additions & 0 deletions data/self_supervised_labeled_mask_ref_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import os.path
from data.unaligned_labeled_mask_ref_dataset import UnalignedLabeledMaskRefDataset
from data.online_creation import fill_mask_with_random, fill_mask_with_color
from PIL import Image
import numpy as np
import torch
import warnings


class SelfSupervisedLabeledMaskRefDataset(UnalignedLabeledMaskRefDataset):
"""
This dataset class can create paired datasets with mask labels from only one domain.
"""

def __init__(self, opt, phase):
"""Initialize this dataset class.
Parameters:
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
"""
super().__init__(opt, phase)

def get_img(
self,
A_img_path,
A_label_mask_path,
A_label_cls,
B_img_path=None,
B_label_mask_path=None,
B_label_cls=None,
index=None,
):
result = super().get_img(
A_img_path,
A_label_mask_path,
A_label_cls,
B_img_path,
B_label_mask_path,
B_label_cls,
index,
clamp_semantics=False,
)

try:

if self.opt.data_online_creation_rand_mask_A:
A_img = fill_mask_with_random(result["A"], result["A_label_mask"], -1)
elif self.opt.data_online_creation_color_mask_A:
A_img = fill_mask_with_color(result["A"], result["A_label_mask"], {})
else:
raise Exception(
"self supervised dataset: no self supervised method specified"
)

result.update(
{
"A": A_img,
"B": result["A"],
"B_img_paths": result["A_img_paths"],
"B_label_mask": result["A_label_mask"].clone(),
}
)
except Exception as e:
print(e, "self supervised data loading")
return None

return result
67 changes: 67 additions & 0 deletions data/unaligned_labeled_mask_ref_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import os

from torchvision.transforms.functional import resize
from PIL import Image

from data.base_dataset import get_transform_ref
from data.unaligned_labeled_mask_dataset import UnalignedLabeledMaskDataset
from data.image_folder import make_ref_path


class UnalignedLabeledMaskRefDataset(UnalignedLabeledMaskDataset):
def __init__(self, opt, phase):
super().__init__(opt, phase)

self.A_img_ref = make_ref_path(self.dir_A, "/conditions.txt")

self.transform_ref = get_transform_ref(opt)

def get_img(
self,
A_img_path,
A_label_mask_path,
A_label_cls,
B_img_path=None,
B_label_mask_path=None,
B_label_cls=None,
index=None,
clamp_semantics=True,
):

result = super().get_img(
A_img_path,
A_label_mask_path,
A_label_cls,
B_img_path,
B_label_mask_path,
B_label_cls,
index,
clamp_semantics,
)

img_path = result["A_img_paths"]

if self.opt.data_relative_paths:
img_path = img_path.replace(self.root, "")

ref_A_path = self.A_img_ref[img_path]

if self.opt.data_relative_paths:
ref_A_path = os.path.join(self.root, ref_A_path)

try:
ref_A = Image.open(ref_A_path).convert("RGB")

except Exception as e:
print(
"failure with reading A domain image ref ",
ref_A_path,
)
print(e)
return None

ref_A = self.transform_ref(ref_A)

result.update({"ref_A": ref_A})

return result
23 changes: 18 additions & 5 deletions models/modules/diffusion_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(
else:
self.cond_embed_dim = cond_embed_dim

if "class" in self.denoise_fn.conditioning:
if any(cond in self.denoise_fn.conditioning for cond in ["class", "ref"]):
self.cond_embed_gammas = self.cond_embed_dim // 2
else:
self.cond_embed_gammas = self.cond_embed_dim
Expand All @@ -86,6 +86,7 @@ def restoration(
mask=None,
sample_num=8,
cls=None,
ref=None,
guidance_scale=0.0,
ddim_num_steps=10,
ddim_eta=0.5,
Expand Down Expand Up @@ -149,6 +150,7 @@ def restoration_ddpm(
phase=phase,
cls=cls,
mask=mask,
ref=ref,
guidance_scale=guidance_scale,
)

Expand Down Expand Up @@ -180,6 +182,7 @@ def p_mean_variance(
clip_denoised: bool,
cls,
mask,
ref,
y_cond=None,
guidance_scale=0.0,
):
Expand All @@ -197,7 +200,11 @@ def p_mean_variance(
y_t,
t=t,
noise=self.denoise_fn(
input, torch.zeros_like(embed_noise_level), cls=None, mask=None
input,
torch.zeros_like(embed_noise_level),
cls=None,
mask=None,
ref=ref,
),
phase=phase,
)
Expand All @@ -206,7 +213,9 @@ def p_mean_variance(
self.denoise_fn.model,
y_t,
t=t,
noise=self.denoise_fn(input, embed_noise_level, cls=cls, mask=mask),
noise=self.denoise_fn(
input, embed_noise_level, cls=cls, mask=mask, ref=ref
),
phase=phase,
)

Expand All @@ -232,6 +241,7 @@ def p_sample(
phase,
cls,
mask,
ref,
clip_denoised=True,
y_cond=None,
guidance_scale=0.0,
Expand All @@ -245,6 +255,7 @@ def p_sample(
phase=phase,
cls=cls,
mask=mask,
ref=ref,
guidance_scale=guidance_scale,
)

Expand Down Expand Up @@ -409,7 +420,7 @@ def ddim_p_mean_variance(

return model_mean, posterior_log_variance

def forward(self, y_0, y_cond, mask, noise, cls, dropout_prob=0.0):
def forward(self, y_0, y_cond, mask, noise, cls, ref, dropout_prob=0.0):

b, *_ = y_0.shape
t = torch.randint(
Expand Down Expand Up @@ -438,7 +449,9 @@ def forward(self, y_0, y_cond, mask, noise, cls, dropout_prob=0.0):

input = torch.cat([y_cond, y_noisy], dim=1)

noise_hat = self.denoise_fn(input, embed_sample_gammas, cls=cls, mask=mask)
noise_hat = self.denoise_fn(
input, embed_sample_gammas, cls=cls, mask=mask, ref=ref
)

return noise, noise_hat

Expand Down
Empty file.
Loading

0 comments on commit 6c92cab

Please sign in to comment.