Skip to content

Commit

Permalink
feat(ml): random canny for batch
Browse files Browse the repository at this point in the history
  • Loading branch information
wr0124 committed Sep 13, 2024
1 parent f697692 commit 29b4f70
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 49 deletions.
17 changes: 12 additions & 5 deletions models/modules/diffusion_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,9 @@ def p_mean_variance(
sequence_length = 0
if len(y_t.shape) == 5:
sequence_length = y_t.shape[1]
y_t, y_cond, mask = rearrange_5dto4d(y_t, y_cond, mask)
y_t, y_cond, mask = rearrange_5dto4d(
"b f c h w -> b c (f h) w", y_t, y_cond, mask
)

noise_level = self.extract(
getattr(self.denoise_fn.model, "gammas_" + phase), t, x_shape=(1, 1)
Expand All @@ -213,7 +215,9 @@ def p_mean_variance(

input = torch.cat([y_cond, y_t], dim=1)
if sequence_length != 0:
input, y_t, mask = rearrange_4dto5d(sequence_length, input, y_t, mask)
input, y_t, mask = rearrange_4dto5d(
"b c (f h) w -> b f c h w", sequence_length, input, y_t, mask
)

if guidance_scale > 0.0 and phase == "test":
y_0_hat_uncond = predict_start_from_noise(
Expand Down Expand Up @@ -451,8 +455,9 @@ def forward(self, y_0, y_cond, mask, noise, cls, ref, dropout_prob=0.0):
# vid only
if len(y_0.shape) == 5:
sequence_length = y_0.shape[1]
y_0, y_cond, mask = rearrange_5dto4d(y_0, y_cond, mask)

y_0, y_cond, mask = rearrange_5dto4d(
"b f c h w -> b c (f h) w", y_0, y_cond, mask
)
b, *_ = y_0.shape

t = torch.randint(
Expand Down Expand Up @@ -482,7 +487,9 @@ def forward(self, y_0, y_cond, mask, noise, cls, ref, dropout_prob=0.0):
input = torch.cat([y_cond, y_noisy], dim=1)

if sequence_length != 0:
input, mask, noise = rearrange_4dto5d(sequence_length, input, mask, noise)
input, mask, noise = rearrange_4dto5d(
"b c (f h) w -> b f c h w", sequence_length, input, mask, noise
)

noise_hat = self.denoise_fn(
input, embed_sample_gammas, cls=cls, mask=mask, ref=ref
Expand Down
10 changes: 4 additions & 6 deletions models/modules/diffusion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,13 +137,11 @@ def extract(a, t, x_shape=(1, 1, 1, 1)):
return out.reshape(b, *((1,) * (len(x_shape) - 1)))


def rearrange_5dto4d(*tensors):
def rearrange_5dto4d(pattern, *tensors):
"""Rearrange a tensor according to a given pattern using einops.rearrange."""
return [rearrange(tensor, "b f c h w -> b c (f h) w") for tensor in tensors]
return [rearrange(tensor, pattern) for tensor in tensors]


def rearrange_4dto5d(frame, *tensors):
def rearrange_4dto5d(pattern, frame, *tensors):
"""Rearrange a tensor from 4D to 5D according to a given pattern using einops.rearrange."""
return [
rearrange(tensor, "b c (f h) w -> b f c h w", f=frame) for tensor in tensors
]
return [rearrange(tensor, pattern, f=frame) for tensor in tensors]
49 changes: 23 additions & 26 deletions models/palette_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import math
import random
import warnings
from einops import rearrange
import torch
import torchvision.transforms as T
from torch import nn
Expand All @@ -19,6 +18,8 @@
from .modules.loss import MultiScaleDiffusionLoss
from .modules.unet_generator_attn.unet_attn_utils import revert_sync_batchnorm

from models.modules.diffusion_utils import rearrange_5dto4d, rearrange_4dto5d


class PaletteModel(BaseDiffusionModel):
@staticmethod
Expand Down Expand Up @@ -401,37 +402,33 @@ def set_input(self, data):
and self.opt.alg_diffusion_cond_image_creation
== "computed_sketch"
):
result = []
for batch_idx in range(self.opt.train_batch_size):
frames = []
for frame_idx in range(
self.opt.data_temporal_number_frames
):
gt_image_temp = self.gt_image[
batch_idx, frame_idx, :, :, :
].unsqueeze(0)
mask_temp = self.mask[
batch_idx, frame_idx, :, :, :
].unsqueeze(0)
cond_image_temp = fill_img_with_random_sketch(
gt_image_temp,
mask_temp,
low_threshold_random=low,
high_threshold_random=high,
)
frames.append(cond_image_temp)
result.append(
torch.stack(frames, dim=1)
) # Rebuild the frame dimension
self.cond_image = torch.cat(result, dim=0)
self.mask, self.gt_image = rearrange_5dto4d(
"b f c h w -> (b f) c h w", self.mask, self.gt_image
)

else:
self.cond_image = fill_img_with_random_sketch(
self.cond_image = (
fill_img_with_random_sketch( # random canny in batch
self.gt_image,
self.mask,
low_threshold_random=low,
high_threshold_random=high,
)
)

if (
self.opt.G_netG == "unet_vid"
and self.opt.alg_diffusion_cond_image_creation
== "computed_sketch"
):

self.mask, self.gt_image, self.cond_image = rearrange_4dto5d(
"(b f) c h w -> b f c h w",
self.opt.data_temporal_number_frames,
self.mask,
self.gt_image,
self.cond_image,
)

elif "sam" in fill_img_with_random_sketch.__name__:
self.cond_image = fill_img_with_random_sketch(
self.gt_image,
Expand Down
28 changes: 16 additions & 12 deletions util/mask_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,27 +37,31 @@ def fill_img_with_sketch(img, mask, **kwargs):
def fill_img_with_canny(
img,
mask,
low_threshold=None,
high_threshold=None,
cur_low_threshold=None,
cur_high_threshold=None,
**kwargs,
):
"""Fill the masked region with canny edges."""
low_threshold_random = kwargs["low_threshold_random"]
high_threshold_random = kwargs["high_threshold_random"]
max_value = 255 * 3
if high_threshold is None and low_threshold is None:
threshold_1 = random.randint(low_threshold_random, high_threshold_random)
threshold_2 = random.randint(low_threshold_random, high_threshold_random)
high_threshold = max(threshold_1, threshold_2)
low_threshold = min(threshold_1, threshold_2)
elif high_threshold is None and low_threshold is not None:
high_threshold = random.randint(low_threshold, max_value)
elif high_threshold is not None and low_threshold is None:
low_threshold = random.randint(0, high_threshold)

device = img.device
edges_list = []
for cur_img in img:
high_threshold = cur_high_threshold
low_threshold = (
cur_low_threshold # Reset thresholds for each image for new random
)

if high_threshold is None and low_threshold is None:
threshold_1 = random.randint(low_threshold_random, high_threshold_random)
threshold_2 = random.randint(low_threshold_random, high_threshold_random)
high_threshold = max(threshold_1, threshold_2)
low_threshold = min(threshold_1, threshold_2)
elif high_threshold is None and low_threshold is not None:
high_threshold = random.randint(low_threshold, max_value)
elif high_threshold is not None and low_threshold is None:
low_threshold = random.randint(0, high_threshold)
cur_img = (
(torch.einsum("chw->hwc", cur_img).cpu().numpy() + 1) * 255 / 2
).astype(np.uint8)
Expand Down

0 comments on commit 29b4f70

Please sign in to comment.