diff --git a/data/base_dataset.py b/data/base_dataset.py index 633d2679d..d1970af26 100644 --- a/data/base_dataset.py +++ b/data/base_dataset.py @@ -1036,10 +1036,21 @@ class ComposeMaskList(transforms.Compose): >>> ]) """ - def __call__(self, imgs, masks=None): + def __call__(self, imgs, masks=None, bbox=None): + if bbox is None: + w, h = imgs[0].size + bbox = np.array([0, 0, w, h]) # sets bbox to full image size + if torch.__version__[0] == "2": + tbbox = datapoints.BoundingBox( + bbox, + format=datapoints.BoundingBoxFormat.XYXY, + spatial_size=imgs[0].size, + ) + else: + tbbox = bbox # placeholder for t in self.transforms: - imgs, masks = t(imgs, masks) - return imgs, masks + imgs, masks, tbbox = t(imgs, masks, tbbox) + return imgs, masks, tbbox class GrayscaleMaskList(transforms.Grayscale): @@ -1058,7 +1069,7 @@ class GrayscaleMaskList(transforms.Grayscale): def __init__(self, num_output_channels=1): self.num_output_channels = num_output_channels - def __call__(self, imgs, masks): + def __call__(self, imgs, masks, bbox): """ Args: img (PIL Image): Image to be converted to grayscale. @@ -1072,7 +1083,7 @@ def __call__(self, imgs, masks): F.to_grayscale(img, num_output_channels=self.num_output_channels) ) - return return_imgs, masks + return return_imgs, masks, bbox def __repr__(self): return self.__class__.__name__ + "(num_output_channels={0})".format( @@ -1093,7 +1104,7 @@ class ResizeMaskList(transforms.Resize): ``PIL.Image.BILINEAR`` """ - def __call__(self, imgs, masks): + def __call__(self, imgs, masks, bbox): """ Args: img (PIL Image): Image to be scaled. @@ -1101,6 +1112,7 @@ def __call__(self, imgs, masks): Returns: PIL Image: Rescaled image. """ + return_imgs = [] return_masks = [] @@ -1115,7 +1127,7 @@ def __call__(self, imgs, masks): return_masks.append( F.resize(mask, self.size, interpolation=InterpolationMode.NEAREST) ) - return return_imgs, return_masks + return return_imgs, return_masks, F2.resize(bbox, self.size) class RandomCropMaskList(transforms.RandomCrop): @@ -1154,7 +1166,7 @@ class RandomCropMaskList(transforms.RandomCrop): """ - def __call__(self, imgs, masks): + def __call__(self, imgs, masks, bbox): """ Args: img (PIL Image): Image to be cropped. @@ -1186,7 +1198,7 @@ def __call__(self, imgs, masks): else: for mask in masks: return_masks.append(F.crop(mask, i, j, h, w)) - return return_imgs, return_masks + return return_imgs, return_masks, F2.crop(bbox, i, j, h, w) class RandomHorizontalFlipMaskList(transforms.RandomHorizontalFlip): @@ -1196,7 +1208,7 @@ class RandomHorizontalFlipMaskList(transforms.RandomHorizontalFlip): p (float): probability of the image being flipped. Default value is 0.5 """ - def __call__(self, imgs, masks): + def __call__(self, imgs, masks, bbox): """ Args: img (PIL Image): Image to be flipped. @@ -1214,9 +1226,9 @@ def __call__(self, imgs, masks): else: return_masks = None - return return_imgs, return_masks + return return_imgs, return_masks, F2.hflip(bbox) else: - return imgs, masks + return imgs, masks, bbox class ToTensorMaskList(transforms.ToTensor): @@ -1230,7 +1242,7 @@ class ToTensorMaskList(transforms.ToTensor): In the other cases, tensors are returned without scaling. """ - def __call__(self, imgs, masks): + def __call__(self, imgs, masks, bbox): """ Args: pic (PIL Image or numpy.ndarray): Image to be converted to tensor. @@ -1248,7 +1260,7 @@ def __call__(self, imgs, masks): ) else: return_masks = None - return return_imgs, return_masks + return return_imgs, return_masks, bbox.data class RandomRotationMaskList(transforms.RandomRotation): @@ -1286,7 +1298,7 @@ def get_params(degrees): return angle - def __call__(self, imgs, masks): + def __call__(self, imgs, masks, bbox): """ Args: img (PIL Image): Image to be rotated. @@ -1305,7 +1317,7 @@ def __call__(self, imgs, masks): else: return_masks = None - return return_imgs, return_masks + return return_imgs, return_masks, F2.rotate(bbox, angle) class NormalizeMaskList(transforms.Normalize): @@ -1324,7 +1336,7 @@ class NormalizeMaskList(transforms.Normalize): """ - def __call__(self, tensor_imgs, tensor_masks): + def __call__(self, tensor_imgs, tensor_masks, tensor_bbox): """ Args: tensor (Tensor): Tensor image of size (C, H, W) to be normalized. @@ -1339,7 +1351,7 @@ def __call__(self, tensor_imgs, tensor_masks): F.normalize(tensor_img, self.mean, self.std, self.inplace) ) - return return_imgs, tensor_masks + return return_imgs, tensor_masks, tensor_bbox def __repr__(self): return self.__class__.__name__ + "(mean={0}, std={1})".format( @@ -1357,7 +1369,7 @@ def set_params(self, p, translate, scale_min, scale_max, shear): self.scale_max = scale_max self.shear = shear - def __call__(self, imgs, masks): + def __call__(self, imgs, masks, bbox): if random.random() > 1.0 - self.p: affine_params = self.get_params( @@ -1376,6 +1388,6 @@ def __call__(self, imgs, masks): else: return_masks = None - return return_imgs, return_masks + return return_imgs, return_masks, F2.affine(bbox, *affine_params) else: - return imgs, masks + return imgs, masks, bbox diff --git a/data/temporal_labeled_mask_online_dataset.py b/data/temporal_labeled_mask_online_dataset.py index 3532efed8..297bd98a4 100644 --- a/data/temporal_labeled_mask_online_dataset.py +++ b/data/temporal_labeled_mask_online_dataset.py @@ -34,7 +34,6 @@ def __init__(self, opt, phase): self.A_img_paths, self.A_label_mask_paths = make_labeled_path_dataset( self.dir_A, "/paths.txt" ) # load images from '/path/to/data/trainA/paths.txt' as well as labels - if self.use_domain_B: self.B_img_paths, self.B_label_mask_paths = make_labeled_path_dataset( self.dir_B, "/paths.txt" @@ -142,7 +141,7 @@ def get_img( fixed_mask_size=self.opt.data_online_fixed_mask_size, ) - cur_A_img, cur_A_label, ref_A_bbox = crop_image( + cur_A_img, cur_A_label, A_bbox = crop_image( cur_A_img_path, cur_A_label_path, mask_delta=mask_delta_A, @@ -156,6 +155,8 @@ def get_img( crop_coordinates=crop_coordinates, fixed_mask_size=self.opt.data_online_fixed_mask_size, ) + if i == 0: + A_ref_bbox = A_bbox[1:] except Exception as e: print(e, f"{i+1}th frame of domain A in temporal dataloading") @@ -164,10 +165,9 @@ def get_img( images_A.append(cur_A_img) labels_A.append(cur_A_label) - images_A, labels_A = self.transform(images_A, labels_A) - + images_A, labels_A, A_ref_bbox = self.transform(images_A, labels_A, A_ref_bbox) + A_ref_label = labels_A[0] images_A = torch.stack(images_A) - labels_A = torch.stack(labels_A) if self.use_domain_B: @@ -222,7 +222,7 @@ def get_img( get_crop_coordinates=True, ) - cur_B_img, cur_B_label, ref_B_bbox = crop_image( + cur_B_img, cur_B_label, B_bbox = crop_image( cur_B_img_path, cur_B_label_path, mask_delta=mask_delta_B, @@ -236,6 +236,8 @@ def get_img( crop_coordinates=crop_coordinates, fixed_mask_size=self.opt.data_online_fixed_mask_size, ) + if i == 0: + B_ref_bbox = B_bbox[1:] except Exception as e: print(e, f"{i+1}th frame of domain B in temporal dataloading") @@ -244,10 +246,11 @@ def get_img( images_B.append(cur_B_img) labels_B.append(cur_B_label) - images_B, labels_B = self.transform(images_B, labels_B) - + images_B, labels_B, B_ref_bbox = self.transform( + images_B, labels_B, B_ref_bbox + ) + B_ref_label = labels_B[0] images_B = torch.stack(images_B) - labels_B = torch.stack(labels_B) else: @@ -258,14 +261,18 @@ def get_img( result = { "A": images_A, "A_img_paths": ref_A_img_path, + "A_ref_bbox": A_ref_bbox, "B": images_B, "B_img_paths": ref_B_img_path, + "B_ref_bbox": B_ref_bbox, } result.update( { - "A_label_mask": labels_A, - "B_label_mask": labels_B, + "A_label_masks": labels_A, + "B_label_masks": labels_B, + "A_label_mask": A_ref_label, + "B_label_mask": B_ref_label, } ) diff --git a/models/base_gan_model.py b/models/base_gan_model.py index 8b69a7a7e..b420c7cdd 100644 --- a/models/base_gan_model.py +++ b/models/base_gan_model.py @@ -208,7 +208,6 @@ def forward_GAN(self): if self.use_temporal: self.compute_temporal_fake(objective_domain="B") - if hasattr(self, "netG_B"): self.compute_temporal_fake(objective_domain="A") @@ -480,35 +479,6 @@ def compute_G_loss_GAN(self): if self.opt.train_temporal_criterion: self.compute_temporal_criterion_loss() - def compute_fake_with_context(self, fake_name, real_name): - setattr( - self, - fake_name + "_with_context", - torch.nn.functional.pad( - getattr(self, fake_name), - ( - self.opt.data_online_context_pixels, - self.opt.data_online_context_pixels, - self.opt.data_online_context_pixels, - self.opt.data_online_context_pixels, - ), - ), - ) - - setattr( - self, - fake_name + "_with_context", - getattr(self, fake_name + "_with_context") - + self.mask_context * getattr(self, real_name + "_with_context"), - ) - setattr( - self, - fake_name + "_with_context_vis", - torch.nn.functional.interpolate( - getattr(self, fake_name + "_with_context"), size=self.real_A.shape[2:] - ), - ) - # compute_real_fake_with_depth def compute_fake_real_with_depth(self, fake_name, real_name): fake_depth = predict_depth( diff --git a/models/base_model.py b/models/base_model.py index e4f5da393..07fbcd458 100644 --- a/models/base_model.py +++ b/models/base_model.py @@ -147,7 +147,6 @@ def __init__(self, opt, rank): self.lpips_metric = LPIPS().to(self.device) def init_metrics(self, dataloader_test): - self.use_inception = any( metric in self.opt.train_metrics_list for metric in ["KID", "FID", "MSID"] ) @@ -386,9 +385,9 @@ def set_input(self, data): self.opt.data_online_context_pixels : -self.opt.data_online_context_pixels, ] - self.real_B_with_context_vis = torch.nn.functional.interpolate( - self.real_B_with_context, size=self.real_A.shape[2:] - ) + self.real_B_with_context_vis = torch.nn.functional.interpolate( + self.real_B_with_context, size=self.real_B.shape[2:] + ) self.image_paths = data["A_img_paths"] @@ -492,7 +491,8 @@ def set_input_temporal(self, data_temporal): "temporal_real_A_" + str(i) + "_with_context_vis", torch.nn.functional.interpolate( getattr(self, "temporal_real_A_" + str(i) + "_with_context"), - size=self.real_A.shape[2:], + # size=self.temporal_real_A.shape[2:] + size=getattr(self, "temporal_real_A_" + str(i)).shape[2:], ), ) else: @@ -531,7 +531,8 @@ def set_input_temporal(self, data_temporal): "temporal_real_B_" + str(i) + "_with_context_vis", torch.nn.functional.interpolate( getattr(self, "temporal_real_B_" + str(i) + "_with_context"), - size=self.real_B.shape[2:], + # size=self.real_B.shape[2:], + size=getattr(self, "temporal_real_B_" + str(i)).shape[2:], ), ) else: @@ -544,10 +545,92 @@ def set_input_temporal(self, data_temporal): ], ) + self.image_paths = data_temporal["A_img_paths"] + + self.input_A_ref_bbox = None # set by semantics + self.input_B_ref_bbox = None + + # first image from temporal sequence is the reference image + self.real_A = self.temporal_real_A_0.clone() + self.real_A_with_context = self.temporal_real_A_0_with_context.clone() + self.real_B = self.temporal_real_B_0.clone() + self.real_B_with_context = self.temporal_real_B_0_with_context.clone() + if self.opt.data_online_context_pixels > 0: + self.real_A_with_context_vis = torch.nn.functional.interpolate( + self.real_A_with_context, size=self.real_A.shape[2:] + ) + self.real_B_with_context_vis = torch.nn.functional.interpolate( + self.real_B_with_context, size=self.real_B.shape[2:] + ) + + if self.opt.train_semantic_mask: + self.set_input_semantic_mask(data_temporal) + if self.opt.train_semantic_cls: + self.set_input_semantic_cls(data_temporal) + def forward(self): for forward_function in self.forward_functions: getattr(self, forward_function)() + def compute_fake_with_context(self, fake_name, real_name): + setattr( + self, + fake_name + "_with_context", + torch.nn.functional.pad( + getattr(self, fake_name), + ( + self.opt.data_online_context_pixels, + self.opt.data_online_context_pixels, + self.opt.data_online_context_pixels, + self.opt.data_online_context_pixels, + ), + ), + ) + + setattr( + self, + fake_name + "_with_context", + getattr(self, fake_name + "_with_context") + + self.mask_context * getattr(self, real_name + "_with_context"), + ) + setattr( + self, + fake_name + "_with_context_vis", + torch.nn.functional.interpolate( + getattr(self, fake_name + "_with_context"), size=self.real_A.shape[2:] + ), + ) + + def compute_temporal_fake_with_context(self, fake_name, real_name): + setattr( + self, + fake_name + "_with_context", + torch.nn.functional.pad( + getattr(self, fake_name), + ( + self.opt.data_online_context_pixels, + self.opt.data_online_context_pixels, + self.opt.data_online_context_pixels, + self.opt.data_online_context_pixels, + ), + ), + ) + + setattr( + self, + fake_name + "_with_context", + getattr(self, fake_name + "_with_context") + + self.mask_context * getattr(self, real_name + "_with_context"), + ) + setattr( + self, + fake_name + "_with_context_vis", + torch.nn.functional.interpolate( + getattr(self, fake_name + "_with_context"), + size=getattr(self, real_name).shape[2:], + ), + ) + def compute_temporal_fake(self, objective_domain): origin_domain = "B" if objective_domain == "A" else "A" netG = getattr(self, "netG_" + origin_domain) @@ -567,7 +650,7 @@ def compute_temporal_fake(self, objective_domain): temporal_fake[:, i], ) if self.opt.data_online_context_pixels > 0: - self.compute_fake_with_context( + self.compute_temporal_fake_with_context( fake_name="temporal_fake_" + objective_domain + "_" + str(i), real_name="temporal_real_" + origin_domain + "_" + str(i), ) @@ -841,7 +924,6 @@ def load_networks(self, epoch): for key in list( state_dict.keys() ): # need to copy keys here because we mutate in loop - self.__patch_instance_norm_state_dict( state_dict, net, key.split(".") ) @@ -972,7 +1054,6 @@ def optimize_parameters(self): True # automatic fall back to eager mode ) - # print(self.niter,self.opt.train_iter_size, self.niter % self.opt.train_iter_size, self.niter % self.opt.train_iter_size != 0, ) if len(self.opt.gpu_ids) > 1 and self.niter % self.opt.train_iter_size != 0: for network in self.model_names: stack.enter_context(getattr(self, "net" + network).no_sync()) @@ -1337,7 +1418,6 @@ def get_current_metrics(self): return metrics def compute_metrics_test(self, dataloaders_test, n_epoch, n_iter): - dims = 2048 batch = 1 @@ -1363,18 +1443,13 @@ def compute_metrics_test(self, dataloaders_test, n_epoch, n_iter): ): # inner loop (minibatch) within one epoch data_test = data_test_list[0] - use_temporal = ( - "temporal" in self.opt.D_netDs - ) or self.opt.train_temporal_criterion - - if use_temporal: + if self.use_temporal: temporal_data_test = data_test_list[1] - - self.set_input( - data_test - ) # unpack data from dataloader and apply preprocessing - if use_temporal: self.set_input_temporal(temporal_data_test) + else: + self.set_input( + data_test + ) # unpack data from dataloader and apply preprocessing self.inference() @@ -1446,7 +1521,6 @@ def compute_metrics_test(self, dataloaders_test, n_epoch, n_iter): self.lpips_test = self.lpips_metric(real_tensor, fake_tensor).mean() def compute_metrics_generic(self, real_act, fake_act): - # FID if "FID" in self.opt.train_metrics_list: fid = self.fid_metric(real_act, fake_act) @@ -1478,7 +1552,10 @@ def compute_metrics_generic(self, real_act, fake_act): return fid, msid, kid def set_input_first_gpu(self, data): - self.set_input(data) + if self.use_temporal: + self.set_input_temporal(data) + else: + self.set_input(data) self.bs_per_gpu = self.real_A.size(0) self.real_A = self.real_A[: self.bs_per_gpu] self.real_B = self.real_B[: self.bs_per_gpu] diff --git a/models/cut_model.py b/models/cut_model.py index 05e3719cf..fb0369a5e 100644 --- a/models/cut_model.py +++ b/models/cut_model.py @@ -423,6 +423,7 @@ def data_dependent_initialize(self, data): real_A_with_z = torch.cat([self.real_A, z_real], 1) else: real_A_with_z = self.real_A + feat_temp = self.netG_A.get_feats(real_A_with_z.cpu(), self.nce_layers) self.netF.data_dependent_initialize(feat_temp) @@ -532,6 +533,10 @@ def forward_cut(self): self.fake_B = self.fake[: self.real_A.size(0)] if self.opt.data_online_context_pixels > 0: + if self.use_temporal: + self.compute_temporal_fake_with_context( + fake_name="temporal_fake_B_0", real_name="temporal_real_A_0" + ) self.compute_fake_with_context(fake_name="fake_B", real_name="real_A") if self.use_depth: @@ -550,6 +555,9 @@ def forward_cut(self): if self.opt.data_online_context_pixels > 0: context = "_with_context" + # if self.use_temporal: + # names = ["temporal_fake_B_0", "temporal_real_B_0"] + # else: names = ["fake_B", "real_B"] for name in names: setattr( diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py index 25a1882e2..bc59bd66b 100644 --- a/models/cycle_gan_model.py +++ b/models/cycle_gan_model.py @@ -278,6 +278,10 @@ def forward_cycle_gan(self): self.fake_B = self.netG_A(self.real_A) # G_A(A) if self.opt.data_online_context_pixels > 0: + if self.use_temporal: + self.compute_temporal_fake_with_context( + fake_name="temporal_fake_B_0", real_name="temporal_real_A_0" + ) self.compute_fake_with_context(fake_name="fake_B", real_name="real_A") # Rec A diff --git a/scripts/run_tests.sh b/scripts/run_tests.sh index cce4db8e0..aedf499e5 100644 --- a/scripts/run_tests.sh +++ b/scripts/run_tests.sh @@ -84,9 +84,9 @@ fi ####### mask cls semantics test with online dataloading echo "Running mask online semantics training tests" -URL=https://joligen.com/datasets/online_mario2sonic_lite.zip -ZIP_FILE=$DIR/online_mario2sonic_lite.zip -TARGET_MASK_SEM_ONLINE_DIR=$DIR/online_mario2sonic_lite +URL=https://joligen.com/datasets/online_mario2sonic_lite2.zip +ZIP_FILE=$DIR/online_mario2sonic_lite2.zip +TARGET_MASK_SEM_ONLINE_DIR=$DIR/online_mario2sonic_lite2 wget -N $URL -O $ZIP_FILE mkdir $TARGET_MASK_SEM_ONLINE_DIR unzip $ZIP_FILE -d $DIR diff --git a/tests/test_run_semantic_mask_online.py b/tests/test_run_semantic_mask_online.py index a817b40c7..3532db68d 100644 --- a/tests/test_run_semantic_mask_online.py +++ b/tests/test_run_semantic_mask_online.py @@ -26,12 +26,13 @@ "data_online_creation_crop_delta_B": 50, "data_online_creation_load_size_A": [2500, 1000], "data_online_creation_load_size_B": [2500, 1000], + "data_online_context_pixels": 0, "train_n_epochs": 1, "train_n_epochs_decay": 0, "data_max_dataset_size": 10, "train_mask_out_mask": True, "f_s_net": "unet", - "f_s_semantic_nclasses": 2, + "f_s_semantic_nclasses": 7, "dataaug_D_noise": 0.001, "train_sem_use_label_B": True, "data_relative_paths": True, @@ -43,10 +44,10 @@ "dataaug_no_rotate": True, "train_mask_compute_miou": True, "train_mask_miou_every": 1, - "data_temporal_number_frames": 2, + "data_temporal_number_frames": 4, "data_temporal_frame_step": 2, "train_semantic_mask": True, - "train_temporal_criterion": True, + "train_temporal_criterion": False, "train_export_jit": True, "train_save_latest_freq": 10, } @@ -60,14 +61,26 @@ D_proj_network_type = ["efficientnet", "vitsmall"] -D_netDs = [["basic", "projected_d", "temporal"], ["sam"]] +D_netDs = [ + ["basic", "projected_d"], + ["basic", "projected_d", "temporal"], + ["projected_d", "sam"], +] f_s_net = ["unet"] model_type_sam = ["mobile_sam"] +data_online_context_pixels = [0, 10] + product_list = product( - models_semantic_mask, G_netG, D_proj_network_type, D_netDs, f_s_net, model_type_sam + models_semantic_mask, + G_netG, + D_proj_network_type, + D_netDs, + f_s_net, + model_type_sam, + data_online_context_pixels, ) @@ -75,7 +88,15 @@ def test_semantic_mask_online(dataroot): json_like_dict["dataroot"] = dataroot json_like_dict["checkpoints_dir"] = "/".join(dataroot.split("/")[:-1]) - for model, Gtype, Dtype, Dnet, f_s_type, sam_type in product_list: + for ( + model, + Gtype, + Dtype, + Dnet, + f_s_type, + sam_type, + data_online_context_pixels, + ) in product_list: if model == "cycle_gan" and "sam" in Dnet: continue json_like_dict_c = json_like_dict.copy() @@ -84,8 +105,11 @@ def test_semantic_mask_online(dataroot): json_like_dict_c["G_netG"] = Gtype json_like_dict_c["D_proj_network_type"] = Dtype json_like_dict_c["D_netDs"] = Dnet + if "temporal" in Dnet: + json_like_dict_c["data_dataset_mode"] = "temporal_labeled_mask_online" json_like_dict_c["f_s_net"] = f_s_type json_like_dict_c["model_type_sam"] = sam_type + json_like_dict_c["data_online_context_pixels"] = data_online_context_pixels opt = TrainOptions().parse_json(json_like_dict_c, save_config=True) train.launch_training(opt) diff --git a/train.py b/train.py index 46717c9b6..3056cbf0b 100644 --- a/train.py +++ b/train.py @@ -190,15 +190,17 @@ def train_gpu(rank, world_size, opt, trainset, trainset_temporal): dataloaders ): # inner loop (minibatch) within one epoch data = data_list[0] - if use_temporal: - temporal_data = data_list[1] iter_start_time = time.time() # timer for computation per iteration t_data_mini_batch = iter_start_time - iter_data_time - model.set_input(data) # unpack data from dataloader and apply preprocessing if use_temporal: + temporal_data = data_list[1] model.set_input_temporal(temporal_data) + else: + model.set_input( + data + ) # unpack data from dataloader and apply preprocessing model.optimize_parameters() # calculate loss functions, get gradients, update network weights t_comp = (time.time() - iter_start_time) / opt.train_batch_size