Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: reference image conditioning [2306.08276] #491

Merged
merged 4 commits into from
Sep 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions data/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ def __init__(self, opt, phase):
self.use_domain_B = not "self_supervised" in self.opt.data_dataset_mode

self.root = opt.dataroot

if not self.root.endswith("/"):
self.root += "/"

self.sv_dir = os.path.join(opt.checkpoints_dir, opt.name)
self.warning_mode = self.opt.warning_mode
self.set_dataset_dirs_and_dims()
Expand Down Expand Up @@ -342,6 +346,37 @@ def get_transform(
return transforms.Compose(transform_list)


def get_transform_ref(
opt,
params=None,
grayscale=False,
method=InterpolationMode.BICUBIC,
convert=True,
crop=True,
):

transform_list = []

if grayscale:
transform_list.append(transforms.Grayscale(1))

osize = [opt.data_crop_size, opt.data_crop_size]
transform_list.append(transforms.Resize(osize, interpolation=method))

if convert:
transform_list += [transforms.ToTensor()]
"""if grayscale:
transform_list += [transforms.Normalize((0.5,), (0.5,))]
else:
transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]"""

transforms.Normalize(
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711),
)
return transforms.Compose(transform_list)


def __make_power_2(img, base, method=InterpolationMode.BICUBIC):
ow, oh = img.size
h = int(round(oh / base) * base)
Expand Down
63 changes: 56 additions & 7 deletions data/image_folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,18 +89,24 @@ def make_labeled_path_dataset(dir, paths, max_dataset_size=float("inf")):
for line in paths_list:
line_split = line.split(" ")

if len(line_split) == 2:
if (
len(line_split) == 1 and len(line_split[0]) > 0
): # we allow B not having a label
images.append(line_split[0])
labels.append(line_split[1])
if len(line_split) == 3:

elif len(line_split) == 2:
images.append(line_split[0])
labels.append(line_split[1] + " " + line_split[2])
labels.append(line_split[1])

elif (
len(line_split) == 1 and len(line_split[0]) > 0
): # we allow B not having a label
elif len(line_split) > 2:
images.append(line_split[0])

label_line = line_split[1]
for i in range(2, len(line_split)):
label_line += " " + line_split[i]

labels.append(label_line)

return (
images[: min(max_dataset_size, len(images))],
labels[: min(max_dataset_size, len(images))],
Expand All @@ -123,6 +129,49 @@ def make_dataset_path(dir, paths, max_dataset_size=float("inf")):
return images[: min(max_dataset_size, len(images))]


def make_ref_path(dir, paths, max_dataset_size=float("inf")):
ref = {}
assert os.path.isdir(dir), "%s is not a valid directory" % dir

with open(dir + paths, "r") as f:
paths_list = f.read().split("\n")

for line in paths_list:
line_split = line.split(" ")

if len(line_split) == 2:
ref[line_split[0]] = line_split[1]

return ref


def make_ref_path_list(dir, paths, max_dataset_size=float("inf")):
ref = {}
assert os.path.isdir(dir), "%s is not a valid directory" % dir

with open(dir + paths, "r") as f:
paths_list = f.read().split("\n")

root = "/".join(dir.split("/")[:-1])

for line in paths_list:
line_split = line.split(" ")

if len(line_split) == 2:
path_to_ref = line_split[1]

path = os.path.join(root, path_to_ref)

with open(path, "r") as f:
paths_ref_list = f.read().split("\n")

paths_ref_list.remove("")

ref[line_split[0]] = paths_ref_list

return ref


def default_loader(path):
return Image.open(path).convert("RGB")

Expand Down
2 changes: 1 addition & 1 deletion data/online_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def crop_image(
int(ref_bbox[3] * (output_dim + margin) / crop_size),
]

return img, mask, ref_bbox
return img, mask, ref_bbox, idx_bbox_ref


def fill_mask_with_random(img, mask, cls):
Expand Down
69 changes: 69 additions & 0 deletions data/self_supervised_labeled_mask_online_ref_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import os.path
from data.unaligned_labeled_mask_online_ref_dataset import (
UnalignedLabeledMaskOnlineRefDataset,
)
from data.online_creation import fill_mask_with_random, fill_mask_with_color
from PIL import Image
import numpy as np
import torch
import warnings


class SelfSupervisedLabeledMaskOnlineRefDataset(UnalignedLabeledMaskOnlineRefDataset):
"""
This dataset class can create paired datasets with mask labels from only one domain.
"""

def __init__(self, opt, phase):
"""Initialize this dataset class.

Parameters:
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
"""
super().__init__(opt, phase)

def get_img(
self,
A_img_path,
A_label_mask_path,
A_label_cls,
B_img_path=None,
B_label_mask_path=None,
B_label_cls=None,
index=None,
):
result = super().get_img(
A_img_path,
A_label_mask_path,
A_label_cls,
B_img_path,
B_label_mask_path,
B_label_cls,
index,
clamp_semantics=False,
)

try:

if self.opt.data_online_creation_rand_mask_A:
A_img = fill_mask_with_random(result["A"], result["A_label_mask"], -1)
elif self.opt.data_online_creation_color_mask_A:
A_img = fill_mask_with_color(result["A"], result["A_label_mask"], {})
else:
raise Exception(
"self supervised dataset: no self supervised method specified"
)

result.update(
{
"A": A_img,
"B": result["A"],
"B_img_paths": result["A_img_paths"],
"B_label_mask": result["A_label_mask"].clone(),
}
)
except Exception as e:
print(e, "self supervised data loading")
return None

return result
67 changes: 67 additions & 0 deletions data/self_supervised_labeled_mask_ref_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import os.path
from data.unaligned_labeled_mask_ref_dataset import UnalignedLabeledMaskRefDataset
from data.online_creation import fill_mask_with_random, fill_mask_with_color
from PIL import Image
import numpy as np
import torch
import warnings


class SelfSupervisedLabeledMaskRefDataset(UnalignedLabeledMaskRefDataset):
"""
This dataset class can create paired datasets with mask labels from only one domain.
"""

def __init__(self, opt, phase):
"""Initialize this dataset class.

Parameters:
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
"""
super().__init__(opt, phase)

def get_img(
self,
A_img_path,
A_label_mask_path,
A_label_cls,
B_img_path=None,
B_label_mask_path=None,
B_label_cls=None,
index=None,
):
result = super().get_img(
A_img_path,
A_label_mask_path,
A_label_cls,
B_img_path,
B_label_mask_path,
B_label_cls,
index,
clamp_semantics=False,
)

try:

if self.opt.data_online_creation_rand_mask_A:
A_img = fill_mask_with_random(result["A"], result["A_label_mask"], -1)
elif self.opt.data_online_creation_color_mask_A:
A_img = fill_mask_with_color(result["A"], result["A_label_mask"], {})
else:
raise Exception(
"self supervised dataset: no self supervised method specified"
)

result.update(
{
"A": A_img,
"B": result["A"],
"B_img_paths": result["A_img_paths"],
"B_label_mask": result["A_label_mask"].clone(),
}
)
except Exception as e:
print(e, "self supervised data loading")
return None

return result
7 changes: 5 additions & 2 deletions data/unaligned_labeled_mask_online_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def get_img(
else:
mask_delta_A = self.opt.data_online_creation_mask_delta_A_ratio

A_img, A_label_mask, A_ref_bbox = crop_image(
A_img, A_label_mask, A_ref_bbox, A_ref_bbox_id = crop_image(
A_img_path,
A_label_mask_path,
mask_delta=mask_delta_A,
Expand All @@ -190,6 +190,7 @@ def get_img(
inverted_mask=self.opt.data_inverted_mask,
single_bbox=self.opt.data_online_single_bbox,
)

self.cat_A_ref_bbox = torch.tensor(A_ref_bbox[0])
A_ref_bbox = A_ref_bbox[1:]

Expand All @@ -213,6 +214,7 @@ def get_img(
"A_img_paths": A_img_path,
"A_label_mask": A_label_mask,
"A_ref_bbox": A_ref_bbox,
"A_ref_bbox_id": A_ref_bbox_id,
}

# Domain B
Expand All @@ -227,7 +229,7 @@ def get_img(
mask_delta_B = self.opt.data_online_creation_mask_delta_B_ratio

if B_label_mask_path is not None:
B_img, B_label_mask, B_ref_bbox = crop_image(
B_img, B_label_mask, B_ref_bbox, B_ref_bbox_id = crop_image(
B_img_path,
B_label_mask_path,
mask_delta=mask_delta_B,
Expand Down Expand Up @@ -283,6 +285,7 @@ def get_img(
{
"B_label_mask": B_label_mask,
"B_ref_bbox": B_ref_bbox,
"B_ref_bbox_id": B_ref_bbox_id,
}
)

Expand Down
66 changes: 66 additions & 0 deletions data/unaligned_labeled_mask_online_ref_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import os
from PIL import Image

from data.base_dataset import get_transform_ref
from data.unaligned_labeled_mask_online_dataset import UnalignedLabeledMaskOnlineDataset
from data.image_folder import make_ref_path_list


class UnalignedLabeledMaskOnlineRefDataset(UnalignedLabeledMaskOnlineDataset):
def __init__(self, opt, phase):
super().__init__(opt, phase)

self.A_img_ref = make_ref_path_list(self.dir_A, "/conditions.txt")

self.transform_ref = get_transform_ref(opt)

def get_img(
self,
A_img_path,
A_label_mask_path,
A_label_cls,
B_img_path=None,
B_label_mask_path=None,
B_label_cls=None,
index=None,
clamp_semantics=True,
):
result = super().get_img(
A_img_path,
A_label_mask_path,
A_label_cls,
B_img_path,
B_label_mask_path,
B_label_cls,
index,
clamp_semantics,
)

img_path = result["A_img_paths"]

if self.opt.data_relative_paths:
img_path = img_path.replace(self.root, "")

A_ref_bbox_id = result["A_ref_bbox_id"]

ref_A_path = self.A_img_ref[img_path][A_ref_bbox_id]

if self.opt.data_relative_paths:
ref_A_path = os.path.join(self.root, ref_A_path)

try:
ref_A = Image.open(ref_A_path).convert("RGB")

except Exception as e:
print(
"failure with reading A domain image ref ",
ref_A_path,
)
print(e)
return None

ref_A = self.transform_ref(ref_A)

result.update({"ref_A": ref_A})

return result
Loading
Loading