Skip to content

Commit

Permalink
fix: temporal discriminator with masks and bboxes
Browse files Browse the repository at this point in the history
  • Loading branch information
beniz committed Oct 4, 2023
1 parent 3577c8f commit bf7674f
Show file tree
Hide file tree
Showing 12 changed files with 218 additions and 142 deletions.
4 changes: 2 additions & 2 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ pipeline {
dockerfile {
filename 'docker/Dockerfile.devel'
additionalBuildArgs '--no-cache'
args '-u jenkins'
args '--shm-size=8gb -u jenkins'
}

}
Expand Down Expand Up @@ -50,6 +50,6 @@ pipeline {
}
}
environment {
DOCKER_PARAMS = '"--runtime nvidia -u jenkins"'
DOCKER_PARAMS = '"--runtime nvidia --shm-size=8gb -u jenkins"'
}
}
54 changes: 33 additions & 21 deletions data/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1066,10 +1066,21 @@ class ComposeMaskList(transforms.Compose):
>>> ])
"""

def __call__(self, imgs, masks=None):
def __call__(self, imgs, masks=None, bbox=None):
if bbox is None:
w, h = imgs[0].size
bbox = np.array([0, 0, w, h]) # sets bbox to full image size
if torch.__version__[0] == "2":
tbbox = datapoints.BoundingBox(
bbox,
format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=imgs[0].size,
)
else:
tbbox = bbox # placeholder
for t in self.transforms:
imgs, masks = t(imgs, masks)
return imgs, masks
imgs, masks, tbbox = t(imgs, masks, tbbox)
return imgs, masks, tbbox


class GrayscaleMaskList(transforms.Grayscale):
Expand All @@ -1088,7 +1099,7 @@ class GrayscaleMaskList(transforms.Grayscale):
def __init__(self, num_output_channels=1):
self.num_output_channels = num_output_channels

def __call__(self, imgs, masks):
def __call__(self, imgs, masks, bbox):
"""
Args:
img (PIL Image): Image to be converted to grayscale.
Expand All @@ -1102,7 +1113,7 @@ def __call__(self, imgs, masks):
F.to_grayscale(img, num_output_channels=self.num_output_channels)
)

return return_imgs, masks
return return_imgs, masks, bbox

def __repr__(self):
return self.__class__.__name__ + "(num_output_channels={0})".format(
Expand All @@ -1123,14 +1134,15 @@ class ResizeMaskList(transforms.Resize):
``PIL.Image.BILINEAR``
"""

def __call__(self, imgs, masks):
def __call__(self, imgs, masks, bbox):
"""
Args:
img (PIL Image): Image to be scaled.
Returns:
PIL Image: Rescaled image.
"""

return_imgs = []
return_masks = []

Expand All @@ -1145,7 +1157,7 @@ def __call__(self, imgs, masks):
return_masks.append(
F.resize(mask, self.size, interpolation=InterpolationMode.NEAREST)
)
return return_imgs, return_masks
return return_imgs, return_masks, F2.resize(bbox, self.size)


class RandomCropMaskList(transforms.RandomCrop):
Expand Down Expand Up @@ -1184,7 +1196,7 @@ class RandomCropMaskList(transforms.RandomCrop):
"""

def __call__(self, imgs, masks):
def __call__(self, imgs, masks, bbox):
"""
Args:
img (PIL Image): Image to be cropped.
Expand Down Expand Up @@ -1216,7 +1228,7 @@ def __call__(self, imgs, masks):
else:
for mask in masks:
return_masks.append(F.crop(mask, i, j, h, w))
return return_imgs, return_masks
return return_imgs, return_masks, F2.crop(bbox, i, j, h, w)


class RandomHorizontalFlipMaskList(transforms.RandomHorizontalFlip):
Expand All @@ -1226,7 +1238,7 @@ class RandomHorizontalFlipMaskList(transforms.RandomHorizontalFlip):
p (float): probability of the image being flipped. Default value is 0.5
"""

def __call__(self, imgs, masks):
def __call__(self, imgs, masks, bbox):
"""
Args:
img (PIL Image): Image to be flipped.
Expand All @@ -1244,9 +1256,9 @@ def __call__(self, imgs, masks):
else:
return_masks = None

return return_imgs, return_masks
return return_imgs, return_masks, F2.hflip(bbox)
else:
return imgs, masks
return imgs, masks, bbox


class ToTensorMaskList(transforms.ToTensor):
Expand All @@ -1260,7 +1272,7 @@ class ToTensorMaskList(transforms.ToTensor):
In the other cases, tensors are returned without scaling.
"""

def __call__(self, imgs, masks):
def __call__(self, imgs, masks, bbox):
"""
Args:
pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
Expand All @@ -1278,7 +1290,7 @@ def __call__(self, imgs, masks):
)
else:
return_masks = None
return return_imgs, return_masks
return return_imgs, return_masks, bbox.data


class RandomRotationMaskList(transforms.RandomRotation):
Expand Down Expand Up @@ -1316,7 +1328,7 @@ def get_params(degrees):

return angle

def __call__(self, imgs, masks):
def __call__(self, imgs, masks, bbox):
"""
Args:
img (PIL Image): Image to be rotated.
Expand All @@ -1335,7 +1347,7 @@ def __call__(self, imgs, masks):
else:
return_masks = None

return return_imgs, return_masks
return return_imgs, return_masks, F2.rotate(bbox, angle)


class NormalizeMaskList(transforms.Normalize):
Expand All @@ -1354,7 +1366,7 @@ class NormalizeMaskList(transforms.Normalize):
"""

def __call__(self, tensor_imgs, tensor_masks):
def __call__(self, tensor_imgs, tensor_masks, tensor_bbox):
"""
Args:
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
Expand All @@ -1369,7 +1381,7 @@ def __call__(self, tensor_imgs, tensor_masks):
F.normalize(tensor_img, self.mean, self.std, self.inplace)
)

return return_imgs, tensor_masks
return return_imgs, tensor_masks, tensor_bbox

def __repr__(self):
return self.__class__.__name__ + "(mean={0}, std={1})".format(
Expand All @@ -1387,7 +1399,7 @@ def set_params(self, p, translate, scale_min, scale_max, shear):
self.scale_max = scale_max
self.shear = shear

def __call__(self, imgs, masks):
def __call__(self, imgs, masks, bbox):
if random.random() > 1.0 - self.p:
affine_params = self.get_params(
(0, 0),
Expand All @@ -1405,6 +1417,6 @@ def __call__(self, imgs, masks):
else:
return_masks = None

return return_imgs, return_masks
return return_imgs, return_masks, F2.affine(bbox, *affine_params)
else:
return imgs, masks
return imgs, masks, bbox
2 changes: 1 addition & 1 deletion data/online_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def crop_image(
min_crop_bbox_ratio=None,
):
margin = context_pixels * 2

try:
img = load_image(img_path)
if load_size != []:
Expand Down
27 changes: 20 additions & 7 deletions data/temporal_labeled_mask_online_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def __init__(self, opt, phase):
self.A_img_paths, self.A_label_mask_paths = make_labeled_path_dataset(
self.dir_A, "/paths.txt"
) # load images from '/path/to/data/trainA/paths.txt' as well as labels

if self.use_domain_B:
self.B_img_paths, self.B_label_mask_paths = make_labeled_path_dataset(
self.dir_B, "/paths.txt"
Expand Down Expand Up @@ -156,6 +155,8 @@ def get_img(
crop_coordinates=crop_coordinates,
fixed_mask_size=self.opt.data_online_fixed_mask_size,
)
if i == 0:
A_ref_bbox = ref_A_bbox[1:]

except Exception as e:
print(e, f"{i+1}th frame of domain A in temporal dataloading")
Expand All @@ -164,10 +165,10 @@ def get_img(
images_A.append(cur_A_img)
labels_A.append(cur_A_label)

images_A, labels_A = self.transform(images_A, labels_A)

images_A, labels_A, A_ref_bbox = self.transform(images_A, labels_A, A_ref_bbox)
A_ref_label = labels_A[0]
A_ref_img = images_A[0]
images_A = torch.stack(images_A)

labels_A = torch.stack(labels_A)

if self.use_domain_B:
Expand Down Expand Up @@ -236,6 +237,8 @@ def get_img(
crop_coordinates=crop_coordinates,
fixed_mask_size=self.opt.data_online_fixed_mask_size,
)
if i == 0:
B_ref_bbox = ref_B_bbox[1:]

except Exception as e:
print(e, f"{i+1}th frame of domain B in temporal dataloading")
Expand All @@ -244,28 +247,38 @@ def get_img(
images_B.append(cur_B_img)
labels_B.append(cur_B_label)

images_B, labels_B = self.transform(images_B, labels_B)

images_B, labels_B, B_ref_bbox = self.transform(
images_B, labels_B, B_ref_bbox
)
B_ref_label = labels_B[0]
B_ref_img = images_B[0]
images_B = torch.stack(images_B)

labels_B = torch.stack(labels_B)

else:
images_B = None
labels_B = None
ref_B_img_path = ""
B_ref_bbox = ""
B_ref_label = ""

result = {
"A_ref": A_ref_img,
"A": images_A,
"A_img_paths": ref_A_img_path,
"A_ref_bbox": A_ref_bbox,
"B_ref": B_ref_img,
"B": images_B,
"B_img_paths": ref_B_img_path,
"B_ref_bbox": B_ref_bbox,
}

result.update(
{
"A_label_mask": labels_A,
"B_label_mask": labels_B,
"A_ref_label_mask": A_ref_label,
"B_ref_label_mask": B_ref_label,
}
)

Expand Down
43 changes: 7 additions & 36 deletions models/base_gan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,6 @@ def forward_GAN(self):

if self.use_temporal:
self.compute_temporal_fake(objective_domain="B")

if hasattr(self, "netG_B"):
self.compute_temporal_fake(objective_domain="A")

Expand Down Expand Up @@ -480,35 +479,6 @@ def compute_G_loss_GAN(self):
if self.opt.train_temporal_criterion:
self.compute_temporal_criterion_loss()

def compute_fake_with_context(self, fake_name, real_name):
setattr(
self,
fake_name + "_with_context",
torch.nn.functional.pad(
getattr(self, fake_name),
(
self.opt.data_online_context_pixels,
self.opt.data_online_context_pixels,
self.opt.data_online_context_pixels,
self.opt.data_online_context_pixels,
),
),
)

setattr(
self,
fake_name + "_with_context",
getattr(self, fake_name + "_with_context")
+ self.mask_context * getattr(self, real_name + "_with_context"),
)
setattr(
self,
fake_name + "_with_context_vis",
torch.nn.functional.interpolate(
getattr(self, fake_name + "_with_context"), size=self.real_A.shape[2:]
),
)

# compute_real_fake_with_depth
def compute_fake_real_with_depth(self, fake_name, real_name):
fake_depth = predict_depth(
Expand Down Expand Up @@ -622,6 +592,13 @@ def set_discriminators_info(self):
dataaug_D_diffusion_every=dataaug_D_diffusion_every,
)

setattr(
self,
loss_calculator_name,
loss_calculator,
)


if "depth" in discriminator_name:
fake_name = "fake_depth"
real_name = "real_depth"
Expand All @@ -632,12 +609,6 @@ def set_discriminators_info(self):
fake_name = "fake_sam"
real_name = "real_sam"

setattr(
self,
loss_calculator_name,
loss_calculator,
)

self.objects_to_update.append(getattr(self, loss_calculator_name))

self.discriminators.append(
Expand Down
Loading

0 comments on commit bf7674f

Please sign in to comment.