From ba9feac4b056220155c044adb491f3224cec719b Mon Sep 17 00:00:00 2001 From: Emmanuel Benazera Date: Thu, 24 Aug 2023 12:45:25 +0200 Subject: [PATCH] fix: temporal discriminator with masks and bboxes --- data/base_dataset.py | 54 +++++---- data/temporal_labeled_mask_online_dataset.py | 29 +++-- models/base_gan_model.py | 32 +----- models/base_model.py | 112 ++++++++++++++++--- models/cut_model.py | 8 ++ models/cycle_gan_model.py | 4 + scripts/run_tests.sh | 6 +- tests/test_run_semantic_mask_online.py | 6 +- train.py | 8 +- 9 files changed, 170 insertions(+), 89 deletions(-) diff --git a/data/base_dataset.py b/data/base_dataset.py index 633d2679d..d1970af26 100644 --- a/data/base_dataset.py +++ b/data/base_dataset.py @@ -1036,10 +1036,21 @@ class ComposeMaskList(transforms.Compose): >>> ]) """ - def __call__(self, imgs, masks=None): + def __call__(self, imgs, masks=None, bbox=None): + if bbox is None: + w, h = imgs[0].size + bbox = np.array([0, 0, w, h]) # sets bbox to full image size + if torch.__version__[0] == "2": + tbbox = datapoints.BoundingBox( + bbox, + format=datapoints.BoundingBoxFormat.XYXY, + spatial_size=imgs[0].size, + ) + else: + tbbox = bbox # placeholder for t in self.transforms: - imgs, masks = t(imgs, masks) - return imgs, masks + imgs, masks, tbbox = t(imgs, masks, tbbox) + return imgs, masks, tbbox class GrayscaleMaskList(transforms.Grayscale): @@ -1058,7 +1069,7 @@ class GrayscaleMaskList(transforms.Grayscale): def __init__(self, num_output_channels=1): self.num_output_channels = num_output_channels - def __call__(self, imgs, masks): + def __call__(self, imgs, masks, bbox): """ Args: img (PIL Image): Image to be converted to grayscale. @@ -1072,7 +1083,7 @@ def __call__(self, imgs, masks): F.to_grayscale(img, num_output_channels=self.num_output_channels) ) - return return_imgs, masks + return return_imgs, masks, bbox def __repr__(self): return self.__class__.__name__ + "(num_output_channels={0})".format( @@ -1093,7 +1104,7 @@ class ResizeMaskList(transforms.Resize): ``PIL.Image.BILINEAR`` """ - def __call__(self, imgs, masks): + def __call__(self, imgs, masks, bbox): """ Args: img (PIL Image): Image to be scaled. @@ -1101,6 +1112,7 @@ def __call__(self, imgs, masks): Returns: PIL Image: Rescaled image. """ + return_imgs = [] return_masks = [] @@ -1115,7 +1127,7 @@ def __call__(self, imgs, masks): return_masks.append( F.resize(mask, self.size, interpolation=InterpolationMode.NEAREST) ) - return return_imgs, return_masks + return return_imgs, return_masks, F2.resize(bbox, self.size) class RandomCropMaskList(transforms.RandomCrop): @@ -1154,7 +1166,7 @@ class RandomCropMaskList(transforms.RandomCrop): """ - def __call__(self, imgs, masks): + def __call__(self, imgs, masks, bbox): """ Args: img (PIL Image): Image to be cropped. @@ -1186,7 +1198,7 @@ def __call__(self, imgs, masks): else: for mask in masks: return_masks.append(F.crop(mask, i, j, h, w)) - return return_imgs, return_masks + return return_imgs, return_masks, F2.crop(bbox, i, j, h, w) class RandomHorizontalFlipMaskList(transforms.RandomHorizontalFlip): @@ -1196,7 +1208,7 @@ class RandomHorizontalFlipMaskList(transforms.RandomHorizontalFlip): p (float): probability of the image being flipped. Default value is 0.5 """ - def __call__(self, imgs, masks): + def __call__(self, imgs, masks, bbox): """ Args: img (PIL Image): Image to be flipped. @@ -1214,9 +1226,9 @@ def __call__(self, imgs, masks): else: return_masks = None - return return_imgs, return_masks + return return_imgs, return_masks, F2.hflip(bbox) else: - return imgs, masks + return imgs, masks, bbox class ToTensorMaskList(transforms.ToTensor): @@ -1230,7 +1242,7 @@ class ToTensorMaskList(transforms.ToTensor): In the other cases, tensors are returned without scaling. """ - def __call__(self, imgs, masks): + def __call__(self, imgs, masks, bbox): """ Args: pic (PIL Image or numpy.ndarray): Image to be converted to tensor. @@ -1248,7 +1260,7 @@ def __call__(self, imgs, masks): ) else: return_masks = None - return return_imgs, return_masks + return return_imgs, return_masks, bbox.data class RandomRotationMaskList(transforms.RandomRotation): @@ -1286,7 +1298,7 @@ def get_params(degrees): return angle - def __call__(self, imgs, masks): + def __call__(self, imgs, masks, bbox): """ Args: img (PIL Image): Image to be rotated. @@ -1305,7 +1317,7 @@ def __call__(self, imgs, masks): else: return_masks = None - return return_imgs, return_masks + return return_imgs, return_masks, F2.rotate(bbox, angle) class NormalizeMaskList(transforms.Normalize): @@ -1324,7 +1336,7 @@ class NormalizeMaskList(transforms.Normalize): """ - def __call__(self, tensor_imgs, tensor_masks): + def __call__(self, tensor_imgs, tensor_masks, tensor_bbox): """ Args: tensor (Tensor): Tensor image of size (C, H, W) to be normalized. @@ -1339,7 +1351,7 @@ def __call__(self, tensor_imgs, tensor_masks): F.normalize(tensor_img, self.mean, self.std, self.inplace) ) - return return_imgs, tensor_masks + return return_imgs, tensor_masks, tensor_bbox def __repr__(self): return self.__class__.__name__ + "(mean={0}, std={1})".format( @@ -1357,7 +1369,7 @@ def set_params(self, p, translate, scale_min, scale_max, shear): self.scale_max = scale_max self.shear = shear - def __call__(self, imgs, masks): + def __call__(self, imgs, masks, bbox): if random.random() > 1.0 - self.p: affine_params = self.get_params( @@ -1376,6 +1388,6 @@ def __call__(self, imgs, masks): else: return_masks = None - return return_imgs, return_masks + return return_imgs, return_masks, F2.affine(bbox, *affine_params) else: - return imgs, masks + return imgs, masks, bbox diff --git a/data/temporal_labeled_mask_online_dataset.py b/data/temporal_labeled_mask_online_dataset.py index 3532efed8..297bd98a4 100644 --- a/data/temporal_labeled_mask_online_dataset.py +++ b/data/temporal_labeled_mask_online_dataset.py @@ -34,7 +34,6 @@ def __init__(self, opt, phase): self.A_img_paths, self.A_label_mask_paths = make_labeled_path_dataset( self.dir_A, "/paths.txt" ) # load images from '/path/to/data/trainA/paths.txt' as well as labels - if self.use_domain_B: self.B_img_paths, self.B_label_mask_paths = make_labeled_path_dataset( self.dir_B, "/paths.txt" @@ -142,7 +141,7 @@ def get_img( fixed_mask_size=self.opt.data_online_fixed_mask_size, ) - cur_A_img, cur_A_label, ref_A_bbox = crop_image( + cur_A_img, cur_A_label, A_bbox = crop_image( cur_A_img_path, cur_A_label_path, mask_delta=mask_delta_A, @@ -156,6 +155,8 @@ def get_img( crop_coordinates=crop_coordinates, fixed_mask_size=self.opt.data_online_fixed_mask_size, ) + if i == 0: + A_ref_bbox = A_bbox[1:] except Exception as e: print(e, f"{i+1}th frame of domain A in temporal dataloading") @@ -164,10 +165,9 @@ def get_img( images_A.append(cur_A_img) labels_A.append(cur_A_label) - images_A, labels_A = self.transform(images_A, labels_A) - + images_A, labels_A, A_ref_bbox = self.transform(images_A, labels_A, A_ref_bbox) + A_ref_label = labels_A[0] images_A = torch.stack(images_A) - labels_A = torch.stack(labels_A) if self.use_domain_B: @@ -222,7 +222,7 @@ def get_img( get_crop_coordinates=True, ) - cur_B_img, cur_B_label, ref_B_bbox = crop_image( + cur_B_img, cur_B_label, B_bbox = crop_image( cur_B_img_path, cur_B_label_path, mask_delta=mask_delta_B, @@ -236,6 +236,8 @@ def get_img( crop_coordinates=crop_coordinates, fixed_mask_size=self.opt.data_online_fixed_mask_size, ) + if i == 0: + B_ref_bbox = B_bbox[1:] except Exception as e: print(e, f"{i+1}th frame of domain B in temporal dataloading") @@ -244,10 +246,11 @@ def get_img( images_B.append(cur_B_img) labels_B.append(cur_B_label) - images_B, labels_B = self.transform(images_B, labels_B) - + images_B, labels_B, B_ref_bbox = self.transform( + images_B, labels_B, B_ref_bbox + ) + B_ref_label = labels_B[0] images_B = torch.stack(images_B) - labels_B = torch.stack(labels_B) else: @@ -258,14 +261,18 @@ def get_img( result = { "A": images_A, "A_img_paths": ref_A_img_path, + "A_ref_bbox": A_ref_bbox, "B": images_B, "B_img_paths": ref_B_img_path, + "B_ref_bbox": B_ref_bbox, } result.update( { - "A_label_mask": labels_A, - "B_label_mask": labels_B, + "A_label_masks": labels_A, + "B_label_masks": labels_B, + "A_label_mask": A_ref_label, + "B_label_mask": B_ref_label, } ) diff --git a/models/base_gan_model.py b/models/base_gan_model.py index 8b69a7a7e..cf81452de 100644 --- a/models/base_gan_model.py +++ b/models/base_gan_model.py @@ -208,7 +208,6 @@ def forward_GAN(self): if self.use_temporal: self.compute_temporal_fake(objective_domain="B") - if hasattr(self, "netG_B"): self.compute_temporal_fake(objective_domain="A") @@ -480,35 +479,6 @@ def compute_G_loss_GAN(self): if self.opt.train_temporal_criterion: self.compute_temporal_criterion_loss() - def compute_fake_with_context(self, fake_name, real_name): - setattr( - self, - fake_name + "_with_context", - torch.nn.functional.pad( - getattr(self, fake_name), - ( - self.opt.data_online_context_pixels, - self.opt.data_online_context_pixels, - self.opt.data_online_context_pixels, - self.opt.data_online_context_pixels, - ), - ), - ) - - setattr( - self, - fake_name + "_with_context", - getattr(self, fake_name + "_with_context") - + self.mask_context * getattr(self, real_name + "_with_context"), - ) - setattr( - self, - fake_name + "_with_context_vis", - torch.nn.functional.interpolate( - getattr(self, fake_name + "_with_context"), size=self.real_A.shape[2:] - ), - ) - # compute_real_fake_with_depth def compute_fake_real_with_depth(self, fake_name, real_name): fake_depth = predict_depth( @@ -769,7 +739,7 @@ def compute_G_loss_semantic_mask_generic(self, domain_fake): else: label_fake = getattr(self, "input_%s_label_mask" % domain_real) # logits - loss_G_sem_mask = self.opt.train_sem_mask_lambda * self.criterionf_s( + loss_G_sem_mask = self.opt.train_sem_mask_lambda * self.criterionf_s( getattr(self, "pred_f_s_fake_%s" % domain_fake), label_fake ) diff --git a/models/base_model.py b/models/base_model.py index e4f5da393..f442c6c1e 100644 --- a/models/base_model.py +++ b/models/base_model.py @@ -386,9 +386,9 @@ def set_input(self, data): self.opt.data_online_context_pixels : -self.opt.data_online_context_pixels, ] - self.real_B_with_context_vis = torch.nn.functional.interpolate( - self.real_B_with_context, size=self.real_A.shape[2:] - ) + self.real_B_with_context_vis = torch.nn.functional.interpolate( + self.real_B_with_context, size=self.real_B.shape[2:] + ) self.image_paths = data["A_img_paths"] @@ -492,7 +492,8 @@ def set_input_temporal(self, data_temporal): "temporal_real_A_" + str(i) + "_with_context_vis", torch.nn.functional.interpolate( getattr(self, "temporal_real_A_" + str(i) + "_with_context"), - size=self.real_A.shape[2:], + # size=self.temporal_real_A.shape[2:] + size=getattr(self, "temporal_real_A_" + str(i)).shape[2:], ), ) else: @@ -531,7 +532,8 @@ def set_input_temporal(self, data_temporal): "temporal_real_B_" + str(i) + "_with_context_vis", torch.nn.functional.interpolate( getattr(self, "temporal_real_B_" + str(i) + "_with_context"), - size=self.real_B.shape[2:], + # size=self.real_B.shape[2:], + size=getattr(self, "temporal_real_B_" + str(i)).shape[2:], ), ) else: @@ -544,10 +546,92 @@ def set_input_temporal(self, data_temporal): ], ) + self.image_paths = data_temporal["A_img_paths"] + + self.input_A_ref_bbox = None # set by semantics + self.input_B_ref_bbox = None + + # first image from temporal sequence is the reference image + self.real_A = self.temporal_real_A_0.clone() + self.real_A_with_context = self.temporal_real_A_0_with_context.clone() + self.real_B = self.temporal_real_B_0.clone() + self.real_B_with_context = self.temporal_real_B_0_with_context.clone() + if self.opt.data_online_context_pixels > 0: + self.real_A_with_context_vis = torch.nn.functional.interpolate( + self.real_A_with_context, size=self.real_A.shape[2:] + ) + self.real_B_with_context_vis = torch.nn.functional.interpolate( + self.real_B_with_context, size=self.real_B.shape[2:] + ) + + if self.opt.train_semantic_mask: + self.set_input_semantic_mask(data_temporal) + if self.opt.train_semantic_cls: + self.set_input_semantic_cls(data_temporal) + def forward(self): for forward_function in self.forward_functions: getattr(self, forward_function)() + def compute_fake_with_context(self, fake_name, real_name): + setattr( + self, + fake_name + "_with_context", + torch.nn.functional.pad( + getattr(self, fake_name), + ( + self.opt.data_online_context_pixels, + self.opt.data_online_context_pixels, + self.opt.data_online_context_pixels, + self.opt.data_online_context_pixels, + ), + ), + ) + + setattr( + self, + fake_name + "_with_context", + getattr(self, fake_name + "_with_context") + + self.mask_context * getattr(self, real_name + "_with_context"), + ) + setattr( + self, + fake_name + "_with_context_vis", + torch.nn.functional.interpolate( + getattr(self, fake_name + "_with_context"), size=self.real_A.shape[2:] + ), + ) + + def compute_temporal_fake_with_context(self, fake_name, real_name): + setattr( + self, + fake_name + "_with_context", + torch.nn.functional.pad( + getattr(self, fake_name), + ( + self.opt.data_online_context_pixels, + self.opt.data_online_context_pixels, + self.opt.data_online_context_pixels, + self.opt.data_online_context_pixels, + ), + ), + ) + + setattr( + self, + fake_name + "_with_context", + getattr(self, fake_name + "_with_context") + + self.mask_context * getattr(self, real_name + "_with_context"), + ) + setattr( + self, + fake_name + "_with_context_vis", + torch.nn.functional.interpolate( + getattr(self, fake_name + "_with_context"), + size=getattr(self, real_name).shape[2:], + ), + ) + def compute_temporal_fake(self, objective_domain): origin_domain = "B" if objective_domain == "A" else "A" netG = getattr(self, "netG_" + origin_domain) @@ -567,7 +651,7 @@ def compute_temporal_fake(self, objective_domain): temporal_fake[:, i], ) if self.opt.data_online_context_pixels > 0: - self.compute_fake_with_context( + self.compute_temporal_fake_with_context( fake_name="temporal_fake_" + objective_domain + "_" + str(i), real_name="temporal_real_" + origin_domain + "_" + str(i), ) @@ -972,7 +1056,6 @@ def optimize_parameters(self): True # automatic fall back to eager mode ) - # print(self.niter,self.opt.train_iter_size, self.niter % self.opt.train_iter_size, self.niter % self.opt.train_iter_size != 0, ) if len(self.opt.gpu_ids) > 1 and self.niter % self.opt.train_iter_size != 0: for network in self.model_names: stack.enter_context(getattr(self, "net" + network).no_sync()) @@ -1363,18 +1446,13 @@ def compute_metrics_test(self, dataloaders_test, n_epoch, n_iter): ): # inner loop (minibatch) within one epoch data_test = data_test_list[0] - use_temporal = ( - "temporal" in self.opt.D_netDs - ) or self.opt.train_temporal_criterion - - if use_temporal: + if self.use_temporal: temporal_data_test = data_test_list[1] - - self.set_input( - data_test - ) # unpack data from dataloader and apply preprocessing - if use_temporal: self.set_input_temporal(temporal_data_test) + else: + self.set_input( + data_test + ) # unpack data from dataloader and apply preprocessing self.inference() diff --git a/models/cut_model.py b/models/cut_model.py index 05e3719cf..fb0369a5e 100644 --- a/models/cut_model.py +++ b/models/cut_model.py @@ -423,6 +423,7 @@ def data_dependent_initialize(self, data): real_A_with_z = torch.cat([self.real_A, z_real], 1) else: real_A_with_z = self.real_A + feat_temp = self.netG_A.get_feats(real_A_with_z.cpu(), self.nce_layers) self.netF.data_dependent_initialize(feat_temp) @@ -532,6 +533,10 @@ def forward_cut(self): self.fake_B = self.fake[: self.real_A.size(0)] if self.opt.data_online_context_pixels > 0: + if self.use_temporal: + self.compute_temporal_fake_with_context( + fake_name="temporal_fake_B_0", real_name="temporal_real_A_0" + ) self.compute_fake_with_context(fake_name="fake_B", real_name="real_A") if self.use_depth: @@ -550,6 +555,9 @@ def forward_cut(self): if self.opt.data_online_context_pixels > 0: context = "_with_context" + # if self.use_temporal: + # names = ["temporal_fake_B_0", "temporal_real_B_0"] + # else: names = ["fake_B", "real_B"] for name in names: setattr( diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py index 25a1882e2..bc59bd66b 100644 --- a/models/cycle_gan_model.py +++ b/models/cycle_gan_model.py @@ -278,6 +278,10 @@ def forward_cycle_gan(self): self.fake_B = self.netG_A(self.real_A) # G_A(A) if self.opt.data_online_context_pixels > 0: + if self.use_temporal: + self.compute_temporal_fake_with_context( + fake_name="temporal_fake_B_0", real_name="temporal_real_A_0" + ) self.compute_fake_with_context(fake_name="fake_B", real_name="real_A") # Rec A diff --git a/scripts/run_tests.sh b/scripts/run_tests.sh index cce4db8e0..aedf499e5 100644 --- a/scripts/run_tests.sh +++ b/scripts/run_tests.sh @@ -84,9 +84,9 @@ fi ####### mask cls semantics test with online dataloading echo "Running mask online semantics training tests" -URL=https://joligen.com/datasets/online_mario2sonic_lite.zip -ZIP_FILE=$DIR/online_mario2sonic_lite.zip -TARGET_MASK_SEM_ONLINE_DIR=$DIR/online_mario2sonic_lite +URL=https://joligen.com/datasets/online_mario2sonic_lite2.zip +ZIP_FILE=$DIR/online_mario2sonic_lite2.zip +TARGET_MASK_SEM_ONLINE_DIR=$DIR/online_mario2sonic_lite2 wget -N $URL -O $ZIP_FILE mkdir $TARGET_MASK_SEM_ONLINE_DIR unzip $ZIP_FILE -d $DIR diff --git a/tests/test_run_semantic_mask_online.py b/tests/test_run_semantic_mask_online.py index a817b40c7..51443cbf4 100644 --- a/tests/test_run_semantic_mask_online.py +++ b/tests/test_run_semantic_mask_online.py @@ -31,7 +31,7 @@ "data_max_dataset_size": 10, "train_mask_out_mask": True, "f_s_net": "unet", - "f_s_semantic_nclasses": 2, + "f_s_semantic_nclasses": 7, "dataaug_D_noise": 0.001, "train_sem_use_label_B": True, "data_relative_paths": True, @@ -43,10 +43,10 @@ "dataaug_no_rotate": True, "train_mask_compute_miou": True, "train_mask_miou_every": 1, - "data_temporal_number_frames": 2, + "data_temporal_number_frames": 4, "data_temporal_frame_step": 2, "train_semantic_mask": True, - "train_temporal_criterion": True, + "train_temporal_criterion": False, "train_export_jit": True, "train_save_latest_freq": 10, } diff --git a/train.py b/train.py index 46717c9b6..3056cbf0b 100644 --- a/train.py +++ b/train.py @@ -190,15 +190,17 @@ def train_gpu(rank, world_size, opt, trainset, trainset_temporal): dataloaders ): # inner loop (minibatch) within one epoch data = data_list[0] - if use_temporal: - temporal_data = data_list[1] iter_start_time = time.time() # timer for computation per iteration t_data_mini_batch = iter_start_time - iter_data_time - model.set_input(data) # unpack data from dataloader and apply preprocessing if use_temporal: + temporal_data = data_list[1] model.set_input_temporal(temporal_data) + else: + model.set_input( + data + ) # unpack data from dataloader and apply preprocessing model.optimize_parameters() # calculate loss functions, get gradients, update network weights t_comp = (time.time() - iter_start_time) / opt.train_batch_size