Skip to content

Commit

Permalink
feat(ml): Canny can use a range of dropout probabilities
Browse files Browse the repository at this point in the history
  • Loading branch information
wr0124 authored and beniz committed Sep 25, 2024
1 parent 027c187 commit 7b4c860
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
7 changes: 4 additions & 3 deletions models/base_diffusion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,10 @@ def modify_commandline_options_train(parser):
)
parser.add_argument(
"--alg_diffusion_vid_canny_dropout",
type=float,
default=0,
help="prob to drop canny for each frame",
type=pairs_of_floats,
default=[[]],
nargs="+",
help="the range of probabilities for dropping the canny for each frame",
)

return parser
Expand Down
8 changes: 5 additions & 3 deletions models/palette_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,9 +405,11 @@ def set_input(self, data):
)

random_num = torch.rand(self.gt_image.shape[0])
canny_frame = (
random_num > self.opt.alg_diffusion_vid_canny_dropout
).int() # binary canny_frame
dropout_pro = torch.empty(self.gt_image.shape[0]).uniform_(
self.opt.alg_diffusion_vid_canny_dropout[0][0],
self.opt.alg_diffusion_vid_canny_dropout[1][0],
)
canny_frame = (random_num > dropout_pro).int() # binary canny_frame
self.cond_image = fill_img_with_random_sketch(
self.gt_image,
self.mask,
Expand Down

0 comments on commit 7b4c860

Please sign in to comment.