diff --git a/data/base_dataset.py b/data/base_dataset.py index bec36ff30..633d2679d 100644 --- a/data/base_dataset.py +++ b/data/base_dataset.py @@ -20,6 +20,9 @@ import torchvision.transforms.functional as F +from data.image_folder import make_dataset, make_dataset_path, make_labeled_path_dataset +from data.online_creation import sanitize_paths, write_paths_file + from abc import ABC, abstractmethod import imgaug as ia import imgaug.augmenters as iaa @@ -137,7 +140,7 @@ def __getitem__(self, index): if B_label_mask_path is not None: B_label_mask_path = os.path.join(self.root, B_label_mask_path) - return self.get_img( + results = self.get_img( A_img_path, A_label_mask_path, A_label_cls, @@ -147,6 +150,8 @@ def __getitem__(self, index): index, ) + return results + def set_dataset_dirs_and_dims(self): btoA = self.opt.data_direction == "BtoA" self.input_nc = ( @@ -249,6 +254,81 @@ def get_validation_set(self, size): return return_A_list, return_B_list + def sanitize(self): + paths_sanitized_train_A = os.path.join( + self.sv_dir, "paths_sanitized_train_A.txt" + ) + if hasattr(self, "B_img_paths"): + paths_sanitized_train_B = os.path.join( + self.sv_dir, "paths_sanitized_train_B.txt" + ) + if hasattr(self, "B_img_paths"): + train_sanitized_exist = os.path.exists( + paths_sanitized_train_A + ) and os.path.exists(paths_sanitized_train_B) + else: + train_sanitized_exist = os.path.exists(paths_sanitized_train_A) + + if train_sanitized_exist: + self.A_img_paths, self.A_label_mask_paths = make_labeled_path_dataset( + self.sv_dir, "/paths_sanitized_train_A.txt" + ) + if hasattr(self, "B_img_paths"): + self.B_img_paths, self.B_label_mask_paths = make_labeled_path_dataset( + self.sv_dir, "/paths_sanitized_train_B.txt" + ) + else: + print("--------------") + print("Sanitizing images and labels paths") + print("--- DOMAIN A ---") + + self.A_img_paths, self.A_label_mask_paths = sanitize_paths( + self.A_img_paths, + self.A_label_mask_paths, + mask_delta=self.opt.data_online_creation_mask_delta_A, + mask_random_offset=self.opt.data_online_creation_mask_random_offset_A, + crop_delta=self.opt.data_online_creation_crop_delta_A, + mask_square=self.opt.data_online_creation_mask_square_A, + crop_dim=self.opt.data_online_creation_crop_size_A, + output_dim=self.opt.data_load_size, + max_dataset_size=self.opt.data_max_dataset_size, + context_pixels=self.opt.data_online_context_pixels, + load_size=self.opt.data_online_creation_load_size_A, + select_cat=self.opt.data_online_select_category, + data_relative_paths=self.opt.data_relative_paths, + data_root_path=self.opt.dataroot, + ) + write_paths_file( + self.A_img_paths, + self.A_label_mask_paths, + paths_sanitized_train_A, + ) + + print("--- DOMAIN B ---") + if hasattr(self, "B_img_paths"): + self.B_img_paths, self.B_label_mask_paths = sanitize_paths( + self.B_img_paths, + self.B_label_mask_paths, + mask_delta=self.opt.data_online_creation_mask_delta_B, + mask_random_offset=self.opt.data_online_creation_mask_random_offset_B, + crop_delta=self.opt.data_online_creation_crop_delta_B, + mask_square=self.opt.data_online_creation_mask_square_B, + crop_dim=self.opt.data_online_creation_crop_size_B, + output_dim=self.opt.data_load_size, + max_dataset_size=self.opt.data_max_dataset_size, + context_pixels=self.opt.data_online_context_pixels, + load_size=self.opt.data_online_creation_load_size_B, + data_relative_paths=self.opt.data_relative_paths, + data_root_path=self.opt.dataroot, + ) + write_paths_file( + self.B_img_paths, + self.B_label_mask_paths, + paths_sanitized_train_B, + ) + + print("--------------") + def get_params(opt, size): w, h = size diff --git a/data/self_supervised_temporal_dataset.py b/data/self_supervised_temporal_dataset.py index d181c7697..555d0b8cf 100644 --- a/data/self_supervised_temporal_dataset.py +++ b/data/self_supervised_temporal_dataset.py @@ -1,11 +1,10 @@ import torch - -from data.temporal_dataset import TemporalDataset +from data.temporal_labeled_mask_online_dataset import TemporalLabeledMaskOnlineDataset from data.online_creation import fill_mask_with_random, fill_mask_with_color -class SelfSupervisedTemporalDataset(TemporalDataset): +class SelfSupervisedTemporalDataset(TemporalLabeledMaskOnlineDataset): """ This dataset class can create datasets with mask labels from one domain. """ @@ -28,15 +27,17 @@ def get_img( B_label_cls=None, index=None, ): - result = super().get_img( - A_img_path, - A_label_mask_path, - A_label_cls, - B_img_path, - B_label_mask_path, - B_label_cls, - index, - ) + result = None + while result is None: + result = super().get_img( + A_img_path, + A_label_mask_path, + A_label_cls, + B_img_path, + B_label_mask_path, + B_label_cls, + index, + ) try: A_img_list = [result["A"][0]] diff --git a/data/temporal_labeled_mask_online_dataset.py b/data/temporal_labeled_mask_online_dataset.py index bbf5d6e1c..3532efed8 100644 --- a/data/temporal_labeled_mask_online_dataset.py +++ b/data/temporal_labeled_mask_online_dataset.py @@ -48,10 +48,13 @@ def __init__(self, opt, phase): self.B_img_paths.sort(key=natural_keys) self.B_label_mask_paths.sort(key=natural_keys) - self.A_img_paths, self.A_label_mask_paths = ( - self.A_img_paths[: opt.data_max_dataset_size], - self.A_label_mask_paths[: opt.data_max_dataset_size], - ) + if self.opt.data_sanitize_paths: + self.sanitize() + elif opt.data_max_dataset_size != float("inf"): + self.A_img_paths, self.A_label_mask_paths = ( + self.A_img_paths[: opt.data_max_dataset_size], + self.A_label_mask_paths[: opt.data_max_dataset_size], + ) if self.use_domain_B: self.B_img_paths, self.B_label_mask_paths = ( @@ -118,13 +121,11 @@ def get_img( cur_A_label_path = os.path.join(self.root, cur_A_label_path) try: - if ( - len(self.opt.data_online_creation_mask_delta_A_ratio[0]) == 1 - and self.opt.data_online_creation_mask_delta_A_ratio[0][0] == 0 - ): + if self.opt.data_online_creation_mask_delta_A_ratio == [[]]: mask_delta_A = self.opt.data_online_creation_mask_delta_A else: mask_delta_A = self.opt.data_online_creation_mask_delta_A_ratio + if i == 0: crop_coordinates = crop_image( cur_A_img_path, @@ -140,6 +141,7 @@ def get_img( get_crop_coordinates=True, fixed_mask_size=self.opt.data_online_fixed_mask_size, ) + cur_A_img, cur_A_label, ref_A_bbox = crop_image( cur_A_img_path, cur_A_label_path, @@ -201,10 +203,7 @@ def get_img( cur_B_label_path = os.path.join(self.root, cur_B_label_path) try: - if ( - len(self.opt.data_online_creation_mask_delta_B_ratio[0]) == 1 - and self.opt.data_online_creation_mask_delta_B_ratio[0][0] == 0 - ): + if self.opt.data_online_creation_mask_delta_B_ratio == [[]]: mask_delta_B = self.opt.data_online_creation_mask_delta_B else: mask_delta_B = self.opt.data_online_creation_mask_delta_B_ratio @@ -254,7 +253,7 @@ def get_img( else: images_B = None labels_B = None - ref_B_img_path = None + ref_B_img_path = "" result = { "A": images_A, diff --git a/data/unaligned_labeled_mask_online_dataset.py b/data/unaligned_labeled_mask_online_dataset.py index 77510ff46..f2a8b2558 100644 --- a/data/unaligned_labeled_mask_online_dataset.py +++ b/data/unaligned_labeled_mask_online_dataset.py @@ -10,7 +10,7 @@ from data.base_dataset import BaseDataset, get_transform, get_transform_seg from data.image_folder import make_dataset, make_dataset_path, make_labeled_path_dataset -from data.online_creation import crop_image, sanitize_paths, write_paths_file +from data.online_creation import crop_image class UnalignedLabeledMaskOnlineDataset(BaseDataset): @@ -77,81 +77,6 @@ def __init__(self, opt, phase): self.header = ["img", "mask"] - def sanitize(self): - paths_sanitized_train_A = os.path.join( - self.sv_dir, "paths_sanitized_train_A.txt" - ) - if hasattr(self, "B_img_paths"): - paths_sanitized_train_B = os.path.join( - self.sv_dir, "paths_sanitized_train_B.txt" - ) - if hasattr(self, "B_img_paths"): - train_sanitized_exist = os.path.exists( - paths_sanitized_train_A - ) and os.path.exists(paths_sanitized_train_B) - else: - train_sanitized_exist = os.path.exists(paths_sanitized_train_A) - - if train_sanitized_exist: - self.A_img_paths, self.A_label_mask_paths = make_labeled_path_dataset( - self.sv_dir, "/paths_sanitized_train_A.txt" - ) - if hasattr(self, "B_img_paths"): - self.B_img_paths, self.B_label_mask_paths = make_labeled_path_dataset( - self.sv_dir, "/paths_sanitized_train_B.txt" - ) - else: - print("--------------") - print("Sanitizing images and labels paths") - print("--- DOMAIN A ---") - - self.A_img_paths, self.A_label_mask_paths = sanitize_paths( - self.A_img_paths, - self.A_label_mask_paths, - mask_delta=self.opt.data_online_creation_mask_delta_A, - mask_random_offset=self.opt.data_online_creation_mask_random_offset_A, - crop_delta=self.opt.data_online_creation_crop_delta_A, - mask_square=self.opt.data_online_creation_mask_square_A, - crop_dim=self.opt.data_online_creation_crop_size_A, - output_dim=self.opt.data_load_size, - max_dataset_size=self.opt.data_max_dataset_size, - context_pixels=self.opt.data_online_context_pixels, - load_size=self.opt.data_online_creation_load_size_A, - select_cat=self.opt.data_online_select_category, - data_relative_paths=self.opt.data_relative_paths, - data_root_path=self.opt.dataroot, - ) - write_paths_file( - self.A_img_paths, - self.A_label_mask_paths, - paths_sanitized_train_A, - ) - - print("--- DOMAIN B ---") - if hasattr(self, "B_img_paths"): - self.B_img_paths, self.B_label_mask_paths = sanitize_paths( - self.B_img_paths, - self.B_label_mask_paths, - mask_delta=self.opt.data_online_creation_mask_delta_B, - mask_random_offset=self.opt.data_online_creation_mask_random_offset_B, - crop_delta=self.opt.data_online_creation_crop_delta_B, - mask_square=self.opt.data_online_creation_mask_square_B, - crop_dim=self.opt.data_online_creation_crop_size_B, - output_dim=self.opt.data_load_size, - max_dataset_size=self.opt.data_max_dataset_size, - context_pixels=self.opt.data_online_context_pixels, - load_size=self.opt.data_online_creation_load_size_B, - data_relative_paths=self.opt.data_relative_paths, - data_root_path=self.opt.root, - ) - write_paths_file( - self.B_img_paths, - self.B_label_mask_paths, - paths_sanitized_train_B, - ) - - print("--------------") - def get_img( self, A_img_path, @@ -164,12 +89,8 @@ def get_img( clamp_semantics=True, ): # Domain A - try: - if ( - len(self.opt.data_online_creation_mask_delta_A_ratio[0]) == 1 - and self.opt.data_online_creation_mask_delta_A_ratio[0][0] == 0 - ): + if self.opt.data_online_creation_mask_delta_A_ratio == [[]]: mask_delta_A = self.opt.data_online_creation_mask_delta_A else: mask_delta_A = self.opt.data_online_creation_mask_delta_A_ratio @@ -218,10 +139,7 @@ def get_img( # Domain B if B_img_path is not None: try: - if ( - len(self.opt.data_online_creation_mask_delta_B_ratio[0]) == 1 - and self.opt.data_online_creation_mask_delta_B_ratio[0][0] == 0 - ): + if self.opt.data_online_creation_mask_delta_B_ratio == [[]]: mask_delta_B = self.opt.data_online_creation_mask_delta_B else: mask_delta_B = self.opt.data_online_creation_mask_delta_B_ratio diff --git a/docker/Dockerfile.devel b/docker/Dockerfile.devel index a2ca5041a..5cd9bda87 100644 --- a/docker/Dockerfile.devel +++ b/docker/Dockerfile.devel @@ -13,6 +13,7 @@ RUN export DEBIAN_FRONTEND=noninteractive && \ python3-pip \ python3-opencv \ python3-pytest \ + ninja-build \ sudo \ wget \ git \ diff --git a/docs/options.md b/docs/options.md index c05ff4e2c..dfdf21e3f 100644 --- a/docs/options.md +++ b/docs/options.md @@ -37,14 +37,14 @@ Here are all the available options to call with `train.py` | --D_spectral | flag | | whether to use spectral norm in the discriminator | | --D_temporal_every | int | 4 | apply temporal discriminator every x steps | | --D_vision_aided_backbones | string | clip+dino+swin | specify vision aided discriminators architectures, they are frozen then output are combined and fitted with a linear network on top, choose from dino, clip, swin, det_coco, seg_ade and combine them with + | -| --D_weight_sam | string | | path to sam weight for D, e.g. models/configs/sam/pretrain/sam_vit_b_01ec64.pth | +| --D_weight_sam | string | | path to sam weight for D, e.g. models/configs/sam/pretrain/sam_vit_b_01ec64.pth, or models/configs/sam/pretrain/mobile_sam.pt for MobileSAM | ## Generator | Parameter | Type | Default | Description | | --- | --- | --- | --- | -| --G_attn_nb_mask_attn | int | 10 | | -| --G_attn_nb_mask_input | int | 1 | | +| --G_attn_nb_mask_attn | int | 10 | number of attention masks in _attn model architectures | +| --G_attn_nb_mask_input | int | 1 | number of mask dedicated to input in _attn model architectures | | --G_backward_compatibility_twice_resnet_blocks | flag | | if true, feats will go througt resnet blocks two times for resnet_attn generators. This option will be deleted, it's for backward compatibility (old models were trained that way). | | --G_config_segformer | string | models/configs/segformer/segformer_config_b0.json | path to segformer configuration file for G | | --G_diff_n_timestep_test | int | 1000 | Number of timesteps used for UNET mha inference (test time). | @@ -62,8 +62,8 @@ Here are all the available options to call with `train.py` | --G_unet_mha_channel_mults | array | [1, 2, 4, 8] | channel multiplier for each level of the UNET mha | | --G_unet_mha_group_norm_size | int | 32 | | | --G_unet_mha_norm_layer | string | groupnorm |

**Values:** groupnorm, batchnorm, layernorm, instancenorm, switchablenorm | -| --G_unet_mha_num_head_channels | int | 32 | | -| --G_unet_mha_num_heads | int | 1 | | +| --G_unet_mha_num_head_channels | int | 32 | number of channels in each head of the mha architecture | +| --G_unet_mha_num_heads | int | 1 | number of heads in the mha architecture | | --G_unet_mha_res_blocks | array | [2, 2, 2, 2] | distribution of resnet blocks across the UNet stages, should have same size as --G_unet_mha_channel_mults | | --G_unet_mha_vit_efficient | flag | | if true, use efficient attention in UNet and UViT | | --G_uvit_num_transformer_blocks | int | 6 | Number of transformer blocks in UViT | @@ -208,7 +208,7 @@ Here are all the available options to call with `train.py` | --f_s_nf | int | 64 | \# of filters in the first conv layer of classifier | | --f_s_semantic_nclasses | int | 2 | number of classes of the semantic loss classifier | | --f_s_semantic_threshold | float | 1.0 | threshold of the semantic classifier loss below with semantic loss is applied | -| --f_s_weight_sam | string | | path to sam weight for f_s, e.g. models/configs/sam/pretrain/sam_vit_b_01ec64.pth | +| --f_s_weight_sam | string | | path to sam weight for f_s, e.g. models/configs/sam/pretrain/sam_vit_b_01ec64.pth, or models/configs/sam/pretrain/mobile_sam.pt for MobileSAM | | --f_s_weight_segformer | string | | path to segformer weight for f_s, e.g. models/configs/segformer/pretrain/segformer_mit-b0.pth | ## Semantic classification network @@ -264,6 +264,7 @@ Here are all the available options to call with `train.py` | --model_multimodal | flag | | multimodal model with random latent input vector | | --model_output_nc | int | 3 | \# of output image channels: 3 for RGB and 1 for grayscale

**Values:** 1, 3 | | --model_prior_321_backwardcompatibility | flag | | whether to load models from previous version of JG. | +| --model_type_sam | string | mobile_sam | which model to use for segment-anything mask generation

**Values:** sam, mobile_sam | ## Training @@ -280,18 +281,19 @@ Here are all the available options to call with `train.py` | --train_cls_l1_regression | flag | | if true l1 loss will be used to compute regressor loss | | --train_cls_regression | flag | | if true cls will be a regressor and not a classifier | | --train_compute_D_accuracy | flag | | whether to compute D accuracy explicitely | -| --train_compute_metrics_test | flag | | | +| --train_compute_metrics_test | flag | | whether to compute test metrics, e.g. FID, ... | | --train_continue | flag | | continue training: load the latest model | | --train_epoch | string | latest | which epoch to load? set to latest to use latest cached model | | --train_epoch_count | int | 1 | the starting epoch count, we save the model by \, \+\, ... | | --train_export_jit | flag | | whether to export model in jit format | +| --train_feat_wavelet | flag | | if true, train in wavelet features space (Note: this may not include all discriminators, when training GANs) | | --train_gan_mode | string | lsgan | the type of GAN objective. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.

**Values:** vanilla, lsgan, wgangp, projected | | --train_iter_size | int | 1 | backward will be apllied each iter_size iterations, it simulate a greater batch size : its value is batch_size\*iter_size | | --train_load_iter | int | 0 | which iteration to load? if load_iter \> 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch] | | --train_lr_decay_iters | int | 50 | multiply by a gamma every lr_decay_iters iterations | | --train_lr_policy | string | linear | learning rate policy.

**Values:** linear, step, plateau, cosine | -| --train_metrics_every | int | 1000 | | -| --train_metrics_list | array | ['FID'] |

**Values:** FID, KID, MSID, PSNR | +| --train_metrics_every | int | 1000 | compute metrics every N iterations | +| --train_metrics_list | array | ['FID'] | metrics on results quality to compute

**Values:** FID, KID, MSID, PSNR, LPIPS | | --train_mm_lambda_z | float | 0.5 | weight for random z loss | | --train_mm_nz | int | 8 | number of latent vectors | | --train_n_epochs | int | 100 | number of epochs with the initial learning rate | @@ -328,13 +330,13 @@ Here are all the available options to call with `train.py` | Parameter | Type | Default | Description | | --- | --- | --- | --- | | --train_mask_charbonnier_eps | float | 1e-06 | Charbonnier loss epsilon value | -| --train_mask_compute_miou | flag | | | +| --train_mask_compute_miou | flag | | whether to compute mIoU on semantic masks prediction | | --train_mask_disjoint_f_s | flag | | whether to use a disjoint f_s with the same exact structure | | --train_mask_f_s_B | flag | | if true f_s will be trained not only on domain A but also on domain B | | --train_mask_for_removal | flag | | if true, object removal mode, domain B images with label 0, cut models only | | --train_mask_lambda_out_mask | float | 10.0 | weight for loss out mask | | --train_mask_loss_out_mask | string | L1 | loss for out mask content (which should not change).

**Values:** L1, MSE, Charbonnier | -| --train_mask_miou_every | int | 1000 | | +| --train_mask_miou_every | int | 1000 | compute mIoU every n iterations | | --train_mask_no_train_f_s_A | flag | | if true f_s wont be trained on domain A | | --train_mask_out_mask | flag | | use loss out mask | diff --git a/docs/source/_static/openapi.json b/docs/source/_static/openapi.json index 04c32d328..db2c89680 100644 --- a/docs/source/_static/openapi.json +++ b/docs/source/_static/openapi.json @@ -1 +1 @@ -{"openapi":"3.1.0","info":{"title":"JoliGEN server","description":"*commit:* [afdde75f](https://github.com/jolibrain/joliGEN/commit/afdde75fe61e376a845cee49ff004fb5d9951bee)\n\nThis is the JoliGEN server API documentation.\n","version":"0.1.0"},"paths":{"/train/{name}":{"get":{"summary":"Get the status of a training process","operationId":"get_train_train__name__get","parameters":[{"required":true,"schema":{"type":"string","title":"Name"},"name":"name","in":"path"}],"responses":{"200":{"description":"Successful Response","content":{"application/json":{"schema":{}}}},"422":{"description":"Validation Error","content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}}}}},"post":{"summary":"Start a training process with given name.","description":"The training process will be created using the same options as command line","operationId":"train_train__name__post","parameters":[{"required":true,"schema":{"type":"string","title":"Name"},"name":"name","in":"path"}],"requestBody":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/TrainOptions"}}}},"responses":{"201":{"description":"Successful Response","content":{"application/json":{"schema":{}}}},"422":{"description":"Validation Error","content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}}}}},"delete":{"summary":"Delete a training process.","description":"If the process is running, it will be stopped.","operationId":"delete_train_train__name__delete","parameters":[{"required":true,"schema":{"type":"string","title":"Name"},"name":"name","in":"path"}],"responses":{"200":{"description":"Successful Response","content":{"application/json":{"schema":{}}}},"422":{"description":"Validation Error","content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}}}}}},"/train":{"get":{"summary":"Get the status of all training processes","operationId":"get_train_processes_train_get","responses":{"200":{"description":"Successful Response","content":{"application/json":{"schema":{}}}}}}},"/info":{"get":{"summary":"Get the server status","operationId":"get_info_info_get","responses":{"200":{"description":"Successful Response","content":{"application/json":{"schema":{}}}}}}},"/fs/":{"delete":{"summary":"Delete a file or a directory in the filesystem","description":"This endpoint can be dangerous, use it with extreme caution","operationId":"delete_path_fs__delete","parameters":[{"required":true,"schema":{"type":"string","title":"Path"},"name":"path","in":"query"}],"responses":{"200":{"description":"Successful Response","content":{"application/json":{"schema":{}}}},"422":{"description":"Validation Error","content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}}}}}}},"components":{"schemas":{"HTTPValidationError":{"properties":{"detail":{"items":{"$ref":"#/components/schemas/ValidationError"},"type":"array","title":"Detail"}},"type":"object","title":"HTTPValidationError"},"ValidationError":{"properties":{"loc":{"items":{"anyOf":[{"type":"string"},{"type":"integer"}]},"type":"array","title":"Location"},"msg":{"type":"string","title":"Message"},"type":{"type":"string","title":"Error Type"}},"type":"object","required":["loc","msg","type"],"title":"ValidationError"},"TrainOptions":{"title":"TrainBody","type":"object","properties":{"server":{"title":"Server","default":{"sync":false},"allOf":[{"$ref":"#/definitions/ServerTrainOptions"}]},"train_options":{"title":"TrainOptions","type":"object","properties":{"D":{"title":"Discriminator","type":"object","properties":{"dropout":{"default":false,"type":"boolean","description":"whether to use dropout in the discriminator"},"n_layers":{"default":3,"type":"integer","description":"only used if netD==n_layers"},"ndf":{"default":64,"type":"integer","description":"\\# of discrim filters in the first conv layer"},"netDs":{"default":["projected_d","basic"],"type":"array","items":{"enum":null,"type":"string"},"description":"specify discriminator architecture, another option, --D_n_layers allows you to specify the layers in the n_layers discriminator. NB: duplicated arguments are ignored. Values: basic, n_layers, pixel, projected_d, temporal, vision_aided, depth, mask, sam"},"no_antialias":{"default":false,"type":"boolean","description":"if specified, use stride=2 convs instead of antialiased-downsampling (sad)"},"no_antialias_up":{"default":false,"type":"boolean","description":"if specified, use [upconv(learned filter)] instead of [upconv(hard-coded [1,3,3,1] filter), conv]"},"norm":{"default":"instance","type":"string","description":"instance normalization or batch normalization for D","enum":["instance","batch","none"]},"proj_config_segformer":{"default":"models/configs/segformer/segformer_config_b0.json","type":"string","description":"path to segformer configuration file"},"proj_interp":{"default":-1,"type":"integer","description":"whether to force projected discriminator interpolation to a value \\> 224, -1 means no interpolation"},"proj_network_type":{"default":"efficientnet","type":"string","description":"projected discriminator architecture","enum":["efficientnet","segformer","vitbase","vitsmall","vitsmall2","vitclip16","depth"]},"proj_weight_segformer":{"default":"models/configs/segformer/pretrain/segformer_mit-b0.pth","type":"string","description":"path to segformer weight"},"spectral":{"default":false,"type":"boolean","description":"whether to use spectral norm in the discriminator"},"temporal_every":{"default":4,"type":"integer","description":"apply temporal discriminator every x steps"},"vision_aided_backbones":{"default":"clip+dino+swin","type":"string","description":"specify vision aided discriminators architectures, they are frozen then output are combined and fitted with a linear network on top, choose from dino, clip, swin, det_coco, seg_ade and combine them with +"},"weight_sam":{"default":"","type":"string","description":"path to sam weight for D, e.g. models/configs/sam/pretrain/sam_vit_b_01ec64.pth"}}},"G":{"title":"Generator","type":"object","properties":{"attn_nb_mask_attn":{"default":10,"type":"integer","description":""},"attn_nb_mask_input":{"default":1,"type":"integer","description":""},"backward_compatibility_twice_resnet_blocks":{"default":false,"type":"boolean","description":"if true, feats will go througt resnet blocks two times for resnet_attn generators. This option will be deleted, it's for backward compatibility (old models were trained that way)."},"config_segformer":{"default":"models/configs/segformer/segformer_config_b0.json","type":"string","description":"path to segformer configuration file for G"},"diff_n_timestep_test":{"default":1000,"type":"integer","description":"Number of timesteps used for UNET mha inference (test time)."},"diff_n_timestep_train":{"default":2000,"type":"integer","description":"Number of timesteps used for UNET mha training."},"dropout":{"default":false,"type":"boolean","description":"dropout for the generator"},"nblocks":{"default":9,"type":"integer","description":"\\# of layer blocks in G, applicable to resnets"},"netE":{"default":"resnet_256","type":"string","description":"specify multimodal latent vector encoder","enum":["resnet_128","resnet_256","resnet_512","conv_128","conv_256","conv_512"]},"netG":{"default":"mobile_resnet_attn","type":"string","description":"specify generator architecture","enum":["resnet","resnet_attn","mobile_resnet","mobile_resnet_attn","unet_256","unet_128","stylegan2","smallstylegan2","segformer_attn_conv","segformer_conv","ittr","unet_mha","uvit"]},"ngf":{"default":64,"type":"integer","description":"\\# of gen filters in the last conv layer"},"norm":{"default":"instance","type":"string","description":"instance normalization or batch normalization for G","enum":["instance","batch","none"]},"padding_type":{"default":"reflect","type":"string","description":"whether to use padding in the generator","enum":["reflect","replicate","zeros"]},"spectral":{"default":false,"type":"boolean","description":"whether to use spectral norm in the generator"},"stylegan2_num_downsampling":{"default":1,"type":"integer","description":"Number of downsampling layers used by StyleGAN2Generator"},"unet_mha_attn_res":{"default":[16],"type":"array","items":{"enum":null,"type":"string"},"description":"downrate samples at which attention takes place"},"unet_mha_channel_mults":{"default":[1,2,4,8],"type":"array","items":{"enum":null,"type":"string"},"description":"channel multiplier for each level of the UNET mha"},"unet_mha_group_norm_size":{"default":32,"type":"integer","description":""},"unet_mha_norm_layer":{"default":"groupnorm","type":"string","description":"","enum":["groupnorm","batchnorm","layernorm","instancenorm","switchablenorm"]},"unet_mha_num_head_channels":{"default":32,"type":"integer","description":""},"unet_mha_num_heads":{"default":1,"type":"integer","description":""},"unet_mha_res_blocks":{"default":[2,2,2,2],"type":"array","items":{"enum":null,"type":"string"},"description":"distribution of resnet blocks across the UNet stages, should have same size as --G_unet_mha_channel_mults"},"unet_mha_vit_efficient":{"default":false,"type":"boolean","description":"if true, use efficient attention in UNet and UViT"},"uvit_num_transformer_blocks":{"default":6,"type":"integer","description":"Number of transformer blocks in UViT"}}},"alg":{"title":"Algorithm-specific","type":"object","properties":{"gan":{"title":"GAN model","type":"object","properties":{"lambda":{"default":1.0,"type":"number","description":"weight for GAN loss:GAN(G(X))"}}},"cut":{"title":"CUT model","type":"object","properties":{"HDCE_gamma":{"default":1.0,"type":"number","description":""},"HDCE_gamma_min":{"default":1.0,"type":"number","description":""},"MSE_idt":{"default":false,"type":"boolean","description":"use MSENCE loss for identity mapping: MSE(G(Y), Y))"},"flip_equivariance":{"default":false,"type":"boolean","description":"Enforce flip-equivariance as additional regularization. It's used by FastCUT, but not CUT"},"lambda_MSE_idt":{"default":1.0,"type":"number","description":"weight for MSE identity loss: MSE(G(X), X)"},"lambda_NCE":{"default":1.0,"type":"number","description":"weight for NCE loss: NCE(G(X), X)"},"lambda_SRC":{"default":0.0,"type":"number","description":"weight for SRC (semantic relation consistency) loss: NCE(G(X), X)"},"nce_T":{"default":0.07,"type":"number","description":"temperature for NCE loss"},"nce_idt":{"default":true,"type":"boolean","description":"use NCE loss for identity mapping: NCE(G(Y), Y))"},"nce_includes_all_negatives_from_minibatch":{"default":false,"type":"boolean","description":"(used for single image translation) If True, include the negatives from the other samples of the minibatch when computing the contrastive loss. Please see models/patchnce.py for more details."},"nce_layers":{"default":"0,4,8,12,16","type":"string","description":"compute NCE loss on which layers"},"nce_loss":{"default":"monce","type":"string","description":"CUT contrastice loss","enum":["patchnce","monce","SRC_hDCE"]},"netF":{"default":"mlp_sample","type":"string","description":"how to downsample the feature map","enum":["sample","mlp_sample","sample_qsattn","mlp_sample_qsattn"]},"netF_dropout":{"default":false,"type":"boolean","description":"whether to use dropout with F"},"netF_nc":{"default":256,"type":"integer","description":""},"netF_norm":{"default":"instance","type":"string","description":"instance normalization or batch normalization for F","enum":["instance","batch","none"]},"num_patches":{"default":256,"type":"integer","description":"number of patches per layer"}}},"cyclegan":{"title":"CycleGAN model","type":"object","properties":{"lambda_A":{"default":10.0,"type":"number","description":"weight for cycle loss (A -\\> B -\\> A)"},"lambda_B":{"default":10.0,"type":"number","description":"weight for cycle loss (B -\\> A -\\> B)"},"lambda_identity":{"default":0.5,"type":"number","description":"use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1"},"rec_noise":{"default":0.0,"type":"number","description":"whether to add noise to reconstruction"}}},"re":{"title":"ReCUT / ReCycleGAN","type":"object","properties":{"P_lr":{"default":0.0002,"type":"number","description":"initial learning rate for P networks"},"adversarial_loss_p":{"default":false,"type":"boolean","description":"if True, also train the prediction model with an adversarial loss"},"netP":{"default":"unet_128","type":"string","description":"specify P architecture","enum":["resnet_9blocks","resnet_6blocks","resnet_attn","unet_256","unet_128"]},"no_train_P_fake_images":{"default":false,"type":"boolean","description":"if True, P wont be trained over fake images projections"},"nuplet_size":{"default":3,"type":"integer","description":"Number of frames loaded"},"projection_threshold":{"default":1.0,"type":"number","description":"threshold of the real images projection loss below with fake projection and fake reconstruction losses are applied"}}},"palette":{"title":"Diffusion model","type":"object","properties":{"computed_sketch_list":{"default":["canny","hed"],"type":"array","items":{"enum":null,"type":"string"},"description":"what to use for random sketch"},"cond_embed_dim":{"default":32,"type":"integer","description":"nb of examples processed for inference"},"cond_image_creation":{"default":"y_t","type":"string","description":"how cond_image is created","enum":["y_t","previous_frame","computed_sketch","low_res"]},"conditioning":{"default":"","type":"string","description":"whether to use conditioning or not","enum":["","mask","class","mask_and_class"]},"ddim_eta":{"default":0.5,"type":"number","description":"eta for ddim sampling variance"},"ddim_num_steps":{"default":10,"type":"integer","description":"number of steps for ddim sampling"},"dropout_prob":{"default":0.0,"type":"number","description":"dropout probability for classifier-free guidance"},"generate_per_class":{"default":false,"type":"boolean","description":"whether to generate samples of each images"},"inference_num":{"default":-1,"type":"integer","description":"nb of examples processed for inference"},"lambda_G":{"default":1.0,"type":"number","description":"weight for supervised loss"},"loss":{"default":"MSE","type":"string","description":"loss for denoising model","enum":["L1","MSE","multiscale"]},"prob_use_previous_frame":{"default":0.5,"type":"number","description":"prob to use previous frame as y cond"},"sam_crop_delta":{"default":true,"type":"boolean","description":"extend crop's width and height by 2\\*crop_delta before computing masks"},"sam_final_canny":{"default":false,"type":"boolean","description":"whether to perform a Canny edge detection on sam sketch to soften the edges"},"sam_max_mask_area":{"default":0.99,"type":"number","description":"maximum area in proportion of image size for a mask to be kept"},"sam_min_mask_area":{"default":0.001,"type":"number","description":"minimum area in proportion of image size for a mask to be kept"},"sam_no_output_binary_sam":{"default":false,"type":"boolean","description":"whether to not output binary sketch before Canny"},"sam_no_sample_points_in_ellipse":{"default":false,"type":"boolean","description":"whether to not sample the points inside an ellipse to avoid the corners of the image"},"sam_no_sobel_filter":{"default":false,"type":"boolean","description":"whether to not use a Sobel filter on each SAM masks"},"sam_points_per_side":{"default":16,"type":"integer","description":"number of points per side of image to prompt SAM with (\\# of prompted points will be points_per_side\\*\\*2)"},"sam_redundancy_threshold":{"default":0.62,"type":"number","description":"redundancy threshold above which redundant masks are not kept"},"sam_sobel_threshold":{"default":0.7,"type":"number","description":"sobel threshold in % of gradient magintude"},"sam_use_gaussian_filter":{"default":false,"type":"boolean","description":"whether to apply a gaussian blur to each SAM masks"},"sampling_method":{"default":"ddpm","type":"string","description":"choose the sampling method between ddpm and ddim","enum":["ddpm","ddim"]},"sketch_canny_range":{"default":[0,765],"type":"array","items":{"enum":null,"type":"string"},"description":"range for Canny thresholds"},"super_resolution_scale":{"default":2.0,"type":"number","description":"scale for super resolution"},"task":{"default":"inpainting","type":"string","description":"","enum":["inpainting","super_resolution"]}}}}},"data":{"title":"Datasets","type":"object","properties":{"online_creation":{"title":"Online created datasets","type":"object","properties":{"color_mask_A":{"default":false,"type":"boolean","description":"Perform task of replacing color-filled masks by objects"},"crop_delta_A":{"default":50,"type":"integer","description":"size of crops are random, values allowed are online_creation_crop_size more or less online_creation_crop_delta for domain A"},"crop_delta_B":{"default":50,"type":"integer","description":"size of crops are random, values allowed are online_creation_crop_size more or less online_creation_crop_delta for domain B"},"crop_size_A":{"default":512,"type":"integer","description":"crop to this size during online creation, it needs to be greater than bbox size for domain A"},"crop_size_B":{"default":512,"type":"integer","description":"crop to this size during online creation, it needs to be greater than bbox size for domain B"},"load_size_A":{"default":[],"type":"array","items":{"enum":null,"type":"string"},"description":"load to this size during online creation, format : width height or only one size if square"},"load_size_B":{"default":[],"type":"array","items":{"enum":null,"type":"string"},"description":"load to this size during online creation, format : width height or only one size if square"},"mask_delta_A":{"default":[[]],"type":"array","items":{"enum":null,"type":"string"},"description":"mask offset (in pixels) to allow generation of a bigger object in domain B (for semantic loss) for domain A, format : 'width (x),height (y)' for each class or only one size if square, e.g. '125, 55 100, 100' for 2 classes"},"mask_delta_A_ratio":{"default":[[]],"type":"array","items":{"enum":null,"type":"string"},"description":"ratio mask offset to allow generation of a bigger object in domain B (for semantic loss) for domain A, format : width (x),height (y) for each class or only one size if square"},"mask_delta_B":{"default":[[]],"type":"array","items":{"enum":null,"type":"string"},"description":"mask offset (in pixels) to allow generation of a bigger object in domain A (for semantic loss) for domain B, format : 'width (x),height (y)' for each class or only one size if square, e.g. '125, 55 100, 100' for 2 classes"},"mask_delta_B_ratio":{"default":[[]],"type":"array","items":{"enum":null,"type":"string"},"description":"ratio mask offset to allow generation of a bigger object in domain A (for semantic loss) for domain B, format : 'width (x),height (y)' for each class or only one size if square"},"mask_random_offset_A":{"default":[0.0],"type":"array","items":{"enum":null,"type":"string"},"description":"ratio mask size randomization (only to make bigger one) to robustify the image generation in domain A, format : width (x) height (y) or only one size if square"},"mask_random_offset_B":{"default":[0.0],"type":"array","items":{"enum":null,"type":"string"},"description":"mask size randomization (only to make bigger one) to robustify the image generation in domain B, format : width (y) height (x) or only one size if square"},"mask_square_A":{"default":false,"type":"boolean","description":"whether masks should be squared for domain A"},"mask_square_B":{"default":false,"type":"boolean","description":"whether masks should be squared for domain B"},"rand_mask_A":{"default":false,"type":"boolean","description":"Perform task of replacing noised masks by objects"}}},"crop_size":{"default":256,"type":"integer","description":"then crop to this size"},"dataset_mode":{"default":"unaligned","type":"string","description":"chooses how datasets are loaded.","enum":["unaligned","unaligned_labeled_cls","unaligned_labeled_mask","self_supervised_labeled_mask","unaligned_labeled_mask_cls","self_supervised_labeled_mask_cls","unaligned_labeled_mask_online","self_supervised_labeled_mask_online","unaligned_labeled_mask_cls_online","self_supervised_labeled_mask_cls_online","aligned","nuplet_unaligned_labeled_mask","temporal_labeled_mask_online","self_supervised_temporal","single"]},"direction":{"default":"AtoB","type":"string","description":"AtoB or BtoA","enum":["AtoB","BtoA"]},"inverted_mask":{"default":false,"type":"boolean","description":"whether to invert the mask, i.e. around the bbox"},"load_size":{"default":286,"type":"integer","description":"scale images to this size"},"max_dataset_size":{"default":1000000000,"type":"integer","description":"Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded."},"num_threads":{"default":4,"type":"integer","description":"\\# threads for loading data"},"online_context_pixels":{"default":0,"type":"integer","description":"context pixel band around the crop, unused for generation, only for disc "},"online_fixed_mask_size":{"default":-1,"type":"integer","description":"if \\>0, it will be used as fixed bbox size (warning: in dataset resolution ie before resizing) "},"online_select_category":{"default":-1,"type":"integer","description":"category to select for bounding boxes, -1 means all boxes selected"},"online_single_bbox":{"default":false,"type":"boolean","description":"whether to only allow a single bbox per online crop"},"preprocess":{"default":"resize_and_crop","type":"string","description":"scaling and cropping of images at load time","enum":["resize_and_crop","crop","scale_width","scale_width_and_crop","none"]},"refined_mask":{"default":false,"type":"boolean","description":"whether to use refined mask with sam"},"relative_paths":{"default":false,"type":"boolean","description":"whether paths to images are relative to dataroot"},"sanitize_paths":{"default":false,"type":"boolean","description":"if true, wrong images or labels paths will be removed before training"},"serial_batches":{"default":false,"type":"boolean","description":"if true, takes images in order to make batches, otherwise takes them randomly"},"temporal_frame_step":{"default":30,"type":"integer","description":"how many frames between successive frames selected"},"temporal_num_common_char":{"default":-1,"type":"integer","description":"how many characters (the first ones) are used to identify a video; if =-1 natural sorting is used "},"temporal_number_frames":{"default":5,"type":"integer","description":"how many successive frames use for temporal loader"}}},"f_s":{"title":"Semantic segmentation network","type":"object","properties":{"all_classes_as_one":{"default":false,"type":"boolean","description":"if true, all classes will be considered as the same one (ie foreground vs background)"},"class_weights":{"default":[],"type":"array","items":{"enum":null,"type":"string"},"description":"class weights for imbalanced semantic classes"},"config_segformer":{"default":"models/configs/segformer/segformer_config_b0.json","type":"string","description":"path to segformer configuration file for f_s"},"dropout":{"default":false,"type":"boolean","description":"dropout for the semantic network"},"net":{"default":"vgg","type":"string","description":"specify f_s network [vgg|unet|segformer|sam]","enum":["vgg","unet","segformer","sam"]},"nf":{"default":64,"type":"integer","description":"\\# of filters in the first conv layer of classifier"},"semantic_nclasses":{"default":2,"type":"integer","description":"number of classes of the semantic loss classifier"},"semantic_threshold":{"default":1.0,"type":"number","description":"threshold of the semantic classifier loss below with semantic loss is applied"},"weight_sam":{"default":"","type":"string","description":"path to sam weight for f_s, e.g. models/configs/sam/pretrain/sam_vit_b_01ec64.pth"},"weight_segformer":{"default":"","type":"string","description":"path to segformer weight for f_s, e.g. models/configs/segformer/pretrain/segformer_mit-b0.pth"}}},"cls":{"title":"Semantic classification network","type":"object","properties":{"all_classes_as_one":{"default":false,"type":"boolean","description":"if true, all classes will be considered as the same one (ie foreground vs background)"},"class_weights":{"default":[],"type":"array","items":{"enum":null,"type":"string"},"description":"class weights for imbalanced semantic classes"},"config_segformer":{"default":"models/configs/segformer/segformer_config_b0.json","type":"string","description":"path to segformer configuration file for cls"},"dropout":{"default":false,"type":"boolean","description":"dropout for the semantic network"},"net":{"default":"vgg","type":"string","description":"specify cls network [vgg|unet|segformer]","enum":["vgg","unet","segformer"]},"nf":{"default":64,"type":"integer","description":"\\# of filters in the first conv layer of classifier"},"semantic_nclasses":{"default":2,"type":"integer","description":"number of classes of the semantic loss classifier"},"semantic_threshold":{"default":1.0,"type":"number","description":"threshold of the semantic classifier loss below with semantic loss is applied"},"weight_segformer":{"default":"","type":"string","description":"path to segformer weight for cls, e.g. models/configs/segformer/pretrain/segformer_mit-b0.pth"}}},"output":{"title":"Output","type":"object","properties":{"display":{"title":"Visdom display","type":"object","properties":{"G_attention_masks":{"default":false,"type":"boolean","description":""},"aim_port":{"default":53800,"type":"integer","description":"aim port of the web display"},"aim_server":{"default":"http://localhost","type":"string","description":"aim server of the web display"},"diff_fake_real":{"default":false,"type":"boolean","description":"if True x - G(x) is displayed"},"env":{"default":"","type":"string","description":"visdom display environment name (default is \"main\")"},"freq":{"default":400,"type":"integer","description":"frequency of showing training results on screen"},"id":{"default":1,"type":"integer","description":"window id of the web display"},"ncols":{"default":0,"type":"integer","description":"if positive, display all images in a single visdom web panel with certain number of images per row.(if == 0 ncols will be computed automatically)"},"networks":{"default":false,"type":"boolean","description":"Set True if you want to display networks on port 8000"},"type":{"default":["visdom"],"type":"array","items":{"enum":null,"type":"string"},"description":"output display, either visdom, aim or no output","enum":["visdom","aim","none"]},"visdom_autostart":{"default":false,"type":"boolean","description":"whether to start a visdom server automatically"},"visdom_port":{"default":8097,"type":"integer","description":"visdom port of the web display"},"visdom_server":{"default":"http://localhost","type":"string","description":"visdom server of the web display"},"winsize":{"default":256,"type":"integer","description":"display window size for both visdom and HTML"}}},"no_html":{"default":false,"type":"boolean","description":"do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/"},"print_freq":{"default":100,"type":"integer","description":"frequency of showing training results on console"},"update_html_freq":{"default":1000,"type":"integer","description":"frequency of saving training results to html"},"verbose":{"default":false,"type":"boolean","description":"if specified, print more debugging information"}}},"model":{"title":"Model","type":"object","properties":{"depth_network":{"default":"DPT_Large","type":"string","description":"specify depth prediction network architecture","enum":["DPT_Large","DPT_Hybrid","MiDaS_small","DPT_BEiT_L_512","DPT_BEiT_L_384","DPT_BEiT_B_384","DPT_SwinV2_L_384","DPT_SwinV2_B_384","DPT_SwinV2_T_256","DPT_Swin_L_384","DPT_Next_ViT_L_384","DPT_LeViT_224"]},"init_gain":{"default":0.02,"type":"number","description":"scaling factor for normal, xavier and orthogonal."},"init_type":{"default":"normal","type":"string","description":"network initialization","enum":["normal","xavier","kaiming","orthogonal"]},"input_nc":{"default":3,"type":"integer","description":"\\# of input image channels: 3 for RGB and 1 for grayscale","enum":[1,3]},"multimodal":{"default":false,"type":"boolean","description":"multimodal model with random latent input vector"},"output_nc":{"default":3,"type":"integer","description":"\\# of output image channels: 3 for RGB and 1 for grayscale","enum":[1,3]},"prior_321_backwardcompatibility":{"default":false,"type":"boolean","description":"whether to load models from previous version of JG."}}},"train":{"title":"Training","type":"object","properties":{"sem":{"title":"Semantic training","type":"object","properties":{"cls_B":{"default":false,"type":"boolean","description":"if true cls will be trained not only on domain A but also on domain B"},"cls_lambda":{"default":1.0,"type":"number","description":"weight for semantic class loss"},"cls_pretrained":{"default":false,"type":"boolean","description":"whether to use a pretrained model, available for non \"basic\" model only"},"cls_template":{"default":"basic","type":"string","description":"classifier/regressor model type, from torchvision (resnet18, ...), default is custom simple model"},"idt":{"default":false,"type":"boolean","description":"if true apply semantic loss on identity"},"lr_cls":{"default":0.0002,"type":"number","description":"cls learning rate"},"lr_f_s":{"default":0.0002,"type":"number","description":"f_s learning rate"},"mask_lambda":{"default":1.0,"type":"number","description":"weight for semantic mask loss"},"net_output":{"default":false,"type":"boolean","description":"if true apply generator semantic loss on network output for real image rather than on label."},"use_label_B":{"default":false,"type":"boolean","description":"if true domain B has labels too"}}},"mask":{"title":"Semantic training with masks","type":"object","properties":{"charbonnier_eps":{"default":1e-06,"type":"number","description":"Charbonnier loss epsilon value"},"compute_miou":{"default":false,"type":"boolean","description":""},"disjoint_f_s":{"default":false,"type":"boolean","description":"whether to use a disjoint f_s with the same exact structure"},"f_s_B":{"default":false,"type":"boolean","description":"if true f_s will be trained not only on domain A but also on domain B"},"for_removal":{"default":false,"type":"boolean","description":"if true, object removal mode, domain B images with label 0, cut models only"},"lambda_out_mask":{"default":10.0,"type":"number","description":"weight for loss out mask"},"loss_out_mask":{"default":"L1","type":"string","description":"loss for out mask content (which should not change).","enum":["L1","MSE","Charbonnier"]},"miou_every":{"default":1000,"type":"integer","description":""},"no_train_f_s_A":{"default":false,"type":"boolean","description":"if true f_s wont be trained on domain A"},"out_mask":{"default":false,"type":"boolean","description":"use loss out mask"}}},"D_accuracy_every":{"default":1000,"type":"integer","description":"compute D accuracy every N iterations"},"D_lr":{"default":0.0001,"type":"number","description":"discriminator separate learning rate"},"G_ema":{"default":false,"type":"boolean","description":"whether to build G via exponential moving average"},"G_ema_beta":{"default":0.999,"type":"number","description":"exponential decay for ema"},"G_lr":{"default":0.0002,"type":"number","description":"initial learning rate for generator"},"batch_size":{"default":1,"type":"integer","description":"input batch size"},"beta1":{"default":0.9,"type":"number","description":"momentum term of adam"},"beta2":{"default":0.999,"type":"number","description":"momentum term of adam"},"cls_l1_regression":{"default":false,"type":"boolean","description":"if true l1 loss will be used to compute regressor loss"},"cls_regression":{"default":false,"type":"boolean","description":"if true cls will be a regressor and not a classifier"},"compute_D_accuracy":{"default":false,"type":"boolean","description":"whether to compute D accuracy explicitely"},"compute_metrics_test":{"default":false,"type":"boolean","description":""},"continue":{"default":false,"type":"boolean","description":"continue training: load the latest model"},"epoch":{"default":"latest","type":"string","description":"which epoch to load? set to latest to use latest cached model"},"epoch_count":{"default":1,"type":"integer","description":"the starting epoch count, we save the model by \\, \\+\\, ..."},"export_jit":{"default":false,"type":"boolean","description":"whether to export model in jit format"},"gan_mode":{"default":"lsgan","type":"string","description":"the type of GAN objective. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.","enum":["vanilla","lsgan","wgangp","projected"]},"iter_size":{"default":1,"type":"integer","description":"backward will be apllied each iter_size iterations, it simulate a greater batch size : its value is batch_size\\*iter_size"},"load_iter":{"default":0,"type":"integer","description":"which iteration to load? if load_iter \\> 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]"},"lr_decay_iters":{"default":50,"type":"integer","description":"multiply by a gamma every lr_decay_iters iterations"},"lr_policy":{"default":"linear","type":"string","description":"learning rate policy.","enum":["linear","step","plateau","cosine"]},"metrics_every":{"default":1000,"type":"integer","description":""},"metrics_list":{"default":["FID"],"type":"array","items":{"enum":null,"type":"string"},"description":"","enum":["FID","KID","MSID","PSNR"]},"mm_lambda_z":{"default":0.5,"type":"number","description":"weight for random z loss"},"mm_nz":{"default":8,"type":"integer","description":"number of latent vectors"},"n_epochs":{"default":100,"type":"integer","description":"number of epochs with the initial learning rate"},"n_epochs_decay":{"default":100,"type":"integer","description":"number of epochs to linearly decay learning rate to zero"},"nb_img_max_fid":{"default":1000000000,"type":"integer","description":"Maximum number of samples allowed per dataset to compute fid. If the dataset directory contains more than nb_img_max_fid, only a subset is used."},"optim":{"default":"adam","type":"string","description":"optimizer (adam, radam, adamw, ...)","enum":["adam","radam","adamw","lion"]},"pool_size":{"default":50,"type":"integer","description":"the size of image buffer that stores previously generated images"},"save_by_iter":{"default":false,"type":"boolean","description":"whether saves model by iteration"},"save_epoch_freq":{"default":1,"type":"integer","description":"frequency of saving checkpoints at the end of epochs"},"save_latest_freq":{"default":5000,"type":"integer","description":"frequency of saving the latest results"},"semantic_cls":{"default":false,"type":"boolean","description":"if true semantic class losses will be used"},"semantic_mask":{"default":false,"type":"boolean","description":"if true semantic mask losses will be used"},"temporal_criterion":{"default":false,"type":"boolean","description":"if true, MSE loss will be computed between successive frames"},"temporal_criterion_lambda":{"default":1.0,"type":"number","description":"lambda for MSE loss that will be computed between successive frames"},"use_contrastive_loss_D":{"default":false,"type":"boolean","description":""}}},"dataaug":{"title":"Data augmentation","type":"object","properties":{"APA":{"default":false,"type":"boolean","description":"if true, G will be used as augmentation during D training adaptively to D overfitting between real and fake images"},"APA_every":{"default":4,"type":"integer","description":"How often to perform APA adjustment?"},"APA_nimg":{"default":50,"type":"integer","description":"APA adjustment speed, measured in how many images it takes for p to increase/decrease by one unit."},"APA_p":{"default":0,"type":"integer","description":"initial value of probability APA"},"APA_target":{"default":0.6,"type":"number","description":""},"D_diffusion":{"default":false,"type":"boolean","description":"whether to apply diffusion noise augmentation to discriminator inputs, projected discriminator only"},"D_diffusion_every":{"default":4,"type":"integer","description":"How often to perform diffusion augmentation adjustment"},"D_label_smooth":{"default":false,"type":"boolean","description":"whether to use one-sided label smoothing with discriminator"},"D_noise":{"default":0.0,"type":"number","description":"whether to add instance noise to discriminator inputs"},"affine":{"default":0.0,"type":"number","description":"if specified, apply random affine transforms to the images for data augmentation"},"affine_scale_max":{"default":1.2,"type":"number","description":"if random affine specified, max scale range value"},"affine_scale_min":{"default":0.8,"type":"number","description":"if random affine specified, min scale range value"},"affine_shear":{"default":45,"type":"integer","description":"if random affine specified, shear range (0,value)"},"affine_translate":{"default":0.2,"type":"number","description":"if random affine specified, translation range (-value\\*img_size,+value\\*img_size) value"},"diff_aug_policy":{"default":"","type":"string","description":"choose the augmentation policy : color randaffine randperspective. If you want more than one, please write them separated by a comma with no space (e.g. color,randaffine)"},"diff_aug_proba":{"default":0.5,"type":"number","description":"proba of using each transformation"},"imgaug":{"default":false,"type":"boolean","description":"whether to apply random image augmentation"},"no_flip":{"default":false,"type":"boolean","description":"if specified, do not flip the images for data augmentation"},"no_rotate":{"default":false,"type":"boolean","description":"if specified, do not rotate the images for data augmentation"}}},"checkpoints_dir":{"default":"./checkpoints","type":"string","description":"models are saved here"},"dataroot":{"default":"None","type":"string","description":"path to images (should have subfolders trainA, trainB, valA, valB, etc)"},"ddp_port":{"default":"12355","type":"string","description":""},"gpu_ids":{"default":"0","type":"string","description":"gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU"},"model_type":{"default":"cut","type":"string","description":"chooses which model to use.","enum":["cut","cycle_gan","palette"]},"name":{"default":"experiment_name","type":"string","description":"name of the experiment. It decides where to store samples and models"},"phase":{"default":"train","type":"string","description":"train, val, test, etc"},"suffix":{"default":"","type":"string","description":"customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}"},"test_batch_size":{"default":1,"type":"integer","description":"input batch size"},"warning_mode":{"default":false,"type":"boolean","description":"whether to display warning"},"with_amp":{"default":false,"type":"boolean","description":"whether to activate torch amp on forward passes"},"with_tf32":{"default":false,"type":"boolean","description":"whether to activate tf32 for faster computations (Ampere GPU and beyond only)"},"with_torch_compile":{"default":false,"type":"boolean","description":"whether to activate torch.compile for some forward and backward functions (experimental)"}}}},"definitions":{"ServerTrainOptions":{"title":"ServerTrainOptions","type":"object","properties":{"sync":{"title":"Sync","description":"if false, the call returns immediately and train process is executed in the background. If true, the call returns only when training process is finished","default":false,"type":"boolean"}}}}}}},"definitions":{"ServerTrainOptions":{"title":"ServerTrainOptions","type":"object","properties":{"sync":{"title":"Sync","description":"if false, the call returns immediately and train process is executed in the background. If true, the call returns only when training process is finished","default":false,"type":"boolean"}}}}} \ No newline at end of file +{"openapi":"3.1.0","info":{"title":"JoliGEN server","description":"*commit:* [f1e05266](https://github.com/jolibrain/joliGEN/commit/f1e05266a96597e8d3ee456907ce55fc7d3511e1)\n\nThis is the JoliGEN server API documentation.\n","version":"0.1.0"},"paths":{"/train/{name}":{"get":{"summary":"Get the status of a training process","operationId":"get_train_train__name__get","parameters":[{"required":true,"schema":{"type":"string","title":"Name"},"name":"name","in":"path"}],"responses":{"200":{"description":"Successful Response","content":{"application/json":{"schema":{}}}},"422":{"description":"Validation Error","content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}}}}},"post":{"summary":"Start a training process with given name.","description":"The training process will be created using the same options as command line","operationId":"train_train__name__post","parameters":[{"required":true,"schema":{"type":"string","title":"Name"},"name":"name","in":"path"}],"requestBody":{"content":{"application/json":{"schema":{"$ref":"#/components/schemas/TrainOptions"}}}},"responses":{"201":{"description":"Successful Response","content":{"application/json":{"schema":{}}}},"422":{"description":"Validation Error","content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}}}}},"delete":{"summary":"Delete a training process.","description":"If the process is running, it will be stopped.","operationId":"delete_train_train__name__delete","parameters":[{"required":true,"schema":{"type":"string","title":"Name"},"name":"name","in":"path"}],"responses":{"200":{"description":"Successful Response","content":{"application/json":{"schema":{}}}},"422":{"description":"Validation Error","content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}}}}}},"/train":{"get":{"summary":"Get the status of all training processes","operationId":"get_train_processes_train_get","responses":{"200":{"description":"Successful Response","content":{"application/json":{"schema":{}}}}}}},"/info":{"get":{"summary":"Get the server status","operationId":"get_info_info_get","responses":{"200":{"description":"Successful Response","content":{"application/json":{"schema":{}}}}}}},"/fs/":{"delete":{"summary":"Delete a file or a directory in the filesystem","description":"This endpoint can be dangerous, use it with extreme caution","operationId":"delete_path_fs__delete","parameters":[{"required":true,"schema":{"type":"string","title":"Path"},"name":"path","in":"query"}],"responses":{"200":{"description":"Successful Response","content":{"application/json":{"schema":{}}}},"422":{"description":"Validation Error","content":{"application/json":{"schema":{"$ref":"#/components/schemas/HTTPValidationError"}}}}}}}},"components":{"schemas":{"HTTPValidationError":{"properties":{"detail":{"items":{"$ref":"#/components/schemas/ValidationError"},"type":"array","title":"Detail"}},"type":"object","title":"HTTPValidationError"},"ValidationError":{"properties":{"loc":{"items":{"anyOf":[{"type":"string"},{"type":"integer"}]},"type":"array","title":"Location"},"msg":{"type":"string","title":"Message"},"type":{"type":"string","title":"Error Type"}},"type":"object","required":["loc","msg","type"],"title":"ValidationError"},"TrainOptions":{"title":"TrainBody","type":"object","properties":{"server":{"title":"Server","default":{"sync":false},"allOf":[{"$ref":"#/definitions/ServerTrainOptions"}]},"train_options":{"title":"TrainOptions","type":"object","properties":{"D":{"title":"Discriminator","type":"object","properties":{"dropout":{"default":false,"type":"boolean","description":"whether to use dropout in the discriminator"},"n_layers":{"default":3,"type":"integer","description":"only used if netD==n_layers"},"ndf":{"default":64,"type":"integer","description":"\\# of discrim filters in the first conv layer"},"netDs":{"default":["projected_d","basic"],"type":"array","items":{"enum":null,"type":"string"},"description":"specify discriminator architecture, another option, --D_n_layers allows you to specify the layers in the n_layers discriminator. NB: duplicated arguments are ignored. Values: basic, n_layers, pixel, projected_d, temporal, vision_aided, depth, mask, sam"},"no_antialias":{"default":false,"type":"boolean","description":"if specified, use stride=2 convs instead of antialiased-downsampling (sad)"},"no_antialias_up":{"default":false,"type":"boolean","description":"if specified, use [upconv(learned filter)] instead of [upconv(hard-coded [1,3,3,1] filter), conv]"},"norm":{"default":"instance","type":"string","description":"instance normalization or batch normalization for D","enum":["instance","batch","none"]},"proj_config_segformer":{"default":"models/configs/segformer/segformer_config_b0.json","type":"string","description":"path to segformer configuration file"},"proj_interp":{"default":-1,"type":"integer","description":"whether to force projected discriminator interpolation to a value \\> 224, -1 means no interpolation"},"proj_network_type":{"default":"efficientnet","type":"string","description":"projected discriminator architecture","enum":["efficientnet","segformer","vitbase","vitsmall","vitsmall2","vitclip16","depth"]},"proj_weight_segformer":{"default":"models/configs/segformer/pretrain/segformer_mit-b0.pth","type":"string","description":"path to segformer weight"},"spectral":{"default":false,"type":"boolean","description":"whether to use spectral norm in the discriminator"},"temporal_every":{"default":4,"type":"integer","description":"apply temporal discriminator every x steps"},"vision_aided_backbones":{"default":"clip+dino+swin","type":"string","description":"specify vision aided discriminators architectures, they are frozen then output are combined and fitted with a linear network on top, choose from dino, clip, swin, det_coco, seg_ade and combine them with +"},"weight_sam":{"default":"","type":"string","description":"path to sam weight for D, e.g. models/configs/sam/pretrain/sam_vit_b_01ec64.pth, or models/configs/sam/pretrain/mobile_sam.pt for MobileSAM"}}},"G":{"title":"Generator","type":"object","properties":{"attn_nb_mask_attn":{"default":10,"type":"integer","description":"number of attention masks in _attn model architectures"},"attn_nb_mask_input":{"default":1,"type":"integer","description":"number of mask dedicated to input in _attn model architectures"},"backward_compatibility_twice_resnet_blocks":{"default":false,"type":"boolean","description":"if true, feats will go througt resnet blocks two times for resnet_attn generators. This option will be deleted, it's for backward compatibility (old models were trained that way)."},"config_segformer":{"default":"models/configs/segformer/segformer_config_b0.json","type":"string","description":"path to segformer configuration file for G"},"diff_n_timestep_test":{"default":1000,"type":"integer","description":"Number of timesteps used for UNET mha inference (test time)."},"diff_n_timestep_train":{"default":2000,"type":"integer","description":"Number of timesteps used for UNET mha training."},"dropout":{"default":false,"type":"boolean","description":"dropout for the generator"},"nblocks":{"default":9,"type":"integer","description":"\\# of layer blocks in G, applicable to resnets"},"netE":{"default":"resnet_256","type":"string","description":"specify multimodal latent vector encoder","enum":["resnet_128","resnet_256","resnet_512","conv_128","conv_256","conv_512"]},"netG":{"default":"mobile_resnet_attn","type":"string","description":"specify generator architecture","enum":["resnet","resnet_attn","mobile_resnet","mobile_resnet_attn","unet_256","unet_128","stylegan2","smallstylegan2","segformer_attn_conv","segformer_conv","ittr","unet_mha","uvit"]},"ngf":{"default":64,"type":"integer","description":"\\# of gen filters in the last conv layer"},"norm":{"default":"instance","type":"string","description":"instance normalization or batch normalization for G","enum":["instance","batch","none"]},"padding_type":{"default":"reflect","type":"string","description":"whether to use padding in the generator","enum":["reflect","replicate","zeros"]},"spectral":{"default":false,"type":"boolean","description":"whether to use spectral norm in the generator"},"stylegan2_num_downsampling":{"default":1,"type":"integer","description":"Number of downsampling layers used by StyleGAN2Generator"},"unet_mha_attn_res":{"default":[16],"type":"array","items":{"enum":null,"type":"string"},"description":"downrate samples at which attention takes place"},"unet_mha_channel_mults":{"default":[1,2,4,8],"type":"array","items":{"enum":null,"type":"string"},"description":"channel multiplier for each level of the UNET mha"},"unet_mha_group_norm_size":{"default":32,"type":"integer","description":""},"unet_mha_norm_layer":{"default":"groupnorm","type":"string","description":"","enum":["groupnorm","batchnorm","layernorm","instancenorm","switchablenorm"]},"unet_mha_num_head_channels":{"default":32,"type":"integer","description":"number of channels in each head of the mha architecture"},"unet_mha_num_heads":{"default":1,"type":"integer","description":"number of heads in the mha architecture"},"unet_mha_res_blocks":{"default":[2,2,2,2],"type":"array","items":{"enum":null,"type":"string"},"description":"distribution of resnet blocks across the UNet stages, should have same size as --G_unet_mha_channel_mults"},"unet_mha_vit_efficient":{"default":false,"type":"boolean","description":"if true, use efficient attention in UNet and UViT"},"uvit_num_transformer_blocks":{"default":6,"type":"integer","description":"Number of transformer blocks in UViT"}}},"alg":{"title":"Algorithm-specific","type":"object","properties":{"gan":{"title":"GAN model","type":"object","properties":{"lambda":{"default":1.0,"type":"number","description":"weight for GAN loss:GAN(G(X))"}}},"cut":{"title":"CUT model","type":"object","properties":{"HDCE_gamma":{"default":1.0,"type":"number","description":""},"HDCE_gamma_min":{"default":1.0,"type":"number","description":""},"MSE_idt":{"default":false,"type":"boolean","description":"use MSENCE loss for identity mapping: MSE(G(Y), Y))"},"flip_equivariance":{"default":false,"type":"boolean","description":"Enforce flip-equivariance as additional regularization. It's used by FastCUT, but not CUT"},"lambda_MSE_idt":{"default":1.0,"type":"number","description":"weight for MSE identity loss: MSE(G(X), X)"},"lambda_NCE":{"default":1.0,"type":"number","description":"weight for NCE loss: NCE(G(X), X)"},"lambda_SRC":{"default":0.0,"type":"number","description":"weight for SRC (semantic relation consistency) loss: NCE(G(X), X)"},"nce_T":{"default":0.07,"type":"number","description":"temperature for NCE loss"},"nce_idt":{"default":true,"type":"boolean","description":"use NCE loss for identity mapping: NCE(G(Y), Y))"},"nce_includes_all_negatives_from_minibatch":{"default":false,"type":"boolean","description":"(used for single image translation) If True, include the negatives from the other samples of the minibatch when computing the contrastive loss. Please see models/patchnce.py for more details."},"nce_layers":{"default":"0,4,8,12,16","type":"string","description":"compute NCE loss on which layers"},"nce_loss":{"default":"monce","type":"string","description":"CUT contrastice loss","enum":["patchnce","monce","SRC_hDCE"]},"netF":{"default":"mlp_sample","type":"string","description":"how to downsample the feature map","enum":["sample","mlp_sample","sample_qsattn","mlp_sample_qsattn"]},"netF_dropout":{"default":false,"type":"boolean","description":"whether to use dropout with F"},"netF_nc":{"default":256,"type":"integer","description":""},"netF_norm":{"default":"instance","type":"string","description":"instance normalization or batch normalization for F","enum":["instance","batch","none"]},"num_patches":{"default":256,"type":"integer","description":"number of patches per layer"}}},"cyclegan":{"title":"CycleGAN model","type":"object","properties":{"lambda_A":{"default":10.0,"type":"number","description":"weight for cycle loss (A -\\> B -\\> A)"},"lambda_B":{"default":10.0,"type":"number","description":"weight for cycle loss (B -\\> A -\\> B)"},"lambda_identity":{"default":0.5,"type":"number","description":"use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1"},"rec_noise":{"default":0.0,"type":"number","description":"whether to add noise to reconstruction"}}},"re":{"title":"ReCUT / ReCycleGAN","type":"object","properties":{"P_lr":{"default":0.0002,"type":"number","description":"initial learning rate for P networks"},"adversarial_loss_p":{"default":false,"type":"boolean","description":"if True, also train the prediction model with an adversarial loss"},"netP":{"default":"unet_128","type":"string","description":"specify P architecture","enum":["resnet_9blocks","resnet_6blocks","resnet_attn","unet_256","unet_128"]},"no_train_P_fake_images":{"default":false,"type":"boolean","description":"if True, P wont be trained over fake images projections"},"nuplet_size":{"default":3,"type":"integer","description":"Number of frames loaded"},"projection_threshold":{"default":1.0,"type":"number","description":"threshold of the real images projection loss below with fake projection and fake reconstruction losses are applied"}}},"palette":{"title":"Diffusion model","type":"object","properties":{"computed_sketch_list":{"default":["canny","hed"],"type":"array","items":{"enum":null,"type":"string"},"description":"what to use for random sketch"},"cond_embed_dim":{"default":32,"type":"integer","description":"nb of examples processed for inference"},"cond_image_creation":{"default":"y_t","type":"string","description":"how cond_image is created","enum":["y_t","previous_frame","computed_sketch","low_res"]},"conditioning":{"default":"","type":"string","description":"whether to use conditioning or not","enum":["","mask","class","mask_and_class"]},"ddim_eta":{"default":0.5,"type":"number","description":"eta for ddim sampling variance"},"ddim_num_steps":{"default":10,"type":"integer","description":"number of steps for ddim sampling"},"dropout_prob":{"default":0.0,"type":"number","description":"dropout probability for classifier-free guidance"},"generate_per_class":{"default":false,"type":"boolean","description":"whether to generate samples of each images"},"inference_num":{"default":-1,"type":"integer","description":"nb of examples processed for inference"},"lambda_G":{"default":1.0,"type":"number","description":"weight for supervised loss"},"loss":{"default":"MSE","type":"string","description":"loss for denoising model","enum":["L1","MSE","multiscale"]},"prob_use_previous_frame":{"default":0.5,"type":"number","description":"prob to use previous frame as y cond"},"sam_crop_delta":{"default":true,"type":"boolean","description":"extend crop's width and height by 2\\*crop_delta before computing masks"},"sam_final_canny":{"default":false,"type":"boolean","description":"whether to perform a Canny edge detection on sam sketch to soften the edges"},"sam_max_mask_area":{"default":0.99,"type":"number","description":"maximum area in proportion of image size for a mask to be kept"},"sam_min_mask_area":{"default":0.001,"type":"number","description":"minimum area in proportion of image size for a mask to be kept"},"sam_no_output_binary_sam":{"default":false,"type":"boolean","description":"whether to not output binary sketch before Canny"},"sam_no_sample_points_in_ellipse":{"default":false,"type":"boolean","description":"whether to not sample the points inside an ellipse to avoid the corners of the image"},"sam_no_sobel_filter":{"default":false,"type":"boolean","description":"whether to not use a Sobel filter on each SAM masks"},"sam_points_per_side":{"default":16,"type":"integer","description":"number of points per side of image to prompt SAM with (\\# of prompted points will be points_per_side\\*\\*2)"},"sam_redundancy_threshold":{"default":0.62,"type":"number","description":"redundancy threshold above which redundant masks are not kept"},"sam_sobel_threshold":{"default":0.7,"type":"number","description":"sobel threshold in % of gradient magintude"},"sam_use_gaussian_filter":{"default":false,"type":"boolean","description":"whether to apply a gaussian blur to each SAM masks"},"sampling_method":{"default":"ddpm","type":"string","description":"choose the sampling method between ddpm and ddim","enum":["ddpm","ddim"]},"sketch_canny_range":{"default":[0,765],"type":"array","items":{"enum":null,"type":"string"},"description":"range for Canny thresholds"},"super_resolution_scale":{"default":2.0,"type":"number","description":"scale for super resolution"},"task":{"default":"inpainting","type":"string","description":"","enum":["inpainting","super_resolution"]}}}}},"data":{"title":"Datasets","type":"object","properties":{"online_creation":{"title":"Online created datasets","type":"object","properties":{"color_mask_A":{"default":false,"type":"boolean","description":"Perform task of replacing color-filled masks by objects"},"crop_delta_A":{"default":50,"type":"integer","description":"size of crops are random, values allowed are online_creation_crop_size more or less online_creation_crop_delta for domain A"},"crop_delta_B":{"default":50,"type":"integer","description":"size of crops are random, values allowed are online_creation_crop_size more or less online_creation_crop_delta for domain B"},"crop_size_A":{"default":512,"type":"integer","description":"crop to this size during online creation, it needs to be greater than bbox size for domain A"},"crop_size_B":{"default":512,"type":"integer","description":"crop to this size during online creation, it needs to be greater than bbox size for domain B"},"load_size_A":{"default":[],"type":"array","items":{"enum":null,"type":"string"},"description":"load to this size during online creation, format : width height or only one size if square"},"load_size_B":{"default":[],"type":"array","items":{"enum":null,"type":"string"},"description":"load to this size during online creation, format : width height or only one size if square"},"mask_delta_A":{"default":[[]],"type":"array","items":{"enum":null,"type":"string"},"description":"mask offset (in pixels) to allow generation of a bigger object in domain B (for semantic loss) for domain A, format : 'width (x),height (y)' for each class or only one size if square, e.g. '125, 55 100, 100' for 2 classes"},"mask_delta_A_ratio":{"default":[[]],"type":"array","items":{"enum":null,"type":"string"},"description":"ratio mask offset to allow generation of a bigger object in domain B (for semantic loss) for domain A, format : width (x),height (y) for each class or only one size if square"},"mask_delta_B":{"default":[[]],"type":"array","items":{"enum":null,"type":"string"},"description":"mask offset (in pixels) to allow generation of a bigger object in domain A (for semantic loss) for domain B, format : 'width (x),height (y)' for each class or only one size if square, e.g. '125, 55 100, 100' for 2 classes"},"mask_delta_B_ratio":{"default":[[]],"type":"array","items":{"enum":null,"type":"string"},"description":"ratio mask offset to allow generation of a bigger object in domain A (for semantic loss) for domain B, format : 'width (x),height (y)' for each class or only one size if square"},"mask_random_offset_A":{"default":[0.0],"type":"array","items":{"enum":null,"type":"string"},"description":"ratio mask size randomization (only to make bigger one) to robustify the image generation in domain A, format : width (x) height (y) or only one size if square"},"mask_random_offset_B":{"default":[0.0],"type":"array","items":{"enum":null,"type":"string"},"description":"mask size randomization (only to make bigger one) to robustify the image generation in domain B, format : width (y) height (x) or only one size if square"},"mask_square_A":{"default":false,"type":"boolean","description":"whether masks should be squared for domain A"},"mask_square_B":{"default":false,"type":"boolean","description":"whether masks should be squared for domain B"},"rand_mask_A":{"default":false,"type":"boolean","description":"Perform task of replacing noised masks by objects"}}},"crop_size":{"default":256,"type":"integer","description":"then crop to this size"},"dataset_mode":{"default":"unaligned","type":"string","description":"chooses how datasets are loaded.","enum":["unaligned","unaligned_labeled_cls","unaligned_labeled_mask","self_supervised_labeled_mask","unaligned_labeled_mask_cls","self_supervised_labeled_mask_cls","unaligned_labeled_mask_online","self_supervised_labeled_mask_online","unaligned_labeled_mask_cls_online","self_supervised_labeled_mask_cls_online","aligned","nuplet_unaligned_labeled_mask","temporal_labeled_mask_online","self_supervised_temporal","single"]},"direction":{"default":"AtoB","type":"string","description":"AtoB or BtoA","enum":["AtoB","BtoA"]},"inverted_mask":{"default":false,"type":"boolean","description":"whether to invert the mask, i.e. around the bbox"},"load_size":{"default":286,"type":"integer","description":"scale images to this size"},"max_dataset_size":{"default":1000000000,"type":"integer","description":"Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded."},"num_threads":{"default":4,"type":"integer","description":"\\# threads for loading data"},"online_context_pixels":{"default":0,"type":"integer","description":"context pixel band around the crop, unused for generation, only for disc "},"online_fixed_mask_size":{"default":-1,"type":"integer","description":"if \\>0, it will be used as fixed bbox size (warning: in dataset resolution ie before resizing) "},"online_select_category":{"default":-1,"type":"integer","description":"category to select for bounding boxes, -1 means all boxes selected"},"online_single_bbox":{"default":false,"type":"boolean","description":"whether to only allow a single bbox per online crop"},"preprocess":{"default":"resize_and_crop","type":"string","description":"scaling and cropping of images at load time","enum":["resize_and_crop","crop","scale_width","scale_width_and_crop","none"]},"refined_mask":{"default":false,"type":"boolean","description":"whether to use refined mask with sam"},"relative_paths":{"default":false,"type":"boolean","description":"whether paths to images are relative to dataroot"},"sanitize_paths":{"default":false,"type":"boolean","description":"if true, wrong images or labels paths will be removed before training"},"serial_batches":{"default":false,"type":"boolean","description":"if true, takes images in order to make batches, otherwise takes them randomly"},"temporal_frame_step":{"default":30,"type":"integer","description":"how many frames between successive frames selected"},"temporal_num_common_char":{"default":-1,"type":"integer","description":"how many characters (the first ones) are used to identify a video; if =-1 natural sorting is used "},"temporal_number_frames":{"default":5,"type":"integer","description":"how many successive frames use for temporal loader"}}},"f_s":{"title":"Semantic segmentation network","type":"object","properties":{"all_classes_as_one":{"default":false,"type":"boolean","description":"if true, all classes will be considered as the same one (ie foreground vs background)"},"class_weights":{"default":[],"type":"array","items":{"enum":null,"type":"string"},"description":"class weights for imbalanced semantic classes"},"config_segformer":{"default":"models/configs/segformer/segformer_config_b0.json","type":"string","description":"path to segformer configuration file for f_s"},"dropout":{"default":false,"type":"boolean","description":"dropout for the semantic network"},"net":{"default":"vgg","type":"string","description":"specify f_s network [vgg|unet|segformer|sam]","enum":["vgg","unet","segformer","sam"]},"nf":{"default":64,"type":"integer","description":"\\# of filters in the first conv layer of classifier"},"semantic_nclasses":{"default":2,"type":"integer","description":"number of classes of the semantic loss classifier"},"semantic_threshold":{"default":1.0,"type":"number","description":"threshold of the semantic classifier loss below with semantic loss is applied"},"weight_sam":{"default":"","type":"string","description":"path to sam weight for f_s, e.g. models/configs/sam/pretrain/sam_vit_b_01ec64.pth, or models/configs/sam/pretrain/mobile_sam.pt for MobileSAM"},"weight_segformer":{"default":"","type":"string","description":"path to segformer weight for f_s, e.g. models/configs/segformer/pretrain/segformer_mit-b0.pth"}}},"cls":{"title":"Semantic classification network","type":"object","properties":{"all_classes_as_one":{"default":false,"type":"boolean","description":"if true, all classes will be considered as the same one (ie foreground vs background)"},"class_weights":{"default":[],"type":"array","items":{"enum":null,"type":"string"},"description":"class weights for imbalanced semantic classes"},"config_segformer":{"default":"models/configs/segformer/segformer_config_b0.json","type":"string","description":"path to segformer configuration file for cls"},"dropout":{"default":false,"type":"boolean","description":"dropout for the semantic network"},"net":{"default":"vgg","type":"string","description":"specify cls network [vgg|unet|segformer]","enum":["vgg","unet","segformer"]},"nf":{"default":64,"type":"integer","description":"\\# of filters in the first conv layer of classifier"},"semantic_nclasses":{"default":2,"type":"integer","description":"number of classes of the semantic loss classifier"},"semantic_threshold":{"default":1.0,"type":"number","description":"threshold of the semantic classifier loss below with semantic loss is applied"},"weight_segformer":{"default":"","type":"string","description":"path to segformer weight for cls, e.g. models/configs/segformer/pretrain/segformer_mit-b0.pth"}}},"output":{"title":"Output","type":"object","properties":{"display":{"title":"Visdom display","type":"object","properties":{"G_attention_masks":{"default":false,"type":"boolean","description":""},"aim_port":{"default":53800,"type":"integer","description":"aim port of the web display"},"aim_server":{"default":"http://localhost","type":"string","description":"aim server of the web display"},"diff_fake_real":{"default":false,"type":"boolean","description":"if True x - G(x) is displayed"},"env":{"default":"","type":"string","description":"visdom display environment name (default is \"main\")"},"freq":{"default":400,"type":"integer","description":"frequency of showing training results on screen"},"id":{"default":1,"type":"integer","description":"window id of the web display"},"ncols":{"default":0,"type":"integer","description":"if positive, display all images in a single visdom web panel with certain number of images per row.(if == 0 ncols will be computed automatically)"},"networks":{"default":false,"type":"boolean","description":"Set True if you want to display networks on port 8000"},"type":{"default":["visdom"],"type":"array","items":{"enum":null,"type":"string"},"description":"output display, either visdom, aim or no output","enum":["visdom","aim","none"]},"visdom_autostart":{"default":false,"type":"boolean","description":"whether to start a visdom server automatically"},"visdom_port":{"default":8097,"type":"integer","description":"visdom port of the web display"},"visdom_server":{"default":"http://localhost","type":"string","description":"visdom server of the web display"},"winsize":{"default":256,"type":"integer","description":"display window size for both visdom and HTML"}}},"no_html":{"default":false,"type":"boolean","description":"do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/"},"print_freq":{"default":100,"type":"integer","description":"frequency of showing training results on console"},"update_html_freq":{"default":1000,"type":"integer","description":"frequency of saving training results to html"},"verbose":{"default":false,"type":"boolean","description":"if specified, print more debugging information"}}},"model":{"title":"Model","type":"object","properties":{"depth_network":{"default":"DPT_Large","type":"string","description":"specify depth prediction network architecture","enum":["DPT_Large","DPT_Hybrid","MiDaS_small","DPT_BEiT_L_512","DPT_BEiT_L_384","DPT_BEiT_B_384","DPT_SwinV2_L_384","DPT_SwinV2_B_384","DPT_SwinV2_T_256","DPT_Swin_L_384","DPT_Next_ViT_L_384","DPT_LeViT_224"]},"init_gain":{"default":0.02,"type":"number","description":"scaling factor for normal, xavier and orthogonal."},"init_type":{"default":"normal","type":"string","description":"network initialization","enum":["normal","xavier","kaiming","orthogonal"]},"input_nc":{"default":3,"type":"integer","description":"\\# of input image channels: 3 for RGB and 1 for grayscale","enum":[1,3]},"multimodal":{"default":false,"type":"boolean","description":"multimodal model with random latent input vector"},"output_nc":{"default":3,"type":"integer","description":"\\# of output image channels: 3 for RGB and 1 for grayscale","enum":[1,3]},"prior_321_backwardcompatibility":{"default":false,"type":"boolean","description":"whether to load models from previous version of JG."},"type_sam":{"default":"mobile_sam","type":"string","description":"which model to use for segment-anything mask generation","enum":["sam","mobile_sam"]}}},"train":{"title":"Training","type":"object","properties":{"sem":{"title":"Semantic training","type":"object","properties":{"cls_B":{"default":false,"type":"boolean","description":"if true cls will be trained not only on domain A but also on domain B"},"cls_lambda":{"default":1.0,"type":"number","description":"weight for semantic class loss"},"cls_pretrained":{"default":false,"type":"boolean","description":"whether to use a pretrained model, available for non \"basic\" model only"},"cls_template":{"default":"basic","type":"string","description":"classifier/regressor model type, from torchvision (resnet18, ...), default is custom simple model"},"idt":{"default":false,"type":"boolean","description":"if true apply semantic loss on identity"},"lr_cls":{"default":0.0002,"type":"number","description":"cls learning rate"},"lr_f_s":{"default":0.0002,"type":"number","description":"f_s learning rate"},"mask_lambda":{"default":1.0,"type":"number","description":"weight for semantic mask loss"},"net_output":{"default":false,"type":"boolean","description":"if true apply generator semantic loss on network output for real image rather than on label."},"use_label_B":{"default":false,"type":"boolean","description":"if true domain B has labels too"}}},"mask":{"title":"Semantic training with masks","type":"object","properties":{"charbonnier_eps":{"default":1e-06,"type":"number","description":"Charbonnier loss epsilon value"},"compute_miou":{"default":false,"type":"boolean","description":"whether to compute mIoU on semantic masks prediction"},"disjoint_f_s":{"default":false,"type":"boolean","description":"whether to use a disjoint f_s with the same exact structure"},"f_s_B":{"default":false,"type":"boolean","description":"if true f_s will be trained not only on domain A but also on domain B"},"for_removal":{"default":false,"type":"boolean","description":"if true, object removal mode, domain B images with label 0, cut models only"},"lambda_out_mask":{"default":10.0,"type":"number","description":"weight for loss out mask"},"loss_out_mask":{"default":"L1","type":"string","description":"loss for out mask content (which should not change).","enum":["L1","MSE","Charbonnier"]},"miou_every":{"default":1000,"type":"integer","description":"compute mIoU every n iterations"},"no_train_f_s_A":{"default":false,"type":"boolean","description":"if true f_s wont be trained on domain A"},"out_mask":{"default":false,"type":"boolean","description":"use loss out mask"}}},"D_accuracy_every":{"default":1000,"type":"integer","description":"compute D accuracy every N iterations"},"D_lr":{"default":0.0001,"type":"number","description":"discriminator separate learning rate"},"G_ema":{"default":false,"type":"boolean","description":"whether to build G via exponential moving average"},"G_ema_beta":{"default":0.999,"type":"number","description":"exponential decay for ema"},"G_lr":{"default":0.0002,"type":"number","description":"initial learning rate for generator"},"batch_size":{"default":1,"type":"integer","description":"input batch size"},"beta1":{"default":0.9,"type":"number","description":"momentum term of adam"},"beta2":{"default":0.999,"type":"number","description":"momentum term of adam"},"cls_l1_regression":{"default":false,"type":"boolean","description":"if true l1 loss will be used to compute regressor loss"},"cls_regression":{"default":false,"type":"boolean","description":"if true cls will be a regressor and not a classifier"},"compute_D_accuracy":{"default":false,"type":"boolean","description":"whether to compute D accuracy explicitely"},"compute_metrics_test":{"default":false,"type":"boolean","description":"whether to compute test metrics, e.g. FID, ..."},"continue":{"default":false,"type":"boolean","description":"continue training: load the latest model"},"epoch":{"default":"latest","type":"string","description":"which epoch to load? set to latest to use latest cached model"},"epoch_count":{"default":1,"type":"integer","description":"the starting epoch count, we save the model by \\, \\+\\, ..."},"export_jit":{"default":false,"type":"boolean","description":"whether to export model in jit format"},"feat_wavelet":{"default":false,"type":"boolean","description":"if true, train in wavelet features space (Note: this may not include all discriminators, when training GANs)"},"gan_mode":{"default":"lsgan","type":"string","description":"the type of GAN objective. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.","enum":["vanilla","lsgan","wgangp","projected"]},"iter_size":{"default":1,"type":"integer","description":"backward will be apllied each iter_size iterations, it simulate a greater batch size : its value is batch_size\\*iter_size"},"load_iter":{"default":0,"type":"integer","description":"which iteration to load? if load_iter \\> 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]"},"lr_decay_iters":{"default":50,"type":"integer","description":"multiply by a gamma every lr_decay_iters iterations"},"lr_policy":{"default":"linear","type":"string","description":"learning rate policy.","enum":["linear","step","plateau","cosine"]},"metrics_every":{"default":1000,"type":"integer","description":"compute metrics every N iterations"},"metrics_list":{"default":["FID"],"type":"array","items":{"enum":null,"type":"string"},"description":"metrics on results quality to compute","enum":["FID","KID","MSID","PSNR","LPIPS"]},"mm_lambda_z":{"default":0.5,"type":"number","description":"weight for random z loss"},"mm_nz":{"default":8,"type":"integer","description":"number of latent vectors"},"n_epochs":{"default":100,"type":"integer","description":"number of epochs with the initial learning rate"},"n_epochs_decay":{"default":100,"type":"integer","description":"number of epochs to linearly decay learning rate to zero"},"nb_img_max_fid":{"default":1000000000,"type":"integer","description":"Maximum number of samples allowed per dataset to compute fid. If the dataset directory contains more than nb_img_max_fid, only a subset is used."},"optim":{"default":"adam","type":"string","description":"optimizer (adam, radam, adamw, ...)","enum":["adam","radam","adamw","lion"]},"pool_size":{"default":50,"type":"integer","description":"the size of image buffer that stores previously generated images"},"save_by_iter":{"default":false,"type":"boolean","description":"whether saves model by iteration"},"save_epoch_freq":{"default":1,"type":"integer","description":"frequency of saving checkpoints at the end of epochs"},"save_latest_freq":{"default":5000,"type":"integer","description":"frequency of saving the latest results"},"semantic_cls":{"default":false,"type":"boolean","description":"if true semantic class losses will be used"},"semantic_mask":{"default":false,"type":"boolean","description":"if true semantic mask losses will be used"},"temporal_criterion":{"default":false,"type":"boolean","description":"if true, MSE loss will be computed between successive frames"},"temporal_criterion_lambda":{"default":1.0,"type":"number","description":"lambda for MSE loss that will be computed between successive frames"},"use_contrastive_loss_D":{"default":false,"type":"boolean","description":""}}},"dataaug":{"title":"Data augmentation","type":"object","properties":{"APA":{"default":false,"type":"boolean","description":"if true, G will be used as augmentation during D training adaptively to D overfitting between real and fake images"},"APA_every":{"default":4,"type":"integer","description":"How often to perform APA adjustment?"},"APA_nimg":{"default":50,"type":"integer","description":"APA adjustment speed, measured in how many images it takes for p to increase/decrease by one unit."},"APA_p":{"default":0,"type":"integer","description":"initial value of probability APA"},"APA_target":{"default":0.6,"type":"number","description":""},"D_diffusion":{"default":false,"type":"boolean","description":"whether to apply diffusion noise augmentation to discriminator inputs, projected discriminator only"},"D_diffusion_every":{"default":4,"type":"integer","description":"How often to perform diffusion augmentation adjustment"},"D_label_smooth":{"default":false,"type":"boolean","description":"whether to use one-sided label smoothing with discriminator"},"D_noise":{"default":0.0,"type":"number","description":"whether to add instance noise to discriminator inputs"},"affine":{"default":0.0,"type":"number","description":"if specified, apply random affine transforms to the images for data augmentation"},"affine_scale_max":{"default":1.2,"type":"number","description":"if random affine specified, max scale range value"},"affine_scale_min":{"default":0.8,"type":"number","description":"if random affine specified, min scale range value"},"affine_shear":{"default":45,"type":"integer","description":"if random affine specified, shear range (0,value)"},"affine_translate":{"default":0.2,"type":"number","description":"if random affine specified, translation range (-value\\*img_size,+value\\*img_size) value"},"diff_aug_policy":{"default":"","type":"string","description":"choose the augmentation policy : color randaffine randperspective. If you want more than one, please write them separated by a comma with no space (e.g. color,randaffine)"},"diff_aug_proba":{"default":0.5,"type":"number","description":"proba of using each transformation"},"imgaug":{"default":false,"type":"boolean","description":"whether to apply random image augmentation"},"no_flip":{"default":false,"type":"boolean","description":"if specified, do not flip the images for data augmentation"},"no_rotate":{"default":false,"type":"boolean","description":"if specified, do not rotate the images for data augmentation"}}},"checkpoints_dir":{"default":"./checkpoints","type":"string","description":"models are saved here"},"dataroot":{"default":"None","type":"string","description":"path to images (should have subfolders trainA, trainB, valA, valB, etc)"},"ddp_port":{"default":"12355","type":"string","description":""},"gpu_ids":{"default":"0","type":"string","description":"gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU"},"model_type":{"default":"cut","type":"string","description":"chooses which model to use.","enum":["cut","cycle_gan","palette"]},"name":{"default":"experiment_name","type":"string","description":"name of the experiment. It decides where to store samples and models"},"phase":{"default":"train","type":"string","description":"train, val, test, etc"},"suffix":{"default":"","type":"string","description":"customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}"},"test_batch_size":{"default":1,"type":"integer","description":"input batch size"},"warning_mode":{"default":false,"type":"boolean","description":"whether to display warning"},"with_amp":{"default":false,"type":"boolean","description":"whether to activate torch amp on forward passes"},"with_tf32":{"default":false,"type":"boolean","description":"whether to activate tf32 for faster computations (Ampere GPU and beyond only)"},"with_torch_compile":{"default":false,"type":"boolean","description":"whether to activate torch.compile for some forward and backward functions (experimental)"}}}},"definitions":{"ServerTrainOptions":{"title":"ServerTrainOptions","type":"object","properties":{"sync":{"title":"Sync","description":"if false, the call returns immediately and train process is executed in the background. If true, the call returns only when training process is finished","default":false,"type":"boolean"}}}}}}},"definitions":{"ServerTrainOptions":{"title":"ServerTrainOptions","type":"object","properties":{"sync":{"title":"Sync","description":"if false, the call returns immediately and train process is executed in the background. If true, the call returns only when training process is finished","default":false,"type":"boolean"}}}}} \ No newline at end of file diff --git a/docs/source/export.rst b/docs/source/export.rst index f02eb0c9e..a504e8da9 100644 --- a/docs/source/export.rst +++ b/docs/source/export.rst @@ -1,10 +1,6 @@ -############# +############## Model export -############ - -************ -Model export -************ +############## .. code:: bash diff --git a/docs/source/index.rst b/docs/source/index.rst index 890dffe32..ca9ebebe2 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -178,6 +178,7 @@ Contact: contact@jolibrain.com export inference + metrics .. toctree:: :maxdepth: 2 diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst new file mode 100644 index 000000000..b19c7ec8d --- /dev/null +++ b/docs/source/metrics.rst @@ -0,0 +1,45 @@ +################# + JoliGEN Metrics +################# + +JoliGEN reads the model configuration from a generated ``train_config.json`` file that is stored in the model directory. +When testing a previously trained model, make sure the ``train_config.json`` file is in the directory. + +.. code:: bash + + python3 test.py \ + --test_model_dir /path/to/model/directory \ + --test_epoch 1 \ + --test_metrics_list FID KID MSID PSNR LPIPS \ + --test_nb_img 1000 \ + --test_batch_size 16 \ + --test_seed 42 + +This will output the selected metrics: + +.. code:: text + + fidB_test: 136.3628652179921 + msidB_test: 32.10317674393986 + kidB_test: 0.036237239837646484 + psnr_test: 20.68259048461914 + +The metrics are also saved in a ``/path/to/model/directory/metrics/date.json`` file: + +.. code:: json + + { + "fidB_test": 136.3628652179921, + "msidB_test": 32.10317674393986, + "kidB_test": 0.036237239837646484, + "psnr_test": 20.68259048461914 + } + +The following options are available: + +- ``test_model_dir``: path to the checkpoints for the model, should contain a ``train_config.json`` file +- ``test_epoch``: which epoch to load, defaults to latest checkpoint +- ``test_metrics``: list of metrics to compute, defaults to all metrics +- ``test_nb_img``: number of images to generate to compute metrics, defaults to dataset size +- ``test_batch_size``: input batch size +- ``test_seed``: seed to use for reproducible results diff --git a/docs/source/options.rst b/docs/source/options.rst index 388059ef9..05210e648 100644 --- a/docs/source/options.rst +++ b/docs/source/options.rst @@ -67,7 +67,7 @@ Discriminator +----------------------------+-----------------+--------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | --D_vision_aided_backbones | string | clip+dino+swin | specify vision aided discriminators architectures, they are frozen then output are combined and fitted with a linear network on top, choose from dino, clip, swin, det_coco, seg_ade and combine them with + | +----------------------------+-----------------+--------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| --D_weight_sam | string | | path to sam weight for D, e.g. models/configs/sam/pretrain/sam_vit_b_01ec64.pth | +| --D_weight_sam | string | | path to sam weight for D, e.g. models/configs/sam/pretrain/sam_vit_b_01ec64.pth, or models/configs/sam/pretrain/mobile_sam.pt for MobileSAM | +----------------------------+-----------------+--------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ Generator @@ -76,9 +76,9 @@ Generator +------------------------------------------------+-----------------+---------------------------------------------------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | Parameter | Type | Default | Description | +================================================+=================+===================================================+=============================================================================================================================================================================================================+ -| --G_attn_nb_mask_attn | int | 10 | | +| --G_attn_nb_mask_attn | int | 10 | number of attention masks in \_attn model architectures | +------------------------------------------------+-----------------+---------------------------------------------------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| --G_attn_nb_mask_input | int | 1 | | +| --G_attn_nb_mask_input | int | 1 | number of mask dedicated to input in \_attn model architectures | +------------------------------------------------+-----------------+---------------------------------------------------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | --G_backward_compatibility_twice_resnet_blocks | flag | | if true, feats will go througt resnet blocks two times for resnet_attn generators. This option will be deleted, it’s for backward compatibility (old models were trained that way). | +------------------------------------------------+-----------------+---------------------------------------------------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ @@ -114,9 +114,9 @@ Generator +------------------------------------------------+-----------------+---------------------------------------------------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | --G_unet_mha_norm_layer | string | groupnorm | **Values:** groupnorm, batchnorm, layernorm, instancenorm, switchablenorm | +------------------------------------------------+-----------------+---------------------------------------------------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| --G_unet_mha_num_head_channels | int | 32 | | +| --G_unet_mha_num_head_channels | int | 32 | number of channels in each head of the mha architecture | +------------------------------------------------+-----------------+---------------------------------------------------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| --G_unet_mha_num_heads | int | 1 | | +| --G_unet_mha_num_heads | int | 1 | number of heads in the mha architecture | +------------------------------------------------+-----------------+---------------------------------------------------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | --G_unet_mha_res_blocks | array | [2, 2, 2, 2] | distribution of resnet blocks across the UNet stages, should have same size as --G_unet_mha_channel_mults | +------------------------------------------------+-----------------+---------------------------------------------------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ @@ -360,29 +360,29 @@ Online created datasets Semantic segmentation network ----------------------------- -+--------------------------+-----------------+---------------------------------------------------+-----------------------------------------------------------------------------------------------+ -| Parameter | Type | Default | Description | -+==========================+=================+===================================================+===============================================================================================+ -| --f_s_all_classes_as_one | flag | | if true, all classes will be considered as the same one (ie foreground vs background) | -+--------------------------+-----------------+---------------------------------------------------+-----------------------------------------------------------------------------------------------+ -| --f_s_class_weights | array | [] | class weights for imbalanced semantic classes | -+--------------------------+-----------------+---------------------------------------------------+-----------------------------------------------------------------------------------------------+ -| --f_s_config_segformer | string | models/configs/segformer/segformer_config_b0.json | path to segformer configuration file for f_s | -+--------------------------+-----------------+---------------------------------------------------+-----------------------------------------------------------------------------------------------+ -| --f_s_dropout | flag | | dropout for the semantic network | -+--------------------------+-----------------+---------------------------------------------------+-----------------------------------------------------------------------------------------------+ -| --f_s_net | string | vgg | specify f_s network [vgg,unet,segformer,sam] **Values:** vgg, unet, segformer, sam | -+--------------------------+-----------------+---------------------------------------------------+-----------------------------------------------------------------------------------------------+ -| --f_s_nf | int | 64 | # of filters in the first conv layer of classifier | -+--------------------------+-----------------+---------------------------------------------------+-----------------------------------------------------------------------------------------------+ -| --f_s_semantic_nclasses | int | 2 | number of classes of the semantic loss classifier | -+--------------------------+-----------------+---------------------------------------------------+-----------------------------------------------------------------------------------------------+ -| --f_s_semantic_threshold | float | 1.0 | threshold of the semantic classifier loss below with semantic loss is applied | -+--------------------------+-----------------+---------------------------------------------------+-----------------------------------------------------------------------------------------------+ -| --f_s_weight_sam | string | | path to sam weight for f_s, e.g. models/configs/sam/pretrain/sam_vit_b_01ec64.pth | -+--------------------------+-----------------+---------------------------------------------------+-----------------------------------------------------------------------------------------------+ -| --f_s_weight_segformer | string | | path to segformer weight for f_s, e.g. models/configs/segformer/pretrain/segformer_mit-b0.pth | -+--------------------------+-----------------+---------------------------------------------------+-----------------------------------------------------------------------------------------------+ ++--------------------------+-----------------+---------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------+ +| Parameter | Type | Default | Description | ++==========================+=================+===================================================+===============================================================================================================================================+ +| --f_s_all_classes_as_one | flag | | if true, all classes will be considered as the same one (ie foreground vs background) | ++--------------------------+-----------------+---------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------+ +| --f_s_class_weights | array | [] | class weights for imbalanced semantic classes | ++--------------------------+-----------------+---------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------+ +| --f_s_config_segformer | string | models/configs/segformer/segformer_config_b0.json | path to segformer configuration file for f_s | ++--------------------------+-----------------+---------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------+ +| --f_s_dropout | flag | | dropout for the semantic network | ++--------------------------+-----------------+---------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------+ +| --f_s_net | string | vgg | specify f_s network [vgg,unet,segformer,sam] **Values:** vgg, unet, segformer, sam | ++--------------------------+-----------------+---------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------+ +| --f_s_nf | int | 64 | # of filters in the first conv layer of classifier | ++--------------------------+-----------------+---------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------+ +| --f_s_semantic_nclasses | int | 2 | number of classes of the semantic loss classifier | ++--------------------------+-----------------+---------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------+ +| --f_s_semantic_threshold | float | 1.0 | threshold of the semantic classifier loss below with semantic loss is applied | ++--------------------------+-----------------+---------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------+ +| --f_s_weight_sam | string | | path to sam weight for f_s, e.g. models/configs/sam/pretrain/sam_vit_b_01ec64.pth, or models/configs/sam/pretrain/mobile_sam.pt for MobileSAM | ++--------------------------+-----------------+---------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------+ +| --f_s_weight_segformer | string | | path to segformer weight for f_s, e.g. models/configs/segformer/pretrain/segformer_mit-b0.pth | ++--------------------------+-----------------+---------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------+ Semantic classification network ------------------------------- @@ -479,6 +479,8 @@ Model +-----------------------------------------+-----------------+-----------------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | --model_prior_321_backwardcompatibility | flag | | whether to load models from previous version of JG. | +-----------------------------------------+-----------------+-----------------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| --model_type_sam | string | mobile_sam | which model to use for segment-anything mask generation **Values:** sam, mobile_sam | ++-----------------------------------------+-----------------+-----------------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ Training -------- @@ -508,7 +510,7 @@ Training +-----------------------------------+-----------------+-----------------+----------------------------------------------------------------------------------------------------------------------------------------------------------+ | --train_compute_D_accuracy | flag | | whether to compute D accuracy explicitely | +-----------------------------------+-----------------+-----------------+----------------------------------------------------------------------------------------------------------------------------------------------------------+ -| --train_compute_metrics_test | flag | | | +| --train_compute_metrics_test | flag | | whether to compute test metrics, e.g. FID, … | +-----------------------------------+-----------------+-----------------+----------------------------------------------------------------------------------------------------------------------------------------------------------+ | --train_continue | flag | | continue training: load the latest model | +-----------------------------------+-----------------+-----------------+----------------------------------------------------------------------------------------------------------------------------------------------------------+ @@ -518,6 +520,8 @@ Training +-----------------------------------+-----------------+-----------------+----------------------------------------------------------------------------------------------------------------------------------------------------------+ | --train_export_jit | flag | | whether to export model in jit format | +-----------------------------------+-----------------+-----------------+----------------------------------------------------------------------------------------------------------------------------------------------------------+ +| --train_feat_wavelet | flag | | if true, train in wavelet features space (Note: this may not include all discriminators, when training GANs) | ++-----------------------------------+-----------------+-----------------+----------------------------------------------------------------------------------------------------------------------------------------------------------+ | --train_gan_mode | string | lsgan | the type of GAN objective. vanilla GAN loss is the cross-entropy objective used in the original GAN paper. **Values:** vanilla, lsgan, wgangp, projected | +-----------------------------------+-----------------+-----------------+----------------------------------------------------------------------------------------------------------------------------------------------------------+ | --train_iter_size | int | 1 | backward will be apllied each iter_size iterations, it simulate a greater batch size : its value is batch_size*iter_size | @@ -528,9 +532,9 @@ Training +-----------------------------------+-----------------+-----------------+----------------------------------------------------------------------------------------------------------------------------------------------------------+ | --train_lr_policy | string | linear | learning rate policy. **Values:** linear, step, plateau, cosine | +-----------------------------------+-----------------+-----------------+----------------------------------------------------------------------------------------------------------------------------------------------------------+ -| --train_metrics_every | int | 1000 | | +| --train_metrics_every | int | 1000 | compute metrics every N iterations | +-----------------------------------+-----------------+-----------------+----------------------------------------------------------------------------------------------------------------------------------------------------------+ -| --train_metrics_list | array | [‘FID’] | **Values:** FID, KID, MSID, PSNR | +| --train_metrics_list | array | [‘FID’] | metrics on results quality to compute **Values:** FID, KID, MSID, PSNR, LPIPS | +-----------------------------------+-----------------+-----------------+----------------------------------------------------------------------------------------------------------------------------------------------------------+ | --train_mm_lambda_z | float | 0.5 | weight for random z loss | +-----------------------------------+-----------------+-----------------+----------------------------------------------------------------------------------------------------------------------------------------------------------+ @@ -598,7 +602,7 @@ Semantic training with masks +==============================+=================+=================+=======================================================================================+ | --train_mask_charbonnier_eps | float | 1e-06 | Charbonnier loss epsilon value | +------------------------------+-----------------+-----------------+---------------------------------------------------------------------------------------+ -| --train_mask_compute_miou | flag | | | +| --train_mask_compute_miou | flag | | whether to compute mIoU on semantic masks prediction | +------------------------------+-----------------+-----------------+---------------------------------------------------------------------------------------+ | --train_mask_disjoint_f_s | flag | | whether to use a disjoint f_s with the same exact structure | +------------------------------+-----------------+-----------------+---------------------------------------------------------------------------------------+ @@ -610,7 +614,7 @@ Semantic training with masks +------------------------------+-----------------+-----------------+---------------------------------------------------------------------------------------+ | --train_mask_loss_out_mask | string | L1 | loss for out mask content (which should not change). **Values:** L1, MSE, Charbonnier | +------------------------------+-----------------+-----------------+---------------------------------------------------------------------------------------+ -| --train_mask_miou_every | int | 1000 | | +| --train_mask_miou_every | int | 1000 | compute mIoU every n iterations | +------------------------------+-----------------+-----------------+---------------------------------------------------------------------------------------+ | --train_mask_no_train_f_s_A | flag | | if true f_s wont be trained on domain A | +------------------------------+-----------------+-----------------+---------------------------------------------------------------------------------------+ diff --git a/docs/source/training.rst b/docs/source/training.rst index 73e045176..90a1bd2b1 100644 --- a/docs/source/training.rst +++ b/docs/source/training.rst @@ -44,7 +44,7 @@ Dataset: https://joligen.com/datasets/horse2zebra.zip .. code:: bash - python3 train.py --dataroot /home/beniz/tmp/joligan/datasets/horse2zebra --checkpoints_dir /home/beniz/tmp/joligan/checkpoints --name horse2zebra --config_json examples/example_gan_horse2zebra.json + python3 train.py --dataroot /path/to/horse2zebra --checkpoints_dir /path/to/checkpoints --name horse2zebra --config_json examples/example_gan_horse2zebra.json .. _training-im2im-with-class-semantics: @@ -97,7 +97,7 @@ Trains a diffusion model to insert glasses onto faces. .. code:: bash - python3 train.py --dataroot noglasses2glasses_ffhq --checkpoints_dir ./checkpoints --name glasses2noglasses --output_display_env glasses2noglasses --config_json examples/example_ddpm_glasses2noglasses.json + python3 train.py --dataroot /path/to/data/noglasses2glasses_ffhq --checkpoints_dir /path/to/checkpoints --name noglasses2glasses --config_json examples/example_ddpm_noglasses2glasses.json diff --git a/examples/example_ddpm_noglasses2glasses.json b/examples/example_ddpm_noglasses2glasses.json index c078d66cf..7be8921fd 100644 --- a/examples/example_ddpm_noglasses2glasses.json +++ b/examples/example_ddpm_noglasses2glasses.json @@ -115,10 +115,10 @@ "load_size_A": [], "load_size_B": [], "mask_delta_A": [ - 0 + [] ], "mask_delta_B": [ - 0 + [] ], "mask_random_offset_A": [ 0.0 @@ -159,17 +159,6 @@ "weight_sam": "", "weight_segformer": "" }, - "cls": { - "all_classes_as_one": false, - "class_weights": [], - "config_segformer": "models/configs/segformer/segformer_config_b0.py", - "dropout": false, - "net": "vgg", - "nf": 64, - "semantic_nclasses": 2, - "semantic_threshold": 1.0, - "weight_segformer": "" - }, "output": { "display": { "G_attention_masks": false, diff --git a/examples/example_gan_glasses2noglasses.json b/examples/example_gan_glasses2noglasses.json index 9a2132414..b46eab8c5 100644 --- a/examples/example_gan_glasses2noglasses.json +++ b/examples/example_gan_glasses2noglasses.json @@ -105,10 +105,10 @@ "load_size_A": [], "load_size_B": [], "mask_delta_A": [ - 0 + 0 ], "mask_delta_B": [ - 0 + 0 ], "mask_random_offset_A": [ 0.0 diff --git a/examples/example_gan_mario2sonic.json b/examples/example_gan_mario2sonic.json index f5dc279ba..48694fb31 100644 --- a/examples/example_gan_mario2sonic.json +++ b/examples/example_gan_mario2sonic.json @@ -62,17 +62,13 @@ "load_size_A": [], "load_size_B": [], "mask_delta_A": [ - 50 + [50] ], "mask_delta_B": [ - 15 - ], - "mask_random_offset_A": [ - 0.0 - ], - "mask_random_offset_B": [ - 0.0 + [15] ], + "mask_random_offset_A": [0.0], + "mask_random_offset_B": [0.0], "mask_square_A": false, "mask_square_B": false, "rand_mask_A": false @@ -99,22 +95,11 @@ "dropout": false, "net": "unet", "nf": 64, - "semantic_nclasses": 2, + "semantic_nclasses": 7, "semantic_threshold": 1.0, "weight_sam": "", "weight_segformer": "" }, - "cls": { - "all_classes_as_one": false, - "class_weights": [], - "config_segformer": "models/configs/segformer/segformer_config_b0.py", - "dropout": false, - "net": "vgg", - "nf": 64, - "semantic_nclasses": 2, - "semantic_threshold": 1.0, - "weight_segformer": "" - }, "output": { "display": { "G_attention_masks": false, diff --git a/models/base_diffusion_model.py b/models/base_diffusion_model.py index f33e786b7..5197afb3d 100644 --- a/models/base_diffusion_model.py +++ b/models/base_diffusion_model.py @@ -9,7 +9,6 @@ import numpy as np import torch import torch.nn.functional as F -from segment_anything import SamPredictor from torchviz import make_dot from tqdm import tqdm @@ -17,8 +16,13 @@ from util.util import save_image, tensor2im from .base_model import BaseModel -from .modules.sam.sam_inference import load_sam_weight, predict_sam_mask -from .modules.utils import download_sam_weight +from .modules.sam.sam_inference import ( + init_sam_net, + load_mobile_sam_weight, + load_sam_weight, + predict_sam_mask, +) +from .modules.utils import download_mobile_sam_weight, download_sam_weight class BaseDiffusionModel(BaseModel): @@ -75,10 +79,9 @@ def __init__(self, opt, rank): else: self.use_sam_edge = False if opt.data_refined_mask or "sam" in opt.alg_palette_computed_sketch_list: - if not os.path.exists(opt.f_s_weight_sam): - download_sam_weight(path=opt.f_s_weight_sam) - self.freezenet_sam, _ = load_sam_weight(model_path=opt.f_s_weight_sam) - self.freezenet_sam = self.freezenet_sam.to(self.device) + self.freezenet_sam, _ = init_sam_net( + opt.model_type_sam, opt.f_s_weight_sam, self.device + ) def init_semantic_cls(self, opt): # specify the training losses you want to print out. diff --git a/models/base_gan_model.py b/models/base_gan_model.py index d945a5256..8b69a7a7e 100644 --- a/models/base_gan_model.py +++ b/models/base_gan_model.py @@ -1,29 +1,35 @@ -import os import copy -import torch -from collections import OrderedDict +import os from abc import abstractmethod -from . import gan_networks -from .modules.utils import get_scheduler, predict_depth, download_midas_weight -from .modules.sam.sam_inference import load_sam_weight, predict_sam -from torchviz import make_dot -from .base_model import BaseModel +from collections import OrderedDict -from util.network_group import NetworkGroup +import numpy as np +import torch +import torch.nn.functional as F +from torchviz import make_dot # for FID from data.base_dataset import get_transform -from util.util import save_image, tensor2im -import numpy as np from util.diff_aug import DiffAugment +from util.discriminator import DiscriminatorInfo # for D accuracy from util.image_pool import ImagePool -import torch.nn.functional as F +from util.network_group import NetworkGroup +from util.util import save_image, tensor2im + +from . import gan_networks +from .base_model import BaseModel # For D loss computing from .modules import loss -from util.discriminator import DiscriminatorInfo +from .modules.sam.sam_inference import ( + init_sam_net, + load_mobile_sam_weight, + load_sam_weight, + predict_sam, +) +from .modules.utils import download_midas_weight, get_scheduler, predict_depth class BaseGanModel(BaseModel): @@ -108,10 +114,10 @@ def __init__(self, opt, rank): else: self.use_depth = False - if "sam" in opt.D_netDs: + if "sam" in opt.D_netDs or opt.data_refined_mask: self.use_sam = True - self.netfreeze_sam, self.predictor_sam = load_sam_weight( - self.opt.D_weight_sam + self.netfreeze_sam, self.predictor_sam = init_sam_net( + opt.model_type_sam, self.opt.D_weight_sam, self.device ) else: self.use_sam = False @@ -142,14 +148,12 @@ def __init__(self, opt, rank): self.loss_functions_G.append("compute_temporal_criterion_loss") def init_semantic_cls(self, opt): - # specify the training losses you want to print out. # The training/test scripts will call super().init_semantic_cls(opt) def init_semantic_mask(self, opt): - # specify the training losses you want to print out. # The training/test scripts will call @@ -172,7 +176,6 @@ def forward_GAN(self): setattr(self, "image_" + str(i), cur_image) if self.opt.data_online_context_pixels > 0: - bs = self.get_current_batch_size() self.mask_context = torch.ones( [ @@ -413,7 +416,6 @@ def compute_G_loss_GAN_generic( real = getattr(self, real_name) if hasattr(self, "diff_augment"): - real = self.diff_augment(real) fake = self.diff_augment(fake) @@ -544,7 +546,6 @@ def set_discriminators_info(self): self.discriminators = [] for discriminator_name in self.discriminators_names: - loss_calculator_name = "D_" + discriminator_name + "_loss_calculator" if "temporal" in discriminator_name or "projected" in discriminator_name: diff --git a/models/base_model.py b/models/base_model.py index 291121e71..e4f5da393 100644 --- a/models/base_model.py +++ b/models/base_model.py @@ -19,8 +19,9 @@ from data.base_dataset import get_transform from util.metrics import _compute_statistics_of_dataloader - +from tqdm import tqdm from piq import MSID, KID, FID, psnr +from lpips import LPIPS from util.util import save_image, tensor2im, delete_flop_param @@ -37,7 +38,7 @@ # Iter Calculator from util.iter_calculator import IterCalculator from util.network_group import NetworkGroup -from util.util import delete_flop_param, save_image, tensor2im +from util.util import delete_flop_param, save_image, tensor2im, MAX_INT from . import base_networks, semantic_networks @@ -142,8 +143,10 @@ def __init__(self, opt, rank): self.msid_metric = MSID() if "KID" in self.opt.train_metrics_list: self.kid_metric = KID() + if "LPIPS" in self.opt.train_metrics_list: + self.lpips_metric = LPIPS().to(self.device) - def init_metrics(self, dataloader, dataloader_test): + def init_metrics(self, dataloader_test): self.use_inception = any( metric in self.opt.train_metrics_list for metric in ["KID", "FID", "MSID"] @@ -720,8 +723,13 @@ def export_networks(self, epoch): input_nc += self.opt.train_mm_nz # onnx - if not "ittr" in self.opt.G_netG and not ( - torch.__version__[0] == "2" and "segformer" in self.opt.G_netG + if ( + not self.opt.train_feat_wavelet + and not "ittr" in self.opt.G_netG + and not ( + torch.__version__[0] == "2" + and "segformer" in self.opt.G_netG + ) ): # XXX: segformer export fails with ONNX and Pytorch2 export_path_onnx = save_path.replace(".pth", ".onnx") @@ -1084,8 +1092,11 @@ def one_hot(self, tensor): def compute_fake_real_masks(self): fake_mask = self.netf_s(self.real_A) fake_mask = F.gumbel_softmax(fake_mask, tau=1.0, hard=True, dim=1) - real_mask = self.netf_s(self.real_B) + real_mask = self.netf_s( + self.real_B + ) # f_s(B) is a good approximation of the real mask when task is easy real_mask = F.gumbel_softmax(real_mask, tau=1.0, hard=True, dim=1) + setattr(self, "fake_mask_B_inv", fake_mask.argmax(dim=1)) setattr(self, "real_mask_B_inv", real_mask.argmax(dim=1)) setattr(self, "fake_mask_B", fake_mask) @@ -1133,7 +1144,16 @@ def compute_f_s_loss(self): f_s = self.netf_s_B else: f_s = self.netf_s - label_B = self.input_B_label_mask + + if self.opt.data_refined_mask: + # get mask with sam instead of label from self.real_B and self.input_B_ref_bbox + self.label_sam_B = ( + predict_sam(self.real_B, self.predictor_sam, self.input_B_ref_bbox) + > 0.0 + ) + label_B = self.label_sam_B.long() + else: + label_B = self.input_B_label_mask pred_B = f_s(self.real_B) self.loss_f_s += self.criterionf_s(pred_B, label_B) # .squeeze(1)) @@ -1303,6 +1323,11 @@ def get_current_metrics(self): "psnr_test", ] + if "LPIPS" in self.opt.train_metrics_list: + metrics_names += [ + "lpips_test", + ] + for name in metrics_names: if isinstance(name, str): metrics[name] = float( @@ -1324,6 +1349,15 @@ def compute_metrics_test(self, dataloaders_test, n_epoch, n_iter): fake_list = [] real_list = [] + if self.opt.train_nb_img_max_fid != MAX_INT: + progress = tqdm( + desc="compute metrics test", + position=1, + total=self.opt.train_nb_img_max_fid, + ) + else: + progress = None + for i, data_test_list in enumerate( dataloaders_test ): # inner loop (minibatch) within one epoch @@ -1367,7 +1401,21 @@ def compute_metrics_test(self, dataloaders_test, n_epoch, n_iter): for sub_list in self.visual_names: for name in sub_list: - setattr(self, name + "_test", getattr(self, name)) + if hasattr(self, name): + setattr(self, name + "_test", getattr(self, name)) + + if progress: + progress.n = min(len(fake_list), progress.total) + progress.refresh() + + if len(fake_list) >= self.opt.train_nb_img_max_fid: + break + + fake_list = fake_list[: self.opt.train_nb_img_max_fid] + real_list = real_list[: self.opt.train_nb_img_max_fid] + + if progress: + progress.close() if self.use_inception: self.fakeactB_test = _compute_statistics_of_dataloader( @@ -1390,15 +1438,18 @@ def compute_metrics_test(self, dataloaders_test, n_epoch, n_iter): real_tensor = (torch.cat(real_list) + 1) / 2 fake_tensor = (torch.clamp(torch.cat(fake_list), min=-1, max=1) + 1) / 2 - self.psnr_test = psnr(real_tensor, fake_tensor) + if "LPIPS" in self.opt.train_metrics_list: + real_tensor = torch.cat(real_list) + fake_tensor = torch.clamp(torch.cat(fake_list), min=-1, max=1) + self.lpips_test = self.lpips_metric(real_tensor, fake_tensor).mean() + def compute_metrics_generic(self, real_act, fake_act): # FID if "FID" in self.opt.train_metrics_list: fid = self.fid_metric(real_act, fake_act) - else: fid = None diff --git a/models/cut_model.py b/models/cut_model.py index 0a4f87395..05e3719cf 100644 --- a/models/cut_model.py +++ b/models/cut_model.py @@ -17,6 +17,7 @@ from util.util import gaussian import itertools +import warnings class CUTModel(BaseGanModel): @@ -139,7 +140,6 @@ def modify_commandline_options(parser, is_train=True): return parser def __init__(self, opt, rank): - super().__init__(opt, rank) # Vanilla cut @@ -150,6 +150,10 @@ def __init__(self, opt, rank): if "segformer" in self.opt.G_netG: self.opt.alg_cut_nce_layers = "0,1,2,3" + self.opt.alg_cut_nce_T = 0.2 # default 0.07 is too low, https://openaccess.thecvf.com/content/CVPR2021/papers/Wang_Understanding_the_Behaviour_of_Contrastive_Loss_CVPR_2021_paper.pdf for a related study + warnings.warn( + "cut with segformer requires nce_layers 0,1,2,3 and nce_T set to 0.2, these values are enforced" + ) elif "ittr" in self.opt.G_netG: self.opt.alg_cut_nce_layers = ",".join( [str(k) for k in range(self.opt.G_nblocks)] @@ -356,12 +360,13 @@ def __init__(self, opt, rank): losses_E = ["G_z"] losses_G += ["G_z"] - for discriminator in self.discriminators: - losses_D.append(discriminator.loss_name_D) - if "mask" in discriminator.name: - continue - else: - losses_G.append(discriminator.loss_name_G) + if self.isTrain: + for discriminator in self.discriminators: + losses_D.append(discriminator.loss_name_D) + if "mask" in discriminator.name: + continue + else: + losses_G.append(discriminator.loss_name_G) self.loss_names_G += losses_G self.loss_names_D += losses_D @@ -453,6 +458,8 @@ def data_dependent_initialize_semantic_mask(self, data): if "mask" in self.opt.D_netDs: visual_names_seg_B += ["real_mask_B_inv", "fake_mask_B_inv"] + if self.opt.data_refined_mask: + visual_names_seg_B += ["label_sam_B"] self.visual_names += [visual_names_seg_A, visual_names_seg_B] diff --git a/models/diffusion_networks.py b/models/diffusion_networks.py index bea5d8b4b..03e02b279 100644 --- a/models/diffusion_networks.py +++ b/models/diffusion_networks.py @@ -47,6 +47,7 @@ def define_G( resblock_updown=True, use_new_attention_order=False, f_s_semantic_nclasses=-1, + train_feat_wavelet=False, **unused_options ): """Create a generator @@ -102,6 +103,7 @@ def define_G( group_norm_size=G_unet_mha_group_norm_size, efficient=G_unet_mha_vit_efficient, cond_embed_dim=cond_embed_dim, + freq_space=train_feat_wavelet, ) elif G_netG == "uvit": @@ -124,6 +126,7 @@ def define_G( num_transformer_blocks=G_uvit_num_transformer_blocks, efficient=G_unet_mha_vit_efficient, cond_embed_dim=alg_palette_cond_embed_dim, + freq_space=train_feat_wavelet, ) cond_embed_dim = alg_palette_cond_embed_dim diff --git a/models/gan_networks.py b/models/gan_networks.py index e37b6bdd4..556fcd7a0 100644 --- a/models/gan_networks.py +++ b/models/gan_networks.py @@ -69,6 +69,7 @@ def define_G( G_unet_mha_channel_mults, G_unet_mha_norm_layer, G_unet_mha_group_norm_size, + train_feat_wavelet, **unused_options ): """Create a generator @@ -148,6 +149,7 @@ def define_G( use_spectral=G_spectral, padding_type=G_padding_type, twice_resnet_blocks=G_backward_compatibility_twice_resnet_blocks, + freq_space=train_feat_wavelet, ) elif G_netG == "mobile_resnet_attn": net = ResnetGenerator_attn( @@ -161,6 +163,7 @@ def define_G( padding_type=G_padding_type, mobile=True, twice_resnet_blocks=G_backward_compatibility_twice_resnet_blocks, + freq_space=train_feat_wavelet, ) elif G_netG == "stylegan2": net = StyleGAN2Generator( @@ -264,6 +267,7 @@ def define_D( dataaug_D_diffusion, f_s_semantic_nclasses, model_depth_network, + train_feat_wavelet, **unused_options ): @@ -316,6 +320,7 @@ def define_D( norm_layer=norm_layer, use_dropout=D_dropout, use_spectral=D_spectral, + freq_space=train_feat_wavelet, ) return_nets[netD] = init_net(net, model_init_type, model_init_gain) diff --git a/models/modules/diffusion_generator.py b/models/modules/diffusion_generator.py index f90aa5306..58ae92f5f 100644 --- a/models/modules/diffusion_generator.py +++ b/models/modules/diffusion_generator.py @@ -286,7 +286,7 @@ def restoration_ddim( tlist = torch.zeros([y_t.shape[0]], device=y_t.device).long() for i in tqdm( - reversed(range(num_steps)), + range(num_steps), desc="sampling loop time step", total=num_steps, ): diff --git a/models/modules/discriminators.py b/models/modules/discriminators.py index 99635f364..5662df779 100644 --- a/models/modules/discriminators.py +++ b/models/modules/discriminators.py @@ -4,8 +4,6 @@ from torch import nn from torch.nn import functional as F -# from .spade_architecture.normalization import get_nonspade_norm_layer - from .utils import spectral_norm, normal_init @@ -20,6 +18,7 @@ def __init__( norm_layer=nn.BatchNorm2d, use_dropout=False, use_spectral=False, + freq_space=False, ): """Construct a PatchGAN discriminator @@ -39,6 +38,14 @@ def __init__( else: use_bias = norm_layer == nn.InstanceNorm2d + self.freq_space = freq_space + if self.freq_space: + from .freq_utils import InverseHaarTransform, HaarTransform + + self.iwt = InverseHaarTransform(input_nc) + self.dwt = HaarTransform(input_nc) + input_nc *= 4 + kw = 4 padw = 1 sequence = [ @@ -103,7 +110,12 @@ def __init__( def forward(self, input): """Standard forward.""" - return self.model(input) + if self.freq_space: + x = self.dwt(input) + else: + x = input + x = self.model(x) + return x class PixelDiscriminator(nn.Module): diff --git a/models/modules/freq_utils.py b/models/modules/freq_utils.py new file mode 100644 index 000000000..5aae7e86f --- /dev/null +++ b/models/modules/freq_utils.py @@ -0,0 +1,59 @@ +import torch +from torch import nn +from .op import upfirdn2d + +### functions for frequency space + + +def get_haar_wavelet(in_channels): + haar_wav_l = 1 / (2**0.5) * torch.ones(1, 2) + haar_wav_h = 1 / (2**0.5) * torch.ones(1, 2) + haar_wav_h[0, 0] = -1 * haar_wav_h[0, 0] + + haar_wav_ll = haar_wav_l.T * haar_wav_l + haar_wav_lh = haar_wav_h.T * haar_wav_l + haar_wav_hl = haar_wav_l.T * haar_wav_h + haar_wav_hh = haar_wav_h.T * haar_wav_h + + return haar_wav_ll, haar_wav_lh, haar_wav_hl, haar_wav_hh + + +class HaarTransform(nn.Module): + def __init__(self, in_channels): + super().__init__() + + ll, lh, hl, hh = get_haar_wavelet(in_channels) + + self.register_buffer("ll", ll) + self.register_buffer("lh", lh) + self.register_buffer("hl", hl) + self.register_buffer("hh", hh) + + def forward(self, input): + ll = upfirdn2d(input, self.ll, down=2) + lh = upfirdn2d(input, self.lh, down=2) + hl = upfirdn2d(input, self.hl, down=2) + hh = upfirdn2d(input, self.hh, down=2) + + return torch.cat((ll, lh, hl, hh), 1) + + +class InverseHaarTransform(nn.Module): + def __init__(self, in_channels): + super().__init__() + + ll, lh, hl, hh = get_haar_wavelet(in_channels) + + self.register_buffer("ll", ll) + self.register_buffer("lh", -lh) + self.register_buffer("hl", -hl) + self.register_buffer("hh", hh) + + def forward(self, input): + ll, lh, hl, hh = input.chunk(4, 1) + ll = upfirdn2d(ll, self.ll, up=2, pad=(1, 0, 1, 0)) + lh = upfirdn2d(lh, self.lh, up=2, pad=(1, 0, 1, 0)) + hl = upfirdn2d(hl, self.hl, up=2, pad=(1, 0, 1, 0)) + hh = upfirdn2d(hh, self.hh, up=2, pad=(1, 0, 1, 0)) + + return ll + lh + hl + hh diff --git a/models/modules/op/__init__.py b/models/modules/op/__init__.py new file mode 100755 index 000000000..cd1fe3d14 --- /dev/null +++ b/models/modules/op/__init__.py @@ -0,0 +1,2 @@ +# from .fused_act import FusedLeakyReLU, fused_leaky_relu +from .upfirdn2d import upfirdn2d diff --git a/models/modules/op/upfirdn2d.cpp b/models/modules/op/upfirdn2d.cpp new file mode 100755 index 000000000..d667b2cfc --- /dev/null +++ b/models/modules/op/upfirdn2d.cpp @@ -0,0 +1,31 @@ +#include +#include + +torch::Tensor upfirdn2d_op(const torch::Tensor &input, + const torch::Tensor &kernel, int up_x, int up_y, + int down_x, int down_y, int pad_x0, int pad_x1, + int pad_y0, int pad_y1); + +#define CHECK_CUDA(x) \ + TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) \ + TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) + +torch::Tensor upfirdn2d(const torch::Tensor &input, const torch::Tensor &kernel, + int up_x, int up_y, int down_x, int down_y, int pad_x0, + int pad_x1, int pad_y0, int pad_y1) { + CHECK_INPUT(input); + CHECK_INPUT(kernel); + + at::DeviceGuard guard(input.device()); + + return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, + pad_y0, pad_y1); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); +} \ No newline at end of file diff --git a/models/modules/op/upfirdn2d.py b/models/modules/op/upfirdn2d.py new file mode 100755 index 000000000..6e4f03b54 --- /dev/null +++ b/models/modules/op/upfirdn2d.py @@ -0,0 +1,209 @@ +from collections import abc +import os + +import torch +from torch.nn import functional as F +from torch.autograd import Function +from torch.utils.cpp_extension import load + + +module_path = os.path.dirname(__file__) +upfirdn2d_op = load( + "upfirdn2d", + sources=[ + os.path.join(module_path, "upfirdn2d.cpp"), + os.path.join(module_path, "upfirdn2d_kernel.cu"), + ], +) + + +class UpFirDn2dBackward(Function): + @staticmethod + def forward( + ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size + ): + + up_x, up_y = up + down_x, down_y = down + g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad + + grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) + + grad_input = upfirdn2d_op.upfirdn2d( + grad_output, + grad_kernel, + down_x, + down_y, + up_x, + up_y, + g_pad_x0, + g_pad_x1, + g_pad_y0, + g_pad_y1, + ) + grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) + + ctx.save_for_backward(kernel) + + pad_x0, pad_x1, pad_y0, pad_y1 = pad + + ctx.up_x = up_x + ctx.up_y = up_y + ctx.down_x = down_x + ctx.down_y = down_y + ctx.pad_x0 = pad_x0 + ctx.pad_x1 = pad_x1 + ctx.pad_y0 = pad_y0 + ctx.pad_y1 = pad_y1 + ctx.in_size = in_size + ctx.out_size = out_size + + return grad_input + + @staticmethod + def backward(ctx, gradgrad_input): + (kernel,) = ctx.saved_tensors + + gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) + + gradgrad_out = upfirdn2d_op.upfirdn2d( + gradgrad_input, + kernel, + ctx.up_x, + ctx.up_y, + ctx.down_x, + ctx.down_y, + ctx.pad_x0, + ctx.pad_x1, + ctx.pad_y0, + ctx.pad_y1, + ) + # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) + gradgrad_out = gradgrad_out.view( + ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] + ) + + return gradgrad_out, None, None, None, None, None, None, None, None + + +class UpFirDn2d(Function): + @staticmethod + def forward(ctx, input, kernel, up, down, pad): + up_x, up_y = up + down_x, down_y = down + pad_x0, pad_x1, pad_y0, pad_y1 = pad + + kernel_h, kernel_w = kernel.shape + batch, channel, in_h, in_w = input.shape + ctx.in_size = input.shape + + input = input.reshape(-1, in_h, in_w, 1) + + ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x + ctx.out_size = (out_h, out_w) + + ctx.up = (up_x, up_y) + ctx.down = (down_x, down_y) + ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) + + g_pad_x0 = kernel_w - pad_x0 - 1 + g_pad_y0 = kernel_h - pad_y0 - 1 + g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 + g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 + + ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) + + out = upfirdn2d_op.upfirdn2d( + input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 + ) + # out = out.view(major, out_h, out_w, minor) + out = out.view(-1, channel, out_h, out_w) + + return out + + @staticmethod + def backward(ctx, grad_output): + kernel, grad_kernel = ctx.saved_tensors + + grad_input = None + + if ctx.needs_input_grad[0]: + grad_input = UpFirDn2dBackward.apply( + grad_output, + kernel, + grad_kernel, + ctx.up, + ctx.down, + ctx.pad, + ctx.g_pad, + ctx.in_size, + ctx.out_size, + ) + + return grad_input, None, None, None, None + + +def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): + if not isinstance(up, abc.Iterable): + up = (up, up) + + if not isinstance(down, abc.Iterable): + down = (down, down) + + if len(pad) == 2: + pad = (pad[0], pad[1], pad[0], pad[1]) + + if input.device.type == "cpu": + out = upfirdn2d_native(input, kernel, *up, *down, *pad) + + else: + out = UpFirDn2d.apply(input, kernel, up, down, pad) + + return out + + +def upfirdn2d_native( + input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 +): + _, channel, in_h, in_w = input.shape + input = input.reshape(-1, in_h, in_w, 1) + + _, in_h, in_w, minor = input.shape + kernel_h, kernel_w = kernel.shape + + out = input.view(-1, in_h, 1, in_w, 1, minor) + out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) + out = out.view(-1, in_h * up_y, in_w * up_x, minor) + + out = F.pad( + out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] + ) + out = out[ + :, + max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), + max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), + :, + ] + + out = out.permute(0, 3, 1, 2) + out = out.reshape( + [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] + ) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape( + -1, + minor, + in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, + ) + out = out.permute(0, 2, 3, 1) + out = out[:, ::down_y, ::down_x, :] + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x + + return out.view(-1, channel, out_h, out_w) diff --git a/models/modules/op/upfirdn2d_kernel.cu b/models/modules/op/upfirdn2d_kernel.cu new file mode 100755 index 000000000..61d3dc413 --- /dev/null +++ b/models/modules/op/upfirdn2d_kernel.cu @@ -0,0 +1,369 @@ +// Copyright (c) 2019, NVIDIA Corporation. All rights reserved. +// +// This work is made available under the Nvidia Source Code License-NC. +// To view a copy of this license, visit +// https://nvlabs.github.io/stylegan2/license.html + +#include + +#include +#include +#include +#include + +#include +#include + +static __host__ __device__ __forceinline__ int floor_div(int a, int b) { + int c = a / b; + + if (c * b > a) { + c--; + } + + return c; +} + +struct UpFirDn2DKernelParams { + int up_x; + int up_y; + int down_x; + int down_y; + int pad_x0; + int pad_x1; + int pad_y0; + int pad_y1; + + int major_dim; + int in_h; + int in_w; + int minor_dim; + int kernel_h; + int kernel_w; + int out_h; + int out_w; + int loop_major; + int loop_x; +}; + +template +__global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input, + const scalar_t *kernel, + const UpFirDn2DKernelParams p) { + int minor_idx = blockIdx.x * blockDim.x + threadIdx.x; + int out_y = minor_idx / p.minor_dim; + minor_idx -= out_y * p.minor_dim; + int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y; + int major_idx_base = blockIdx.z * p.loop_major; + + if (out_x_base >= p.out_w || out_y >= p.out_h || + major_idx_base >= p.major_dim) { + return; + } + + int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0; + int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h); + int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y; + int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y; + + for (int loop_major = 0, major_idx = major_idx_base; + loop_major < p.loop_major && major_idx < p.major_dim; + loop_major++, major_idx++) { + for (int loop_x = 0, out_x = out_x_base; + loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) { + int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0; + int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w); + int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x; + int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x; + + const scalar_t *x_p = + &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + + minor_idx]; + const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x]; + int x_px = p.minor_dim; + int k_px = -p.up_x; + int x_py = p.in_w * p.minor_dim; + int k_py = -p.up_y * p.kernel_w; + + scalar_t v = 0.0f; + + for (int y = 0; y < h; y++) { + for (int x = 0; x < w; x++) { + v += static_cast(*x_p) * static_cast(*k_p); + x_p += x_px; + k_p += k_px; + } + + x_p += x_py - w * x_px; + k_p += k_py - w * k_px; + } + + out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + + minor_idx] = v; + } + } +} + +template +__global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input, + const scalar_t *kernel, + const UpFirDn2DKernelParams p) { + const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; + const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; + + __shared__ volatile float sk[kernel_h][kernel_w]; + __shared__ volatile float sx[tile_in_h][tile_in_w]; + + int minor_idx = blockIdx.x; + int tile_out_y = minor_idx / p.minor_dim; + minor_idx -= tile_out_y * p.minor_dim; + tile_out_y *= tile_out_h; + int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; + int major_idx_base = blockIdx.z * p.loop_major; + + if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | + major_idx_base >= p.major_dim) { + return; + } + + for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; + tap_idx += blockDim.x) { + int ky = tap_idx / kernel_w; + int kx = tap_idx - ky * kernel_w; + scalar_t v = 0.0; + + if (kx < p.kernel_w & ky < p.kernel_h) { + v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; + } + + sk[ky][kx] = v; + } + + for (int loop_major = 0, major_idx = major_idx_base; + loop_major < p.loop_major & major_idx < p.major_dim; + loop_major++, major_idx++) { + for (int loop_x = 0, tile_out_x = tile_out_x_base; + loop_x < p.loop_x & tile_out_x < p.out_w; + loop_x++, tile_out_x += tile_out_w) { + int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; + int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; + int tile_in_x = floor_div(tile_mid_x, up_x); + int tile_in_y = floor_div(tile_mid_y, up_y); + + __syncthreads(); + + for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; + in_idx += blockDim.x) { + int rel_in_y = in_idx / tile_in_w; + int rel_in_x = in_idx - rel_in_y * tile_in_w; + int in_x = rel_in_x + tile_in_x; + int in_y = rel_in_y + tile_in_y; + + scalar_t v = 0.0; + + if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { + v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * + p.minor_dim + + minor_idx]; + } + + sx[rel_in_y][rel_in_x] = v; + } + + __syncthreads(); + for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; + out_idx += blockDim.x) { + int rel_out_y = out_idx / tile_out_w; + int rel_out_x = out_idx - rel_out_y * tile_out_w; + int out_x = rel_out_x + tile_out_x; + int out_y = rel_out_y + tile_out_y; + + int mid_x = tile_mid_x + rel_out_x * down_x; + int mid_y = tile_mid_y + rel_out_y * down_y; + int in_x = floor_div(mid_x, up_x); + int in_y = floor_div(mid_y, up_y); + int rel_in_x = in_x - tile_in_x; + int rel_in_y = in_y - tile_in_y; + int kernel_x = (in_x + 1) * up_x - mid_x - 1; + int kernel_y = (in_y + 1) * up_y - mid_y - 1; + + scalar_t v = 0.0; + +#pragma unroll + for (int y = 0; y < kernel_h / up_y; y++) +#pragma unroll + for (int x = 0; x < kernel_w / up_x; x++) + v += sx[rel_in_y + y][rel_in_x + x] * + sk[kernel_y + y * up_y][kernel_x + x * up_x]; + + if (out_x < p.out_w & out_y < p.out_h) { + out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + + minor_idx] = v; + } + } + } + } +} + +torch::Tensor upfirdn2d_op(const torch::Tensor &input, + const torch::Tensor &kernel, int up_x, int up_y, + int down_x, int down_y, int pad_x0, int pad_x1, + int pad_y0, int pad_y1) { + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + UpFirDn2DKernelParams p; + + auto x = input.contiguous(); + auto k = kernel.contiguous(); + + p.major_dim = x.size(0); + p.in_h = x.size(1); + p.in_w = x.size(2); + p.minor_dim = x.size(3); + p.kernel_h = k.size(0); + p.kernel_w = k.size(1); + p.up_x = up_x; + p.up_y = up_y; + p.down_x = down_x; + p.down_y = down_y; + p.pad_x0 = pad_x0; + p.pad_x1 = pad_x1; + p.pad_y0 = pad_y0; + p.pad_y1 = pad_y1; + + p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / + p.down_y; + p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / + p.down_x; + + auto out = + at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); + + int mode = -1; + + int tile_out_h = -1; + int tile_out_w = -1; + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 1; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 3 && p.kernel_w <= 3) { + mode = 2; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 3; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 2 && p.kernel_w <= 2) { + mode = 4; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && + p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 5; + tile_out_h = 8; + tile_out_w = 32; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && + p.kernel_h <= 2 && p.kernel_w <= 2) { + mode = 6; + tile_out_h = 8; + tile_out_w = 32; + } + + dim3 block_size; + dim3 grid_size; + + if (tile_out_h > 0 && tile_out_w > 0) { + p.loop_major = (p.major_dim - 1) / 16384 + 1; + p.loop_x = 1; + block_size = dim3(32 * 8, 1, 1); + grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, + (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, + (p.major_dim - 1) / p.loop_major + 1); + } else { + p.loop_major = (p.major_dim - 1) / 16384 + 1; + p.loop_x = 4; + block_size = dim3(4, 32, 1); + grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1, + (p.out_w - 1) / (p.loop_x * block_size.y) + 1, + (p.major_dim - 1) / p.loop_major + 1); + } + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { + switch (mode) { + case 1: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 2: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 3: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 4: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 5: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 6: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + default: + upfirdn2d_kernel_large<<>>( + out.data_ptr(), x.data_ptr(), + k.data_ptr(), p); + } + }); + + return out; +} \ No newline at end of file diff --git a/models/modules/resnet_architecture/resnet_generator.py b/models/modules/resnet_architecture/resnet_generator.py index a713c8e8d..5dc07b1b4 100644 --- a/models/modules/resnet_architecture/resnet_generator.py +++ b/models/modules/resnet_architecture/resnet_generator.py @@ -400,6 +400,7 @@ def __init__( padding_type="reflect", mobile=False, twice_resnet_blocks=False, + freq_space=False, ): super(ResnetGenerator_attn, self).__init__( nb_mask_attn=nb_mask_attn, nb_mask_input=nb_mask_input @@ -415,48 +416,70 @@ def __init__( self.nb = n_blocks self.padding_type = padding_type self.twice_resnet_blocks = twice_resnet_blocks + self.freq_space = freq_space - self.conv1 = spectral_norm(nn.Conv2d(input_nc, ngf, 7, 1, 0), use_spectral) + if freq_space: + from ..freq_utils import InverseHaarTransform, HaarTransform + + self.iwt = InverseHaarTransform(self.input_nc) + self.dwt = HaarTransform(self.input_nc) + self.input_nc = input_nc * 4 + + self.conv1 = spectral_norm( + nn.Conv2d(self.input_nc, self.ngf, 7, 1, 0), use_spectral + ) self.input_nc = output_nc # hack - self.conv1_norm = nn.InstanceNorm2d(ngf) - self.conv2 = spectral_norm(nn.Conv2d(ngf, ngf * 2, 3, 2, 1), use_spectral) - self.conv2_norm = nn.InstanceNorm2d(ngf * 2) - self.conv3 = spectral_norm(nn.Conv2d(ngf * 2, ngf * 4, 3, 2, 1), use_spectral) - self.conv3_norm = nn.InstanceNorm2d(ngf * 4) + self.conv1_norm = nn.InstanceNorm2d(self.ngf) + self.conv2 = spectral_norm( + nn.Conv2d(self.ngf, self.ngf * 2, 3, 2, 1), use_spectral + ) + self.conv2_norm = nn.InstanceNorm2d(self.ngf * 2) + self.conv3 = spectral_norm( + nn.Conv2d(self.ngf * 2, self.ngf * 4, 3, 2, 1), use_spectral + ) + self.conv3_norm = nn.InstanceNorm2d(self.ngf * 4) self.resnet_blocks = [] for i in range(n_blocks): self.resnet_blocks.append( - resnet_block_attn(ngf * 4, 3, 1, self.padding_type, conv=conv) + resnet_block_attn(self.ngf * 4, 3, 1, self.padding_type, conv=conv) ) self.resnet_blocks[i].weight_init(0, 0.02) self.resnet_blocks = nn.Sequential(*self.resnet_blocks) self.deconv1_content = spectral_norm( - nn.ConvTranspose2d(ngf * 4, ngf * 2, 3, 2, 1, 1), use_spectral + nn.ConvTranspose2d(self.ngf * 4, self.ngf * 2, 3, 2, 1, 1), use_spectral ) - self.deconv1_norm_content = nn.InstanceNorm2d(ngf * 2) + self.deconv1_norm_content = nn.InstanceNorm2d(self.ngf * 2) self.deconv2_content = spectral_norm( - nn.ConvTranspose2d(ngf * 2, ngf, 3, 2, 1, 1), use_spectral + nn.ConvTranspose2d(self.ngf * 2, self.ngf, 3, 2, 1, 1), use_spectral ) - self.deconv2_norm_content = nn.InstanceNorm2d(ngf) + self.deconv2_norm_content = nn.InstanceNorm2d(self.ngf) + if self.freq_space: + deconv3_ngf = int(self.ngf / 4) + else: + deconv3_ngf = self.ngf self.deconv3_content = spectral_norm( nn.Conv2d( - ngf, self.input_nc * (self.nb_mask_attn - self.nb_mask_input), 7, 1, 0 + deconv3_ngf, + self.input_nc * (self.nb_mask_attn - self.nb_mask_input), + 7, + 1, + 0, ), use_spectral, ) self.deconv1_attention = spectral_norm( - nn.ConvTranspose2d(ngf * 4, ngf * 2, 3, 2, 1, 1), use_spectral + nn.ConvTranspose2d(self.ngf * 4, self.ngf * 2, 3, 2, 1, 1), use_spectral ) - self.deconv1_norm_attention = nn.InstanceNorm2d(ngf * 2) + self.deconv1_norm_attention = nn.InstanceNorm2d(self.ngf * 2) self.deconv2_attention = spectral_norm( - nn.ConvTranspose2d(ngf * 2, ngf, 3, 2, 1, 1), use_spectral + nn.ConvTranspose2d(self.ngf * 2, self.ngf, 3, 2, 1, 1), use_spectral ) - self.deconv2_norm_attention = nn.InstanceNorm2d(ngf) - self.deconv3_attention = nn.Conv2d(ngf, self.nb_mask_attn, 1, 1, 0) + self.deconv2_norm_attention = nn.InstanceNorm2d(self.ngf) + self.deconv3_attention = nn.Conv2d(self.ngf, self.nb_mask_attn, 1, 1, 0) self.tanh = nn.Tanh() @@ -470,6 +493,10 @@ def compute_feats(self, input, extract_layer_ids=[]): x = F.pad(input, (3, 3, 3, 3), "reflect") else: x = F.pad(input, (3, 3, 3, 3), "constant", 0) + + if self.freq_space: + x = self.dwt(x) + x = F.relu(self.conv1_norm(self.conv1(x))) x = F.relu(self.conv2_norm(self.conv2(x))) x = F.relu(self.conv3_norm(self.conv3(x))) @@ -495,11 +522,16 @@ def compute_attention_content(self, feat): x_content = F.relu(self.deconv1_norm_content(self.deconv1_content(x))) x_content = F.relu(self.deconv2_norm_content(self.deconv2_content(x_content))) + + if self.freq_space: + x_content = self.iwt(x_content) + if self.padding_type == "reflect": x_content = F.pad(x_content, (3, 3, 3, 3), "reflect") else: x_content = F.pad(x_content, (3, 3, 3, 3), "constant", 0) content = self.deconv3_content(x_content) + image = self.tanh(content) images = [] diff --git a/models/modules/sam/sam_inference.py b/models/modules/sam/sam_inference.py index 80ec43ed2..3f4510719 100644 --- a/models/modules/sam/sam_inference.py +++ b/models/modules/sam/sam_inference.py @@ -1,12 +1,19 @@ import os import random import sys -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union import cv2 import numpy as np import scipy import torch +from mobile_sam.modeling import ( + ImageEncoderViT, + MaskDecoder, + PromptEncoder, + TinyViT, + TwoWayTransformer, +) from numpy.random import PCG64, Generator from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry from segment_anything.modeling.image_encoder import ImageEncoderViT @@ -16,6 +23,7 @@ from torch import nn from torch.nn import functional as F +from models.modules.utils import download_mobile_sam_weight, download_sam_weight from util.util import tensor2im @@ -443,6 +451,228 @@ def reset_image(self) -> None: self.input_w = None +class MobileSam(nn.Module): + """ + The MobileSAM related code has been adapted to our needs from the official + MobileSAM repository (https://github.com/ChaoningZhang/MobileSAM). Many thanks to + their team for this great work! + """ + + mask_threshold: float = 0.0 + image_format: str = "RGB" + + def __init__( + self, + image_encoder: Union[ImageEncoderViT, TinyViT], + prompt_encoder: PromptEncoder, + mask_decoder: MaskDecoder, + pixel_mean: List[float] = [123.675, 116.28, 103.53], + pixel_std: List[float] = [58.395, 57.12, 57.375], + ) -> None: + """ + SAM predicts object masks from an image and input prompts. + + Arguments: + image_encoder (ImageEncoderViT): The backbone used to encode the + image into image embeddings that allow for efficient mask prediction. + prompt_encoder (PromptEncoder): Encodes various types of input prompts. + mask_decoder (MaskDecoder): Predicts masks from the image embeddings + and encoded prompts. + pixel_mean (list(float)): Mean values for normalizing pixels in the input image. + pixel_std (list(float)): Std values for normalizing pixels in the input image. + """ + super().__init__() + self.image_encoder = image_encoder + self.prompt_encoder = prompt_encoder + self.mask_decoder = mask_decoder + self.register_buffer( + "pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False + ) + self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) + + @property + def device(self) -> Any: + return self.pixel_mean.device + + def forward( + self, + batched_input: List[Dict[str, Any]], + multimask_output: bool, + ) -> List[Dict[str, torch.Tensor]]: + """ + Predicts masks end-to-end from provided images and prompts. + If prompts are not known in advance, using SamPredictor is + recommended over calling the model directly. + + Arguments: + batched_input (list(dict)): A list over input images, each a + dictionary with the following keys. A prompt key can be + excluded if it is not present. + 'image': The image as a torch tensor in 3xHxW format, + already transformed for input to the model. + 'original_size': (tuple(int, int)) The original size of + the image before transformation, as (H, W). + 'point_coords': (torch.Tensor) Batched point prompts for + this image, with shape BxNx2. Already transformed to the + input frame of the model. + 'point_labels': (torch.Tensor) Batched labels for point prompts, + with shape BxN. + 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. + Already transformed to the input frame of the model. + 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, + in the form Bx1xHxW. + multimask_output (bool): Whether the model should predict multiple + disambiguating masks, or return a single mask. + + Returns: + (list(dict)): A list over input images, where each element is + as dictionary with the following keys. + 'masks': (torch.Tensor) Batched binary mask predictions, + with shape BxCxHxW, where B is the number of input prompts, + C is determined by multimask_output, and (H, W) is the + original size of the image. + 'iou_predictions': (torch.Tensor) The model's predictions + of mask quality, in shape BxC. + 'low_res_logits': (torch.Tensor) Low resolution logits with + shape BxCxHxW, where H=W=256. Can be passed as mask input + to subsequent iterations of prediction. + """ + input_images = torch.stack( + [self.preprocess(x["image"]) for x in batched_input], dim=0 + ) + image_embeddings = self.image_encoder(input_images) + + outputs = [] + for image_record, curr_embedding in zip(batched_input, image_embeddings): + if "point_coords" in image_record: + points = (image_record["point_coords"], image_record["point_labels"]) + else: + points = None + sparse_embeddings, dense_embeddings = self.prompt_encoder( + points=points, + boxes=image_record.get("boxes", None), + masks=image_record.get("mask_inputs", None), + ) + low_res_masks, iou_predictions = self.mask_decoder( + image_embeddings=curr_embedding.unsqueeze(0), + image_pe=self.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + ) + masks = self.postprocess_masks( + low_res_masks, + input_size=image_record["image"].shape[-2:], + original_size=image_record["original_size"], + ) + masks = masks > self.mask_threshold + outputs.append( + { + "masks": masks, + "iou_predictions": iou_predictions, + "low_res_logits": low_res_masks, + } + ) + return outputs + + def postprocess_masks( + self, + masks: torch.Tensor, + input_size: Tuple[int, ...], + original_size: Tuple[int, ...], + ) -> torch.Tensor: + """ + Remove padding and upscale masks to the original image size. + + Arguments: + masks (torch.Tensor): Batched masks from the mask_decoder, + in BxCxHxW format. + input_size (tuple(int, int)): The size of the image input to the + model, in (H, W) format. Used to remove padding. + original_size (tuple(int, int)): The original size of the image + before resizing for input to the model, in (H, W) format. + + Returns: + (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) + is given by original_size. + """ + masks = F.interpolate( + masks, + (self.image_encoder.img_size, self.image_encoder.img_size), + mode="bilinear", + align_corners=False, + ) + masks = masks[..., : input_size[0], : input_size[1]] + masks = F.interpolate( + masks, original_size, mode="bilinear", align_corners=False + ) + return masks + + def preprocess(self, x: torch.Tensor) -> torch.Tensor: + """Normalize pixel values and pad to a square input.""" + # Normalize colors + x = (x - self.pixel_mean) / self.pixel_std + + # Pad + h, w = x.shape[-2:] + padh = self.image_encoder.img_size - h + padw = self.image_encoder.img_size - w + x = F.pad(x, (0, padw, 0, padh)) + return x + + +def build_sam_vit_t(checkpoint=None): + prompt_embed_dim = 256 + image_size = 1024 + vit_patch_size = 16 + image_embedding_size = image_size // vit_patch_size + mobile_sam = MobileSam( + image_encoder=TinyViT( + img_size=1024, + in_chans=3, + num_classes=1000, + embed_dims=[64, 128, 160, 320], + depths=[2, 2, 6, 2], + num_heads=[2, 4, 5, 10], + window_sizes=[7, 7, 14, 7], + mlp_ratio=4.0, + drop_rate=0.0, + drop_path_rate=0.0, + use_checkpoint=False, + mbconv_expand_ratio=4.0, + local_conv_size=3, + layer_lr_decay=0.8, + ), + prompt_encoder=PromptEncoder( + embed_dim=prompt_embed_dim, + image_embedding_size=(image_embedding_size, image_embedding_size), + input_image_size=(image_size, image_size), + mask_in_chans=16, + ), + mask_decoder=MaskDecoder( + num_multimask_outputs=3, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + ), + pixel_mean=[123.675, 116.28, 103.53], + pixel_std=[58.395, 57.12, 57.375], + ) + + mobile_sam.eval() + if checkpoint is not None: + with open(checkpoint, "rb") as f: + state_dict = torch.load(f) + mobile_sam.load_state_dict(state_dict) + return mobile_sam + + ######### JoliGEN level functions def load_sam_weight(model_path): if "vit_h" in model_path: @@ -456,6 +686,12 @@ def load_sam_weight(model_path): return sam, sam_predictor +def load_mobile_sam_weight(model_path): + sam = build_sam_vit_t(checkpoint=model_path) + sam_predictor = SamPredictorG(sam) + return sam, sam_predictor + + def predict_sam(img, sam_predictor, bbox=None): # - img in RBG value space img = torch.clamp(img, min=-1.0, max=1.0) @@ -724,7 +960,6 @@ def predict_sam_edges( batched_edges = [] for k in range(len(image)): - pass masked_imgs = [] for mask in batched_output[k]["non_redundant_masks"]: assert ( @@ -874,14 +1109,25 @@ def compute_mask_with_sam(img, rect_mask, sam_model, device, batched=True): predictor=predictor, cat=categories[i], ) - """ - xmin, ymin, xmax, ymax = boxes[i].cpu() - mask[: int(ymin), :] = 0 - mask[int(ymax) :, :] = 0 - mask[:, : int(xmin)] = 0 - mask[:, int(xmax) :] = 0 - """ sam_masks[i] = torch.from_numpy(mask).unsqueeze(0) else: sam_masks[i] = rect_mask[i] return sam_masks + + +def init_sam_net(model_type_sam, model_path, device): + if model_type_sam == "sam": + download_sam_weight(path=model_path) + freezenet_sam, predictor_sam = load_sam_weight(model_path=model_path) + if device is not None: + freezenet_sam = freezenet_sam.to(device) + elif model_type_sam == "mobile_sam": + download_mobile_sam_weight(path=model_path) + freezenet_sam, predictor_sam = load_mobile_sam_weight(model_path=model_path) + if device is not None: + freezenet_sam.to(device) + else: + raise ValueError( + f'{model_type_sam} is not a correct choice for model_type_sam.\nChoices: ["sam", "mobile_sam"]' + ) + return freezenet_sam, predictor_sam diff --git a/models/modules/unet_generator_attn/unet_generator_attn.py b/models/modules/unet_generator_attn/unet_generator_attn.py index c961eed65..c6ec73bae 100644 --- a/models/modules/unet_generator_attn/unet_generator_attn.py +++ b/models/modules/unet_generator_attn/unet_generator_attn.py @@ -58,16 +58,31 @@ class Upsample(nn.Module): """ - def __init__(self, channels, use_conv, out_channel=None, efficient=False): + def __init__( + self, channels, use_conv, out_channel=None, efficient=False, freq_space=False + ): super().__init__() self.channels = channels self.out_channel = out_channel or channels self.use_conv = use_conv + self.freq_space = freq_space + + if freq_space: + from ..freq_utils import InverseHaarTransform, HaarTransform + + self.iwt = InverseHaarTransform(3) + self.dwt = HaarTransform(3) + self.channels = int(self.channels / 4) + self.out_channel = int(self.out_channel / 4) + if use_conv: self.conv = nn.Conv2d(self.channels, self.out_channel, 3, padding=1) self.efficient = efficient def forward(self, x): + if self.freq_space: + x = self.iwt(x) + assert x.shape[1] == self.channels if not self.efficient: x = F.interpolate(x, scale_factor=2, mode="nearest") @@ -75,6 +90,10 @@ def forward(self, x): x = self.conv(x) if self.efficient: # if efficient, we do the interpolation after the conv x = F.interpolate(x, scale_factor=2, mode="nearest") + + if self.freq_space: + x = self.dwt(x) + return x @@ -85,11 +104,21 @@ class Downsample(nn.Module): :param use_conv: a bool determining if a convolution is applied. """ - def __init__(self, channels, use_conv, out_channel=None): + def __init__(self, channels, use_conv, out_channel=None, freq_space=False): super().__init__() self.channels = channels self.out_channel = out_channel or channels self.use_conv = use_conv + self.freq_space = freq_space + + if self.freq_space: + from ..freq_utils import InverseHaarTransform, HaarTransform + + self.iwt = InverseHaarTransform(3) + self.dwt = HaarTransform(3) + self.channels = int(self.channels / 4) + self.out_channel = int(self.out_channel / 4) + stride = 2 if use_conv: self.op = nn.Conv2d( @@ -100,8 +129,16 @@ def __init__(self, channels, use_conv, out_channel=None): self.op = nn.AvgPool2d(kernel_size=stride, stride=stride) def forward(self, x): + if self.freq_space: + x = self.iwt(x) + assert x.shape[1] == self.channels - return self.op(x) + opx = self.op(x) + + if self.freq_space: + opx = self.dwt(opx) + + return opx class ResBlock(EmbedBlock): @@ -132,6 +169,7 @@ def __init__( up=False, down=False, efficient=False, + freq_space=False, ): super().__init__() self.channels = channels @@ -143,21 +181,21 @@ def __init__( self.use_scale_shift_norm = use_scale_shift_norm self.up = up self.efficient = efficient + self.freq_space = freq_space + self.updown = up or down self.in_layers = nn.Sequential( - normalization(channels, norm), + normalization(self.channels, norm), torch.nn.SiLU(), - nn.Conv2d(channels, self.out_channel, 3, padding=1), + nn.Conv2d(self.channels, self.out_channel, 3, padding=1), ) - self.updown = up or down - if up: - self.h_upd = Upsample(channels, False) - self.x_upd = Upsample(channels, False) + self.h_upd = Upsample(channels, False, freq_space=self.freq_space) + self.x_upd = Upsample(channels, False, freq_space=self.freq_space) elif down: - self.h_upd = Downsample(channels, False) - self.x_upd = Downsample(channels, False) + self.h_upd = Downsample(channels, False, freq_space=self.freq_space) + self.x_upd = Downsample(channels, False, freq_space=self.freq_space) else: self.h_upd = self.x_upd = nn.Identity() @@ -196,7 +234,9 @@ def forward(self, x, emb): def _forward(self, x, emb): if self.updown: in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + if self.efficient and self.up: h = in_conv(h) h = self.h_upd(h) @@ -205,6 +245,7 @@ def _forward(self, x, emb): h = self.h_upd(h) x = self.x_upd(x) h = in_conv(h) + else: h = self.in_layers(x) emb_out = self.emb_layers(emb).type(h.dtype) @@ -400,6 +441,7 @@ def __init__( resblock_updown=True, use_new_attention_order=False, efficient=False, + freq_space=False, ): super().__init__() @@ -420,13 +462,22 @@ def __init__( self.num_heads = num_heads self.num_head_channels = num_head_channels self.num_heads_upsample = num_heads_upsample + self.freq_space = freq_space + + if self.freq_space: + from ..freq_utils import InverseHaarTransform, HaarTransform + + self.iwt = InverseHaarTransform(3) + self.dwt = HaarTransform(3) + in_channel *= 4 + out_channel *= 4 if norm == "groupnorm": norm = norm + str(group_norm_size) self.cond_embed_dim = cond_embed_dim - ch = input_ch = int(channel_mults[0] * inner_channel) + ch = input_ch = int(channel_mults[0] * self.inner_channel) self.input_blocks = nn.ModuleList( [EmbedSequential(nn.Conv2d(in_channel, ch, 3, padding=1))] ) @@ -440,14 +491,15 @@ def __init__( ch, self.cond_embed_dim, dropout, - out_channel=int(mult * inner_channel), + out_channel=int(mult * self.inner_channel), use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, norm=norm, efficient=efficient, + freq_space=self.freq_space, ) ] - ch = int(mult * inner_channel) + ch = int(mult * self.inner_channel) if ds in attn_res: layers.append( AttentionBlock( @@ -475,9 +527,15 @@ def __init__( down=True, norm=norm, efficient=efficient, + freq_space=self.freq_space, ) if resblock_updown - else Downsample(ch, conv_resample, out_channel=out_ch) + else Downsample( + ch, + conv_resample, + out_channel=out_ch, + freq_space=self.freq_space, + ) ) ) ch = out_ch @@ -494,6 +552,7 @@ def __init__( use_scale_shift_norm=use_scale_shift_norm, norm=norm, efficient=efficient, + freq_space=self.freq_space, ), AttentionBlock( ch, @@ -510,6 +569,7 @@ def __init__( use_scale_shift_norm=use_scale_shift_norm, norm=norm, efficient=efficient, + freq_space=self.freq_space, ), ) self._feature_size += ch @@ -523,14 +583,15 @@ def __init__( ch + ich, self.cond_embed_dim, dropout, - out_channel=int(inner_channel * mult), + out_channel=int(self.inner_channel * mult), use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, norm=norm, efficient=efficient, + freq_space=self.freq_space, ) ] - ch = int(inner_channel * mult) + ch = int(self.inner_channel * mult) if ds in attn_res: layers.append( AttentionBlock( @@ -554,9 +615,15 @@ def __init__( up=True, norm=norm, efficient=efficient, + freq_space=self.freq_space, ) if resblock_updown - else Upsample(ch, conv_resample, out_channel=out_ch) + else Upsample( + ch, + conv_resample, + out_channel=out_ch, + freq_space=self.freq_space, + ) ) ds //= 2 self.output_blocks.append(EmbedSequential(*layers)) @@ -601,6 +668,10 @@ def compute_feats(self, input, embed_gammas): hs = [] h = input.type(torch.float32) + + if self.freq_space: + h = self.dwt(h) + for module in self.input_blocks: h = module(h, emb) @@ -617,7 +688,12 @@ def forward(self, input, embed_gammas=None): h = torch.cat([h, hs.pop()], dim=1) h = module(h, emb) h = h.type(input.dtype) - return self.out(h) + outh = self.out(h) + + if self.freq_space: + outh = self.iwt(outh) + + return outh def get_feats(self, input, extract_layer_ids): _, hs, _ = self.compute_feats(input, embed_gammas=None) @@ -719,6 +795,7 @@ def __init__( use_new_attention_order=False, num_transformer_blocks=6, efficient=False, + freq_space=False, ): super().__init__() @@ -739,6 +816,15 @@ def __init__( self.num_heads = num_heads self.num_head_channels = num_head_channels self.num_heads_upsample = num_heads_upsample + self.freq_space = freq_space + + if self.freq_space: + from ..freq_utils import InverseHaarTransform, HaarTransform + + self.iwt = InverseHaarTransform(3) + self.dwt = HaarTransform(3) + in_channel *= 4 + out_channel *= 4 if norm == "groupnorm": norm = norm + str(group_norm_size) @@ -746,7 +832,7 @@ def __init__( self.cond_embed_dim = cond_embed_dim self.inner_channel = inner_channel - ch = input_ch = int(channel_mults[0] * inner_channel) + ch = input_ch = int(channel_mults[0] * self.inner_channel) self.input_blocks = nn.ModuleList( [EmbedSequential(nn.Conv2d(in_channel, ch, 3, padding=1))] ) @@ -760,20 +846,28 @@ def __init__( ch, cond_embed_dim, dropout, - out_channel=int(mult * inner_channel), + out_channel=int(mult * self.inner_channel), use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, norm=norm, + freq_space=self.freq_space, ) ] - ch = int(mult * inner_channel) + ch = int(mult * self.inner_channel) self.input_blocks.append(EmbedSequential(*layers)) self._feature_size += ch input_block_chans.append(ch) if level != len(channel_mults) - 1: out_ch = ch self.input_blocks.append( - EmbedSequential(Downsample(ch, conv_resample, out_channel=out_ch)) + EmbedSequential( + Downsample( + ch, + conv_resample, + out_channel=out_ch, + freq_space=self.freq_space, + ) + ) ) ch = out_ch input_block_chans.append(ch) @@ -807,18 +901,23 @@ def __init__( ch + ich, cond_embed_dim, dropout, - out_channel=int(inner_channel * mult), + out_channel=int(self.inner_channel * mult), use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, norm=norm, + freq_space=self.freq_space, ) ] - ch = int(inner_channel * mult) + ch = int(self.inner_channel * mult) if level and i == res_blocks[level]: out_ch = ch layers.append( Upsample( - ch, conv_resample, out_channel=out_ch, efficient=efficient + ch, + conv_resample, + out_channel=out_ch, + efficient=efficient, + freq_space=self.freq_space, ) ) ds //= 2 @@ -864,6 +963,10 @@ def compute_feats(self, input, embed_gammas): hs = [] h = input.type(torch.float32) + + if self.freq_space: + h = self.dwt(h) + for module in self.input_blocks: h = module(h, emb) hs.append(h) @@ -886,7 +989,12 @@ def forward(self, input, embed_gammas=None): h = torch.cat([h, hs.pop()], dim=1) h = module(h, emb) h = h.type(input.dtype) - return self.out(h) + outh = self.out(h) + + if self.freq_space: + outh = self.iwt(outh) + + return outh def get_feats(self, input, extract_layer_ids): _, hs, _ = self.compute_feats(input, embed_gammas=None) diff --git a/models/modules/utils.py b/models/modules/utils.py index 8ebe7ed8d..6b995ebe5 100644 --- a/models/modules/utils.py +++ b/models/modules/utils.py @@ -241,24 +241,46 @@ def download_midas_weight(model_type="DPT_Large"): def download_sam_weight(path): - sam_weights = { - "sam_vit_h_4b8939.pth": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", - "sam_vit_l_0b3195.pth": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", - "sam_vit_b_01ec64.pth": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", - } - for i in range(2, len(path.split("/"))): - temp = path.split("/")[:i] - cur_path = "/".join(temp) - if not os.path.isdir(cur_path): - os.mkdir(cur_path) - model_name = path.split("/")[-1] - if model_name in sam_weights: - wget.download(sam_weights[model_name], path) - else: - raise NameError( - "There is no pretrained weight to download for %s, you need to provide a path to segformer weights." - % model_name - ) + if not os.path.exists(path): + if not "mobile_sam" in path: + sam_weights = { + "sam_vit_h_4b8939.pth": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", + "sam_vit_l_0b3195.pth": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", + "sam_vit_b_01ec64.pth": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", + } + for i in range(2, len(path.split("/"))): + temp = path.split("/")[:i] + cur_path = "/".join(temp) + if not os.path.isdir(cur_path): + os.mkdir(cur_path) + model_name = path.split("/")[-1] + if model_name in sam_weights: + wget.download(sam_weights[model_name], path) + else: + raise NameError( + "There is no pretrained weight to download for %s, you need to provide a path to sam weights." + % model_name + ) + else: + download_mobile_sam_weight(path) + + +def download_mobile_sam_weight(path): + if not os.path.exists(path): + sam_weights = "https://github.com/ChaoningZhang/MobileSAM/raw/master/weights/mobile_sam.pt" + for i in range(2, len(path.split("/"))): + temp = path.split("/")[:i] + cur_path = "/".join(temp) + if not os.path.isdir(cur_path): + os.mkdir(cur_path) + model_name = path.split("/")[-1] + if model_name in sam_weights: + wget.download(sam_weights, path) + else: + raise NameError( + "There is no pretrained weight to download for %s, you need to provide a path to mobileSam weights." + % model_name + ) def predict_depth(img, midas, model_type): diff --git a/models/palette_model.py b/models/palette_model.py index a56b3944d..f9816eed4 100644 --- a/models/palette_model.py +++ b/models/palette_model.py @@ -1,17 +1,15 @@ import copy +import itertools import math import random import warnings import torch import torchvision.transforms as T -from torch import nn - - -import itertools import tqdm +from torch import nn -from data.online_creation import fill_mask_with_color +from data.online_creation import fill_mask_with_color, fill_mask_with_random from models.modules.sam.sam_inference import compute_mask_with_sam from util.iter_calculator import IterCalculator from util.mask_generation import random_edge_mask @@ -247,12 +245,14 @@ def __init__(self, opt, rank): (self.opt.data_crop_size, self.opt.data_crop_size) ) + if opt.isTrain: + batch_size = self.opt.train_batch_size + else: + batch_size = self.opt.test_batch_size if self.opt.alg_palette_inference_num == -1: - self.inference_num = self.opt.train_batch_size + self.inference_num = batch_size else: - self.inference_num = min( - self.opt.alg_palette_inference_num, self.opt.train_batch_size - ) + self.inference_num = min(self.opt.alg_palette_inference_num, batch_size) self.ddim_num_steps = self.opt.alg_palette_ddim_num_steps self.ddim_eta = self.opt.alg_palette_ddim_eta @@ -315,14 +315,14 @@ def __init__(self, opt, rank): G_parameters = itertools.chain(*G_parameters) # Define optimizer - self.optimizer_G = opt.optim( - opt, - G_parameters, - lr=opt.train_G_lr, - betas=(opt.train_beta1, opt.train_beta2), - ) - - self.optimizers.append(self.optimizer_G) + if opt.isTrain: + self.optimizer_G = opt.optim( + opt, + G_parameters, + lr=opt.train_G_lr, + betas=(opt.train_beta1, opt.train_beta2), + ) + self.optimizers.append(self.optimizer_G) # Define loss functions if self.opt.alg_palette_loss == "MSE": @@ -384,14 +384,28 @@ def set_input(self, data): self.gt_image = data["B"].to(self.device)[:, 1] if self.task == "inpainting": self.previous_frame_mask = data["B_label_mask"].to(self.device)[:, 0] + ### Note: the sam related stuff should eventually go into the dataloader if self.use_sam_mask: + if self.opt.data_inverted_mask: + temp_mask = data["B_label_mask"].clone() + temp_mask[temp_mask > 0] = 2 + temp_mask[temp_mask == 0] = 1 + temp_mask[temp_mask == 2] = 0 + else: + temp_mask = data["B_label_mask"].clone() self.mask = compute_mask_with_sam( self.gt_image, - data["B_label_mask"].to(self.device)[:, 1], + temp_mask.to(self.device)[:, 1], self.freezenet_sam, self.device, batched=True, ) + + if self.opt.data_inverted_mask: + self.mask[self.mask > 0] = 2 + self.mask[self.mask == 0] = 1 + self.mask[self.mask == 2] = 0 + self.y_t = fill_mask_with_random(self.gt_image, self.mask, -1) else: self.mask = data["B_label_mask"].to(self.device)[:, 1] else: @@ -400,14 +414,28 @@ def set_input(self, data): if self.task == "inpainting": self.y_t = data["A"].to(self.device) self.gt_image = data["B"].to(self.device) + ### Note: the sam related stuff should eventually go into the dataloader if self.use_sam_mask: + if self.opt.data_inverted_mask: + temp_mask = data["B_label_mask"].clone() + temp_mask[temp_mask > 0] = 2 + temp_mask[temp_mask == 0] = 1 + temp_mask[temp_mask == 2] = 0 + else: + temp_mask = data["B_label_mask"].clone() self.mask = compute_mask_with_sam( self.gt_image, - data["B_label_mask"].to(self.device), + temp_mask.to(self.device), self.freezenet_sam, self.device, batched=True, ) + if self.opt.data_inverted_mask: + self.mask[self.mask > 0] = 2 + self.mask[self.mask == 0] = 1 + self.mask[self.mask == 2] = 0 + self.y_t = fill_mask_with_random(self.gt_image, self.mask, -1) + else: self.mask = data["B_label_mask"].to(self.device) else: # e.g. super-resolution diff --git a/models/semantic_networks.py b/models/semantic_networks.py index 65739204d..51e55c441 100644 --- a/models/semantic_networks.py +++ b/models/semantic_networks.py @@ -1,16 +1,19 @@ import os from .modules.classifiers import ( + TORCH_MODEL_CLASSES, Classifier, VGG16_FCN8s, torch_model, - TORCH_MODEL_CLASSES, ) -from .modules.UNet_classification import UNet +from .modules.sam.sam_inference import ( + init_sam_net, + load_mobile_sam_weight, + load_sam_weight, +) from .modules.segformer.segformer_generator import Segformer - -from .modules.utils import init_net, get_weights -from .modules.sam.sam_inference import load_sam_weight +from .modules.UNet_classification import UNet +from .modules.utils import get_weights, init_net def define_C( @@ -22,7 +25,7 @@ def define_C( model_init_type, model_init_gain, train_sem_cls_pretrained, - **unused_options + **unused_options, ): img_size = data_crop_size if train_sem_cls_template == "basic": @@ -43,6 +46,7 @@ def define_f( f_s_net, model_input_nc, f_s_semantic_nclasses, + model_type_sam, model_init_type, model_init_gain, f_s_config_segformer, @@ -50,7 +54,7 @@ def define_f( f_s_weight_sam, jg_dir, data_crop_size, - **unused_options + **unused_options, ): if f_s_net == "vgg": net = VGG16_FCN8s( @@ -95,7 +99,7 @@ def define_f( net.net.load_state_dict(weights, strict=False) return net elif f_s_net == "sam": - net, mg = load_sam_weight(f_s_weight_sam) + net, mg = init_sam_net(model_type_sam, f_s_weight_sam, device=None) return net, mg return init_net(net, model_init_type, model_init_gain) diff --git a/options/base_options.py b/options/base_options.py index 6e7d15b9c..fe26392c4 100644 --- a/options/base_options.py +++ b/options/base_options.py @@ -11,6 +11,7 @@ import data import models from models.modules.classifiers import TORCH_MODEL_CLASSES +from models.modules.utils import download_mobile_sam_weight, download_sam_weight from util import util from util.util import MAX_INT, flatten_json, pairs_of_floats, pairs_of_ints @@ -190,7 +191,7 @@ def initialize(self, parser): "--D_weight_sam", type=str, default="", - help="path to sam weight for D, e.g. models/configs/sam/pretrain/sam_vit_b_01ec64.pth", + help="path to sam weight for D, e.g. models/configs/sam/pretrain/sam_vit_b_01ec64.pth, or models/configs/sam/pretrain/mobile_sam.pt for MobileSAM", ) # generator @@ -261,9 +262,19 @@ def initialize(self, parser): default="models/configs/segformer/segformer_config_b0.json", help="path to segformer configuration file for G", ) - parser.add_argument("--G_attn_nb_mask_attn", default=10, type=int) + parser.add_argument( + "--G_attn_nb_mask_attn", + default=10, + type=int, + help="number of attention masks in _attn model architectures", + ) - parser.add_argument("--G_attn_nb_mask_input", default=1, type=int) + parser.add_argument( + "--G_attn_nb_mask_input", + default=1, + type=int, + help="number of mask dedicated to input in _attn model architectures", + ) parser.add_argument( "--G_backward_compatibility_twice_resnet_blocks", @@ -285,8 +296,18 @@ def initialize(self, parser): help="specify multimodal latent vector encoder", ) - parser.add_argument("--G_unet_mha_num_head_channels", default=32, type=int) - parser.add_argument("--G_unet_mha_num_heads", default=1, type=int) + parser.add_argument( + "--G_unet_mha_num_head_channels", + default=32, + type=int, + help="number of channels in each head of the mha architecture", + ) + parser.add_argument( + "--G_unet_mha_num_heads", + default=1, + type=int, + help="number of heads in the mha architecture", + ) parser.add_argument( "--G_uvit_num_transformer_blocks", @@ -523,10 +544,9 @@ def initialize(self, parser): "--f_s_weight_sam", type=str, default="", - help="path to sam weight for f_s, e.g. models/configs/sam/pretrain/sam_vit_b_01ec64.pth", + help="path to sam weight for f_s, e.g. models/configs/sam/pretrain/sam_vit_b_01ec64.pth, or models/configs/sam/pretrain/mobile_sam.pt for MobileSAM", ) - # cls semantic network parser.add_argument( "--cls_net", type=str, @@ -654,6 +674,14 @@ def initialize(self, parser): help="whether to use refined mask with sam", ) + parser.add_argument( + "--model_type_sam", + type=str, + default="mobile_sam", + choices=["sam", "mobile_sam"], + help="which model to use for segment-anything mask generation", + ) + # Online dataset creation options parser.add_argument( "--data_online_select_category", @@ -994,10 +1022,12 @@ def _after_parse(self, opt, set_device=True): opt.D_proj_interp = 224 # Dsam requires D_weight_sam - if "sam" in opt.D_netDs and opt.D_weight_sam == "": - raise ValueError( - "Dsam requires D_weight_sam, please specify a path to a pretrained sam model" - ) + if "sam" in opt.D_netDs: + if opt.D_weight_sam == "": + raise ValueError( + "Dsam requires D_weight_sam, please specify a path to a pretrained SAM or MobileSAM model" + ) + download_sam_weight(opt.D_weight_sam) # diffusion D + vitsmall check if opt.dataaug_D_diffusion and "vit" in opt.D_proj_network_type: @@ -1013,54 +1043,57 @@ def _after_parse(self, opt, set_device=True): ): raise ValueError("SAM with masks and bbox prompting requires Pytorch 2") if opt.f_s_net == "sam" and opt.data_dataset_mode == "unaligned_labeled_mask": - raise warning.warn("SAM with direct masks does not use mask/bbox prompting") + warnings.warn("SAM with direct masks does not use mask/bbox prompting") # mask delta check - if opt.data_online_creation_mask_delta_A == None: + if opt.data_online_creation_mask_delta_A == [[]]: pass else: if ( - len(opt.data_online_creation_mask_delta_A) > 1 - and len(opt.data_online_creation_mask_delta_A) + len(opt.data_online_creation_mask_delta_A) < opt.f_s_semantic_nclasses - 1 ): + if len(opt.data_online_creation_mask_delta_A) == 1: + warnings.warn( + "Mask delta A list should be of length f_s_semantic_nclasses, distributing single value across all classes" + ) + opt.data_online_creation_mask_delta_A = ( + opt.data_online_creation_mask_delta_A + * (opt.f_s_semantic_nclasses - 1) + ) + elif len(opt.data_online_creation_mask_delta_A) > 1: raise ValueError( "Mask delta A list must be of length f_s_semantic_nclasses" ) - if ( - len(opt.data_online_creation_mask_delta_A) - == opt.f_s_semantic_nclasses - 1 - ): - if ( - len(opt.data_online_creation_mask_delta_A[0]) == 1 - and not opt.data_online_creation_mask_square_A - ): - raise ValueError( - "Mask delta A has a single value per dimension but --data_online_creation_mask_square_A is not set, please set it to True" - ) - if opt.data_online_creation_mask_delta_B == None: + + if opt.data_online_creation_mask_delta_B == [[]]: pass else: if ( - len(opt.data_online_creation_mask_delta_B) > 1 - and len(opt.data_online_creation_mask_delta_B) + len(opt.data_online_creation_mask_delta_B) < opt.f_s_semantic_nclasses - 1 ): + if len(opt.data_online_creation_mask_delta_B) == 1: + warnings.warn( + "Mask delta B list should be of length f_s_semantic_nclasses, distributing single value across all classes" + ) + opt.data_online_creation_mask_delta_B = ( + opt.data_online_creation_mask_delta_B + * (opt.f_s_semantic_nclasses - 1) + ) + elif len(opt.data_online_creation_mask_delta_B) > 1: raise ValueError( "Mask delta B list must be of length f_s_semantic_nclasses" ) - if ( - len(opt.data_online_creation_mask_delta_B) - == opt.f_s_semantic_nclasses - 1 - ): - if ( - len(opt.data_online_creation_mask_delta_B[0]) == 1 - and not opt.data_online_creation_mask_square_B - ): - raise ValueError( - "Mask delta B has a single value per dimension but --data_online_creation_mask_square_B is not set, please set it to True" - ) + # training is frequency space only available for a few architectures atm + if opt.train_feat_wavelet: + if not opt.G_netG in ["mobile_resnet_attn", "unet_mha", "uvit"]: + raise ValueError( + "Wavelet training is only available for mobile_resnet_attn, unet_mha and uvit architectures" + ) + + # register options self.opt = opt return self.opt diff --git a/options/test_options.py b/options/test_options.py deleted file mode 100644 index aade3861c..000000000 --- a/options/test_options.py +++ /dev/null @@ -1,39 +0,0 @@ -from .base_options import BaseOptions - - -class TestOptions(BaseOptions): - """This class includes test options. - - It also includes shared options defined in BaseOptions. - """ - - def initialize(self, parser): - parser = BaseOptions.initialize(self, parser) # define shared options - parser.add_argument( - "--test_ntest", type=int, default=float("inf"), help="# of test examples." - ) - parser.add_argument( - "--test_results_dir", - type=str, - default="./results/", - help="saves results here.", - ) - parser.add_argument( - "--test_aspect_ratio", - type=float, - default=1.0, - help="aspect ratio of result images", - ) - # Dropout and Batchnorm has different behaviour during training and test. - parser.add_argument( - "--test_eval", action="store_true", help="use eval mode during test time." - ) - parser.add_argument( - "--test_num_test", type=int, default=50, help="how many test images to run" - ) - # rewrite devalue values - parser.set_defaults(model="test") - # To avoid cropping, the load_size should be the same as crop_size - parser.set_defaults(load_size=parser.get_default("crop_size")) - self.isTrain = False - return parser diff --git a/options/train_options.py b/options/train_options.py index 968b62fce..f26dcef91 100644 --- a/options/train_options.py +++ b/options/train_options.py @@ -180,14 +180,24 @@ def initialize(self, parser): help="which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]", ) - parser.add_argument("--train_compute_metrics_test", action="store_true") - parser.add_argument("--train_metrics_every", type=int, default=1000) + parser.add_argument( + "--train_compute_metrics_test", + action="store_true", + help="whether to compute test metrics, e.g. FID, ...", + ) + parser.add_argument( + "--train_metrics_every", + type=int, + default=1000, + help="compute metrics every N iterations", + ) parser.add_argument( "--train_metrics_list", type=str, default=["FID"], nargs="*", - choices=["FID", "KID", "MSID", "PSNR"], + choices=["FID", "KID", "MSID", "PSNR", "LPIPS"], + help="metrics on results quality to compute", ) parser.add_argument( @@ -282,6 +292,13 @@ def initialize(self, parser): ) parser.add_argument("--train_use_contrastive_loss_D", action="store_true") + # frequency space training + parser.add_argument( + "--train_feat_wavelet", + action="store_true", + help="if true, train in wavelet features space (Note: this may not include all discriminators, when training GANs)", + ) + # multimodal training parser.add_argument( "--train_mm_lambda_z", @@ -419,8 +436,17 @@ def initialize(self, parser): help="if true, object removal mode, domain B images with label 0, cut models only", ) - parser.add_argument("--train_mask_compute_miou", action="store_true") - parser.add_argument("--train_mask_miou_every", type=int, default=1000) + parser.add_argument( + "--train_mask_compute_miou", + action="store_true", + help="whether to compute mIoU on semantic masks prediction", + ) + parser.add_argument( + "--train_mask_miou_every", + type=int, + default=1000, + help="compute mIoU every n iterations", + ) # train with temporal criterion loss parser.add_argument( diff --git a/requirements.txt b/requirements.txt index bcd8505a5..eb69275c3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,4 +20,6 @@ onnx aim git+https://github.com/facebookresearch/segment-anything.git piq - +git+https://github.com/ChaoningZhang/MobileSAM.git +Ninja +lpips diff --git a/scripts/gen_single_image_diffusion.py b/scripts/gen_single_image_diffusion.py index 446edc441..c1fd1cf23 100644 --- a/scripts/gen_single_image_diffusion.py +++ b/scripts/gen_single_image_diffusion.py @@ -27,6 +27,7 @@ from models.modules.diffusion_utils import set_new_noise_schedule from models.modules.sam.sam_inference import ( compute_mask_with_sam, + init_sam_net, load_sam_weight, predict_sam_mask, ) @@ -433,16 +434,22 @@ def generate( mask = mask.to(device).clone().detach() if mask is not None: - if data_refined_mask: + if data_refined_mask or opt.data_refined_mask: opt.f_s_weight_sam = "../" + opt.f_s_weight_sam - if not os.path.exists(opt.f_s_weight_sam): - download_sam_weight(path=opt.f_s_weight_sam) - sam_model, _ = load_sam_weight(model_path=opt.f_s_weight_sam) - sam_model = sam_model.to(device) + sam_model, _ = init_sam_net( + model_type_sam=opt.model_type_sam, + model_path=opt.f_s_weight_sam, + device=device, + ) mask = compute_mask_with_sam( img_tensor, mask, sam_model, device, batched=False ).unsqueeze(0) + if opt.data_inverted_mask: + mask[mask > 0] = 2 + mask[mask == 0] = 1 + mask[mask == 2] = 0 + if opt.data_online_creation_rand_mask_A: y_t = fill_mask_with_random( img_tensor.clone().detach(), mask.clone().detach(), -1 diff --git a/scripts/run_tests.sh b/scripts/run_tests.sh index 7b40d9b24..cce4db8e0 100644 --- a/scripts/run_tests.sh +++ b/scripts/run_tests.sh @@ -91,6 +91,8 @@ wget -N $URL -O $ZIP_FILE mkdir $TARGET_MASK_SEM_ONLINE_DIR unzip $ZIP_FILE -d $DIR rm $ZIP_FILE +ln -s $TARGET_MASK_SEM_ONLINE_DIR/trainA $TARGET_MASK_SEM_ONLINE_DIR/testA +ln -s $TARGET_MASK_SEM_ONLINE_DIR/trainB $TARGET_MASK_SEM_ONLINE_DIR/testB python3 -m pytest -p no:cacheprovider -s "${current_dir}/../tests/test_run_semantic_mask_online.py" --dataroot "$TARGET_MASK_SEM_ONLINE_DIR" @@ -108,6 +110,39 @@ if [ $OUT != 0 ]; then exit 1 fi +###### test cut +echo "Running test cut" +python3 "${current_dir}/../test.py" \ + --test_model_dir $DIR/joligen_utest_cut/ \ + --test_metrics_list FID KID PSNR LPIPS +OUT=$? + +if [ $OUT != 0 ]; then + exit 1 +fi + +###### test cycle_gan +echo "Running test cycle_gan" +python3 "${current_dir}/../test.py" \ + --test_model_dir $DIR/joligen_utest_cycle_gan/ \ + --test_metrics_list FID KID PSNR LPIPS +OUT=$? + +if [ $OUT != 0 ]; then + exit 1 +fi + +###### test palette +echo "Running test palette" +python3 "${current_dir}/../test.py" \ + --test_model_dir $DIR/joligen_utest_palette/ \ + --test_metrics_list FID KID PSNR LPIPS +OUT=$? + +if [ $OUT != 0 ]; then + exit 1 +fi + ####### mask cls semantics test echo "Running mask and class semantics training tests" URL=https://joligen.com/datasets/daytime2dawn_dusk_lite.zip diff --git a/test.py b/test.py index 5d5cf7958..3c1dd5fe3 100644 --- a/test.py +++ b/test.py @@ -1,85 +1,127 @@ -"""General-purpose test script for image-to-image translation. +import argparse +import os +import torch +import random +import numpy as np +import time +import json -Once you have trained your model with train.py, you can use this script to test the model. -It will load a saved model from --checkpoints_dir and save the results to --results_dir. +from data import ( + create_dataloader, + create_dataset, + create_dataset_temporal, + create_iterable_dataloader, +) +from models import create_model +from util.parser import get_opt +from util.util import MAX_INT -It first creates model and dataset given the option. It will hard-code some parameters. -It then runs inference for --num_test images and save results to an HTML file. -Example (You need to train models first or download pre-trained models from our website): - Test a CycleGAN model (both sides): - python test.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan +def launch_testing(opt): + rank = 0 - Test a CycleGAN model (one side only): - python test.py --dataroot datasets/horse2zebra/testA --name horse2zebra_pretrained --model test + opt.jg_dir = os.path.join("/".join(__file__.split("/")[:-1])) + opt.use_cuda = torch.cuda.is_available() and opt.gpu_ids and opt.gpu_ids[0] >= 0 + if opt.use_cuda: + torch.cuda.set_device(opt.gpu_ids[rank]) + opt.isTrain = False - The option '--model test' is used for generating CycleGAN results only for one side. - This option will automatically set '--dataset_mode single', which only loads the images from one set. - On the contrary, using '--model cycle_gan' requires loading and generating results in both directions, - which is sometimes unnecessary. The results will be saved at ./results/. - Use '--results_dir ' to specify the results directory. + testset = create_dataset(opt, phase="test") + print("The number of testing images = %d" % len(testset)) + opt.train_nb_img_max_fid = min(opt.train_nb_img_max_fid, len(testset)) - Test a pix2pix model: - python test.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --direction BtoA + dataloader_test = create_dataloader( + opt, rank, testset, batch_size=opt.test_batch_size + ) # create a dataset given opt.dataset_mode and other options + + use_temporal = ("temporal" in opt.D_netDs) or opt.train_temporal_criterion + + if use_temporal: + testset_temporal = create_dataset_temporal(opt, phase="test") + + dataloader_test_temporal = create_iterable_dataloader( + opt, rank, testset_temporal, batch_size=opt.test_batch_size + ) + else: + dataloader_test_temporal = None + + model = create_model(opt, rank) # create a model given opt.model and other options + model.setup(opt) # regular setup: load and print networks; create schedulers + model.use_temporal = use_temporal + model.eval() + if opt.use_cuda: + model.single_gpu() + model.init_metrics(dataloader_test) + + if use_temporal: + dataloaders_test = zip(dataloader_test, dataloader_test_temporal) + else: + dataloaders_test = zip(dataloader_test) + + epoch = "test" + total_iters = "test" + with torch.no_grad(): + model.compute_metrics_test(dataloaders_test, epoch, total_iters) + + metrics = model.get_current_metrics() + for metric, value in metrics.items(): + print(f"{metric}: {value}") + + metrics_dir = os.path.join(opt.test_model_dir, "metrics") + os.makedirs(metrics_dir, exist_ok=True) + metrics_file = os.path.join(metrics_dir, time.strftime("%Y%m%d-%H%M%S") + ".json") + with open(metrics_file, "w") as f: + f.write(json.dumps(metrics, indent=4)) + print("metrics written to:", metrics_file) -See options/base_options.py and options/test_options.py for more test options. -See training and test tips at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/tips.md -See frequently asked questions at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/qa.md -""" -import os -from options.test_options import TestOptions -from data import create_dataset -from models import create_model -from util.visualizer import save_images -from util import html_util if __name__ == "__main__": - opt = TestOptions().parse() # get test options - # hard-code some parameters for test - opt.num_threads = 0 # test code only supports num_threads = 1 - opt.batch_size = 1 # test code only supports batch_size = 1 - opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed. - opt.no_flip = ( - True # no flip; comment this line if results on flipped images are needed. + main_parser = argparse.ArgumentParser() + + main_parser.add_argument( + "--test_model_dir", type=str, required=True, help="path to model directory" ) - opt.display_id = ( - -1 - ) # no visdom display; the test code saves the results to a HTML file. - dataset = create_dataset( - opt - ) # create a dataset given opt.dataset_mode and other options - model = create_model(opt) # create a model given opt.model and other options - model.setup(opt) # regular setup: load and print networks; create schedulers - # create a website - web_dir = os.path.join( - opt.results_dir, opt.name, "{}_{}".format(opt.phase, opt.epoch) - ) # define the website directory - if opt.load_iter > 0: # load_iter is 0 by default - web_dir = "{:s}_iter{:d}".format(web_dir, opt.load_iter) - print("creating web directory", web_dir) - webpage = html_util.HTML( - web_dir, - "Experiment = %s, Phase = %s, Epoch = %s" % (opt.name, opt.phase, opt.epoch), + main_parser.add_argument( + "--test_epoch", + type=str, + default="latest", + help="which epoch to load? set to latest to use latest cached model", ) - # test with eval mode. This only affects layers like batchnorm and dropout. - # For [pix2pix]: we use batchnorm and dropout in the original pix2pix. You can experiment it with and without eval() mode. - # For [CycleGAN]: It should not affect CycleGAN as CycleGAN uses instancenorm without dropout. - if opt.eval: - model.eval() - for i, data in enumerate(dataset): - if i >= opt.num_test: # only apply our model to opt.num_test images. - break - model.set_input(data) # unpack data from data loader - model.test() # run inference - visuals = model.get_current_visuals() # get image results - img_path = model.get_image_paths() # get image paths - if i % 5 == 0: # save images to an HTML file - print("processing (%04d)-th image... %s" % (i, img_path)) - save_images( - webpage, - visuals, - img_path, - aspect_ratio=opt.aspect_ratio, - width=opt.display_winsize, - ) - webpage.save() # save the HTML + main_parser.add_argument( + "--test_metrics_list", + type=str, + nargs="*", + choices=["FID", "KID", "MSID", "PSNR", "LPIPS"], + default=["FID", "KID", "MSID", "PSNR", "LPIPS"], + ) + main_parser.add_argument( + "--test_nb_img", + type=int, + default=MAX_INT, + help="Number of samples to compute metrics. If the dataset directory contains more, only a subset is used.", + ) + main_parser.add_argument( + "--test_batch_size", type=int, default=1, help="input batch size" + ) + main_parser.add_argument( + "--test_seed", type=int, default=42, help="seed to use for tests" + ) + + main_opt, remaining_args = main_parser.parse_known_args() + main_opt.config_json = os.path.join(main_opt.test_model_dir, "train_config.json") + + opt = get_opt(main_opt, remaining_args) + + # override global options with local test options + opt.train_compute_metrics_test = True + opt.test_model_dir = main_opt.test_model_dir + opt.train_epoch = main_opt.test_epoch + opt.train_metrics_list = main_opt.test_metrics_list + opt.train_nb_img_max_fid = main_opt.test_nb_img + opt.test_batch_size = main_opt.test_batch_size + + random.seed(main_opt.test_seed) + torch.manual_seed(main_opt.test_seed) + np.random.seed(main_opt.test_seed) + + launch_testing(opt) diff --git a/tests/test_run_diffusion.py b/tests/test_run_diffusion.py index e69f21f82..16510e332 100644 --- a/tests/test_run_diffusion.py +++ b/tests/test_run_diffusion.py @@ -35,6 +35,7 @@ models_diffusion = ["palette"] G_netG = ["unet_mha", "uvit"] G_efficient = [True, False] +train_feat_wavelet = [False, True] G_unet_mha_norm_layer = [ "groupnorm", @@ -55,6 +56,7 @@ G_unet_mha_norm_layer, alg_palette_conditioning, G_efficient, + train_feat_wavelet, ) @@ -62,7 +64,14 @@ def test_semantic_mask(dataroot): json_like_dict["dataroot"] = dataroot json_like_dict["checkpoints_dir"] = "/".join(dataroot.split("/")[:-1]) - for model, Gtype, G_norm, alg_palette_conditioning, G_efficient in product_list: + for ( + model, + Gtype, + G_norm, + alg_palette_conditioning, + G_efficient, + train_feat_wavelet, + ) in product_list: json_like_dict_c = json_like_dict.copy() json_like_dict_c["model_type"] = model @@ -73,6 +82,7 @@ def test_semantic_mask(dataroot): json_like_dict_c["G_unet_mha_vit_efficient"] = G_efficient json_like_dict_c["alg_palette_conditioning"] = alg_palette_conditioning + json_like_dict_c["train_feat_wavelet"] = train_feat_wavelet opt = TrainOptions().parse_json(json_like_dict_c) train.launch_training(opt) diff --git a/tests/test_run_diffusion_online.py b/tests/test_run_diffusion_online.py index d77a3c1e5..ad20bf896 100644 --- a/tests/test_run_diffusion_online.py +++ b/tests/test_run_diffusion_online.py @@ -14,7 +14,6 @@ "output_display_env": "joligen_utest", "output_display_id": 0, "gpu_ids": "0", - "data_dataset_mode": "self_supervised_labeled_mask_online", "data_load_size": 128, "data_crop_size": 128, "data_online_creation_crop_size_A": 420, @@ -39,24 +38,32 @@ "data_online_creation_rand_mask_A": True, "train_export_jit": True, "train_save_latest_freq": 10, + "G_diff_n_timestep_test": 10, } models_diffusion = ["palette"] G_netG = ["unet_mha", "uvit"] +data_dataset_mode = ["self_supervised_labeled_mask_online", "self_supervised_temporal"] - -product_list = product(models_diffusion, G_netG) +product_list = product(models_diffusion, G_netG, data_dataset_mode) def test_diffusion_online(dataroot): json_like_dict["dataroot"] = dataroot json_like_dict["checkpoints_dir"] = "/".join(dataroot.split("/")[:-1]) - for model, Gtype in product_list: + for model, Gtype, dataset_mode in product_list: json_like_dict_c = json_like_dict.copy() json_like_dict_c["model_type"] = model json_like_dict_c["name"] += "_" + model json_like_dict_c["G_netG"] = Gtype - opt = TrainOptions().parse_json(json_like_dict_c) + json_like_dict_c["data_dataset_mode"] = dataset_mode + if dataset_mode == "self_supervised_temporal": + json_like_dict_c["data_temporal_number_frames"] = 2 + json_like_dict_c["data_temporal_frame_step"] = 1 + json_like_dict_c["data_temporal_num_common_char"] = 3 + json_like_dict_c["alg_palette_cond_image_creation"] = "previous_frame" + + opt = TrainOptions().parse_json(json_like_dict_c, save_config=True) train.launch_training(opt) diff --git a/tests/test_run_nosemantic.py b/tests/test_run_nosemantic.py index 157438e88..145c586a9 100644 --- a/tests/test_run_nosemantic.py +++ b/tests/test_run_nosemantic.py @@ -16,8 +16,8 @@ "output_display_id": 0, "gpu_ids": "0", "data_dataset_mode": "unaligned", - "data_load_size": 180, - "data_crop_size": 180, + "data_load_size": 128, + "data_crop_size": 128, "train_n_epochs": 1, "train_n_epochs_decay": 0, "data_max_dataset_size": 10, @@ -33,16 +33,19 @@ D_netDs = [["projected_d", "basic"], ["projected_d", "basic", "depth"]] -product_list = product(models_nosemantic, D_netDs) +train_feat_wavelet = [False, True] + +product_list = product(models_nosemantic, D_netDs, train_feat_wavelet) def test_nosemantic(dataroot): json_like_dict["dataroot"] = dataroot json_like_dict["checkpoints_dir"] = "/".join(dataroot.split("/")[:-1]) - for model, Dtype in product_list: + for model, Dtype, train_feat_wavelet in product_list: json_like_dict["model_type"] = model json_like_dict["D_netDs"] = Dtype + json_like_dict["train_feat_wavelet"] = train_feat_wavelet if model == "cycle_gan" and "depth" in Dtype: continue # skip diff --git a/tests/test_run_semantic_mask_online.py b/tests/test_run_semantic_mask_online.py index 22bbe299b..a817b40c7 100644 --- a/tests/test_run_semantic_mask_online.py +++ b/tests/test_run_semantic_mask_online.py @@ -36,6 +36,7 @@ "train_sem_use_label_B": True, "data_relative_paths": True, "D_netDs": ["basic", "projected_d", "temporal"], + "D_weight_sam": "models/configs/sam/pretrain/mobile_sam.pt", "train_gan_mode": "projected", "D_proj_interp": 256, "train_G_ema": True, @@ -59,22 +60,32 @@ D_proj_network_type = ["efficientnet", "vitsmall"] +D_netDs = [["basic", "projected_d", "temporal"], ["sam"]] + f_s_net = ["unet"] -product_list = product(models_semantic_mask, G_netG, D_proj_network_type, f_s_net) +model_type_sam = ["mobile_sam"] + +product_list = product( + models_semantic_mask, G_netG, D_proj_network_type, D_netDs, f_s_net, model_type_sam +) def test_semantic_mask_online(dataroot): json_like_dict["dataroot"] = dataroot json_like_dict["checkpoints_dir"] = "/".join(dataroot.split("/")[:-1]) - for model, Gtype, Dtype, f_s_type in product_list: + for model, Gtype, Dtype, Dnet, f_s_type, sam_type in product_list: + if model == "cycle_gan" and "sam" in Dnet: + continue json_like_dict_c = json_like_dict.copy() json_like_dict_c["model_type"] = model json_like_dict_c["name"] += "_" + model json_like_dict_c["G_netG"] = Gtype json_like_dict_c["D_proj_network_type"] = Dtype + json_like_dict_c["D_netDs"] = Dnet json_like_dict_c["f_s_net"] = f_s_type + json_like_dict_c["model_type_sam"] = sam_type - opt = TrainOptions().parse_json(json_like_dict_c) + opt = TrainOptions().parse_json(json_like_dict_c, save_config=True) train.launch_training(opt) diff --git a/train.py b/train.py index 3548d22ea..46717c9b6 100644 --- a/train.py +++ b/train.py @@ -25,6 +25,7 @@ import time import warnings import copy +import sys import torch import torch.distributed as dist @@ -37,8 +38,7 @@ create_iterable_dataloader, ) from models import create_model -from options.train_options import TrainOptions -from util.util import flatten_json +from util.parser import get_opt from util.visualizer import Visualizer from util.lion_pytorch import Lion @@ -127,7 +127,7 @@ def train_gpu(rank, world_size, opt, trainset, trainset_temporal): rank_0 = rank == 0 if rank_0: - model.init_metrics(dataloader, dataloader_test) + model.init_metrics(dataloader_test) model.setup(opt) # regular setup: load and print networks; create schedulers @@ -158,6 +158,20 @@ def train_gpu(rank, world_size, opt, trainset, trainset_temporal): for path in model.save_networks_img(data): visualizer.display_img(path + ".png") + if rank_0: + # Get the command line arguments + command_line_arguments = sys.argv + + # Join the arguments into a single string + command_line = " ".join(command_line_arguments) + + # Save the command line to a file + sv_path = os.path.join(opt.checkpoints_dir, opt.name, "command_line.txt") + with open(sv_path, "w") as file: + file.write(command_line) + + print(f"Command line was saved at {sv_path}") + for epoch in range( opt.train_epoch_count, opt.train_n_epochs + opt.train_n_epochs_decay + 1 ): # outer loop for different epochs; we save the model by , + @@ -361,10 +375,8 @@ def train_gpu(rank, world_size, opt, trainset, trainset_temporal): print("End of training") -def launch_training(opt=None): +def launch_training(opt): signal.signal(signal.SIGINT, signal_handler) # to really kill the process - if opt is None: - opt = TrainOptions().parse() # get training options opt.jg_dir = os.path.join("/".join(__file__.split("/")[:-1])) world_size = len(opt.gpu_ids) @@ -396,21 +408,6 @@ def launch_training(opt=None): train_gpu(0, world_size, opt, trainset, trainset_temporal) -def compute_test_metrics(model, dataloader): - - return metrics - - -def get_override_options_names(remaining_args): - return_options_names = [] - - for arg in remaining_args: - if arg.startswith("--"): - return_options_names.append(arg[2:]) - - return return_options_names - - if __name__ == "__main__": main_parser = argparse.ArgumentParser(add_help=False) @@ -420,25 +417,6 @@ def get_override_options_names(remaining_args): main_opt, remaining_args = main_parser.parse_known_args() - if main_opt.config_json != "": - override_options_names = get_override_options_names(remaining_args) - - if not "--dataroot" in remaining_args: - remaining_args += ["--dataroot", "unused"] - override_options_json = flatten_json( - TrainOptions().parse_to_json(remaining_args) - ) - - with open(main_opt.config_json, "r") as jsonf: - train_json = flatten_json(json.load(jsonf)) - - for name in override_options_names: - train_json[name] = override_options_json[name] - - opt = TrainOptions().parse_json(train_json) - - print("%s config file loaded" % main_opt.config_json) - else: - opt = None + opt = get_opt(main_opt, remaining_args) launch_training(opt) diff --git a/util/metrics.py b/util/metrics.py index 510724c41..d7ac0bf50 100755 --- a/util/metrics.py +++ b/util/metrics.py @@ -102,7 +102,9 @@ def get_activations( "Number of images limitation doesn't work with pytorch dataloaders, the full dataset will be used instead for activations computation." ) - for batch in tqdm(dataloader, total=len(dataloader) // batch_size): + for batch in tqdm( + dataloader, total=len(dataloader) // batch_size, desc="activations" + ): if isinstance(batch, dict) and domain is not None: img = batch[domain].to(device) else: diff --git a/util/parser.py b/util/parser.py new file mode 100644 index 000000000..43295ac72 --- /dev/null +++ b/util/parser.py @@ -0,0 +1,40 @@ +import json +import os +import torch +from util.util import flatten_json +from options.train_options import TrainOptions + + +def get_override_options_names(remaining_args): + return_options_names = [] + + for arg in remaining_args: + if arg.startswith("--"): + return_options_names.append(arg[2:]) + + return return_options_names + + +def get_opt(main_opt, remaining_args): + if main_opt.config_json != "": + override_options_names = get_override_options_names(remaining_args) + + if not "--dataroot" in remaining_args: + remaining_args += ["--dataroot", "unused"] + override_options_json = flatten_json( + TrainOptions().parse_to_json(remaining_args) + ) + + with open(main_opt.config_json, "r") as jsonf: + train_json = flatten_json(json.load(jsonf)) + + for name in override_options_names: + train_json[name] = override_options_json[name] + + opt = TrainOptions().parse_json(train_json) + + print("%s config file loaded" % main_opt.config_json) + else: + opt = TrainOptions().parse() # get training options + + return opt