Skip to content

Commit

Permalink
feat(ml): canny dropout for vid
Browse files Browse the repository at this point in the history
  • Loading branch information
wr0124 committed Sep 3, 2024
1 parent 03c0bec commit 2ffe84b
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 7 deletions.
59 changes: 52 additions & 7 deletions models/palette_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import math
import random
import warnings

from einops import rearrange
import torch
import torchvision.transforms as T
from torch import nn
Expand All @@ -13,7 +13,7 @@
from util.iter_calculator import IterCalculator
from util.mask_generation import random_edge_mask
from util.network_group import NetworkGroup

from util.mask_generation import fill_mask_with_canny_dropout
from . import diffusion_networks
from .base_diffusion_model import BaseDiffusionModel
from .modules.loss import MultiScaleDiffusionLoss
Expand Down Expand Up @@ -333,7 +333,7 @@ def set_input(self, data):

if self.use_ref:
self.ref_A = data["ref_A"].to(self.device)

sequence_length = 0
if self.opt.alg_diffusion_cond_image_creation == "y_t":
self.cond_image = self.y_t
elif self.opt.alg_diffusion_cond_image_creation == "previous_frame":
Expand Down Expand Up @@ -393,12 +393,47 @@ def set_input(self, data):
if "canny" in fill_img_with_random_sketch.__name__:
low = min(self.opt.alg_diffusion_cond_sketch_canny_range)
high = max(self.opt.alg_diffusion_cond_sketch_canny_range)
self.cond_image = fill_img_with_random_sketch(
if self.opt.G_netG == "unet_vid":
self.gt_image,
sequence_length = self.gt_image.shape[1]
self.mask,
low_threshold_random=low,
high_threshold_random=high,
)
self.mask = rearrange(self.mask, "b f c h w -> (b f) c h w")
low_threshold_random = (low,)
self.gt_image = rearrange(
self.gt_image, "b f c h w -> (b f) c h w"
)
cond_image_canny = fill_img_with_random_sketch(
self.gt_image,
self.mask,
low_threshold_random=low,
high_threshold_random=high,
vary_thresholds=True,
)

cond_image_nocanny, frame_select = fill_mask_with_canny_dropout(
self.gt_image,
self.mask,
self.opt.data_temporal_number_frames,
0.9,
)

print(" frame_select, ", frame_select)
frame_select = torch.tensor(frame_select, dtype=torch.bool)
frame_select = frame_select.to(cond_image_canny.device)
self.cond_image = torch.where(
frame_select.unsqueeze(1).unsqueeze(2).unsqueeze(3),
cond_image_canny,
cond_image_nocanny,
)

else:
self.cond_image = fill_img_with_random_sketch(
self.gt_image,
self.mask,
low_threshold_random=low,
high_threshold_random=high,
)

elif "sam" in fill_img_with_random_sketch.__name__:
self.cond_image = fill_img_with_random_sketch(
self.gt_image,
Expand All @@ -420,6 +455,16 @@ def set_input(self, data):
elif self.opt.alg_diffusion_cond_image_creation == "ref":
self.cond_image = self.ref_A

if self.opt.G_netG == "unet_vid":
self.cond_image = rearrange(
self.cond_image, "(b f) c h w -> b f c h w", f=sequence_length
)
self.gt_image = rearrange(
self.gt_image, "(b f) c h w -> b f c h w", f=sequence_length
)
self.mask = rearrange(
self.mask, "(b f) c h w -> b f c h w", f=sequence_length
)
self.batch_size = self.cond_image.shape[0]

self.real_A = self.cond_image
Expand Down
67 changes: 67 additions & 0 deletions util/mask_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,79 @@ def fill_img_with_sketch(img, mask, **kwargs):
return mask * thresh + (1 - mask) * img


def fill_mask_with_canny_dropout(
img,
mask,
sequence_length,
dropout=0.9,
**kwargs,
):
"""Fill the masked region with fill_value"""
fill_tensor = torch.full_like(mask, -1)
mask = torch.clamp(mask, 0, 1)
output_img = mask * fill_tensor + (1 - mask) * img
frame_dropout = [
1 if random.random() > dropout else 0 for _ in range(sequence_length)
]

return output_img, frame_dropout


def fill_img_with_canny(
img,
mask,
low_threshold=None,
high_threshold=None,
**kwargs,
):
"""Fill the masked region with canny edges."""
low_threshold_random = kwargs["low_threshold_random"]
high_threshold_random = kwargs["high_threshold_random"]
vary_thresholds = kwargs.get("vary_thresholds", False)
max_value = 255 * 3
device = img.device
edges_list = []
threshold_pairs = []
for _ in range(img.shape[0]):
if (high_threshold is None and low_threshold is None) or (
vary_thresholds == True
):
threshold_1 = random.randint(low_threshold_random, high_threshold_random)
threshold_2 = random.randint(low_threshold_random, high_threshold_random)
high_threshold = max(threshold_1, threshold_2)
low_threshold = min(threshold_1, threshold_2)
elif high_threshold is None and low_threshold is not None:
high_threshold = random.randint(low_threshold, max_value)
elif high_threshold is not None and low_threshold is None:
low_threshold = random.randint(0, high_threshold)
threshold_pairs.append((low_threshold, high_threshold))
for idx, cur_img in enumerate(img):

if vary_thresholds:
low_threshold, high_threshold = threshold_pairs[idx]
else:
low_threshold, high_threshold = threshold_pairs[0]
cur_img = (
(torch.einsum("chw->hwc", cur_img).cpu().numpy() + 1) * 255 / 2
).astype(np.uint8)
edges = cv2.Canny(cur_img, low_threshold, high_threshold)
edges = (
(((torch.tensor(edges, device=device) / 255) * 2) - 1)
.unsqueeze(0)
.unsqueeze(0)
)
edges_list.append(edges)
edges = torch.cat(edges_list, dim=0)
mask = torch.clamp(mask, 0, 1)
return mask * edges + (1 - mask) * img


def fill_img_with_canny_ori(
img,
mask,
low_threshold=None,
high_threshold=None,
**kwargs,
):
"""Fill the masked region with canny edges."""
low_threshold_random = kwargs["low_threshold_random"]
Expand Down

0 comments on commit 2ffe84b

Please sign in to comment.