Skip to content

Commit

Permalink
Merge branch 'jolibrain:master' into feat_api_predict
Browse files Browse the repository at this point in the history
  • Loading branch information
alx authored Aug 21, 2023
2 parents 3087c55 + bd16f1e commit 9dd1a69
Show file tree
Hide file tree
Showing 48 changed files with 1,935 additions and 565 deletions.
82 changes: 81 additions & 1 deletion data/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@

import torchvision.transforms.functional as F

from data.image_folder import make_dataset, make_dataset_path, make_labeled_path_dataset
from data.online_creation import sanitize_paths, write_paths_file

from abc import ABC, abstractmethod
import imgaug as ia
import imgaug.augmenters as iaa
Expand Down Expand Up @@ -137,7 +140,7 @@ def __getitem__(self, index):
if B_label_mask_path is not None:
B_label_mask_path = os.path.join(self.root, B_label_mask_path)

return self.get_img(
results = self.get_img(
A_img_path,
A_label_mask_path,
A_label_cls,
Expand All @@ -147,6 +150,8 @@ def __getitem__(self, index):
index,
)

return results

def set_dataset_dirs_and_dims(self):
btoA = self.opt.data_direction == "BtoA"
self.input_nc = (
Expand Down Expand Up @@ -249,6 +254,81 @@ def get_validation_set(self, size):

return return_A_list, return_B_list

def sanitize(self):
paths_sanitized_train_A = os.path.join(
self.sv_dir, "paths_sanitized_train_A.txt"
)
if hasattr(self, "B_img_paths"):
paths_sanitized_train_B = os.path.join(
self.sv_dir, "paths_sanitized_train_B.txt"
)
if hasattr(self, "B_img_paths"):
train_sanitized_exist = os.path.exists(
paths_sanitized_train_A
) and os.path.exists(paths_sanitized_train_B)
else:
train_sanitized_exist = os.path.exists(paths_sanitized_train_A)

if train_sanitized_exist:
self.A_img_paths, self.A_label_mask_paths = make_labeled_path_dataset(
self.sv_dir, "/paths_sanitized_train_A.txt"
)
if hasattr(self, "B_img_paths"):
self.B_img_paths, self.B_label_mask_paths = make_labeled_path_dataset(
self.sv_dir, "/paths_sanitized_train_B.txt"
)
else:
print("--------------")
print("Sanitizing images and labels paths")
print("--- DOMAIN A ---")

self.A_img_paths, self.A_label_mask_paths = sanitize_paths(
self.A_img_paths,
self.A_label_mask_paths,
mask_delta=self.opt.data_online_creation_mask_delta_A,
mask_random_offset=self.opt.data_online_creation_mask_random_offset_A,
crop_delta=self.opt.data_online_creation_crop_delta_A,
mask_square=self.opt.data_online_creation_mask_square_A,
crop_dim=self.opt.data_online_creation_crop_size_A,
output_dim=self.opt.data_load_size,
max_dataset_size=self.opt.data_max_dataset_size,
context_pixels=self.opt.data_online_context_pixels,
load_size=self.opt.data_online_creation_load_size_A,
select_cat=self.opt.data_online_select_category,
data_relative_paths=self.opt.data_relative_paths,
data_root_path=self.opt.dataroot,
)
write_paths_file(
self.A_img_paths,
self.A_label_mask_paths,
paths_sanitized_train_A,
)

print("--- DOMAIN B ---")
if hasattr(self, "B_img_paths"):
self.B_img_paths, self.B_label_mask_paths = sanitize_paths(
self.B_img_paths,
self.B_label_mask_paths,
mask_delta=self.opt.data_online_creation_mask_delta_B,
mask_random_offset=self.opt.data_online_creation_mask_random_offset_B,
crop_delta=self.opt.data_online_creation_crop_delta_B,
mask_square=self.opt.data_online_creation_mask_square_B,
crop_dim=self.opt.data_online_creation_crop_size_B,
output_dim=self.opt.data_load_size,
max_dataset_size=self.opt.data_max_dataset_size,
context_pixels=self.opt.data_online_context_pixels,
load_size=self.opt.data_online_creation_load_size_B,
data_relative_paths=self.opt.data_relative_paths,
data_root_path=self.opt.dataroot,
)
write_paths_file(
self.B_img_paths,
self.B_label_mask_paths,
paths_sanitized_train_B,
)

print("--------------")


def get_params(opt, size):
w, h = size
Expand Down
25 changes: 13 additions & 12 deletions data/self_supervised_temporal_dataset.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import torch


from data.temporal_dataset import TemporalDataset
from data.temporal_labeled_mask_online_dataset import TemporalLabeledMaskOnlineDataset
from data.online_creation import fill_mask_with_random, fill_mask_with_color


class SelfSupervisedTemporalDataset(TemporalDataset):
class SelfSupervisedTemporalDataset(TemporalLabeledMaskOnlineDataset):
"""
This dataset class can create datasets with mask labels from one domain.
"""
Expand All @@ -28,15 +27,17 @@ def get_img(
B_label_cls=None,
index=None,
):
result = super().get_img(
A_img_path,
A_label_mask_path,
A_label_cls,
B_img_path,
B_label_mask_path,
B_label_cls,
index,
)
result = None
while result is None:
result = super().get_img(
A_img_path,
A_label_mask_path,
A_label_cls,
B_img_path,
B_label_mask_path,
B_label_cls,
index,
)

try:
A_img_list = [result["A"][0]]
Expand Down
25 changes: 12 additions & 13 deletions data/temporal_labeled_mask_online_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,13 @@ def __init__(self, opt, phase):
self.B_img_paths.sort(key=natural_keys)
self.B_label_mask_paths.sort(key=natural_keys)

self.A_img_paths, self.A_label_mask_paths = (
self.A_img_paths[: opt.data_max_dataset_size],
self.A_label_mask_paths[: opt.data_max_dataset_size],
)
if self.opt.data_sanitize_paths:
self.sanitize()
elif opt.data_max_dataset_size != float("inf"):
self.A_img_paths, self.A_label_mask_paths = (
self.A_img_paths[: opt.data_max_dataset_size],
self.A_label_mask_paths[: opt.data_max_dataset_size],
)

if self.use_domain_B:
self.B_img_paths, self.B_label_mask_paths = (
Expand Down Expand Up @@ -118,13 +121,11 @@ def get_img(
cur_A_label_path = os.path.join(self.root, cur_A_label_path)

try:
if (
len(self.opt.data_online_creation_mask_delta_A_ratio[0]) == 1
and self.opt.data_online_creation_mask_delta_A_ratio[0][0] == 0
):
if self.opt.data_online_creation_mask_delta_A_ratio == [[]]:
mask_delta_A = self.opt.data_online_creation_mask_delta_A
else:
mask_delta_A = self.opt.data_online_creation_mask_delta_A_ratio

if i == 0:
crop_coordinates = crop_image(
cur_A_img_path,
Expand All @@ -140,6 +141,7 @@ def get_img(
get_crop_coordinates=True,
fixed_mask_size=self.opt.data_online_fixed_mask_size,
)

cur_A_img, cur_A_label, ref_A_bbox = crop_image(
cur_A_img_path,
cur_A_label_path,
Expand Down Expand Up @@ -201,10 +203,7 @@ def get_img(
cur_B_label_path = os.path.join(self.root, cur_B_label_path)

try:
if (
len(self.opt.data_online_creation_mask_delta_B_ratio[0]) == 1
and self.opt.data_online_creation_mask_delta_B_ratio[0][0] == 0
):
if self.opt.data_online_creation_mask_delta_B_ratio == [[]]:
mask_delta_B = self.opt.data_online_creation_mask_delta_B
else:
mask_delta_B = self.opt.data_online_creation_mask_delta_B_ratio
Expand Down Expand Up @@ -254,7 +253,7 @@ def get_img(
else:
images_B = None
labels_B = None
ref_B_img_path = None
ref_B_img_path = ""

result = {
"A": images_A,
Expand Down
88 changes: 3 additions & 85 deletions data/unaligned_labeled_mask_online_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from data.base_dataset import BaseDataset, get_transform, get_transform_seg
from data.image_folder import make_dataset, make_dataset_path, make_labeled_path_dataset
from data.online_creation import crop_image, sanitize_paths, write_paths_file
from data.online_creation import crop_image


class UnalignedLabeledMaskOnlineDataset(BaseDataset):
Expand Down Expand Up @@ -77,81 +77,6 @@ def __init__(self, opt, phase):

self.header = ["img", "mask"]

def sanitize(self):
paths_sanitized_train_A = os.path.join(
self.sv_dir, "paths_sanitized_train_A.txt"
)
if hasattr(self, "B_img_paths"):
paths_sanitized_train_B = os.path.join(
self.sv_dir, "paths_sanitized_train_B.txt"
)
if hasattr(self, "B_img_paths"):
train_sanitized_exist = os.path.exists(
paths_sanitized_train_A
) and os.path.exists(paths_sanitized_train_B)
else:
train_sanitized_exist = os.path.exists(paths_sanitized_train_A)

if train_sanitized_exist:
self.A_img_paths, self.A_label_mask_paths = make_labeled_path_dataset(
self.sv_dir, "/paths_sanitized_train_A.txt"
)
if hasattr(self, "B_img_paths"):
self.B_img_paths, self.B_label_mask_paths = make_labeled_path_dataset(
self.sv_dir, "/paths_sanitized_train_B.txt"
)
else:
print("--------------")
print("Sanitizing images and labels paths")
print("--- DOMAIN A ---")

self.A_img_paths, self.A_label_mask_paths = sanitize_paths(
self.A_img_paths,
self.A_label_mask_paths,
mask_delta=self.opt.data_online_creation_mask_delta_A,
mask_random_offset=self.opt.data_online_creation_mask_random_offset_A,
crop_delta=self.opt.data_online_creation_crop_delta_A,
mask_square=self.opt.data_online_creation_mask_square_A,
crop_dim=self.opt.data_online_creation_crop_size_A,
output_dim=self.opt.data_load_size,
max_dataset_size=self.opt.data_max_dataset_size,
context_pixels=self.opt.data_online_context_pixels,
load_size=self.opt.data_online_creation_load_size_A,
select_cat=self.opt.data_online_select_category,
data_relative_paths=self.opt.data_relative_paths,
data_root_path=self.opt.dataroot,
)
write_paths_file(
self.A_img_paths,
self.A_label_mask_paths,
paths_sanitized_train_A,
)

print("--- DOMAIN B ---")
if hasattr(self, "B_img_paths"):
self.B_img_paths, self.B_label_mask_paths = sanitize_paths(
self.B_img_paths,
self.B_label_mask_paths,
mask_delta=self.opt.data_online_creation_mask_delta_B,
mask_random_offset=self.opt.data_online_creation_mask_random_offset_B,
crop_delta=self.opt.data_online_creation_crop_delta_B,
mask_square=self.opt.data_online_creation_mask_square_B,
crop_dim=self.opt.data_online_creation_crop_size_B,
output_dim=self.opt.data_load_size,
max_dataset_size=self.opt.data_max_dataset_size,
context_pixels=self.opt.data_online_context_pixels,
load_size=self.opt.data_online_creation_load_size_B,
data_relative_paths=self.opt.data_relative_paths,
data_root_path=self.opt.root,
)
write_paths_file(
self.B_img_paths,
self.B_label_mask_paths,
paths_sanitized_train_B,
)

print("--------------")

def get_img(
self,
A_img_path,
Expand All @@ -164,12 +89,8 @@ def get_img(
clamp_semantics=True,
):
# Domain A

try:
if (
len(self.opt.data_online_creation_mask_delta_A_ratio[0]) == 1
and self.opt.data_online_creation_mask_delta_A_ratio[0][0] == 0
):
if self.opt.data_online_creation_mask_delta_A_ratio == [[]]:
mask_delta_A = self.opt.data_online_creation_mask_delta_A
else:
mask_delta_A = self.opt.data_online_creation_mask_delta_A_ratio
Expand Down Expand Up @@ -218,10 +139,7 @@ def get_img(
# Domain B
if B_img_path is not None:
try:
if (
len(self.opt.data_online_creation_mask_delta_B_ratio[0]) == 1
and self.opt.data_online_creation_mask_delta_B_ratio[0][0] == 0
):
if self.opt.data_online_creation_mask_delta_B_ratio == [[]]:
mask_delta_B = self.opt.data_online_creation_mask_delta_B
else:
mask_delta_B = self.opt.data_online_creation_mask_delta_B_ratio
Expand Down
1 change: 1 addition & 0 deletions docker/Dockerfile.devel
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ RUN export DEBIAN_FRONTEND=noninteractive && \
python3-pip \
python3-opencv \
python3-pytest \
ninja-build \
sudo \
wget \
git \
Expand Down
Loading

0 comments on commit 9dd1a69

Please sign in to comment.