Skip to content

Commit

Permalink
Added fallback to cpu if cuda is not available
Browse files Browse the repository at this point in the history
  • Loading branch information
max810 committed Nov 12, 2023
1 parent 9c5693b commit adc7d4c
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 15 deletions.
3 changes: 2 additions & 1 deletion inference/frame_selection/frame_selection_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@ def extract_keys(dataloder, processor, print_progress=False, flatten=True, **kwa
shrinkages = []
selections = []
device = None
system_device = 'cuda' if torch.cuda.is_available() else 'cpu'
with torch.no_grad(): # just in case
key_sum = None

for ti, data in enumerate(tqdm(dataloder, disable=not print_progress, desc='Calculating key features')):
data: Sample = data
rgb = data.rgb.cuda()
rgb = data.rgb.to(system_device)
key, shrinkage, selection = processor.encode_frame_key(rgb)

if key_sum is None:
Expand Down
11 changes: 6 additions & 5 deletions inference/interact/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def __init__(self, net: XMem,
self.res_man = resource_manager
self.threadpool = QThreadPool()
self.last_opened_directory = str(Path.home())
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'

self.num_frames = len(self.res_man)
self.height, self.width = self.res_man.h, self.res_man.w
Expand Down Expand Up @@ -399,7 +400,7 @@ def __init__(self, net: XMem,
self.current_image = np.zeros((self.height, self.width, 3), dtype=np.uint8)
self.current_image_torch = None
self.current_mask = np.zeros((self.height, self.width), dtype=np.uint8)
self.current_prob = torch.zeros((self.num_objects, self.height, self.width), dtype=torch.float).cuda()
self.current_prob = torch.zeros((self.num_objects, self.height, self.width), dtype=torch.float).to(self.device)

# initialize visualization
self.viz_mode = 'davis'
Expand Down Expand Up @@ -493,7 +494,7 @@ def load_current_torch_image_mask(self, no_mask=False):
self.current_image_torch, self.current_image_torch_no_norm = image_to_torch(self.current_image)

if self.current_prob is None and not no_mask:
self.current_prob = index_numpy_to_one_hot_torch(self.current_mask, self.num_objects+1).cuda()
self.current_prob = index_numpy_to_one_hot_torch(self.current_mask, self.num_objects+1).to(self.device)

def compose_current_im(self):
self.viz = get_visualization(self.viz_mode, self.current_image, self.current_mask,
Expand Down Expand Up @@ -851,7 +852,7 @@ def on_save_reference(self):
if self.interaction is not None:
self.on_commit()
current_image_torch, _ = image_to_torch(self.current_image)
current_prob = index_numpy_to_one_hot_torch(self.current_mask, self.num_objects+1).cuda()
current_prob = index_numpy_to_one_hot_torch(self.current_mask, self.num_objects+1).to(self.device)

msk = current_prob[1:]
a = perf_counter()
Expand Down Expand Up @@ -985,7 +986,7 @@ def on_mouse_press(self, event):
if self.curr_interaction == 'Scribble':
if last_interaction is None or type(last_interaction) != ScribbleInteraction:
self.complete_interaction()
new_interaction = ScribbleInteraction(image, torch.from_numpy(self.current_mask).float().cuda(),
new_interaction = ScribbleInteraction(image, torch.from_numpy(self.current_mask).float().to(self.device),
(h, w), self.s2m_controller, self.num_objects)
elif self.curr_interaction == 'Free':
if last_interaction is None or type(last_interaction) != FreeInteraction:
Expand Down Expand Up @@ -1264,7 +1265,7 @@ def _try_load_layer(self, file_name):
else:
self.console_push_text(f'Layer file {file_name} loaded.')
self.overlay_layer = layer
self.overlay_layer_torch = torch.from_numpy(layer).float().cuda()/255
self.overlay_layer_torch = torch.from_numpy(layer).float().to(self.device)/255
self.show_current_frame()
except FileNotFoundError:
self.console_push_text(f'{file_name} not found.')
Expand Down
4 changes: 3 additions & 1 deletion inference/interact/interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ def __init__(self, image, prev_mask, true_size, num_objects):
self.curr_path = [[] for _ in range(self.K + 1)]

self.size = None

self.device = 'cuda' if torch.cuda.is_available() else 'cpu'

def set_size(self, size):
self.size = size
Expand Down Expand Up @@ -125,7 +127,7 @@ def end_path(self):
self.curr_path = [[] for _ in range(self.K + 1)]

def predict(self):
self.out_prob = index_numpy_to_one_hot_torch(self.drawn_map, self.K+1).cuda()
self.out_prob = index_numpy_to_one_hot_torch(self.drawn_map, self.K+1).to(self.device)
# self.out_prob = torch.from_numpy(self.drawn_map).float().cuda()
# self.out_prob, _ = pad_divide_by(self.out_prob, 16, self.out_prob.shape[-2:])
# self.out_prob = aggregate_sbg(self.out_prob, keep_bg=True)
Expand Down
15 changes: 9 additions & 6 deletions inference/run_on_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def _inference_on_video(frames_with_masks, imgs_in_path, masks_in_path, masks_ou
object_color_if_single_object=(255, 255, 255),
print_fps=False,
image_saving_max_queue_size=200):
device = 'cuda' if torch.cuda.is_available() else 'cpu'

torch.autograd.set_grad_enabled(False)
frames_with_masks = set(frames_with_masks)
Expand Down Expand Up @@ -75,7 +76,7 @@ def _inference_on_video(frames_with_masks, imgs_in_path, masks_in_path, masks_ou
with torch.cuda.amp.autocast(enabled=True):
data: Sample = data # Just for Intellisense
# No batch dimension here, just single samples
sample = replace(data, rgb=data.rgb.cuda())
sample = replace(data, rgb=data.rgb.to(device))

if ti in frames_with_masks:
msk = sample.mask
Expand All @@ -87,7 +88,7 @@ def _inference_on_video(frames_with_masks, imgs_in_path, masks_in_path, masks_ou
# https://github.com/hkchengrex/XMem/issues/21 just make exhaustive = True
msk, labels = mapper.convert_mask(
msk.numpy(), exhaustive=True)
msk = torch.Tensor(msk).cuda()
msk = torch.Tensor(msk).to(device)
if sample.need_resize:
msk = vid_reader.resize_mask(msk.unsqueeze(0))[0]
processor.set_all_labels(list(mapper.remappings.values()))
Expand Down Expand Up @@ -145,8 +146,9 @@ def _inference_on_video(frames_with_masks, imgs_in_path, masks_in_path, masks_ou
return pd.DataFrame(stats)

def _load_main_objects(imgs_in_path, masks_in_path, config):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_path = config['model']
network = XMem(config, model_path, pretrained_key_encoder=False, pretrained_value_encoder=False).cuda().eval()
network = XMem(config, model_path, pretrained_key_encoder=False, pretrained_value_encoder=False).to(device).eval()
if model_path is not None:
model_weights = torch.load(model_path)
network.load_weights(model_weights, init_as_zero_if_needed=True)
Expand Down Expand Up @@ -197,17 +199,18 @@ def _create_dataloaders(imgs_in_path: Union[str, PathLike], masks_in_path: Union


def _preload_permanent_memory(frames_to_put_in_permanent_memory: List[int], vid_reader: VideoReader, mapper: MaskMapper, processor: InferenceCore, augment_images_with_masks=False):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
total_preloading_time = 0
at_least_one_mask_loaded = False
for j in frames_to_put_in_permanent_memory:
sample: Sample = vid_reader[j]
sample = replace(sample, rgb=sample.rgb.cuda())
sample = replace(sample, rgb=sample.rgb.to(device))

# https://github.com/hkchengrex/XMem/issues/21 just make exhaustive = True
if sample.mask is None:
raise FileNotFoundError(f"Couldn't find mask {j}! Check that the filename is either the same as for frame {j} or follows the `frame_%06d.png` format if using a video file for input.")
msk, labels = mapper.convert_mask(sample.mask, exhaustive=True)
msk = torch.Tensor(msk).cuda()
msk = torch.Tensor(msk).to(device)

if min(msk.shape) == 0: # empty mask, e.g. [1, 0, 720, 1280]
warn(f"Skipping adding frame {j} to permanent memory, as the mask is empty")
Expand All @@ -232,7 +235,7 @@ def _preload_permanent_memory(frames_to_put_in_permanent_memory: List[int], vid_

for img_aug, mask_aug in augs:
# tensor -> PIL.Image -> tensor -> whatever normalization vid_reader applies
rgb_aug = vid_reader.im_transform(img_aug(rgb_raw)).cuda()
rgb_aug = vid_reader.im_transform(img_aug(rgb_raw)).to(device)

msk_aug = mask_aug(msk)

Expand Down
5 changes: 3 additions & 2 deletions interactive_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,16 @@
config['enable_long_term'] = True
config['enable_long_term_count_usage'] = True

device = 'cuda' if torch.cuda.is_available() else 'cpu'
with torch.cuda.amp.autocast(enabled=not args.no_amp):

# Load our checkpoint
network = XMem(config, args.model, pretrained_key_encoder=False, pretrained_value_encoder=False).cuda().eval()
network = XMem(config, args.model, pretrained_key_encoder=False, pretrained_value_encoder=False).to(device).eval()

# Loads the S2M model
if args.s2m_model is not None:
s2m_saved = torch.load(args.s2m_model)
s2m_model = S2M().cuda().eval()
s2m_model = S2M().to(device).eval()
s2m_model.load_state_dict(s2m_saved)
else:
s2m_model = None
Expand Down

0 comments on commit adc7d4c

Please sign in to comment.