From 733ca37f7084354c74bfa8b08384f62bf39c33f4 Mon Sep 17 00:00:00 2001 From: beniz Date: Thu, 29 Aug 2024 10:55:44 +0200 Subject: [PATCH] fix: unet vid with bs > 1 --- examples/example_ddpm_vid_mario.json | 2 +- models/modules/diffusion_generator.py | 39 ++++++-------- models/modules/diffusion_utils.py | 4 +- .../unet_generator_attn_vid.py | 53 +++++++++++-------- models/palette_model.py | 3 -- train.py | 8 ++- 6 files changed, 55 insertions(+), 54 deletions(-) diff --git a/examples/example_ddpm_vid_mario.json b/examples/example_ddpm_vid_mario.json index a7f18822d..a767e27f5 100644 --- a/examples/example_ddpm_vid_mario.json +++ b/examples/example_ddpm_vid_mario.json @@ -270,7 +270,7 @@ "pool_size": 50, "save_by_iter": false, "save_epoch_freq": 1, - "save_latest_freq": 5000, + "save_latest_freq": 1000, "semantic_cls": false, "semantic_mask": false, "temporal_criterion": false, diff --git a/models/modules/diffusion_generator.py b/models/modules/diffusion_generator.py index c867d991c..42602d85d 100644 --- a/models/modules/diffusion_generator.py +++ b/models/modules/diffusion_generator.py @@ -130,10 +130,6 @@ def restoration_ddpm( ref, ): phase = "test" - frame = 0 # To adapt a temporal video, convert the 5D tensor to a 4D tensor to align with the standard DDPM process - if len(y_cond.shape) == 5: - frame = y_cond.shape[1] - y_cond, y_t, y_0, mask = rearrange_5dto4d(y_cond, y_t, y_0, mask) b, *_ = y_cond.shape @@ -144,8 +140,6 @@ def restoration_ddpm( y_t = self.default(y_t, lambda: torch.randn_like(y_cond)) ret_arr = y_t - if frame != 0: # the temporal video use 5D tensor as input - y_t, y_cond, mask = rearrange_4dto5d(frame, y_t, y_cond, mask) for i in tqdm( reversed(range(0, self.denoise_fn.model.num_timesteps_test)), @@ -164,17 +158,12 @@ def restoration_ddpm( ref=ref, guidance_scale=guidance_scale, ) - if frame != 0: - mask = rearrange(mask, "b f c h w -> (b f) c h w") + if mask is not None: temp_mask = torch.clamp(mask, min=0.0, max=1.0) y_t = y_0 * (1.0 - temp_mask) + temp_mask * y_t if i % sample_inter == 0: ret_arr = torch.cat([ret_arr, y_t], dim=0) - if frame != 0: - y_t, mask = rearrange_4dto5d(frame, y_t, mask) - if frame != 0: - ret_arr = rearrange(ret_arr, "(b f) c h w -> b f c h w", f=frame) return y_t, ret_arr @@ -203,9 +192,9 @@ def p_mean_variance( y_cond=None, guidance_scale=0.0, ): - frame = 0 + sequence_length = 0 if len(y_t.shape) == 5: - frame = y_t.shape[1] + sequence_length = y_t.shape[1] y_t, y_cond, mask = rearrange_5dto4d(y_t, y_cond, mask) noise_level = self.extract( @@ -215,8 +204,8 @@ def p_mean_variance( embed_noise_level = self.compute_gammas(noise_level) input = torch.cat([y_cond, y_t], dim=1) - if frame != 0: - input, mask = rearrange_4dto5d(frame, input, mask) + if sequence_length != 0: + input, y_t, mask = rearrange_4dto5d(sequence_length, input, y_t, mask) if guidance_scale > 0.0 and phase == "test": y_0_hat_uncond = predict_start_from_noise( @@ -268,6 +257,7 @@ def p_sample( y_cond=None, guidance_scale=0.0, ): + model_mean, model_log_variance = self.p_mean_variance( y_t=y_t, t=t, @@ -279,12 +269,10 @@ def p_sample( ref=ref, guidance_scale=guidance_scale, ) - frame = 0 - if len(y_t.shape) == 5: - y_t = rearrange(y_t, "b f c h w -> (b f) c h w") - noise = torch.randn_like(y_t) if any(t > 0) else torch.zeros_like(y_t) + out = model_mean + noise * (0.5 * model_log_variance).exp() + return out ## DDIM @@ -450,9 +438,9 @@ def ddim_p_mean_variance( return model_mean, posterior_log_variance def forward(self, y_0, y_cond, mask, noise, cls, ref, dropout_prob=0.0): - frame = 0 + sequence_length = 0 if len(y_0.shape) == 5: - frame = y_0.shape[1] + sequence_length = y_0.shape[1] y_0, y_cond, mask = rearrange_5dto4d(y_0, y_cond, mask) b, *_ = y_0.shape @@ -483,11 +471,13 @@ def forward(self, y_0, y_cond, mask, noise, cls, ref, dropout_prob=0.0): input = torch.cat([y_cond, y_noisy], dim=1) - if frame != 0: - input, mask = rearrange_4dto5d(frame, input, mask) + if sequence_length != 0: + input, mask, noise = rearrange_4dto5d(sequence_length, input, mask, noise) + noise_hat = self.denoise_fn( input, embed_sample_gammas, cls=cls, mask=mask, ref=ref ) + # min-SNR loss weight phase = "train" ksnr = 5.0 @@ -506,6 +496,7 @@ def forward(self, y_0, y_cond, mask, noise, cls, ref, dropout_prob=0.0): ) # reshape min_snr_loss_weight to match noise_hat min_snr_loss_weight = min_snr_loss_weight.view(-1, 1, 1, 1) + return noise, noise_hat, min_snr_loss_weight def set_new_sampling_method(self, sampling_method): diff --git a/models/modules/diffusion_utils.py b/models/modules/diffusion_utils.py index 33cade079..a4a10865a 100644 --- a/models/modules/diffusion_utils.py +++ b/models/modules/diffusion_utils.py @@ -139,11 +139,11 @@ def extract(a, t, x_shape=(1, 1, 1, 1)): def rearrange_5dto4d(*tensors): """Rearrange a tensor according to a given pattern using einops.rearrange.""" - return [rearrange(tensor, "b f c h w -> (b f) c h w") for tensor in tensors] + return [rearrange(tensor, "b f c h w -> b c (f h) w") for tensor in tensors] def rearrange_4dto5d(frame, *tensors): """Rearrange a tensor from 4D to 5D according to a given pattern using einops.rearrange.""" return [ - rearrange(tensor, "(b f) c h w -> b f c h w", f=frame) for tensor in tensors + rearrange(tensor, "b c (f h) w -> b f c h w", f=frame) for tensor in tensors ] diff --git a/models/modules/unet_generator_attn/unet_generator_attn_vid.py b/models/modules/unet_generator_attn/unet_generator_attn_vid.py index 035bc0e7b..6ccea4bfe 100644 --- a/models/modules/unet_generator_attn/unet_generator_attn_vid.py +++ b/models/modules/unet_generator_attn/unet_generator_attn_vid.py @@ -240,36 +240,49 @@ def _forward(self, x, emb): if self.updown: in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] - h = in_rest(x) + hd = in_rest(x) if self.efficient and self.up: - h = in_conv(h) - h = self.h_upd(h) + hd = in_conv(hd) + hd = self.h_upd(hd) x = self.x_upd(x) else: - h = self.h_upd(h) + hd = self.h_upd(hd) x = self.x_upd(x) - h = in_conv(h) - + hd = in_conv(hd) else: - h = self.in_layers(x) - emb_out = self.emb_layers(emb).type(h.dtype) - while len(emb_out.shape) < len(h.shape): + hd = self.in_layers(x) + + bf, c, h, w = hd.shape + b = bf // f + f = bf // b + hd = hd.view(b, f, c, h, w) + + emb_out = self.emb_layers(emb).type(hd.dtype) + emb_out = emb_out.unsqueeze(1) + while len(emb_out.shape) < len(hd.shape): emb_out = emb_out.unsqueeze(-1) - # emb_out = emb_out[..., None] + if self.use_scale_shift_norm: out_norm, out_rest = self.out_layers[0], self.out_layers[1:] - scale, shift = torch.chunk(emb_out, 2, dim=1) - h = out_norm(h) * (1 + scale) + shift - h = out_rest(h) + scale, shift = torch.chunk(emb_out, 2, dim=2) + + ##TODO: simplify by receiving the embedding in 'packed' shape ? + hd = hd.view(b * f, c, h, w) + hd = out_norm(hd) + hd = hd.view(b, f, c, h, w) + hd *= (1 + scale) + shift + hd = hd.view(b * f, c, h, w) + hd = out_rest(hd) else: - h = h + emb_out - h = self.out_layers(h) + hd = hd + emb_out + hd = hd.view(b * f, c, h, w) + hd = self.out_layers(hd) skipw = 1.0 if self.efficient: skipw = 1.0 / math.sqrt(2) - output = self.skip_connection(x) + h + output = self.skip_connection(x) + hd bf, c, h, w = output.shape b = bf // f f = bf // b @@ -1355,11 +1368,6 @@ def __init__( } def compute_feats(self, input, embed_gammas): - if embed_gammas is None: - # Only for GAN - b = (input.shape[0], self.cond_embed_dim) - embed_gammas = torch.ones(b).to(input.device) - emb = embed_gammas hs = [] @@ -1378,7 +1386,6 @@ def compute_feats(self, input, embed_gammas): return outs, feats, emb def forward(self, input, embed_gammas=None): - h, hs, emb = self.compute_feats(input, embed_gammas=embed_gammas) for i, module in enumerate(self.output_blocks): h = torch.cat([h, hs.pop()], dim=2) @@ -1392,7 +1399,7 @@ def forward(self, input, embed_gammas=None): if self.freq_space: outh = self.iwt(outh) outh = outh.reshape(b, f, -1, h_dim, w_dim) - outh = rearrange(outh, "b f c h w -> (b f) c h w") + return outh def get_feats(self, input, extract_layer_ids): diff --git a/models/palette_model.py b/models/palette_model.py index 6c04bb4d4..bd07fe67a 100644 --- a/models/palette_model.py +++ b/models/palette_model.py @@ -18,7 +18,6 @@ from .base_diffusion_model import BaseDiffusionModel from .modules.loss import MultiScaleDiffusionLoss from .modules.unet_generator_attn.unet_attn_utils import revert_sync_batchnorm -from einops import rearrange class PaletteModel(BaseDiffusionModel): @@ -467,7 +466,6 @@ def compute_palette_loss(self): frame = 0 if len(y_0.shape) == 5: frame = y_0.shape[1] - mask = rearrange(mask, "b f c h w -> (b f) c h w") if not self.opt.alg_palette_minsnr: min_snr_loss_weight = 1.0 @@ -478,7 +476,6 @@ def compute_palette_loss(self): min_snr_loss_weight * mask_binary * noise, min_snr_loss_weight * mask_binary * noise_hat, ) - else: loss = self.loss_fn( min_snr_loss_weight * noise, min_snr_loss_weight * noise_hat diff --git a/train.py b/train.py index 93dc2cabd..29d1275b9 100644 --- a/train.py +++ b/train.py @@ -32,6 +32,7 @@ from util.visualizer import Visualizer from util.lion_pytorch import Lion from util.script import get_override_options_names +import datetime def setup(rank, world_size, port): @@ -39,7 +40,12 @@ def setup(rank, world_size, port): os.environ["MASTER_PORT"] = port # initialize the process group - dist.init_process_group("nccl", rank=rank, world_size=world_size) + dist.init_process_group( + "nccl", + rank=rank, + world_size=world_size, + timeout=datetime.timedelta(seconds=5400), + ) # modified timeout from default 10 or 30 mins (?) to 1.5h def optim(opt, params, lr, betas, weight_decay, eps):