diff --git a/models/modules/diffusion_generator.py b/models/modules/diffusion_generator.py index 38f97e1fe..2d7421aeb 100644 --- a/models/modules/diffusion_generator.py +++ b/models/modules/diffusion_generator.py @@ -14,8 +14,8 @@ q_posterior, gamma_embedding, extract, - rearrange_5dto4d, - rearrange_4dto5d, + rearrange_5dto4d_fh, + rearrange_4dto5d_fh, ) @@ -203,9 +203,7 @@ def p_mean_variance( sequence_length = 0 if len(y_t.shape) == 5: sequence_length = y_t.shape[1] - y_t, y_cond, mask = rearrange_5dto4d( - "b f c h w -> b c (f h) w", y_t, y_cond, mask - ) + y_t, y_cond, mask = rearrange_5dto4d_fh(y_t, y_cond, mask) noise_level = self.extract( getattr(self.denoise_fn.model, "gammas_" + phase), t, x_shape=(1, 1) @@ -215,9 +213,7 @@ def p_mean_variance( input = torch.cat([y_cond, y_t], dim=1) if sequence_length != 0: - input, y_t, mask = rearrange_4dto5d( - "b c (f h) w -> b f c h w", sequence_length, input, y_t, mask - ) + input, y_t, mask = rearrange_4dto5d_fh(sequence_length, input, y_t, mask) if guidance_scale > 0.0 and phase == "test": y_0_hat_uncond = predict_start_from_noise( @@ -455,9 +451,7 @@ def forward(self, y_0, y_cond, mask, noise, cls, ref, dropout_prob=0.0): # vid only if len(y_0.shape) == 5: sequence_length = y_0.shape[1] - y_0, y_cond, mask = rearrange_5dto4d( - "b f c h w -> b c (f h) w", y_0, y_cond, mask - ) + y_0, y_cond, mask = rearrange_5dto4d_fh(y_0, y_cond, mask) b, *_ = y_0.shape t = torch.randint( @@ -487,8 +481,8 @@ def forward(self, y_0, y_cond, mask, noise, cls, ref, dropout_prob=0.0): input = torch.cat([y_cond, y_noisy], dim=1) if sequence_length != 0: - input, mask, noise = rearrange_4dto5d( - "b c (f h) w -> b f c h w", sequence_length, input, mask, noise + input, mask, noise = rearrange_4dto5d_fh( + sequence_length, input, mask, noise ) noise_hat = self.denoise_fn( diff --git a/models/modules/diffusion_utils.py b/models/modules/diffusion_utils.py index 51cdb296e..e4a71cf7d 100644 --- a/models/modules/diffusion_utils.py +++ b/models/modules/diffusion_utils.py @@ -137,11 +137,25 @@ def extract(a, t, x_shape=(1, 1, 1, 1)): return out.reshape(b, *((1,) * (len(x_shape) - 1))) -def rearrange_5dto4d(pattern, *tensors): +def rearrange_5dto4d_fh(*tensors): """Rearrange a tensor according to a given pattern using einops.rearrange.""" - return [rearrange(tensor, pattern) for tensor in tensors] + return [rearrange(tensor, "b f c h w -> b c (f h) w") for tensor in tensors] -def rearrange_4dto5d(pattern, frame, *tensors): +def rearrange_4dto5d_fh(frame, *tensors): """Rearrange a tensor from 4D to 5D according to a given pattern using einops.rearrange.""" - return [rearrange(tensor, pattern, f=frame) for tensor in tensors] + return [ + rearrange(tensor, "b c (f h) w -> b f c h w", f=frame) for tensor in tensors + ] + + +def rearrange_5dto4d_bf(*tensors): + """Rearrange a tensor according to a given pattern using einops.rearrange.""" + return [rearrange(tensor, "b f c h w -> (b f) c h w") for tensor in tensors] + + +def rearrange_4dto5d_bf(frame, *tensors): + """Rearrange a tensor from 4D to 5D according to a given pattern using einops.rearrange.""" + return [ + rearrange(tensor, "(b f) c h w -> b f c h w", f=frame) for tensor in tensors + ] diff --git a/models/palette_model.py b/models/palette_model.py index e0d416624..a8d1c550a 100644 --- a/models/palette_model.py +++ b/models/palette_model.py @@ -18,7 +18,7 @@ from .modules.loss import MultiScaleDiffusionLoss from .modules.unet_generator_attn.unet_attn_utils import revert_sync_batchnorm -from models.modules.diffusion_utils import rearrange_5dto4d, rearrange_4dto5d +from models.modules.diffusion_utils import rearrange_5dto4d_bf, rearrange_4dto5d_bf class PaletteModel(BaseDiffusionModel): @@ -401,8 +401,8 @@ def set_input(self, data): and self.opt.alg_diffusion_cond_image_creation == "computed_sketch" ): - self.mask, self.gt_image = rearrange_5dto4d( - "b f c h w -> (b f) c h w", self.mask, self.gt_image + self.mask, self.gt_image = rearrange_5dto4d_bf( + self.mask, self.gt_image ) self.cond_image = ( @@ -420,8 +420,7 @@ def set_input(self, data): == "computed_sketch" ): - self.mask, self.gt_image, self.cond_image = rearrange_4dto5d( - "(b f) c h w -> b f c h w", + self.mask, self.gt_image, self.cond_image = rearrange_4dto5d_bf( self.opt.data_temporal_number_frames, self.mask, self.gt_image,