diff --git a/models/modules/diffusion_generator.py b/models/modules/diffusion_generator.py index 69bcb029e..2d7421aeb 100644 --- a/models/modules/diffusion_generator.py +++ b/models/modules/diffusion_generator.py @@ -14,8 +14,8 @@ q_posterior, gamma_embedding, extract, - rearrange_5dto4d, - rearrange_4dto5d, + rearrange_5dto4d_fh, + rearrange_4dto5d_fh, ) @@ -203,7 +203,7 @@ def p_mean_variance( sequence_length = 0 if len(y_t.shape) == 5: sequence_length = y_t.shape[1] - y_t, y_cond, mask = rearrange_5dto4d(y_t, y_cond, mask) + y_t, y_cond, mask = rearrange_5dto4d_fh(y_t, y_cond, mask) noise_level = self.extract( getattr(self.denoise_fn.model, "gammas_" + phase), t, x_shape=(1, 1) @@ -213,7 +213,7 @@ def p_mean_variance( input = torch.cat([y_cond, y_t], dim=1) if sequence_length != 0: - input, y_t, mask = rearrange_4dto5d(sequence_length, input, y_t, mask) + input, y_t, mask = rearrange_4dto5d_fh(sequence_length, input, y_t, mask) if guidance_scale > 0.0 and phase == "test": y_0_hat_uncond = predict_start_from_noise( @@ -451,8 +451,7 @@ def forward(self, y_0, y_cond, mask, noise, cls, ref, dropout_prob=0.0): # vid only if len(y_0.shape) == 5: sequence_length = y_0.shape[1] - y_0, y_cond, mask = rearrange_5dto4d(y_0, y_cond, mask) - + y_0, y_cond, mask = rearrange_5dto4d_fh(y_0, y_cond, mask) b, *_ = y_0.shape t = torch.randint( @@ -482,7 +481,9 @@ def forward(self, y_0, y_cond, mask, noise, cls, ref, dropout_prob=0.0): input = torch.cat([y_cond, y_noisy], dim=1) if sequence_length != 0: - input, mask, noise = rearrange_4dto5d(sequence_length, input, mask, noise) + input, mask, noise = rearrange_4dto5d_fh( + sequence_length, input, mask, noise + ) noise_hat = self.denoise_fn( input, embed_sample_gammas, cls=cls, mask=mask, ref=ref diff --git a/models/modules/diffusion_utils.py b/models/modules/diffusion_utils.py index a4a10865a..e4a71cf7d 100644 --- a/models/modules/diffusion_utils.py +++ b/models/modules/diffusion_utils.py @@ -137,13 +137,25 @@ def extract(a, t, x_shape=(1, 1, 1, 1)): return out.reshape(b, *((1,) * (len(x_shape) - 1))) -def rearrange_5dto4d(*tensors): +def rearrange_5dto4d_fh(*tensors): """Rearrange a tensor according to a given pattern using einops.rearrange.""" return [rearrange(tensor, "b f c h w -> b c (f h) w") for tensor in tensors] -def rearrange_4dto5d(frame, *tensors): +def rearrange_4dto5d_fh(frame, *tensors): """Rearrange a tensor from 4D to 5D according to a given pattern using einops.rearrange.""" return [ rearrange(tensor, "b c (f h) w -> b f c h w", f=frame) for tensor in tensors ] + + +def rearrange_5dto4d_bf(*tensors): + """Rearrange a tensor according to a given pattern using einops.rearrange.""" + return [rearrange(tensor, "b f c h w -> (b f) c h w") for tensor in tensors] + + +def rearrange_4dto5d_bf(frame, *tensors): + """Rearrange a tensor from 4D to 5D according to a given pattern using einops.rearrange.""" + return [ + rearrange(tensor, "(b f) c h w -> b f c h w", f=frame) for tensor in tensors + ] diff --git a/models/palette_model.py b/models/palette_model.py index bd07fe67a..187ba6879 100644 --- a/models/palette_model.py +++ b/models/palette_model.py @@ -3,7 +3,6 @@ import math import random import warnings - import torch import torchvision.transforms as T from torch import nn @@ -19,6 +18,8 @@ from .modules.loss import MultiScaleDiffusionLoss from .modules.unet_generator_attn.unet_attn_utils import revert_sync_batchnorm +from models.modules.diffusion_utils import rearrange_5dto4d_bf, rearrange_4dto5d_bf + class PaletteModel(BaseDiffusionModel): @staticmethod @@ -395,12 +396,37 @@ def set_input(self, data): if "canny" in fill_img_with_random_sketch.__name__: low = min(self.opt.alg_diffusion_cond_sketch_canny_range) high = max(self.opt.alg_diffusion_cond_sketch_canny_range) - self.cond_image = fill_img_with_random_sketch( - self.gt_image, - self.mask, - low_threshold_random=low, - high_threshold_random=high, + if ( + self.opt.G_netG == "unet_vid" + and self.opt.alg_diffusion_cond_image_creation + == "computed_sketch" + ): + self.mask, self.gt_image = rearrange_5dto4d_bf( + self.mask, self.gt_image + ) + + self.cond_image = ( + fill_img_with_random_sketch( # random canny in batch + self.gt_image, + self.mask, + low_threshold_random=low, + high_threshold_random=high, + ) ) + + if ( + self.opt.G_netG == "unet_vid" + and self.opt.alg_diffusion_cond_image_creation + == "computed_sketch" + ): + + self.mask, self.gt_image, self.cond_image = rearrange_4dto5d_bf( + self.opt.data_temporal_number_frames, + self.mask, + self.gt_image, + self.cond_image, + ) + elif "sam" in fill_img_with_random_sketch.__name__: self.cond_image = fill_img_with_random_sketch( self.gt_image, @@ -459,7 +485,6 @@ def compute_palette_loss(self): ref = self.ref_A else: ref = None - noise, noise_hat, min_snr_loss_weight = self.netG_A( y_0=y_0, y_cond=y_cond, noise=noise, mask=mask, cls=cls, ref=ref ) @@ -623,6 +648,7 @@ def inference(self, nb_imgs, offset=0): # other tasks else: + print("canny task inference nb_imgs , sample_num", nb_imgs, self.sample_num) self.output, self.visuals = netG.restoration( y_cond=self.cond_image[:nb_imgs], sample_num=self.sample_num ) diff --git a/util/mask_generation.py b/util/mask_generation.py index ec085a04d..9e9a9fbf0 100644 --- a/util/mask_generation.py +++ b/util/mask_generation.py @@ -37,27 +37,31 @@ def fill_img_with_sketch(img, mask, **kwargs): def fill_img_with_canny( img, mask, - low_threshold=None, - high_threshold=None, + cur_low_threshold=None, + cur_high_threshold=None, **kwargs, ): """Fill the masked region with canny edges.""" low_threshold_random = kwargs["low_threshold_random"] high_threshold_random = kwargs["high_threshold_random"] max_value = 255 * 3 - if high_threshold is None and low_threshold is None: - threshold_1 = random.randint(low_threshold_random, high_threshold_random) - threshold_2 = random.randint(low_threshold_random, high_threshold_random) - high_threshold = max(threshold_1, threshold_2) - low_threshold = min(threshold_1, threshold_2) - elif high_threshold is None and low_threshold is not None: - high_threshold = random.randint(low_threshold, max_value) - elif high_threshold is not None and low_threshold is None: - low_threshold = random.randint(0, high_threshold) - device = img.device edges_list = [] for cur_img in img: + high_threshold = cur_high_threshold + low_threshold = ( + cur_low_threshold # Reset thresholds for each image for new random + ) + + if high_threshold is None and low_threshold is None: + threshold_1 = random.randint(low_threshold_random, high_threshold_random) + threshold_2 = random.randint(low_threshold_random, high_threshold_random) + high_threshold = max(threshold_1, threshold_2) + low_threshold = min(threshold_1, threshold_2) + elif high_threshold is None and low_threshold is not None: + high_threshold = random.randint(low_threshold, max_value) + elif high_threshold is not None and low_threshold is None: + low_threshold = random.randint(0, high_threshold) cur_img = ( (torch.einsum("chw->hwc", cur_img).cpu().numpy() + 1) * 255 / 2 ).astype(np.uint8) @@ -70,7 +74,6 @@ def fill_img_with_canny( edges_list.append(edges) edges = torch.cat(edges_list, dim=0) mask = torch.clamp(mask, 0, 1) - return mask * edges + (1 - mask) * img