Skip to content

Commit

Permalink
fix: temporal discriminator with masks and bboxes
Browse files Browse the repository at this point in the history
  • Loading branch information
beniz committed Aug 29, 2023
1 parent 911fc20 commit fcfc253
Show file tree
Hide file tree
Showing 9 changed files with 200 additions and 96 deletions.
54 changes: 33 additions & 21 deletions data/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1036,10 +1036,21 @@ class ComposeMaskList(transforms.Compose):
>>> ])
"""

def __call__(self, imgs, masks=None):
def __call__(self, imgs, masks=None, bbox=None):
if bbox is None:
w, h = imgs[0].size
bbox = np.array([0, 0, w, h]) # sets bbox to full image size
if torch.__version__[0] == "2":
tbbox = datapoints.BoundingBox(
bbox,
format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=imgs[0].size,
)
else:
tbbox = bbox # placeholder
for t in self.transforms:
imgs, masks = t(imgs, masks)
return imgs, masks
imgs, masks, tbbox = t(imgs, masks, tbbox)
return imgs, masks, tbbox


class GrayscaleMaskList(transforms.Grayscale):
Expand All @@ -1058,7 +1069,7 @@ class GrayscaleMaskList(transforms.Grayscale):
def __init__(self, num_output_channels=1):
self.num_output_channels = num_output_channels

def __call__(self, imgs, masks):
def __call__(self, imgs, masks, bbox):
"""
Args:
img (PIL Image): Image to be converted to grayscale.
Expand All @@ -1072,7 +1083,7 @@ def __call__(self, imgs, masks):
F.to_grayscale(img, num_output_channels=self.num_output_channels)
)

return return_imgs, masks
return return_imgs, masks, bbox

def __repr__(self):
return self.__class__.__name__ + "(num_output_channels={0})".format(
Expand All @@ -1093,14 +1104,15 @@ class ResizeMaskList(transforms.Resize):
``PIL.Image.BILINEAR``
"""

def __call__(self, imgs, masks):
def __call__(self, imgs, masks, bbox):
"""
Args:
img (PIL Image): Image to be scaled.
Returns:
PIL Image: Rescaled image.
"""

return_imgs = []
return_masks = []

Expand All @@ -1115,7 +1127,7 @@ def __call__(self, imgs, masks):
return_masks.append(
F.resize(mask, self.size, interpolation=InterpolationMode.NEAREST)
)
return return_imgs, return_masks
return return_imgs, return_masks, F2.resize(bbox, self.size)


class RandomCropMaskList(transforms.RandomCrop):
Expand Down Expand Up @@ -1154,7 +1166,7 @@ class RandomCropMaskList(transforms.RandomCrop):
"""

def __call__(self, imgs, masks):
def __call__(self, imgs, masks, bbox):
"""
Args:
img (PIL Image): Image to be cropped.
Expand Down Expand Up @@ -1186,7 +1198,7 @@ def __call__(self, imgs, masks):
else:
for mask in masks:
return_masks.append(F.crop(mask, i, j, h, w))
return return_imgs, return_masks
return return_imgs, return_masks, F2.crop(bbox, i, j, h, w)


class RandomHorizontalFlipMaskList(transforms.RandomHorizontalFlip):
Expand All @@ -1196,7 +1208,7 @@ class RandomHorizontalFlipMaskList(transforms.RandomHorizontalFlip):
p (float): probability of the image being flipped. Default value is 0.5
"""

def __call__(self, imgs, masks):
def __call__(self, imgs, masks, bbox):
"""
Args:
img (PIL Image): Image to be flipped.
Expand All @@ -1214,9 +1226,9 @@ def __call__(self, imgs, masks):
else:
return_masks = None

return return_imgs, return_masks
return return_imgs, return_masks, F2.hflip(bbox)
else:
return imgs, masks
return imgs, masks, bbox


class ToTensorMaskList(transforms.ToTensor):
Expand All @@ -1230,7 +1242,7 @@ class ToTensorMaskList(transforms.ToTensor):
In the other cases, tensors are returned without scaling.
"""

def __call__(self, imgs, masks):
def __call__(self, imgs, masks, bbox):
"""
Args:
pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
Expand All @@ -1248,7 +1260,7 @@ def __call__(self, imgs, masks):
)
else:
return_masks = None
return return_imgs, return_masks
return return_imgs, return_masks, bbox.data


class RandomRotationMaskList(transforms.RandomRotation):
Expand Down Expand Up @@ -1286,7 +1298,7 @@ def get_params(degrees):

return angle

def __call__(self, imgs, masks):
def __call__(self, imgs, masks, bbox):
"""
Args:
img (PIL Image): Image to be rotated.
Expand All @@ -1305,7 +1317,7 @@ def __call__(self, imgs, masks):
else:
return_masks = None

return return_imgs, return_masks
return return_imgs, return_masks, F2.rotate(bbox, angle)


class NormalizeMaskList(transforms.Normalize):
Expand All @@ -1324,7 +1336,7 @@ class NormalizeMaskList(transforms.Normalize):
"""

def __call__(self, tensor_imgs, tensor_masks):
def __call__(self, tensor_imgs, tensor_masks, tensor_bbox):
"""
Args:
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
Expand All @@ -1339,7 +1351,7 @@ def __call__(self, tensor_imgs, tensor_masks):
F.normalize(tensor_img, self.mean, self.std, self.inplace)
)

return return_imgs, tensor_masks
return return_imgs, tensor_masks, tensor_bbox

def __repr__(self):
return self.__class__.__name__ + "(mean={0}, std={1})".format(
Expand All @@ -1357,7 +1369,7 @@ def set_params(self, p, translate, scale_min, scale_max, shear):
self.scale_max = scale_max
self.shear = shear

def __call__(self, imgs, masks):
def __call__(self, imgs, masks, bbox):

if random.random() > 1.0 - self.p:
affine_params = self.get_params(
Expand All @@ -1376,6 +1388,6 @@ def __call__(self, imgs, masks):
else:
return_masks = None

return return_imgs, return_masks
return return_imgs, return_masks, F2.affine(bbox, *affine_params)
else:
return imgs, masks
return imgs, masks, bbox
29 changes: 18 additions & 11 deletions data/temporal_labeled_mask_online_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def __init__(self, opt, phase):
self.A_img_paths, self.A_label_mask_paths = make_labeled_path_dataset(
self.dir_A, "/paths.txt"
) # load images from '/path/to/data/trainA/paths.txt' as well as labels

if self.use_domain_B:
self.B_img_paths, self.B_label_mask_paths = make_labeled_path_dataset(
self.dir_B, "/paths.txt"
Expand Down Expand Up @@ -142,7 +141,7 @@ def get_img(
fixed_mask_size=self.opt.data_online_fixed_mask_size,
)

cur_A_img, cur_A_label, ref_A_bbox = crop_image(
cur_A_img, cur_A_label, A_bbox = crop_image(
cur_A_img_path,
cur_A_label_path,
mask_delta=mask_delta_A,
Expand All @@ -156,6 +155,8 @@ def get_img(
crop_coordinates=crop_coordinates,
fixed_mask_size=self.opt.data_online_fixed_mask_size,
)
if i == 0:
A_ref_bbox = A_bbox[1:]

except Exception as e:
print(e, f"{i+1}th frame of domain A in temporal dataloading")
Expand All @@ -164,10 +165,9 @@ def get_img(
images_A.append(cur_A_img)
labels_A.append(cur_A_label)

images_A, labels_A = self.transform(images_A, labels_A)

images_A, labels_A, A_ref_bbox = self.transform(images_A, labels_A, A_ref_bbox)
A_ref_label = labels_A[0]
images_A = torch.stack(images_A)

labels_A = torch.stack(labels_A)

if self.use_domain_B:
Expand Down Expand Up @@ -222,7 +222,7 @@ def get_img(
get_crop_coordinates=True,
)

cur_B_img, cur_B_label, ref_B_bbox = crop_image(
cur_B_img, cur_B_label, B_bbox = crop_image(
cur_B_img_path,
cur_B_label_path,
mask_delta=mask_delta_B,
Expand All @@ -236,6 +236,8 @@ def get_img(
crop_coordinates=crop_coordinates,
fixed_mask_size=self.opt.data_online_fixed_mask_size,
)
if i == 0:
B_ref_bbox = B_bbox[1:]

except Exception as e:
print(e, f"{i+1}th frame of domain B in temporal dataloading")
Expand All @@ -244,10 +246,11 @@ def get_img(
images_B.append(cur_B_img)
labels_B.append(cur_B_label)

images_B, labels_B = self.transform(images_B, labels_B)

images_B, labels_B, B_ref_bbox = self.transform(
images_B, labels_B, B_ref_bbox
)
B_ref_label = labels_B[0]
images_B = torch.stack(images_B)

labels_B = torch.stack(labels_B)

else:
Expand All @@ -258,14 +261,18 @@ def get_img(
result = {
"A": images_A,
"A_img_paths": ref_A_img_path,
"A_ref_bbox": A_ref_bbox,
"B": images_B,
"B_img_paths": ref_B_img_path,
"B_ref_bbox": B_ref_bbox,
}

result.update(
{
"A_label_mask": labels_A,
"B_label_mask": labels_B,
"A_label_masks": labels_A,
"B_label_masks": labels_B,
"A_label_mask": A_ref_label,
"B_label_mask": B_ref_label,
}
)

Expand Down
30 changes: 0 additions & 30 deletions models/base_gan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,6 @@ def forward_GAN(self):

if self.use_temporal:
self.compute_temporal_fake(objective_domain="B")

if hasattr(self, "netG_B"):
self.compute_temporal_fake(objective_domain="A")

Expand Down Expand Up @@ -480,35 +479,6 @@ def compute_G_loss_GAN(self):
if self.opt.train_temporal_criterion:
self.compute_temporal_criterion_loss()

def compute_fake_with_context(self, fake_name, real_name):
setattr(
self,
fake_name + "_with_context",
torch.nn.functional.pad(
getattr(self, fake_name),
(
self.opt.data_online_context_pixels,
self.opt.data_online_context_pixels,
self.opt.data_online_context_pixels,
self.opt.data_online_context_pixels,
),
),
)

setattr(
self,
fake_name + "_with_context",
getattr(self, fake_name + "_with_context")
+ self.mask_context * getattr(self, real_name + "_with_context"),
)
setattr(
self,
fake_name + "_with_context_vis",
torch.nn.functional.interpolate(
getattr(self, fake_name + "_with_context"), size=self.real_A.shape[2:]
),
)

# compute_real_fake_with_depth
def compute_fake_real_with_depth(self, fake_name, real_name):
fake_depth = predict_depth(
Expand Down
Loading

0 comments on commit fcfc253

Please sign in to comment.