Skip to content

Commit

Permalink
feat(ml): debug inference_vid for canny
Browse files Browse the repository at this point in the history
  • Loading branch information
wr0124 committed Sep 18, 2024
1 parent 2c53948 commit 0ca63af
Showing 1 changed file with 19 additions and 3 deletions.
22 changes: 19 additions & 3 deletions scripts/gen_vid_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ def generate(
bbox_select_list = []
img_tensor_list = []
out_img_list = []
sequence_count = 0
for img_path, bbox_path in zip(limited_paths_img, limited_paths_bbox):
img_in = os.path.join(os.path.dirname(os.path.dirname(paths_in_file)), img_path)
bbox_in = os.path.join(
Expand Down Expand Up @@ -599,7 +600,7 @@ def generate(
cond_image = fill_img_with_sketch(
img_tensor.unsqueeze(0), mask.unsqueeze(0)
)
elif opt.alg_diffusion_cond_image_creation == "canny":
elif opt.alg_diffusion_cond_image_creation == "computed_sketch": # "canny":
clamp = torch.clamp(mask, 0, 1)
if cond_in:
# mask the background to avoid canny edges around cond image
Expand All @@ -613,6 +614,7 @@ def generate(
high_threshold=alg_diffusion_sketch_canny_thresholds[1],
low_threshold_random=-1,
high_threshold_random=-1,
select_canny=[1] + [0] * (opt.data_temporal_number_frames - 1),
)
if cond_in:
# restore background
Expand Down Expand Up @@ -683,8 +685,22 @@ def generate(
else:
ref_tensor = None

cond_image_list.append(cond_image)
y_t_list.append(y_t)
if opt.alg_diffusion_cond_image_creation == "computed_sketch":
if sequence_count == 0:
cond_image_list.append(cond_image)
y_t_list.append(y_t)
else:
cond_image_list.append(y_t_list[0])
y_t_list.append(y_t_list[0])
if opt.alg_diffusion_cond_image_creation == "y_t":
if sequence_count == 0:
cond_image_list.append(cond_image)
y_t_list.append(y_t)
else:
cond_image_list.append(cond_image)
y_t_list.append(y_t_list[0])

sequence_count = sequence_count + 1
y0_tensor_list.append(y0_tensor)
mask_list.append(mask)
bbox_select_list.append(bbox_select)
Expand Down

0 comments on commit 0ca63af

Please sign in to comment.