Skip to content

Commit

Permalink
feat(ml): UNetVid for generating video with bs > 1
Browse files Browse the repository at this point in the history
  • Loading branch information
wr0124 authored and beniz committed Sep 6, 2024
1 parent 03c0bec commit 00f11bc
Show file tree
Hide file tree
Showing 16 changed files with 442 additions and 63 deletions.
3 changes: 2 additions & 1 deletion data/self_supervised_temporal_labeled_mask_online_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(self, opt, phase, name=""):
# sort
self.A_img_paths.sort(key=natural_keys)
self.A_label_mask_paths.sort(key=natural_keys)

if self.opt.data_sanitize_paths:
self.sanitize()
elif opt.data_max_dataset_size != float("inf"):
Expand Down Expand Up @@ -95,6 +96,7 @@ def get_img(
print("Condition not met, generating a new index_A...")
else: # dataset from one video in form of img/frames.jpg
break

for i in range(self.num_frames):
cur_index_A = index_A + i * self.frame_step

Expand Down Expand Up @@ -212,5 +214,4 @@ def get_img(
ref_A_img_path,
)
return None

return result
1 change: 1 addition & 0 deletions docs/source/training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -208,3 +208,4 @@ Train a DDPM model to generate a sequence of frame images for inpainting, ensuri
.. code:: bash
python3 train.py --dataroot /path/to/data/online_mario2sonic_full --checkpoints_dir /path/to/checkpoints --name mario_vid --config_json examples/example_ddpm_vid_mario.json
339 changes: 339 additions & 0 deletions examples/example_ddpm_vid_mario.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,339 @@
{
"D": {
"dropout": false,
"n_layers": 3,
"ndf": 64,
"netDs": [
"projected_d",
"basic"
],
"no_antialias": false,
"no_antialias_up": false,
"norm": "instance",
"proj_config_segformer": "models/configs/segformer/segformer_config_b0.json",
"proj_interp": -1,
"proj_network_type": "efficientnet",
"proj_weight_segformer": "models/configs/segformer/pretrain/segformer_mit-b0.pth",
"spectral": false,
"temporal_every": 4,
"vision_aided_backbones": "clip+dino+swin",
"weight_sam": ""
},
"G": {
"attn_nb_mask_attn": 10,
"attn_nb_mask_input": 1,
"backward_compatibility_twice_resnet_blocks": false,
"config_segformer": "models/configs/segformer/segformer_config_b0.json",
"diff_n_timestep_test": 1000,
"diff_n_timestep_train": 2000,
"dropout": false,
"hdit_depths": [
2,
2,
4
],
"hdit_patch_size": 4,
"hdit_widths": [
192,
384,
768
],
"lora_unet": 8,
"lora_vae": 8,
"nblocks": 9,
"netE": "resnet_256",
"netG": "unet_vid",
"ngf": 64,
"norm": "instance",
"padding_type": "reflect",
"spectral": false,
"unet_mha_attn_res": [
16
],
"unet_mha_channel_mults": [
1,
2,
4,
8
],
"unet_mha_group_norm_size": 32,
"unet_mha_norm_layer": "groupnorm",
"unet_mha_num_head_channels": 32,
"unet_mha_num_heads": 1,
"unet_mha_res_blocks": [
2,
2,
2,
2
],
"unet_mha_vit_efficient": false,
"uvit_num_transformer_blocks": 6
},
"alg": {
"diffusion_cond_computed_sketch_list": [
"canny",
"hed"
],
"diffusion_cond_embed": "",
"diffusion_cond_embed_dim": 32,
"diffusion_cond_image_creation": "y_t",
"diffusion_cond_prob_use_previous_frame": 0.5,
"diffusion_cond_sam_crop_delta": true,
"diffusion_cond_sam_final_canny": false,
"diffusion_cond_sam_max_mask_area": 0.99,
"diffusion_cond_sam_min_mask_area": 0.001,
"diffusion_cond_sam_no_output_binary_sam": true,
"diffusion_cond_sam_no_sample_points_in_ellipse": true,
"diffusion_cond_sam_no_sobel_filter": true,
"diffusion_cond_sam_points_per_side": 16,
"diffusion_cond_sam_redundancy_threshold": 0.62,
"diffusion_cond_sam_sobel_threshold": 0.7,
"diffusion_cond_sam_use_gaussian_filter": false,
"diffusion_cond_sketch_canny_range": [
0,
765
],
"diffusion_dropout_prob": 0.0,
"diffusion_generate_per_class": false,
"diffusion_lambda_G": 1.0,
"diffusion_ref_embed_net": "clip",
"diffusion_super_resolution_scale": 2.0,
"diffusion_task": "inpainting",
"re": {
"P_lr": 0.0002,
"adversarial_loss_p": false,
"netP": "unet_128",
"no_train_P_fake_images": false,
"nuplet_size": 3,
"projection_threshold": 1.0
},
"palette": {
"ddim_eta": 0.5,
"ddim_num_steps": 10,
"loss": "MSE",
"minsnr": false,
"sampling_method": "ddpm"
}
},
"data": {
"crop_size": 64,
"dataset_mode": "self_supervised_temporal_labeled_mask_online",
"direction": "AtoB",
"image_bits": 8,
"inverted_mask": false,
"load_size": 64,
"max_dataset_size": 1000000000,
"num_threads": 4,
"online_context_pixels": 0,
"online_fixed_mask_size": -1,
"online_random_bbox": false,
"online_select_category": -1,
"online_single_bbox": false,
"preprocess": "resize_and_crop",
"refined_mask": false,
"relative_paths": true,
"sanitize_paths": false,
"serial_batches": false,
"temporal_frame_step": 1,
"temporal_num_common_char": -1,
"temporal_number_frames": 8,
"online_creation": {
"color_mask_A": false,
"crop_delta_A": 50,
"crop_delta_B": 50,
"crop_size_A": 64,
"crop_size_B": 64,
"load_size_A": [],
"load_size_B": [],
"mask_delta_A": [
[]
],
"mask_delta_A_ratio": [
[
0.12
],
[
0.12
]
],
"mask_delta_B": [
[]
],
"mask_delta_B_ratio": [
[]
],
"mask_random_offset_A": [
0.0
],
"mask_random_offset_B": [
0.0
],
"mask_square_A": false,
"mask_square_B": false,
"rand_mask_A": true
}
},
"f_s": {
"all_classes_as_one": false,
"class_weights": [],
"config_segformer": "models/configs/segformer/segformer_config_b0.json",
"dropout": false,
"net": "vgg",
"nf": 64,
"semantic_nclasses": 2,
"semantic_threshold": 1.0,
"weight_sam": "",
"weight_segformer": ""
},
"cls": {
"all_classes_as_one": false,
"class_weights": [],
"config_segformer": "models/configs/segformer/segformer_config_b0.json",
"dropout": false,
"net": "vgg",
"nf": 64,
"semantic_nclasses": 2,
"semantic_threshold": 1.0,
"weight_segformer": ""
},
"output": {
"no_html": false,
"num_images": 20,
"print_freq": 100,
"update_html_freq": 1000,
"verbose": true,
"display": {
"G_attention_masks": false,
"aim_port": 53800,
"aim_server": "http://localhost",
"diff_fake_real": false,
"env": "",
"freq": 5000,
"id": 1,
"ncols": 0,
"networks": false,
"type": [
"visdom"
],
"visdom_autostart": false,
"visdom_port": 8097,
"visdom_server": "http://localhost",
"winsize": 256
}
},
"model": {
"depth_network": "DPT_Large",
"init_gain": 0.02,
"init_type": "normal",
"input_nc": 3,
"multimodal": false,
"output_nc": 3,
"prior_321_backwardcompatibility": false,
"type_sam": "mobile_sam"
},
"train": {
"D_accuracy_every": 1000,
"D_lr": 0.0001,
"G_ema": true,
"G_ema_beta": 0.999,
"G_lr": 1e-04,
"batch_size": 1,
"beta1": 0.9,
"beta2": 0.999,
"cls_l1_regression": false,
"cls_regression": false,
"compute_D_accuracy": false,
"compute_metrics_test": false,
"continue": false,
"epoch": "latest",
"epoch_count": 1,
"export_jit": false,
"feat_wavelet": false,
"gan_mode": "lsgan",
"iter_size": 32,
"load_iter": 0,
"lr_decay_iters": 50,
"lr_policy": "linear",
"metrics_every": 1000,
"metrics_list": [
"FID"
],
"metrics_save_images": false,
"mm_lambda_z": 0.5,
"mm_nz": 8,
"n_epochs": 100,
"n_epochs_decay": 100,
"nb_img_max_fid": 1000000000,
"optim": "adamw",
"optim_eps": 1e-08,
"optim_weight_decay": 0.0,
"pool_size": 50,
"save_by_iter": false,
"save_epoch_freq": 1,
"save_latest_freq": 1000,
"semantic_cls": false,
"semantic_mask": false,
"temporal_criterion": false,
"temporal_criterion_lambda": 1.0,
"use_contrastive_loss_D": false,
"sem": {
"cls_B": false,
"cls_lambda": 1.0,
"cls_pretrained": false,
"cls_template": "basic",
"idt": false,
"lr_cls": 0.0002,
"lr_f_s": 0.0002,
"mask_lambda": 1.0,
"net_output": false,
"use_label_B": false
},
"mask": {
"charbonnier_eps": 1e-06,
"compute_miou": false,
"disjoint_f_s": false,
"f_s_B": false,
"for_removal": false,
"lambda_out_mask": 10.0,
"loss_out_mask": "L1",
"miou_every": 1000,
"no_train_f_s_A": false,
"out_mask": false
}
},
"dataaug": {
"APA": false,
"APA_every": 4,
"APA_nimg": 50,
"APA_p": 0,
"APA_target": 0.6,
"D_diffusion": false,
"D_diffusion_every": 4,
"D_label_smooth": false,
"D_noise": 0.0,
"affine": 0.0,
"affine_scale_max": 1.2,
"affine_scale_min": 0.8,
"affine_shear": 45,
"affine_translate": 0.2,
"diff_aug_policy": "",
"diff_aug_proba": 0.5,
"flip": "horizontal",
"imgaug": false,
"no_rotate": true
},
"checkpoints_dir": "checkpoints",
"dataroot": "online_mario2sonic_full/mario",
"ddp_port": "12355",
"gpu_ids": "0",
"model_type": "palette",
"name": "mario_vid",
"phase": "train",
"suffix": "",
"test_batch_size": 1,
"warning_mode": false,
"with_amp": false,
"with_tf32": false,
"with_torch_compile": false
}

1 change: 1 addition & 0 deletions models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1621,6 +1621,7 @@ def compute_metrics_test(
real_tensor = (torch.clamp(torch.cat(real_list), min=-1.0, max=1.0) + 1.0) / 2.0
fake_tensor = (torch.clamp(torch.cat(fake_list), min=-1.0, max=1.0) + 1.0) / 2.0
if len(real_tensor.shape) == 5: # temporal

real_tensor = real_tensor[:, 1]
fake_tensor = fake_tensor[:, 1]
ssim_test = ssim(real_tensor, fake_tensor)
Expand Down
6 changes: 2 additions & 4 deletions models/diffusion_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from .modules.palette_denoise_fn import PaletteDenoiseFn
from .modules.cm_generator import CMGenerator
from .modules.unet_generator_attn.unet_generator_attn_vid import UNetVid


def define_G(
Expand Down Expand Up @@ -128,10 +129,7 @@ def define_G(
)

elif G_netG == "unet_vid":
if model_prior_321_backwardcompatibility:
cond_embed_dim = G_ngf * 4
else:
cond_embed_dim = alg_diffusion_cond_embed_dim
cond_embed_dim = alg_diffusion_cond_embed_dim

model = UNetVid(
image_size=data_crop_size,
Expand Down
Loading

0 comments on commit 00f11bc

Please sign in to comment.