Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

random canny in batch #683

Merged
merged 2 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions models/modules/diffusion_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
q_posterior,
gamma_embedding,
extract,
rearrange_5dto4d,
rearrange_4dto5d,
rearrange_5dto4d_fh,
rearrange_4dto5d_fh,
)


Expand Down Expand Up @@ -203,7 +203,7 @@ def p_mean_variance(
sequence_length = 0
if len(y_t.shape) == 5:
sequence_length = y_t.shape[1]
y_t, y_cond, mask = rearrange_5dto4d(y_t, y_cond, mask)
y_t, y_cond, mask = rearrange_5dto4d_fh(y_t, y_cond, mask)

noise_level = self.extract(
getattr(self.denoise_fn.model, "gammas_" + phase), t, x_shape=(1, 1)
Expand All @@ -213,7 +213,7 @@ def p_mean_variance(

input = torch.cat([y_cond, y_t], dim=1)
if sequence_length != 0:
input, y_t, mask = rearrange_4dto5d(sequence_length, input, y_t, mask)
input, y_t, mask = rearrange_4dto5d_fh(sequence_length, input, y_t, mask)

if guidance_scale > 0.0 and phase == "test":
y_0_hat_uncond = predict_start_from_noise(
Expand Down Expand Up @@ -451,8 +451,7 @@ def forward(self, y_0, y_cond, mask, noise, cls, ref, dropout_prob=0.0):
# vid only
if len(y_0.shape) == 5:
sequence_length = y_0.shape[1]
y_0, y_cond, mask = rearrange_5dto4d(y_0, y_cond, mask)

y_0, y_cond, mask = rearrange_5dto4d_fh(y_0, y_cond, mask)
b, *_ = y_0.shape

t = torch.randint(
Expand Down Expand Up @@ -482,7 +481,9 @@ def forward(self, y_0, y_cond, mask, noise, cls, ref, dropout_prob=0.0):
input = torch.cat([y_cond, y_noisy], dim=1)

if sequence_length != 0:
input, mask, noise = rearrange_4dto5d(sequence_length, input, mask, noise)
input, mask, noise = rearrange_4dto5d_fh(
sequence_length, input, mask, noise
)

noise_hat = self.denoise_fn(
input, embed_sample_gammas, cls=cls, mask=mask, ref=ref
Expand Down
16 changes: 14 additions & 2 deletions models/modules/diffusion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,13 +137,25 @@ def extract(a, t, x_shape=(1, 1, 1, 1)):
return out.reshape(b, *((1,) * (len(x_shape) - 1)))


def rearrange_5dto4d(*tensors):
def rearrange_5dto4d_fh(*tensors):
"""Rearrange a tensor according to a given pattern using einops.rearrange."""
return [rearrange(tensor, "b f c h w -> b c (f h) w") for tensor in tensors]


def rearrange_4dto5d(frame, *tensors):
def rearrange_4dto5d_fh(frame, *tensors):
"""Rearrange a tensor from 4D to 5D according to a given pattern using einops.rearrange."""
return [
rearrange(tensor, "b c (f h) w -> b f c h w", f=frame) for tensor in tensors
]


def rearrange_5dto4d_bf(*tensors):
"""Rearrange a tensor according to a given pattern using einops.rearrange."""
return [rearrange(tensor, "b f c h w -> (b f) c h w") for tensor in tensors]


def rearrange_4dto5d_bf(frame, *tensors):
"""Rearrange a tensor from 4D to 5D according to a given pattern using einops.rearrange."""
return [
rearrange(tensor, "(b f) c h w -> b f c h w", f=frame) for tensor in tensors
]
40 changes: 33 additions & 7 deletions models/palette_model.py
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when using the original method fill_img_with_random_sketch, the inputs gt_image and mask need to be sliced to process them one by one to produce random Canny. Using two loops would not be efficient. How do you think to change the method fill_img_with_random_sketch by adding some additional variable to make it more flexible for all the case?(random canny/fixed canny/dropout canny)

Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import math
import random
import warnings

import torch
import torchvision.transforms as T
from torch import nn
Expand All @@ -19,6 +18,8 @@
from .modules.loss import MultiScaleDiffusionLoss
from .modules.unet_generator_attn.unet_attn_utils import revert_sync_batchnorm

from models.modules.diffusion_utils import rearrange_5dto4d_bf, rearrange_4dto5d_bf


class PaletteModel(BaseDiffusionModel):
@staticmethod
Expand Down Expand Up @@ -395,12 +396,37 @@ def set_input(self, data):
if "canny" in fill_img_with_random_sketch.__name__:
low = min(self.opt.alg_diffusion_cond_sketch_canny_range)
high = max(self.opt.alg_diffusion_cond_sketch_canny_range)
self.cond_image = fill_img_with_random_sketch(
self.gt_image,
self.mask,
low_threshold_random=low,
high_threshold_random=high,
if (
self.opt.G_netG == "unet_vid"
and self.opt.alg_diffusion_cond_image_creation
== "computed_sketch"
):
self.mask, self.gt_image = rearrange_5dto4d_bf(
self.mask, self.gt_image
)

self.cond_image = (
fill_img_with_random_sketch( # random canny in batch
self.gt_image,
self.mask,
low_threshold_random=low,
high_threshold_random=high,
)
)

if (
self.opt.G_netG == "unet_vid"
and self.opt.alg_diffusion_cond_image_creation
== "computed_sketch"
):

self.mask, self.gt_image, self.cond_image = rearrange_4dto5d_bf(
self.opt.data_temporal_number_frames,
self.mask,
self.gt_image,
self.cond_image,
)

elif "sam" in fill_img_with_random_sketch.__name__:
self.cond_image = fill_img_with_random_sketch(
self.gt_image,
Expand Down Expand Up @@ -459,7 +485,6 @@ def compute_palette_loss(self):
ref = self.ref_A
else:
ref = None

noise, noise_hat, min_snr_loss_weight = self.netG_A(
y_0=y_0, y_cond=y_cond, noise=noise, mask=mask, cls=cls, ref=ref
)
Expand Down Expand Up @@ -623,6 +648,7 @@ def inference(self, nb_imgs, offset=0):

# other tasks
else:
print("canny task inference nb_imgs , sample_num", nb_imgs, self.sample_num)
self.output, self.visuals = netG.restoration(
y_cond=self.cond_image[:nb_imgs], sample_num=self.sample_num
)
Expand Down
29 changes: 16 additions & 13 deletions util/mask_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,27 +37,31 @@ def fill_img_with_sketch(img, mask, **kwargs):
def fill_img_with_canny(
img,
mask,
low_threshold=None,
high_threshold=None,
cur_low_threshold=None,
cur_high_threshold=None,
**kwargs,
):
"""Fill the masked region with canny edges."""
low_threshold_random = kwargs["low_threshold_random"]
high_threshold_random = kwargs["high_threshold_random"]
max_value = 255 * 3
if high_threshold is None and low_threshold is None:
threshold_1 = random.randint(low_threshold_random, high_threshold_random)
threshold_2 = random.randint(low_threshold_random, high_threshold_random)
high_threshold = max(threshold_1, threshold_2)
low_threshold = min(threshold_1, threshold_2)
elif high_threshold is None and low_threshold is not None:
high_threshold = random.randint(low_threshold, max_value)
elif high_threshold is not None and low_threshold is None:
low_threshold = random.randint(0, high_threshold)

device = img.device
edges_list = []
for cur_img in img:
high_threshold = cur_high_threshold
low_threshold = (
cur_low_threshold # Reset thresholds for each image for new random
)

if high_threshold is None and low_threshold is None:
threshold_1 = random.randint(low_threshold_random, high_threshold_random)
threshold_2 = random.randint(low_threshold_random, high_threshold_random)
high_threshold = max(threshold_1, threshold_2)
low_threshold = min(threshold_1, threshold_2)
elif high_threshold is None and low_threshold is not None:
high_threshold = random.randint(low_threshold, max_value)
elif high_threshold is not None and low_threshold is None:
low_threshold = random.randint(0, high_threshold)
cur_img = (
(torch.einsum("chw->hwc", cur_img).cpu().numpy() + 1) * 255 / 2
).astype(np.uint8)
Expand All @@ -70,7 +74,6 @@ def fill_img_with_canny(
edges_list.append(edges)
edges = torch.cat(edges_list, dim=0)
mask = torch.clamp(mask, 0, 1)

return mask * edges + (1 - mask) * img


Expand Down
Loading