Skip to content

Commit

Permalink
fix: sam compatible with mask inversion in palette
Browse files Browse the repository at this point in the history
  • Loading branch information
killian31 authored and beniz committed Jul 7, 2023
1 parent ea11745 commit 4294679
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 3 deletions.
28 changes: 26 additions & 2 deletions models/palette_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,13 +384,25 @@ def set_input(self, data):
self.previous_frame_mask = data["B_label_mask"].to(self.device)[:, 0]
### Note: the sam related stuff should eventually go into the dataloader
if self.use_sam_mask:
if self.opt.data_inverted_mask:
temp_mask = data["B_label_mask"].clone()
temp_mask[temp_mask > 0] = 2
temp_mask[temp_mask == 0] = 1
temp_mask[temp_mask == 2] = 0
else:
temp_mask = data["B_label_mask"].clone()
self.mask = compute_mask_with_sam(
self.gt_image,
data["B_label_mask"].to(self.device)[:, 1],
temp_mask.to(self.device)[:, 1],
self.freezenet_sam,
self.device,
batched=True,
)

if self.opt.data_inverted_mask:
self.mask[self.mask > 0] = 2
self.mask[self.mask == 0] = 1
self.mask[self.mask == 2] = 0
self.y_t = fill_mask_with_random(self.gt_image, self.mask, -1)
else:
self.mask = data["B_label_mask"].to(self.device)[:, 1]
Expand All @@ -402,14 +414,26 @@ def set_input(self, data):
self.gt_image = data["B"].to(self.device)
### Note: the sam related stuff should eventually go into the dataloader
if self.use_sam_mask:
if self.opt.data_inverted_mask:
temp_mask = data["B_label_mask"].clone()
temp_mask[temp_mask > 0] = 2
temp_mask[temp_mask == 0] = 1
temp_mask[temp_mask == 2] = 0
else:
temp_mask = data["B_label_mask"].clone()
self.mask = compute_mask_with_sam(
self.gt_image,
data["B_label_mask"].to(self.device),
temp_mask.to(self.device),
self.freezenet_sam,
self.device,
batched=True,
)
if self.opt.data_inverted_mask:
self.mask[self.mask > 0] = 2
self.mask[self.mask == 0] = 1
self.mask[self.mask == 2] = 0
self.y_t = fill_mask_with_random(self.gt_image, self.mask, -1)

else:
self.mask = data["B_label_mask"].to(self.device)
else: # e.g. super-resolution
Expand Down
1 change: 0 additions & 1 deletion models/semantic_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ def define_f(
f_s_config_segformer,
f_s_weight_segformer,
f_s_weight_sam,
f_s_weight_mobile_sam,
jg_dir,
data_crop_size,
**unused_options,
Expand Down

0 comments on commit 4294679

Please sign in to comment.