Skip to content

Commit

Permalink
feat(ml): unchange fill_img_with_canny with random drop canny
Browse files Browse the repository at this point in the history
  • Loading branch information
wr0124 authored and beniz committed Sep 18, 2024
1 parent 06ce7d7 commit a2ed3fc
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 71 deletions.
6 changes: 6 additions & 0 deletions models/base_diffusion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,12 @@ def modify_commandline_options_train(parser):
choices=["clip", "imagebind"],
help="embedding network to use for ref conditioning",
)
parser.add_argument(
"--alg_diffusion_vid_canny_dropout",
type=float,
default=0,
help="prob to drop canny for each frame",
)

return parser

Expand Down
1 change: 1 addition & 0 deletions models/modules/diffusion_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,7 @@ def forward(self, y_0, y_cond, mask, noise, cls, ref, dropout_prob=0.0):
if len(y_0.shape) == 5:
sequence_length = y_0.shape[1]
y_0, y_cond, mask = rearrange_5dto4d_fh(y_0, y_cond, mask)

b, *_ = y_0.shape

t = torch.randint(
Expand Down
81 changes: 10 additions & 71 deletions util/mask_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,24 +34,6 @@ def fill_img_with_sketch(img, mask, **kwargs):
return mask * thresh + (1 - mask) * img


def fill_mask_with_canny_dropout(
img,
mask,
sequence_length,
dropout=0.9,
**kwargs,
):
"""Fill the masked region with fill_value"""
fill_tensor = torch.full_like(mask, -1)
mask = torch.clamp(mask, 0, 1)
output_img = mask * fill_tensor + (1 - mask) * img
frame_dropout = [
1 if random.random() > dropout else 0 for _ in range(sequence_length)
]

return output_img, frame_dropout


def fill_img_with_canny(
img,
mask,
Expand All @@ -62,63 +44,14 @@ def fill_img_with_canny(
"""Fill the masked region with canny edges."""
low_threshold_random = kwargs["low_threshold_random"]
high_threshold_random = kwargs["high_threshold_random"]
vary_thresholds = kwargs.get("vary_thresholds", False)
canny_list = kwargs["select_canny"]
max_value = 255 * 3
device = img.device
edges_list = []
threshold_pairs = []
for _ in range(img.shape[0]):
if (high_threshold is None and low_threshold is None) or (
vary_thresholds == True
):
threshold_1 = random.randint(low_threshold_random, high_threshold_random)
threshold_2 = random.randint(low_threshold_random, high_threshold_random)
high_threshold = max(threshold_1, threshold_2)
low_threshold = min(threshold_1, threshold_2)
elif high_threshold is None and low_threshold is not None:
high_threshold = random.randint(low_threshold, max_value)
elif high_threshold is not None and low_threshold is None:
low_threshold = random.randint(0, high_threshold)
threshold_pairs.append((low_threshold, high_threshold))
for idx, cur_img in enumerate(img):

if vary_thresholds:
low_threshold, high_threshold = threshold_pairs[idx]
else:
low_threshold, high_threshold = threshold_pairs[0]
cur_img = (
(torch.einsum("chw->hwc", cur_img).cpu().numpy() + 1) * 255 / 2
).astype(np.uint8)
edges = cv2.Canny(cur_img, low_threshold, high_threshold)
edges = (
(((torch.tensor(edges, device=device) / 255) * 2) - 1)
.unsqueeze(0)
.unsqueeze(0)
)
edges_list.append(edges)
edges = torch.cat(edges_list, dim=0)
mask = torch.clamp(mask, 0, 1)
return mask * edges + (1 - mask) * img


def fill_img_with_canny_ori(
img,
mask,
low_threshold=None,
high_threshold=None,
**kwargs,
):
"""Fill the masked region with canny edges."""
low_threshold_random = kwargs["low_threshold_random"]
high_threshold_random = kwargs["high_threshold_random"]
max_value = 255 * 3
device = img.device
edges_list = []
for cur_img in img:
for cur_img, canny in zip(img, canny_list):
high_threshold = cur_high_threshold
low_threshold = (
cur_low_threshold # Reset thresholds for each image for new random
)
low_threshold = cur_low_threshold

if high_threshold is None and low_threshold is None:
threshold_1 = random.randint(low_threshold_random, high_threshold_random)
Expand All @@ -129,10 +62,16 @@ def fill_img_with_canny_ori(
high_threshold = random.randint(low_threshold, max_value)
elif high_threshold is not None and low_threshold is None:
low_threshold = random.randint(0, high_threshold)

cur_img = (
(torch.einsum("chw->hwc", cur_img).cpu().numpy() + 1) * 255 / 2
).astype(np.uint8)
edges = cv2.Canny(cur_img, low_threshold, high_threshold)

if canny == 1:
edges = cv2.Canny(cur_img, low_threshold, high_threshold)
else: # black image
edges = np.zeros_like(cur_img[:, :, 0])

edges = (
(((torch.tensor(edges, device=device) / 255) * 2) - 1)
.unsqueeze(0)
Expand Down

0 comments on commit a2ed3fc

Please sign in to comment.