From 334a33b273c276ac05c6625ccff4d91749adab5e Mon Sep 17 00:00:00 2001 From: max810 Date: Thu, 17 Nov 2022 13:46:47 +0400 Subject: [PATCH 01/49] Added script for running XMem on videos with provided annotations. Added controls on what to store in the memory --- inference/data/mask_mapper.py | 2 +- inference/inference_core.py | 8 +- run_on_video.py | 187 ++++++++++++++++++++++++++++++++++ 3 files changed, 194 insertions(+), 3 deletions(-) create mode 100644 run_on_video.py diff --git a/inference/data/mask_mapper.py b/inference/data/mask_mapper.py index 8e5b38d..378c090 100644 --- a/inference/data/mask_mapper.py +++ b/inference/data/mask_mapper.py @@ -30,7 +30,7 @@ def convert_mask(self, mask, exhaustive=False): new_labels = list(set(labels) - set(self.labels)) if not exhaustive: - assert len(new_labels) == len(labels), 'Old labels found in non-exhaustive mode' + assert len(new_labels) == len(labels), 'Old labels found in non-exhaustive mode' #a: it runs if you put exhaustive = True # add new remappings for i, l in enumerate(new_labels): diff --git a/inference/inference_core.py b/inference/inference_core.py index f5459df..1b48268 100644 --- a/inference/inference_core.py +++ b/inference/inference_core.py @@ -39,14 +39,17 @@ def set_all_labels(self, all_labels): # self.all_labels = [l.item() for l in all_labels] self.all_labels = all_labels - def step(self, image, mask=None, valid_labels=None, end=False): + def step(self, image, mask=None, valid_labels=None, end=False, manually_curated_masks=False): # image: 3*H*W # mask: num_objects*H*W or None self.curr_ti += 1 image, self.pad = pad_divide_by(image, 16) image = image.unsqueeze(0) # add the batch dimension + if manually_curated_masks: + is_mem_frame = (mask is not None) and (not end) + else: + is_mem_frame = ((self.curr_ti-self.last_mem_ti >= self.mem_every) or (mask is not None)) and (not end) - is_mem_frame = ((self.curr_ti-self.last_mem_ti >= self.mem_every) or (mask is not None)) and (not end) need_segment = (self.curr_ti > 0) and ((valid_labels is None) or (len(self.all_labels) != len(valid_labels))) is_deep_update = ( (self.deep_update_sync and is_mem_frame) or # synchronized @@ -98,6 +101,7 @@ def step(self, image, mask=None, valid_labels=None, end=False): pred_prob_with_bg[1:].unsqueeze(0), is_deep_update=is_deep_update) self.memory.add_memory(key, shrinkage, value, self.all_labels, selection=selection if self.enable_long_term else None) + self.last_mem_ti = self.curr_ti if is_deep_update: diff --git a/run_on_video.py b/run_on_video.py new file mode 100644 index 0000000..ec1614b --- /dev/null +++ b/run_on_video.py @@ -0,0 +1,187 @@ +import os +from os import path +from argparse import ArgumentParser +import shutil + +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader +import numpy as np +from PIL import Image +from tqdm import tqdm + +from inference.data.test_datasets import LongTestDataset, DAVISTestDataset, YouTubeVOSTestDataset +from inference.data.mask_mapper import MaskMapper +from inference.data.video_reader import VideoReader +from model.network import XMem +from inference.inference_core import InferenceCore + + +def inference_on_video(how_many_extra_frames, frames_with_masks): + torch.autograd.set_grad_enabled(False) + + config = { + 'buffer_size': 100, + 'deep_update_every': -1, + 'enable_long_term': True, + 'enable_long_term_count_usage': True, + 'fbrs_model': 'saves/fbrs.pth', + 'hidden_dim': 64, + 'images': None, + 'key_dim': 64, + 'max_long_term_elements': 10000, + 'max_mid_term_frames': 10, + 'mem_every': 10, + 'min_mid_term_frames': 5, + 'model': './saves/XMem.pth', + 'no_amp': False, + 'num_objects': 1, + 'num_prototypes': 128, + 's2m_model': 'saves/s2m.pth', + 'size': 480, + 'top_k': 30, + 'value_dim': 512, + 'video': '../VIDEOS/Maksym_frontal_simple.mp4', + 'masks_out_path': f'../VIDEOS/RESULTS/XMem/lips/MaximumPossibleIoU/{how_many_extra_frames}_extra_frames', + 'workspace': None, + 'save_masks': True + } + + model_path = config['model'] + network = XMem(config, model_path).cuda().eval() + if model_path is not None: + model_weights = torch.load(model_path) + network.load_weights(model_weights, init_as_zero_if_needed=True) + else: + print('No model loaded.') + + total_process_time = 0 + total_frames = 0 + + # Start eval + vid_reader = VideoReader( + "Maksym_frontal_simple", + '/home/maksym/RESEARCH/VIDEOS/IVOS_Maksym_frontal_simple_14_Nov_2022_DATA/JPEGImages', + '/home/maksym/RESEARCH/VIDEOS/IVOS_Maksym_frontal_simple_14_Nov_2022_DATA/SegmentationMaskBinary', + size=config['size'], + use_all_mask=True + ) + + loader = DataLoader(vid_reader, batch_size=1, shuffle=False, num_workers=8) + vid_name = vid_reader.vid_name + vid_length = len(loader) + # no need to count usage for LT if the video is not that long anyway + config['enable_long_term_count_usage'] = ( + config['enable_long_term'] and + (vid_length + / (config['max_mid_term_frames']-config['min_mid_term_frames']) + * config['num_prototypes']) + >= config['max_long_term_elements'] + ) + + mapper = MaskMapper() + processor = InferenceCore(network, config=config) + first_mask_loaded = False + + for ti, data in enumerate(loader): + with torch.cuda.amp.autocast(enabled=True): + rgb = data['rgb'].cuda()[0] + + # TODO: - only use % of the frames + if ti in frames_with_masks: + msk = data['mask'] + else: + msk = None + + info = data['info'] + frame = info['frame'][0] + shape = info['shape'] + need_resize = info['need_resize'][0] + + """ + For timing see https://discuss.pytorch.org/t/how-to-measure-time-in-pytorch/26964 + Seems to be very similar in testing as my previous timing method + with two cuda sync + time.time() in STCN though + """ + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + + if not first_mask_loaded: + if msk is not None: + first_mask_loaded = True + else: + # no point to do anything without a mask + continue + + if False: + rgb = torch.flip(rgb, dims=[-1]) + msk = torch.flip(msk, dims=[-1]) if msk is not None else None + + # Map possibly non-continuous labels to continuous ones + # TODO: What are labels? Debug + if msk is not None: + # https://github.com/hkchengrex/XMem/issues/21 just make exhaustive = True + msk, labels = mapper.convert_mask(msk[0].numpy(), exhaustive=True) + msk = torch.Tensor(msk).cuda() + if need_resize: + msk = vid_reader.resize_mask(msk.unsqueeze(0))[0] + processor.set_all_labels(list(mapper.remappings.values())) + + else: + labels = None + + # Run the model on this frame + # TODO: still running inference even on frames with masks? + prob = processor.step(rgb, msk, labels, end=(ti==vid_length-1), manually_curated_masks=False) + + # Upsample to original size if needed + if need_resize: + prob = F.interpolate(prob.unsqueeze(1), shape, mode='bilinear', align_corners=False)[:,0] + + end.record() + torch.cuda.synchronize() + total_process_time += (start.elapsed_time(end)/1000) + total_frames += 1 + + if False: + prob = torch.flip(prob, dims=[-1]) + + # Probability mask -> index mask + out_mask = torch.argmax(prob, dim=0) + out_mask = (out_mask.detach().cpu().numpy()).astype(np.uint8) + + if False: + prob = (prob.detach().cpu().numpy()*255).astype(np.uint8) + + # Save the mask + if config['save_masks']: + this_out_path = path.join(config['masks_out_path'], vid_name) + os.makedirs(this_out_path, exist_ok=True) + out_mask = mapper.remap_index_mask(out_mask) + out_img = Image.fromarray(out_mask) + if vid_reader.get_palette() is not None: + out_img.putpalette(vid_reader.get_palette()) + out_img.save(os.path.join(this_out_path, frame[:-4]+'.png')) + + if False: #args.save_scores: + np_path = path.join(args.output, 'Scores', vid_name) + os.makedirs(np_path, exist_ok=True) + if ti==len(loader)-1: + hkl.dump(mapper.remappings, path.join(np_path, f'backward.hkl'), mode='w') + if args.save_all or info['save'][0]: + hkl.dump(prob, path.join(np_path, f'{frame[:-4]}.hkl'), mode='w', compression='lzf') + + + print(f'Total processing time: {total_process_time}') + print(f'Total processed frames: {total_frames}') + print(f'FPS: {total_frames / total_process_time}') + print(f'Max allocated memory (MB): {torch.cuda.max_memory_allocated() / (2**20)}') + + +if __name__ == '__main__': + for how_many_extra_frames in tqdm(range(0, 91)): + # e.g. [0, 10, 20, ..., 180] without 180 + frames_with_masks = set(np.linspace(0, 180, how_many_extra_frames+2)[0:-1].astype(int)) + + inference_on_video(how_many_extra_frames, frames_with_masks) From cdabe17b7c95e24b906620e57d7b100141a60624 Mon Sep 17 00:00:00 2001 From: max810 Date: Sat, 17 Dec 2022 16:57:52 +0400 Subject: [PATCH 02/49] Implemented permanent working memory mechanism for storing annotated frames forever as references during inference --- inference/inference_core.py | 10 ++- inference/interact/gui.py | 2 +- inference/kv_memory_store.py | 1 + inference/memory_manager.py | 123 +++++++++++++++++++++++------------ 4 files changed, 91 insertions(+), 45 deletions(-) diff --git a/inference/inference_core.py b/inference/inference_core.py index 1b48268..29549d4 100644 --- a/inference/inference_core.py +++ b/inference/inference_core.py @@ -40,6 +40,12 @@ def set_all_labels(self, all_labels): self.all_labels = all_labels def step(self, image, mask=None, valid_labels=None, end=False, manually_curated_masks=False): + # For feedback: + # 1. We run the model as usual + # 2. We get feedback: 2 lists, one with good prediction indices, one with bad + # 3. We force the good frames (+ annotated frames) to stay in working memory forever + # 4. We force the bad frames to never even get added to the working memory + # 5. Rerun with these settings # image: 3*H*W # mask: num_objects*H*W or None self.curr_ti += 1 @@ -49,6 +55,8 @@ def step(self, image, mask=None, valid_labels=None, end=False, manually_curated_ is_mem_frame = (mask is not None) and (not end) else: is_mem_frame = ((self.curr_ti-self.last_mem_ti >= self.mem_every) or (mask is not None)) and (not end) + + is_permanent_frame = mask is not None need_segment = (self.curr_ti > 0) and ((valid_labels is None) or (len(self.all_labels) != len(valid_labels))) is_deep_update = ( @@ -100,7 +108,7 @@ def step(self, image, mask=None, valid_labels=None, end=False, manually_curated_ value, hidden = self.network.encode_value(image, f16, self.memory.get_hidden(), pred_prob_with_bg[1:].unsqueeze(0), is_deep_update=is_deep_update) self.memory.add_memory(key, shrinkage, value, self.all_labels, - selection=selection if self.enable_long_term else None) + selection=selection if self.enable_long_term else None, permanent=is_permanent_frame) self.last_mem_ti = self.curr_ti diff --git a/inference/interact/gui.py b/inference/interact/gui.py index 039a382..679e90e 100644 --- a/inference/interact/gui.py +++ b/inference/interact/gui.py @@ -824,7 +824,7 @@ def update_memory_size(self): max_work_elements = self.processor.memory.max_work_elements max_long_elements = self.processor.memory.max_long_elements - curr_work_elements = self.processor.memory.work_mem.size + curr_work_elements = self.processor.memory.temporary_work_mem.size + self.processor.memory.permanent_work_mem.size curr_long_elements = self.processor.memory.long_mem.size self.work_mem_gauge.setFormat(f'{curr_work_elements} / {max_work_elements}') diff --git a/inference/kv_memory_store.py b/inference/kv_memory_store.py index 33a3326..a31ac63 100644 --- a/inference/kv_memory_store.py +++ b/inference/kv_memory_store.py @@ -105,6 +105,7 @@ def sieve_by_range(self, start: int, end: int, min_size: int): # (because they are not consolidated) if end == 0: + # just sieves till the `start` # negative 0 would not work as the end index! self.k = self.k[:,:,:start] if self.count_usage: diff --git a/inference/memory_manager.py b/inference/memory_manager.py index bce2c00..3641760 100644 --- a/inference/memory_manager.py +++ b/inference/memory_manager.py @@ -9,6 +9,7 @@ class MemoryManager: """ Manages all three memory stores and the transition between working/long-term memory """ + def __init__(self, config): self.hidden_dim = config['hidden_dim'] self.top_k = config['top_k'] @@ -16,8 +17,8 @@ def __init__(self, config): self.enable_long_term = config['enable_long_term'] self.enable_long_term_usage = config['enable_long_term_count_usage'] if self.enable_long_term: - self.max_mt_frames = config['max_mid_term_frames'] - self.min_mt_frames = config['min_mid_term_frames'] + self.max_mt_frames = config['max_mid_term_frames'] # maximum work memory size + self.min_mt_frames = config['min_mid_term_frames'] # minimum number of frames to keep in work memory when consolidating self.num_prototypes = config['num_prototypes'] self.max_long_elements = config['max_long_term_elements'] @@ -29,7 +30,8 @@ def __init__(self, config): # B x num_objects x CH x H x W self.hidden = None - self.work_mem = KeyValueMemoryStore(count_usage=self.enable_long_term) + self.temporary_work_mem = KeyValueMemoryStore(count_usage=self.enable_long_term) + self.permanent_work_mem = KeyValueMemoryStore(count_usage=False) if self.enable_long_term: self.long_mem = KeyValueMemoryStore(count_usage=self.enable_long_term_usage) @@ -57,7 +59,9 @@ def _readout(self, affinity, v): def match_memory(self, query_key, selection): # query_key: B x C^k x H x W # selection: B x C^k x H x W - num_groups = self.work_mem.num_groups + # TODO: keep groups in both..? + # 1x64x30x54 + num_groups = self.temporary_work_mem.num_groups h, w = query_key.shape[-2:] query_key = query_key.flatten(start_dim=2) @@ -67,21 +71,25 @@ def match_memory(self, query_key, selection): Memory readout using keys """ + temp_work_mem_size = self.temporary_work_mem.size if self.enable_long_term and self.long_mem.engaged(): # Use long-term memory long_mem_size = self.long_mem.size - memory_key = torch.cat([self.long_mem.key, self.work_mem.key], -1) - shrinkage = torch.cat([self.long_mem.shrinkage, self.work_mem.shrinkage], -1) + + memory_key = torch.cat([self.long_mem.key, self.temporary_work_mem.key, self.permanent_work_mem.key], -1) + shrinkage = torch.cat([self.long_mem.shrinkage, self.temporary_work_mem.shrinkage, self.permanent_work_mem.shrinkage], -1) similarity = get_similarity(memory_key, shrinkage, query_key, selection) - work_mem_similarity = similarity[:, long_mem_size:] + long_mem_similarity = similarity[:, :long_mem_size] + temp_work_mem_similarity = similarity[:, long_mem_size:long_mem_size+temp_work_mem_size] + perm_work_mem_similarity = similarity[:, long_mem_size+temp_work_mem_size:] # get the usage with the first group # the first group always have all the keys valid affinity, usage = do_softmax( - torch.cat([long_mem_similarity[:, -self.long_mem.get_v_size(0):], work_mem_similarity], 1), - top_k=self.top_k, inplace=True, return_usage=True) + torch.cat([long_mem_similarity[:, -self.long_mem.get_v_size(0):], temp_work_mem_similarity, perm_work_mem_similarity], 1), + top_k=self.top_k, inplace=True, return_usage=True) affinity = [affinity] # compute affinity group by group as later groups only have a subset of keys @@ -89,57 +97,66 @@ def match_memory(self, query_key, selection): if gi < self.long_mem.num_groups: # merge working and lt similarities before softmax affinity_one_group = do_softmax( - torch.cat([long_mem_similarity[:, -self.long_mem.get_v_size(gi):], - work_mem_similarity[:, -self.work_mem.get_v_size(gi):]], 1), + torch.cat([long_mem_similarity[:, -self.long_mem.get_v_size(gi):], + temp_work_mem_similarity[:, -self.temporary_work_mem.get_v_size(gi):], + perm_work_mem_similarity[:, -self.permanent_work_mem.get_v_size(gi):]], + 1), top_k=self.top_k, inplace=True) else: # no long-term memory for this group - affinity_one_group = do_softmax(work_mem_similarity[:, -self.work_mem.get_v_size(gi):], - top_k=self.top_k, inplace=(gi==num_groups-1)) + affinity_one_group = do_softmax(torch.cat([ + temp_work_mem_similarity[:, -self.temporary_work_mem.get_v_size(gi):], + perm_work_mem_similarity[:, -self.permanent_work_mem.get_v_size(gi):]], + 1), + top_k=self.top_k, inplace=(gi == num_groups-1)) affinity.append(affinity_one_group) all_memory_value = [] - for gi, gv in enumerate(self.work_mem.value): + for gi, gv in enumerate(self.temporary_work_mem.value): # merge the working and lt values before readout if gi < self.long_mem.num_groups: - all_memory_value.append(torch.cat([self.long_mem.value[gi], self.work_mem.value[gi]], -1)) + all_memory_value.append(torch.cat([self.long_mem.value[gi], self.temporary_work_mem.value[gi], self.permanent_work_mem.value[gi]], -1)) else: - all_memory_value.append(gv) + all_memory_value.append(torch.cat([self.temporary_work_mem.value[gi], self.permanent_work_mem.value[gi]], -1)) """ Record memory usage for working and long-term memory """ # ignore the index return for long-term memory - work_usage = usage[:, long_mem_size:] - self.work_mem.update_usage(work_usage.flatten()) + work_usage = usage[:, long_mem_size:long_mem_size+temp_work_mem_size] # no usage for permanent memory + self.temporary_work_mem.update_usage(work_usage.flatten()) if self.enable_long_term_usage: # ignore the index return for working memory long_usage = usage[:, :long_mem_size] self.long_mem.update_usage(long_usage.flatten()) else: + memory_key = torch.cat([self.temporary_work_mem.key, self.permanent_work_mem.key], -1) + shrinkage = torch.cat([self.temporary_work_mem.shrinkage, self.permanent_work_mem.shrinkage], -1) # No long-term memory - similarity = get_similarity(self.work_mem.key, self.work_mem.shrinkage, query_key, selection) + similarity = get_similarity(memory_key, shrinkage, query_key, selection) if self.enable_long_term: - affinity, usage = do_softmax(similarity, inplace=(num_groups==1), - top_k=self.top_k, return_usage=True) + affinity, usage = do_softmax(similarity, inplace=(num_groups == 1), + top_k=self.top_k, return_usage=True) # Record memory usage for working memory - self.work_mem.update_usage(usage.flatten()) + self.temporary_work_mem.update_usage(usage[:, :temp_work_mem_size].flatten()) else: - affinity = do_softmax(similarity, inplace=(num_groups==1), - top_k=self.top_k, return_usage=False) + affinity = do_softmax(similarity, inplace=(num_groups == 1), + top_k=self.top_k, return_usage=False) affinity = [affinity] # compute affinity group by group as later groups only have a subset of keys for gi in range(1, num_groups): - affinity_one_group = do_softmax(similarity[:, -self.work_mem.get_v_size(gi):], - top_k=self.top_k, inplace=(gi==num_groups-1)) + affinity_one_group = do_softmax(similarity[:, -self.temporary_work_mem.get_v_size(gi):], + top_k=self.top_k, inplace=(gi == num_groups-1)) affinity.append(affinity_one_group) - - all_memory_value = self.work_mem.value + + all_memory_value = [] + for gi, gv in enumerate(self.temporary_work_mem.value): + all_memory_value.append(torch.cat([self.temporary_work_mem.value[gi], self.permanent_work_mem.value[gi]], -1)) # Shared affinity within each group all_readout_mem = torch.cat([ @@ -149,7 +166,7 @@ def match_memory(self, query_key, selection): return all_readout_mem.view(all_readout_mem.shape[0], self.CV, h, w) - def add_memory(self, key, shrinkage, value, objects, selection=None): + def add_memory(self, key, shrinkage, value, objects, selection=None, permanent=False): # key: 1*C*H*W # value: 1*num_objects*C*H*W # objects contain a list of object indices @@ -165,7 +182,7 @@ def add_memory(self, key, shrinkage, value, objects, selection=None): # key: 1*C*N # value: num_objects*C*N key = key.flatten(start_dim=2) - shrinkage = shrinkage.flatten(start_dim=2) + shrinkage = shrinkage.flatten(start_dim=2) value = value[0].flatten(start_dim=2) self.CK = key.shape[1] @@ -176,18 +193,26 @@ def add_memory(self, key, shrinkage, value, objects, selection=None): warnings.warn('the selection factor is only needed in long-term mode', UserWarning) selection = selection.flatten(start_dim=2) - self.work_mem.add(key, value, shrinkage, selection, objects) + if permanent: + self.permanent_work_mem.add(key, value, shrinkage, selection, objects) + else: + self.temporary_work_mem.add(key, value, shrinkage, selection, objects) + + if not self.temporary_work_mem.engaged(): + # first frame goes in goth to avoid crashes + self.temporary_work_mem.add(key, value, shrinkage, selection, objects) # long-term memory cleanup if self.enable_long_term: # Do memory compressed if needed - if self.work_mem.size >= self.max_work_elements: + if self.temporary_work_mem.size >= self.max_work_elements: + # if we have more then N elements in the work memory # Remove obsolete features if needed if self.long_mem.size >= (self.max_long_elements-self.num_prototypes): self.long_mem.remove_obsolete_features(self.max_long_elements-self.num_prototypes) - - self.compress_features() + # We NEVER remove anything from the working memory + self.compress_features() def create_hidden_state(self, n, sample_key): # n is the TOTAL number of objects @@ -196,11 +221,11 @@ def create_hidden_state(self, n, sample_key): self.hidden = torch.zeros((1, n, self.hidden_dim, h, w), device=sample_key.device) elif self.hidden.shape[1] != n: self.hidden = torch.cat([ - self.hidden, + self.hidden, torch.zeros((1, n-self.hidden.shape[1], self.hidden_dim, h, w), device=sample_key.device) ], 1) - assert(self.hidden.shape[1] == n) + assert (self.hidden.shape[1] == n) def set_hidden(self, hidden): self.hidden = hidden @@ -208,34 +233,46 @@ def set_hidden(self, hidden): def get_hidden(self): return self.hidden + # def slices_excluding_permanent(self, group_value, start, end): + # HW = self.HW + # group_value[:,:,HW:-self.min_work_elements+HW] + + # slices = [] + + # # this won't work because after just 1 consolidation all permanent frames are going to be god know where + # # and their indices would mean nothing + # # How about have 2 separate tensors and concatenate them just for memory reading? + # all_indices = torch.arange(self.temporary_work_mem.size // HW) # all frames indices from 0 to ... + def compress_features(self): HW = self.HW candidate_value = [] - total_work_mem_size = self.work_mem.size - for gv in self.work_mem.value: + total_work_mem_size = self.temporary_work_mem.size + for gv in self.temporary_work_mem.value: # Some object groups might be added later in the video # So not all keys have values associated with all objects # We need to keep track of the key->value validity mem_size_in_this_group = gv.shape[-1] if mem_size_in_this_group == total_work_mem_size: # full LT - candidate_value.append(gv[:,:,HW:-self.min_work_elements+HW]) + candidate_value.append(gv[:, :, :-self.min_work_elements]) else: # mem_size is smaller than total_work_mem_size, but at least HW assert HW <= mem_size_in_this_group < total_work_mem_size - if mem_size_in_this_group > self.min_work_elements+HW: + if mem_size_in_this_group > self.min_work_elements: # part of this object group still goes into LT - candidate_value.append(gv[:,:,HW:-self.min_work_elements+HW]) + candidate_value.append(gv[:, :, :-self.min_work_elements]) else: # this object group cannot go to the LT at all candidate_value.append(None) # perform memory consolidation + # now starts at zero, because the 1st frame is going into permanent memory prototype_key, prototype_value, prototype_shrinkage = self.consolidation( - *self.work_mem.get_all_sliced(HW, -self.min_work_elements+HW), candidate_value) + *self.temporary_work_mem.get_all_sliced(0, -self.min_work_elements), candidate_value) # remove consolidated working memory - self.work_mem.sieve_by_range(HW, -self.min_work_elements+HW, min_size=self.min_work_elements+HW) + self.temporary_work_mem.sieve_by_range(0, -self.min_work_elements, min_size=self.min_work_elements+HW) # add to long-term memory self.long_mem.add(prototype_key, prototype_value, prototype_shrinkage, selection=None, objects=None) From c99e44df0c106b0afa5417ac34d2cbda3d5c5ba6 Mon Sep 17 00:00:00 2001 From: max810 Date: Sun, 18 Dec 2022 11:02:42 +0400 Subject: [PATCH 03/49] Implemented preloading all annotated frames to permanent working memory before running inference --- inference/inference_core.py | 23 +++++++++++++++++++++-- inference/memory_manager.py | 8 +++++--- run_on_video.py | 34 +++++++++++++++++++++++++--------- 3 files changed, 51 insertions(+), 14 deletions(-) diff --git a/inference/inference_core.py b/inference/inference_core.py index 29549d4..bc5d4d2 100644 --- a/inference/inference_core.py +++ b/inference/inference_core.py @@ -56,7 +56,8 @@ def step(self, image, mask=None, valid_labels=None, end=False, manually_curated_ else: is_mem_frame = ((self.curr_ti-self.last_mem_ti >= self.mem_every) or (mask is not None)) and (not end) - is_permanent_frame = mask is not None + # is_permanent_frame = mask is not None + is_ignore = mask is not None # to avoid adding permanent memory frames twice, since they are alredy in the memory need_segment = (self.curr_ti > 0) and ((valid_labels is None) or (len(self.all_labels) != len(valid_labels))) is_deep_update = ( @@ -108,7 +109,7 @@ def step(self, image, mask=None, valid_labels=None, end=False, manually_curated_ value, hidden = self.network.encode_value(image, f16, self.memory.get_hidden(), pred_prob_with_bg[1:].unsqueeze(0), is_deep_update=is_deep_update) self.memory.add_memory(key, shrinkage, value, self.all_labels, - selection=selection if self.enable_long_term else None, permanent=is_permanent_frame) + selection=selection if self.enable_long_term else None, ignore=is_ignore) self.last_mem_ti = self.curr_ti @@ -117,3 +118,21 @@ def step(self, image, mask=None, valid_labels=None, end=False, manually_curated_ self.last_deep_update_ti = self.curr_ti return unpad(pred_prob_with_bg, self.pad) + + def put_to_permanent_memory(self, image, mask): + image, self.pad = pad_divide_by(image, 16) + image = image.unsqueeze(0) # add the batch dimension + key, shrinkage, selection, f16, f8, f4 = self.network.encode_key(image, + need_ek=True, + need_sk=True) + + mask, _ = pad_divide_by(mask, 16) + + pred_prob_with_bg = aggregate(mask, dim=0) + self.memory.create_hidden_state(len(self.all_labels), key) + + value, hidden = self.network.encode_value(image, f16, self.memory.get_hidden(), + pred_prob_with_bg[1:].unsqueeze(0), is_deep_update=False) + + self.memory.add_memory(key, shrinkage, value, self.all_labels, + selection=selection if self.enable_long_term else None, permanent=True) diff --git a/inference/memory_manager.py b/inference/memory_manager.py index 3641760..d0f992a 100644 --- a/inference/memory_manager.py +++ b/inference/memory_manager.py @@ -166,7 +166,7 @@ def match_memory(self, query_key, selection): return all_readout_mem.view(all_readout_mem.shape[0], self.CV, h, w) - def add_memory(self, key, shrinkage, value, objects, selection=None, permanent=False): + def add_memory(self, key, shrinkage, value, objects, selection=None, permanent=False, ignore=False): # key: 1*C*H*W # value: 1*num_objects*C*H*W # objects contain a list of object indices @@ -193,13 +193,15 @@ def add_memory(self, key, shrinkage, value, objects, selection=None, permanent=F warnings.warn('the selection factor is only needed in long-term mode', UserWarning) selection = selection.flatten(start_dim=2) - if permanent: + if ignore: + pass # all permanent frames are pre-placed into permanent memory + elif permanent: self.permanent_work_mem.add(key, value, shrinkage, selection, objects) else: self.temporary_work_mem.add(key, value, shrinkage, selection, objects) if not self.temporary_work_mem.engaged(): - # first frame goes in goth to avoid crashes + # first frame goes in both to avoid crashes self.temporary_work_mem.add(key, value, shrinkage, selection, objects) # long-term memory cleanup diff --git a/run_on_video.py b/run_on_video.py index ec1614b..6769b5a 100644 --- a/run_on_video.py +++ b/run_on_video.py @@ -1,6 +1,7 @@ import os from os import path from argparse import ArgumentParser +from pathlib import Path import shutil import torch @@ -17,7 +18,7 @@ from inference.inference_core import InferenceCore -def inference_on_video(how_many_extra_frames, frames_with_masks): +def inference_on_video(frames_with_masks, dir_name=''): torch.autograd.set_grad_enabled(False) config = { @@ -83,11 +84,26 @@ def inference_on_video(how_many_extra_frames, frames_with_masks): processor = InferenceCore(network, config=config) first_mask_loaded = False - for ti, data in enumerate(loader): + for j in frames_with_masks: + sample = vid_reader[j] + rgb = sample['rgb'].cuda() + msk = sample['mask'] + info = sample['info'] + need_resize = info['need_resize'] + + # https://github.com/hkchengrex/XMem/issues/21 just make exhaustive = True + msk, labels = mapper.convert_mask(msk, exhaustive=True) + msk = torch.Tensor(msk).cuda() + if need_resize: + msk = vid_reader.resize_mask(msk.unsqueeze(0))[0] + + processor.set_all_labels(list(mapper.remappings.values())) + processor.put_to_permanent_memory(rgb, msk) + + for ti, data in enumerate(tqdm(loader)): with torch.cuda.amp.autocast(enabled=True): rgb = data['rgb'].cuda()[0] - # TODO: - only use % of the frames if ti in frames_with_masks: msk = data['mask'] else: @@ -103,9 +119,9 @@ def inference_on_video(how_many_extra_frames, frames_with_masks): Seems to be very similar in testing as my previous timing method with two cuda sync + time.time() in STCN though """ - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - start.record() + # start = torch.cuda.Event(enable_timing=True) + # end = torch.cuda.Event(enable_timing=True) + # start.record() if not first_mask_loaded: if msk is not None: @@ -139,9 +155,9 @@ def inference_on_video(how_many_extra_frames, frames_with_masks): if need_resize: prob = F.interpolate(prob.unsqueeze(1), shape, mode='bilinear', align_corners=False)[:,0] - end.record() - torch.cuda.synchronize() - total_process_time += (start.elapsed_time(end)/1000) + # end.record() + # torch.cuda.synchronize() + # total_process_time += (start.elapsed_time(end)/1000) total_frames += 1 if False: From 29f83413dd770e4901bdcb3d87b236ce523f0b6e Mon Sep 17 00:00:00 2001 From: max810 Date: Fri, 23 Dec 2022 18:41:16 +0400 Subject: [PATCH 04/49] Added active learning based on entropy and BALD, added configurations for running the old memory mechanism --- inference/active_learning.py | 75 ++++++++ inference/data/video_reader.py | 2 + inference/inference_core.py | 15 +- inference/memory_manager.py | 13 +- run_on_video.py | 333 ++++++++++++++++++++++++++++----- 5 files changed, 388 insertions(+), 50 deletions(-) create mode 100644 inference/active_learning.py diff --git a/inference/active_learning.py b/inference/active_learning.py new file mode 100644 index 0000000..4a3b4f8 --- /dev/null +++ b/inference/active_learning.py @@ -0,0 +1,75 @@ +from sklearn.cluster import KMeans +import pandas as pd +from PIL import Image +from torchvision.transforms import ColorJitter, Grayscale, RandomPosterize, RandomAdjustSharpness, ToTensor + +def select_n_frame_candidates(preds_df: pd.DataFrame, uncertainty_name: str, n=5): + df = preds_df + + df.reset_index(drop=False, inplace=True) + + # max_frame = df['frame'].max() + # max_entropy = df['entropy'].max() + + df = df[df['mask_provided'] == False] # removing frames with masks + df = df[df[uncertainty_name] >= df[uncertainty_name].median()] # removing low entropy parts + + df_backup = df.copy() + + df['index'] = df['index'] / df['index'].max() # scale to 0..1 + # df['entropy'] = df['entropy'] / df['entropy'].max() # scale to 0..1 + + X = df[['index', uncertainty_name]].to_numpy() + + clusterer = KMeans(n_clusters=n) + + labels = clusterer.fit_predict(X) + + clusters = df_backup.groupby(labels) + + candidates = [] + + for g, cluster in clusters: + if g == -1: + continue + + max_entropy_idx = cluster[uncertainty_name].argmax() + + res = cluster.iloc[max_entropy_idx] + + candidates.append(res) + + return candidates + +def select_most_uncertain_frame(preds_df: pd.DataFrame, uncertainty_name: str): + preds_df.reset_index(drop=False, inplace=True) + return preds_df.iloc[preds_df[uncertainty_name].argmax()] + + +def get_determenistic_augmentations(): + # TODO: maybe add GaussianBlur? + + bright = ColorJitter(brightness=(1.5, 1.5)) + dark = ColorJitter(brightness=(0.5, 0.5)) + gray = Grayscale(num_output_channels=3) + reduce_bits = RandomPosterize(bits=3, p=1) + sharp = RandomAdjustSharpness(sharpness_factor=16, p=1) + + return [bright, dark, gray, reduce_bits, sharp] + +def apply_aug(img_path, out_path): + img = Image.open(img_path) + + bright, dark, gray, reduce_bits, sharp = get_determenistic_augmentations() + + img_augged = sharp(img) + + img_augged.save(out_path) + + +if __name__ == '__main__': + img_in = '/home/maksym/RESEARCH/VIDEOS/thanks_no_ears_5_annot/JPEGImages/frame_000001.PNG' + img_out = 'test_aug.png' + + apply_aug(img_in, img_out) + \ No newline at end of file diff --git a/inference/data/video_reader.py b/inference/data/video_reader.py index 28cc4c6..bc468c4 100644 --- a/inference/data/video_reader.py +++ b/inference/data/video_reader.py @@ -5,6 +5,7 @@ from torchvision import transforms from torchvision.transforms import InterpolationMode import torch.nn.functional as F +import torchvision.transforms.functional as FT from PIL import Image import numpy as np @@ -71,6 +72,7 @@ def __getitem__(self, idx): shape = np.array(size_im).shape[:2] gt_path = path.join(self.mask_dir, frame[:-4]+'.png') + data['raw_image_tensor'] = FT.to_tensor(img) # for dataloaders it cannot be raw PIL.Image, only tensors img = self.im_transform(img) load_mask = self.use_all_mask or (gt_path == self.first_gt_path) diff --git a/inference/inference_core.py b/inference/inference_core.py index bc5d4d2..b81f4a7 100644 --- a/inference/inference_core.py +++ b/inference/inference_core.py @@ -39,7 +39,7 @@ def set_all_labels(self, all_labels): # self.all_labels = [l.item() for l in all_labels] self.all_labels = all_labels - def step(self, image, mask=None, valid_labels=None, end=False, manually_curated_masks=False): + def step(self, image, mask=None, valid_labels=None, end=False, manually_curated_masks=False, disable_memory_updates=False, do_not_add_mask_to_memory=False): # For feedback: # 1. We run the model as usual # 2. We get feedback: 2 lists, one with good prediction indices, one with bad @@ -55,9 +55,8 @@ def step(self, image, mask=None, valid_labels=None, end=False, manually_curated_ is_mem_frame = (mask is not None) and (not end) else: is_mem_frame = ((self.curr_ti-self.last_mem_ti >= self.mem_every) or (mask is not None)) and (not end) - - # is_permanent_frame = mask is not None - is_ignore = mask is not None # to avoid adding permanent memory frames twice, since they are alredy in the memory + + is_ignore = do_not_add_mask_to_memory # to avoid adding permanent memory frames twice, since they are alredy in the memory need_segment = (self.curr_ti > 0) and ((valid_labels is None) or (len(self.all_labels) != len(valid_labels))) is_deep_update = ( @@ -71,6 +70,11 @@ def step(self, image, mask=None, valid_labels=None, end=False, manually_curated_ need_sk=is_mem_frame) multi_scale_features = (f16, f8, f4) + if disable_memory_updates: + is_normal_update = False + is_deep_update = False + is_mem_frame = False + # segment the current frame is needed if need_segment: memory_readout = self.memory.match_memory(key, selection).unsqueeze(0) @@ -102,7 +106,8 @@ def step(self, image, mask=None, valid_labels=None, end=False, manually_curated_ pred_prob_with_bg = aggregate(mask, dim=0) # also create new hidden states - self.memory.create_hidden_state(len(self.all_labels), key) + if not disable_memory_updates: + self.memory.create_hidden_state(len(self.all_labels), key) # save as memory if needed if is_mem_frame: diff --git a/inference/memory_manager.py b/inference/memory_manager.py index d0f992a..bb132cf 100644 --- a/inference/memory_manager.py +++ b/inference/memory_manager.py @@ -194,15 +194,22 @@ def add_memory(self, key, shrinkage, value, objects, selection=None, permanent=F selection = selection.flatten(start_dim=2) if ignore: - pass # all permanent frames are pre-placed into permanent memory + pass # all permanent frames are pre-placed into permanent memory (when using our memory modification) + # also ignores the first frame (#0) when using original memory mechanism, since it's already in the permanent memory elif permanent: self.permanent_work_mem.add(key, value, shrinkage, selection, objects) else: self.temporary_work_mem.add(key, value, shrinkage, selection, objects) if not self.temporary_work_mem.engaged(): - # first frame goes in both to avoid crashes - self.temporary_work_mem.add(key, value, shrinkage, selection, objects) + # first frame; we need to have both memories engaged to avoid crashes when concating + # so we just initialize the temporary one with an empty tensor + key0 = key[..., 0:0] + value0 = value[..., 0:0] + shrinkage0 = shrinkage[..., 0:0] + selection0 = selection[..., 0:0] + + self.temporary_work_mem.add(key0, value0, shrinkage0, selection0, objects) # long-term memory cleanup if self.enable_long_term: diff --git a/run_on_video.py b/run_on_video.py index 6769b5a..8afab2d 100644 --- a/run_on_video.py +++ b/run_on_video.py @@ -1,26 +1,38 @@ +from collections import defaultdict +import math import os from os import path from argparse import ArgumentParser from pathlib import Path import shutil +import pandas as pd import torch import torch.nn.functional as F from torch.utils.data import DataLoader import numpy as np from PIL import Image from tqdm import tqdm +from scipy.stats import entropy +from baal.active.heuristics import BALD +import torchvision.transforms.functional as FT +from inference.active_learning import get_determenistic_augmentations, select_most_uncertain_frame, select_n_frame_candidates from inference.data.test_datasets import LongTestDataset, DAVISTestDataset, YouTubeVOSTestDataset from inference.data.mask_mapper import MaskMapper from inference.data.video_reader import VideoReader from model.network import XMem from inference.inference_core import InferenceCore +from util.tensor_util import compute_tensor_iou -def inference_on_video(frames_with_masks, dir_name=''): +def inference_on_video(frames_with_masks, imgs_in_path, masks_in_path, masks_out_path, + original_memory_mechanism=False, + compute_iou = False, compute_uncertainty = False, manually_curated_masks=False, print_progress=True, + uncertainty_name: str = None, + overwrite_config: dict = None): torch.autograd.set_grad_enabled(False) - + frames_with_masks = set(frames_with_masks) config = { 'buffer_size': 100, 'deep_update_every': -1, @@ -42,12 +54,25 @@ def inference_on_video(frames_with_masks, dir_name=''): 'size': 480, 'top_k': 30, 'value_dim': 512, - 'video': '../VIDEOS/Maksym_frontal_simple.mp4', - 'masks_out_path': f'../VIDEOS/RESULTS/XMem/lips/MaximumPossibleIoU/{how_many_extra_frames}_extra_frames', + # 'video': '../VIDEOS/Maksym_frontal_simple.mp4', + 'masks_out_path': masks_out_path,#f'../VIDEOS/RESULTS/XMem_feedback/thanks_two_face_5_frames/', + # 'masks_out_path': f'../VIDEOS/RESULTS/XMem/WhichFramesWithPreds/1/{dir_name}/{"_".join(map(str, frames_with_masks))}_frames_provided', + # 'masks_out_path': f'../VIDEOS/RESULTS/XMem/DAVIS_2017/WhichFrames/1/{dir_name}/{len(frames_with_masks) - 1}_extra_frames', 'workspace': None, 'save_masks': True } + if overwrite_config is not None: + config.update(overwrite_config) + + vid_reader = VideoReader( + "", + imgs_in_path, #f'/home/maksym/RESEARCH/VIDEOS/thanks_no_ears_5_annot/JPEGImages', + masks_in_path, #f'/home/maksym/RESEARCH/VIDEOS/thanks_no_ears_5_annot/Annotations_binarized_two_face', + size=config['size'], + use_all_mask=True + ) + model_path = config['model'] network = XMem(config, model_path).cuda().eval() if model_path is not None: @@ -56,18 +81,6 @@ def inference_on_video(frames_with_masks, dir_name=''): else: print('No model loaded.') - total_process_time = 0 - total_frames = 0 - - # Start eval - vid_reader = VideoReader( - "Maksym_frontal_simple", - '/home/maksym/RESEARCH/VIDEOS/IVOS_Maksym_frontal_simple_14_Nov_2022_DATA/JPEGImages', - '/home/maksym/RESEARCH/VIDEOS/IVOS_Maksym_frontal_simple_14_Nov_2022_DATA/SegmentationMaskBinary', - size=config['size'], - use_all_mask=True - ) - loader = DataLoader(vid_reader, batch_size=1, shuffle=False, num_workers=8) vid_name = vid_reader.vid_name vid_length = len(loader) @@ -84,7 +97,12 @@ def inference_on_video(frames_with_masks, dir_name=''): processor = InferenceCore(network, config=config) first_mask_loaded = False - for j in frames_with_masks: + if original_memory_mechanism: + frames_to_put_in_permanent_memory = [0] # only the first frame goes into permanent memory originally + # the rest are going to be processed later + else: + frames_to_put_in_permanent_memory = frames_with_masks # in our modification, all frames with provided masks go into permanent memory + for j in frames_to_put_in_permanent_memory: sample = vid_reader[j] rgb = sample['rgb'].cuda() msk = sample['mask'] @@ -99,11 +117,23 @@ def inference_on_video(frames_with_masks, dir_name=''): processor.set_all_labels(list(mapper.remappings.values())) processor.put_to_permanent_memory(rgb, msk) + + stats = [] - for ti, data in enumerate(tqdm(loader)): + if compute_uncertainty: + assert uncertainty_name is not None + uncertainty_name = uncertainty_name.lower() + assert uncertainty_name in {'entropy', 'bald'} + + if uncertainty_name == 'bald': + bald = BALD() + + for ti, data in enumerate(tqdm(loader, disable=not print_progress)): with torch.cuda.amp.autocast(enabled=True): rgb = data['rgb'].cuda()[0] + rgb_raw_tensor = data['raw_image_tensor'].cpu()[0] + gt = data.get('mask') # for IoU computations if ti in frames_with_masks: msk = data['mask'] else: @@ -113,15 +143,7 @@ def inference_on_video(frames_with_masks, dir_name=''): frame = info['frame'][0] shape = info['shape'] need_resize = info['need_resize'][0] - - """ - For timing see https://discuss.pytorch.org/t/how-to-measure-time-in-pytorch/26964 - Seems to be very similar in testing as my previous timing method - with two cuda sync + time.time() in STCN though - """ - # start = torch.cuda.Event(enable_timing=True) - # end = torch.cuda.Event(enable_timing=True) - # start.record() + curr_stat = {'frame': frame, 'mask_provided': msk is not None} if not first_mask_loaded: if msk is not None: @@ -146,20 +168,43 @@ def inference_on_video(frames_with_masks, dir_name=''): else: labels = None - + + if compute_uncertainty and uncertainty_name == 'bald': + dry_run_preds = [] + augs = get_determenistic_augmentations() + for aug in augs: + # tensor -> PIL.Image -> tensor -> whatever normalization vid_reader applies + rgb_raw = FT.to_pil_image(rgb_raw_tensor) + rgb_aug = vid_reader.im_transform(aug(rgb_raw)).cuda() + + dry_run_prob = processor.step(rgb_aug, msk, labels, end=(ti==vid_length-1), manually_curated_masks=manually_curated_masks, disable_memory_updates=True) + dry_run_preds.append(dry_run_prob.cpu()) + + if original_memory_mechanism: + do_not_add_mask_to_memory = (ti == 0) # we only ignore the first mask, since it's already in the permanent memory + else: + do_not_add_mask_to_memory = msk is not None # we ignore all frames with masks, since they are already preloaded in the permanent memory # Run the model on this frame # TODO: still running inference even on frames with masks? - prob = processor.step(rgb, msk, labels, end=(ti==vid_length-1), manually_curated_masks=False) + # 2+ channels, classes+ and background + prob = processor.step(rgb, msk, labels, end=(ti==vid_length-1), manually_curated_masks=manually_curated_masks, do_not_add_mask_to_memory=do_not_add_mask_to_memory) + + if compute_uncertainty: + if uncertainty_name == 'bald': + # [batch=1, num_classes, ..., num_iterations] + all_samples = torch.stack([x.unsqueeze(0) for x in dry_run_preds + [prob.cpu()]], dim=-1).numpy() + score = bald.compute_score(all_samples) + # TODO: can also return the exact pixels for every frame? As a suggestion on what to label + curr_stat['bald'] = float(np.squeeze(score).mean()) + else: + e = entropy(prob.cpu()) + e_mean = np.mean(e) + curr_stat['entropy'] = float(e_mean) # Upsample to original size if needed if need_resize: prob = F.interpolate(prob.unsqueeze(1), shape, mode='bilinear', align_corners=False)[:,0] - # end.record() - # torch.cuda.synchronize() - # total_process_time += (start.elapsed_time(end)/1000) - total_frames += 1 - if False: prob = torch.flip(prob, dims=[-1]) @@ -167,6 +212,15 @@ def inference_on_video(frames_with_masks, dir_name=''): out_mask = torch.argmax(prob, dim=0) out_mask = (out_mask.detach().cpu().numpy()).astype(np.uint8) + if compute_iou: + # mask is [0, 1] + # gt is [0, 255] + # both -> [False, True] + if gt is not None: + iou = float(compute_tensor_iou(torch.tensor(out_mask).type(torch.bool), gt.type(torch.bool))) + else: + iou = -1 + curr_stat['iou'] = iou if False: prob = (prob.detach().cpu().numpy()*255).astype(np.uint8) @@ -188,16 +242,211 @@ def inference_on_video(frames_with_masks, dir_name=''): if args.save_all or info['save'][0]: hkl.dump(prob, path.join(np_path, f'{frame[:-4]}.hkl'), mode='w', compression='lzf') + stats.append(curr_stat) + + return pd.DataFrame(stats) + + +def run_active_learning(imgs_in_path, masks_in_path, masks_out_path, num_extra_frames: int, uncertainty_name: str, csv_out_path: str = None, mode='batched', use_cache=False): + """ + mode:str + Possible values: + 'uniform': uniformly distributed indices (np.linspace(0, num_total_frames, `num_extra_frames`).astype(int)) + 'random': pick `num_extra_frames` random frames (cannot include first or last ones) + 'batched': Pick only `num_extra_frames` best frames + 'iterative': Pick only 1 best frame instead of `num_extra_frames`, repeat `num_extra_frames` times + """ + + assert mode in {'uniform', 'random', 'batched', 'iterative'} + assert uncertainty_name in {'entropy', 'bald'} + + if mode == 'uniform': + num_total_frames = len(os.listdir(imgs_in_path)) + # linspace is [a, b] (inclusive) + frames_with_masks = np.linspace(0, num_total_frames - 1, num_extra_frames).astype(int) + elif mode == 'random': + num_total_frames = len(os.listdir(imgs_in_path)) + extra_frames = np.random.choice(np.arange(1, num_total_frames), size=num_extra_frames, replace=False).tolist() + frames_with_masks = sorted([0] + extra_frames) + elif mode == 'batched': + # we save baseline results here, with just 1 annotation + baseline_out= Path(masks_out_path).parent.parent / 'baseline' + df = inference_on_video( + imgs_in_path=imgs_in_path, + masks_in_path=masks_in_path, + masks_out_path=baseline_out / 'masks', + frames_with_masks=[0], + compute_uncertainty=True, + compute_iou=True, + uncertainty_name=uncertainty_name, + manually_curated_masks=False, + print_progress=False, + overwrite_config={'save_masks': True}, + ) + + df.to_csv(baseline_out / 'stats.csv', index=False) + + candidates = select_n_frame_candidates(df, n=num_extra_frames, uncertainty_name=uncertainty_name) + + extra_frames = [int(candidate['index']) for candidate in candidates] + + frames_with_masks = sorted([0] + extra_frames) + elif mode == 'iterative': + extra_frames = [] + for i in range(num_extra_frames): + df = inference_on_video( + imgs_in_path=imgs_in_path, + masks_in_path=masks_in_path, + masks_out_path=masks_out_path, + frames_with_masks=[0] + extra_frames, + compute_uncertainty=True, + compute_iou=False, + uncertainty_name=uncertainty_name, + manually_curated_masks=False, + print_progress=False, + overwrite_config={'save_masks': False}, + ) + + max_frame = select_most_uncertain_frame(df) + extra_frames.append(max_frame['index']) + + # keep unsorted to preserve order of the choices + frames_with_masks = [0] + extra_frames + if use_cache and os.path.exists(csv_out_path): + final_df = pd.read_csv(csv_out_path) + else: + final_df = inference_on_video( + imgs_in_path=imgs_in_path, + masks_in_path=masks_in_path, + masks_out_path=masks_out_path, + frames_with_masks=frames_with_masks, + compute_uncertainty=True, + compute_iou=True, + print_progress=False, + uncertainty_name=uncertainty_name, + manually_curated_masks=False, + ) + + if csv_out_path is not None: + p_csv_out = Path(csv_out_path) + + if not p_csv_out.parent.exists(): + p_csv_out.parent.mkdir(parents=True) + + final_df.to_csv(p_csv_out, index=False) + + return final_df, frames_with_masks + + +def eval_active_learning(dataset_path: str, out_path: str, num_extra_frames: int, uncertainty_name: str): + assert uncertainty_name in {'entropy', 'bald'} + + p_in_ds = Path(dataset_path) + p_out = Path(out_path) + + big_stats = defaultdict(list) + for p_video_imgs_in in tqdm(sorted((p_in_ds / 'JPEGImages').iterdir())): + video_name = p_video_imgs_in.stem + p_video_masks_in = p_in_ds / 'Annotations_binarized' / video_name + + p_video_out_general = p_out / f'Active_learning_{uncertainty_name}' / video_name / f'{num_extra_frames}_extra_frames' + + for mode in ['uniform', 'random', 'batched', 'iterative']: + curr_video_stat = {'video': video_name} + p_out_masks = p_video_out_general / mode / 'masks' + p_out_stats = p_video_out_general / mode / 'stats.csv' - print(f'Total processing time: {total_process_time}') - print(f'Total processed frames: {total_frames}') - print(f'FPS: {total_frames / total_process_time}') - print(f'Max allocated memory (MB): {torch.cuda.max_memory_allocated() / (2**20)}') + stats, frames_with_masks = run_active_learning(p_video_imgs_in, p_video_masks_in, p_out_masks, + num_extra_frames=num_extra_frames, csv_out_path=p_out_stats, mode=mode, use_cache=True) + + stats = stats[stats['mask_provided'] == False] # remove stats for frames with given masks + for i in range(1, len(frames_with_masks) + 1): + curr_video_stat[f'extra_frame_{i}'] = frames_with_masks[i - 1] + + curr_video_stat[f'mean_iou'] = stats['iou'].mean() + curr_video_stat[f'mean_{uncertainty_name}'] = stats[uncertainty_name].mean() + + big_stats[mode].append(curr_video_stat) + + for mode, mode_stats in big_stats.items(): + df_mode_stats = pd.DataFrame(mode_stats) + df_mode_stats.to_csv(p_out / f'Active_learning_{uncertainty_name}' / f'stats_{mode}_all_videos.csv', index=False) if __name__ == '__main__': - for how_many_extra_frames in tqdm(range(0, 91)): - # e.g. [0, 10, 20, ..., 180] without 180 - frames_with_masks = set(np.linspace(0, 180, how_many_extra_frames+2)[0:-1].astype(int)) + pass + # eval_active_learning('/home/maksym/RESEARCH/VIDEOS/LVOS_dataset/valid', + # '/home/maksym/RESEARCH/VIDEOS/RESULTS/XMem_feedback/permanent_work_memory/LVOS', + # 5) + + + # res, frames_with_masks = run_active_learning('/home/maksym/RESEARCH/VIDEOS/LVOS_dataset/valid/JPEGImages/0tCWPOrc', + # '/home/maksym/RESEARCH/VIDEOS/LVOS_dataset/valid/Annotations_binarized/0tCWPOrc', + # '/home/maksym/RESEARCH/VIDEOS/RESULTS/XMem_feedback/permanent_work_memory/LVOS/JUNK/masks', + # num_extra_frames=5, + # csv_out_path='/home/maksym/RESEARCH/VIDEOS/RESULTS/XMem_feedback/permanent_work_memory/LVOS/JUNK/stats.csv', mode='iterative') + + # print(frames_with_masks) + # pass + # bald_df = inference_on_video([0], + # '/home/maksym/RESEARCH/VIDEOS/thanks_no_ears_5_annot/JPEGImages', + # '/home/maksym/RESEARCH/VIDEOS/thanks_no_ears_5_annot/Annotations_binarized', + # 'JUNK', + # compute_iou=False, + # compute_uncertainty=True, + # use_bald=True) # for t.hanks style video + + # bald_df.to_csv('output/bald_thanks_0_frame.csv', index=False) + + # df = inference_on_video( + # imgs_in_path='/home/maksym/RESEARCH/VIDEOS/thanks_no_ears_5_annot/JPEGImages', + # masks_in_path='/home/maksym/RESEARCH/VIDEOS/thanks_no_ears_5_annot/Annotations_binarized', + # frames_with_masks=[0, 259, 621, 785, 1401], + # masks_out_path='../VIDEOS/RESULTS/XMem_feedback/BASELINE_REIMPLEMENTED/5_annotated_frames_new_mem', + # compute_uncertainty=False, + # compute_iou=False, + # manually_curated_masks=False, + # original_memory_mechanism=False, + # overwrite_config={'save_masks': True}) + + # df = inference_on_video( + # imgs_in_path='/home/maksym/RESEARCH/VIDEOS/LVOS_dataset/valid/JPEGImages/vjG0jbkQ', + # masks_in_path='/home/maksym/RESEARCH/VIDEOS/LVOS_dataset/valid/Annotations/vjG0jbkQ', + # masks_out_path='JUNK', + # frames_with_masks=[0], + # compute_entropy=True, + # compute_iou=True, + # manually_curated_masks=False, + # overwrite_config={'save_masks': False}) + + # print(df.shape) + # df.to_csv('junk.csv', index=False) + # p_in = Path('/home/maksym/RESEARCH/VIDEOS/DAVIS-2017-trainval-480p/DAVIS/2017_train_val_split/val/JPEGImages_chosen') + # p_in = Path('/home/maksym/RESEARCH/VIDEOS/DAVIS-2017-trainval-480p/DAVIS/2017_train_val_split/val/JPEGImages_chosen') + # num_frames_mapping = {} + + # for p_dir in sorted(p for p in p_in.iterdir() if p.is_dir()): + # dir_name = p_dir.name + # num_frames = sum(1 for _ in p_dir.iterdir()) + # num_frames_mapping[dir_name] = num_frames # math.ceil(num_frames/2) + # # print(extra_frames_ranges) + # # exit(0) + # p_bar = tqdm(total=sum(num_frames_mapping.values()), desc='% extra frames DAVIS 2017 val') + # for dir_name, total_frames in num_frames_mapping.items(): + # for how_many_extra_frames in range(0, math.ceil(total_frames // 2)): + # # frames_with_masks = set([0, frame_with_mask]) + # frames_with_masks = set(np.linspace(0, num_frames, how_many_extra_frames+2)[0:-1].astype(int)) + # inference_on_video(frames_with_masks, dir_name) + + # p_bar.update() + + # num_runs = 90 + # p_bar = tqdm(total=num_runs) + # for how_many_extra_frames in range(0, 90): + # # for j in range(0, 181 - 1): + # # e.g. [0, 10, 20, ..., 180] without 180 + # frames_with_masks = set(np.linspace(0, 180, how_many_extra_frames+2)[0:-1].astype(int)) + # # frames_with_masks = set([0, i, j]) + # inference_on_video(frames_with_masks) - inference_on_video(how_many_extra_frames, frames_with_masks) + # p_bar.update() From f0caad8a8fd86dd2c1d605c0d4e97b01fea997f0 Mon Sep 17 00:00:00 2001 From: Longnhat Date: Wed, 4 Jan 2023 19:08:25 +0400 Subject: [PATCH 05/49] Bug fixes --- inference/inference_core.py | 4 +++- inference/memory_manager.py | 21 +++++++++++---------- util/tensor_util.py | 14 +++++++++++++- 3 files changed, 27 insertions(+), 12 deletions(-) diff --git a/inference/inference_core.py b/inference/inference_core.py index b81f4a7..fdc28a7 100644 --- a/inference/inference_core.py +++ b/inference/inference_core.py @@ -49,6 +49,7 @@ def step(self, image, mask=None, valid_labels=None, end=False, manually_curated_ # image: 3*H*W # mask: num_objects*H*W or None self.curr_ti += 1 + image, self.pad = pad_divide_by(image, 16) image = image.unsqueeze(0) # add the batch dimension if manually_curated_masks: @@ -74,10 +75,11 @@ def step(self, image, mask=None, valid_labels=None, end=False, manually_curated_ is_normal_update = False is_deep_update = False is_mem_frame = False + self.curr_ti -= 1 # do not advance the iteration further # segment the current frame is needed if need_segment: - memory_readout = self.memory.match_memory(key, selection).unsqueeze(0) + memory_readout = self.memory.match_memory(key, selection, disable_usage_updates=disable_memory_updates).unsqueeze(0) hidden, _, pred_prob_with_bg = self.network.segment(multi_scale_features, memory_readout, self.memory.get_hidden(), h_out=is_normal_update, strip_bg=False) # remove batch dim diff --git a/inference/memory_manager.py b/inference/memory_manager.py index bb132cf..aadf88d 100644 --- a/inference/memory_manager.py +++ b/inference/memory_manager.py @@ -56,7 +56,7 @@ def _readout(self, affinity, v): # this function is for a single object group return v @ affinity - def match_memory(self, query_key, selection): + def match_memory(self, query_key, selection, disable_usage_updates=False): # query_key: B x C^k x H x W # selection: B x C^k x H x W # TODO: keep groups in both..? @@ -122,14 +122,15 @@ def match_memory(self, query_key, selection): """ Record memory usage for working and long-term memory """ + if not disable_usage_updates: # ignore the index return for long-term memory - work_usage = usage[:, long_mem_size:long_mem_size+temp_work_mem_size] # no usage for permanent memory - self.temporary_work_mem.update_usage(work_usage.flatten()) + work_usage = usage[:, long_mem_size:long_mem_size+temp_work_mem_size] # no usage for permanent memory + self.temporary_work_mem.update_usage(work_usage.flatten()) - if self.enable_long_term_usage: - # ignore the index return for working memory - long_usage = usage[:, :long_mem_size] - self.long_mem.update_usage(long_usage.flatten()) + if self.enable_long_term_usage: + # ignore the index return for working memory + long_usage = usage[:, :long_mem_size] + self.long_mem.update_usage(long_usage.flatten()) else: memory_key = torch.cat([self.temporary_work_mem.key, self.permanent_work_mem.key], -1) shrinkage = torch.cat([self.temporary_work_mem.shrinkage, self.permanent_work_mem.shrinkage], -1) @@ -139,9 +140,9 @@ def match_memory(self, query_key, selection): if self.enable_long_term: affinity, usage = do_softmax(similarity, inplace=(num_groups == 1), top_k=self.top_k, return_usage=True) - - # Record memory usage for working memory - self.temporary_work_mem.update_usage(usage[:, :temp_work_mem_size].flatten()) + if not disable_usage_updates: + # Record memory usage for working memory + self.temporary_work_mem.update_usage(usage[:, :temp_work_mem_size].flatten()) else: affinity = do_softmax(similarity, inplace=(num_groups == 1), top_k=self.top_k, return_usage=False) diff --git a/util/tensor_util.py b/util/tensor_util.py index 05189d3..bdf1ec3 100644 --- a/util/tensor_util.py +++ b/util/tensor_util.py @@ -1,4 +1,5 @@ import torch.nn.functional as F +import torch def compute_tensor_iu(seg, gt): @@ -44,4 +45,15 @@ def unpad(img, pad): img = img[:,:,pad[0]:-pad[1]] else: raise NotImplementedError - return img \ No newline at end of file + return img + +def get_bbox_from_mask(mask): + mask = torch.squeeze(mask) + assert mask.ndim == 2 + + nonzero = torch.nonzero(mask) + + min_y, min_x = nonzero.min(dim=0).values + max_y, max_x = nonzero.max(dim=0).values + + return int(min_y), int(min_x), int(max_y), int(max_x) \ No newline at end of file From 77a0218bcee71a6c4f71dd26288697ab8ffcb22e Mon Sep 17 00:00:00 2001 From: Longnhat Date: Wed, 4 Jan 2023 19:19:48 +0400 Subject: [PATCH 06/49] Quality of life improvements - fixed augmentations for frames in memory, fixed active learning/eval procedures in run_on_video.py --- inference/active_learning.py | 259 ++++++++++++++++++++++++++++++++++- run_on_video.py | 135 ++++++++++++++---- 2 files changed, 358 insertions(+), 36 deletions(-) diff --git a/inference/active_learning.py b/inference/active_learning.py index 4a3b4f8..da14bec 100644 --- a/inference/active_learning.py +++ b/inference/active_learning.py @@ -1,7 +1,14 @@ +from functools import partial +from os import access +from pathlib import Path from sklearn.cluster import KMeans import pandas as pd from PIL import Image -from torchvision.transforms import ColorJitter, Grayscale, RandomPosterize, RandomAdjustSharpness, ToTensor +import torch +import torchvision.transforms.functional as FT +from torchvision.transforms import ColorJitter, Grayscale, RandomPosterize, RandomAdjustSharpness, ToTensor, RandomAffine + +from util.tensor_util import get_bbox_from_mask def select_n_frame_candidates(preds_df: pd.DataFrame, uncertainty_name: str, n=5): df = preds_df @@ -42,20 +49,203 @@ def select_n_frame_candidates(preds_df: pd.DataFrame, uncertainty_name: str, n=5 return candidates def select_most_uncertain_frame(preds_df: pd.DataFrame, uncertainty_name: str): - preds_df.reset_index(drop=False, inplace=True) - return preds_df.iloc[preds_df[uncertainty_name].argmax()] + df = preds_df[preds_df['mask_provided'] == False] + df.reset_index(drop=False, inplace=True) + return df.iloc[df[uncertainty_name].argmax()] + +def select_n_frame_candidates_no_neighbours_simple(preds_df: pd.DataFrame, uncertainty_name: str, n=5, neighbourhood_size=4): + df = preds_df + df.reset_index(drop=False, inplace=True) + + df = df[df['mask_provided'] == False] # removing frames with masks + + neighbours_indices = set() + chosen_candidates = [] + + df_sorted = df.sort_values(uncertainty_name, ascending=False) + i = 0 + while len(chosen_candidates) < n: + candidate = df_sorted.iloc[i] + candidate_index = candidate['index'] + + if candidate_index not in neighbours_indices: + chosen_candidates.append(candidate) + candidate_neighbours = range(candidate_index - neighbourhood_size, candidate_index + neighbourhood_size + 1) + neighbours_indices.update(candidate_neighbours) + + i += 1 + + return chosen_candidates -def get_determenistic_augmentations(): - # TODO: maybe add GaussianBlur? +WhichAugToPick = -1 +def get_determenistic_augmentations(img_size=None, mask=None, subset: str=None): + assert subset in {'best_3', 'best_3_with_symmetrical', 'best_all', 'original_only', 'all'} + bright = ColorJitter(brightness=(1.5, 1.5)) dark = ColorJitter(brightness=(0.5, 0.5)) gray = Grayscale(num_output_channels=3) reduce_bits = RandomPosterize(bits=3, p=1) sharp = RandomAdjustSharpness(sharpness_factor=16, p=1) + rotate_right = RandomAffine(degrees=(30, 30)) + blur = partial(FT.gaussian_blur, kernel_size=7) + + if img_size is not None: + h, w = img_size[-2:] + translate_distance = w // 5 + else: + translate_distance = 200 + + translate_right = partial(FT.affine, angle=0, translate=(translate_distance, 0), scale=1, shear=0) + + zoom_out = partial(FT.affine, angle=0, translate=(0, 0), scale=0.5, shear=0) + zoom_in = partial(FT.affine, angle=0, translate=(0, 0), scale=1.5, shear=0) + shear_right = partial(FT.affine, angle=0, translate=(0, 0), scale=1, shear=20) + + identity = torch.nn.Identity() + identity.name = 'identity' + + if mask is not None: + if mask.any(): + min_y, min_x, max_y, max_x = get_bbox_from_mask(mask) + h, w = mask.shape[-2:] + crop_mask = partial(FT.resized_crop, top=min_y - 10, left=min_x - 10, height=max_y - min_y + 10, width=max_x - min_x + 10, size=(w, h)) + crop_mask.name = 'crop_mask' + else: + crop_mask = identity # if the mask is empty + else: + crop_mask = None + + bright.name = 'bright' + dark.name = 'dark' + gray.name = 'gray' + reduce_bits.name = 'reduce_bits' + sharp.name = 'sharp' + rotate_right.name = 'rotate_right' + translate_right.name = 'translate_right' + zoom_out.name = 'zoom_out' + zoom_in.name = 'zoom_in' + shear_right.name = 'shear_right' + blur.name = 'blur' + + + + rotate_left = RandomAffine(degrees=(-30, -30)) + rotate_left.name = 'rotate_left' + + shear_left = partial(FT.affine, angle=0, translate=(0, 0), scale=1, shear=-20) + shear_left.name = 'shear_left' - return [bright, dark, gray, reduce_bits, sharp] + if WhichAugToPick != -1: + return [img_mask_augs_pairs[WhichAugToPick]] + + if subset == 'best_3': + img_mask_augs_pairs = [ + # augs only applied to the image + # (bright, identity), + # (dark, identity), + # (gray, identity), + # (reduce_bits, identity), + # (sharp, identity), + (blur, identity), + + # augs requiring modifying the mask as well: + # (rotate_right, rotate_right), + # (rotate_left, rotate_left), + # (translate_right, translate_right), + # (zoom_out, zoom_out), + (zoom_in, zoom_in), + (shear_right, shear_right), + # (shear_left, shear_left), + ] + + return img_mask_augs_pairs + elif subset == 'best_3_with_symmetrical': + img_mask_augs_pairs = [ + # augs only applied to the image + # (bright, identity), + # (dark, identity), + # (gray, identity), + # (reduce_bits, identity), + # (sharp, identity), + (blur, identity), + + # augs requiring modifying the mask as well: + # (rotate_right, rotate_right), + # (rotate_left, rotate_left), + # (translate_right, translate_right), + # (zoom_out, zoom_out), + (zoom_in, zoom_in), + (shear_right, shear_right), + (shear_left, shear_left), + ] + + return img_mask_augs_pairs + elif subset == 'best_all': + img_mask_augs_pairs = [ + # augs only applied to the image + (bright, identity), + (dark, identity), + # (gray, identity), + (reduce_bits, identity), + (sharp, identity), + (blur, identity), + + # augs requiring modifying the mask as well: + (rotate_right, rotate_right), + (rotate_left, rotate_left), + # (translate_right, translate_right), + (zoom_out, zoom_out), + (zoom_in, zoom_in), + (shear_right, shear_right), + (shear_left, shear_left), + ] + + return img_mask_augs_pairs + + elif subset == 'original_only': + img_mask_augs_pairs = [ + # augs only applied to the image + (bright, identity), + (dark, identity), + (gray, identity), + (reduce_bits, identity), + (sharp, identity), + (blur, identity), + + # augs requiring modifying the mask as well: + # (rotate_right, rotate_right), + # (translate_right, translate_right), + # (zoom_out, zoom_out), + # (zoom_in, zoom_in), + # (shear_right, shear_right), + ] + else: + img_mask_augs_pairs = [ + # augs only applied to the image + (bright, identity), + (dark, identity), + (gray, identity), + (reduce_bits, identity), + (sharp, identity), + (blur, identity), + + # augs requiring modifying the mask as well: + (rotate_right, rotate_right), + (rotate_left, rotate_left), + (translate_right, translate_right), + (zoom_out, zoom_out), + (zoom_in, zoom_in), + (shear_right, shear_right), + (shear_left, shear_left), + ] + + if crop_mask is not None: + img_mask_augs_pairs.append((crop_mask, crop_mask)) + + return img_mask_augs_pairs + def apply_aug(img_path, out_path): img = Image.open(img_path) @@ -67,6 +257,63 @@ def apply_aug(img_path, out_path): img_augged.save(out_path) +def compute_disparity(predictions, augs, images:list = None, output_save_path: str = None): + assert len(predictions) - len(augs) == 1 + disparity_map = None + prev = None + + if images is None: + images = [None] * len(predictions) + else: + assert len(predictions) == len(images) + + if output_save_path is not None: + p_out_disparity = Path(output_save_path) + else: + p_out_disparity = None + + try: + aug_names = [aug.name for aug in augs] + except AttributeError: + aug_names = [aug._get_name() for aug in augs] + + names = ['original'] + aug_names + for i, (name, img, pred) in enumerate(zip(names, images, predictions)): + fg_mask = pred[1:2].squeeze().cpu() # 1:2 is Foreground + + if disparity_map is None: + disparity_map = torch.zeros_like(fg_mask) + else: + disparity_map += (prev - fg_mask).abs() + + pred_mask_ = FT.to_pil_image(fg_mask) + if p_out_disparity is not None: + p_out_save_mask = p_out_disparity / 'masks' / (f'{i}_{name}.png') + p_out_save_image = p_out_disparity / 'images' / (f'{i}_{name}.png') + + if not p_out_save_mask.parent.exists(): + p_out_save_mask.parent.mkdir(parents=True) + + pred_mask_.save(p_out_save_mask) + + if not p_out_save_image.parent.exists(): + p_out_save_image.parent.mkdir(parents=True) + + img.save(p_out_save_image) + + prev = fg_mask + + disparity_scaled = disparity_map / (len(augs) + 1) # 0..1; not `disparity_map.max()`, as the scale would differ across images + disparity_avg = disparity_scaled.mean() + disparity_large = (disparity_scaled > 0.5).sum() # num pixels with large disparities + + if p_out_disparity is not None: + disparity_img = FT.to_pil_image(disparity_scaled) + disparity_img.save(p_out_disparity / (f'{i+1}_absolute_disparity.png')) + + return {'full': disparity_scaled, 'avg': disparity_avg, 'large': disparity_large} + + if __name__ == '__main__': img_in = '/home/maksym/RESEARCH/VIDEOS/thanks_no_ears_5_annot/JPEGImages/frame_000001.PNG' img_out = 'test_aug.png' diff --git a/run_on_video.py b/run_on_video.py index 8afab2d..239813c 100644 --- a/run_on_video.py +++ b/run_on_video.py @@ -17,7 +17,7 @@ from baal.active.heuristics import BALD import torchvision.transforms.functional as FT -from inference.active_learning import get_determenistic_augmentations, select_most_uncertain_frame, select_n_frame_candidates +from inference.active_learning import get_determenistic_augmentations, select_most_uncertain_frame, select_n_frame_candidates, compute_disparity as compute_disparity_func, select_n_frame_candidates_no_neighbours_simple from inference.data.test_datasets import LongTestDataset, DAVISTestDataset, YouTubeVOSTestDataset from inference.data.mask_mapper import MaskMapper from inference.data.video_reader import VideoReader @@ -29,6 +29,7 @@ def inference_on_video(frames_with_masks, imgs_in_path, masks_in_path, masks_out_path, original_memory_mechanism=False, compute_iou = False, compute_uncertainty = False, manually_curated_masks=False, print_progress=True, + augment_images_with_masks=False, uncertainty_name: str = None, overwrite_config: dict = None): torch.autograd.set_grad_enabled(False) @@ -55,7 +56,7 @@ def inference_on_video(frames_with_masks, imgs_in_path, masks_in_path, masks_out 'top_k': 30, 'value_dim': 512, # 'video': '../VIDEOS/Maksym_frontal_simple.mp4', - 'masks_out_path': masks_out_path,#f'../VIDEOS/RESULTS/XMem_feedback/thanks_two_face_5_frames/', + 'masks_out_path': masks_out_path,#f'../VIDEOS/RESULTS/XMem_memory/thanks_two_face_5_frames/', # 'masks_out_path': f'../VIDEOS/RESULTS/XMem/WhichFramesWithPreds/1/{dir_name}/{"_".join(map(str, frames_with_masks))}_frames_provided', # 'masks_out_path': f'../VIDEOS/RESULTS/XMem/DAVIS_2017/WhichFrames/1/{dir_name}/{len(frames_with_masks) - 1}_extra_frames', 'workspace': None, @@ -65,6 +66,14 @@ def inference_on_video(frames_with_masks, imgs_in_path, masks_in_path, masks_out if overwrite_config is not None: config.update(overwrite_config) + if compute_uncertainty: + assert uncertainty_name is not None + uncertainty_name = uncertainty_name.lower() + assert uncertainty_name in {'entropy', 'bald', 'disparity', 'disparity_large'} + compute_disparity = uncertainty_name.startswith('disparity') + else: + compute_disparity = False + vid_reader = VideoReader( "", imgs_in_path, #f'/home/maksym/RESEARCH/VIDEOS/thanks_no_ears_5_annot/JPEGImages', @@ -105,6 +114,7 @@ def inference_on_video(frames_with_masks, imgs_in_path, masks_in_path, masks_out for j in frames_to_put_in_permanent_memory: sample = vid_reader[j] rgb = sample['rgb'].cuda() + rgb_raw_tensor = sample['raw_image_tensor'].cpu() msk = sample['mask'] info = sample['info'] need_resize = info['need_resize'] @@ -118,14 +128,21 @@ def inference_on_video(frames_with_masks, imgs_in_path, masks_in_path, masks_out processor.set_all_labels(list(mapper.remappings.values())) processor.put_to_permanent_memory(rgb, msk) + if augment_images_with_masks: + augs = get_determenistic_augmentations(rgb.shape, msk, subset='best_all') + rgb_raw = FT.to_pil_image(rgb_raw_tensor) + + for img_aug, mask_aug in augs: + # tensor -> PIL.Image -> tensor -> whatever normalization vid_reader applies + rgb_aug = vid_reader.im_transform(img_aug(rgb_raw)).cuda() + + msk_aug = mask_aug(msk) + + processor.put_to_permanent_memory(rgb_aug, msk_aug) + stats = [] - if compute_uncertainty: - assert uncertainty_name is not None - uncertainty_name = uncertainty_name.lower() - assert uncertainty_name in {'entropy', 'bald'} - - if uncertainty_name == 'bald': + if compute_uncertainty and uncertainty_name == 'bald': bald = BALD() for ti, data in enumerate(tqdm(loader, disable=not print_progress)): @@ -169,13 +186,18 @@ def inference_on_video(frames_with_masks, imgs_in_path, masks_in_path, masks_out else: labels = None - if compute_uncertainty and uncertainty_name == 'bald': + if (compute_uncertainty and uncertainty_name == 'bald') or compute_disparity: dry_run_preds = [] - augs = get_determenistic_augmentations() - for aug in augs: + augged_images = [] + augs = get_determenistic_augmentations(subset='original_only') + rgb_raw = FT.to_pil_image(rgb_raw_tensor) + for img_aug, mask_aug in augs: # tensor -> PIL.Image -> tensor -> whatever normalization vid_reader applies - rgb_raw = FT.to_pil_image(rgb_raw_tensor) - rgb_aug = vid_reader.im_transform(aug(rgb_raw)).cuda() + augged_img = img_aug(rgb_raw) + augged_images.append(augged_img) + rgb_aug = vid_reader.im_transform(augged_img).cuda() + + msk = mask_aug(msk) # does not do anything, since original_only=True augmentations don't alter the mask at all dry_run_prob = processor.step(rgb_aug, msk, labels, end=(ti==vid_length-1), manually_curated_masks=manually_curated_masks, disable_memory_updates=True) dry_run_preds.append(dry_run_prob.cpu()) @@ -195,12 +217,22 @@ def inference_on_video(frames_with_masks, imgs_in_path, masks_in_path, masks_out all_samples = torch.stack([x.unsqueeze(0) for x in dry_run_preds + [prob.cpu()]], dim=-1).numpy() score = bald.compute_score(all_samples) # TODO: can also return the exact pixels for every frame? As a suggestion on what to label - curr_stat['bald'] = float(np.squeeze(score).mean()) + curr_stat['bald'] = float(np.squeeze(score).mean()) + elif compute_disparity: + # p_out_disparity = Path('output/masks_disparity/') + # if ti in {0, 200, 500, 900, 1100, 1300, 1450, 1600}: + # output_save_path = p_out_disparity / str(ti) + # else: + # output_save_path = None + + disparity_stats = compute_disparity_func(predictions=[prob] + dry_run_preds, augs=[img_aug for img_aug, _ in augs], images=[rgb_raw] + augged_images, output_save_path=None) + curr_stat['disparity'] = float(disparity_stats['avg']) + curr_stat['disparity_large'] = float(disparity_stats['large']) else: e = entropy(prob.cpu()) e_mean = np.mean(e) curr_stat['entropy'] = float(e_mean) - + # Upsample to original size if needed if need_resize: prob = F.interpolate(prob.unsqueeze(1), shape, mode='bilinear', align_corners=False)[:,0] @@ -247,25 +279,26 @@ def inference_on_video(frames_with_masks, imgs_in_path, masks_in_path, masks_out return pd.DataFrame(stats) -def run_active_learning(imgs_in_path, masks_in_path, masks_out_path, num_extra_frames: int, uncertainty_name: str, csv_out_path: str = None, mode='batched', use_cache=False): +def run_active_learning(imgs_in_path, masks_in_path, masks_out_path, num_extra_frames: int, uncertainty_name: str, csv_out_path: str = None, mode='batched', use_cache=False, **kwargs): """ mode:str Possible values: - 'uniform': uniformly distributed indices (np.linspace(0, num_total_frames, `num_extra_frames`).astype(int)) + 'uniform': uniformly distributed indices np.linspace(0, `num_total_frames` - 1, `num_extra_frames` + 1).astype(int) 'random': pick `num_extra_frames` random frames (cannot include first or last ones) 'batched': Pick only `num_extra_frames` best frames 'iterative': Pick only 1 best frame instead of `num_extra_frames`, repeat `num_extra_frames` times """ assert mode in {'uniform', 'random', 'batched', 'iterative'} - assert uncertainty_name in {'entropy', 'bald'} + assert uncertainty_name in {'entropy', 'bald', 'disparity', 'disparity_large'} if mode == 'uniform': num_total_frames = len(os.listdir(imgs_in_path)) # linspace is [a, b] (inclusive) - frames_with_masks = np.linspace(0, num_total_frames - 1, num_extra_frames).astype(int) + frames_with_masks = np.linspace(0, num_total_frames - 1, num_extra_frames + 1).astype(int) elif mode == 'random': num_total_frames = len(os.listdir(imgs_in_path)) + np.random.seed(1) extra_frames = np.random.choice(np.arange(1, num_total_frames), size=num_extra_frames, replace=False).tolist() frames_with_masks = sorted([0] + extra_frames) elif mode == 'batched': @@ -285,8 +318,10 @@ def run_active_learning(imgs_in_path, masks_in_path, masks_out_path, num_extra_f ) df.to_csv(baseline_out / 'stats.csv', index=False) - - candidates = select_n_frame_candidates(df, n=num_extra_frames, uncertainty_name=uncertainty_name) + if uncertainty_name == 'disparity_large': + candidates = select_n_frame_candidates_no_neighbours_simple(df, n=num_extra_frames, uncertainty_name=uncertainty_name) + else: + candidates = select_n_frame_candidates(df, n=num_extra_frames, uncertainty_name=uncertainty_name) extra_frames = [int(candidate['index']) for candidate in candidates] @@ -307,7 +342,7 @@ def run_active_learning(imgs_in_path, masks_in_path, masks_out_path, num_extra_f overwrite_config={'save_masks': False}, ) - max_frame = select_most_uncertain_frame(df) + max_frame = select_most_uncertain_frame(df, uncertainty_name=uncertainty_name) extra_frames.append(max_frame['index']) # keep unsorted to preserve order of the choices @@ -324,7 +359,7 @@ def run_active_learning(imgs_in_path, masks_in_path, masks_out_path, num_extra_f compute_iou=True, print_progress=False, uncertainty_name=uncertainty_name, - manually_curated_masks=False, + **kwargs ) if csv_out_path is not None: @@ -338,9 +373,12 @@ def run_active_learning(imgs_in_path, masks_in_path, masks_out_path, num_extra_f return final_df, frames_with_masks -def eval_active_learning(dataset_path: str, out_path: str, num_extra_frames: int, uncertainty_name: str): - assert uncertainty_name in {'entropy', 'bald'} +def eval_active_learning(dataset_path: str, out_path: str, num_extra_frames: int, uncertainty_name: str, modes: list = None, **kwargs): + assert uncertainty_name in {'entropy', 'bald', 'disparity', 'disparity_large'} + if modes is None: + modes = ['uniform', 'random', 'batched', 'iterative'] + p_in_ds = Path(dataset_path) p_out = Path(out_path) @@ -351,13 +389,13 @@ def eval_active_learning(dataset_path: str, out_path: str, num_extra_frames: int p_video_out_general = p_out / f'Active_learning_{uncertainty_name}' / video_name / f'{num_extra_frames}_extra_frames' - for mode in ['uniform', 'random', 'batched', 'iterative']: + for mode in modes: curr_video_stat = {'video': video_name} p_out_masks = p_video_out_general / mode / 'masks' p_out_stats = p_video_out_general / mode / 'stats.csv' stats, frames_with_masks = run_active_learning(p_video_imgs_in, p_video_masks_in, p_out_masks, - num_extra_frames=num_extra_frames, csv_out_path=p_out_stats, mode=mode, use_cache=True) + num_extra_frames=num_extra_frames, csv_out_path=p_out_stats, mode=mode, uncertainty_name=uncertainty_name, use_cache=False, **kwargs) stats = stats[stats['mask_provided'] == False] # remove stats for frames with given masks for i in range(1, len(frames_with_masks) + 1): @@ -375,9 +413,46 @@ def eval_active_learning(dataset_path: str, out_path: str, num_extra_frames: int if __name__ == '__main__': pass - # eval_active_learning('/home/maksym/RESEARCH/VIDEOS/LVOS_dataset/valid', - # '/home/maksym/RESEARCH/VIDEOS/RESULTS/XMem_feedback/permanent_work_memory/LVOS', - # 5) + # inference_on_video( + # imgs_in_path='../VIDEOS/thanks_no_ears_5_annot/JPEGImages', + # masks_in_path='../VIDEOS/thanks_no_ears_5_annot/annotations_3_face', + # masks_out_path='../VIDEOS/RESULTS/XMem_memory/permanent_work_memory/thanks_no_ears_5_annot/annotations_3_face_5_frames', + # frames_with_masks=[0, 625, 785, 1300, 1488], + # compute_uncertainty=False, + # compute_iou=False, + # print_progress=True, + # manually_curated_masks=False, + # ) + # df.to_csv('output/disparity/disparity.csv', index=False) + # exit(0) + # from inference import active_learning as AL + + # img_size = (3, 480, 853) + # test_mask = FT.to_tensor(Image.open('test.png')) + # for i in tqdm(range(0, 10)): + # eval_active_learning('../VIDEOS/LVOS_dataset/valid', + # f'../VIDEOS/RESULTS/XMem_memory/permanent_work_memory/LVOS_old_scaling/{i}_extra_frames', + # i, + # uncertainty_name='entropy', + # modes=['uniform'], + # augment_images_with_mask=False, + # original_memory_mechanism=True) + + # eval_active_learning('../VIDEOS/LVOS_dataset/valid', + # '../VIDEOS/RESULTS/XMem_memory/permanent_work_memory/LVOS_best_all_augs_in_memory_fixed', + # 5, + # uncertainty_name='entropy', + # modes=['uniform'], + # augment_images_with_masks=True, + # ) + + # eval_active_learning('../VIDEOS/LVOS_dataset/valid', + # '../VIDEOS/RESULTS/XMem_memory/LVOS/LVOS_disparity', + # 5, + # uncertainty_name='disparity_large', + # modes=['random', 'uniform', 'batched', 'iterative'], + # augment_images_with_masks=False) + # res, frames_with_masks = run_active_learning('/home/maksym/RESEARCH/VIDEOS/LVOS_dataset/valid/JPEGImages/0tCWPOrc', From 32599cf94b7f63a6f02429e252b447c5c4e6ace3 Mon Sep 17 00:00:00 2001 From: max810 Date: Mon, 23 Jan 2023 18:08:14 +0400 Subject: [PATCH 07/49] Bug fixes for using permanent memory and multi-inference Active Learning tehcniques --- inference/inference_core.py | 4 +++- inference/memory_manager.py | 21 +++++++++++---------- model/modules.py | 10 +++++----- 3 files changed, 19 insertions(+), 16 deletions(-) diff --git a/inference/inference_core.py b/inference/inference_core.py index b81f4a7..fdc28a7 100644 --- a/inference/inference_core.py +++ b/inference/inference_core.py @@ -49,6 +49,7 @@ def step(self, image, mask=None, valid_labels=None, end=False, manually_curated_ # image: 3*H*W # mask: num_objects*H*W or None self.curr_ti += 1 + image, self.pad = pad_divide_by(image, 16) image = image.unsqueeze(0) # add the batch dimension if manually_curated_masks: @@ -74,10 +75,11 @@ def step(self, image, mask=None, valid_labels=None, end=False, manually_curated_ is_normal_update = False is_deep_update = False is_mem_frame = False + self.curr_ti -= 1 # do not advance the iteration further # segment the current frame is needed if need_segment: - memory_readout = self.memory.match_memory(key, selection).unsqueeze(0) + memory_readout = self.memory.match_memory(key, selection, disable_usage_updates=disable_memory_updates).unsqueeze(0) hidden, _, pred_prob_with_bg = self.network.segment(multi_scale_features, memory_readout, self.memory.get_hidden(), h_out=is_normal_update, strip_bg=False) # remove batch dim diff --git a/inference/memory_manager.py b/inference/memory_manager.py index bb132cf..aadf88d 100644 --- a/inference/memory_manager.py +++ b/inference/memory_manager.py @@ -56,7 +56,7 @@ def _readout(self, affinity, v): # this function is for a single object group return v @ affinity - def match_memory(self, query_key, selection): + def match_memory(self, query_key, selection, disable_usage_updates=False): # query_key: B x C^k x H x W # selection: B x C^k x H x W # TODO: keep groups in both..? @@ -122,14 +122,15 @@ def match_memory(self, query_key, selection): """ Record memory usage for working and long-term memory """ + if not disable_usage_updates: # ignore the index return for long-term memory - work_usage = usage[:, long_mem_size:long_mem_size+temp_work_mem_size] # no usage for permanent memory - self.temporary_work_mem.update_usage(work_usage.flatten()) + work_usage = usage[:, long_mem_size:long_mem_size+temp_work_mem_size] # no usage for permanent memory + self.temporary_work_mem.update_usage(work_usage.flatten()) - if self.enable_long_term_usage: - # ignore the index return for working memory - long_usage = usage[:, :long_mem_size] - self.long_mem.update_usage(long_usage.flatten()) + if self.enable_long_term_usage: + # ignore the index return for working memory + long_usage = usage[:, :long_mem_size] + self.long_mem.update_usage(long_usage.flatten()) else: memory_key = torch.cat([self.temporary_work_mem.key, self.permanent_work_mem.key], -1) shrinkage = torch.cat([self.temporary_work_mem.shrinkage, self.permanent_work_mem.shrinkage], -1) @@ -139,9 +140,9 @@ def match_memory(self, query_key, selection): if self.enable_long_term: affinity, usage = do_softmax(similarity, inplace=(num_groups == 1), top_k=self.top_k, return_usage=True) - - # Record memory usage for working memory - self.temporary_work_mem.update_usage(usage[:, :temp_work_mem_size].flatten()) + if not disable_usage_updates: + # Record memory usage for working memory + self.temporary_work_mem.update_usage(usage[:, :temp_work_mem_size].flatten()) else: affinity = do_softmax(similarity, inplace=(num_groups == 1), top_k=self.top_k, return_usage=False) diff --git a/model/modules.py b/model/modules.py index 9920799..82d1ac4 100644 --- a/model/modules.py +++ b/model/modules.py @@ -124,13 +124,13 @@ def __init__(self, value_dim, hidden_dim, single_object=False): def forward(self, image, image_feat_f16, h, masks, others, is_deep_update=True): # image_feat_f16 is the feature from the key encoder if not self.single_object: - g = torch.stack([masks, others], 2) + g_1 = torch.stack([masks, others], 2) else: - g = masks.unsqueeze(2) - g = self.distributor(image, g) + g_1 = masks.unsqueeze(2) + g_2 = self.distributor(image, g_1) - batch_size, num_objects = g.shape[:2] - g = g.flatten(start_dim=0, end_dim=1) + batch_size, num_objects = g_2.shape[:2] + g = g_2.flatten(start_dim=0, end_dim=1) g = self.conv1(g) g = self.bn1(g) # 1/2, 64 From de7b762898f04772d768de7a9b371a40bc107156 Mon Sep 17 00:00:00 2001 From: max810 Date: Fri, 27 Jan 2023 16:35:44 +0400 Subject: [PATCH 08/49] Cleaned up the code, provided easier to use main.py --- inference/active_learning.py | 564 ++++++++++++++++++++++++++++++++++- inference/inference_core.py | 9 + main.py | 36 +++ run_on_video.py | 444 ++++++++++++++------------- util/configuration.py | 16 +- 5 files changed, 845 insertions(+), 224 deletions(-) create mode 100644 main.py diff --git a/inference/active_learning.py b/inference/active_learning.py index da14bec..fa8a3bf 100644 --- a/inference/active_learning.py +++ b/inference/active_learning.py @@ -1,12 +1,24 @@ +from dataclasses import asdict from functools import partial from os import access from pathlib import Path +from turtle import numinput from sklearn.cluster import KMeans import pandas as pd from PIL import Image import torch import torchvision.transforms.functional as FT +import numpy as np from torchvision.transforms import ColorJitter, Grayscale, RandomPosterize, RandomAdjustSharpness, ToTensor, RandomAffine +from sklearn.decomposition import PCA, FastICA +from sklearn.cluster import AgglomerativeClustering +from scipy.spatial.distance import cdist +from sklearn.manifold import TSNE +from umap import UMAP +from hdbscan import flat + +from tqdm import tqdm +from model.memory_util import do_softmax, get_similarity from util.tensor_util import get_bbox_from_mask @@ -246,6 +258,550 @@ def get_determenistic_augmentations(img_size=None, mask=None, subset: str=None): return img_mask_augs_pairs +def _extract_keys(dataloder, processor, print_progress=False): + frame_keys = [] + shrinkages = [] + selections = [] + device = None + with torch.no_grad(): # just in case + key_sum = None + + for ti, data in enumerate(tqdm(dataloder, disable=not print_progress, desc='Calculating key features')): + rgb = data['rgb'].cuda()[0] + key, shrinkage, selection = processor.encode_frame_key(rgb) + + if key_sum is None: + device = key.device + key_sum = torch.zeros_like(key, device=device, dtype=torch.float64) # to avoid possible overflow + + key_sum += key.type(torch.float64) + + frame_keys.append(key.flatten(start_dim=2).cpu()) + shrinkages.append(shrinkage.flatten(start_dim=2).cpu()) + selections.append(selection.flatten(start_dim=2).cpu()) + + num_frames = ti + 1 # 0 after 1 iteration, 1 after 2, etc. + + return frame_keys, shrinkages, selections, device, num_frames, key_sum + +def calculate_proposals_for_annotations_with_average_distance(dataloader, processor, how_many_frames=9, print_progress=False): + with torch.no_grad(): # just in case + frame_keys, device, num_frames, key_sum = _extract_keys(dataloader, processor, print_progress) + + avg_key = (key_sum / num_frames).type(torch.float32) + qk = avg_key.flatten(start_dim=2) + + similarities = [] + for i in tqdm(range(num_frames), desc='Computing similarity to avg frame'): # how to run a loop for lower memory usage + frame_key = frame_keys[i] + similarity_per_pixel = get_similarity(frame_key.to(device), ms=None, qk=qk, qe=None) + similarity_avg = (similarity_per_pixel < 0).sum() # number of dissimilar pixels + + similarities.append(similarity_avg) + + # import matplotlib.pyplot as plt + # import numpy as np + # plt.figure(figsize=(16, 10)) + # plt.xticks(np.arange(0, num_frames, 100)) + # plt.plot([float(x) for x in similarities]) + # plt.title("Inner XMem mean similarity VS average frame") + # plt.savefig( + # 'output/similarities_NEG_vs_avg_frame.png' + # ) + values, indices = torch.topk(torch.tensor(similarities), k=how_many_frames, largest=True) # top `how_many_frames` frames LEAST similar to the avg_key + return indices + +def calculate_proposals_for_annotations_with_first_distance(dataloader, processor, how_many_frames=9, print_progress=False): + with torch.no_grad(): # just in case + frame_keys, device, num_frames, key_sum = _extract_keys(dataloader, processor, print_progress) + qk = frame_keys[0].flatten(start_dim=2).to(device) + + similarities = [] + first_similarity = None + for i in tqdm(range(num_frames), desc='Computing similarity to avg frame'): # how to run a loop for lower memory usage + frame_key = frame_keys[i] + similarity_per_pixel = get_similarity(frame_key.to(device), ms=None, qk=qk, qe=None) + if i == 0: + first_similarity = similarity_per_pixel + similarity_avg = similarity_per_pixel.mean() + + # if i == 0 or i == 175 or i == 353 or i == 560 or i == 900: + # import seaborn as sns + # import matplotlib.pyplot as plt + # plt.figure(figsize=(40, 40)) + # sns.heatmap(similarity_per_pixel.squeeze().cpu(), square=True, cmap="icefire") + # plt.savefig(f'output/SIMILARITY_HEATMAPS/0_vs_{i}.png') + + # plt.figure(figsize=(40, 40)) + # sns.heatmap((similarity_per_pixel - first_similarity).squeeze().cpu(), square=True, cmap="icefire") + # plt.savefig(f'output/SIMILARITY_HEATMAPS/0_vs_{i}_diff_with_0.png') + similarities.append(similarity_avg) + + # import matplotlib.pyplot as plt + # import numpy as np + # plt.figure(figsize=(16, 10)) + # plt.xticks(np.arange(0, num_frames, 100)) + # plt.plot([float(x) for x in similarities]) + # plt.title("Inner XMem mean similarity VS 1st frame") + # plt.savefig( + # 'output/similarities_NEG_vs_1st_frame.png' + # ) + + # we don't need to worry about 1st frame with itself, since we take the LEAST similar frames + values, indices = torch.topk(torch.tensor(similarities), k=how_many_frames, largest=False) # top `how_many_frames` frames LEAST similar to the avg_key + return indices + + +def calculate_proposals_for_annotations_with_iterative_distance(dataloader, processor, how_many_frames=9, print_progress=False): + with torch.no_grad(): # just in case + frame_keys, device, num_frames, key_sum = _extract_keys(dataloader, processor, print_progress) + + chosen_frames = [0] + chosen_frames_mem_keys = [frame_keys[0].flatten(start_dim=2).to(device)] + + for i in tqdm(range(how_many_frames), desc='Iteratively picking the most dissimilar frames'): + similarities = [] + for j in tqdm(range(num_frames), desc='Computing similarity to avg frame', disable=True): # how to run a loop for lower memory usage + qk = frame_keys[j].to(device) + + similarities_across_mem_keys = [] + for mem_key in chosen_frames_mem_keys: + similarity_per_pixel = get_similarity(qk, ms=None, qk=mem_key, qe=None) + similarity_avg = similarity_per_pixel.mean() + similarities_across_mem_keys.append(similarity_avg) + + similarity_max_across_all = max(similarities_across_mem_keys) + similarities.append(similarity_max_across_all) + + values, indices = torch.topk(torch.tensor(similarities), k=1, largest=False) + idx = int(indices[0]) + + import matplotlib.pyplot as plt + import numpy as np + plt.figure(figsize=(16, 10)) + plt.xticks(np.arange(0, num_frames, 100)) + plt.plot([float(x) for x in similarities]) + plt.title(f"Inner XMem mean similarity VS frames {chosen_frames}") + plt.savefig( + f'output/iterative_similarity/{i}.png' + ) + + chosen_frames.append(idx) + next_frame_to_add = frame_keys[idx] + chosen_frames_mem_keys.append(next_frame_to_add.to(device)) + + + # we don't need to worry about 1st frame with itself, since we take the LEAST similar frames + return chosen_frames + + +def calculate_proposals_for_annotations_with_iterative_distance_diff(dataloader, processor, how_many_frames=9, print_progress=False): + with torch.no_grad(): # just in case + frame_keys, device, num_frames, key_sum = _extract_keys(dataloader, processor, print_progress) + + chosen_frames = [0] + chosen_frames_mem_keys = [frame_keys[0].flatten(start_dim=2).to(device)] + + chosen_frames_self_similarities = [get_similarity(chosen_frames_mem_keys[0], ms=None, qk=chosen_frames_mem_keys[0], qe=None)] + for i in tqdm(range(how_many_frames), desc='Iteratively picking the most dissimilar frames', disable=not print_progress): + dissimilarities = [] + for j in tqdm(range(num_frames), desc='Computing similarity to chosen frames', disable=not print_progress): # how to run a loop for lower memory usage + qk = frame_keys[j].to(device) + + dissimilarities_across_mem_keys = [] + for mem_key, self_sim in zip(chosen_frames_mem_keys, chosen_frames_self_similarities): + similarity_per_pixel = get_similarity(qk, ms=None, qk=mem_key, qe=None) + + # basically, removing scene ambiguity and only keeping differences due to the scene change + # in theory, of course + dissimilarity_score = (similarity_per_pixel - self_sim).abs().sum() / similarity_per_pixel.numel() + dissimilarities_across_mem_keys.append(dissimilarity_score) + + # filtering our existin or very similar frames + dissimilarity_min_across_all = min(dissimilarities_across_mem_keys) + dissimilarities.append(dissimilarity_min_across_all) + + values, indices = torch.topk(torch.tensor(dissimilarities), k=1, largest=True) + idx = int(indices[0]) + + # import matplotlib.pyplot as plt + # import numpy as np + # plt.figure(figsize=(16, 10)) + # plt.xticks(np.arange(0, num_frames, 100)) + # plt.plot([float(x) for x in dissimilarities]) + # plt.title(f"Inner XMem mean dissimilarity VS frames {chosen_frames}") + # plt.savefig( + # f'output/iterative_dissimilarity/{i}.png' + # ) + + chosen_frames.append(idx) + next_frame_to_add_key = frame_keys[idx].to(device) + chosen_frames_mem_keys.append(next_frame_to_add_key) + chosen_frames_self_similarities.append(get_similarity(next_frame_to_add_key, ms=None, qk=next_frame_to_add_key, qe=None)) + + + # we don't need to worry about 1st frame with itself, since we take the LEAST similar frames + return chosen_frames + +def calculate_proposals_for_annotations_with_uniform_iterative_distance_diff(dataloader, processor, how_many_frames=9, print_progress=False): + with torch.no_grad(): + frame_keys, shrinkages, selections, device, num_frames, key_sum = _extract_keys(dataloader, processor, print_progress) + + space = np.linspace(0, num_frames, how_many_frames + 2, endpoint=True, dtype=int) + ranges = zip(space, space[1:]) + + chosen_frames = [] + chosen_frames_mem_keys = [] + chosen_frames_self_similarities = [] + chosen_frames_shrinkages = [] + + for a, b in ranges: + if a == 0: + chosen_new_frame = 0 # in the first range we always pick the first frame + else: + dissimilarities = [] + for i in tqdm(range(b - a), desc=f'Computing similarity to chosen frames in range({a}, {b})', disable=not print_progress): # how to run a loop for lower memory usage + true_frame_idx = a + i + qk = frame_keys[true_frame_idx].to(device) + selection = selections[true_frame_idx].to(device) # query + + dissimilarities_across_mem_keys = [] + for mem_key, shrinkage, self_sim in zip(chosen_frames_mem_keys, chosen_frames_shrinkages, chosen_frames_self_similarities): + similarity_per_pixel = get_similarity(mem_key, ms=shrinkage, qk=qk, qe=selection) + + # basically, removing scene ambiguity and only keeping differences due to the scene change + # in theory, of course + diff = (similarity_per_pixel - self_sim) + dissimilarity_score = diff[diff > 0].sum() / similarity_per_pixel.numel() + dissimilarities_across_mem_keys.append(dissimilarity_score) + + # filtering our existing or very similar frames + dissimilarity_min_across_all = min(dissimilarities_across_mem_keys) + dissimilarities.append(dissimilarity_min_across_all) + + values, indices = torch.topk(torch.tensor(dissimilarities), k=1, largest=True) + chosen_new_frame = int(indices[0]) + a + + chosen_frames.append(chosen_new_frame) + chosen_frames_mem_keys.append(frame_keys[chosen_new_frame].to(device)) + chosen_frames_shrinkages.append(shrinkages[chosen_new_frame].to(device)) + chosen_frames_self_similarities.append(get_similarity(chosen_frames_mem_keys[-1], ms=shrinkages[chosen_new_frame].to(device), qk=chosen_frames_mem_keys[-1], qe=selections[chosen_new_frame].to(device))) + + # we don't need to worry about 1st frame with itself, since we take the LEAST similar frames + return chosen_frames + +def calculate_proposals_for_annotations_with_uniform_iterative_distance_cycle(dataloader, processor, how_many_frames=9, print_progress=False): + with torch.no_grad(): + frame_keys, shrinkages, selections, device, num_frames, key_sum = _extract_keys(dataloader, processor, print_progress) + + space = np.linspace(0, num_frames, how_many_frames + 2, endpoint=True, dtype=int) + ranges = zip(space, space[1:]) + + chosen_frames = [] + chosen_frames_mem_keys = [] + # chosen_frames_self_similarities = [] + chosen_frames_shrinkages = [] + + for a, b in ranges: + if a == 0: + chosen_new_frame = 0 # in the first range we always pick the first frame + else: + dissimilarities = [] + for i in tqdm(range(b - a), desc=f'Computing similarity to chosen frames in range({a}, {b})', disable=not print_progress): # how to run a loop for lower memory usage + true_frame_idx = a + i + qk = frame_keys[true_frame_idx].to(device) + query_selection = selections[true_frame_idx].to(device) # query + query_shrinkage = shrinkages[true_frame_idx].to(device) + + dissimilarities_across_mem_keys = [] + for key_idx, mem_key, key_shrinkage in zip(chosen_frames, chosen_frames_mem_keys, chosen_frames_shrinkages): + mem_key = mem_key.to(device) + key_selection = selections[key_idx].to(device) + similarity_per_pixel = get_similarity(mem_key, ms=key_shrinkage, qk=qk, qe=query_selection) + reverse_similarity_per_pixel = get_similarity(qk, ms=query_shrinkage, qk=mem_key, qe=key_selection) + + # mapping of pixels A -> B would be very similar to B -> A if the images are similar + # and very different if the images are different + cycle_dissimilarity_per_pixel = (similarity_per_pixel - reverse_similarity_per_pixel) + cycle_dissimilarity_score = cycle_dissimilarity_per_pixel.abs().sum() / cycle_dissimilarity_per_pixel.numel() + + dissimilarities_across_mem_keys.append(cycle_dissimilarity_score) + + # filtering our existing or very similar frames + dissimilarity_min_across_all = min(dissimilarities_across_mem_keys) + dissimilarities.append(dissimilarity_min_across_all) + + values, indices = torch.topk(torch.tensor(dissimilarities), k=1, largest=True) + chosen_new_frame = int(indices[0]) + a + + chosen_frames.append(chosen_new_frame) + chosen_frames_mem_keys.append(frame_keys[chosen_new_frame].to(device)) + chosen_frames_shrinkages.append(shrinkages[chosen_new_frame].to(device)) + # chosen_frames_self_similarities.append(get_similarity(chosen_frames_mem_keys[-1], ms=shrinkages[chosen_new_frame].to(device), qk=chosen_frames_mem_keys[-1], qe=selections[chosen_new_frame].to(device))) + + # we don't need to worry about 1st frame with itself, since we take the LEAST similar frames + return chosen_frames + + +def calculate_proposals_for_annotations_with_iterative_distance_cycle(dataloader, processor, how_many_frames=9, print_progress=False): + with torch.no_grad(): + frame_keys, shrinkages, selections, device, num_frames, key_sum = _extract_keys(dataloader, processor, print_progress) + + chosen_frames = [0] + chosen_frames_mem_keys = [frame_keys[0].to(device)] + + for i in tqdm(range(how_many_frames), desc='Iteratively picking the most dissimilar frames', disable=not print_progress): + dissimilarities = [] + for j in tqdm(range(num_frames), desc='Computing similarity to chosen frames', disable=not print_progress): # how to run a loop for lower memory usage + qk = frame_keys[j].to(device) + query_selection = selections[j].to(device) # query + query_shrinkage = shrinkages[j].to(device) + + dissimilarities_across_mem_keys = [] + for key_idx, mem_key in zip(chosen_frames, chosen_frames_mem_keys): + mem_key = mem_key.to(device) + key_shrinkage = shrinkages[key_idx].to(device) + key_selection = selections[key_idx].to(device) + + similarity_per_pixel = get_similarity(mem_key, ms=None, qk=qk, qe=None) + reverse_similarity_per_pixel = get_similarity(qk, ms=None, qk=mem_key, qe=None) + + # mapping of pixels A -> B would be very similar to B -> A if the images are similar + # and very different if the images are different + cycle_dissimilarity_per_pixel = (similarity_per_pixel - reverse_similarity_per_pixel) + cycle_dissimilarity_score = cycle_dissimilarity_per_pixel.abs().sum() / cycle_dissimilarity_per_pixel.numel() + + dissimilarities_across_mem_keys.append(cycle_dissimilarity_score) + + # filtering our existing or very similar frames + dissimilarity_min_across_all = min(dissimilarities_across_mem_keys) + dissimilarities.append(dissimilarity_min_across_all) + + values, indices = torch.topk(torch.tensor(dissimilarities), k=1, largest=True) + chosen_new_frame = int(indices[0]) + + chosen_frames.append(chosen_new_frame) + chosen_frames_mem_keys.append(frame_keys[chosen_new_frame].to(device)) + # chosen_frames_self_similarities.append(get_similarity(chosen_frames_mem_keys[-1], ms=shrinkages[chosen_new_frame].to(device), qk=chosen_frames_mem_keys[-1], qe=selections[chosen_new_frame].to(device))) + + # we don't need to worry about 1st frame with itself, since we take the LEAST similar frames + return chosen_frames + + +def calculate_proposals_for_annotations_with_uniform_iterative_distance_double_diff(dataloader, processor, how_many_frames=9, print_progress=False): + with torch.no_grad(): + frame_keys, shrinkages, selections, device, num_frames, key_sum = _extract_keys(dataloader, processor, print_progress) + + space = np.linspace(0, num_frames, how_many_frames + 2, endpoint=True, dtype=int) + ranges = zip(space, space[1:]) + + chosen_frames = [] + chosen_frames_mem_keys = [] + # chosen_frames_self_similarities = [] + chosen_frames_shrinkages = [] + + for a, b in ranges: + if a == 0: + chosen_new_frame = 0 # in the first range we always pick the first frame + else: + dissimilarities = [] + for i in tqdm(range(b - a), desc=f'Computing similarity to chosen frames in range({a}, {b})', disable=not print_progress): # how to run a loop for lower memory usage + true_frame_idx = a + i + qk = frame_keys[true_frame_idx].to(device) + query_selection = selections[true_frame_idx].to(device) # query + query_shrinkage = shrinkages[true_frame_idx].to(device) + + dissimilarities_across_mem_keys = [] + for key_idx, mem_key in zip(chosen_frames, chosen_frames_mem_keys): + mem_key = mem_key.to(device) + key_shrinkage = shrinkages[key_idx].to(device) + key_selection = selections[key_idx].to(device) + + similarity_per_pixel = get_similarity(mem_key, ms=key_shrinkage, qk=qk, qe=query_selection) + self_similarity_key = get_similarity(mem_key, ms=key_shrinkage, qk=mem_key, qe=key_selection) + self_similarity_query = get_similarity(qk, ms=query_shrinkage, qk=query_shrinkage, qe=query_selection) + + # mapping of pixels A -> B would be very similar to B -> A if the images are similar + # and very different if the images are different + + pure_similarity = 2 * similarity_per_pixel - self_similarity_key - self_similarity_query + + dissimilarity_score = pure_similarity.abs().sum() / pure_similarity.numel() + + dissimilarities_across_mem_keys.append(dissimilarity_score) + + # filtering our existing or very similar frames + dissimilarity_min_across_all = min(dissimilarities_across_mem_keys) + dissimilarities.append(dissimilarity_min_across_all) + + values, indices = torch.topk(torch.tensor(dissimilarities), k=1, largest=True) + chosen_new_frame = int(indices[0]) + a + + chosen_frames.append(chosen_new_frame) + chosen_frames_mem_keys.append(frame_keys[chosen_new_frame].to(device)) + chosen_frames_shrinkages.append(shrinkages[chosen_new_frame].to(device)) + # chosen_frames_self_similarities.append(get_similarity(chosen_frames_mem_keys[-1], ms=shrinkages[chosen_new_frame].to(device), qk=chosen_frames_mem_keys[-1], qe=selections[chosen_new_frame].to(device))) + + # we don't need to worry about 1st frame with itself, since we take the LEAST similar frames + return chosen_frames + + +def calculate_proposals_for_annotations_iterative_pca_cosine(dataloader, processor, how_many_frames=9, print_progress=False): + # might not pick 0-th frame + np.random.seed(1) # Just in case + with torch.no_grad(): + frame_keys, shrinkages, selections, device, num_frames, key_sum = _extract_keys(dataloader, processor, print_progress) + flat_keys = torch.stack([key.flatten().cpu() for key in frame_keys]).numpy() + + # PCA hangs at num_frames // 2: https://github.com/scikit-learn/scikit-learn/issues/22434 + pca = PCA(num_frames - 1, svd_solver='arpack') + # pca = FastICA(num_frames - 1) + smol_keys = pca.fit_transform(flat_keys.astype(np.float64)) + # smol_keys = flat_keys # to disable PCA + + chosen_frames = [0] + for c in range(how_many_frames): + distances = cdist(smol_keys[chosen_frames], smol_keys, metric='euclidean') + closest_to_mem_key_distances = distances.min(axis=0) + most_distant_frame = np.argmax(closest_to_mem_key_distances) + chosen_frames.append(most_distant_frame) + + return chosen_frames + + +def calculate_proposals_for_annotations_iterative_umap_cosine(dataloader, processor, how_many_frames=9, print_progress=False): + # might not pick 0-th frame + np.random.seed(1) # Just in case + with torch.no_grad(): + frame_keys, shrinkages, selections, device, num_frames, key_sum = _extract_keys(dataloader, processor, print_progress) + flat_keys = torch.stack([key.flatten().cpu() for key in frame_keys]).numpy() + + # PCA hangs at num_frames // 2: https://github.com/scikit-learn/scikit-learn/issues/22434 + pca = UMAP(n_neighbors=num_frames - 1, n_components=num_frames // 2, random_state=1) + # pca = FastICA(num_frames - 1) + smol_keys = pca.fit_transform(flat_keys.astype(np.float64)) + # smol_keys = flat_keys # to disable PCA + + chosen_frames = [0] + for c in range(how_many_frames): + distances = cdist(smol_keys[chosen_frames], smol_keys, metric='euclidean') + closest_to_mem_key_distances = distances.min(axis=0) + most_distant_frame = np.argmax(closest_to_mem_key_distances) + chosen_frames.append(most_distant_frame) + + return chosen_frames + + +def calculate_proposals_for_annotations_uniform_iterative_pca_cosine(dataloader, processor, how_many_frames=9, print_progress=False): + # might not pick 0-th frame + with torch.no_grad(): + frame_keys, shrinkages, selections, device, num_frames, key_sum = _extract_keys(dataloader, processor, print_progress) + flat_keys = torch.stack([key.flatten().cpu() for key in frame_keys]).numpy() + + # PCA hangs at num_frames // 2: https://github.com/scikit-learn/scikit-learn/issues/22434 + pca = PCA(num_frames - 1, svd_solver='arpack') + smol_keys = pca.fit_transform(flat_keys) + # smol_keys = flat_keys # to disable PCA + + space = np.linspace(0, num_frames, how_many_frames + 2, endpoint=True, dtype=int) + ranges = zip(space, space[1:]) + + chosen_frames = [0] + + for a, b in ranges: + if a == 0: + # skipping the first one + continue + + distances = cdist(smol_keys[chosen_frames], smol_keys[a:b], metric='cosine') + closest_to_mem_key_distances = distances.min(axis=0) + most_distant_frame = np.argmax(closest_to_mem_key_distances) + a + + chosen_frames.append(most_distant_frame) + + return chosen_frames + + +def calculate_proposals_for_annotations_iterative_pca_cosine_values(values, how_many_frames=9, print_progress=False): + # might not pick 0-th frame + np.random.seed(1) # Just in case + with torch.no_grad(): + num_frames = len(values) + flat_values = torch.stack([value.flatten().cpu() for value in values]).numpy() + + # PCA hangs at num_frames // 2: https://github.com/scikit-learn/scikit-learn/issues/22434 + pca = PCA(num_frames - 1, svd_solver='arpack') + # pca = FastICA(num_frames - 1) + smol_values = pca.fit_transform(flat_values.astype(np.float64)) + # smol_keys = flat_keys # to disable PCA + + chosen_frames = [0] + for c in range(how_many_frames): + distances = cdist(smol_values[chosen_frames], smol_values, metric='euclidean') + closest_to_mem_key_distances = distances.min(axis=0) + most_distant_frame = np.argmax(closest_to_mem_key_distances) + chosen_frames.append(most_distant_frame) + + return chosen_frames + + +def calculate_proposals_for_annotations_umap_hdbscan_clustering(dataloader, processor, how_many_frames=9, print_progress=False): + # might not pick 0-th frame + with torch.no_grad(): + frame_keys, shrinkages, selections, device, num_frames, key_sum = _extract_keys(dataloader, processor, print_progress) + flat_keys = torch.stack([key.flatten().cpu() for key in frame_keys]).numpy() + + pca = UMAP(n_neighbors=num_frames - 1, n_components=num_frames // 2, random_state=1) + smol_keys = pca.fit_transform(flat_keys) + + # clustering = AgglomerativeClustering(n_clusters=how_many_frames + 1, linkage='single') + clustering = flat.HDBSCAN_flat(smol_keys, n_clusters=how_many_frames + 1) + labels = clustering.labels_ + + chosen_frames = [] + for c in range(how_many_frames + 1): + vectors = smol_keys[labels == c] + true_index_mapping = {i: int(ti) for i, ti in zip(range(len(vectors)), np.nonzero(labels == c)[0])} + center = np.mean(vectors, axis=0) + + distances = cdist(vectors, [center], metric='euclidean').squeeze() + + closest_to_cluster_center_idx = np.argsort(distances)[0] + + chosen_frame_idx = true_index_mapping[closest_to_cluster_center_idx] + chosen_frames.append(chosen_frame_idx) + + return chosen_frames + + +def calculate_proposals_for_annotations_pca_hierarchical_clustering(dataloader, processor, how_many_frames=9, print_progress=False): + # might not pick 0-th frame + with torch.no_grad(): + frame_keys, shrinkages, selections, device, num_frames, key_sum = _extract_keys(dataloader, processor, print_progress) + flat_keys = torch.stack([key.flatten().cpu() for key in frame_keys]).numpy() + + pca = PCA(num_frames) + smol_keys = pca.fit_transform(flat_keys) + + # clustering = AgglomerativeClustering(n_clusters=how_many_frames + 1, linkage='single') + clustering = KMeans(n_clusters=how_many_frames + 1) + labels = clustering.fit_predict(smol_keys) + + chosen_frames = [] + for c in range(how_many_frames + 1): + vectors = smol_keys[labels == c] + true_index_mapping = {i: int(ti) for i, ti in zip(range(len(vectors)), np.nonzero(labels == c)[0])} + center = np.mean(vectors, axis=0) + + distances = cdist(vectors, [center], metric='euclidean').squeeze() + + closest_to_cluster_center_idx = np.argsort(distances)[0] + + chosen_frame_idx = true_index_mapping[closest_to_cluster_center_idx] + chosen_frames.append(chosen_frame_idx) + + return chosen_frames + def apply_aug(img_path, out_path): img = Image.open(img_path) @@ -312,11 +868,3 @@ def compute_disparity(predictions, augs, images:list = None, output_save_path: s disparity_img.save(p_out_disparity / (f'{i+1}_absolute_disparity.png')) return {'full': disparity_scaled, 'avg': disparity_avg, 'large': disparity_large} - - -if __name__ == '__main__': - img_in = '/home/maksym/RESEARCH/VIDEOS/thanks_no_ears_5_annot/JPEGImages/frame_000001.PNG' - img_out = 'test_aug.png' - - apply_aug(img_in, img_out) - \ No newline at end of file diff --git a/inference/inference_core.py b/inference/inference_core.py index fdc28a7..f6a4973 100644 --- a/inference/inference_core.py +++ b/inference/inference_core.py @@ -39,6 +39,15 @@ def set_all_labels(self, all_labels): # self.all_labels = [l.item() for l in all_labels] self.all_labels = all_labels + def encode_frame_key(self, image): + image, self.pad = pad_divide_by(image, 16) + image = image.unsqueeze(0) # add the batch dimension + + key, shrinkage, selection, f16, f8, f4 = self.network.encode_key(image, + need_ek=True, + need_sk=True) + + return key, shrinkage, selection def step(self, image, mask=None, valid_labels=None, end=False, manually_curated_masks=False, disable_memory_updates=False, do_not_add_mask_to_memory=False): # For feedback: # 1. We run the model as usual diff --git a/main.py b/main.py new file mode 100644 index 0000000..3538ba6 --- /dev/null +++ b/main.py @@ -0,0 +1,36 @@ +from run_on_video import run_on_video, predict_annotation_candidates + + +if __name__ == '__main__': + # If pytorch cannot download the weights due to an ssl error, uncomment the following lines + # import ssl + # ssl._create_default_https_context = ssl._create_unverified_context + + # Example for a fully-labeled video + video_frames_path = 'example_videos/DAVIS-bmx/frames' + video_masks_path = 'example_videos/DAVIS-bmx/masks' + output_masks_path_baseline = 'output/DAVIS-bmx/baseline' + output_masks_path_5_frames = 'output/DAVIS-bmx/5_frames' + + num_annotation_candidates = 5 + + # The following step is not necessary, you as a human can also choose suitable frames and provide annotations + compute_iou = True + chosen_annotation_frames = predict_annotation_candidates(video_frames_path, num_candidates=num_annotation_candidates) + + print(f"The following frames were chosen as annotation candidates: {chosen_annotation_frames}") + + stats_first_frame_only = run_on_video(video_frames_path, video_masks_path, output_masks_path_baseline, frames_with_masks=[0], compute_iou=True) + stats_5_frames = run_on_video(video_frames_path, video_masks_path, output_masks_path_5_frames, frames_with_masks=chosen_annotation_frames, compute_iou=True) + + print(f"Average IoU for the video: {float(stats_first_frame_only['iou'].mean())} (first frame only)") + print(f"Average IoU for the video: {float(stats_5_frames['iou'].mean())} ({num_annotation_candidates} chosen annotated frames)") + + # Example for a video with only a few annotations present + video_frames_path = 'example_videos/two-face/frames' + video_masks_path = 'example_videos/two-face/masks' + output_masks_path_baseline = 'output/two-face/baseline' + output_masks_path_5_frames = 'output/two-face/5_frames' + + run_on_video(video_frames_path, video_masks_path, output_masks_path_baseline, frames_with_masks=[0], compute_iou=False) + run_on_video(video_frames_path, video_masks_path, output_masks_path_5_frames, frames_with_masks=[0, 259, 621, 785, 1401], compute_iou=False) diff --git a/run_on_video.py b/run_on_video.py index 239813c..717dffb 100644 --- a/run_on_video.py +++ b/run_on_video.py @@ -1,37 +1,54 @@ -from collections import defaultdict +import csv +from typing import Iterable, Union, List +from util.tensor_util import compute_tensor_iou +from inference.inference_core import InferenceCore +from model.network import XMem +from inference.data.video_reader import VideoReader +from inference.data.mask_mapper import MaskMapper +from inference.data.test_datasets import LongTestDataset, DAVISTestDataset, YouTubeVOSTestDataset +from inference.active_learning import calculate_proposals_for_annotations_iterative_umap_cosine, calculate_proposals_for_annotations_iterative_pca_cosine, calculate_proposals_for_annotations_iterative_pca_cosine_values, calculate_proposals_for_annotations_pca_hierarchical_clustering, calculate_proposals_for_annotations_umap_hdbscan_clustering, calculate_proposals_for_annotations_uniform_iterative_pca_cosine, calculate_proposals_for_annotations_with_average_distance, calculate_proposals_for_annotations_with_first_distance, calculate_proposals_for_annotations_with_iterative_distance, calculate_proposals_for_annotations_with_iterative_distance_cycle, calculate_proposals_for_annotations_with_iterative_distance_diff, calculate_proposals_for_annotations_with_uniform_iterative_distance_cycle, calculate_proposals_for_annotations_with_uniform_iterative_distance_diff, calculate_proposals_for_annotations_with_uniform_iterative_distance_double_diff, get_determenistic_augmentations, select_most_uncertain_frame, select_n_frame_candidates, compute_disparity as compute_disparity_func, select_n_frame_candidates_no_neighbours_simple +import torchvision.transforms.functional as FT +from baal.active.heuristics import BALD +from scipy.stats import entropy +from tqdm import tqdm +from PIL import Image +import numpy as np +from torch.utils.data import DataLoader +import torch.nn.functional as F +import torch +import pandas as pd +import shutil +from pathlib import Path +from argparse import ArgumentParser +from os import PathLike, path import math +from collections import defaultdict import os -from os import path -from argparse import ArgumentParser -from pathlib import Path -import shutil +from torchvision.transforms import functional as FT -import pandas as pd -import torch -import torch.nn.functional as F -from torch.utils.data import DataLoader -import numpy as np -from PIL import Image -from tqdm import tqdm -from scipy.stats import entropy -from baal.active.heuristics import BALD -import torchvision.transforms.functional as FT -from inference.active_learning import get_determenistic_augmentations, select_most_uncertain_frame, select_n_frame_candidates, compute_disparity as compute_disparity_func, select_n_frame_candidates_no_neighbours_simple -from inference.data.test_datasets import LongTestDataset, DAVISTestDataset, YouTubeVOSTestDataset -from inference.data.mask_mapper import MaskMapper -from inference.data.video_reader import VideoReader -from model.network import XMem -from inference.inference_core import InferenceCore -from util.tensor_util import compute_tensor_iou +def save_frames(dataset, frame_indices, output_folder): + p_out = Path(output_folder) + + if not p_out.exists(): + p_out.mkdir(parents=True) + + + for i in frame_indices: + sample = dataset[i] + rgb_raw_tensor = sample['raw_image_tensor'].cpu().squeeze() + img = FT.to_pil_image(rgb_raw_tensor) + + img.save(p_out / f'frame_{i:06d}.png') -def inference_on_video(frames_with_masks, imgs_in_path, masks_in_path, masks_out_path, - original_memory_mechanism=False, - compute_iou = False, compute_uncertainty = False, manually_curated_masks=False, print_progress=True, - augment_images_with_masks=False, - uncertainty_name: str = None, - overwrite_config: dict = None): +def _inference_on_video(frames_with_masks, imgs_in_path, masks_in_path, masks_out_path, + original_memory_mechanism=False, + compute_iou=False, compute_uncertainty=False, manually_curated_masks=False, print_progress=True, + augment_images_with_masks=False, + uncertainty_name: str = None, + only_predict_frames_to_annotate_and_quit=0, + overwrite_config: dict = None): torch.autograd.set_grad_enabled(False) frames_with_masks = set(frames_with_masks) config = { @@ -55,10 +72,7 @@ def inference_on_video(frames_with_masks, imgs_in_path, masks_in_path, masks_out 'size': 480, 'top_k': 30, 'value_dim': 512, - # 'video': '../VIDEOS/Maksym_frontal_simple.mp4', - 'masks_out_path': masks_out_path,#f'../VIDEOS/RESULTS/XMem_memory/thanks_two_face_5_frames/', - # 'masks_out_path': f'../VIDEOS/RESULTS/XMem/WhichFramesWithPreds/1/{dir_name}/{"_".join(map(str, frames_with_masks))}_frames_provided', - # 'masks_out_path': f'../VIDEOS/RESULTS/XMem/DAVIS_2017/WhichFrames/1/{dir_name}/{len(frames_with_masks) - 1}_extra_frames', + 'masks_out_path': masks_out_path, # f'../VIDEOS/RESULTS/XMem_memory/thanks_two_face_5_frames/', 'workspace': None, 'save_masks': True } @@ -70,14 +84,14 @@ def inference_on_video(frames_with_masks, imgs_in_path, masks_in_path, masks_out assert uncertainty_name is not None uncertainty_name = uncertainty_name.lower() assert uncertainty_name in {'entropy', 'bald', 'disparity', 'disparity_large'} - compute_disparity = uncertainty_name.startswith('disparity') + compute_disparity = uncertainty_name.startswith('disparity') else: compute_disparity = False - + vid_reader = VideoReader( - "", - imgs_in_path, #f'/home/maksym/RESEARCH/VIDEOS/thanks_no_ears_5_annot/JPEGImages', - masks_in_path, #f'/home/maksym/RESEARCH/VIDEOS/thanks_no_ears_5_annot/Annotations_binarized_two_face', + "", + imgs_in_path, # f'/home/maksym/RESEARCH/VIDEOS/thanks_no_ears_5_annot/JPEGImages', + masks_in_path, # f'/home/maksym/RESEARCH/VIDEOS/thanks_no_ears_5_annot/Annotations_binarized_two_face', size=config['size'], use_all_mask=True ) @@ -106,8 +120,16 @@ def inference_on_video(frames_with_masks, imgs_in_path, masks_in_path, masks_out processor = InferenceCore(network, config=config) first_mask_loaded = False + if only_predict_frames_to_annotate_and_quit > 0: + iterative_frames = calculate_proposals_for_annotations_iterative_pca_cosine(loader, processor, print_progress=print_progress, how_many_frames=only_predict_frames_to_annotate_and_quit) + + return iterative_frames + + frames_ = [] + masks_ = [] + if original_memory_mechanism: - frames_to_put_in_permanent_memory = [0] # only the first frame goes into permanent memory originally + frames_to_put_in_permanent_memory = [0] # only the first frame goes into permanent memory originally # the rest are going to be processed later else: frames_to_put_in_permanent_memory = frames_with_masks # in our modification, all frames with provided masks go into permanent memory @@ -122,40 +144,47 @@ def inference_on_video(frames_with_masks, imgs_in_path, masks_in_path, masks_out # https://github.com/hkchengrex/XMem/issues/21 just make exhaustive = True msk, labels = mapper.convert_mask(msk, exhaustive=True) msk = torch.Tensor(msk).cuda() + + if min(msk.shape) == 0: # empty mask, e.g. [1, 0, 720, 1280] + print(f"Skipping adding frame {j} to memory, as the mask is empty") + continue # just don't add anything to the memory if need_resize: msk = vid_reader.resize_mask(msk.unsqueeze(0))[0] processor.set_all_labels(list(mapper.remappings.values())) processor.put_to_permanent_memory(rgb, msk) + frames_.append(rgb) + masks_.append(msk) + if augment_images_with_masks: augs = get_determenistic_augmentations(rgb.shape, msk, subset='best_all') rgb_raw = FT.to_pil_image(rgb_raw_tensor) - + for img_aug, mask_aug in augs: # tensor -> PIL.Image -> tensor -> whatever normalization vid_reader applies rgb_aug = vid_reader.im_transform(img_aug(rgb_raw)).cuda() - + msk_aug = mask_aug(msk) - + processor.put_to_permanent_memory(rgb_aug, msk_aug) - + stats = [] if compute_uncertainty and uncertainty_name == 'bald': bald = BALD() - + for ti, data in enumerate(tqdm(loader, disable=not print_progress)): with torch.cuda.amp.autocast(enabled=True): rgb = data['rgb'].cuda()[0] rgb_raw_tensor = data['raw_image_tensor'].cpu()[0] - + gt = data.get('mask') # for IoU computations if ti in frames_with_masks: msk = data['mask'] else: msk = None - + info = data['info'] frame = info['frame'][0] shape = info['shape'] @@ -169,12 +198,7 @@ def inference_on_video(frames_with_masks, imgs_in_path, masks_in_path, masks_out # no point to do anything without a mask continue - if False: - rgb = torch.flip(rgb, dims=[-1]) - msk = torch.flip(msk, dims=[-1]) if msk is not None else None - # Map possibly non-continuous labels to continuous ones - # TODO: What are labels? Debug if msk is not None: # https://github.com/hkchengrex/XMem/issues/21 just make exhaustive = True msk, labels = mapper.convert_mask(msk[0].numpy(), exhaustive=True) @@ -182,10 +206,10 @@ def inference_on_video(frames_with_masks, imgs_in_path, masks_in_path, masks_out if need_resize: msk = vid_reader.resize_mask(msk.unsqueeze(0))[0] processor.set_all_labels(list(mapper.remappings.values())) - + else: labels = None - + if (compute_uncertainty and uncertainty_name == 'bald') or compute_disparity: dry_run_preds = [] augged_images = [] @@ -196,20 +220,21 @@ def inference_on_video(frames_with_masks, imgs_in_path, masks_in_path, masks_out augged_img = img_aug(rgb_raw) augged_images.append(augged_img) rgb_aug = vid_reader.im_transform(augged_img).cuda() - + msk = mask_aug(msk) # does not do anything, since original_only=True augmentations don't alter the mask at all - - dry_run_prob = processor.step(rgb_aug, msk, labels, end=(ti==vid_length-1), manually_curated_masks=manually_curated_masks, disable_memory_updates=True) + + dry_run_prob = processor.step(rgb_aug, msk, labels, end=(ti == vid_length-1), + manually_curated_masks=manually_curated_masks, disable_memory_updates=True) dry_run_preds.append(dry_run_prob.cpu()) - + if original_memory_mechanism: do_not_add_mask_to_memory = (ti == 0) # we only ignore the first mask, since it's already in the permanent memory else: do_not_add_mask_to_memory = msk is not None # we ignore all frames with masks, since they are already preloaded in the permanent memory # Run the model on this frame - # TODO: still running inference even on frames with masks? # 2+ channels, classes+ and background - prob = processor.step(rgb, msk, labels, end=(ti==vid_length-1), manually_curated_masks=manually_curated_masks, do_not_add_mask_to_memory=do_not_add_mask_to_memory) + prob = processor.step(rgb, msk, labels, end=(ti == vid_length-1), + manually_curated_masks=manually_curated_masks, do_not_add_mask_to_memory=do_not_add_mask_to_memory) if compute_uncertainty: if uncertainty_name == 'bald': @@ -217,28 +242,20 @@ def inference_on_video(frames_with_masks, imgs_in_path, masks_in_path, masks_out all_samples = torch.stack([x.unsqueeze(0) for x in dry_run_preds + [prob.cpu()]], dim=-1).numpy() score = bald.compute_score(all_samples) # TODO: can also return the exact pixels for every frame? As a suggestion on what to label - curr_stat['bald'] = float(np.squeeze(score).mean()) + curr_stat['bald'] = float(np.squeeze(score).mean()) elif compute_disparity: - # p_out_disparity = Path('output/masks_disparity/') - # if ti in {0, 200, 500, 900, 1100, 1300, 1450, 1600}: - # output_save_path = p_out_disparity / str(ti) - # else: - # output_save_path = None - - disparity_stats = compute_disparity_func(predictions=[prob] + dry_run_preds, augs=[img_aug for img_aug, _ in augs], images=[rgb_raw] + augged_images, output_save_path=None) + disparity_stats = compute_disparity_func( + predictions=[prob] + dry_run_preds, augs=[img_aug for img_aug, _ in augs], images=[rgb_raw] + augged_images, output_save_path=None) curr_stat['disparity'] = float(disparity_stats['avg']) curr_stat['disparity_large'] = float(disparity_stats['large']) else: e = entropy(prob.cpu()) e_mean = np.mean(e) curr_stat['entropy'] = float(e_mean) - + # Upsample to original size if needed if need_resize: - prob = F.interpolate(prob.unsqueeze(1), shape, mode='bilinear', align_corners=False)[:,0] - - if False: - prob = torch.flip(prob, dims=[-1]) + prob = F.interpolate(prob.unsqueeze(1), shape, mode='bilinear', align_corners=False)[:, 0] # Probability mask -> index mask out_mask = torch.argmax(prob, dim=0) @@ -253,8 +270,6 @@ def inference_on_video(frames_with_masks, imgs_in_path, masks_in_path, masks_out else: iou = -1 curr_stat['iou'] = iou - if False: - prob = (prob.detach().cpu().numpy()*255).astype(np.uint8) # Save the mask if config['save_masks']: @@ -266,10 +281,10 @@ def inference_on_video(frames_with_masks, imgs_in_path, masks_in_path, masks_out out_img.putpalette(vid_reader.get_palette()) out_img.save(os.path.join(this_out_path, frame[:-4]+'.png')) - if False: #args.save_scores: + if False: # args.save_scores: np_path = path.join(args.output, 'Scores', vid_name) os.makedirs(np_path, exist_ok=True) - if ti==len(loader)-1: + if ti == len(loader)-1: hkl.dump(mapper.remappings, path.join(np_path, f'backward.hkl'), mode='w') if args.save_all or info['save'][0]: hkl.dump(prob, path.join(np_path, f'{frame[:-4]}.hkl'), mode='w', compression='lzf') @@ -288,23 +303,39 @@ def run_active_learning(imgs_in_path, masks_in_path, masks_out_path, num_extra_f 'batched': Pick only `num_extra_frames` best frames 'iterative': Pick only 1 best frame instead of `num_extra_frames`, repeat `num_extra_frames` times """ - - assert mode in {'uniform', 'random', 'batched', 'iterative'} + if mode.startswith('uniform_random'): + pass + else: + assert mode in {'uniform', 'random', 'batched', 'iterative', 'umap_half_cosine', 'umap_hdbscan_clustering', 'pca_max_cosine_values_arpack'} assert uncertainty_name in {'entropy', 'bald', 'disparity', 'disparity_large'} - + + num_total_frames = len(os.listdir(imgs_in_path)) if mode == 'uniform': - num_total_frames = len(os.listdir(imgs_in_path)) # linspace is [a, b] (inclusive) frames_with_masks = np.linspace(0, num_total_frames - 1, num_extra_frames + 1).astype(int) elif mode == 'random': - num_total_frames = len(os.listdir(imgs_in_path)) np.random.seed(1) extra_frames = np.random.choice(np.arange(1, num_total_frames), size=num_extra_frames, replace=False).tolist() frames_with_masks = sorted([0] + extra_frames) + elif mode.startswith('uniform_random'): + seed = int(mode.split('_')[-1]) + chosen_frames = [] + space = np.linspace(0, num_total_frames, num_extra_frames + 2, endpoint=True, dtype=int) + ranges = zip(space, space[1:]) + np.random.seed(seed) + + for a, b in ranges: + if a == 0: + chosen_frames.append(0) + else: + extra_frame = int(np.random.choice(np.arange(a, b), replace=False)) + chosen_frames.append(extra_frame) + frames_with_masks = chosen_frames + elif mode == 'batched': # we save baseline results here, with just 1 annotation - baseline_out= Path(masks_out_path).parent.parent / 'baseline' - df = inference_on_video( + baseline_out = Path(masks_out_path).parent.parent / 'baseline' + df = _inference_on_video( imgs_in_path=imgs_in_path, masks_in_path=masks_in_path, masks_out_path=baseline_out / 'masks', @@ -316,7 +347,7 @@ def run_active_learning(imgs_in_path, masks_in_path, masks_out_path, num_extra_f print_progress=False, overwrite_config={'save_masks': True}, ) - + df.to_csv(baseline_out / 'stats.csv', index=False) if uncertainty_name == 'disparity_large': candidates = select_n_frame_candidates_no_neighbours_simple(df, n=num_extra_frames, uncertainty_name=uncertainty_name) @@ -329,7 +360,7 @@ def run_active_learning(imgs_in_path, masks_in_path, masks_out_path, num_extra_f elif mode == 'iterative': extra_frames = [] for i in range(num_extra_frames): - df = inference_on_video( + df = _inference_on_video( imgs_in_path=imgs_in_path, masks_in_path=masks_in_path, masks_out_path=masks_out_path, @@ -339,27 +370,58 @@ def run_active_learning(imgs_in_path, masks_in_path, masks_out_path, num_extra_f uncertainty_name=uncertainty_name, manually_curated_masks=False, print_progress=False, - overwrite_config={'save_masks': False}, + overwrite_config={'save_masks': False}, ) max_frame = select_most_uncertain_frame(df, uncertainty_name=uncertainty_name) extra_frames.append(max_frame['index']) # keep unsorted to preserve order of the choices - frames_with_masks = [0] + extra_frames - if use_cache and os.path.exists(csv_out_path): - final_df = pd.read_csv(csv_out_path) - else: - final_df = inference_on_video( + frames_with_masks = [0] + extra_frames + elif mode == 'umap_hdbscan_clustering' or mode == 'umap_half_cosine': + frames_with_masks = _inference_on_video( imgs_in_path=imgs_in_path, masks_in_path=masks_in_path, masks_out_path=masks_out_path, - frames_with_masks=frames_with_masks, + frames_with_masks=[0], compute_uncertainty=True, - compute_iou=True, + compute_iou=False, + manually_curated_masks=False, print_progress=False, uncertainty_name=uncertainty_name, - **kwargs + overwrite_config={'save_masks': False}, + only_predict_frames_to_annotate_and_quit=num_extra_frames, # ONLY THIS WILL RUN ANYWAY + ) + elif mode == 'pca_max_cosine_values_arpack': + # getting all the values + _, values = _inference_on_video( + imgs_in_path=imgs_in_path, + masks_in_path=masks_in_path, + masks_out_path=masks_out_path, + frames_with_masks=[0], + compute_uncertainty=False, + compute_iou=False, + manually_curated_masks=False, + print_progress=True, + uncertainty_name=uncertainty_name, + return_all_values=True, ## The key argument + overwrite_config={'save_masks': False}, + ) + + frames_with_masks = calculate_proposals_for_annotations_iterative_pca_cosine_values(values, how_many_frames=num_extra_frames, print_progress=False) + if use_cache and os.path.exists(csv_out_path): + final_df = pd.read_csv(csv_out_path) + else: + final_df = _inference_on_video( + imgs_in_path=imgs_in_path, + masks_in_path=masks_in_path, + masks_out_path=masks_out_path, + frames_with_masks=frames_with_masks, + compute_uncertainty=True, + compute_iou=True, + print_progress=False, + uncertainty_name=uncertainty_name, + **kwargs ) if csv_out_path is not None: @@ -375,15 +437,15 @@ def run_active_learning(imgs_in_path, masks_in_path, masks_out_path, num_extra_f def eval_active_learning(dataset_path: str, out_path: str, num_extra_frames: int, uncertainty_name: str, modes: list = None, **kwargs): assert uncertainty_name in {'entropy', 'bald', 'disparity', 'disparity_large'} - + if modes is None: - modes = ['uniform', 'random', 'batched', 'iterative'] - + modes = ['uniform', 'random', 'uniform_random', 'batched', 'iterative', 'umap_half_cosine', 'umap_hdbscan_clustering', 'pca_max_cosine_values_arpack'] + p_in_ds = Path(dataset_path) p_out = Path(out_path) big_stats = defaultdict(list) - for p_video_imgs_in in tqdm(sorted((p_in_ds / 'JPEGImages').iterdir())): + for i, p_video_imgs_in in enumerate(tqdm(sorted((p_in_ds / 'JPEGImages').iterdir()))): video_name = p_video_imgs_in.stem p_video_masks_in = p_in_ds / 'Annotations_binarized' / video_name @@ -394,8 +456,8 @@ def eval_active_learning(dataset_path: str, out_path: str, num_extra_frames: int p_out_masks = p_video_out_general / mode / 'masks' p_out_stats = p_video_out_general / mode / 'stats.csv' - stats, frames_with_masks = run_active_learning(p_video_imgs_in, p_video_masks_in, p_out_masks, - num_extra_frames=num_extra_frames, csv_out_path=p_out_stats, mode=mode, uncertainty_name=uncertainty_name, use_cache=False, **kwargs) + stats, frames_with_masks = run_active_learning(p_video_imgs_in, p_video_masks_in, p_out_masks, + num_extra_frames=num_extra_frames, csv_out_path=p_out_stats, mode=mode, uncertainty_name=uncertainty_name, use_cache=False, **kwargs) stats = stats[stats['mask_provided'] == False] # remove stats for frames with given masks for i in range(1, len(frames_with_masks) + 1): @@ -411,117 +473,77 @@ def eval_active_learning(dataset_path: str, out_path: str, num_extra_frames: int df_mode_stats.to_csv(p_out / f'Active_learning_{uncertainty_name}' / f'stats_{mode}_all_videos.csv', index=False) -if __name__ == '__main__': - pass - # inference_on_video( - # imgs_in_path='../VIDEOS/thanks_no_ears_5_annot/JPEGImages', - # masks_in_path='../VIDEOS/thanks_no_ears_5_annot/annotations_3_face', - # masks_out_path='../VIDEOS/RESULTS/XMem_memory/permanent_work_memory/thanks_no_ears_5_annot/annotations_3_face_5_frames', - # frames_with_masks=[0, 625, 785, 1300, 1488], - # compute_uncertainty=False, - # compute_iou=False, - # print_progress=True, - # manually_curated_masks=False, - # ) - # df.to_csv('output/disparity/disparity.csv', index=False) - # exit(0) - # from inference import active_learning as AL - - # img_size = (3, 480, 853) - # test_mask = FT.to_tensor(Image.open('test.png')) - # for i in tqdm(range(0, 10)): - # eval_active_learning('../VIDEOS/LVOS_dataset/valid', - # f'../VIDEOS/RESULTS/XMem_memory/permanent_work_memory/LVOS_old_scaling/{i}_extra_frames', - # i, - # uncertainty_name='entropy', - # modes=['uniform'], - # augment_images_with_mask=False, - # original_memory_mechanism=True) - - # eval_active_learning('../VIDEOS/LVOS_dataset/valid', - # '../VIDEOS/RESULTS/XMem_memory/permanent_work_memory/LVOS_best_all_augs_in_memory_fixed', - # 5, - # uncertainty_name='entropy', - # modes=['uniform'], - # augment_images_with_masks=True, - # ) - - # eval_active_learning('../VIDEOS/LVOS_dataset/valid', - # '../VIDEOS/RESULTS/XMem_memory/LVOS/LVOS_disparity', - # 5, - # uncertainty_name='disparity_large', - # modes=['random', 'uniform', 'batched', 'iterative'], - # augment_images_with_masks=False) - +def run_on_video( + imgs_in_path: Union[str, PathLike], + masks_in_path: Union[str, PathLike], + masks_out_path: Union[str, PathLike], + frames_with_masks: Iterable[int] = (0, ), + compute_iou=False, + print_progress=True, + ) -> pd.DataFrame: + """ + Args: + imgs_in_path (Union[str, PathLike]): Path to the directory containing video frames in the following format: `frame_000000.png`. .jpg works too. + + masks_in_path (Union[str, PathLike]): Path to the directory containing video frames' masks in the same format, with corresponding names between video frames. Each unique object should have unique color. + + masks_out_path (Union[str, PathLike]): Path to the output directory (will be created if doesn't exist) where the predicted masks will be stored in .png format. + + frames_with_masks (Iterable[int]): A list of integers representing the frames on which the masks should be applied (default: [0], only applied to the first frame). 0-based. + + compute_iou (bool): A flag to indicate whether to compute the IoU metric (default: False, requires ALL video frames to have a corresponding mask). + + print_progress (bool): A flag to indicate whether to print a progress bar (default: True). + + Returns: + stats (pd.Dataframe): a table containing every frame and the following information: IoU score with corresponding mask (if `compute_iou` is True) + """ + + return _inference_on_video( + imgs_in_path=imgs_in_path, + masks_in_path=masks_in_path, + masks_out_path=masks_out_path, + frames_with_masks=frames_with_masks, + compute_uncertainty=False, + compute_iou=compute_iou, + print_progress=print_progress, + manually_curated_masks=False + ) + + +def predict_annotation_candidates( + imgs_in_path: Union[str, PathLike], + num_candidates: int = 1, + print_progress=True, + ) -> List[int]: + + """ + Args: + imgs_in_path (Union[str, PathLike]): Path to the directory containing video frames in the following format: `frame_000000.png` .jpg works too. + + num_candidates (int, default: 1): How many annotations candidates to predict. + + print_progress (bool): A flag to indicate whether to print a progress bar (default: True). + + Returns: + annotation_candidates (List[int]): A list of frames indices (0-based) chosen as annotation candidates, sorted by importance (most -> least). Always contains [0] - first frame - at index 0. + """ + + assert num_candidates >= 1 + + if num_candidates == 1: + return [0] # First frame is hard-coded to always be used + + return _inference_on_video( + imgs_in_path=imgs_in_path, + masks_in_path=imgs_in_path, # Ignored + masks_out_path=None, # Ignored + frames_with_masks=[0], # Ignored + compute_uncertainty=False, + compute_iou=False, + print_progress=print_progress, + manually_curated_masks=False, + only_predict_frames_to_annotate_and_quit=num_candidates, + ) - # res, frames_with_masks = run_active_learning('/home/maksym/RESEARCH/VIDEOS/LVOS_dataset/valid/JPEGImages/0tCWPOrc', - # '/home/maksym/RESEARCH/VIDEOS/LVOS_dataset/valid/Annotations_binarized/0tCWPOrc', - # '/home/maksym/RESEARCH/VIDEOS/RESULTS/XMem_feedback/permanent_work_memory/LVOS/JUNK/masks', - # num_extra_frames=5, - # csv_out_path='/home/maksym/RESEARCH/VIDEOS/RESULTS/XMem_feedback/permanent_work_memory/LVOS/JUNK/stats.csv', mode='iterative') - - # print(frames_with_masks) - # pass - # bald_df = inference_on_video([0], - # '/home/maksym/RESEARCH/VIDEOS/thanks_no_ears_5_annot/JPEGImages', - # '/home/maksym/RESEARCH/VIDEOS/thanks_no_ears_5_annot/Annotations_binarized', - # 'JUNK', - # compute_iou=False, - # compute_uncertainty=True, - # use_bald=True) # for t.hanks style video - - # bald_df.to_csv('output/bald_thanks_0_frame.csv', index=False) - - # df = inference_on_video( - # imgs_in_path='/home/maksym/RESEARCH/VIDEOS/thanks_no_ears_5_annot/JPEGImages', - # masks_in_path='/home/maksym/RESEARCH/VIDEOS/thanks_no_ears_5_annot/Annotations_binarized', - # frames_with_masks=[0, 259, 621, 785, 1401], - # masks_out_path='../VIDEOS/RESULTS/XMem_feedback/BASELINE_REIMPLEMENTED/5_annotated_frames_new_mem', - # compute_uncertainty=False, - # compute_iou=False, - # manually_curated_masks=False, - # original_memory_mechanism=False, - # overwrite_config={'save_masks': True}) - - # df = inference_on_video( - # imgs_in_path='/home/maksym/RESEARCH/VIDEOS/LVOS_dataset/valid/JPEGImages/vjG0jbkQ', - # masks_in_path='/home/maksym/RESEARCH/VIDEOS/LVOS_dataset/valid/Annotations/vjG0jbkQ', - # masks_out_path='JUNK', - # frames_with_masks=[0], - # compute_entropy=True, - # compute_iou=True, - # manually_curated_masks=False, - # overwrite_config={'save_masks': False}) - - # print(df.shape) - # df.to_csv('junk.csv', index=False) - # p_in = Path('/home/maksym/RESEARCH/VIDEOS/DAVIS-2017-trainval-480p/DAVIS/2017_train_val_split/val/JPEGImages_chosen') - # p_in = Path('/home/maksym/RESEARCH/VIDEOS/DAVIS-2017-trainval-480p/DAVIS/2017_train_val_split/val/JPEGImages_chosen') - # num_frames_mapping = {} - - # for p_dir in sorted(p for p in p_in.iterdir() if p.is_dir()): - # dir_name = p_dir.name - # num_frames = sum(1 for _ in p_dir.iterdir()) - # num_frames_mapping[dir_name] = num_frames # math.ceil(num_frames/2) - # # print(extra_frames_ranges) - # # exit(0) - # p_bar = tqdm(total=sum(num_frames_mapping.values()), desc='% extra frames DAVIS 2017 val') - # for dir_name, total_frames in num_frames_mapping.items(): - # for how_many_extra_frames in range(0, math.ceil(total_frames // 2)): - # # frames_with_masks = set([0, frame_with_mask]) - # frames_with_masks = set(np.linspace(0, num_frames, how_many_extra_frames+2)[0:-1].astype(int)) - # inference_on_video(frames_with_masks, dir_name) - - # p_bar.update() - - # num_runs = 90 - # p_bar = tqdm(total=num_runs) - # for how_many_extra_frames in range(0, 90): - # # for j in range(0, 181 - 1): - # # e.g. [0, 10, 20, ..., 180] without 180 - # frames_with_masks = set(np.linspace(0, 180, how_many_extra_frames+2)[0:-1].astype(int)) - # # frames_with_masks = set([0, i, j]) - # inference_on_video(frames_with_masks) - - # p_bar.update() diff --git a/util/configuration.py b/util/configuration.py index 890956b..8445dad 100644 --- a/util/configuration.py +++ b/util/configuration.py @@ -13,10 +13,10 @@ def parse(self, unknown_arg_ok=False): parser.add_argument('--no_amp', action='store_true') # Data parameters - parser.add_argument('--static_root', help='Static training data root', default='../static') - parser.add_argument('--bl_root', help='Blender training data root', default='../BL30K') - parser.add_argument('--yv_root', help='YouTubeVOS data root', default='../YouTube') - parser.add_argument('--davis_root', help='DAVIS data root', default='../DAVIS') + parser.add_argument('--static_root', help='Static training data root', default='../Datasets/static') + parser.add_argument('--bl_root', help='Blender training data root', default='../Datasets/BL30K') + parser.add_argument('--yv_root', help='YouTubeVOS data root', default='../Datasets/YouTube') + parser.add_argument('--davis_root', help='DAVIS data root', default='.../Datasets/DAVIS') parser.add_argument('--num_workers', help='Total number of dataloader workers across all GPUs processes', type=int, default=16) parser.add_argument('--key_dim', default=64, type=int) @@ -32,7 +32,7 @@ def parse(self, unknown_arg_ok=False): Batch sizes are effective -- you don't have to scale them when you scale the number processes """ # Stage 0, static images - parser.add_argument('--s0_batch_size', default=16, type=int) + parser.add_argument('--s0_batch_size', default=8, type=int) parser.add_argument('--s0_iterations', default=150000, type=int) parser.add_argument('--s0_finetune', default=0, type=int) parser.add_argument('--s0_steps', nargs="*", default=[], type=int) @@ -133,3 +133,9 @@ def __setitem__(self, key, value): def __str__(self): return str(self.args) + +if __name__ == '__main__': + c = Configuration() + c.parse() + for k in sorted(c.args.keys()): + print(k, c.args[k]) \ No newline at end of file From c9038e8d2ccf86405121b77caa267c8f4ed040f1 Mon Sep 17 00:00:00 2001 From: max810 Date: Fri, 27 Jan 2023 16:58:46 +0400 Subject: [PATCH 09/49] Changed default example video --- main.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/main.py b/main.py index 3538ba6..1e5ad28 100644 --- a/main.py +++ b/main.py @@ -27,10 +27,10 @@ print(f"Average IoU for the video: {float(stats_5_frames['iou'].mean())} ({num_annotation_candidates} chosen annotated frames)") # Example for a video with only a few annotations present - video_frames_path = 'example_videos/two-face/frames' - video_masks_path = 'example_videos/two-face/masks' - output_masks_path_baseline = 'output/two-face/baseline' - output_masks_path_5_frames = 'output/two-face/5_frames' + video_frames_path = 'example_videos/imbalanced-scenes/frames' + video_masks_path = 'example_videos/imbalanced-scenes/masks' + output_masks_path_baseline = 'output/imbalanced-scenes/baseline' + output_masks_path_3_frames = 'output/imbalanced-scenes/3_frames' run_on_video(video_frames_path, video_masks_path, output_masks_path_baseline, frames_with_masks=[0], compute_iou=False) - run_on_video(video_frames_path, video_masks_path, output_masks_path_5_frames, frames_with_masks=[0, 259, 621, 785, 1401], compute_iou=False) + run_on_video(video_frames_path, video_masks_path, output_masks_path_3_frames, frames_with_masks=[0, 140, 830], compute_iou=False) From 41b517ddef91ee111c6cd71a60308713b4bb067f Mon Sep 17 00:00:00 2001 From: max810 Date: Fri, 3 Feb 2023 15:45:07 +0400 Subject: [PATCH 10/49] Fixed whitespace inconsitencies in train_s0.sh --- train_s0.sh | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) create mode 100755 train_s0.sh diff --git a/train_s0.sh b/train_s0.sh new file mode 100755 index 0000000..f940a3a --- /dev/null +++ b/train_s0.sh @@ -0,0 +1,18 @@ +STAGE=0 +B=4 # original 16 +LR=2.5e-06 # original 1e-5 / (16/4) batch size difference +START_WARM=80000 # original 20K iterations * (16/4) batch size difference +END_WARM=280000 # original 70K iterations * (16/4) batch size difference +NUM_ITER=600000 # original 150K iterations * (16/4) batch size difference +NUM_WORKERS=6 + +OMP_NUM_THREADS=6 python -m torch.distributed.run --master_port 25763 --nproc_per_node=2 train.py \ + --exp_id xmem_multiscale \ + --stage "$STAGE" \ + --s0_batch_size="$B" \ + --s0_lr="$LR" \ + --s0_start_warm="$START_WARM" \ + --s0_end_warm="$END_WARM" \ + --s0_iterations="$NUM_ITER" \ + --num_workers="$NUM_WORKERS" \ + --load_checkpoint='last' From 5408941109e63519f421e33b078fea8e37dfa76a Mon Sep 17 00:00:00 2001 From: max810 Date: Fri, 10 Feb 2023 15:52:44 +0400 Subject: [PATCH 11/49] Major cleanup and refactoring: split AL stuff into multiple files with usage examples, fixed a bug when the model didn't run if 0th frame was not provided, removed unused AL functions --- .gitignore | 3 + inference/active_learning.py | 870 ------------------ inference/data/test_datasets.py | 2 +- inference/data/video_reader.py | 9 +- inference/frame_selection/__init__.py | 0 inference/frame_selection/frame_selection.py | 247 +++++ .../frame_selection/frame_selection_utils.py | 323 +++++++ inference/inference_core.py | 2 +- inference/run_experiments.py | 259 ++++++ inference/run_on_video.py | 393 ++++++++ main.py | 2 +- run_on_video.py | 549 ----------- 12 files changed, 1234 insertions(+), 1425 deletions(-) delete mode 100644 inference/active_learning.py create mode 100644 inference/frame_selection/__init__.py create mode 100644 inference/frame_selection/frame_selection.py create mode 100644 inference/frame_selection/frame_selection_utils.py create mode 100644 inference/run_experiments.py create mode 100644 inference/run_on_video.py delete mode 100644 run_on_video.py diff --git a/.gitignore b/.gitignore index bee1e68..4991846 100644 --- a/.gitignore +++ b/.gitignore @@ -135,3 +135,6 @@ dmypy.json # Pyre type checker .pyre/ + +output/ +example_videos/ \ No newline at end of file diff --git a/inference/active_learning.py b/inference/active_learning.py deleted file mode 100644 index fa8a3bf..0000000 --- a/inference/active_learning.py +++ /dev/null @@ -1,870 +0,0 @@ -from dataclasses import asdict -from functools import partial -from os import access -from pathlib import Path -from turtle import numinput -from sklearn.cluster import KMeans -import pandas as pd -from PIL import Image -import torch -import torchvision.transforms.functional as FT -import numpy as np -from torchvision.transforms import ColorJitter, Grayscale, RandomPosterize, RandomAdjustSharpness, ToTensor, RandomAffine -from sklearn.decomposition import PCA, FastICA -from sklearn.cluster import AgglomerativeClustering -from scipy.spatial.distance import cdist -from sklearn.manifold import TSNE -from umap import UMAP -from hdbscan import flat - -from tqdm import tqdm -from model.memory_util import do_softmax, get_similarity - -from util.tensor_util import get_bbox_from_mask - -def select_n_frame_candidates(preds_df: pd.DataFrame, uncertainty_name: str, n=5): - df = preds_df - - df.reset_index(drop=False, inplace=True) - - # max_frame = df['frame'].max() - # max_entropy = df['entropy'].max() - - df = df[df['mask_provided'] == False] # removing frames with masks - df = df[df[uncertainty_name] >= df[uncertainty_name].median()] # removing low entropy parts - - df_backup = df.copy() - - df['index'] = df['index'] / df['index'].max() # scale to 0..1 - # df['entropy'] = df['entropy'] / df['entropy'].max() # scale to 0..1 - - X = df[['index', uncertainty_name]].to_numpy() - - clusterer = KMeans(n_clusters=n) - - labels = clusterer.fit_predict(X) - - clusters = df_backup.groupby(labels) - - candidates = [] - - for g, cluster in clusters: - if g == -1: - continue - - max_entropy_idx = cluster[uncertainty_name].argmax() - - res = cluster.iloc[max_entropy_idx] - - candidates.append(res) - - return candidates - -def select_most_uncertain_frame(preds_df: pd.DataFrame, uncertainty_name: str): - df = preds_df[preds_df['mask_provided'] == False] - df.reset_index(drop=False, inplace=True) - return df.iloc[df[uncertainty_name].argmax()] - -def select_n_frame_candidates_no_neighbours_simple(preds_df: pd.DataFrame, uncertainty_name: str, n=5, neighbourhood_size=4): - df = preds_df - df.reset_index(drop=False, inplace=True) - - df = df[df['mask_provided'] == False] # removing frames with masks - - neighbours_indices = set() - chosen_candidates = [] - - df_sorted = df.sort_values(uncertainty_name, ascending=False) - i = 0 - while len(chosen_candidates) < n: - candidate = df_sorted.iloc[i] - candidate_index = candidate['index'] - - if candidate_index not in neighbours_indices: - chosen_candidates.append(candidate) - candidate_neighbours = range(candidate_index - neighbourhood_size, candidate_index + neighbourhood_size + 1) - neighbours_indices.update(candidate_neighbours) - - i += 1 - - return chosen_candidates - - -WhichAugToPick = -1 - -def get_determenistic_augmentations(img_size=None, mask=None, subset: str=None): - assert subset in {'best_3', 'best_3_with_symmetrical', 'best_all', 'original_only', 'all'} - - bright = ColorJitter(brightness=(1.5, 1.5)) - dark = ColorJitter(brightness=(0.5, 0.5)) - gray = Grayscale(num_output_channels=3) - reduce_bits = RandomPosterize(bits=3, p=1) - sharp = RandomAdjustSharpness(sharpness_factor=16, p=1) - rotate_right = RandomAffine(degrees=(30, 30)) - blur = partial(FT.gaussian_blur, kernel_size=7) - - if img_size is not None: - h, w = img_size[-2:] - translate_distance = w // 5 - else: - translate_distance = 200 - - translate_right = partial(FT.affine, angle=0, translate=(translate_distance, 0), scale=1, shear=0) - - zoom_out = partial(FT.affine, angle=0, translate=(0, 0), scale=0.5, shear=0) - zoom_in = partial(FT.affine, angle=0, translate=(0, 0), scale=1.5, shear=0) - shear_right = partial(FT.affine, angle=0, translate=(0, 0), scale=1, shear=20) - - identity = torch.nn.Identity() - identity.name = 'identity' - - if mask is not None: - if mask.any(): - min_y, min_x, max_y, max_x = get_bbox_from_mask(mask) - h, w = mask.shape[-2:] - crop_mask = partial(FT.resized_crop, top=min_y - 10, left=min_x - 10, height=max_y - min_y + 10, width=max_x - min_x + 10, size=(w, h)) - crop_mask.name = 'crop_mask' - else: - crop_mask = identity # if the mask is empty - else: - crop_mask = None - - bright.name = 'bright' - dark.name = 'dark' - gray.name = 'gray' - reduce_bits.name = 'reduce_bits' - sharp.name = 'sharp' - rotate_right.name = 'rotate_right' - translate_right.name = 'translate_right' - zoom_out.name = 'zoom_out' - zoom_in.name = 'zoom_in' - shear_right.name = 'shear_right' - blur.name = 'blur' - - - - rotate_left = RandomAffine(degrees=(-30, -30)) - rotate_left.name = 'rotate_left' - - shear_left = partial(FT.affine, angle=0, translate=(0, 0), scale=1, shear=-20) - shear_left.name = 'shear_left' - - if WhichAugToPick != -1: - return [img_mask_augs_pairs[WhichAugToPick]] - - if subset == 'best_3': - img_mask_augs_pairs = [ - # augs only applied to the image - # (bright, identity), - # (dark, identity), - # (gray, identity), - # (reduce_bits, identity), - # (sharp, identity), - (blur, identity), - - # augs requiring modifying the mask as well: - # (rotate_right, rotate_right), - # (rotate_left, rotate_left), - # (translate_right, translate_right), - # (zoom_out, zoom_out), - (zoom_in, zoom_in), - (shear_right, shear_right), - # (shear_left, shear_left), - ] - - return img_mask_augs_pairs - elif subset == 'best_3_with_symmetrical': - img_mask_augs_pairs = [ - # augs only applied to the image - # (bright, identity), - # (dark, identity), - # (gray, identity), - # (reduce_bits, identity), - # (sharp, identity), - (blur, identity), - - # augs requiring modifying the mask as well: - # (rotate_right, rotate_right), - # (rotate_left, rotate_left), - # (translate_right, translate_right), - # (zoom_out, zoom_out), - (zoom_in, zoom_in), - (shear_right, shear_right), - (shear_left, shear_left), - ] - - return img_mask_augs_pairs - elif subset == 'best_all': - img_mask_augs_pairs = [ - # augs only applied to the image - (bright, identity), - (dark, identity), - # (gray, identity), - (reduce_bits, identity), - (sharp, identity), - (blur, identity), - - # augs requiring modifying the mask as well: - (rotate_right, rotate_right), - (rotate_left, rotate_left), - # (translate_right, translate_right), - (zoom_out, zoom_out), - (zoom_in, zoom_in), - (shear_right, shear_right), - (shear_left, shear_left), - ] - - return img_mask_augs_pairs - - elif subset == 'original_only': - img_mask_augs_pairs = [ - # augs only applied to the image - (bright, identity), - (dark, identity), - (gray, identity), - (reduce_bits, identity), - (sharp, identity), - (blur, identity), - - # augs requiring modifying the mask as well: - # (rotate_right, rotate_right), - # (translate_right, translate_right), - # (zoom_out, zoom_out), - # (zoom_in, zoom_in), - # (shear_right, shear_right), - ] - else: - img_mask_augs_pairs = [ - # augs only applied to the image - (bright, identity), - (dark, identity), - (gray, identity), - (reduce_bits, identity), - (sharp, identity), - (blur, identity), - - # augs requiring modifying the mask as well: - (rotate_right, rotate_right), - (rotate_left, rotate_left), - (translate_right, translate_right), - (zoom_out, zoom_out), - (zoom_in, zoom_in), - (shear_right, shear_right), - (shear_left, shear_left), - ] - - if crop_mask is not None: - img_mask_augs_pairs.append((crop_mask, crop_mask)) - - return img_mask_augs_pairs - -def _extract_keys(dataloder, processor, print_progress=False): - frame_keys = [] - shrinkages = [] - selections = [] - device = None - with torch.no_grad(): # just in case - key_sum = None - - for ti, data in enumerate(tqdm(dataloder, disable=not print_progress, desc='Calculating key features')): - rgb = data['rgb'].cuda()[0] - key, shrinkage, selection = processor.encode_frame_key(rgb) - - if key_sum is None: - device = key.device - key_sum = torch.zeros_like(key, device=device, dtype=torch.float64) # to avoid possible overflow - - key_sum += key.type(torch.float64) - - frame_keys.append(key.flatten(start_dim=2).cpu()) - shrinkages.append(shrinkage.flatten(start_dim=2).cpu()) - selections.append(selection.flatten(start_dim=2).cpu()) - - num_frames = ti + 1 # 0 after 1 iteration, 1 after 2, etc. - - return frame_keys, shrinkages, selections, device, num_frames, key_sum - -def calculate_proposals_for_annotations_with_average_distance(dataloader, processor, how_many_frames=9, print_progress=False): - with torch.no_grad(): # just in case - frame_keys, device, num_frames, key_sum = _extract_keys(dataloader, processor, print_progress) - - avg_key = (key_sum / num_frames).type(torch.float32) - qk = avg_key.flatten(start_dim=2) - - similarities = [] - for i in tqdm(range(num_frames), desc='Computing similarity to avg frame'): # how to run a loop for lower memory usage - frame_key = frame_keys[i] - similarity_per_pixel = get_similarity(frame_key.to(device), ms=None, qk=qk, qe=None) - similarity_avg = (similarity_per_pixel < 0).sum() # number of dissimilar pixels - - similarities.append(similarity_avg) - - # import matplotlib.pyplot as plt - # import numpy as np - # plt.figure(figsize=(16, 10)) - # plt.xticks(np.arange(0, num_frames, 100)) - # plt.plot([float(x) for x in similarities]) - # plt.title("Inner XMem mean similarity VS average frame") - # plt.savefig( - # 'output/similarities_NEG_vs_avg_frame.png' - # ) - values, indices = torch.topk(torch.tensor(similarities), k=how_many_frames, largest=True) # top `how_many_frames` frames LEAST similar to the avg_key - return indices - -def calculate_proposals_for_annotations_with_first_distance(dataloader, processor, how_many_frames=9, print_progress=False): - with torch.no_grad(): # just in case - frame_keys, device, num_frames, key_sum = _extract_keys(dataloader, processor, print_progress) - qk = frame_keys[0].flatten(start_dim=2).to(device) - - similarities = [] - first_similarity = None - for i in tqdm(range(num_frames), desc='Computing similarity to avg frame'): # how to run a loop for lower memory usage - frame_key = frame_keys[i] - similarity_per_pixel = get_similarity(frame_key.to(device), ms=None, qk=qk, qe=None) - if i == 0: - first_similarity = similarity_per_pixel - similarity_avg = similarity_per_pixel.mean() - - # if i == 0 or i == 175 or i == 353 or i == 560 or i == 900: - # import seaborn as sns - # import matplotlib.pyplot as plt - # plt.figure(figsize=(40, 40)) - # sns.heatmap(similarity_per_pixel.squeeze().cpu(), square=True, cmap="icefire") - # plt.savefig(f'output/SIMILARITY_HEATMAPS/0_vs_{i}.png') - - # plt.figure(figsize=(40, 40)) - # sns.heatmap((similarity_per_pixel - first_similarity).squeeze().cpu(), square=True, cmap="icefire") - # plt.savefig(f'output/SIMILARITY_HEATMAPS/0_vs_{i}_diff_with_0.png') - similarities.append(similarity_avg) - - # import matplotlib.pyplot as plt - # import numpy as np - # plt.figure(figsize=(16, 10)) - # plt.xticks(np.arange(0, num_frames, 100)) - # plt.plot([float(x) for x in similarities]) - # plt.title("Inner XMem mean similarity VS 1st frame") - # plt.savefig( - # 'output/similarities_NEG_vs_1st_frame.png' - # ) - - # we don't need to worry about 1st frame with itself, since we take the LEAST similar frames - values, indices = torch.topk(torch.tensor(similarities), k=how_many_frames, largest=False) # top `how_many_frames` frames LEAST similar to the avg_key - return indices - - -def calculate_proposals_for_annotations_with_iterative_distance(dataloader, processor, how_many_frames=9, print_progress=False): - with torch.no_grad(): # just in case - frame_keys, device, num_frames, key_sum = _extract_keys(dataloader, processor, print_progress) - - chosen_frames = [0] - chosen_frames_mem_keys = [frame_keys[0].flatten(start_dim=2).to(device)] - - for i in tqdm(range(how_many_frames), desc='Iteratively picking the most dissimilar frames'): - similarities = [] - for j in tqdm(range(num_frames), desc='Computing similarity to avg frame', disable=True): # how to run a loop for lower memory usage - qk = frame_keys[j].to(device) - - similarities_across_mem_keys = [] - for mem_key in chosen_frames_mem_keys: - similarity_per_pixel = get_similarity(qk, ms=None, qk=mem_key, qe=None) - similarity_avg = similarity_per_pixel.mean() - similarities_across_mem_keys.append(similarity_avg) - - similarity_max_across_all = max(similarities_across_mem_keys) - similarities.append(similarity_max_across_all) - - values, indices = torch.topk(torch.tensor(similarities), k=1, largest=False) - idx = int(indices[0]) - - import matplotlib.pyplot as plt - import numpy as np - plt.figure(figsize=(16, 10)) - plt.xticks(np.arange(0, num_frames, 100)) - plt.plot([float(x) for x in similarities]) - plt.title(f"Inner XMem mean similarity VS frames {chosen_frames}") - plt.savefig( - f'output/iterative_similarity/{i}.png' - ) - - chosen_frames.append(idx) - next_frame_to_add = frame_keys[idx] - chosen_frames_mem_keys.append(next_frame_to_add.to(device)) - - - # we don't need to worry about 1st frame with itself, since we take the LEAST similar frames - return chosen_frames - - -def calculate_proposals_for_annotations_with_iterative_distance_diff(dataloader, processor, how_many_frames=9, print_progress=False): - with torch.no_grad(): # just in case - frame_keys, device, num_frames, key_sum = _extract_keys(dataloader, processor, print_progress) - - chosen_frames = [0] - chosen_frames_mem_keys = [frame_keys[0].flatten(start_dim=2).to(device)] - - chosen_frames_self_similarities = [get_similarity(chosen_frames_mem_keys[0], ms=None, qk=chosen_frames_mem_keys[0], qe=None)] - for i in tqdm(range(how_many_frames), desc='Iteratively picking the most dissimilar frames', disable=not print_progress): - dissimilarities = [] - for j in tqdm(range(num_frames), desc='Computing similarity to chosen frames', disable=not print_progress): # how to run a loop for lower memory usage - qk = frame_keys[j].to(device) - - dissimilarities_across_mem_keys = [] - for mem_key, self_sim in zip(chosen_frames_mem_keys, chosen_frames_self_similarities): - similarity_per_pixel = get_similarity(qk, ms=None, qk=mem_key, qe=None) - - # basically, removing scene ambiguity and only keeping differences due to the scene change - # in theory, of course - dissimilarity_score = (similarity_per_pixel - self_sim).abs().sum() / similarity_per_pixel.numel() - dissimilarities_across_mem_keys.append(dissimilarity_score) - - # filtering our existin or very similar frames - dissimilarity_min_across_all = min(dissimilarities_across_mem_keys) - dissimilarities.append(dissimilarity_min_across_all) - - values, indices = torch.topk(torch.tensor(dissimilarities), k=1, largest=True) - idx = int(indices[0]) - - # import matplotlib.pyplot as plt - # import numpy as np - # plt.figure(figsize=(16, 10)) - # plt.xticks(np.arange(0, num_frames, 100)) - # plt.plot([float(x) for x in dissimilarities]) - # plt.title(f"Inner XMem mean dissimilarity VS frames {chosen_frames}") - # plt.savefig( - # f'output/iterative_dissimilarity/{i}.png' - # ) - - chosen_frames.append(idx) - next_frame_to_add_key = frame_keys[idx].to(device) - chosen_frames_mem_keys.append(next_frame_to_add_key) - chosen_frames_self_similarities.append(get_similarity(next_frame_to_add_key, ms=None, qk=next_frame_to_add_key, qe=None)) - - - # we don't need to worry about 1st frame with itself, since we take the LEAST similar frames - return chosen_frames - -def calculate_proposals_for_annotations_with_uniform_iterative_distance_diff(dataloader, processor, how_many_frames=9, print_progress=False): - with torch.no_grad(): - frame_keys, shrinkages, selections, device, num_frames, key_sum = _extract_keys(dataloader, processor, print_progress) - - space = np.linspace(0, num_frames, how_many_frames + 2, endpoint=True, dtype=int) - ranges = zip(space, space[1:]) - - chosen_frames = [] - chosen_frames_mem_keys = [] - chosen_frames_self_similarities = [] - chosen_frames_shrinkages = [] - - for a, b in ranges: - if a == 0: - chosen_new_frame = 0 # in the first range we always pick the first frame - else: - dissimilarities = [] - for i in tqdm(range(b - a), desc=f'Computing similarity to chosen frames in range({a}, {b})', disable=not print_progress): # how to run a loop for lower memory usage - true_frame_idx = a + i - qk = frame_keys[true_frame_idx].to(device) - selection = selections[true_frame_idx].to(device) # query - - dissimilarities_across_mem_keys = [] - for mem_key, shrinkage, self_sim in zip(chosen_frames_mem_keys, chosen_frames_shrinkages, chosen_frames_self_similarities): - similarity_per_pixel = get_similarity(mem_key, ms=shrinkage, qk=qk, qe=selection) - - # basically, removing scene ambiguity and only keeping differences due to the scene change - # in theory, of course - diff = (similarity_per_pixel - self_sim) - dissimilarity_score = diff[diff > 0].sum() / similarity_per_pixel.numel() - dissimilarities_across_mem_keys.append(dissimilarity_score) - - # filtering our existing or very similar frames - dissimilarity_min_across_all = min(dissimilarities_across_mem_keys) - dissimilarities.append(dissimilarity_min_across_all) - - values, indices = torch.topk(torch.tensor(dissimilarities), k=1, largest=True) - chosen_new_frame = int(indices[0]) + a - - chosen_frames.append(chosen_new_frame) - chosen_frames_mem_keys.append(frame_keys[chosen_new_frame].to(device)) - chosen_frames_shrinkages.append(shrinkages[chosen_new_frame].to(device)) - chosen_frames_self_similarities.append(get_similarity(chosen_frames_mem_keys[-1], ms=shrinkages[chosen_new_frame].to(device), qk=chosen_frames_mem_keys[-1], qe=selections[chosen_new_frame].to(device))) - - # we don't need to worry about 1st frame with itself, since we take the LEAST similar frames - return chosen_frames - -def calculate_proposals_for_annotations_with_uniform_iterative_distance_cycle(dataloader, processor, how_many_frames=9, print_progress=False): - with torch.no_grad(): - frame_keys, shrinkages, selections, device, num_frames, key_sum = _extract_keys(dataloader, processor, print_progress) - - space = np.linspace(0, num_frames, how_many_frames + 2, endpoint=True, dtype=int) - ranges = zip(space, space[1:]) - - chosen_frames = [] - chosen_frames_mem_keys = [] - # chosen_frames_self_similarities = [] - chosen_frames_shrinkages = [] - - for a, b in ranges: - if a == 0: - chosen_new_frame = 0 # in the first range we always pick the first frame - else: - dissimilarities = [] - for i in tqdm(range(b - a), desc=f'Computing similarity to chosen frames in range({a}, {b})', disable=not print_progress): # how to run a loop for lower memory usage - true_frame_idx = a + i - qk = frame_keys[true_frame_idx].to(device) - query_selection = selections[true_frame_idx].to(device) # query - query_shrinkage = shrinkages[true_frame_idx].to(device) - - dissimilarities_across_mem_keys = [] - for key_idx, mem_key, key_shrinkage in zip(chosen_frames, chosen_frames_mem_keys, chosen_frames_shrinkages): - mem_key = mem_key.to(device) - key_selection = selections[key_idx].to(device) - similarity_per_pixel = get_similarity(mem_key, ms=key_shrinkage, qk=qk, qe=query_selection) - reverse_similarity_per_pixel = get_similarity(qk, ms=query_shrinkage, qk=mem_key, qe=key_selection) - - # mapping of pixels A -> B would be very similar to B -> A if the images are similar - # and very different if the images are different - cycle_dissimilarity_per_pixel = (similarity_per_pixel - reverse_similarity_per_pixel) - cycle_dissimilarity_score = cycle_dissimilarity_per_pixel.abs().sum() / cycle_dissimilarity_per_pixel.numel() - - dissimilarities_across_mem_keys.append(cycle_dissimilarity_score) - - # filtering our existing or very similar frames - dissimilarity_min_across_all = min(dissimilarities_across_mem_keys) - dissimilarities.append(dissimilarity_min_across_all) - - values, indices = torch.topk(torch.tensor(dissimilarities), k=1, largest=True) - chosen_new_frame = int(indices[0]) + a - - chosen_frames.append(chosen_new_frame) - chosen_frames_mem_keys.append(frame_keys[chosen_new_frame].to(device)) - chosen_frames_shrinkages.append(shrinkages[chosen_new_frame].to(device)) - # chosen_frames_self_similarities.append(get_similarity(chosen_frames_mem_keys[-1], ms=shrinkages[chosen_new_frame].to(device), qk=chosen_frames_mem_keys[-1], qe=selections[chosen_new_frame].to(device))) - - # we don't need to worry about 1st frame with itself, since we take the LEAST similar frames - return chosen_frames - - -def calculate_proposals_for_annotations_with_iterative_distance_cycle(dataloader, processor, how_many_frames=9, print_progress=False): - with torch.no_grad(): - frame_keys, shrinkages, selections, device, num_frames, key_sum = _extract_keys(dataloader, processor, print_progress) - - chosen_frames = [0] - chosen_frames_mem_keys = [frame_keys[0].to(device)] - - for i in tqdm(range(how_many_frames), desc='Iteratively picking the most dissimilar frames', disable=not print_progress): - dissimilarities = [] - for j in tqdm(range(num_frames), desc='Computing similarity to chosen frames', disable=not print_progress): # how to run a loop for lower memory usage - qk = frame_keys[j].to(device) - query_selection = selections[j].to(device) # query - query_shrinkage = shrinkages[j].to(device) - - dissimilarities_across_mem_keys = [] - for key_idx, mem_key in zip(chosen_frames, chosen_frames_mem_keys): - mem_key = mem_key.to(device) - key_shrinkage = shrinkages[key_idx].to(device) - key_selection = selections[key_idx].to(device) - - similarity_per_pixel = get_similarity(mem_key, ms=None, qk=qk, qe=None) - reverse_similarity_per_pixel = get_similarity(qk, ms=None, qk=mem_key, qe=None) - - # mapping of pixels A -> B would be very similar to B -> A if the images are similar - # and very different if the images are different - cycle_dissimilarity_per_pixel = (similarity_per_pixel - reverse_similarity_per_pixel) - cycle_dissimilarity_score = cycle_dissimilarity_per_pixel.abs().sum() / cycle_dissimilarity_per_pixel.numel() - - dissimilarities_across_mem_keys.append(cycle_dissimilarity_score) - - # filtering our existing or very similar frames - dissimilarity_min_across_all = min(dissimilarities_across_mem_keys) - dissimilarities.append(dissimilarity_min_across_all) - - values, indices = torch.topk(torch.tensor(dissimilarities), k=1, largest=True) - chosen_new_frame = int(indices[0]) - - chosen_frames.append(chosen_new_frame) - chosen_frames_mem_keys.append(frame_keys[chosen_new_frame].to(device)) - # chosen_frames_self_similarities.append(get_similarity(chosen_frames_mem_keys[-1], ms=shrinkages[chosen_new_frame].to(device), qk=chosen_frames_mem_keys[-1], qe=selections[chosen_new_frame].to(device))) - - # we don't need to worry about 1st frame with itself, since we take the LEAST similar frames - return chosen_frames - - -def calculate_proposals_for_annotations_with_uniform_iterative_distance_double_diff(dataloader, processor, how_many_frames=9, print_progress=False): - with torch.no_grad(): - frame_keys, shrinkages, selections, device, num_frames, key_sum = _extract_keys(dataloader, processor, print_progress) - - space = np.linspace(0, num_frames, how_many_frames + 2, endpoint=True, dtype=int) - ranges = zip(space, space[1:]) - - chosen_frames = [] - chosen_frames_mem_keys = [] - # chosen_frames_self_similarities = [] - chosen_frames_shrinkages = [] - - for a, b in ranges: - if a == 0: - chosen_new_frame = 0 # in the first range we always pick the first frame - else: - dissimilarities = [] - for i in tqdm(range(b - a), desc=f'Computing similarity to chosen frames in range({a}, {b})', disable=not print_progress): # how to run a loop for lower memory usage - true_frame_idx = a + i - qk = frame_keys[true_frame_idx].to(device) - query_selection = selections[true_frame_idx].to(device) # query - query_shrinkage = shrinkages[true_frame_idx].to(device) - - dissimilarities_across_mem_keys = [] - for key_idx, mem_key in zip(chosen_frames, chosen_frames_mem_keys): - mem_key = mem_key.to(device) - key_shrinkage = shrinkages[key_idx].to(device) - key_selection = selections[key_idx].to(device) - - similarity_per_pixel = get_similarity(mem_key, ms=key_shrinkage, qk=qk, qe=query_selection) - self_similarity_key = get_similarity(mem_key, ms=key_shrinkage, qk=mem_key, qe=key_selection) - self_similarity_query = get_similarity(qk, ms=query_shrinkage, qk=query_shrinkage, qe=query_selection) - - # mapping of pixels A -> B would be very similar to B -> A if the images are similar - # and very different if the images are different - - pure_similarity = 2 * similarity_per_pixel - self_similarity_key - self_similarity_query - - dissimilarity_score = pure_similarity.abs().sum() / pure_similarity.numel() - - dissimilarities_across_mem_keys.append(dissimilarity_score) - - # filtering our existing or very similar frames - dissimilarity_min_across_all = min(dissimilarities_across_mem_keys) - dissimilarities.append(dissimilarity_min_across_all) - - values, indices = torch.topk(torch.tensor(dissimilarities), k=1, largest=True) - chosen_new_frame = int(indices[0]) + a - - chosen_frames.append(chosen_new_frame) - chosen_frames_mem_keys.append(frame_keys[chosen_new_frame].to(device)) - chosen_frames_shrinkages.append(shrinkages[chosen_new_frame].to(device)) - # chosen_frames_self_similarities.append(get_similarity(chosen_frames_mem_keys[-1], ms=shrinkages[chosen_new_frame].to(device), qk=chosen_frames_mem_keys[-1], qe=selections[chosen_new_frame].to(device))) - - # we don't need to worry about 1st frame with itself, since we take the LEAST similar frames - return chosen_frames - - -def calculate_proposals_for_annotations_iterative_pca_cosine(dataloader, processor, how_many_frames=9, print_progress=False): - # might not pick 0-th frame - np.random.seed(1) # Just in case - with torch.no_grad(): - frame_keys, shrinkages, selections, device, num_frames, key_sum = _extract_keys(dataloader, processor, print_progress) - flat_keys = torch.stack([key.flatten().cpu() for key in frame_keys]).numpy() - - # PCA hangs at num_frames // 2: https://github.com/scikit-learn/scikit-learn/issues/22434 - pca = PCA(num_frames - 1, svd_solver='arpack') - # pca = FastICA(num_frames - 1) - smol_keys = pca.fit_transform(flat_keys.astype(np.float64)) - # smol_keys = flat_keys # to disable PCA - - chosen_frames = [0] - for c in range(how_many_frames): - distances = cdist(smol_keys[chosen_frames], smol_keys, metric='euclidean') - closest_to_mem_key_distances = distances.min(axis=0) - most_distant_frame = np.argmax(closest_to_mem_key_distances) - chosen_frames.append(most_distant_frame) - - return chosen_frames - - -def calculate_proposals_for_annotations_iterative_umap_cosine(dataloader, processor, how_many_frames=9, print_progress=False): - # might not pick 0-th frame - np.random.seed(1) # Just in case - with torch.no_grad(): - frame_keys, shrinkages, selections, device, num_frames, key_sum = _extract_keys(dataloader, processor, print_progress) - flat_keys = torch.stack([key.flatten().cpu() for key in frame_keys]).numpy() - - # PCA hangs at num_frames // 2: https://github.com/scikit-learn/scikit-learn/issues/22434 - pca = UMAP(n_neighbors=num_frames - 1, n_components=num_frames // 2, random_state=1) - # pca = FastICA(num_frames - 1) - smol_keys = pca.fit_transform(flat_keys.astype(np.float64)) - # smol_keys = flat_keys # to disable PCA - - chosen_frames = [0] - for c in range(how_many_frames): - distances = cdist(smol_keys[chosen_frames], smol_keys, metric='euclidean') - closest_to_mem_key_distances = distances.min(axis=0) - most_distant_frame = np.argmax(closest_to_mem_key_distances) - chosen_frames.append(most_distant_frame) - - return chosen_frames - - -def calculate_proposals_for_annotations_uniform_iterative_pca_cosine(dataloader, processor, how_many_frames=9, print_progress=False): - # might not pick 0-th frame - with torch.no_grad(): - frame_keys, shrinkages, selections, device, num_frames, key_sum = _extract_keys(dataloader, processor, print_progress) - flat_keys = torch.stack([key.flatten().cpu() for key in frame_keys]).numpy() - - # PCA hangs at num_frames // 2: https://github.com/scikit-learn/scikit-learn/issues/22434 - pca = PCA(num_frames - 1, svd_solver='arpack') - smol_keys = pca.fit_transform(flat_keys) - # smol_keys = flat_keys # to disable PCA - - space = np.linspace(0, num_frames, how_many_frames + 2, endpoint=True, dtype=int) - ranges = zip(space, space[1:]) - - chosen_frames = [0] - - for a, b in ranges: - if a == 0: - # skipping the first one - continue - - distances = cdist(smol_keys[chosen_frames], smol_keys[a:b], metric='cosine') - closest_to_mem_key_distances = distances.min(axis=0) - most_distant_frame = np.argmax(closest_to_mem_key_distances) + a - - chosen_frames.append(most_distant_frame) - - return chosen_frames - - -def calculate_proposals_for_annotations_iterative_pca_cosine_values(values, how_many_frames=9, print_progress=False): - # might not pick 0-th frame - np.random.seed(1) # Just in case - with torch.no_grad(): - num_frames = len(values) - flat_values = torch.stack([value.flatten().cpu() for value in values]).numpy() - - # PCA hangs at num_frames // 2: https://github.com/scikit-learn/scikit-learn/issues/22434 - pca = PCA(num_frames - 1, svd_solver='arpack') - # pca = FastICA(num_frames - 1) - smol_values = pca.fit_transform(flat_values.astype(np.float64)) - # smol_keys = flat_keys # to disable PCA - - chosen_frames = [0] - for c in range(how_many_frames): - distances = cdist(smol_values[chosen_frames], smol_values, metric='euclidean') - closest_to_mem_key_distances = distances.min(axis=0) - most_distant_frame = np.argmax(closest_to_mem_key_distances) - chosen_frames.append(most_distant_frame) - - return chosen_frames - - -def calculate_proposals_for_annotations_umap_hdbscan_clustering(dataloader, processor, how_many_frames=9, print_progress=False): - # might not pick 0-th frame - with torch.no_grad(): - frame_keys, shrinkages, selections, device, num_frames, key_sum = _extract_keys(dataloader, processor, print_progress) - flat_keys = torch.stack([key.flatten().cpu() for key in frame_keys]).numpy() - - pca = UMAP(n_neighbors=num_frames - 1, n_components=num_frames // 2, random_state=1) - smol_keys = pca.fit_transform(flat_keys) - - # clustering = AgglomerativeClustering(n_clusters=how_many_frames + 1, linkage='single') - clustering = flat.HDBSCAN_flat(smol_keys, n_clusters=how_many_frames + 1) - labels = clustering.labels_ - - chosen_frames = [] - for c in range(how_many_frames + 1): - vectors = smol_keys[labels == c] - true_index_mapping = {i: int(ti) for i, ti in zip(range(len(vectors)), np.nonzero(labels == c)[0])} - center = np.mean(vectors, axis=0) - - distances = cdist(vectors, [center], metric='euclidean').squeeze() - - closest_to_cluster_center_idx = np.argsort(distances)[0] - - chosen_frame_idx = true_index_mapping[closest_to_cluster_center_idx] - chosen_frames.append(chosen_frame_idx) - - return chosen_frames - - -def calculate_proposals_for_annotations_pca_hierarchical_clustering(dataloader, processor, how_many_frames=9, print_progress=False): - # might not pick 0-th frame - with torch.no_grad(): - frame_keys, shrinkages, selections, device, num_frames, key_sum = _extract_keys(dataloader, processor, print_progress) - flat_keys = torch.stack([key.flatten().cpu() for key in frame_keys]).numpy() - - pca = PCA(num_frames) - smol_keys = pca.fit_transform(flat_keys) - - # clustering = AgglomerativeClustering(n_clusters=how_many_frames + 1, linkage='single') - clustering = KMeans(n_clusters=how_many_frames + 1) - labels = clustering.fit_predict(smol_keys) - - chosen_frames = [] - for c in range(how_many_frames + 1): - vectors = smol_keys[labels == c] - true_index_mapping = {i: int(ti) for i, ti in zip(range(len(vectors)), np.nonzero(labels == c)[0])} - center = np.mean(vectors, axis=0) - - distances = cdist(vectors, [center], metric='euclidean').squeeze() - - closest_to_cluster_center_idx = np.argsort(distances)[0] - - chosen_frame_idx = true_index_mapping[closest_to_cluster_center_idx] - chosen_frames.append(chosen_frame_idx) - - return chosen_frames - - -def apply_aug(img_path, out_path): - img = Image.open(img_path) - - bright, dark, gray, reduce_bits, sharp = get_determenistic_augmentations() - - img_augged = sharp(img) - - img_augged.save(out_path) - - -def compute_disparity(predictions, augs, images:list = None, output_save_path: str = None): - assert len(predictions) - len(augs) == 1 - disparity_map = None - prev = None - - if images is None: - images = [None] * len(predictions) - else: - assert len(predictions) == len(images) - - if output_save_path is not None: - p_out_disparity = Path(output_save_path) - else: - p_out_disparity = None - - try: - aug_names = [aug.name for aug in augs] - except AttributeError: - aug_names = [aug._get_name() for aug in augs] - - names = ['original'] + aug_names - for i, (name, img, pred) in enumerate(zip(names, images, predictions)): - fg_mask = pred[1:2].squeeze().cpu() # 1:2 is Foreground - - if disparity_map is None: - disparity_map = torch.zeros_like(fg_mask) - else: - disparity_map += (prev - fg_mask).abs() - - pred_mask_ = FT.to_pil_image(fg_mask) - if p_out_disparity is not None: - p_out_save_mask = p_out_disparity / 'masks' / (f'{i}_{name}.png') - p_out_save_image = p_out_disparity / 'images' / (f'{i}_{name}.png') - - if not p_out_save_mask.parent.exists(): - p_out_save_mask.parent.mkdir(parents=True) - - pred_mask_.save(p_out_save_mask) - - if not p_out_save_image.parent.exists(): - p_out_save_image.parent.mkdir(parents=True) - - img.save(p_out_save_image) - - prev = fg_mask - - disparity_scaled = disparity_map / (len(augs) + 1) # 0..1; not `disparity_map.max()`, as the scale would differ across images - disparity_avg = disparity_scaled.mean() - disparity_large = (disparity_scaled > 0.5).sum() # num pixels with large disparities - - if p_out_disparity is not None: - disparity_img = FT.to_pil_image(disparity_scaled) - disparity_img.save(p_out_disparity / (f'{i+1}_absolute_disparity.png')) - - return {'full': disparity_scaled, 'avg': disparity_avg, 'large': disparity_large} diff --git a/inference/data/test_datasets.py b/inference/data/test_datasets.py index 3a4446e..1f2a1d5 100644 --- a/inference/data/test_datasets.py +++ b/inference/data/test_datasets.py @@ -89,7 +89,7 @@ def get_datasets(self): path.join(self.mask_dir, video), size=self.size, to_save=self.req_frame_list[video], - use_all_mask=True + use_all_masks=True ) def __len__(self): diff --git a/inference/data/video_reader.py b/inference/data/video_reader.py index bc468c4..8f1669c 100644 --- a/inference/data/video_reader.py +++ b/inference/data/video_reader.py @@ -16,7 +16,7 @@ class VideoReader(Dataset): """ This class is used to read a video, one frame at a time """ - def __init__(self, vid_name, image_dir, mask_dir, size=-1, to_save=None, use_all_mask=False, size_dir=None): + def __init__(self, vid_name, image_dir, mask_dir, size=-1, to_save=None, use_all_masks=False, size_dir=None): """ image_dir - points to a directory of jpg images mask_dir - points to a directory of png masks @@ -30,7 +30,7 @@ def __init__(self, vid_name, image_dir, mask_dir, size=-1, to_save=None, use_all self.image_dir = image_dir self.mask_dir = mask_dir self.to_save = to_save - self.use_all_mask = use_all_mask + self.use_all_masks = use_all_masks if size_dir is None: self.size_dir = self.image_dir else: @@ -72,10 +72,13 @@ def __getitem__(self, idx): shape = np.array(size_im).shape[:2] gt_path = path.join(self.mask_dir, frame[:-4]+'.png') + if not os.path.exists(gt_path): + gt_path = path.join(self.mask_dir, frame[:-4]+'.PNG') + data['raw_image_tensor'] = FT.to_tensor(img) # for dataloaders it cannot be raw PIL.Image, only tensors img = self.im_transform(img) - load_mask = self.use_all_mask or (gt_path == self.first_gt_path) + load_mask = self.use_all_masks or (gt_path == self.first_gt_path) if load_mask and path.exists(gt_path): mask = Image.open(gt_path).convert('P') mask = np.array(mask, dtype=np.uint8) diff --git a/inference/frame_selection/__init__.py b/inference/frame_selection/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/inference/frame_selection/frame_selection.py b/inference/frame_selection/frame_selection.py new file mode 100644 index 0000000..5b2da04 --- /dev/null +++ b/inference/frame_selection/frame_selection.py @@ -0,0 +1,247 @@ +from functools import partial +import json +from pathlib import Path +from typing import Any, Dict, List, Set, Tuple, Union + +import torch +import torchvision.transforms.functional as FT +import numpy as np +from sklearn.decomposition import PCA +from scipy.spatial.distance import cdist +from umap import UMAP +from hdbscan import flat +from tqdm import tqdm + +from model.memory_util import get_similarity + + +# -----------------------------CHOSEN FRAME SELECTORS--------------------------------------- + +# Utility +def _extract_keys(dataloder, processor, print_progress=False): + frame_keys = [] + shrinkages = [] + selections = [] + device = None + with torch.no_grad(): # just in case + key_sum = None + + for ti, data in enumerate(tqdm(dataloder, disable=not print_progress, desc='Calculating key features')): + rgb = data['rgb'].cuda()[0] + key, shrinkage, selection = processor.encode_frame_key(rgb) + + if key_sum is None: + device = key.device + # to avoid possible overflow + key_sum = torch.zeros_like( + key, device=device, dtype=torch.float64) + + key_sum += key.type(torch.float64) + + frame_keys.append(key.flatten(start_dim=2).cpu()) + shrinkages.append(shrinkage.flatten(start_dim=2).cpu()) + selections.append(selection.flatten(start_dim=2).cpu()) + + num_frames = ti + 1 # 0 after 1 iteration, 1 after 2, etc. + + return frame_keys, shrinkages, selections, device, num_frames, key_sum + + +def first_frame_only(*args, **kwargs): + # baseline + return [0] + + +def uniformly_selected_frames(dataloader, *args, how_many_frames=10, **kwargs) -> List[int]: + # baseline + # TODO: debug and check if works + num_total_frames = len(dataloader) + return np.linspace(0, num_total_frames - 1, how_many_frames).astype(int).tolist() + + +def calculate_proposals_for_annotations_iterative_pca(dataloader, processor, how_many_frames=10, print_progress=False, distance_metric='euclidean') -> List[int]: + assert distance_metric in {'cosine', 'euclidean'} + # might not pick 0-th frame + np.random.seed(1) # Just in case + with torch.no_grad(): + frame_keys, shrinkages, selections, device, num_frames, key_sum = _extract_keys( + dataloader, processor, print_progress) + flat_keys = torch.stack([key.flatten().cpu() + for key in frame_keys]).numpy() + + # PCA hangs at num_frames // 2: https://github.com/scikit-learn/scikit-learn/issues/22434 + pca = PCA(num_frames - 1, svd_solver='arpack') + smol_keys = pca.fit_transform(flat_keys.astype(np.float64)) + # smol_keys = flat_keys # to disable PCA + + chosen_frames = [0] + for c in range(how_many_frames - 1): + distances = cdist(smol_keys[chosen_frames], + smol_keys, metric=distance_metric) + closest_to_mem_key_distances = distances.min(axis=0) + most_distant_frame = np.argmax(closest_to_mem_key_distances) + chosen_frames.append(int(most_distant_frame)) + + return chosen_frames + + +def calculate_proposals_for_annotations_umap_half_hdbscan_clustering(dataloader, processor, how_many_frames=10, print_progress=False) -> List[int]: + # might not pick 0-th frame + with torch.no_grad(): + frame_keys, shrinkages, selections, device, num_frames, key_sum = _extract_keys( + dataloader, processor, print_progress) + flat_keys = torch.stack([key.flatten().cpu() + for key in frame_keys]).numpy() + + pca = UMAP(n_neighbors=num_frames - 1, + n_components=num_frames // 2, random_state=1) + smol_keys = pca.fit_transform(flat_keys) + + # clustering = AgglomerativeClustering(n_clusters=how_many_frames + 1, linkage='single') + clustering = flat.HDBSCAN_flat( + smol_keys, n_clusters=how_many_frames + 1) + labels = clustering.labels_ + + chosen_frames = [] + for c in range(how_many_frames): + vectors = smol_keys[labels == c] + true_index_mapping = {i: int(ti) for i, ti in zip( + range(len(vectors)), np.nonzero(labels == c)[0])} + center = np.mean(vectors, axis=0) + + # since HDBSCAN is density-based, it makes 0 sense to use anything but euclidean distance here + distances = cdist(vectors, [center], metric='euclidean').squeeze() + + closest_to_cluster_center_idx = np.argsort(distances)[0] + + chosen_frame_idx = true_index_mapping[closest_to_cluster_center_idx] + chosen_frames.append(chosen_frame_idx) + + return chosen_frames + + +def calculate_proposals_for_annotations_with_iterative_distance_cycle(dataloader, processor, how_many_frames=10, print_progress=False) -> List[int]: + with torch.no_grad(): + frame_keys, shrinkages, selections, device, num_frames, key_sum = _extract_keys( + dataloader, processor, print_progress) + + chosen_frames = [0] + chosen_frames_mem_keys = [frame_keys[0].to(device)] + + for i in tqdm(range(how_many_frames - 1), desc='Iteratively picking the most dissimilar frames', disable=not print_progress): + dissimilarities = [] + # how to run a loop for lower memory usage + for j in tqdm(range(num_frames), desc='Computing similarity to chosen frames', disable=not print_progress): + qk = frame_keys[j].to(device) + query_selection = selections[j].to(device) # query + query_shrinkage = shrinkages[j].to(device) + + dissimilarities_across_mem_keys = [] + for key_idx, mem_key in zip(chosen_frames, chosen_frames_mem_keys): + mem_key = mem_key.to(device) + key_shrinkage = shrinkages[key_idx].to(device) + key_selection = selections[key_idx].to(device) + + similarity_per_pixel = get_similarity( + mem_key, ms=None, qk=qk, qe=None) + reverse_similarity_per_pixel = get_similarity( + qk, ms=None, qk=mem_key, qe=None) + + # mapping of pixels A -> B would be very similar to B -> A if the images are similar + # and very different if the images are different + cycle_dissimilarity_per_pixel = ( + similarity_per_pixel - reverse_similarity_per_pixel) + cycle_dissimilarity_score = cycle_dissimilarity_per_pixel.abs().sum() / \ + cycle_dissimilarity_per_pixel.numel() + + dissimilarities_across_mem_keys.append( + cycle_dissimilarity_score) + + # filtering our existing or very similar frames + dissimilarity_min_across_all = min( + dissimilarities_across_mem_keys) + dissimilarities.append(dissimilarity_min_across_all) + + values, indices = torch.topk(torch.tensor( + dissimilarities), k=1, largest=True) + chosen_new_frame = int(indices[0]) + + chosen_frames.append(chosen_new_frame) + chosen_frames_mem_keys.append( + frame_keys[chosen_new_frame].to(device)) + # chosen_frames_self_similarities.append(get_similarity(chosen_frames_mem_keys[-1], ms=shrinkages[chosen_new_frame].to(device), qk=chosen_frames_mem_keys[-1], qe=selections[chosen_new_frame].to(device))) + + # we don't need to worry about 1st frame with itself, since we take the LEAST similar frames + return chosen_frames + + +def calculate_proposals_for_annotations_with_iterative_distance_double_diff(dataloader, processor, how_many_frames=10, print_progress=False) -> List[int]: + with torch.no_grad(): + frame_keys, shrinkages, selections, device, num_frames, key_sum = _extract_keys( + dataloader, processor, print_progress) + + chosen_frames = [0] + chosen_frames_mem_keys = [frame_keys[0]] + # chosen_frames_self_similarities = [] + for c in range(how_many_frames - 1): + dissimilarities = [] + # how to run a loop for lower memory usage + for i in tqdm(range(num_frames), desc=f'Computing similarity to chosen frames', disable=not print_progress): + true_frame_idx = i + qk = frame_keys[true_frame_idx].to(device) + query_selection = selections[true_frame_idx].to( + device) # query + query_shrinkage = shrinkages[true_frame_idx].to(device) + + dissimilarities_across_mem_keys = [] + for key_idx, mem_key in zip(chosen_frames, chosen_frames_mem_keys): + mem_key = mem_key.to(device) + key_shrinkage = shrinkages[key_idx].to(device) + key_selection = selections[key_idx].to(device) + + similarity_per_pixel = get_similarity( + mem_key, ms=key_shrinkage, qk=qk, qe=query_selection) + self_similarity_key = get_similarity( + mem_key, ms=key_shrinkage, qk=mem_key, qe=key_selection) + self_similarity_query = get_similarity( + qk, ms=query_shrinkage, qk=qk, qe=query_selection) + + # mapping of pixels A -> B would be very similar to B -> A if the images are similar + # and very different if the images are different + + pure_similarity = 2 * similarity_per_pixel - \ + self_similarity_key - self_similarity_query + + dissimilarity_score = pure_similarity.abs().sum() / pure_similarity.numel() + + dissimilarities_across_mem_keys.append(dissimilarity_score) + + # filtering our existing or very similar frames + dissimilarity_min_across_all = min( + dissimilarities_across_mem_keys) + dissimilarities.append(dissimilarity_min_across_all) + + values, indices = torch.topk(torch.tensor( + dissimilarities), k=1, largest=True) + chosen_new_frame = int(indices[0]) + + chosen_frames.append(chosen_new_frame) + chosen_frames_mem_keys.append( + frame_keys[chosen_new_frame].to(device)) + + # we don't need to worry about 1st frame with itself, since we take the LEAST similar frames + return chosen_frames + + +KNOWN_ANNOTATION_PREDICTORS = { + 'PCA_EUCLIDEAN': partial(calculate_proposals_for_annotations_iterative_pca, distance_metric='euclidean'), + 'PCA_COSINE': partial(calculate_proposals_for_annotations_iterative_pca, distance_metric='cosine'), + 'UMAP_EUCLIDEAN': calculate_proposals_for_annotations_umap_half_hdbscan_clustering, + 'INTERNAL_CYCLE_CONSISTENCY': calculate_proposals_for_annotations_with_iterative_distance_cycle, + 'INTERNAL_DOUBLE_DIFF': calculate_proposals_for_annotations_with_iterative_distance_double_diff, + + 'FIRST_FRAME_ONLY': first_frame_only, # ignores the number of candidates, baseline + 'UNIFORM': uniformly_selected_frames # baseline +} + +# ------------------------END CHOSEN----------------------------------------------- diff --git a/inference/frame_selection/frame_selection_utils.py b/inference/frame_selection/frame_selection_utils.py new file mode 100644 index 0000000..1746463 --- /dev/null +++ b/inference/frame_selection/frame_selection_utils.py @@ -0,0 +1,323 @@ +from functools import partial +import json +import os +from pathlib import Path +from typing import Any, Dict, List, Set, Tuple, Union + +from sklearn.cluster import KMeans +import pandas as pd +from PIL import Image +import torch +import torchvision.transforms.functional as FT +import numpy as np +from torchvision.transforms import ColorJitter, Grayscale, RandomPosterize, RandomAdjustSharpness, ToTensor, RandomAffine +import cv2 +from tqdm import tqdm +import matplotlib.pyplot as plt + +from util.tensor_util import get_bbox_from_mask +from inference.run_on_video import run_on_video, predict_annotation_candidates, KNOWN_ANNOTATION_PREDICTORS + + +def select_n_frame_candidates(preds_df: pd.DataFrame, uncertainty_name: str, n=5): + df = preds_df + + df.reset_index(drop=False, inplace=True) + + # max_frame = df['frame'].max() + # max_entropy = df['entropy'].max() + + df = df[df['mask_provided'] == False] # removing frames with masks + # removing low entropy parts + df = df[df[uncertainty_name] >= df[uncertainty_name].median()] + + df_backup = df.copy() + + df['index'] = df['index'] / df['index'].max() # scale to 0..1 + # df['entropy'] = df['entropy'] / df['entropy'].max() # scale to 0..1 + + X = df[['index', uncertainty_name]].to_numpy() + + clusterer = KMeans(n_clusters=n) + + labels = clusterer.fit_predict(X) + + clusters = df_backup.groupby(labels) + + candidates = [] + + for g, cluster in clusters: + if g == -1: + continue + + max_entropy_idx = cluster[uncertainty_name].argmax() + + res = cluster.iloc[max_entropy_idx] + + candidates.append(res) + + return candidates + + +def select_most_uncertain_frame(preds_df: pd.DataFrame, uncertainty_name: str): + df = preds_df[preds_df['mask_provided'] == False] + df.reset_index(drop=False, inplace=True) + return df.iloc[df[uncertainty_name].argmax()] + + +def select_n_frame_candidates_no_neighbours_simple(preds_df: pd.DataFrame, uncertainty_name: str, n=5, neighbourhood_size=4): + df = preds_df + df.reset_index(drop=False, inplace=True) + + df = df[df['mask_provided'] == False] # removing frames with masks + + neighbours_indices = set() + chosen_candidates = [] + + df_sorted = df.sort_values(uncertainty_name, ascending=False) + i = 0 + while len(chosen_candidates) < n: + candidate = df_sorted.iloc[i] + candidate_index = candidate['index'] + + if candidate_index not in neighbours_indices: + chosen_candidates.append(candidate) + candidate_neighbours = range( + candidate_index - neighbourhood_size, candidate_index + neighbourhood_size + 1) + neighbours_indices.update(candidate_neighbours) + + i += 1 + + return chosen_candidates + + +WhichAugToPick = -1 + + +def get_determenistic_augmentations(img_size=None, mask=None, subset: str = None): + assert subset in {'best_3', 'best_3_with_symmetrical', + 'best_all', 'original_only', 'all'} + + bright = ColorJitter(brightness=(1.5, 1.5)) + dark = ColorJitter(brightness=(0.5, 0.5)) + gray = Grayscale(num_output_channels=3) + reduce_bits = RandomPosterize(bits=3, p=1) + sharp = RandomAdjustSharpness(sharpness_factor=16, p=1) + rotate_right = RandomAffine(degrees=(30, 30)) + blur = partial(FT.gaussian_blur, kernel_size=7) + + if img_size is not None: + h, w = img_size[-2:] + translate_distance = w // 5 + else: + translate_distance = 200 + + translate_right = partial(FT.affine, angle=0, translate=( + translate_distance, 0), scale=1, shear=0) + + zoom_out = partial(FT.affine, angle=0, + translate=(0, 0), scale=0.5, shear=0) + zoom_in = partial(FT.affine, angle=0, translate=(0, 0), scale=1.5, shear=0) + shear_right = partial(FT.affine, angle=0, + translate=(0, 0), scale=1, shear=20) + + identity = torch.nn.Identity() + identity.name = 'identity' + + if mask is not None: + if mask.any(): + min_y, min_x, max_y, max_x = get_bbox_from_mask(mask) + h, w = mask.shape[-2:] + crop_mask = partial(FT.resized_crop, top=min_y - 10, left=min_x - 10, + height=max_y - min_y + 10, width=max_x - min_x + 10, size=(w, h)) + crop_mask.name = 'crop_mask' + else: + crop_mask = identity # if the mask is empty + else: + crop_mask = None + + bright.name = 'bright' + dark.name = 'dark' + gray.name = 'gray' + reduce_bits.name = 'reduce_bits' + sharp.name = 'sharp' + rotate_right.name = 'rotate_right' + translate_right.name = 'translate_right' + zoom_out.name = 'zoom_out' + zoom_in.name = 'zoom_in' + shear_right.name = 'shear_right' + blur.name = 'blur' + + rotate_left = RandomAffine(degrees=(-30, -30)) + rotate_left.name = 'rotate_left' + + shear_left = partial(FT.affine, angle=0, + translate=(0, 0), scale=1, shear=-20) + shear_left.name = 'shear_left' + + if WhichAugToPick != -1: + return [img_mask_augs_pairs[WhichAugToPick]] + + if subset == 'best_3': + img_mask_augs_pairs = [ + # augs only applied to the image + # (bright, identity), + # (dark, identity), + # (gray, identity), + # (reduce_bits, identity), + # (sharp, identity), + (blur, identity), + + # augs requiring modifying the mask as well: + # (rotate_right, rotate_right), + # (rotate_left, rotate_left), + # (translate_right, translate_right), + # (zoom_out, zoom_out), + (zoom_in, zoom_in), + (shear_right, shear_right), + # (shear_left, shear_left), + ] + + return img_mask_augs_pairs + elif subset == 'best_3_with_symmetrical': + img_mask_augs_pairs = [ + # augs only applied to the image + # (bright, identity), + # (dark, identity), + # (gray, identity), + # (reduce_bits, identity), + # (sharp, identity), + (blur, identity), + + # augs requiring modifying the mask as well: + # (rotate_right, rotate_right), + # (rotate_left, rotate_left), + # (translate_right, translate_right), + # (zoom_out, zoom_out), + (zoom_in, zoom_in), + (shear_right, shear_right), + (shear_left, shear_left), + ] + + return img_mask_augs_pairs + elif subset == 'best_all': + img_mask_augs_pairs = [ + # augs only applied to the image + (bright, identity), + (dark, identity), + # (gray, identity), + (reduce_bits, identity), + (sharp, identity), + (blur, identity), + + # augs requiring modifying the mask as well: + (rotate_right, rotate_right), + (rotate_left, rotate_left), + # (translate_right, translate_right), + (zoom_out, zoom_out), + (zoom_in, zoom_in), + (shear_right, shear_right), + (shear_left, shear_left), + ] + + return img_mask_augs_pairs + + elif subset == 'original_only': + img_mask_augs_pairs = [ + # augs only applied to the image + (bright, identity), + (dark, identity), + (gray, identity), + (reduce_bits, identity), + (sharp, identity), + (blur, identity), + + # augs requiring modifying the mask as well: + # (rotate_right, rotate_right), + # (translate_right, translate_right), + # (zoom_out, zoom_out), + # (zoom_in, zoom_in), + # (shear_right, shear_right), + ] + else: + img_mask_augs_pairs = [ + # augs only applied to the image + (bright, identity), + (dark, identity), + (gray, identity), + (reduce_bits, identity), + (sharp, identity), + (blur, identity), + + # augs requiring modifying the mask as well: + (rotate_right, rotate_right), + (rotate_left, rotate_left), + (translate_right, translate_right), + (zoom_out, zoom_out), + (zoom_in, zoom_in), + (shear_right, shear_right), + (shear_left, shear_left), + ] + + if crop_mask is not None: + img_mask_augs_pairs.append((crop_mask, crop_mask)) + + return img_mask_augs_pairs + +def disparity_func(predictions, augs, images: list = None, output_save_path: str = None): + assert len(predictions) - len(augs) == 1 + disparity_map = None + prev = None + + if images is None: + images = [None] * len(predictions) + else: + assert len(predictions) == len(images) + + if output_save_path is not None: + p_out_disparity = Path(output_save_path) + else: + p_out_disparity = None + + try: + aug_names = [aug.name for aug in augs] + except AttributeError: + aug_names = [aug._get_name() for aug in augs] + + names = ['original'] + aug_names + for i, (name, img, pred) in enumerate(zip(names, images, predictions)): + fg_mask = pred[1:2].squeeze().cpu() # 1:2 is Foreground + + if disparity_map is None: + disparity_map = torch.zeros_like(fg_mask) + else: + disparity_map += (prev - fg_mask).abs() + + pred_mask_ = FT.to_pil_image(fg_mask) + if p_out_disparity is not None: + p_out_save_mask = p_out_disparity / 'masks' / (f'{i}_{name}.png') + p_out_save_image = p_out_disparity / 'images' / (f'{i}_{name}.png') + + if not p_out_save_mask.parent.exists(): + p_out_save_mask.parent.mkdir(parents=True) + + pred_mask_.save(p_out_save_mask) + + if not p_out_save_image.parent.exists(): + p_out_save_image.parent.mkdir(parents=True) + + img.save(p_out_save_image) + + prev = fg_mask + + # 0..1; not `disparity_map.max()`, as the scale would differ across images + disparity_scaled = disparity_map / (len(augs) + 1) + disparity_avg = disparity_scaled.mean() + # num pixels with large disparities + disparity_large = (disparity_scaled > 0.5).sum() + + if p_out_disparity is not None: + disparity_img = FT.to_pil_image(disparity_scaled) + disparity_img.save(p_out_disparity / (f'{i+1}_absolute_disparity.png')) + + return {'full': disparity_scaled, 'avg': disparity_avg, 'large': disparity_large} diff --git a/inference/inference_core.py b/inference/inference_core.py index f6a4973..d53e062 100644 --- a/inference/inference_core.py +++ b/inference/inference_core.py @@ -68,7 +68,7 @@ def step(self, image, mask=None, valid_labels=None, end=False, manually_curated_ is_ignore = do_not_add_mask_to_memory # to avoid adding permanent memory frames twice, since they are alredy in the memory - need_segment = (self.curr_ti > 0) and ((valid_labels is None) or (len(self.all_labels) != len(valid_labels))) + need_segment = (valid_labels is None) or (len(self.all_labels) != len(valid_labels)) is_deep_update = ( (self.deep_update_sync and is_mem_frame) or # synchronized (not self.deep_update_sync and self.curr_ti-self.last_deep_update_ti >= self.deep_update_every) # no-sync diff --git a/inference/run_experiments.py b/inference/run_experiments.py new file mode 100644 index 0000000..9acd525 --- /dev/null +++ b/inference/run_experiments.py @@ -0,0 +1,259 @@ +import os +import json +from pathlib import Path +from typing import Any, Dict, List, Set, Tuple, Union + +import cv2 +import numpy as np +import pandas as pd +from tqdm import tqdm +from matplotlib import pyplot as plt + +from inference.frame_selection.frame_selection import KNOWN_ANNOTATION_PREDICTORS +from inference.run_on_video import predict_annotation_candidates, run_on_video + +# ---------------BEGIN Inference and visualization utils -------------------------- + +def make_non_uniform_grid(rows_of_image_paths: List[List[str]], output_path: str, grid_size=3, resize_to: Tuple[int, int]=(854, 480)): + assert len(rows_of_image_paths) == grid_size + for row in rows_of_image_paths: + assert len(row) <= grid_size + + p_out_dir = Path(output_path) + if not p_out_dir.exists(): + p_out_dir.mkdir(parents=True) + num_frames = None + + for row in rows_of_image_paths: + for img_path_dir in row: + num_frames_in_dir = len(os.listdir(img_path_dir)) + if num_frames is None: + num_frames = num_frames_in_dir + else: + assert num_frames == num_frames_in_dir + + rows_of_iterators = [] + for row_of_image_dir_paths in rows_of_image_paths: + row = [] + for image_dir_path in row_of_image_dir_paths: + p = Path(image_dir_path) + iterator = iter(sorted(p.iterdir())) + row.append(iterator) + rows_of_iterators.append(row) + + for i in tqdm(range(num_frames)): + rows_of_frames = [] + for row in rows_of_iterators: + frames = [] + global_h, global_w = None, None + for iterator in row: + frame_path = str(next(iterator)) + frame = cv2.imread(frame_path) + h, w = frame.shape[0:2] + + if resize_to is not None: + desired_w, desired_h = resize_to + if h != desired_w or w != desired_w: + frame = cv2.resize(frame, (desired_w, desired_h)) + h, w = frame.shape[0:2] + + frames.append(frame) + + if global_h is None: + global_h, global_w = h, w + + wide_frame = np.concatenate(frames, axis=1) + + if len(frames) < grid_size: + pad_size = global_w * (grid_size - len(frames)) // 2 + # center the frame + wide_frame = np.pad(wide_frame, [(0, 0), (pad_size, pad_size), (0, 0)], mode='constant', constant_values=0) + rows_of_frames.append(wide_frame) + + big_frame = np.concatenate(rows_of_frames, axis=0) + cv2.imwrite(str(p_out_dir / f'frame_{i:06d}.png'), big_frame) + + +def visualize_grid(video_names: List[str], labeled=True): + for video_name in video_names: + p_in_general = Path( + f'/home/maksym/RESEARCH/VIDEOS/RESULTS/XMem_memory/permanent_work_memory/AL_comparison/{video_name}/Overlay') + if labeled: + p_in_general /= 'Labeled' + + cycle = p_in_general / 'INTERNAL_CYCLE_CONSISTENCY' + ddiff = p_in_general / 'INTERNAL_DOUBLE_DIFF' + umap = p_in_general / 'UMAP_EUCLIDEAN' + pca_euclidean = p_in_general / 'PCA_EUCLIDEAN' + pca_cosine = p_in_general / 'PCA_COSINE' + one_frame_only = p_in_general / 'ONLY_ONE_FRAME' + baseline_uniform = p_in_general / 'BASELINE_UNIFORM' + baseline_human = p_in_general / 'HUMAN_CHOSEN' + ULTIMATE = p_in_general / 'ULTIMATE_AUTO' + + grid = [ + [cycle, ddiff, umap], + [pca_euclidean, pca_cosine, baseline_uniform], + [baseline_human, one_frame_only, ULTIMATE] + ] + if labeled: + p_out = p_in_general.parent.parent / 'All_combined' + else: + p_out = p_in_general.parent / 'All_combined_unlabeled' + + make_non_uniform_grid(grid, p_out, grid_size=3) + + +def get_videos_info(): + return { + 'long_scene': { + 'num_annotation_candidates': 3, # 3, + 'video_frames_path': '/home/maksym/RESEARCH/VIDEOS/long_scene/JPEGImages', + 'video_masks_path': '/home/maksym/RESEARCH/VIDEOS/long_scene/Annotations', + 'masks_out_path': '/home/maksym/RESEARCH/VIDEOS/RESULTS/XMem_memory/permanent_work_memory/AL_comparison/long_scene' + }, + 'long_scene_scale': { + 'num_annotation_candidates': 3, # 3, + 'video_frames_path': '/home/maksym/RESEARCH/VIDEOS/long_scene_scale/JPEGImages', + 'video_masks_path': '/home/maksym/RESEARCH/VIDEOS/long_scene_scale/Annotations', + 'masks_out_path': '/home/maksym/RESEARCH/VIDEOS/RESULTS/XMem_memory/permanent_work_memory/AL_comparison/long_scene_scale' + }, + 'ariana_smile': { + 'num_annotation_candidates': 3, # 3, + 'video_frames_path': '/home/maksym/RESEARCH/VIDEOS/Scenes_ariana_fixed_naming/smile/JPEGImages', + 'video_masks_path': '/home/maksym/RESEARCH/VIDEOS/Scenes_ariana_fixed_naming/smile/Annotations/Lips', + 'masks_out_path': '/home/maksym/RESEARCH/VIDEOS/RESULTS/XMem_memory/permanent_work_memory/AL_comparison/ariana_smile' + }, + 'ariana_blog': { + 'num_annotation_candidates': 5, # 5, + 'video_frames_path': '/home/maksym/RESEARCH/VIDEOS/Scenes_ariana_fixed_naming/blog/JPEGImages', + 'video_masks_path': '/home/maksym/RESEARCH/VIDEOS/Scenes_ariana_fixed_naming/blog/Annotations/Together', + 'masks_out_path': '/home/maksym/RESEARCH/VIDEOS/RESULTS/XMem_memory/permanent_work_memory/AL_comparison/ariana_blog' + }, + } + + +def run_multiple_frame_selectors(videos_info: Dict[str, Dict], csv_output_path: str): + output = pd.DataFrame(columns=list(KNOWN_ANNOTATION_PREDICTORS)) + p_bar = tqdm(total=len(videos_info) * len(KNOWN_ANNOTATION_PREDICTORS)) + + for video_name, info in videos_info.items(): + video_frames_path = info['video_frames_path'] + num_candidate_frames = info['num_annotation_candidates'] + + results = {} + for method_name in KNOWN_ANNOTATION_PREDICTORS: + chosen_annotation_frames = predict_annotation_candidates( + video_frames_path, num_candidates=num_candidate_frames, approach=method_name) + results[method_name] = json.dumps(chosen_annotation_frames) + p_bar.update() + + output.loc[video_name] = results + + output.index.name = 'video_name' + output.to_csv(csv_output_path) + + +def run_inference_with_pre_chosen_frames(chosen_frames_csv_path: str, videos_info: Dict[str, Dict], output_path: str, only_methods_subset: Set[str] = None): + df = pd.read_csv(chosen_frames_csv_path, index_col='video_name') + num_runs = np.prod(df.shape) + p_bar = tqdm( + desc='Running inference comparing multiple different AL approaches', total=num_runs) + + for video_name, info in videos_info.items(): + video_row = df.loc[video_name] + for method in video_row.index: + if only_methods_subset is not None and method not in only_methods_subset: + continue + + chosen_frames_str = video_row.loc[method] + chosen_frames = json.loads(chosen_frames_str) + print(chosen_frames) + + video_frames_path = info['video_frames_path'] + video_masks_path = info['video_masks_path'] + + output_masks_path = Path(output_path) / video_name / method + + run_on_video(video_frames_path, video_masks_path, output_masks_path, + frames_with_masks=chosen_frames, compute_iou=False, print_progress=False) + + p_bar.update() + + +def visualize_chosen_frames(video_name: str, num_total_frames: int, data: pd.Series, output_path: str): + def _sort_index(series): + ll = list(series.index) + sorted_ll = sorted(ll, key=lambda x: str( + min(json.loads(series.loc[x])))) + return sorted_ll + + sorted_index = _sort_index(data) + plt.figure(figsize=(16, 10)) + plt.title(f"Chosen frames for {video_name}") + plt.xlim(-10, num_total_frames + 10) + num_methods = len(data.index) + + plt.ylim(-0.25, num_methods + 0.25) + + plt.xlabel('Frame number') + plt.ylabel('AL method') + + plt.yticks([]) # disable yticks + + previous_plots = [] + + for i, method_name in enumerate(sorted_index): + chosen_frames = json.loads(data.loc[method_name]) + num_frames = len(chosen_frames) + + x = sorted(chosen_frames) + y = [i for _ in chosen_frames] + plt.axhline(y=i, zorder=1, xmin=0.01, xmax=0.99) + + plt.scatter(x=x, y=y, label=method_name, s=256, zorder=3, marker="v") + if len(previous_plots) != 0: + for i in range(num_frames): + curr_x, curr_y = x[i], y[i] + prev_x, prev_y = previous_plots[-1][0][i], previous_plots[-1][1][i] + + plt.plot([prev_x, curr_x], [prev_y, curr_y], + linewidth=1, color='gray', alpha=0.5) + + previous_plots.append((x, y)) + + # texts = map(str, range(num_frames)) + # for i, txt in enumerate(texts): + # plt.annotate(txt, (x[i] + 2, y[i] + 0.1), zorder=4, fontproperties={'weight': 'bold'}) + + plt.legend() + p_out = Path(f'{output_path}/chosen_frames_{video_name}.png') + if not p_out.parent.exists(): + p_out.parent.mkdir(parents=True) + + plt.savefig(p_out, bbox_inches='tight') + +# -------------------------END Inference and visualization utils -------------------------- + + +if __name__ == "__main__": + pass + + # ## Usage examples + # ## Run from root-level directory, e.g. in `main.py` + + # ## Running multiple frame selectors, saving their predicted frame numbers to a .csv file + # run_multiple_frame_selectors(get_videos_info(), csv_output_path='output/al_videos_chosen_frames.csv') + + # ## Running and visualizing inference based on pre-calculated frames selected + # run_inference_with_pre_chosen_frames( + # chosen_frames_csv_path='output/al_videos_chosen_frames.csv', + # videos_info=get_videos_info(), + # output_path='/home/maksym/RESEARCH/VIDEOS/RESULTS/XMem_memory/permanent_work_memory/AL_comparison/' + # ) + + # ## Concatenating multiple video results into a non-uniform grid + # visualize_grid( + # names=['long_scene', 'ariana_blog', 'ariana_smile', 'long_scene_scale'], + # labeled=True, + # ) diff --git a/inference/run_on_video.py b/inference/run_on_video.py new file mode 100644 index 0000000..b03dee7 --- /dev/null +++ b/inference/run_on_video.py @@ -0,0 +1,393 @@ +import os +from os import PathLike, path +from typing import Iterable, Literal, Union, List +from collections import defaultdict +from pathlib import Path + +import numpy as np +import pandas as pd +import torch +import torch.nn.functional as F +from torchvision.transforms import functional as FT +from torch.utils.data import DataLoader +from baal.active.heuristics import BALD +from scipy.stats import entropy +from tqdm import tqdm +from PIL import Image + +from model.network import XMem +from util.tensor_util import compute_tensor_iou +from inference.inference_core import InferenceCore +from inference.data.video_reader import VideoReader +from inference.data.mask_mapper import MaskMapper +from inference.frame_selection.frame_selection import KNOWN_ANNOTATION_PREDICTORS +from inference.frame_selection.frame_selection_utils import disparity_func, get_determenistic_augmentations + + +def save_frames(dataset, frame_indices, output_folder): + p_out = Path(output_folder) + + if not p_out.exists(): + p_out.mkdir(parents=True) + + for i in frame_indices: + sample = dataset[i] + rgb_raw_tensor = sample['raw_image_tensor'].cpu().squeeze() + img = FT.to_pil_image(rgb_raw_tensor) + + img.save(p_out / f'frame_{i:06d}.png') + + +def _inference_on_video(frames_with_masks, imgs_in_path, masks_in_path, masks_out_path, + original_memory_mechanism=False, + compute_iou=False, compute_uncertainty=False, manually_curated_masks=False, print_progress=True, + augment_images_with_masks=False, + uncertainty_name: str = None, + only_predict_frames_to_annotate_and_quit=0, + overwrite_config: dict = None, + frame_selector_func: callable = None): + torch.autograd.set_grad_enabled(False) + frames_with_masks = set(frames_with_masks) + config = { + 'buffer_size': 100, + 'deep_update_every': -1, + 'enable_long_term': True, + 'enable_long_term_count_usage': True, + 'fbrs_model': 'saves/fbrs.pth', + 'hidden_dim': 64, + 'images': None, + 'key_dim': 64, + 'max_long_term_elements': 10000, + 'max_mid_term_frames': 10, + 'mem_every': 10, + 'min_mid_term_frames': 5, + 'model': './saves/XMem.pth', + 'no_amp': False, + 'num_objects': 1, + 'num_prototypes': 128, + 's2m_model': 'saves/s2m.pth', + 'size': 480, + 'top_k': 30, + 'value_dim': 512, + # f'../VIDEOS/RESULTS/XMem_memory/thanks_two_face_5_frames/', + 'masks_out_path': masks_out_path, + 'workspace': None, + 'save_masks': True + } + + if overwrite_config is not None: + config.update(overwrite_config) + + if compute_uncertainty: + assert uncertainty_name is not None + uncertainty_name = uncertainty_name.lower() + assert uncertainty_name in {'entropy', + 'bald', 'disparity', 'disparity_large'} + compute_disparity = uncertainty_name.startswith('disparity') + else: + compute_disparity = False + + vid_reader = VideoReader( + "", + imgs_in_path, # f'/home/maksym/RESEARCH/VIDEOS/thanks_no_ears_5_annot/JPEGImages', + masks_in_path, # f'/home/maksym/RESEARCH/VIDEOS/thanks_no_ears_5_annot/Annotations_binarized_two_face', + size=config['size'], + use_all_masks=(only_predict_frames_to_annotate_and_quit == 0) + ) + + model_path = config['model'] + network = XMem(config, model_path).cuda().eval() + if model_path is not None: + model_weights = torch.load(model_path) + network.load_weights(model_weights, init_as_zero_if_needed=True) + else: + print('No model loaded.') + + loader = DataLoader(vid_reader, batch_size=1, shuffle=False, num_workers=8) + vid_name = vid_reader.vid_name + vid_length = len(loader) + # no need to count usage for LT if the video is not that long anyway + config['enable_long_term_count_usage'] = ( + config['enable_long_term'] and + (vid_length + / (config['max_mid_term_frames']-config['min_mid_term_frames']) + * config['num_prototypes']) + >= config['max_long_term_elements'] + ) + + mapper = MaskMapper() + processor = InferenceCore(network, config=config) + first_mask_loaded = False + + if only_predict_frames_to_annotate_and_quit > 0: + assert frame_selector_func is not None + chosen_annotation_candidate_frames = frame_selector_func( + loader, processor, print_progress=print_progress, how_many_frames=only_predict_frames_to_annotate_and_quit) + + return chosen_annotation_candidate_frames + + frames_ = [] + masks_ = [] + + if original_memory_mechanism: + # only the first frame goes into permanent memory originally + frames_to_put_in_permanent_memory = [0] + # the rest are going to be processed later + else: + # in our modification, all frames with provided masks go into permanent memory + frames_to_put_in_permanent_memory = frames_with_masks + for j in frames_to_put_in_permanent_memory: + sample = vid_reader[j] + rgb = sample['rgb'].cuda() + rgb_raw_tensor = sample['raw_image_tensor'].cpu() + msk = sample['mask'] + info = sample['info'] + need_resize = info['need_resize'] + + # https://github.com/hkchengrex/XMem/issues/21 just make exhaustive = True + msk, labels = mapper.convert_mask(msk, exhaustive=True) + msk = torch.Tensor(msk).cuda() + + if min(msk.shape) == 0: # empty mask, e.g. [1, 0, 720, 1280] + print(f"Skipping adding frame {j} to memory, as the mask is empty") + continue # just don't add anything to the memory + if need_resize: + msk = vid_reader.resize_mask(msk.unsqueeze(0))[0] + + processor.set_all_labels(list(mapper.remappings.values())) + processor.put_to_permanent_memory(rgb, msk) + + if not first_mask_loaded: + first_mask_loaded = True + + frames_.append(rgb) + masks_.append(msk) + + if augment_images_with_masks: + augs = get_determenistic_augmentations( + rgb.shape, msk, subset='best_all') + rgb_raw = FT.to_pil_image(rgb_raw_tensor) + + for img_aug, mask_aug in augs: + # tensor -> PIL.Image -> tensor -> whatever normalization vid_reader applies + rgb_aug = vid_reader.im_transform(img_aug(rgb_raw)).cuda() + + msk_aug = mask_aug(msk) + + processor.put_to_permanent_memory(rgb_aug, msk_aug) + + if not first_mask_loaded: + raise ValueError("No valid masks provided!") + + stats = [] + + if compute_uncertainty and uncertainty_name == 'bald': + bald = BALD() + + for ti, data in enumerate(tqdm(loader, disable=not print_progress)): + with torch.cuda.amp.autocast(enabled=True): + rgb = data['rgb'].cuda()[0] + rgb_raw_tensor = data['raw_image_tensor'].cpu()[0] + + gt = data.get('mask') # for IoU computations + if ti in frames_with_masks: + msk = data['mask'] + else: + msk = None + + info = data['info'] + frame = info['frame'][0] + shape = info['shape'] + need_resize = info['need_resize'][0] + curr_stat = {'frame': frame, 'mask_provided': msk is not None} + + # not important anymore as long as at least one mask is in permanent memory + if original_memory_mechanism and not first_mask_loaded: + if msk is not None: + first_mask_loaded = True + else: + # no point to do anything without a mask + continue + + # Map possibly non-continuous labels to continuous ones + if msk is not None: + # https://github.com/hkchengrex/XMem/issues/21 just make exhaustive = True + msk, labels = mapper.convert_mask( + msk[0].numpy(), exhaustive=True) + msk = torch.Tensor(msk).cuda() + if need_resize: + msk = vid_reader.resize_mask(msk.unsqueeze(0))[0] + processor.set_all_labels(list(mapper.remappings.values())) + + else: + labels = None + + if (compute_uncertainty and uncertainty_name == 'bald') or compute_disparity: + dry_run_preds = [] + augged_images = [] + augs = get_determenistic_augmentations(subset='original_only') + rgb_raw = FT.to_pil_image(rgb_raw_tensor) + for img_aug, mask_aug in augs: + # tensor -> PIL.Image -> tensor -> whatever normalization vid_reader applies + augged_img = img_aug(rgb_raw) + augged_images.append(augged_img) + rgb_aug = vid_reader.im_transform(augged_img).cuda() + + # does not do anything, since original_only=True augmentations don't alter the mask at all + msk = mask_aug(msk) + + dry_run_prob = processor.step(rgb_aug, msk, labels, end=(ti == vid_length-1), + manually_curated_masks=manually_curated_masks, disable_memory_updates=True) + dry_run_preds.append(dry_run_prob.cpu()) + + if original_memory_mechanism: + # we only ignore the first mask, since it's already in the permanent memory + do_not_add_mask_to_memory = (ti == 0) + else: + # we ignore all frames with masks, since they are already preloaded in the permanent memory + do_not_add_mask_to_memory = msk is not None + # Run the model on this frame + # 2+ channels, classes+ and background + prob = processor.step(rgb, msk, labels, end=(ti == vid_length-1), + manually_curated_masks=manually_curated_masks, do_not_add_mask_to_memory=do_not_add_mask_to_memory) + + if compute_uncertainty: + if uncertainty_name == 'bald': + # [batch=1, num_classes, ..., num_iterations] + all_samples = torch.stack( + [x.unsqueeze(0) for x in dry_run_preds + [prob.cpu()]], dim=-1).numpy() + score = bald.compute_score(all_samples) + # TODO: can also return the exact pixels for every frame? As a suggestion on what to label + curr_stat['bald'] = float(np.squeeze(score).mean()) + elif compute_disparity: + disparity_stats = disparity_func( + predictions=[prob] + dry_run_preds, augs=[img_aug for img_aug, _ in augs], images=[rgb_raw] + augged_images, output_save_path=None) + curr_stat['disparity'] = float(disparity_stats['avg']) + curr_stat['disparity_large'] = float( + disparity_stats['large']) + else: + e = entropy(prob.cpu()) + e_mean = np.mean(e) + curr_stat['entropy'] = float(e_mean) + + # Upsample to original size if needed + if need_resize: + prob = F.interpolate(prob.unsqueeze( + 1), shape, mode='bilinear', align_corners=False)[:, 0] + + # Probability mask -> index mask + out_mask = torch.argmax(prob, dim=0) + out_mask = (out_mask.detach().cpu().numpy()).astype(np.uint8) + + if compute_iou: + # mask is [0, 1] + # gt is [0, 255] + # both -> [False, True] + if gt is not None: + iou = float(compute_tensor_iou(torch.tensor( + out_mask).type(torch.bool), gt.type(torch.bool))) + else: + iou = -1 + curr_stat['iou'] = iou + + # Save the mask + if config['save_masks']: + this_out_path = path.join(config['masks_out_path'], vid_name) + os.makedirs(this_out_path, exist_ok=True) + out_mask = mapper.remap_index_mask(out_mask) + out_img = Image.fromarray(out_mask) + if vid_reader.get_palette() is not None: + out_img.putpalette(vid_reader.get_palette()) + out_img.save(os.path.join(this_out_path, frame[:-4]+'.png')) + + if False: # args.save_scores: + np_path = path.join(args.output, 'Scores', vid_name) + os.makedirs(np_path, exist_ok=True) + if ti == len(loader)-1: + hkl.dump(mapper.remappings, path.join( + np_path, f'backward.hkl'), mode='w') + if args.save_all or info['save'][0]: + hkl.dump(prob, path.join( + np_path, f'{frame[:-4]}.hkl'), mode='w', compression='lzf') + + stats.append(curr_stat) + + return pd.DataFrame(stats) + + +def run_on_video( + imgs_in_path: Union[str, PathLike], + masks_in_path: Union[str, PathLike], + masks_out_path: Union[str, PathLike], + frames_with_masks: Iterable[int] = (0, ), + compute_iou=False, + print_progress=True, +) -> pd.DataFrame: + """ + Args: + imgs_in_path (Union[str, PathLike]): Path to the directory containing video frames in the following format: `frame_000000.png`. .jpg works too. + + masks_in_path (Union[str, PathLike]): Path to the directory containing video frames' masks in the same format, with corresponding names between video frames. Each unique object should have unique color. + + masks_out_path (Union[str, PathLike]): Path to the output directory (will be created if doesn't exist) where the predicted masks will be stored in .png format. + + frames_with_masks (Iterable[int]): A list of integers representing the frames on which the masks should be applied (default: [0], only applied to the first frame). 0-based. + + compute_iou (bool): A flag to indicate whether to compute the IoU metric (default: False, requires ALL video frames to have a corresponding mask). + + print_progress (bool): A flag to indicate whether to print a progress bar (default: True). + + Returns: + stats (pd.Dataframe): a table containing every frame and the following information: IoU score with corresponding mask (if `compute_iou` is True) + """ + + return _inference_on_video( + imgs_in_path=imgs_in_path, + masks_in_path=masks_in_path, + masks_out_path=masks_out_path, + frames_with_masks=frames_with_masks, + compute_uncertainty=False, + compute_iou=compute_iou, + print_progress=print_progress, + manually_curated_masks=False + ) + + +def predict_annotation_candidates( + imgs_in_path: Union[str, PathLike], + approach: str, + num_candidates: int = 1, + print_progress=True, +) -> List[int]: + """ + Args: + imgs_in_path (Union[str, PathLike]): Path to the directory containing video frames in the following format: `frame_000000.png` .jpg works too. + + if num_candidates == 1: + return [0] # First frame is hard-coded to always be used + + # p_bar.update() + + Returns: + annotation_candidates (List[int]): A list of frames indices (0-based) chosen as annotation candidates, sorted by importance (most -> least). Always contains [0] - first frame - at index 0. + """ + + candidate_selection_function = KNOWN_ANNOTATION_PREDICTORS[approach] + + assert num_candidates >= 1 + + if num_candidates == 1: + return [0] # First frame is hard-coded to always be used + + return _inference_on_video( + imgs_in_path=imgs_in_path, + masks_in_path=imgs_in_path, # Ignored + masks_out_path=None, # Ignored + frames_with_masks=[0], # Ignored + compute_uncertainty=False, + compute_iou=False, + print_progress=print_progress, + manually_curated_masks=False, + only_predict_frames_to_annotate_and_quit=num_candidates, + frame_selector_func=candidate_selection_function + ) diff --git a/main.py b/main.py index 1e5ad28..e7d6a9a 100644 --- a/main.py +++ b/main.py @@ -1,4 +1,4 @@ -from run_on_video import run_on_video, predict_annotation_candidates +from inference.run_on_video import run_on_video, predict_annotation_candidates if __name__ == '__main__': diff --git a/run_on_video.py b/run_on_video.py deleted file mode 100644 index 717dffb..0000000 --- a/run_on_video.py +++ /dev/null @@ -1,549 +0,0 @@ -import csv -from typing import Iterable, Union, List -from util.tensor_util import compute_tensor_iou -from inference.inference_core import InferenceCore -from model.network import XMem -from inference.data.video_reader import VideoReader -from inference.data.mask_mapper import MaskMapper -from inference.data.test_datasets import LongTestDataset, DAVISTestDataset, YouTubeVOSTestDataset -from inference.active_learning import calculate_proposals_for_annotations_iterative_umap_cosine, calculate_proposals_for_annotations_iterative_pca_cosine, calculate_proposals_for_annotations_iterative_pca_cosine_values, calculate_proposals_for_annotations_pca_hierarchical_clustering, calculate_proposals_for_annotations_umap_hdbscan_clustering, calculate_proposals_for_annotations_uniform_iterative_pca_cosine, calculate_proposals_for_annotations_with_average_distance, calculate_proposals_for_annotations_with_first_distance, calculate_proposals_for_annotations_with_iterative_distance, calculate_proposals_for_annotations_with_iterative_distance_cycle, calculate_proposals_for_annotations_with_iterative_distance_diff, calculate_proposals_for_annotations_with_uniform_iterative_distance_cycle, calculate_proposals_for_annotations_with_uniform_iterative_distance_diff, calculate_proposals_for_annotations_with_uniform_iterative_distance_double_diff, get_determenistic_augmentations, select_most_uncertain_frame, select_n_frame_candidates, compute_disparity as compute_disparity_func, select_n_frame_candidates_no_neighbours_simple -import torchvision.transforms.functional as FT -from baal.active.heuristics import BALD -from scipy.stats import entropy -from tqdm import tqdm -from PIL import Image -import numpy as np -from torch.utils.data import DataLoader -import torch.nn.functional as F -import torch -import pandas as pd -import shutil -from pathlib import Path -from argparse import ArgumentParser -from os import PathLike, path -import math -from collections import defaultdict -import os -from torchvision.transforms import functional as FT - - -def save_frames(dataset, frame_indices, output_folder): - p_out = Path(output_folder) - - if not p_out.exists(): - p_out.mkdir(parents=True) - - - for i in frame_indices: - sample = dataset[i] - rgb_raw_tensor = sample['raw_image_tensor'].cpu().squeeze() - img = FT.to_pil_image(rgb_raw_tensor) - - img.save(p_out / f'frame_{i:06d}.png') - - -def _inference_on_video(frames_with_masks, imgs_in_path, masks_in_path, masks_out_path, - original_memory_mechanism=False, - compute_iou=False, compute_uncertainty=False, manually_curated_masks=False, print_progress=True, - augment_images_with_masks=False, - uncertainty_name: str = None, - only_predict_frames_to_annotate_and_quit=0, - overwrite_config: dict = None): - torch.autograd.set_grad_enabled(False) - frames_with_masks = set(frames_with_masks) - config = { - 'buffer_size': 100, - 'deep_update_every': -1, - 'enable_long_term': True, - 'enable_long_term_count_usage': True, - 'fbrs_model': 'saves/fbrs.pth', - 'hidden_dim': 64, - 'images': None, - 'key_dim': 64, - 'max_long_term_elements': 10000, - 'max_mid_term_frames': 10, - 'mem_every': 10, - 'min_mid_term_frames': 5, - 'model': './saves/XMem.pth', - 'no_amp': False, - 'num_objects': 1, - 'num_prototypes': 128, - 's2m_model': 'saves/s2m.pth', - 'size': 480, - 'top_k': 30, - 'value_dim': 512, - 'masks_out_path': masks_out_path, # f'../VIDEOS/RESULTS/XMem_memory/thanks_two_face_5_frames/', - 'workspace': None, - 'save_masks': True - } - - if overwrite_config is not None: - config.update(overwrite_config) - - if compute_uncertainty: - assert uncertainty_name is not None - uncertainty_name = uncertainty_name.lower() - assert uncertainty_name in {'entropy', 'bald', 'disparity', 'disparity_large'} - compute_disparity = uncertainty_name.startswith('disparity') - else: - compute_disparity = False - - vid_reader = VideoReader( - "", - imgs_in_path, # f'/home/maksym/RESEARCH/VIDEOS/thanks_no_ears_5_annot/JPEGImages', - masks_in_path, # f'/home/maksym/RESEARCH/VIDEOS/thanks_no_ears_5_annot/Annotations_binarized_two_face', - size=config['size'], - use_all_mask=True - ) - - model_path = config['model'] - network = XMem(config, model_path).cuda().eval() - if model_path is not None: - model_weights = torch.load(model_path) - network.load_weights(model_weights, init_as_zero_if_needed=True) - else: - print('No model loaded.') - - loader = DataLoader(vid_reader, batch_size=1, shuffle=False, num_workers=8) - vid_name = vid_reader.vid_name - vid_length = len(loader) - # no need to count usage for LT if the video is not that long anyway - config['enable_long_term_count_usage'] = ( - config['enable_long_term'] and - (vid_length - / (config['max_mid_term_frames']-config['min_mid_term_frames']) - * config['num_prototypes']) - >= config['max_long_term_elements'] - ) - - mapper = MaskMapper() - processor = InferenceCore(network, config=config) - first_mask_loaded = False - - if only_predict_frames_to_annotate_and_quit > 0: - iterative_frames = calculate_proposals_for_annotations_iterative_pca_cosine(loader, processor, print_progress=print_progress, how_many_frames=only_predict_frames_to_annotate_and_quit) - - return iterative_frames - - frames_ = [] - masks_ = [] - - if original_memory_mechanism: - frames_to_put_in_permanent_memory = [0] # only the first frame goes into permanent memory originally - # the rest are going to be processed later - else: - frames_to_put_in_permanent_memory = frames_with_masks # in our modification, all frames with provided masks go into permanent memory - for j in frames_to_put_in_permanent_memory: - sample = vid_reader[j] - rgb = sample['rgb'].cuda() - rgb_raw_tensor = sample['raw_image_tensor'].cpu() - msk = sample['mask'] - info = sample['info'] - need_resize = info['need_resize'] - - # https://github.com/hkchengrex/XMem/issues/21 just make exhaustive = True - msk, labels = mapper.convert_mask(msk, exhaustive=True) - msk = torch.Tensor(msk).cuda() - - if min(msk.shape) == 0: # empty mask, e.g. [1, 0, 720, 1280] - print(f"Skipping adding frame {j} to memory, as the mask is empty") - continue # just don't add anything to the memory - if need_resize: - msk = vid_reader.resize_mask(msk.unsqueeze(0))[0] - - processor.set_all_labels(list(mapper.remappings.values())) - processor.put_to_permanent_memory(rgb, msk) - - frames_.append(rgb) - masks_.append(msk) - - if augment_images_with_masks: - augs = get_determenistic_augmentations(rgb.shape, msk, subset='best_all') - rgb_raw = FT.to_pil_image(rgb_raw_tensor) - - for img_aug, mask_aug in augs: - # tensor -> PIL.Image -> tensor -> whatever normalization vid_reader applies - rgb_aug = vid_reader.im_transform(img_aug(rgb_raw)).cuda() - - msk_aug = mask_aug(msk) - - processor.put_to_permanent_memory(rgb_aug, msk_aug) - - stats = [] - - if compute_uncertainty and uncertainty_name == 'bald': - bald = BALD() - - for ti, data in enumerate(tqdm(loader, disable=not print_progress)): - with torch.cuda.amp.autocast(enabled=True): - rgb = data['rgb'].cuda()[0] - rgb_raw_tensor = data['raw_image_tensor'].cpu()[0] - - gt = data.get('mask') # for IoU computations - if ti in frames_with_masks: - msk = data['mask'] - else: - msk = None - - info = data['info'] - frame = info['frame'][0] - shape = info['shape'] - need_resize = info['need_resize'][0] - curr_stat = {'frame': frame, 'mask_provided': msk is not None} - - if not first_mask_loaded: - if msk is not None: - first_mask_loaded = True - else: - # no point to do anything without a mask - continue - - # Map possibly non-continuous labels to continuous ones - if msk is not None: - # https://github.com/hkchengrex/XMem/issues/21 just make exhaustive = True - msk, labels = mapper.convert_mask(msk[0].numpy(), exhaustive=True) - msk = torch.Tensor(msk).cuda() - if need_resize: - msk = vid_reader.resize_mask(msk.unsqueeze(0))[0] - processor.set_all_labels(list(mapper.remappings.values())) - - else: - labels = None - - if (compute_uncertainty and uncertainty_name == 'bald') or compute_disparity: - dry_run_preds = [] - augged_images = [] - augs = get_determenistic_augmentations(subset='original_only') - rgb_raw = FT.to_pil_image(rgb_raw_tensor) - for img_aug, mask_aug in augs: - # tensor -> PIL.Image -> tensor -> whatever normalization vid_reader applies - augged_img = img_aug(rgb_raw) - augged_images.append(augged_img) - rgb_aug = vid_reader.im_transform(augged_img).cuda() - - msk = mask_aug(msk) # does not do anything, since original_only=True augmentations don't alter the mask at all - - dry_run_prob = processor.step(rgb_aug, msk, labels, end=(ti == vid_length-1), - manually_curated_masks=manually_curated_masks, disable_memory_updates=True) - dry_run_preds.append(dry_run_prob.cpu()) - - if original_memory_mechanism: - do_not_add_mask_to_memory = (ti == 0) # we only ignore the first mask, since it's already in the permanent memory - else: - do_not_add_mask_to_memory = msk is not None # we ignore all frames with masks, since they are already preloaded in the permanent memory - # Run the model on this frame - # 2+ channels, classes+ and background - prob = processor.step(rgb, msk, labels, end=(ti == vid_length-1), - manually_curated_masks=manually_curated_masks, do_not_add_mask_to_memory=do_not_add_mask_to_memory) - - if compute_uncertainty: - if uncertainty_name == 'bald': - # [batch=1, num_classes, ..., num_iterations] - all_samples = torch.stack([x.unsqueeze(0) for x in dry_run_preds + [prob.cpu()]], dim=-1).numpy() - score = bald.compute_score(all_samples) - # TODO: can also return the exact pixels for every frame? As a suggestion on what to label - curr_stat['bald'] = float(np.squeeze(score).mean()) - elif compute_disparity: - disparity_stats = compute_disparity_func( - predictions=[prob] + dry_run_preds, augs=[img_aug for img_aug, _ in augs], images=[rgb_raw] + augged_images, output_save_path=None) - curr_stat['disparity'] = float(disparity_stats['avg']) - curr_stat['disparity_large'] = float(disparity_stats['large']) - else: - e = entropy(prob.cpu()) - e_mean = np.mean(e) - curr_stat['entropy'] = float(e_mean) - - # Upsample to original size if needed - if need_resize: - prob = F.interpolate(prob.unsqueeze(1), shape, mode='bilinear', align_corners=False)[:, 0] - - # Probability mask -> index mask - out_mask = torch.argmax(prob, dim=0) - out_mask = (out_mask.detach().cpu().numpy()).astype(np.uint8) - - if compute_iou: - # mask is [0, 1] - # gt is [0, 255] - # both -> [False, True] - if gt is not None: - iou = float(compute_tensor_iou(torch.tensor(out_mask).type(torch.bool), gt.type(torch.bool))) - else: - iou = -1 - curr_stat['iou'] = iou - - # Save the mask - if config['save_masks']: - this_out_path = path.join(config['masks_out_path'], vid_name) - os.makedirs(this_out_path, exist_ok=True) - out_mask = mapper.remap_index_mask(out_mask) - out_img = Image.fromarray(out_mask) - if vid_reader.get_palette() is not None: - out_img.putpalette(vid_reader.get_palette()) - out_img.save(os.path.join(this_out_path, frame[:-4]+'.png')) - - if False: # args.save_scores: - np_path = path.join(args.output, 'Scores', vid_name) - os.makedirs(np_path, exist_ok=True) - if ti == len(loader)-1: - hkl.dump(mapper.remappings, path.join(np_path, f'backward.hkl'), mode='w') - if args.save_all or info['save'][0]: - hkl.dump(prob, path.join(np_path, f'{frame[:-4]}.hkl'), mode='w', compression='lzf') - - stats.append(curr_stat) - - return pd.DataFrame(stats) - - -def run_active_learning(imgs_in_path, masks_in_path, masks_out_path, num_extra_frames: int, uncertainty_name: str, csv_out_path: str = None, mode='batched', use_cache=False, **kwargs): - """ - mode:str - Possible values: - 'uniform': uniformly distributed indices np.linspace(0, `num_total_frames` - 1, `num_extra_frames` + 1).astype(int) - 'random': pick `num_extra_frames` random frames (cannot include first or last ones) - 'batched': Pick only `num_extra_frames` best frames - 'iterative': Pick only 1 best frame instead of `num_extra_frames`, repeat `num_extra_frames` times - """ - if mode.startswith('uniform_random'): - pass - else: - assert mode in {'uniform', 'random', 'batched', 'iterative', 'umap_half_cosine', 'umap_hdbscan_clustering', 'pca_max_cosine_values_arpack'} - assert uncertainty_name in {'entropy', 'bald', 'disparity', 'disparity_large'} - - num_total_frames = len(os.listdir(imgs_in_path)) - if mode == 'uniform': - # linspace is [a, b] (inclusive) - frames_with_masks = np.linspace(0, num_total_frames - 1, num_extra_frames + 1).astype(int) - elif mode == 'random': - np.random.seed(1) - extra_frames = np.random.choice(np.arange(1, num_total_frames), size=num_extra_frames, replace=False).tolist() - frames_with_masks = sorted([0] + extra_frames) - elif mode.startswith('uniform_random'): - seed = int(mode.split('_')[-1]) - chosen_frames = [] - space = np.linspace(0, num_total_frames, num_extra_frames + 2, endpoint=True, dtype=int) - ranges = zip(space, space[1:]) - np.random.seed(seed) - - for a, b in ranges: - if a == 0: - chosen_frames.append(0) - else: - extra_frame = int(np.random.choice(np.arange(a, b), replace=False)) - chosen_frames.append(extra_frame) - frames_with_masks = chosen_frames - - elif mode == 'batched': - # we save baseline results here, with just 1 annotation - baseline_out = Path(masks_out_path).parent.parent / 'baseline' - df = _inference_on_video( - imgs_in_path=imgs_in_path, - masks_in_path=masks_in_path, - masks_out_path=baseline_out / 'masks', - frames_with_masks=[0], - compute_uncertainty=True, - compute_iou=True, - uncertainty_name=uncertainty_name, - manually_curated_masks=False, - print_progress=False, - overwrite_config={'save_masks': True}, - ) - - df.to_csv(baseline_out / 'stats.csv', index=False) - if uncertainty_name == 'disparity_large': - candidates = select_n_frame_candidates_no_neighbours_simple(df, n=num_extra_frames, uncertainty_name=uncertainty_name) - else: - candidates = select_n_frame_candidates(df, n=num_extra_frames, uncertainty_name=uncertainty_name) - - extra_frames = [int(candidate['index']) for candidate in candidates] - - frames_with_masks = sorted([0] + extra_frames) - elif mode == 'iterative': - extra_frames = [] - for i in range(num_extra_frames): - df = _inference_on_video( - imgs_in_path=imgs_in_path, - masks_in_path=masks_in_path, - masks_out_path=masks_out_path, - frames_with_masks=[0] + extra_frames, - compute_uncertainty=True, - compute_iou=False, - uncertainty_name=uncertainty_name, - manually_curated_masks=False, - print_progress=False, - overwrite_config={'save_masks': False}, - ) - - max_frame = select_most_uncertain_frame(df, uncertainty_name=uncertainty_name) - extra_frames.append(max_frame['index']) - - # keep unsorted to preserve order of the choices - frames_with_masks = [0] + extra_frames - elif mode == 'umap_hdbscan_clustering' or mode == 'umap_half_cosine': - frames_with_masks = _inference_on_video( - imgs_in_path=imgs_in_path, - masks_in_path=masks_in_path, - masks_out_path=masks_out_path, - frames_with_masks=[0], - compute_uncertainty=True, - compute_iou=False, - manually_curated_masks=False, - print_progress=False, - uncertainty_name=uncertainty_name, - overwrite_config={'save_masks': False}, - only_predict_frames_to_annotate_and_quit=num_extra_frames, # ONLY THIS WILL RUN ANYWAY - ) - elif mode == 'pca_max_cosine_values_arpack': - # getting all the values - _, values = _inference_on_video( - imgs_in_path=imgs_in_path, - masks_in_path=masks_in_path, - masks_out_path=masks_out_path, - frames_with_masks=[0], - compute_uncertainty=False, - compute_iou=False, - manually_curated_masks=False, - print_progress=True, - uncertainty_name=uncertainty_name, - return_all_values=True, ## The key argument - overwrite_config={'save_masks': False}, - ) - - frames_with_masks = calculate_proposals_for_annotations_iterative_pca_cosine_values(values, how_many_frames=num_extra_frames, print_progress=False) - if use_cache and os.path.exists(csv_out_path): - final_df = pd.read_csv(csv_out_path) - else: - final_df = _inference_on_video( - imgs_in_path=imgs_in_path, - masks_in_path=masks_in_path, - masks_out_path=masks_out_path, - frames_with_masks=frames_with_masks, - compute_uncertainty=True, - compute_iou=True, - print_progress=False, - uncertainty_name=uncertainty_name, - **kwargs - ) - - if csv_out_path is not None: - p_csv_out = Path(csv_out_path) - - if not p_csv_out.parent.exists(): - p_csv_out.parent.mkdir(parents=True) - - final_df.to_csv(p_csv_out, index=False) - - return final_df, frames_with_masks - - -def eval_active_learning(dataset_path: str, out_path: str, num_extra_frames: int, uncertainty_name: str, modes: list = None, **kwargs): - assert uncertainty_name in {'entropy', 'bald', 'disparity', 'disparity_large'} - - if modes is None: - modes = ['uniform', 'random', 'uniform_random', 'batched', 'iterative', 'umap_half_cosine', 'umap_hdbscan_clustering', 'pca_max_cosine_values_arpack'] - - p_in_ds = Path(dataset_path) - p_out = Path(out_path) - - big_stats = defaultdict(list) - for i, p_video_imgs_in in enumerate(tqdm(sorted((p_in_ds / 'JPEGImages').iterdir()))): - video_name = p_video_imgs_in.stem - p_video_masks_in = p_in_ds / 'Annotations_binarized' / video_name - - p_video_out_general = p_out / f'Active_learning_{uncertainty_name}' / video_name / f'{num_extra_frames}_extra_frames' - - for mode in modes: - curr_video_stat = {'video': video_name} - p_out_masks = p_video_out_general / mode / 'masks' - p_out_stats = p_video_out_general / mode / 'stats.csv' - - stats, frames_with_masks = run_active_learning(p_video_imgs_in, p_video_masks_in, p_out_masks, - num_extra_frames=num_extra_frames, csv_out_path=p_out_stats, mode=mode, uncertainty_name=uncertainty_name, use_cache=False, **kwargs) - - stats = stats[stats['mask_provided'] == False] # remove stats for frames with given masks - for i in range(1, len(frames_with_masks) + 1): - curr_video_stat[f'extra_frame_{i}'] = frames_with_masks[i - 1] - - curr_video_stat[f'mean_iou'] = stats['iou'].mean() - curr_video_stat[f'mean_{uncertainty_name}'] = stats[uncertainty_name].mean() - - big_stats[mode].append(curr_video_stat) - - for mode, mode_stats in big_stats.items(): - df_mode_stats = pd.DataFrame(mode_stats) - df_mode_stats.to_csv(p_out / f'Active_learning_{uncertainty_name}' / f'stats_{mode}_all_videos.csv', index=False) - - -def run_on_video( - imgs_in_path: Union[str, PathLike], - masks_in_path: Union[str, PathLike], - masks_out_path: Union[str, PathLike], - frames_with_masks: Iterable[int] = (0, ), - compute_iou=False, - print_progress=True, - ) -> pd.DataFrame: - - """ - Args: - imgs_in_path (Union[str, PathLike]): Path to the directory containing video frames in the following format: `frame_000000.png`. .jpg works too. - - masks_in_path (Union[str, PathLike]): Path to the directory containing video frames' masks in the same format, with corresponding names between video frames. Each unique object should have unique color. - - masks_out_path (Union[str, PathLike]): Path to the output directory (will be created if doesn't exist) where the predicted masks will be stored in .png format. - - frames_with_masks (Iterable[int]): A list of integers representing the frames on which the masks should be applied (default: [0], only applied to the first frame). 0-based. - - compute_iou (bool): A flag to indicate whether to compute the IoU metric (default: False, requires ALL video frames to have a corresponding mask). - - print_progress (bool): A flag to indicate whether to print a progress bar (default: True). - - Returns: - stats (pd.Dataframe): a table containing every frame and the following information: IoU score with corresponding mask (if `compute_iou` is True) - """ - - return _inference_on_video( - imgs_in_path=imgs_in_path, - masks_in_path=masks_in_path, - masks_out_path=masks_out_path, - frames_with_masks=frames_with_masks, - compute_uncertainty=False, - compute_iou=compute_iou, - print_progress=print_progress, - manually_curated_masks=False - ) - - -def predict_annotation_candidates( - imgs_in_path: Union[str, PathLike], - num_candidates: int = 1, - print_progress=True, - ) -> List[int]: - - """ - Args: - imgs_in_path (Union[str, PathLike]): Path to the directory containing video frames in the following format: `frame_000000.png` .jpg works too. - - num_candidates (int, default: 1): How many annotations candidates to predict. - - print_progress (bool): A flag to indicate whether to print a progress bar (default: True). - - Returns: - annotation_candidates (List[int]): A list of frames indices (0-based) chosen as annotation candidates, sorted by importance (most -> least). Always contains [0] - first frame - at index 0. - """ - - assert num_candidates >= 1 - - if num_candidates == 1: - return [0] # First frame is hard-coded to always be used - - return _inference_on_video( - imgs_in_path=imgs_in_path, - masks_in_path=imgs_in_path, # Ignored - masks_out_path=None, # Ignored - frames_with_masks=[0], # Ignored - compute_uncertainty=False, - compute_iou=False, - print_progress=print_progress, - manually_curated_masks=False, - only_predict_frames_to_annotate_and_quit=num_candidates, - ) - From 8c0db37c7d215e6b1b3f323047bdacdc7d0850c3 Mon Sep 17 00:00:00 2001 From: max810 Date: Fri, 10 Feb 2023 15:53:53 +0400 Subject: [PATCH 12/49] Added environment.yml file --- environment.yml | 180 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 180 insertions(+) create mode 100644 environment.yml diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000..dd0b0f4 --- /dev/null +++ b/environment.yml @@ -0,0 +1,180 @@ +name: XMem +channels: + - pytorch + - conda-forge + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - _openmp_mutex=5.1=1_gnu + - asttokens=2.0.5=pyhd3eb1b0_0 + - backcall=0.2.0=pyhd3eb1b0_0 + - blas=1.0=mkl + - brotlipy=0.7.0=py39h27cfd23_1003 + - bzip2=1.0.8=h7b6447c_0 + - ca-certificates=2023.01.10=h06a4308_0 + - certifi=2022.12.7=py39h06a4308_0 + - cffi=1.15.1=py39h74dc2b5_0 + - colorama=0.4.6=pyhd8ed1ab_0 + - cryptography=37.0.1=py39h9ce1e76_0 + - cudatoolkit=11.3.1=h2bc3f7f_2 + - daal4py=2021.6.0=py39h79cecc1_1 + - dal=2021.6.0=hdb19cb5_916 + - decorator=5.1.1=pyhd3eb1b0_0 + - executing=0.8.3=pyhd3eb1b0_0 + - ffmpeg=4.3=hf484d3e_0 + - fftw=3.3.10=nompi_h77c792f_102 + - freetype=2.11.0=h70c0345_0 + - giflib=5.2.1=h7b6447c_0 + - gmp=6.2.1=h295c915_3 + - gnutls=3.6.15=he1e5248_0 + - hdbscan=0.8.28=py39hce5d2b2_1 + - idna=3.4=py39h06a4308_0 + - intel-openmp=2021.4.0=h06a4308_3561 + - ipython=8.8.0=py39h06a4308_0 + - jedi=0.18.1=py39h06a4308_1 + - joblib=1.1.1=py39h06a4308_0 + - jpeg=9e=h7f8727e_0 + - lame=3.100=h7b6447c_0 + - lcms2=2.12=h3be6417_0 + - ld_impl_linux-64=2.38=h1181459_1 + - lerc=3.0=h295c915_0 + - libdeflate=1.8=h7f8727e_5 + - libffi=3.3=he6710b0_2 + - libgcc-ng=11.2.0=h1234567_1 + - libgfortran-ng=12.2.0=h69a702a_19 + - libgfortran5=12.2.0=h337968e_19 + - libgomp=11.2.0=h1234567_1 + - libiconv=1.16=h7f8727e_2 + - libidn2=2.3.2=h7f8727e_0 + - libllvm11=11.1.0=hf817b99_2 + - libpng=1.6.37=hbc83047_0 + - libstdcxx-ng=11.2.0=h1234567_1 + - libtasn1=4.16.0=h27cfd23_0 + - libtiff=4.4.0=hecacb30_0 + - libunistring=0.9.10=h27cfd23_0 + - libwebp=1.2.4=h11a3e52_0 + - libwebp-base=1.2.4=h5eee18b_0 + - llvmlite=0.39.1=py39he621ea3_0 + - lz4-c=1.9.3=h295c915_1 + - matplotlib-inline=0.1.6=py39h06a4308_0 + - mkl=2021.4.0=h06a4308_640 + - mkl-service=2.4.0=py39h7f8727e_0 + - mkl_fft=1.3.1=py39hd3c417c_0 + - mkl_random=1.2.2=py39h51133e4_0 + - mpi=1.0=mpich + - mpich=3.3.2=external_0 + - ncurses=6.3=h5eee18b_3 + - nettle=3.7.3=hbbd107a_1 + - numba=0.56.4=py39h417a72b_0 + - openh264=2.1.1=h4ff587b_0 + - openssl=1.1.1s=h7f8727e_0 + - parso=0.8.3=pyhd3eb1b0_0 + - pexpect=4.8.0=pyhd3eb1b0_3 + - pickleshare=0.7.5=pyhd3eb1b0_1003 + - pip=22.2.2=py39h06a4308_0 + - prompt-toolkit=3.0.20=pyhd3eb1b0_0 + - ptyprocess=0.7.0=pyhd3eb1b0_2 + - pure_eval=0.2.2=pyhd3eb1b0_0 + - pycparser=2.21=pyhd3eb1b0_0 + - pygments=2.11.2=pyhd3eb1b0_0 + - pynndescent=0.5.8=pyh1a96a4e_0 + - pyopenssl=22.0.0=pyhd3eb1b0_0 + - pysocks=1.7.1=py39h06a4308_0 + - python=3.9.13=haa1d7c7_2 + - python_abi=3.9=2_cp39 + - pytorch=1.12.1=py3.9_cuda11.3_cudnn8.3.2_0 + - pytorch-mutex=1.0=cuda + - readline=8.1.2=h7f8727e_1 + - scikit-learn-intelex=2021.6.0=py39h06a4308_0 + - six=1.16.0=pyhd3eb1b0_1 + - sqlite=3.39.3=h5082296_0 + - stack_data=0.2.0=pyhd3eb1b0_0 + - tbb=2021.6.0=hdb19cb5_0 + - threadpoolctl=3.1.0=pyh8a188c0_0 + - tk=8.6.12=h1ccaba5_0 + - torchaudio=0.12.1=py39_cu113 + - tqdm=4.64.1=pyhd8ed1ab_0 + - traitlets=5.7.1=py39h06a4308_0 + - tzdata=2022e=h04d1e81_0 + - umap-learn=0.5.3=py39hf3d152e_0 + - wcwidth=0.2.5=pyhd3eb1b0_0 + - xz=5.2.6=h5eee18b_0 + - zlib=1.2.12=h5eee18b_3 + - zstd=1.5.2=ha4553b6_0 + - pip: + - absl-py==1.4.0 + - autopep8==2.0.0 + - baal==1.7.0 + - beautifulsoup4==4.11.1 + - blessed==1.19.1 + - cachetools==5.3.0 + - charset-normalizer==3.0.1 + - contourpy==1.0.6 + - cycler==0.11.0 + - cython==0.29.32 + - filelock==3.8.0 + - fonttools==4.38.0 + - gdown==4.5.3 + - gitdb==4.0.9 + - gitpython==3.1.29 + - google-auth==2.16.0 + - google-auth-oauthlib==0.4.6 + - gprof2dot==2022.7.29 + - gpustat==1.0.0 + - grpcio==1.51.1 + - h5py==3.7.0 + - hickle==5.0.2 + - importlib-metadata==6.0.0 + - kiwisolver==1.4.4 + - markdown==3.4.1 + - markupsafe==2.1.2 + - matplotlib==3.6.2 + - numpy==1.23.5 + - nvidia-ml-py==11.495.46 + - oauthlib==3.2.2 + - opencv-python==4.6.0.66 + - packaging==22.0 + - pandas==1.5.2 + - pillow==9.2.0 + - profilehooks==1.12.0 + - progressbar2==4.1.1 + - protobuf==3.20.3 + - psutil==5.9.4 + - pyasn1==0.4.8 + - pyasn1-modules==0.2.8 + - pycodestyle==2.9.1 + - pyparsing==3.0.9 + - pyqt5==5.15.7 + - pyqt5-qt5==5.15.2 + - pyqt5-sip==12.11.0 + - python-dateutil==2.8.2 + - python-graphviz==0.20.1 + - python-utils==3.3.3 + - pytz==2022.7 + - requests==2.28.2 + - requests-oauthlib==1.3.1 + - rsa==4.9 + - scikit-learn==1.2.0 + - scipy==1.9.3 + - seaborn==0.12.2 + - setuptools==66.1.1 + - smmap==5.0.0 + - snakeviz==2.1.1 + - soupsieve==2.3.2.post1 + - structlog==21.5.0 + - tensorboard==2.11.2 + - tensorboard-data-server==0.6.1 + - tensorboard-plugin-wit==1.8.1 + - thin-plate-spline==1.0.1 + - thinplate==1.0.0 + - tomli==2.0.1 + - torch-tb-profiler==0.4.1 + - torchmetrics==0.9.3 + - torchvision==0.13.1 + - torchviz==0.0.2 + - tornado==6.2 + - typing-extensions==4.4.0 + - urllib3==1.26.14 + - werkzeug==2.2.2 + - wheel==0.38.4 + - zipp==3.12.0 From 74e1948ee6c412d9e367320b46756a520b3144af Mon Sep 17 00:00:00 2001 From: max810 Date: Fri, 10 Feb 2023 15:56:17 +0400 Subject: [PATCH 13/49] Updated .gitignore --- .gitignore | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 4991846..3198e73 100644 --- a/.gitignore +++ b/.gitignore @@ -137,4 +137,12 @@ dmypy.json .pyre/ output/ -example_videos/ \ No newline at end of file +example_videos/ +torchlogs/ +*.png +*.csv +*.txt +*junk* +*.profile + +.DS_Store \ No newline at end of file From 48633d72a85aecabf9a61c254ac924edc5497228 Mon Sep 17 00:00:00 2001 From: max810 Date: Tue, 14 Feb 2023 14:36:50 +0400 Subject: [PATCH 14/49] Fixed a minor import error, XMem now saves overlaid images as well as original masks by default --- inference/data/video_reader.py | 8 ++-- .../frame_selection/frame_selection_utils.py | 1 - inference/run_on_video.py | 19 +++++--- util/image_saver.py | 46 ++++++++++++++++++- 4 files changed, 62 insertions(+), 12 deletions(-) diff --git a/inference/data/video_reader.py b/inference/data/video_reader.py index 8f1669c..b389678 100644 --- a/inference/data/video_reader.py +++ b/inference/data/video_reader.py @@ -37,7 +37,7 @@ def __init__(self, vid_name, image_dir, mask_dir, size=-1, to_save=None, use_all self.size_dir = size_dir self.frames = sorted(os.listdir(self.image_dir)) - self.palette = Image.open(path.join(mask_dir, sorted(os.listdir(mask_dir))[0])).getpalette() + self.reference_mask = Image.open(path.join(mask_dir, sorted(os.listdir(mask_dir))[0])).convert('P') self.first_gt_path = path.join(self.mask_dir, sorted(os.listdir(self.mask_dir))[0]) if size < 0: @@ -98,8 +98,10 @@ def resize_mask(self, mask): return F.interpolate(mask, (int(h/min_hw*self.size), int(w/min_hw*self.size)), mode='nearest') - def get_palette(self): - return self.palette + def map_the_colors_back(self, pred_mask: Image.Image): + # https://stackoverflow.com/questions/29433243/convert-image-to-specific-palette-using-pil-without-dithering + # dither=Dither.NONE just in case + return pred_mask.quantize(palette=self.reference_mask, dither=Image.Dither.NONE).convert('RGB') def __len__(self): return len(self.frames) \ No newline at end of file diff --git a/inference/frame_selection/frame_selection_utils.py b/inference/frame_selection/frame_selection_utils.py index 1746463..d5d2ebf 100644 --- a/inference/frame_selection/frame_selection_utils.py +++ b/inference/frame_selection/frame_selection_utils.py @@ -16,7 +16,6 @@ import matplotlib.pyplot as plt from util.tensor_util import get_bbox_from_mask -from inference.run_on_video import run_on_video, predict_annotation_candidates, KNOWN_ANNOTATION_PREDICTORS def select_n_frame_candidates(preds_df: pd.DataFrame, uncertainty_name: str, n=5): diff --git a/inference/run_on_video.py b/inference/run_on_video.py index b03dee7..45ba55f 100644 --- a/inference/run_on_video.py +++ b/inference/run_on_video.py @@ -16,6 +16,7 @@ from PIL import Image from model.network import XMem +from util.image_saver import create_overlay, save_image from util.tensor_util import compute_tensor_iou from inference.inference_core import InferenceCore from inference.data.video_reader import VideoReader @@ -45,7 +46,9 @@ def _inference_on_video(frames_with_masks, imgs_in_path, masks_in_path, masks_ou uncertainty_name: str = None, only_predict_frames_to_annotate_and_quit=0, overwrite_config: dict = None, - frame_selector_func: callable = None): + frame_selector_func: callable = None, + save_overlay=True, + b_and_w_color=(255, 0, 0)): torch.autograd.set_grad_enabled(False) frames_with_masks = set(frames_with_masks) config = { @@ -257,7 +260,6 @@ def _inference_on_video(frames_with_masks, imgs_in_path, masks_in_path, masks_ou all_samples = torch.stack( [x.unsqueeze(0) for x in dry_run_preds + [prob.cpu()]], dim=-1).numpy() score = bald.compute_score(all_samples) - # TODO: can also return the exact pixels for every frame? As a suggestion on what to label curr_stat['bald'] = float(np.squeeze(score).mean()) elif compute_disparity: disparity_stats = disparity_func( @@ -292,13 +294,16 @@ def _inference_on_video(frames_with_masks, imgs_in_path, masks_in_path, masks_ou # Save the mask if config['save_masks']: - this_out_path = path.join(config['masks_out_path'], vid_name) - os.makedirs(this_out_path, exist_ok=True) + original_img = FT.to_pil_image(rgb_raw_tensor) + out_mask = mapper.remap_index_mask(out_mask) out_img = Image.fromarray(out_mask) - if vid_reader.get_palette() is not None: - out_img.putpalette(vid_reader.get_palette()) - out_img.save(os.path.join(this_out_path, frame[:-4]+'.png')) + out_img = vid_reader.map_the_colors_back(out_img) + save_image(out_img, frame, vid_name, general_dir_path=config['masks_out_path'], sub_dir_name='masks') + + if save_overlay: + overlaid_img = create_overlay(original_img, out_img, color_if_black_and_white=b_and_w_color) + save_image(overlaid_img, frame, vid_name, general_dir_path=config['masks_out_path'], sub_dir_name='overlay') if False: # args.save_scores: np_path = path.join(args.output, 'Scores', vid_name) diff --git a/util/image_saver.py b/util/image_saver.py index c43d9de..ca026a6 100644 --- a/util/image_saver.py +++ b/util/image_saver.py @@ -1,5 +1,8 @@ +import os +from time import perf_counter import cv2 import numpy as np +from PIL import Image import torch from dataset.range_transform import inv_im_trans @@ -133,4 +136,45 @@ def pool_pairs(images, size, num_objects): # print(mask_transform(images['cls_gt'][bi,ti,0]==(oi+1), size).shape) - return get_image_array(req_images, size, key_captions) \ No newline at end of file + return get_image_array(req_images, size, key_captions) + +def _check_if_black_and_white(img: Image.Image): + unique_colors = img.getcolors() + if len(unique_colors) > 2: + return False + + if len(unique_colors) == 1: + return True # just a black image + + for _, color_rgb in unique_colors: + if color_rgb == (255, 255, 255): + return True + + return False + +def create_overlay(img: Image.Image, mask: Image.Image, mask_alpha=0.5, color_if_black_and_white=(255, 0, 0)): # all RGB + mask = mask.convert('RGB') + is_b_and_w = _check_if_black_and_white(mask) + + if img.size != mask.size: + mask = mask.resize(img.size, resample=Image.NEAREST) + + mask_arr = np.array(mask) + + if is_b_and_w: + mask_arr = np.where(mask_arr, np.array(color_if_black_and_white), mask_arr).astype(np.uint8) + mask = Image.fromarray(mask_arr, mode='RGB') + + alpha_mask = np.full(mask_arr.shape[0:2], 255) + alpha_mask[cv2.cvtColor(mask_arr, cv2.COLOR_BGR2GRAY) > 0] = int(mask_alpha * 255) # 255 for black (to keep original image in full), `mask_alpha` for predicted pixels + + overlay = Image.composite(img, mask, Image.fromarray(alpha_mask.astype(np.uint8), mode='L')) + + return overlay + +def save_image(img: Image.Image, frame_name, video_name, general_dir_path, sub_dir_name='masks'): + this_out_path = os.path.join(general_dir_path, video_name, sub_dir_name) + os.makedirs(this_out_path, exist_ok=True) + + img_save_path = os.path.join(this_out_path, frame_name[:-4]+'.png') + cv2.imwrite(img_save_path, cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)) \ No newline at end of file From ca2436e3ca8adb5234896dd58dcaa8ea3d74d360 Mon Sep 17 00:00:00 2001 From: max810 Date: Tue, 14 Feb 2023 17:21:36 +0400 Subject: [PATCH 15/49] Changed .png to .jpg for saved overlay masks, minor tweaks --- .../frame_selection/frame_selection_utils.py | 22 +++++++++---------- inference/run_on_video.py | 8 ++++--- util/image_saver.py | 4 ++-- 3 files changed, 18 insertions(+), 16 deletions(-) diff --git a/inference/frame_selection/frame_selection_utils.py b/inference/frame_selection/frame_selection_utils.py index d5d2ebf..1301d5b 100644 --- a/inference/frame_selection/frame_selection_utils.py +++ b/inference/frame_selection/frame_selection_utils.py @@ -123,17 +123,17 @@ def get_determenistic_augmentations(img_size=None, mask=None, subset: str = None identity = torch.nn.Identity() identity.name = 'identity' - if mask is not None: - if mask.any(): - min_y, min_x, max_y, max_x = get_bbox_from_mask(mask) - h, w = mask.shape[-2:] - crop_mask = partial(FT.resized_crop, top=min_y - 10, left=min_x - 10, - height=max_y - min_y + 10, width=max_x - min_x + 10, size=(w, h)) - crop_mask.name = 'crop_mask' - else: - crop_mask = identity # if the mask is empty - else: - crop_mask = None + # if mask is not None: + # if mask.any(): + # min_y, min_x, max_y, max_x = get_bbox_from_mask(mask) + # h, w = mask.shape[-2:] + # crop_mask = partial(FT.resized_crop, top=min_y - 10, left=min_x - 10, + # height=max_y - min_y + 10, width=max_x - min_x + 10, size=(w, h)) + # crop_mask.name = 'crop_mask' + # else: + # crop_mask = identity # if the mask is empty + # else: + crop_mask = None bright.name = 'bright' dark.name = 'dark' diff --git a/inference/run_on_video.py b/inference/run_on_video.py index 45ba55f..5f5d10e 100644 --- a/inference/run_on_video.py +++ b/inference/run_on_video.py @@ -299,11 +299,11 @@ def _inference_on_video(frames_with_masks, imgs_in_path, masks_in_path, masks_ou out_mask = mapper.remap_index_mask(out_mask) out_img = Image.fromarray(out_mask) out_img = vid_reader.map_the_colors_back(out_img) - save_image(out_img, frame, vid_name, general_dir_path=config['masks_out_path'], sub_dir_name='masks') + save_image(out_img, frame, vid_name, general_dir_path=config['masks_out_path'], sub_dir_name='masks', extension='.png') if save_overlay: overlaid_img = create_overlay(original_img, out_img, color_if_black_and_white=b_and_w_color) - save_image(overlaid_img, frame, vid_name, general_dir_path=config['masks_out_path'], sub_dir_name='overlay') + save_image(overlaid_img, frame, vid_name, general_dir_path=config['masks_out_path'], sub_dir_name='overlay', extension='.jpg') if False: # args.save_scores: np_path = path.join(args.output, 'Scores', vid_name) @@ -327,6 +327,7 @@ def run_on_video( frames_with_masks: Iterable[int] = (0, ), compute_iou=False, print_progress=True, + **kwargs ) -> pd.DataFrame: """ Args: @@ -354,7 +355,8 @@ def run_on_video( compute_uncertainty=False, compute_iou=compute_iou, print_progress=print_progress, - manually_curated_masks=False + manually_curated_masks=False, + **kwargs ) diff --git a/util/image_saver.py b/util/image_saver.py index ca026a6..f858671 100644 --- a/util/image_saver.py +++ b/util/image_saver.py @@ -172,9 +172,9 @@ def create_overlay(img: Image.Image, mask: Image.Image, mask_alpha=0.5, color_if return overlay -def save_image(img: Image.Image, frame_name, video_name, general_dir_path, sub_dir_name='masks'): +def save_image(img: Image.Image, frame_name, video_name, general_dir_path, sub_dir_name='masks', extension='.png'): this_out_path = os.path.join(general_dir_path, video_name, sub_dir_name) os.makedirs(this_out_path, exist_ok=True) - img_save_path = os.path.join(this_out_path, frame_name[:-4]+'.png') + img_save_path = os.path.join(this_out_path, frame_name[:-4] + extension) cv2.imwrite(img_save_path, cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)) \ No newline at end of file From d4f8335b80525ea0435cf3f0c4b24d4a67e69013 Mon Sep 17 00:00:00 2001 From: max810 Date: Fri, 10 Mar 2023 14:38:47 +0400 Subject: [PATCH 16/49] Cleanup; Added minor experiments functions; Added metrics; Fixed num_groups bug in memory_manager --- environment.yml | 14 +- .../frame_selection/frame_selection_utils.py | 10 - inference/memory_manager.py | 63 +++-- inference/run_experiments.py | 198 ++++++++++++-- util/metrics.py | 255 ++++++++++++++++++ util/tensor_util.py | 29 ++ 6 files changed, 526 insertions(+), 43 deletions(-) create mode 100644 util/metrics.py diff --git a/environment.yml b/environment.yml index dd0b0f4..44d3fc0 100644 --- a/environment.yml +++ b/environment.yml @@ -21,8 +21,10 @@ dependencies: - dal=2021.6.0=hdb19cb5_916 - decorator=5.1.1=pyhd3eb1b0_0 - executing=0.8.3=pyhd3eb1b0_0 + - faiss-gpu=1.7.3=py3.9_h28a55e0_0_cuda11.3 - ffmpeg=4.3=hf484d3e_0 - fftw=3.3.10=nompi_h77c792f_102 + - flit-core=3.6.0=pyhd3eb1b0_0 - freetype=2.11.0=h70c0345_0 - giflib=5.2.1=h7b6447c_0 - gmp=6.2.1=h295c915_3 @@ -39,6 +41,7 @@ dependencies: - ld_impl_linux-64=2.38=h1181459_1 - lerc=3.0=h295c915_0 - libdeflate=1.8=h7f8727e_5 + - libfaiss=1.7.3=hfc2d529_0_cuda11.3 - libffi=3.3=he6710b0_2 - libgcc-ng=11.2.0=h1234567_1 - libgfortran-ng=12.2.0=h69a702a_19 @@ -67,7 +70,7 @@ dependencies: - nettle=3.7.3=hbbd107a_1 - numba=0.56.4=py39h417a72b_0 - openh264=2.1.1=h4ff587b_0 - - openssl=1.1.1s=h7f8727e_0 + - openssl=1.1.1t=h7f8727e_0 - parso=0.8.3=pyhd3eb1b0_0 - pexpect=4.8.0=pyhd3eb1b0_3 - pickleshare=0.7.5=pyhd3eb1b0_1003 @@ -95,6 +98,7 @@ dependencies: - torchaudio=0.12.1=py39_cu113 - tqdm=4.64.1=pyhd8ed1ab_0 - traitlets=5.7.1=py39h06a4308_0 + - typing_extensions=4.3.0=py39h06a4308_0 - tzdata=2022e=h04d1e81_0 - umap-learn=0.5.3=py39hf3d152e_0 - wcwidth=0.2.5=pyhd3eb1b0_0 @@ -110,8 +114,10 @@ dependencies: - cachetools==5.3.0 - charset-normalizer==3.0.1 - contourpy==1.0.6 + - cupy-cuda11x==11.5.0 - cycler==0.11.0 - cython==0.29.32 + - fastrlock==0.8.1 - filelock==3.8.0 - fonttools==4.38.0 - gdown==4.5.3 @@ -123,6 +129,7 @@ dependencies: - gpustat==1.0.0 - grpcio==1.51.1 - h5py==3.7.0 + - haishoku==1.1.8 - hickle==5.0.2 - importlib-metadata==6.0.0 - kiwisolver==1.4.4 @@ -165,6 +172,7 @@ dependencies: - tensorboard==2.11.2 - tensorboard-data-server==0.6.1 - tensorboard-plugin-wit==1.8.1 + - termcolor==2.2.0 - thin-plate-spline==1.0.1 - thinplate==1.0.0 - tomli==2.0.1 @@ -173,8 +181,10 @@ dependencies: - torchvision==0.13.1 - torchviz==0.0.2 - tornado==6.2 - - typing-extensions==4.4.0 + - trash-cli==0.23.2.13.2 + - typing==3.7.4.3 - urllib3==1.26.14 - werkzeug==2.2.2 - wheel==0.38.4 - zipp==3.12.0 +prefix: /home/maksym/miniconda3/envs/XMem diff --git a/inference/frame_selection/frame_selection_utils.py b/inference/frame_selection/frame_selection_utils.py index 1301d5b..14a3449 100644 --- a/inference/frame_selection/frame_selection_utils.py +++ b/inference/frame_selection/frame_selection_utils.py @@ -1,21 +1,11 @@ from functools import partial -import json -import os from pathlib import Path -from typing import Any, Dict, List, Set, Tuple, Union from sklearn.cluster import KMeans import pandas as pd -from PIL import Image import torch import torchvision.transforms.functional as FT -import numpy as np from torchvision.transforms import ColorJitter, Grayscale, RandomPosterize, RandomAdjustSharpness, ToTensor, RandomAffine -import cv2 -from tqdm import tqdm -import matplotlib.pyplot as plt - -from util.tensor_util import get_bbox_from_mask def select_n_frame_candidates(preds_df: pd.DataFrame, uncertainty_name: str, n=5): diff --git a/inference/memory_manager.py b/inference/memory_manager.py index aadf88d..0e721b0 100644 --- a/inference/memory_manager.py +++ b/inference/memory_manager.py @@ -61,7 +61,9 @@ def match_memory(self, query_key, selection, disable_usage_updates=False): # selection: B x C^k x H x W # TODO: keep groups in both..? # 1x64x30x54 - num_groups = self.temporary_work_mem.num_groups + + # = permanent_work_mem.num_groups, since it's always >= temporary_work_mem.num_groups + num_groups = max(self.temporary_work_mem.num_groups, self.permanent_work_mem.num_groups) h, w = query_key.shape[-2:] query_key = query_key.flatten(start_dim=2) @@ -94,25 +96,30 @@ def match_memory(self, query_key, selection, disable_usage_updates=False): # compute affinity group by group as later groups only have a subset of keys for gi in range(1, num_groups): + temp_group_v_size = self.temporary_work_mem.get_v_size(gi) + perm_group_v_size = self.permanent_work_mem.get_v_size(gi) + temp_sim_size = temp_work_mem_similarity.shape[1] + perm_sim_size = perm_work_mem_similarity.shape[1] + if gi < self.long_mem.num_groups: # merge working and lt similarities before softmax affinity_one_group = do_softmax( torch.cat([long_mem_similarity[:, -self.long_mem.get_v_size(gi):], - temp_work_mem_similarity[:, -self.temporary_work_mem.get_v_size(gi):], - perm_work_mem_similarity[:, -self.permanent_work_mem.get_v_size(gi):]], - 1), + temp_work_mem_similarity[:, temp_sim_size-temp_group_v_size:], + perm_work_mem_similarity[:, perm_sim_size-perm_group_v_size:]], + dim=1), top_k=self.top_k, inplace=True) else: # no long-term memory for this group affinity_one_group = do_softmax(torch.cat([ - temp_work_mem_similarity[:, -self.temporary_work_mem.get_v_size(gi):], - perm_work_mem_similarity[:, -self.permanent_work_mem.get_v_size(gi):]], + temp_work_mem_similarity[:, temp_sim_size-temp_group_v_size:], + perm_work_mem_similarity[:, perm_sim_size-perm_group_v_size:]], 1), top_k=self.top_k, inplace=(gi == num_groups-1)) affinity.append(affinity_one_group) all_memory_value = [] - for gi, gv in enumerate(self.temporary_work_mem.value): + for gi in range(num_groups): # merge the working and lt values before readout if gi < self.long_mem.num_groups: all_memory_value.append(torch.cat([self.long_mem.value[gi], self.temporary_work_mem.value[gi], self.permanent_work_mem.value[gi]], -1)) @@ -136,6 +143,8 @@ def match_memory(self, query_key, selection, disable_usage_updates=False): shrinkage = torch.cat([self.temporary_work_mem.shrinkage, self.permanent_work_mem.shrinkage], -1) # No long-term memory similarity = get_similarity(memory_key, shrinkage, query_key, selection) + temp_work_mem_similarity = similarity[:, :temp_work_mem_size] + perm_work_mem_similarity = similarity[:, temp_work_mem_size:] if self.enable_long_term: affinity, usage = do_softmax(similarity, inplace=(num_groups == 1), @@ -151,13 +160,25 @@ def match_memory(self, query_key, selection, disable_usage_updates=False): # compute affinity group by group as later groups only have a subset of keys for gi in range(1, num_groups): - affinity_one_group = do_softmax(similarity[:, -self.temporary_work_mem.get_v_size(gi):], - top_k=self.top_k, inplace=(gi == num_groups-1)) + temp_group_v_size = self.temporary_work_mem.get_v_size(gi) + perm_group_v_size = self.permanent_work_mem.get_v_size(gi) + temp_sim_size = temp_work_mem_similarity.shape[1] + perm_sim_size = perm_work_mem_similarity.shape[1] + + affinity_one_group = do_softmax( + torch.cat([ + # concats empty tensor if the group is also empty for temporary memory + temp_work_mem_similarity[:, temp_sim_size-temp_group_v_size:], + perm_work_mem_similarity[:, perm_sim_size-perm_group_v_size:], + ], dim=1), + top_k=self.top_k, inplace=(gi == num_groups-1) + ) affinity.append(affinity_one_group) all_memory_value = [] - for gi, gv in enumerate(self.temporary_work_mem.value): - all_memory_value.append(torch.cat([self.temporary_work_mem.value[gi], self.permanent_work_mem.value[gi]], -1)) + for gi in range(num_groups): + group_v_cat = torch.cat([self.temporary_work_mem.value[gi], self.permanent_work_mem.value[gi]], -1) + all_memory_value.append(group_v_cat) # Shared affinity within each group all_readout_mem = torch.cat([ @@ -202,16 +223,28 @@ def add_memory(self, key, shrinkage, value, objects, selection=None, permanent=F else: self.temporary_work_mem.add(key, value, shrinkage, selection, objects) - if not self.temporary_work_mem.engaged(): - # first frame; we need to have both memories engaged to avoid crashes when concating + + num_temp_groups = self.temporary_work_mem.num_groups + num_perm_groups = self.permanent_work_mem.num_groups + + if not self.temporary_work_mem.engaged() or (num_temp_groups != num_perm_groups): + # print(f"PERM_NUM_GROUPS={num_perm_groups} vs TEMP_NUM_GROUPS={num_temp_groups}", end=' ') + + # first frame or new group; we need to have both memories engaged to avoid crashes when concating # so we just initialize the temporary one with an empty tensor key0 = key[..., 0:0] value0 = value[..., 0:0] shrinkage0 = shrinkage[..., 0:0] selection0 = selection[..., 0:0] + if num_perm_groups > num_temp_groups: + # for preloading into permanent memory + self.temporary_work_mem.add(key0, value0, shrinkage0, selection0, objects) + else: + # for original memory mechanism + self.permanent_work_mem.add(key0, value0, shrinkage0, selection0, objects) + + # print(f"AFTER->PERM_NUM_GROUPS={self.permanent_work_mem.num_groups} vs TEMP_NUM_GROUPS={self.temporary_work_mem.num_groups}") - self.temporary_work_mem.add(key0, value0, shrinkage0, selection0, objects) - # long-term memory cleanup if self.enable_long_term: # Do memory compressed if needed diff --git a/inference/run_experiments.py b/inference/run_experiments.py index 9acd525..db6607c 100644 --- a/inference/run_experiments.py +++ b/inference/run_experiments.py @@ -6,8 +6,12 @@ import cv2 import numpy as np import pandas as pd +import torch from tqdm import tqdm from matplotlib import pyplot as plt +from PIL import Image +from util.metrics import batched_f_measure, batched_jaccard +from p_tqdm import p_umap from inference.frame_selection.frame_selection import KNOWN_ANNOTATION_PREDICTORS from inference.run_on_video import predict_annotation_candidates, run_on_video @@ -133,52 +137,104 @@ def get_videos_info(): } -def run_multiple_frame_selectors(videos_info: Dict[str, Dict], csv_output_path: str): - output = pd.DataFrame(columns=list(KNOWN_ANNOTATION_PREDICTORS)) - p_bar = tqdm(total=len(videos_info) * len(KNOWN_ANNOTATION_PREDICTORS)) +def run_multiple_frame_selectors(videos_info: Dict[str, Dict], csv_output_path: str, predictors: Dict[str, callable] = KNOWN_ANNOTATION_PREDICTORS, load_existing_masks=False): + output = pd.DataFrame(columns=list(predictors)) + p_bar = tqdm(total=len(videos_info) * len(predictors)) + + exceptions = pd.DataFrame(columns=['video', 'method', 'error_message']) for video_name, info in videos_info.items(): video_frames_path = info['video_frames_path'] num_candidate_frames = info['num_annotation_candidates'] + if load_existing_masks: + masks_first_frame_only = Path(info['masks_out_path']) / 'ONLY_ONE_FRAME' + else: + masks_first_frame_only = None results = {} - for method_name in KNOWN_ANNOTATION_PREDICTORS: - chosen_annotation_frames = predict_annotation_candidates( - video_frames_path, num_candidates=num_candidate_frames, approach=method_name) + for method_name, method_func in predictors.items(): + try: + chosen_annotation_frames = predict_annotation_candidates( + video_frames_path, + num_candidates=num_candidate_frames, + candidate_selection_function=method_func, + masks_first_frame_only=masks_first_frame_only, + masks_in_path=info['video_masks_path'], + masks_out_path=Path(info['masks_out_path']) / 'FIRST_FRAME_ONLY' / 'masks', # used by some target-aware algorithms + print_progress=False + ) + except Exception as e: + print(f"[!!!] ERROR ({video_name},{method_name})={e}") + print("Resulting to uniform baseline") + chosen_annotation_frames = predict_annotation_candidates( + video_frames_path, + num_candidates=num_candidate_frames, + candidate_selection_function=KNOWN_ANNOTATION_PREDICTORS['UNIFORM'], + masks_in_path=info['video_masks_path'], + print_progress=False + ) + exceptions.append([video_name, method_name, str(e)]) + + torch.cuda.empty_cache() results[method_name] = json.dumps(chosen_annotation_frames) p_bar.update() output.loc[video_name] = results - output.index.name = 'video_name' - output.to_csv(csv_output_path) + # save updated after every video + output.index.name = 'video_name' + output.to_csv(csv_output_path) + if min(exceptions.shape) > 0: + exceptions.to_csv('output/exceptions.csv') -def run_inference_with_pre_chosen_frames(chosen_frames_csv_path: str, videos_info: Dict[str, Dict], output_path: str, only_methods_subset: Set[str] = None): + +def run_inference_with_pre_chosen_frames(chosen_frames_csv_path: str, videos_info: Dict[str, Dict], output_path: str, only_methods_subset: Set[str] = None, compute_iou=False, IoU_results_save_path=None, **kwargs): df = pd.read_csv(chosen_frames_csv_path, index_col='video_name') - num_runs = np.prod(df.shape) - p_bar = tqdm( - desc='Running inference comparing multiple different AL approaches', total=num_runs) + if only_methods_subset is not None: + num_runs = df.shape[0] * len(only_methods_subset) + else: + num_runs = np.prod(df.shape) + + if compute_iou: + assert IoU_results_save_path is not None + p_iou_dir = Path(IoU_results_save_path) + + i = 0 + p_bar = tqdm(desc='Running inference comparing multiple different AL approaches', total=num_runs) for video_name, info in videos_info.items(): video_row = df.loc[video_name] + # ious = {} for method in video_row.index: if only_methods_subset is not None and method not in only_methods_subset: continue - + chosen_frames_str = video_row.loc[method] chosen_frames = json.loads(chosen_frames_str) - print(chosen_frames) video_frames_path = info['video_frames_path'] video_masks_path = info['video_masks_path'] output_masks_path = Path(output_path) / video_name / method - run_on_video(video_frames_path, video_masks_path, output_masks_path, - frames_with_masks=chosen_frames, compute_iou=False, print_progress=False) + stats = run_on_video(video_frames_path, video_masks_path, output_masks_path, + frames_with_masks=chosen_frames, compute_iou=compute_iou, print_progress=False, **kwargs) + + if compute_iou: + p_out_curr_video_method = p_iou_dir / video_name + if not p_out_curr_video_method.exists(): + p_out_curr_video_method.mkdir(parents=True) + + stats.to_csv(p_out_curr_video_method / f'{method}.csv')#f'output/AL_comparison_all_methods/{video_name}_{method}.csv') + # print(f"Video={video_name},method={method},IoU={stats['iou'].mean():.4f}") + # ious[f'{video_name}_{method}'] = [float(iou) for iou in stats['iou']] p_bar.update() + i += 1 + + # with open(f'output/AL_comparison_all_methods/ious_{video_name}_all_methods.json', 'wt') as f_out: + # json.dump(ious, f_out) def visualize_chosen_frames(video_name: str, num_total_frames: int, data: pd.Series, output_path: str): @@ -234,6 +290,116 @@ def _sort_index(series): plt.savefig(p_out, bbox_inches='tight') # -------------------------END Inference and visualization utils -------------------------- +# ------------------------BEGIN metrics --------------------------------------------------- + +def _load_gt(p): + return np.stack([np.array(Image.open(p_gt).convert('P')) for p_gt in sorted(p.iterdir())]) + + +def _load_preds(p, palette: Image.Image, size: tuple): + return np.stack([Image.open(p_gt).convert('RGB').resize(size, resample=Image.Resampling.NEAREST).quantize(palette=palette, dither=Image.Dither.NONE) for p_gt in sorted(p.iterdir())]) + +def compute_metrics_al(p_source_masks, p_preds, looped=True): + def _proc(p_video: Path): + video_name = p_video.name + p_gts = p_source_masks / p_video.name + first_mask = Image.open(next(p_gts.iterdir())).convert('P') + w, h = first_mask.size + gts = _load_gt(p_gts) + + stats = { + 'video_name': video_name + } + + for p_method in p_video.iterdir(): + method_name = p_method.name + p_masks = p_method / 'masks' + preds = _load_preds(p_masks, palette=first_mask, size=(w, h)) + + assert preds.shape == gts.shape + + iou = batched_jaccard(gts, preds) + avg_iou = iou.mean(axis=0) + + f_score = batched_f_measure(gts, preds) + avg_f_score = f_score.mean(axis=0) + + stats[f'{method_name}-iou'] = float(avg_iou) + stats[f'{method_name}-f'] = float(avg_f_score) + + if looped: + n = iou.shape[0] + between = int(0.9 * n) + first_part_iou = iou[:between].mean() + second_part_iou = iou[between:].mean() + + first_part_f_score = f_score[:between].mean() + second_part_f_score = f_score[between:].mean() + + stats[f'{method_name}-iou-90'] = float(first_part_iou) + stats[f'{method_name}-iou-10'] = float(second_part_iou) + stats[f'{method_name}-f-90'] = float(first_part_f_score) + stats[f'{method_name}-f-10'] = float(second_part_f_score) + + return stats + + list_of_stats = p_umap(_proc, list(p_preds.iterdir()), num_cpus=3) + + results = pd.DataFrame.from_records(list_of_stats).dropna(axis='columns').set_index('video_name') + return results + +def compute_metrics(p_source_masks, p_preds): + list_of_stats = [] + # for p_pred_video in list(p_preds.iterdir()): + def _proc(p_pred_video: Path): + video_name = p_pred_video.name + # if 'XMem' in str(p_pred_video): + p_pred_video = Path(p_pred_video) / 'masks' + p_gts = p_source_masks / video_name + first_mask = Image.open(next(p_gts.iterdir())).convert('P') + w, h = first_mask.size + gts = _load_gt(p_gts) + + preds = _load_preds(p_pred_video, palette=first_mask, size=(w, h)) + + assert preds.shape == gts.shape + + avg_iou = batched_jaccard(gts, preds).mean(axis=0) + avg_f_score = batched_f_measure(gts, preds).mean(axis=0) + stats = { + 'video_name': video_name, + 'iou': float(avg_iou), + 'f': float(avg_f_score), + } + + return stats + # list_of_stats.append(stats) + # p_source_masks = Path('/home/maksym/RESEARCH/Datasets/MOSE/train/Annotations') + # p_preds = Path('/home/maksym/RESEARCH/VIDEOS/RESULTS/XMem_memory/MOSE/AL_comparison') + + list_of_stats = p_umap(_proc, sorted(p_preds.iterdir(), key=lambda x: len(os.listdir(x)), reverse=True), num_cpus=4) + + results = pd.DataFrame.from_records(list_of_stats).dropna(axis='columns').set_index('video_name') + return results + +# -------------------------END metrics ------------------------------------------------------ + +def get_dataset_video_info(p_imgs_general, p_annotations_general, p_out_general, num_annotation_candidates=5): + videos_info = {} + + for p_video in sorted(p_imgs_general.iterdir(), key=lambda x: len(os.listdir(x)), reverse=True): # longest video first to avoid OOM in the future + video_name = p_video.name + p_masks = p_annotations_general / video_name + + videos_info[video_name] = dict( + num_annotation_candidates=num_annotation_candidates, + video_frames_path=p_video, + video_masks_path=p_masks, + masks_out_path=p_out_general / video_name + ) + + return videos_info + if __name__ == "__main__": diff --git a/util/metrics.py b/util/metrics.py new file mode 100644 index 0000000..84a01cc --- /dev/null +++ b/util/metrics.py @@ -0,0 +1,255 @@ +from __future__ import absolute_import, division +import math + +import cv2 +import numpy as np +from skimage.morphology import disk + +__all__ = ['batched_jaccard', 'batched_f_measure'] + + +def batched_jaccard(y_true, y_pred, average_over_objects=True, nb_objects=None): + """ Batch jaccard similarity for multiple instance segmentation. + + Jaccard similarity over two subsets of binary elements $A$ and $B$: + + $$ + \mathcal{J} = \\frac{A \\cap B}{A \\cup B} + $$ + + # Arguments + y_true: Numpy Array. Array of shape (B x H x W) and type integer giving the + ground truth of the object instance segmentation. + y_pred: Numpy Array. Array of shape (B x H x W) and type integer giving the + prediction of the object segmentation. + average_over_objects: Boolean. Weather or not to average the jaccard over + all the objects in the sequence. Default True. + nb_objects: Integer. Number of objects in the ground truth mask. If + `None` the value will be infered from `y_true`. Setting this value + will speed up the computation. + + # Returns + ndarray: Returns an array of shape (B) with the average jaccard for + all instances at each frame if `average_over_objects=True`. If + `average_over_objects=False` returns an array of shape (B x nObj) + with nObj being the number of objects on `y_true`. + """ + y_true = np.asarray(y_true, dtype=np.int) + y_pred = np.asarray(y_pred, dtype=np.int) + if y_true.ndim != 3: + raise ValueError('y_true array must have 3 dimensions.') + if y_pred.ndim != 3: + raise ValueError('y_pred array must have 3 dimensions.') + if y_true.shape != y_pred.shape: + raise ValueError('y_true and y_pred must have the same shape. {} != {}'.format(y_true.shape, y_pred.shape)) + + if nb_objects is None: + objects_ids = np.unique(y_true[(y_true < 255) & (y_true > 0)]) + nb_objects = len(objects_ids) + else: + objects_ids = [i + 1 for i in range(nb_objects)] + objects_ids = np.asarray(objects_ids, dtype=np.int) + if nb_objects == 0: + raise ValueError('Number of objects in y_true should be higher than 0.') + nb_frames = len(y_true) + + jaccard = np.empty((nb_frames, nb_objects), dtype=np.float) + + for i, obj_id in enumerate(objects_ids): + mask_true, mask_pred = y_true == obj_id, y_pred == obj_id + + union = (mask_true | mask_pred).sum(axis=(1, 2)) + intersection = (mask_true & mask_pred).sum(axis=(1, 2)) + + for j in range(nb_frames): + if np.isclose(union[j], 0): + jaccard[j, i] = 1. + else: + jaccard[j, i] = intersection[j] / union[j] + + if average_over_objects: + jaccard = jaccard.mean(axis=1) + return jaccard + + +def _seg2bmap(seg, width=None, height=None): + """ + From a segmentation, compute a binary boundary map with 1 pixel wide + boundaries. The boundary pixels are offset by 1/2 pixel towards the + origin from the actual segment boundary. + + # Arguments + seg: Segments labeled from 1..k. + width: Width of desired bmap <= seg.shape[1] + height: Height of desired bmap <= seg.shape[0] + + # Returns + bmap (ndarray): Binary boundary map. + + David Martin + January 2003 + """ + + seg = seg.astype(np.bool) + seg[seg > 0] = 1 + + assert np.atleast_3d(seg).shape[2] == 1 + + width = seg.shape[1] if width is None else width + height = seg.shape[0] if height is None else height + + h, w = seg.shape[:2] + + ar1 = float(width) / float(height) + ar2 = float(w) / float(h) + + assert not (width > w | height > h | abs(ar1 - ar2) > + 0.01), "Can't convert %dx%d seg to %dx%d bmap." % (w, h, width, + height) + + e = np.zeros_like(seg) + s = np.zeros_like(seg) + se = np.zeros_like(seg) + + e[:, :-1] = seg[:, 1:] + s[:-1, :] = seg[1:, :] + se[:-1, :-1] = seg[1:, 1:] + + b = seg ^ e | seg ^ s | seg ^ se + b[-1, :] = seg[-1, :] ^ e[-1, :] + b[:, -1] = seg[:, -1] ^ s[:, -1] + b[-1, -1] = 0 + + if w == width and h == height: + bmap = b + else: + bmap = np.zeros((height, width)) + for x in range(w): + for y in range(h): + if b[y, x]: + j = 1 + math.floor((y - 1) + height / h) + i = 1 + math.floor((x - 1) + width / h) + bmap[j, i] = 1 + + return bmap + + +def f_measure(true_mask, pred_mask, bound_th=0.008): + """F-measure for two 2D masks. + + # Arguments + true_mask: Numpy Array, Binary array of shape (H x W) representing the + ground truth mask. + pred_mask: Numpy Array. Binary array of shape (H x W) representing the + predicted mask. + bound_th: Float. Optional parameter to compute the F-measure. Default is + 0.008. + + # Returns + float: F-measure. + """ + true_mask = np.asarray(true_mask, dtype=np.bool) + pred_mask = np.asarray(pred_mask, dtype=np.bool) + + assert true_mask.shape == pred_mask.shape + + bound_pix = bound_th if bound_th >= 1 else (np.ceil( + bound_th * np.linalg.norm(true_mask.shape))) + + fg_boundary = _seg2bmap(pred_mask) + gt_boundary = _seg2bmap(true_mask) + + fg_dil = cv2.dilate( + fg_boundary.astype(np.uint8), + disk(bound_pix).astype(np.uint8)) + gt_dil = cv2.dilate( + gt_boundary.astype(np.uint8), + disk(bound_pix).astype(np.uint8)) + + # Get the intersection + gt_match = gt_boundary * fg_dil + fg_match = fg_boundary * gt_dil + + # Area of the intersection + n_fg = np.sum(fg_boundary) + n_gt = np.sum(gt_boundary) + + # Compute precision and recall + if n_fg == 0 and n_gt > 0: + precision = 1 + recall = 0 + elif n_fg > 0 and n_gt == 0: + precision = 0 + recall = 1 + elif n_fg == 0 and n_gt == 0: + precision = 1 + recall = 1 + else: + precision = np.sum(fg_match) / float(n_fg) + recall = np.sum(gt_match) / float(n_gt) + + # Compute F measure + if precision + recall == 0: + F = 0 + else: + F = 2 * precision * recall / (precision + recall) + + return F + + +def batched_f_measure(y_true, + y_pred, + average_over_objects=True, + nb_objects=None, + bound_th=0.008): + """ Batch F-measure for multiple instance segmentation. + + # Arguments + y_true: Numpy Array. Array of shape (B x H x W) and type integer giving + the ground truth of the object instance segmentation. + y_pred: Numpy Array. Array of shape (B x H x W) and type integer giving + the + prediction of the object segmentation. + average_over_objects: Boolean. Weather or not to average the F-measure + over all the objects in the sequence. Default True. + nb_objects: Integer. Number of objects in the ground truth mask. If + `None` the value will be infered from `y_true`. Setting this value + will speed up the computation. + + # Returns + ndarray: Returns an array of shape (B) with the average F-measure for + all instances at each frame if `average_over_objects=True`. If + `average_over_objects=False` returns an array of shape (B x nObj) + with nObj being the number of objects on `y_true`. + """ + y_true = np.asarray(y_true, dtype=np.int) + y_pred = np.asarray(y_pred, dtype=np.int) + if y_true.ndim != 3: + raise ValueError('y_true array must have 3 dimensions.') + if y_pred.ndim != 3: + raise ValueError('y_pred array must have 3 dimensions.') + if y_true.shape != y_pred.shape: + raise ValueError('y_true and y_pred must have the same shape. {} != {}'.format(y_true.shape, y_pred.shape)) + + if nb_objects is None: + objects_ids = np.unique(y_true[(y_true < 255) & (y_true > 0)]) + nb_objects = len(objects_ids) + else: + objects_ids = [i + 1 for i in range(nb_objects)] + objects_ids = np.asarray(objects_ids, dtype=np.int) + if nb_objects == 0: + raise ValueError('Number of objects in y_true should be higher than 0.') + nb_frames = len(y_true) + + f_measure_result = np.empty((nb_frames, nb_objects), dtype=np.float) + + for i, obj_id in enumerate(objects_ids): + for frame_id in range(nb_frames): + gt_mask = y_true[frame_id, :, :] == obj_id + pred_mask = y_pred[frame_id, :, :] == obj_id + f_measure_result[frame_id, i] = f_measure( + gt_mask, pred_mask, bound_th=bound_th) + + if average_over_objects: + f_measure_result = f_measure_result.mean(axis=1) + return f_measure_result diff --git a/util/tensor_util.py b/util/tensor_util.py index bdf1ec3..de21517 100644 --- a/util/tensor_util.py +++ b/util/tensor_util.py @@ -1,3 +1,4 @@ +import numpy as np import torch.nn.functional as F import torch @@ -14,6 +15,34 @@ def compute_tensor_iou(seg, gt): return iou +def compute_array_iou(seg, gt): + # grayscale 2D masks, each gray shade - unique object + seg = seg.squeeze() + gt = gt.squeeze() + + ious = [] + for color in np.unique(seg): + if color == 0: + continue # skipping background + + curr_object_iou = compute_tensor_iou( + torch.tensor(seg == color), + torch.tensor(gt == color), + ) + + ious.append(curr_object_iou) + + if not len(ious): + # GT is pure black, let's check if the mask also doesn't have any junk + curr_object_iou = compute_tensor_iou( + torch.tensor(seg == 0), + torch.tensor(gt == 0), + ) + + ious.append(curr_object_iou) + + return sum(ious) / len(ious) + # STM def pad_divide_by(in_img, d): h, w = in_img.shape[-2:] From 80434bad768b62bdaa9d98d29b99aa4c02a01c6d Mon Sep 17 00:00:00 2001 From: max810 Date: Fri, 10 Mar 2023 23:04:27 +0400 Subject: [PATCH 17/49] WIP: added code for running inference, saving masks to permanent memory and selecting candidates --- inference/frame_selection/frame_selection.py | 353 ++++++++---------- .../frame_selection/frame_selection_utils.py | 30 ++ inference/inference_core.py | 29 +- inference/interact/gui.py | 95 ++++- inference/interact/gui_utils.py | 77 +++- inference/interact/resource_manager.py | 35 ++ inference/memory_manager.py | 6 +- inference/run_experiments.py | 4 +- 8 files changed, 410 insertions(+), 219 deletions(-) diff --git a/inference/frame_selection/frame_selection.py b/inference/frame_selection/frame_selection.py index 5b2da04..1fce745 100644 --- a/inference/frame_selection/frame_selection.py +++ b/inference/frame_selection/frame_selection.py @@ -1,52 +1,19 @@ -from functools import partial -import json +from copy import copy from pathlib import Path +import time from typing import Any, Dict, List, Set, Tuple, Union +import cv2 import torch +import torch.nn.functional as F import torchvision.transforms.functional as FT import numpy as np -from sklearn.decomposition import PCA -from scipy.spatial.distance import cdist -from umap import UMAP -from hdbscan import flat from tqdm import tqdm +from inference.frame_selection.frame_selection_utils import extract_keys from model.memory_util import get_similarity -# -----------------------------CHOSEN FRAME SELECTORS--------------------------------------- - -# Utility -def _extract_keys(dataloder, processor, print_progress=False): - frame_keys = [] - shrinkages = [] - selections = [] - device = None - with torch.no_grad(): # just in case - key_sum = None - - for ti, data in enumerate(tqdm(dataloder, disable=not print_progress, desc='Calculating key features')): - rgb = data['rgb'].cuda()[0] - key, shrinkage, selection = processor.encode_frame_key(rgb) - - if key_sum is None: - device = key.device - # to avoid possible overflow - key_sum = torch.zeros_like( - key, device=device, dtype=torch.float64) - - key_sum += key.type(torch.float64) - - frame_keys.append(key.flatten(start_dim=2).cpu()) - shrinkages.append(shrinkage.flatten(start_dim=2).cpu()) - selections.append(selection.flatten(start_dim=2).cpu()) - - num_frames = ti + 1 # 0 after 1 iteration, 1 after 2, etc. - - return frame_keys, shrinkages, selections, device, num_frames, key_sum - - def first_frame_only(*args, **kwargs): # baseline return [0] @@ -59,71 +26,26 @@ def uniformly_selected_frames(dataloader, *args, how_many_frames=10, **kwargs) - return np.linspace(0, num_total_frames - 1, how_many_frames).astype(int).tolist() -def calculate_proposals_for_annotations_iterative_pca(dataloader, processor, how_many_frames=10, print_progress=False, distance_metric='euclidean') -> List[int]: - assert distance_metric in {'cosine', 'euclidean'} - # might not pick 0-th frame - np.random.seed(1) # Just in case +def calculate_proposals_for_annotations_with_iterative_distance_cycle_MASKS(dataloader, processor, existing_masks_path: str, how_many_frames=10, print_progress=False, mult_instead=False, alpha=1.0, too_small_mask_threshold_px=9, **kwargs) -> List[int]: with torch.no_grad(): - frame_keys, shrinkages, selections, device, num_frames, key_sum = _extract_keys( - dataloader, processor, print_progress) - flat_keys = torch.stack([key.flatten().cpu() - for key in frame_keys]).numpy() - - # PCA hangs at num_frames // 2: https://github.com/scikit-learn/scikit-learn/issues/22434 - pca = PCA(num_frames - 1, svd_solver='arpack') - smol_keys = pca.fit_transform(flat_keys.astype(np.float64)) - # smol_keys = flat_keys # to disable PCA - - chosen_frames = [0] - for c in range(how_many_frames - 1): - distances = cdist(smol_keys[chosen_frames], - smol_keys, metric=distance_metric) - closest_to_mem_key_distances = distances.min(axis=0) - most_distant_frame = np.argmax(closest_to_mem_key_distances) - chosen_frames.append(int(most_distant_frame)) - - return chosen_frames - - -def calculate_proposals_for_annotations_umap_half_hdbscan_clustering(dataloader, processor, how_many_frames=10, print_progress=False) -> List[int]: - # might not pick 0-th frame - with torch.no_grad(): - frame_keys, shrinkages, selections, device, num_frames, key_sum = _extract_keys( - dataloader, processor, print_progress) - flat_keys = torch.stack([key.flatten().cpu() - for key in frame_keys]).numpy() - - pca = UMAP(n_neighbors=num_frames - 1, - n_components=num_frames // 2, random_state=1) - smol_keys = pca.fit_transform(flat_keys) - - # clustering = AgglomerativeClustering(n_clusters=how_many_frames + 1, linkage='single') - clustering = flat.HDBSCAN_flat( - smol_keys, n_clusters=how_many_frames + 1) - labels = clustering.labels_ - - chosen_frames = [] - for c in range(how_many_frames): - vectors = smol_keys[labels == c] - true_index_mapping = {i: int(ti) for i, ti in zip( - range(len(vectors)), np.nonzero(labels == c)[0])} - center = np.mean(vectors, axis=0) - - # since HDBSCAN is density-based, it makes 0 sense to use anything but euclidean distance here - distances = cdist(vectors, [center], metric='euclidean').squeeze() - - closest_to_cluster_center_idx = np.argsort(distances)[0] - - chosen_frame_idx = true_index_mapping[closest_to_cluster_center_idx] - chosen_frames.append(chosen_frame_idx) - - return chosen_frames - - -def calculate_proposals_for_annotations_with_iterative_distance_cycle(dataloader, processor, how_many_frames=10, print_progress=False) -> List[int]: - with torch.no_grad(): - frame_keys, shrinkages, selections, device, num_frames, key_sum = _extract_keys( - dataloader, processor, print_progress) + frame_keys, shrinkages, selections, device, num_frames, key_sum = extract_keys(dataloader, processor, print_progress) + + h, w = frame_keys[0].squeeze().shape[1:3] # removing batch dimension + p_masks_dir = Path(existing_masks_path) + mask_sizes_px = [] + for i, p_img in enumerate(p_masks_dir.iterdir()): + img = cv2.imread(str(p_img)) + img = cv2.resize(img, (w, h)) / 255. + img_tensor = FT.to_tensor(img) + mask_size_px = (img_tensor > 0).sum() + mask_sizes_px.append(mask_size_px) + + if not mult_instead: + composite_key = torch.cat([frame_keys[i].cpu().squeeze(), img_tensor], dim=0) # along channels + else: + composite_key = frame_keys[i].cpu().squeeze() * img_tensor.max(dim=0, keepdim=True).values # all objects -> 1., background -> 0.. Keep 1 channel only + composite_key = composite_key * alpha + frame_keys[i].cpu().squeeze() * (1 - alpha) + frame_keys[i] = composite_key chosen_frames = [0] chosen_frames_mem_keys = [frame_keys[0].to(device)] @@ -133,33 +55,32 @@ def calculate_proposals_for_annotations_with_iterative_distance_cycle(dataloader # how to run a loop for lower memory usage for j in tqdm(range(num_frames), desc='Computing similarity to chosen frames', disable=not print_progress): qk = frame_keys[j].to(device) - query_selection = selections[j].to(device) # query - query_shrinkage = shrinkages[j].to(device) - - dissimilarities_across_mem_keys = [] - for key_idx, mem_key in zip(chosen_frames, chosen_frames_mem_keys): - mem_key = mem_key.to(device) - key_shrinkage = shrinkages[key_idx].to(device) - key_selection = selections[key_idx].to(device) - - similarity_per_pixel = get_similarity( - mem_key, ms=None, qk=qk, qe=None) - reverse_similarity_per_pixel = get_similarity( - qk, ms=None, qk=mem_key, qe=None) - - # mapping of pixels A -> B would be very similar to B -> A if the images are similar - # and very different if the images are different - cycle_dissimilarity_per_pixel = ( - similarity_per_pixel - reverse_similarity_per_pixel) - cycle_dissimilarity_score = cycle_dissimilarity_per_pixel.abs().sum() / \ - cycle_dissimilarity_per_pixel.numel() - - dissimilarities_across_mem_keys.append( - cycle_dissimilarity_score) - - # filtering our existing or very similar frames - dissimilarity_min_across_all = min( - dissimilarities_across_mem_keys) + + if mask_sizes_px[j] < too_small_mask_threshold_px: + dissimilarity_min_across_all = 0 + else: + dissimilarities_across_mem_keys = [] + for mem_key in chosen_frames_mem_keys: + mem_key = mem_key.to(device) + + similarity_per_pixel = get_similarity( + mem_key, ms=None, qk=qk, qe=None) + reverse_similarity_per_pixel = get_similarity( + qk, ms=None, qk=mem_key, qe=None) + + # mapping of pixels A -> B would be very similar to B -> A if the images are similar + # and very different if the images are different + cycle_dissimilarity_per_pixel = ( + similarity_per_pixel - reverse_similarity_per_pixel) + + cycle_dissimilarity_score = F.relu(cycle_dissimilarity_per_pixel).sum() / \ + cycle_dissimilarity_per_pixel.numel() + dissimilarities_across_mem_keys.append( + cycle_dissimilarity_score) + + # filtering our existing or very similar frames + dissimilarity_min_across_all = min(dissimilarities_across_mem_keys) + dissimilarities.append(dissimilarity_min_across_all) values, indices = torch.topk(torch.tensor( @@ -175,73 +96,109 @@ def calculate_proposals_for_annotations_with_iterative_distance_cycle(dataloader return chosen_frames -def calculate_proposals_for_annotations_with_iterative_distance_double_diff(dataloader, processor, how_many_frames=10, print_progress=False) -> List[int]: +def select_next_candidates(keys: torch.Tensor, masks: List[torch.tensor], num_next_candidates: int, previously_chosen_candidates: List[int] = (0,), print_progress=False, alpha=1.0, min_mask_presence_px=9, device: torch.device = 'cuda:0', progress_callback=None): + assert len(keys) == len(masks) + assert len(keys) > 0 + assert keys[0].shape[-2:] == masks[0].shape[-2:] + assert num_next_candidates > 0 + assert len(previously_chosen_candidates) > 0 + assert 0.0 <= alpha <= 1.0 + assert min_mask_presence_px >= 0 + assert len(previously_chosen_candidates) < len(keys) + + """ + Select candidate frames for annotation based on dissimilarity and cycle consistency. + + Parameters + ---------- + `keys` : `List[torch.Tensor]` + A list of "key" feature maps for all frames of the video. + `masks` : `List[torch.Tensor]` + A list of masks for each frame (predicted or user-provided). + `num_next_candidates` : `int` + The number of candidate frames to select. + `previously_chosen_candidates` : `List[int]`, optional + A list of previously chosen candidates. Default is (0,). + `print_progress` : `bool`, optional + Whether to print progress information. Default is False. + `alpha` : `float`, optional + The weight for cycle consistency in the candidate selection process. Default is 1.0. + `min_mask_presence_px` : `int`, optional + The minimum number of pixels for a valid mask. Default is 9. + + Returns + ------- + `List[int]` + A list of indices of the selected candidate frames. + + Notes + ----- + This function uses a dissimilarity measure and cycle consistency to select candidate frames for the user to annotate. + The dissimilarity measure ensures that the selected frames are as diverse as possible, while the cycle consistency + ensures that the dissimilarity D(A->A)=0, while D(A->B)>0, and is larger the more different A and B are (pixel-wise). + + """ with torch.no_grad(): - frame_keys, shrinkages, selections, device, num_frames, key_sum = _extract_keys( - dataloader, processor, print_progress) - - chosen_frames = [0] - chosen_frames_mem_keys = [frame_keys[0]] - # chosen_frames_self_similarities = [] - for c in range(how_many_frames - 1): - dissimilarities = [] - # how to run a loop for lower memory usage - for i in tqdm(range(num_frames), desc=f'Computing similarity to chosen frames', disable=not print_progress): - true_frame_idx = i - qk = frame_keys[true_frame_idx].to(device) - query_selection = selections[true_frame_idx].to( - device) # query - query_shrinkage = shrinkages[true_frame_idx].to(device) - - dissimilarities_across_mem_keys = [] - for key_idx, mem_key in zip(chosen_frames, chosen_frames_mem_keys): - mem_key = mem_key.to(device) - key_shrinkage = shrinkages[key_idx].to(device) - key_selection = selections[key_idx].to(device) - - similarity_per_pixel = get_similarity( - mem_key, ms=key_shrinkage, qk=qk, qe=query_selection) - self_similarity_key = get_similarity( - mem_key, ms=key_shrinkage, qk=mem_key, qe=key_selection) - self_similarity_query = get_similarity( - qk, ms=query_shrinkage, qk=qk, qe=query_selection) - - # mapping of pixels A -> B would be very similar to B -> A if the images are similar - # and very different if the images are different - - pure_similarity = 2 * similarity_per_pixel - \ - self_similarity_key - self_similarity_query - - dissimilarity_score = pure_similarity.abs().sum() / pure_similarity.numel() - - dissimilarities_across_mem_keys.append(dissimilarity_score) - - # filtering our existing or very similar frames - dissimilarity_min_across_all = min( - dissimilarities_across_mem_keys) - dissimilarities.append(dissimilarity_min_across_all) - - values, indices = torch.topk(torch.tensor( - dissimilarities), k=1, largest=True) - chosen_new_frame = int(indices[0]) - - chosen_frames.append(chosen_new_frame) - chosen_frames_mem_keys.append( - frame_keys[chosen_new_frame].to(device)) - - # we don't need to worry about 1st frame with itself, since we take the LEAST similar frames - return chosen_frames - - -KNOWN_ANNOTATION_PREDICTORS = { - 'PCA_EUCLIDEAN': partial(calculate_proposals_for_annotations_iterative_pca, distance_metric='euclidean'), - 'PCA_COSINE': partial(calculate_proposals_for_annotations_iterative_pca, distance_metric='cosine'), - 'UMAP_EUCLIDEAN': calculate_proposals_for_annotations_umap_half_hdbscan_clustering, - 'INTERNAL_CYCLE_CONSISTENCY': calculate_proposals_for_annotations_with_iterative_distance_cycle, - 'INTERNAL_DOUBLE_DIFF': calculate_proposals_for_annotations_with_iterative_distance_double_diff, - - 'FIRST_FRAME_ONLY': first_frame_only, # ignores the number of candidates, baseline - 'UNIFORM': uniformly_selected_frames # baseline -} - -# ------------------------END CHOSEN----------------------------------------------- + composite_keys = [] + keys = keys.squeeze() + N = len(keys) + h, w = keys[0].shape[1:3] # removing batch dimension + masks_validity = np.full(N, True) + + for i, mask in enumerate(masks): + mask_size_px = (mask > 0).sum() + + if mask_size_px < min_mask_presence_px: + masks_validity[i] = False + composite_keys.append(None) + + composite_key = keys[i] * mask.max(dim=0, keepdim=True).values # any object -> 1., background -> 0.. Keep 1 channel only + composite_key = composite_key * alpha + keys[i] * (1 - alpha) + + composite_keys.append(composite_key.to(device)) + + chosen_candidates = list(previously_chosen_candidates) + chosen_candidate_keys = [composite_keys[i] for i in chosen_candidates] + + for i in tqdm(range(num_next_candidates), desc='Iteratively picking the most dissimilar frames', disable=not print_progress): + candidate_dissimilarities = [] + for j in tqdm(range(N), desc='Computing similarity to chosen frames', disable=not print_progress): + qk = composite_keys[j].to(device) + + if not masks_validity[j]: + # ignore this potential candidate + dissimilarity_min_across_all = 0 + else: + dissimilarities_across_mem_keys = [] + for mem_key in chosen_candidate_keys: + mem_key = mem_key + + similarity_per_pixel = get_similarity(mem_key, ms=None, qk=qk, qe=None) + reverse_similarity_per_pixel = get_similarity(qk, ms=None, qk=mem_key, qe=None) + + # mapping of pixels A -> B would be very similar to B -> A if the images are similar + # and very different if the images are different + cycle_dissimilarity_per_pixel = (similarity_per_pixel - reverse_similarity_per_pixel) + + # Take non-negative mappings, normalize by tensor size + cycle_dissimilarity_score = F.relu(cycle_dissimilarity_per_pixel).sum() / cycle_dissimilarity_per_pixel.numel() + + dissimilarities_across_mem_keys.append(cycle_dissimilarity_score) + + # filtering our existing or very similar frames + # if the key has already been used or is very similar to at least one of the chosen candidates + # dissimilarity_min_across_all -> 0 (or close to) + dissimilarity_min_across_all = min(dissimilarities_across_mem_keys) + + candidate_dissimilarities.append(dissimilarity_min_across_all) + + index = torch.argmax(torch.tensor(candidate_dissimilarities)) + chosen_new_frame = int(index) + + chosen_candidates.append(chosen_new_frame) + chosen_candidate_keys.append(composite_keys[chosen_new_frame]) + + if progress_callback is not None: + progress_callback.emit(i + 1) + + return chosen_candidates \ No newline at end of file diff --git a/inference/frame_selection/frame_selection_utils.py b/inference/frame_selection/frame_selection_utils.py index 14a3449..59865d7 100644 --- a/inference/frame_selection/frame_selection_utils.py +++ b/inference/frame_selection/frame_selection_utils.py @@ -8,6 +8,36 @@ from torchvision.transforms import ColorJitter, Grayscale, RandomPosterize, RandomAdjustSharpness, ToTensor, RandomAffine +# +def extract_keys(dataloder, processor, print_progress=False): + frame_keys = [] + shrinkages = [] + selections = [] + device = None + with torch.no_grad(): # just in case + key_sum = None + + for ti, data in enumerate(tqdm(dataloder, disable=not print_progress, desc='Calculating key features')): + rgb = data['rgb'].cuda()[0] + key, shrinkage, selection = processor.encode_frame_key(rgb) + + if key_sum is None: + device = key.device + # to avoid possible overflow + key_sum = torch.zeros_like( + key, device=device, dtype=torch.float64) + + key_sum += key.type(torch.float64) + + frame_keys.append(key.flatten(start_dim=2).cpu()) + shrinkages.append(shrinkage.flatten(start_dim=2).cpu()) + selections.append(selection.flatten(start_dim=2).cpu()) + + num_frames = ti + 1 # 0 after 1 iteration, 1 after 2, etc. + + return frame_keys, shrinkages, selections, device, num_frames, key_sum + + def select_n_frame_candidates(preds_df: pd.DataFrame, uncertainty_name: str, n=5): df = preds_df diff --git a/inference/inference_core.py b/inference/inference_core.py index d53e062..9c225fc 100644 --- a/inference/inference_core.py +++ b/inference/inference_core.py @@ -48,7 +48,7 @@ def encode_frame_key(self, image): need_sk=True) return key, shrinkage, selection - def step(self, image, mask=None, valid_labels=None, end=False, manually_curated_masks=False, disable_memory_updates=False, do_not_add_mask_to_memory=False): + def step(self, image, mask=None, valid_labels=None, end=False, manually_curated_masks=False, disable_memory_updates=False, do_not_add_mask_to_memory=False, return_key=False): # For feedback: # 1. We run the model as usual # 2. We get feedback: 2 lists, one with good prediction indices, one with bad @@ -132,10 +132,15 @@ def step(self, image, mask=None, valid_labels=None, end=False, manually_curated_ if is_deep_update: self.memory.set_hidden(hidden) self.last_deep_update_ti = self.curr_ti - - return unpad(pred_prob_with_bg, self.pad) - def put_to_permanent_memory(self, image, mask): + res = unpad(pred_prob_with_bg, self.pad) + + if return_key: + return res, key + else: + return res + + def put_to_permanent_memory(self, image, mask, ti=None): image, self.pad = pad_divide_by(image, 16) image = image.unsqueeze(0) # add the batch dimension key, shrinkage, selection, f16, f8, f4 = self.network.encode_key(image, @@ -149,6 +154,16 @@ def put_to_permanent_memory(self, image, mask): value, hidden = self.network.encode_value(image, f16, self.memory.get_hidden(), pred_prob_with_bg[1:].unsqueeze(0), is_deep_update=False) - - self.memory.add_memory(key, shrinkage, value, self.all_labels, - selection=selection if self.enable_long_term else None, permanent=True) + + if self.memory.frame_already_saved(ti): + # self.memory.update_permanent_memory(ti, ) + # maybe delete and update? + # TODO: splice the memory, update existing one + pass + else: + self.memory.add_memory(key, shrinkage, value, self.all_labels, + selection=selection if self.enable_long_term else None, permanent=True) + + @property + def permanent_memory_frames(self): + return list(self.memory.frame_id_to_mem_idx.keys()) \ No newline at end of file diff --git a/inference/interact/gui.py b/inference/interact/gui.py index 679e90e..8435295 100644 --- a/inference/interact/gui.py +++ b/inference/interact/gui.py @@ -16,6 +16,8 @@ import os import cv2 + +from inference.frame_selection.frame_selection import select_next_candidates # fix conflicts between qt5 and cv2 os.environ.pop("QT_QPA_PLATFORM_PLUGIN_PATH") @@ -24,10 +26,10 @@ from PyQt5.QtWidgets import (QWidget, QApplication, QComboBox, QCheckBox, QHBoxLayout, QLabel, QPushButton, QTextEdit, QSpinBox, QFileDialog, - QPlainTextEdit, QVBoxLayout, QSizePolicy, QButtonGroup, QSlider, QShortcut, QRadioButton) + QPlainTextEdit, QVBoxLayout, QSizePolicy, QButtonGroup, QSlider, QShortcut, QRadioButton, QTabWidget, QDialog) from PyQt5.QtGui import QPixmap, QKeySequence, QImage, QTextCursor, QIcon -from PyQt5.QtCore import Qt, QTimer +from PyQt5.QtCore import Qt, QTimer, QThreadPool from model.network import XMem @@ -56,6 +58,7 @@ def __init__(self, net: XMem, self.processor = InferenceCore(net, config) self.processor.set_all_labels(list(range(1, self.num_objects+1))) self.res_man = resource_manager + self.threadpool = QThreadPool() self.num_frames = len(self.res_man) self.height, self.width = self.res_man.h, self.res_man.w @@ -70,6 +73,10 @@ def __init__(self, net: XMem, self.play_button.clicked.connect(self.on_play_video) self.commit_button = QPushButton('Commit') self.commit_button.clicked.connect(self.on_commit) + self.save_reference_button = QPushButton('Save reference') + self.save_reference_button.clicked.connect(self.on_save_reference) + self.compute_candidates_button = QPushButton('Compute Annotation candidates') + self.compute_candidates_button.clicked.connect(self.on_compute_candidates) self.forward_run_button = QPushButton('Forward Propagate') self.forward_run_button.clicked.connect(self.on_forward_propagation) @@ -233,6 +240,13 @@ def __init__(self, net: XMem, navi.addWidget(QLabel('Save overlay during propagation')) navi.addWidget(self.save_visualization_checkbox) navi.addStretch(1) + + self.test_btn = QPushButton('TEST') + self.test_btn.clicked.connect(self.TEST) + + navi.addWidget(self.test_btn) + navi.addWidget(self.save_reference_button) + navi.addWidget(self.compute_candidates_button) navi.addWidget(self.commit_button) navi.addWidget(self.forward_run_button) navi.addWidget(self.backward_run_button) @@ -241,8 +255,17 @@ def __init__(self, net: XMem, draw_area = QHBoxLayout() draw_area.addWidget(self.main_canvas, 4) + self.tabs = QTabWidget() + self.map_tab = QWidget() + self.references_tab = QWidget() + + self.tabs.addTab(self.map_tab,"Minimap && Stats") + self.tabs.addTab(self.references_tab,"References && Candidates") + + tabs_layout = QVBoxLayout() + # Minimap area - minimap_area = QVBoxLayout() + minimap_area = QVBoxLayout(self.map_tab) minimap_area.setAlignment(Qt.AlignTop) mini_label = QLabel('Minimap') mini_label.setAlignment(Qt.AlignTop) @@ -275,10 +298,12 @@ def __init__(self, net: XMem, import_area.addWidget(self.import_layer_button) minimap_area.addLayout(import_area) - # console - minimap_area.addWidget(self.console) + chosen_figures_area = QVBoxLayout(self.references_tab) + chosen_figures_area.addWidget(QLabel("TEST TEST FIGURES")) - draw_area.addLayout(minimap_area, 1) + tabs_layout.addWidget(self.tabs) + tabs_layout.addWidget(self.console) + draw_area.addLayout(tabs_layout, 1) layout = QVBoxLayout() layout.addLayout(draw_area) @@ -312,6 +337,8 @@ def __init__(self, net: XMem, self.brush_vis_alpha = np.zeros((self.height, self.width, 1), dtype=np.float32) self.cursur = 0 self.on_showing = None + self.reference_ids = [] + self.candidates_ids = [] # Zoom parameters self.zoom_pixels = 150 @@ -349,6 +376,10 @@ def __init__(self, net: XMem, self.console_push_text('Initialized.') self.initialized = True + def TEST(self): + print(self.res_man.all_masks_present()) + pass + def resizeEvent(self, event): self.show_current_frame() @@ -459,6 +490,10 @@ def show_current_frame(self, fast=False): self.lcd.setText('{: 3d} / {: 3d}'.format(self.cursur, self.num_frames-1)) self.tl_slider.setValue(self.cursur) + def show_candidates(self): + # TODO: draw image grid + pass + def pixel_pos_to_image_pos(self, x, y): # Un-scale and un-pad the label coordinates into image coordinates oh, ow = self.image_size.height(), self.image_size.width() @@ -573,10 +608,13 @@ def on_pause(self): def on_propagation(self): # start to propagate self.load_current_torch_image_mask() + # TODO: put into permanent memory self.show_current_frame(fast=True) self.console_push_text('Propagation started.') - self.current_prob = self.processor.step(self.current_image_torch, self.current_prob[1:]) + self.current_prob, key = self.processor.step(self.current_image_torch, self.current_prob[1:], return_key=True) + self.res_man.add_key_with_mask(self.cursur, key, self.current_prob) + self.current_mask = torch_prob_to_numpy_mask(self.current_prob) # clear self.interacted_prob = None @@ -591,7 +629,10 @@ def on_propagation(self): self.load_current_image_mask(no_mask=True) self.load_current_torch_image_mask(no_mask=True) - self.current_prob = self.processor.step(self.current_image_torch) + # TODO: read existing mask (if there is one), pass to .step() + self.current_prob, key = self.processor.step(self.current_image_torch, return_key=True) + self.res_man.add_key_with_mask(self.cursur, key, self.current_prob) + self.current_mask = torch_prob_to_numpy_mask(self.current_prob) self.save_current_mask() @@ -602,7 +643,8 @@ def on_propagation(self): if self.cursur == 0 or self.cursur == self.num_frames-1: break - + + # TODO: after finished, compute candidates and show maybe? self.propagating = False self.curr_frame_dirty = False self.on_pause() @@ -616,6 +658,41 @@ def on_commit(self): self.complete_interaction() self.update_interacted_mask() + def on_compute_candidates(self): + def _update_candidates(candidates_ids): + print(candidates_ids) + self.candidates_ids = candidates_ids + + def _update_progress(i): + candidate_progress.setValue(i) + + k = 5 + # self.candidate_progress.setMaximum(k) + # self.candidate_progress.exec_() + # Candidate progress dialog + + # time.sleep(5) + # my_dialog.close() + candidate_progress = QProgressDialog("Selecting candidates", None, 0, k, self, Qt.WindowFlags(Qt.WindowType.Dialog | ~Qt.WindowCloseButtonHint)) + worker = Worker(select_next_candidates, self.res_man.keys, self.res_man.small_masks, k, self.reference_ids, print_progress=False, alpha=0.5, min_mask_presence_px=9) # Any other args, kwargs are passed to the run function + worker.signals.result.connect(_update_candidates) + worker.signals.progress.connect(_update_progress) + + self.threadpool.start(worker) + + candidate_progress.open() + + def on_save_reference(self): + # TODO: update permanent memory. Add if new, replace if already existed + current_image_torch, _ = image_to_torch(self.current_image) + current_prob = index_numpy_to_one_hot_torch(self.current_mask, self.num_objects+1).cuda() + + self.processor.put_to_permanent_memory(current_image_torch, current_prob, self.cursur) + + self.reference_ids.append(self.cursur) + + # TODO: remove from candidates if was there + def on_prev_frame(self): # self.tl_slide will trigger on setValue self.cursur = max(0, self.cursur-1) diff --git a/inference/interact/gui_utils.py b/inference/interact/gui_utils.py index daf852b..b4f9fed 100644 --- a/inference/interact/gui_utils.py +++ b/inference/interact/gui_utils.py @@ -1,5 +1,78 @@ -from PyQt5.QtCore import Qt -from PyQt5.QtWidgets import (QHBoxLayout, QLabel, QSpinBox, QVBoxLayout, QProgressBar) +from typing import Optional, Union +import time +import traceback, sys + +from PyQt5.QtCore import Qt, QRunnable, pyqtSlot, pyqtSignal, QObject +from PyQt5.QtWidgets import (QHBoxLayout, QLabel, QSpinBox, QVBoxLayout, QProgressBar, QDialog, QWidget, QProgressDialog) + +class WorkerSignals(QObject): + ''' + Defines the signals available from a running worker thread. + + Supported signals are: + + finished + No data + + error + tuple (exctype, value, traceback.format_exc() ) + + result + object data returned from processing, anything + + progress + int indicating % progress + + ''' + finished = pyqtSignal() + error = pyqtSignal(tuple) + result = pyqtSignal(object) + progress = pyqtSignal(int) + + +class Worker(QRunnable): + ''' + Worker thread + + Inherits from QRunnable to handler worker thread setup, signals and wrap-up. + + :param callback: The function callback to run on this worker thread. Supplied args and + kwargs will be passed through to the runner. + :type callback: function + :param args: Arguments to pass to the callback function + :param kwargs: Keywords to pass to the callback function + + ''' + + def __init__(self, fn, *args, **kwargs): + super(Worker, self).__init__() + + # Store constructor arguments (re-used for processing) + self.fn = fn + self.args = args + self.kwargs = kwargs + self.signals = WorkerSignals() + + # Add the callback to our kwargs + self.kwargs['progress_callback'] = self.signals.progress + + @pyqtSlot() + def run(self): + ''' + Initialise the runner function with passed args, kwargs. + ''' + + # Retrieve args/kwargs here; and fire processing using them + try: + result = self.fn(*self.args, **self.kwargs) + except: + traceback.print_exc() + exctype, value = sys.exc_info()[:2] + self.signals.error.emit((exctype, value, traceback.format_exc())) + else: + self.signals.result.emit(result) # Return the result of the processing + finally: + self.signals.finished.emit() # Done def create_parameter_box(min_val, max_val, text, step=1, callback=None): diff --git a/inference/interact/resource_manager.py b/inference/interact/resource_manager.py index b0f28af..ef36830 100644 --- a/inference/interact/resource_manager.py +++ b/inference/interact/resource_manager.py @@ -5,6 +5,9 @@ import cv2 from PIL import Image +import torch +from torchvision.transforms import Resize, InterpolationMode + if not hasattr(Image, 'Resampling'): # Pillow<9.0 Image.Resampling = Image import numpy as np @@ -100,6 +103,11 @@ def __init__(self, config): self.height, self.width = self.get_image(0).shape[:2] self.visualization_init = False + self._resize = None + self._small_masks = None + self._keys = None + self._keys_processed = np.zeros(self.length, dtype=bool) + def _extract_frames(self, video): cap = cv2.VideoCapture(video) frame_index = 0 @@ -138,6 +146,24 @@ def _copy_resize_frames(self, images): cv2.imwrite(path.join(self.image_dir, image_name), frame) print('Done!') + def add_key_with_mask(self, ti, key, mask): + if self._keys is None: + c, h, w = key.squeeze().shape + c_mask = mask.shape[0] + self._keys = torch.empty((self.length, c, h, w), dtype=key.dtype, device=key.device) + self._small_masks = torch.empty((self.length, c_mask, h, w), dtype=mask.dtype, device=key.device) + self._resize = Resize((h, w), interpolation=InterpolationMode.NEAREST) + + if not self._keys_processed[ti]: + # keys don't change for the video, so we only save them once + self._keys[ti] = key + self._keys_processed[ti] = True + + self._small_masks[ti] = self._resize(mask) + + def all_masks_present(self): + return self._keys_processed.sum() == self.length + def save_mask(self, ti, mask): # mask should be uint8 H*W without channels assert 0 <= ti < self.length @@ -204,3 +230,12 @@ def h(self): @property def w(self): return self.width + + @property + def small_masks(self): + return self._small_masks + + @property + def keys(self): + return self._keys + diff --git a/inference/memory_manager.py b/inference/memory_manager.py index 0e721b0..cbe6a35 100644 --- a/inference/memory_manager.py +++ b/inference/memory_manager.py @@ -32,6 +32,7 @@ def __init__(self, config): self.temporary_work_mem = KeyValueMemoryStore(count_usage=self.enable_long_term) self.permanent_work_mem = KeyValueMemoryStore(count_usage=False) + self.frame_id_to_mem_idx = dict() if self.enable_long_term: self.long_mem = KeyValueMemoryStore(count_usage=self.enable_long_term_usage) @@ -229,7 +230,7 @@ def add_memory(self, key, shrinkage, value, objects, selection=None, permanent=F if not self.temporary_work_mem.engaged() or (num_temp_groups != num_perm_groups): # print(f"PERM_NUM_GROUPS={num_perm_groups} vs TEMP_NUM_GROUPS={num_temp_groups}", end=' ') - + # first frame or new group; we need to have both memories engaged to avoid crashes when concating # so we just initialize the temporary one with an empty tensor key0 = key[..., 0:0] @@ -275,6 +276,9 @@ def set_hidden(self, hidden): def get_hidden(self): return self.hidden + + def frame_already_saved(self, ti): + return ti in self.frame_id_to_mem_idx # def slices_excluding_permanent(self, group_value, start, end): # HW = self.HW diff --git a/inference/run_experiments.py b/inference/run_experiments.py index db6607c..d980d44 100644 --- a/inference/run_experiments.py +++ b/inference/run_experiments.py @@ -13,7 +13,7 @@ from util.metrics import batched_f_measure, batched_jaccard from p_tqdm import p_umap -from inference.frame_selection.frame_selection import KNOWN_ANNOTATION_PREDICTORS +# from inference.frame_selection.frame_selection import KNOWN_ANNOTATION_PREDICTORS from inference.run_on_video import predict_annotation_candidates, run_on_video # ---------------BEGIN Inference and visualization utils -------------------------- @@ -137,7 +137,7 @@ def get_videos_info(): } -def run_multiple_frame_selectors(videos_info: Dict[str, Dict], csv_output_path: str, predictors: Dict[str, callable] = KNOWN_ANNOTATION_PREDICTORS, load_existing_masks=False): +def run_multiple_frame_selectors(videos_info: Dict[str, Dict], csv_output_path: str, predictors: Dict[str, callable] = None, load_existing_masks=False): output = pd.DataFrame(columns=list(predictors)) p_bar = tqdm(total=len(videos_info) * len(predictors)) From f26193d104524c304e3385b130d99e2df3a49b5d Mon Sep 17 00:00:00 2001 From: max810 Date: Sat, 11 Mar 2023 17:53:20 +0400 Subject: [PATCH 18/49] WIP: now showing references and candidates, scroll+click works --- inference/interact/gui.py | 50 +++++++-- inference/interact/gui_utils.py | 192 +++++++++++++++++++++++++++++++- 2 files changed, 231 insertions(+), 11 deletions(-) diff --git a/inference/interact/gui.py b/inference/interact/gui.py index 8435295..514a9eb 100644 --- a/inference/interact/gui.py +++ b/inference/interact/gui.py @@ -299,7 +299,13 @@ def __init__(self, net: XMem, minimap_area.addLayout(import_area) chosen_figures_area = QVBoxLayout(self.references_tab) - chosen_figures_area.addWidget(QLabel("TEST TEST FIGURES")) + chosen_figures_area.addWidget(QLabel("SAVED REFERENCES IN PERMANENT MEMORY")) + self.references_collection = ImageLinkCollection(self.scroll_to, self.load_current_image_thumbnail) + chosen_figures_area.addWidget(self.references_collection) + + self.candidates_collection = ImageLinkCollection(self.scroll_to, self.load_current_image_thumbnail) + chosen_figures_area.addWidget(QLabel("ANNOTATION CANDIDATES")) + chosen_figures_area.addWidget(self.candidates_collection) tabs_layout.addWidget(self.tabs) tabs_layout.addWidget(self.console) @@ -337,7 +343,7 @@ def __init__(self, net: XMem, self.brush_vis_alpha = np.zeros((self.height, self.width, 1), dtype=np.float32) self.cursur = 0 self.on_showing = None - self.reference_ids = [] + self.reference_ids = set() self.candidates_ids = [] # Zoom parameters @@ -477,6 +483,19 @@ def update_current_image_fast(self): qImg = QImage(self.viz.data, width, height, bytesPerLine, QImage.Format_RGB888) self.main_canvas.setPixmap(QPixmap(qImg.scaled(self.main_canvas.size(), Qt.KeepAspectRatio, Qt.FastTransformation))) + + def load_current_image_thumbnail(self, size=128): + curr_pixmap = self.main_canvas.pixmap() + curr_size = curr_pixmap.size() + h = curr_size.height() + w = curr_size.width() + + if h < w: + thumbnail = curr_pixmap.scaledToHeight(size) + else: + thumbnail = curr_pixmap.scaledToWidth(size) + + return thumbnail def show_current_frame(self, fast=False): # Re-compute overlay and show the image @@ -566,6 +585,11 @@ def tl_slide(self): self.load_current_image_mask() self.show_current_frame() + def scroll_to(self, idx): + assert self.tl_slider.minimum() <= idx <= self.tl_slider.maximum() + self.tl_slider.setValue(idx) + self.tl_slide() + def brush_slide(self): self.brush_size = self.brush_slider.value() self.brush_label.setText('Brush size: %d' % self.brush_size) @@ -661,18 +685,21 @@ def on_commit(self): def on_compute_candidates(self): def _update_candidates(candidates_ids): print(candidates_ids) + for i in self.candidates_ids: + # removing any old candidates left + self.candidates_collection.remove_image(i) self.candidates_ids = candidates_ids + prev_pos = self.cursur + for i in self.candidates_ids: + self.scroll_to(i) + self.candidates_collection.add_image(i) + self.scroll_to(prev_pos) + def _update_progress(i): candidate_progress.setValue(i) k = 5 - # self.candidate_progress.setMaximum(k) - # self.candidate_progress.exec_() - # Candidate progress dialog - - # time.sleep(5) - # my_dialog.close() candidate_progress = QProgressDialog("Selecting candidates", None, 0, k, self, Qt.WindowFlags(Qt.WindowType.Dialog | ~Qt.WindowCloseButtonHint)) worker = Worker(select_next_candidates, self.res_man.keys, self.res_man.small_masks, k, self.reference_ids, print_progress=False, alpha=0.5, min_mask_presence_px=9) # Any other args, kwargs are passed to the run function worker.signals.result.connect(_update_candidates) @@ -689,8 +716,13 @@ def on_save_reference(self): self.processor.put_to_permanent_memory(current_image_torch, current_prob, self.cursur) - self.reference_ids.append(self.cursur) + self.reference_ids.add(self.cursur) + self.references_collection.add_image(self.cursur) + if self.cursur in self.candidates_ids: + self.candidates_ids.remove(self.cursur) + + self.candidates_collection.remove_image(self.cursur) # TODO: remove from candidates if was there def on_prev_frame(self): diff --git a/inference/interact/gui_utils.py b/inference/interact/gui_utils.py index b4f9fed..288f742 100644 --- a/inference/interact/gui_utils.py +++ b/inference/interact/gui_utils.py @@ -1,9 +1,10 @@ +from functools import partial from typing import Optional, Union import time import traceback, sys -from PyQt5.QtCore import Qt, QRunnable, pyqtSlot, pyqtSignal, QObject -from PyQt5.QtWidgets import (QHBoxLayout, QLabel, QSpinBox, QVBoxLayout, QProgressBar, QDialog, QWidget, QProgressDialog) +from PyQt5.QtCore import Qt, QRunnable, pyqtSlot, pyqtSignal, QObject, QPoint, QRect, QSize +from PyQt5.QtWidgets import (QHBoxLayout, QLabel, QSpinBox, QVBoxLayout, QProgressBar, QDialog, QWidget, QProgressDialog, QScrollArea, QLayout, QLayoutItem, QStyle, QSizePolicy, QSpacerItem, QFrame, QPushButton) class WorkerSignals(QObject): ''' @@ -111,3 +112,190 @@ def create_gauge(text): layout.addWidget(gauge) return gauge, layout + + +class FlowLayout(QLayout): + def __init__(self, parent: QWidget=None, margin: int=-1, hSpacing: int=-1, vSpacing: int=-1): + super().__init__(parent) + + self.itemList = list() + self.m_hSpace = hSpacing + self.m_vSpace = vSpacing + + self.setContentsMargins(margin, margin, margin, margin) + + def __del__(self): + # copied for consistency, not sure this is needed or ever called + item = self.takeAt(0) + while item: + item = self.takeAt(0) + + def addItem(self, item: QLayoutItem): + self.itemList.append(item) + + def horizontalSpacing(self) -> int: + if self.m_hSpace >= 0: + return self.m_hSpace + else: + return self.smartSpacing(QStyle.PM_LayoutHorizontalSpacing) + + def verticalSpacing(self) -> int: + if self.m_vSpace >= 0: + return self.m_vSpace + else: + return self.smartSpacing(QStyle.PM_LayoutVerticalSpacing) + + def count(self) -> int: + return len(self.itemList) + + def itemAt(self, index: int) -> Union[QLayoutItem, None]: + if 0 <= index < len(self.itemList): + return self.itemList[index] + else: + return None + + def takeAt(self, index: int) -> Union[QLayoutItem, None]: + if 0 <= index < len(self.itemList): + return self.itemList.pop(index) + else: + return None + + def expandingDirections(self) -> Qt.Orientations: + return Qt.Orientations(Qt.Orientation(0)) + + def hasHeightForWidth(self) -> bool: + return True + + def heightForWidth(self, width: int) -> int: + height = self.doLayout(QRect(0, 0, width, 0), True) + return height + + def setGeometry(self, rect: QRect) -> None: + super().setGeometry(rect) + self.doLayout(rect, False) + + def sizeHint(self) -> QSize: + return self.minimumSize() + + def minimumSize(self) -> QSize: + size = QSize() + for item in self.itemList: + size = size.expandedTo(item.minimumSize()) + + margins = self.contentsMargins() + size += QSize(margins.left() + margins.right(), margins.top() + margins.bottom()) + return size + + def smartSpacing(self, pm: QStyle.PixelMetric) -> int: + parent = self.parent() + if not parent: + return -1 + elif parent.isWidgetType(): + return parent.style().pixelMetric(pm, None, parent) + else: + return parent.spacing() + + def doLayout(self, rect: QRect, testOnly: bool) -> int: + left, top, right, bottom = self.getContentsMargins() + effectiveRect = rect.adjusted(+left, +top, -right, -bottom) + x = effectiveRect.x() + y = effectiveRect.y() + lineHeight = 0 + + for item in self.itemList: + wid = item.widget() + spaceX = self.horizontalSpacing() + if spaceX == -1: + spaceX = wid.style().layoutSpacing(QSizePolicy.PushButton, QSizePolicy.PushButton, Qt.Horizontal) + spaceY = self.verticalSpacing() + if spaceY == -1: + spaceY = wid.style().layoutSpacing(QSizePolicy.PushButton, QSizePolicy.PushButton, Qt.Vertical) + + nextX = x + item.sizeHint().width() + spaceX + if nextX - spaceX > effectiveRect.right() and lineHeight > 0: + x = effectiveRect.x() + y = y + lineHeight + spaceY + nextX = x + item.sizeHint().width() + spaceX + lineHeight = 0 + + if not testOnly: + item.setGeometry(QRect(QPoint(x, y), item.sizeHint())) + + x = nextX + lineHeight = max(lineHeight, item.sizeHint().height()) + + return y + lineHeight - rect.y() + bottom + + +class JFlowLayout(FlowLayout): + # flow layout, similar to an HTML `
` + # this is our "wrapper" to the `FlowLayout` sample Qt code we have implemented + # we use it in place of where we used to use a `QHBoxLayout` + # in order to make few outside-world changes, and revert to `QHBoxLayout`if we ever want to, + # there are a couple of methods here which are available on a `QBoxLayout` but not on a `QLayout` + # for which we provide a "lite-equivalent" which will suffice for our purposes + + def addLayout(self, layout: QLayout, stretch: int=0): + # "equivalent" of `QBoxLayout.addLayout()` + # we want to add sub-layouts (e.g. a `QVBoxLayout` holding a label above a widget) + # there is some dispute as to how to do this/whether it is supported by `FlowLayout` + # see my https://forum.qt.io/topic/104653/how-to-do-a-no-break-qhboxlayout + # there is a suggestion that we should not add a sub-layout but rather enclose it in a `QWidget` + # but since it seems to be working as I've done it below I'm elaving it at that for now... + + # suprisingly to me, we do not need to add the layout via `addChildLayout()`, that seems to make no difference + # self.addChildLayout(layout) + # all that seems to be reuqired is to add it onto the list via `addItem()` + self.addItem(layout) + + def addStretch(self, stretch: int=0): + # "equivalent" of `QBoxLayout.addStretch()` + # we can't do stretches, we just arbitrarily put in a "spacer" to give a bit of a gap + w = stretch * 20 + spacerItem = QSpacerItem(w, 0, QSizePolicy.Expanding, QSizePolicy.Minimum) + self.addItem(spacerItem) + + +class ClickableLabel(QLabel): + clicked = pyqtSignal() + def mouseReleaseEvent(self, event): + super(ClickableLabel, self).mousePressEvent(event) + if event.button() == Qt.LeftButton and event.pos() in self.rect(): + self.clicked.emit() + +class ImageLinkCollection(QWidget): + def __init__(self, on_click: callable, load_image: callable, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.on_click = on_click + self.load_image = load_image + + # scrollable_area = QScrollArea(self) + # frame = QFrame(scrollable_area) + + self.flow_layout = JFlowLayout(self) + + self.img_widgets_lookup = dict() + + + def add_image(self, img_idx): + image = self.load_image() + + # TODO: add frame number + # layout = QVBoxLayout() + # frame_num = QLabel(f"Frame {img_idx}") + img_widget = ClickableLabel() + img_widget.setPixmap(image) + + img_widget.clicked.connect(partial(self.on_click, img_idx)) + # layout.addWidget(img_widget) + + self.img_widgets_lookup[img_idx] = img_widget + self.flow_layout.addWidget(img_widget) + + def remove_image(self, img_idx): + img_widget = self.img_widgets_lookup.pop(img_idx) + self.flow_layout.removeWidget(img_widget) + + # def set_active(img_idx): + # # TODO: make red border on selected, remove on others + # pass From d4f13a8babdb5f7b57d0f9c937d62713a42eda6c Mon Sep 17 00:00:00 2001 From: max810 Date: Sat, 11 Mar 2023 19:08:28 +0400 Subject: [PATCH 19/49] Fixed a bug resulting in bad segmentation. Usable build --- inference/interact/gui.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/inference/interact/gui.py b/inference/interact/gui.py index 514a9eb..d8606be 100644 --- a/inference/interact/gui.py +++ b/inference/interact/gui.py @@ -78,6 +78,9 @@ def __init__(self, net: XMem, self.compute_candidates_button = QPushButton('Compute Annotation candidates') self.compute_candidates_button.clicked.connect(self.on_compute_candidates) + self.full_run_button = QPushButton('FULL Propagate') + self.full_run_button.clicked.connect(self.on_full_propagation) + self.forward_run_button = QPushButton('Forward Propagate') self.forward_run_button.clicked.connect(self.on_forward_propagation) self.forward_run_button.setMinimumWidth(200) @@ -173,7 +176,7 @@ def __init__(self, net: XMem, self.zoom_m_button.clicked.connect(self.on_zoom_minus) # Parameters setting - self.clear_mem_button = QPushButton('Clear memory') + self.clear_mem_button = QPushButton('Clear TEMP memory') self.clear_mem_button.clicked.connect(self.on_clear_memory) self.work_mem_gauge, self.work_mem_gauge_layout = create_gauge('Working memory size') @@ -248,6 +251,7 @@ def __init__(self, net: XMem, navi.addWidget(self.save_reference_button) navi.addWidget(self.compute_candidates_button) navi.addWidget(self.commit_button) + navi.addWidget(self.full_run_button) navi.addWidget(self.forward_run_button) navi.addWidget(self.backward_run_button) @@ -599,6 +603,10 @@ def brush_slide(self): except AttributeError: # Initialization, forget about it pass + + def on_full_propagation(self): + self.scroll_to(0) + self.on_forward_propagation() def on_forward_propagation(self): if self.propagating: @@ -633,10 +641,12 @@ def on_propagation(self): # start to propagate self.load_current_torch_image_mask() # TODO: put into permanent memory + # TODO: debug why quality so bad self.show_current_frame(fast=True) self.console_push_text('Propagation started.') - self.current_prob, key = self.processor.step(self.current_image_torch, self.current_prob[1:], return_key=True) + msk = self.current_prob[1:] if self.cursur in self.reference_ids else None + self.current_prob, key = self.processor.step(self.current_image_torch, msk, return_key=True) self.res_man.add_key_with_mask(self.cursur, key, self.current_prob) self.current_mask = torch_prob_to_numpy_mask(self.current_prob) @@ -653,8 +663,8 @@ def on_propagation(self): self.load_current_image_mask(no_mask=True) self.load_current_torch_image_mask(no_mask=True) - # TODO: read existing mask (if there is one), pass to .step() - self.current_prob, key = self.processor.step(self.current_image_torch, return_key=True) + msk = self.current_prob[1:] if self.cursur in self.reference_ids else None + self.current_prob, key = self.processor.step(self.current_image_torch, msk, return_key=True) self.res_man.add_key_with_mask(self.cursur, key, self.current_prob) self.current_mask = torch_prob_to_numpy_mask(self.current_prob) @@ -714,7 +724,7 @@ def on_save_reference(self): current_image_torch, _ = image_to_torch(self.current_image) current_prob = index_numpy_to_one_hot_torch(self.current_mask, self.num_objects+1).cuda() - self.processor.put_to_permanent_memory(current_image_torch, current_prob, self.cursur) + self.processor.put_to_permanent_memory(current_image_torch, current_prob[1:], self.cursur) self.reference_ids.add(self.cursur) self.references_collection.add_image(self.cursur) From c2d76413bb602776f969524e9476b576429f5566 Mon Sep 17 00:00:00 2001 From: max810 Date: Sun, 12 Mar 2023 16:23:55 +0400 Subject: [PATCH 20/49] Working temp and long memory clearing, works if want to change the target object --- inference/inference_core.py | 24 ++++++++----- inference/interact/gui.py | 65 +++++++++++++++++++++++++++++------- inference/kv_memory_store.py | 24 +++++++++++++ inference/memory_manager.py | 55 +++++++++++++++++++++++++++--- 4 files changed, 144 insertions(+), 24 deletions(-) diff --git a/inference/inference_core.py b/inference/inference_core.py index 9c225fc..0498f1b 100644 --- a/inference/inference_core.py +++ b/inference/inference_core.py @@ -19,12 +19,17 @@ def __init__(self, network:XMem, config): self.clear_memory() self.all_labels = None - def clear_memory(self): + def clear_memory(self, keep_permanent=False): self.curr_ti = -1 self.last_mem_ti = 0 if not self.deep_update_sync: self.last_deep_update_ti = -self.deep_update_every - self.memory = MemoryManager(config=self.config) + if keep_permanent: + new_memory = self.memory.copy_perm_mem_only() + else: + new_memory = MemoryManager(config=self.config) + + self.memory = new_memory def update_config(self, config): self.mem_every = config['mem_every'] @@ -155,15 +160,18 @@ def put_to_permanent_memory(self, image, mask, ti=None): value, hidden = self.network.encode_value(image, f16, self.memory.get_hidden(), pred_prob_with_bg[1:].unsqueeze(0), is_deep_update=False) + is_update = self.memory.frame_already_saved(ti) + print(ti, f"update={is_update}") if self.memory.frame_already_saved(ti): - # self.memory.update_permanent_memory(ti, ) - # maybe delete and update? - # TODO: splice the memory, update existing one - pass + self.memory.update_permanent_memory(ti, key, shrinkage, value, selection=selection if self.enable_long_term else None) else: self.memory.add_memory(key, shrinkage, value, self.all_labels, - selection=selection if self.enable_long_term else None, permanent=True) + selection=selection if self.enable_long_term else None, permanent=True, ti=ti) + print(self.memory.permanent_work_mem.key.shape) + + return is_update + @property def permanent_memory_frames(self): - return list(self.memory.frame_id_to_mem_idx.keys()) \ No newline at end of file + return list(self.memory.frame_id_to_permanent_mem_idx.keys()) \ No newline at end of file diff --git a/inference/interact/gui.py b/inference/interact/gui.py index d8606be..9167444 100644 --- a/inference/interact/gui.py +++ b/inference/interact/gui.py @@ -26,7 +26,8 @@ from PyQt5.QtWidgets import (QWidget, QApplication, QComboBox, QCheckBox, QHBoxLayout, QLabel, QPushButton, QTextEdit, QSpinBox, QFileDialog, - QPlainTextEdit, QVBoxLayout, QSizePolicy, QButtonGroup, QSlider, QShortcut, QRadioButton, QTabWidget, QDialog) + QPlainTextEdit, QVBoxLayout, QSizePolicy, QButtonGroup, QSlider, QShortcut, + QRadioButton, QTabWidget, QDialog, QErrorMessage, QMessageBox) from PyQt5.QtGui import QPixmap, QKeySequence, QImage, QTextCursor, QIcon from PyQt5.QtCore import Qt, QTimer, QThreadPool @@ -79,14 +80,14 @@ def __init__(self, net: XMem, self.compute_candidates_button.clicked.connect(self.on_compute_candidates) self.full_run_button = QPushButton('FULL Propagate') - self.full_run_button.clicked.connect(self.on_full_propagation) + self.full_run_button.clicked.connect(partial(self.general_propagation_callback, propagation_type='full')) self.forward_run_button = QPushButton('Forward Propagate') - self.forward_run_button.clicked.connect(self.on_forward_propagation) + self.forward_run_button.clicked.connect(partial(self.general_propagation_callback, propagation_type='forward')) self.forward_run_button.setMinimumWidth(200) self.backward_run_button = QPushButton('Backward Propagate') - self.backward_run_button.clicked.connect(self.on_backward_propagation) + self.backward_run_button.clicked.connect(partial(self.general_propagation_callback, propagation_type='backward')) self.backward_run_button.setMinimumWidth(200) self.reset_button = QPushButton('Reset Frame') @@ -176,7 +177,7 @@ def __init__(self, net: XMem, self.zoom_m_button.clicked.connect(self.on_zoom_minus) # Parameters setting - self.clear_mem_button = QPushButton('Clear TEMP memory') + self.clear_mem_button = QPushButton('Clear TEMP and LONG memory') self.clear_mem_button.clicked.connect(self.on_clear_memory) self.work_mem_gauge, self.work_mem_gauge_layout = create_gauge('Working memory size') @@ -378,10 +379,11 @@ def __init__(self, net: XMem, self.vis_target_objects = [1] # try to load the default overlay self._try_load_layer('./docs/ECCV-logo.png') - + self.load_current_image_mask() self.show_current_frame() self.show() + self.style_new_reference() self.console_push_text('Initialized.') self.initialized = True @@ -513,9 +515,18 @@ def show_current_frame(self, fast=False): self.lcd.setText('{: 3d} / {: 3d}'.format(self.cursur, self.num_frames-1)) self.tl_slider.setValue(self.cursur) - def show_candidates(self): - # TODO: draw image grid - pass + if self.cursur in self.reference_ids: + self.style_editing_reference() + else: + self.style_new_reference() + + def style_editing_reference(self): + self.save_reference_button.setText("Update reference") + self.save_reference_button.setStyleSheet('QPushButton {background-color: #E4A11B; font-size: bold; }') + + def style_new_reference(self): + self.save_reference_button.setText("Save reference") + self.save_reference_button.setStyleSheet('QPushButton {background-color: #14A44D; font-size: bold;}') def pixel_pos_to_image_pos(self, x, y): # Un-scale and un-pad the label coordinates into image coordinates @@ -603,6 +614,26 @@ def brush_slide(self): except AttributeError: # Initialization, forget about it pass + + def confirm_ready_for_propagation(self): + if len(self.reference_ids) > 0: + return True + + qm = QErrorMessage(self) + qm.setWindowModality(Qt.WindowModality.WindowModal) + qm.showMessage("Save at least 1 reference!") + + return False + + def general_propagation_callback(self, propagation_type: str): + if not self.confirm_ready_for_propagation(): + return + if propagation_type == 'full': + self.on_full_propagation() + elif propagation_type == 'forward': + self.on_forward_propagation() + elif propagation_type == 'backward': + self.on_backward_propagation() def on_full_propagation(self): self.scroll_to(0) @@ -614,6 +645,7 @@ def on_forward_propagation(self): self.propagating = False else: self.propagate_fn = self.on_next_frame + self.full_run_button.setEnabled(False) self.backward_run_button.setEnabled(False) self.forward_run_button.setText('Pause Propagation') self.on_propagation() @@ -624,12 +656,14 @@ def on_backward_propagation(self): self.propagating = False else: self.propagate_fn = self.on_prev_frame + self.full_run_button.setEnabled(False) self.forward_run_button.setEnabled(False) self.backward_run_button.setText('Pause Propagation') self.on_propagation() def on_pause(self): self.propagating = False + self.full_run_button.setEnabled(True) self.forward_run_button.setEnabled(True) self.backward_run_button.setEnabled(True) self.clear_mem_button.setEnabled(True) @@ -721,10 +755,16 @@ def _update_progress(i): def on_save_reference(self): # TODO: update permanent memory. Add if new, replace if already existed + if self.interaction is not None: + self.on_commit() current_image_torch, _ = image_to_torch(self.current_image) current_prob = index_numpy_to_one_hot_torch(self.current_mask, self.num_objects+1).cuda() - self.processor.put_to_permanent_memory(current_image_torch, current_prob[1:], self.cursur) + is_update = self.processor.put_to_permanent_memory(current_image_torch, current_prob[1:], self.cursur) + + if is_update: + self.reference_ids.remove(self.cursur) + self.references_collection.remove_image(self.cursur) self.reference_ids.add(self.cursur) self.references_collection.add_image(self.cursur) @@ -733,7 +773,8 @@ def on_save_reference(self): self.candidates_ids.remove(self.cursur) self.candidates_collection.remove_image(self.cursur) - # TODO: remove from candidates if was there + + self.show_current_frame() def on_prev_frame(self): # self.tl_slide will trigger on setValue @@ -979,7 +1020,7 @@ def update_config(self): self.processor.update_config(self.config) def on_clear_memory(self): - self.processor.clear_memory() + self.processor.clear_memory(keep_permanent=True) torch.cuda.empty_cache() self.update_gpu_usage() self.update_memory_size() diff --git a/inference/kv_memory_store.py b/inference/kv_memory_store.py index a31ac63..31b62e9 100644 --- a/inference/kv_memory_store.py +++ b/inference/kv_memory_store.py @@ -89,6 +89,10 @@ def add(self, key, value, shrinkage, selection, objects: List[int]): else: self.v.append(gv) + pos = int((self.k.shape[-1] + 1e-9) // (key.shape[-1] + 1e-9)) - 1 # index of newly added frame + + return pos + def update_usage(self, usage): # increase all life count by 1 # increase use of indexed elements @@ -98,6 +102,26 @@ def update_usage(self, usage): self.use_count += usage.view_as(self.use_count) self.life_count += 1 + def replace_at(self, start_pos: int, key, value, shrinkage=None, selection=None): + start = start_pos * key.shape[-1] + end = (start_pos + 1) * key.shape[-1] + + self.k[:,:,start:end] = key + + for gi in range(self.num_groups): + self.v[gi][:, :, start:end] = value[gi] + + if self.s is not None and shrinkage is not None: + self.s[:, :, start:end] = shrinkage + + if self.e is not None and selection is not None: + self.e[:, :, start:end] = selection + + def remove_at(self, start: int, elem_size: int): + end = start + elem_size + + self.sieve_by_range(start, end, min_size=0) # remove the value irrespective of its size + def sieve_by_range(self, start: int, end: int, min_size: int): # keep only the elements *outside* of this range (with some boundary conditions) # i.e., concat (a[:start], a[end:]) diff --git a/inference/memory_manager.py b/inference/memory_manager.py index cbe6a35..114839b 100644 --- a/inference/memory_manager.py +++ b/inference/memory_manager.py @@ -11,6 +11,7 @@ class MemoryManager: """ def __init__(self, config): + self.config = config self.hidden_dim = config['hidden_dim'] self.top_k = config['top_k'] @@ -32,7 +33,7 @@ def __init__(self, config): self.temporary_work_mem = KeyValueMemoryStore(count_usage=self.enable_long_term) self.permanent_work_mem = KeyValueMemoryStore(count_usage=False) - self.frame_id_to_mem_idx = dict() + self.frame_id_to_permanent_mem_idx = dict() if self.enable_long_term: self.long_mem = KeyValueMemoryStore(count_usage=self.enable_long_term_usage) @@ -189,7 +190,27 @@ def match_memory(self, query_key, selection, disable_usage_updates=False): return all_readout_mem.view(all_readout_mem.shape[0], self.CV, h, w) - def add_memory(self, key, shrinkage, value, objects, selection=None, permanent=False, ignore=False): + def update_permanent_memory(self, frame_idx, key, shrinkage, value, selection=None): + saved_pos = self.frame_id_to_permanent_mem_idx[frame_idx] + + key = key.flatten(start_dim=2) + shrinkage = shrinkage.flatten(start_dim=2) + value = value[0].flatten(start_dim=2) + + if selection is not None: + selection = selection.flatten(start_dim=2) + + self.permanent_work_mem.replace_at(saved_pos, key, value, shrinkage, selection) + + def remove_from_permanent_memory(self, frame_idx): + elem_size = self.HW + saved_pos = self.frame_id_to_permanent_mem_idx[frame_idx] + + self.permanent_work_mem.remove_at(saved_pos, elem_size) + + del self.frame_id_to_permanent_mem_idx[frame_idx] + + def add_memory(self, key, shrinkage, value, objects, selection=None, permanent=False, ignore=False, ti=None): # key: 1*C*H*W # value: 1*num_objects*C*H*W # objects contain a list of object indices @@ -220,7 +241,9 @@ def add_memory(self, key, shrinkage, value, objects, selection=None, permanent=F pass # all permanent frames are pre-placed into permanent memory (when using our memory modification) # also ignores the first frame (#0) when using original memory mechanism, since it's already in the permanent memory elif permanent: - self.permanent_work_mem.add(key, value, shrinkage, selection, objects) + pos = self.permanent_work_mem.add(key, value, shrinkage, selection, objects) + if ti is not None: + self.frame_id_to_permanent_mem_idx[ti] = pos else: self.temporary_work_mem.add(key, value, shrinkage, selection, objects) @@ -278,7 +301,7 @@ def get_hidden(self): return self.hidden def frame_already_saved(self, ti): - return ti in self.frame_id_to_mem_idx + return ti in self.frame_id_to_permanent_mem_idx # def slices_excluding_permanent(self, group_value, start, end): # HW = self.HW @@ -366,3 +389,27 @@ def consolidation(self, candidate_key, candidate_shrinkage, candidate_selection, prototype_shrinkage = self._readout(affinity[0], candidate_shrinkage) if candidate_shrinkage is not None else None return prototype_key, prototype_value, prototype_shrinkage + + def copy_perm_mem_only(self): + new_mem = MemoryManager(config=self.config) + new_mem.permanent_work_mem = self.permanent_work_mem + + key0 = self.permanent_work_mem.key[..., 0:0] + value0 = self.permanent_work_mem.value[0][..., 0:0] + shrinkage0 = self.permanent_work_mem.shrinkage[..., 0:0] if self.permanent_work_mem.shrinkage is not None else None + selection0 = self.permanent_work_mem.selection[..., 0:0] if self.permanent_work_mem.selection is not None else None + + new_mem.temporary_work_mem.add(key0, value0, shrinkage0, selection0, self.permanent_work_mem.all_objects) + + new_mem.CK = self.permanent_work_mem.key.shape[1] + new_mem.CV = self.permanent_work_mem.value[0].shape[1] + + key_shape = self.permanent_work_mem.key.shape + sample_key = self.permanent_work_mem.key[..., 0:self.HW].view(*key_shape[:-1], self.H, self.W) + new_mem.create_hidden_state(len(self.permanent_work_mem.all_objects), sample_key) + + new_mem.temporary_work_mem.obj_groups = self.temporary_work_mem.obj_groups + new_mem.temporary_work_mem.all_objects = self.temporary_work_mem.all_objects + + return new_mem + From 5a191417998d00051fa65558f42a433b52faa5f7 Mon Sep 17 00:00:00 2001 From: max810 Date: Sun, 12 Mar 2023 16:50:18 +0400 Subject: [PATCH 21/49] Quality of life improvements: added frame labels, sped up reference saving time --- inference/frame_selection/frame_selection.py | 6 ++++-- inference/inference_core.py | 8 +++++++- inference/interact/gui.py | 9 +++++++-- inference/interact/gui_utils.py | 17 +++++++++++++++-- 4 files changed, 33 insertions(+), 7 deletions(-) diff --git a/inference/frame_selection/frame_selection.py b/inference/frame_selection/frame_selection.py index 1fce745..e8fc8d9 100644 --- a/inference/frame_selection/frame_selection.py +++ b/inference/frame_selection/frame_selection.py @@ -96,7 +96,7 @@ def calculate_proposals_for_annotations_with_iterative_distance_cycle_MASKS(data return chosen_frames -def select_next_candidates(keys: torch.Tensor, masks: List[torch.tensor], num_next_candidates: int, previously_chosen_candidates: List[int] = (0,), print_progress=False, alpha=1.0, min_mask_presence_px=9, device: torch.device = 'cuda:0', progress_callback=None): +def select_next_candidates(keys: torch.Tensor, masks: List[torch.tensor], num_next_candidates: int, previously_chosen_candidates: List[int] = (0,), print_progress=False, alpha=1.0, min_mask_presence_px=9, device: torch.device = 'cuda:0', progress_callback=None, only_new_candidates=True): assert len(keys) == len(masks) assert len(keys) > 0 assert keys[0].shape[-2:] == masks[0].shape[-2:] @@ -201,4 +201,6 @@ def select_next_candidates(keys: torch.Tensor, masks: List[torch.tensor], num_ne if progress_callback is not None: progress_callback.emit(i + 1) - return chosen_candidates \ No newline at end of file + if only_new_candidates: + chosen_candidates = chosen_candidates[len(previously_chosen_candidates):] + return chosen_candidates diff --git a/inference/inference_core.py b/inference/inference_core.py index 0498f1b..e898257 100644 --- a/inference/inference_core.py +++ b/inference/inference_core.py @@ -1,3 +1,6 @@ +from time import perf_counter + +import torch from inference.memory_manager import MemoryManager from model.network import XMem from model.aggregate import aggregate @@ -19,6 +22,9 @@ def __init__(self, network:XMem, config): self.clear_memory() self.all_labels = None + # warmup + self.network.encode_key(torch.zeros((1, 3, 480, 854), device='cuda:0')) + def clear_memory(self, keep_permanent=False): self.curr_ti = -1 self.last_mem_ti = 0 @@ -28,7 +34,7 @@ def clear_memory(self, keep_permanent=False): new_memory = self.memory.copy_perm_mem_only() else: new_memory = MemoryManager(config=self.config) - + self.memory = new_memory def update_config(self, config): diff --git a/inference/interact/gui.py b/inference/interact/gui.py index 9167444..f09c7f6 100644 --- a/inference/interact/gui.py +++ b/inference/interact/gui.py @@ -15,6 +15,7 @@ import functools import os +from time import perf_counter import cv2 from inference.frame_selection.frame_selection import select_next_candidates @@ -754,13 +755,17 @@ def _update_progress(i): candidate_progress.open() def on_save_reference(self): - # TODO: update permanent memory. Add if new, replace if already existed if self.interaction is not None: self.on_commit() current_image_torch, _ = image_to_torch(self.current_image) current_prob = index_numpy_to_one_hot_torch(self.current_mask, self.num_objects+1).cuda() - is_update = self.processor.put_to_permanent_memory(current_image_torch, current_prob[1:], self.cursur) + msk = current_prob[1:] + a = perf_counter() + is_update = self.processor.put_to_permanent_memory(current_image_torch, msk, self.cursur) + b = perf_counter() + + self.console_push_text(f"Saving took {(b-a)*1000:.2f} ms.") if is_update: self.reference_ids.remove(self.cursur) diff --git a/inference/interact/gui_utils.py b/inference/interact/gui_utils.py index 288f742..47190bc 100644 --- a/inference/interact/gui_utils.py +++ b/inference/interact/gui_utils.py @@ -263,6 +263,17 @@ def mouseReleaseEvent(self, event): if event.button() == Qt.LeftButton and event.pos() in self.rect(): self.clicked.emit() + +class ImageWithCaption(QWidget): + def __init__(self, img: QLabel, caption: str, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + self.layout = QVBoxLayout(self) + self.text_label = QLabel(caption) + self.layout.addWidget(self.text_label) + self.layout.addWidget(img) + + self.layout.setAlignment(self.text_label, Qt.AlignmentFlag.AlignHCenter) class ImageLinkCollection(QWidget): def __init__(self, on_click: callable, load_image: callable, *args, **kwargs) -> None: super().__init__(*args, **kwargs) @@ -287,10 +298,12 @@ def add_image(self, img_idx): img_widget.setPixmap(image) img_widget.clicked.connect(partial(self.on_click, img_idx)) + + wrapper = ImageWithCaption(img_widget, f"Frame {img_idx:>6d}") # layout.addWidget(img_widget) - self.img_widgets_lookup[img_idx] = img_widget - self.flow_layout.addWidget(img_widget) + self.img_widgets_lookup[img_idx] = wrapper + self.flow_layout.addWidget(wrapper) def remove_image(self, img_idx): img_widget = self.img_widgets_lookup.pop(img_idx) From 235d5a633d2faee2c2816a9138fcff44ec5be66c Mon Sep 17 00:00:00 2001 From: max810 Date: Sun, 12 Mar 2023 20:48:34 +0400 Subject: [PATCH 22/49] Added scrollable references tab, added a warning for selecting candidates of propagation has not been run on all frames yet --- inference/interact/gui.py | 51 ++++++++++++++++++++++----- inference/interact/gui_utils.py | 61 ++++++++++++++++++++++++++++++++- 2 files changed, 103 insertions(+), 9 deletions(-) diff --git a/inference/interact/gui.py b/inference/interact/gui.py index f09c7f6..96268eb 100644 --- a/inference/interact/gui.py +++ b/inference/interact/gui.py @@ -94,6 +94,9 @@ def __init__(self, net: XMem, self.reset_button = QPushButton('Reset Frame') self.reset_button.clicked.connect(self.on_reset_mask) + self.spacebar = QShortcut(QKeySequence(Qt.Key_Space), self) + self.spacebar.activated.connect(self.pause_propagation) + # LCD self.lcd = QTextEdit() self.lcd.setReadOnly(True) @@ -246,12 +249,12 @@ def __init__(self, net: XMem, navi.addWidget(self.save_visualization_checkbox) navi.addStretch(1) - self.test_btn = QPushButton('TEST') - self.test_btn.clicked.connect(self.TEST) + # self.test_btn = QPushButton('TEST') + # self.test_btn.clicked.connect(self.TEST) - navi.addWidget(self.test_btn) + # navi.addWidget(self.test_btn) navi.addWidget(self.save_reference_button) - navi.addWidget(self.compute_candidates_button) + # navi.addWidget(self.compute_candidates_button) navi.addWidget(self.commit_button) navi.addWidget(self.full_run_button) navi.addWidget(self.forward_run_button) @@ -262,11 +265,20 @@ def __init__(self, net: XMem, draw_area.addWidget(self.main_canvas, 4) self.tabs = QTabWidget() + self.tabs.setMinimumWidth(500) self.map_tab = QWidget() self.references_tab = QWidget() + references_scroll = QScrollArea() + references_scroll.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOn) + references_scroll.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff) + references_scroll.setWidgetResizable(True) + references_scroll.setWidget(self.references_tab) + + self.references_scroll = references_scroll + self.tabs.addTab(self.map_tab,"Minimap && Stats") - self.tabs.addTab(self.references_tab,"References && Candidates") + self.tabs.addTab(self.references_scroll, "References && Candidates") tabs_layout = QVBoxLayout() @@ -317,6 +329,15 @@ def __init__(self, net: XMem, tabs_layout.addWidget(self.console) draw_area.addLayout(tabs_layout, 1) + candidates_area = QVBoxLayout() + self.candidates_k_slider = NamedSlider("k", 1, 10, 1, default=5) + self.candidates_alpha_slider = NamedSlider("α", 0, 100, 1, default=50, multiplier=0.01, min_text='Frames', max_text='Masks') + candidates_area.addWidget(QLabel("Candidates calculation hyperparameters")) + candidates_area.addWidget(self.candidates_k_slider) + candidates_area.addWidget(self.candidates_alpha_slider) + candidates_area.addWidget(self.compute_candidates_button) + tabs_layout.addLayout(candidates_area) + layout = QVBoxLayout() layout.addLayout(draw_area) layout.addWidget(self.tl_slider) @@ -713,7 +734,6 @@ def on_propagation(self): if self.cursur == 0 or self.cursur == self.num_frames-1: break - # TODO: after finished, compute candidates and show maybe? self.propagating = False self.curr_frame_dirty = False self.on_pause() @@ -727,6 +747,16 @@ def on_commit(self): self.complete_interaction() self.update_interacted_mask() + def confirm_ready_for_candidates_selection(self): + if self.res_man.all_masks_present(): + return True + + qm = QErrorMessage(self) + qm.setWindowModality(Qt.WindowModality.WindowModal) + qm.showMessage("Run propagation on all frames first!") + + return False + def on_compute_candidates(self): def _update_candidates(candidates_ids): print(candidates_ids) @@ -740,13 +770,18 @@ def _update_candidates(candidates_ids): self.scroll_to(i) self.candidates_collection.add_image(i) self.scroll_to(prev_pos) + self.tabs.setCurrentIndex(1) def _update_progress(i): candidate_progress.setValue(i) - k = 5 + if not self.confirm_ready_for_candidates_selection(): + return + + k = self.candidates_k_slider.value() + alpha = self.candidates_alpha_slider.value() candidate_progress = QProgressDialog("Selecting candidates", None, 0, k, self, Qt.WindowFlags(Qt.WindowType.Dialog | ~Qt.WindowCloseButtonHint)) - worker = Worker(select_next_candidates, self.res_man.keys, self.res_man.small_masks, k, self.reference_ids, print_progress=False, alpha=0.5, min_mask_presence_px=9) # Any other args, kwargs are passed to the run function + worker = Worker(select_next_candidates, self.res_man.keys, self.res_man.small_masks, k, self.reference_ids, print_progress=False, alpha=alpha, min_mask_presence_px=9) # Any other args, kwargs are passed to the run function worker.signals.result.connect(_update_candidates) worker.signals.progress.connect(_update_progress) diff --git a/inference/interact/gui_utils.py b/inference/interact/gui_utils.py index 47190bc..737bffa 100644 --- a/inference/interact/gui_utils.py +++ b/inference/interact/gui_utils.py @@ -4,7 +4,7 @@ import traceback, sys from PyQt5.QtCore import Qt, QRunnable, pyqtSlot, pyqtSignal, QObject, QPoint, QRect, QSize -from PyQt5.QtWidgets import (QHBoxLayout, QLabel, QSpinBox, QVBoxLayout, QProgressBar, QDialog, QWidget, QProgressDialog, QScrollArea, QLayout, QLayoutItem, QStyle, QSizePolicy, QSpacerItem, QFrame, QPushButton) +from PyQt5.QtWidgets import (QHBoxLayout, QLabel, QSpinBox, QVBoxLayout, QProgressBar, QDialog, QWidget, QProgressDialog, QScrollArea, QLayout, QLayoutItem, QStyle, QSizePolicy, QSpacerItem, QFrame, QPushButton, QSlider) class WorkerSignals(QObject): ''' @@ -256,6 +256,65 @@ def addStretch(self, stretch: int=0): self.addItem(spacerItem) +class NamedSlider(QWidget): + valueChanged = pyqtSignal(float) + + def __init__(self, name: str, min_: int, max_: int, step_size: int, default: int, multiplier=1, min_text=None, max_text=None, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.name = name + self.multiplier = multiplier + self.min_text = min_text + self.max_text = max_text + + layout = QHBoxLayout(self) + self.slider = QSlider(Qt.Horizontal) + self.slider.setMinimum(min_) + self.slider.setMaximum(max_) + self.slider.setValue(default) + self.slider.setTickPosition(QSlider.TicksBelow) + self.slider.setTickInterval(step_size) + + name_label = QLabel(name + " |") + self.value_label = QLabel() + + layout.addWidget(name_label) + layout.addWidget(self.value_label) + layout.addWidget(self.slider) + + self.update_name() + + self.slider.valueChanged.connect(self.on_slide) + + def value(self): + return self.slider.value() * self.multiplier + + def on_slide(self): + self.update_name() + self.valueChanged.emit(self.slider.value() * self.multiplier) + + def update_name(self): + value = self.value() + value_str = None + if self.multiplier != 1: + if isinstance(self.multiplier, float): + min_str = f'{self.slider.minimum() * self.multiplier:.2f}' + value_str = f'{value:.2f}' + max_str = f'{self.slider.maximum() * self.multiplier:.2f}' + + if value_str is None: + min_str = f'{self.slider.minimum() * self.multiplier:d}' + value_str = f'{value:d}' + max_str = f'{self.slider.maximum() * self.multiplier:d}' + + if self.min_text is not None: + min_str += f' ({self.min_text})' + if self.max_text is not None: + max_str += f' ({self.max_text})' + + final_str = f'{min_str} <= {value_str} <= {max_str}' + + self.value_label.setText(final_str) + class ClickableLabel(QLabel): clicked = pyqtSignal() def mouseReleaseEvent(self, event): From 933aba9fb6ccd3cb2e5493930096e43815340736 Mon Sep 17 00:00:00 2001 From: max810 Date: Sun, 12 Mar 2023 21:40:52 +0400 Subject: [PATCH 23/49] Added removal from memory, minor bug fixes --- inference/inference_core.py | 3 +++ inference/interact/gui.py | 32 ++++++++++++++++------ inference/interact/gui_utils.py | 48 ++++++++++++++++++++++++++------- inference/memory_manager.py | 14 +++++++++- 4 files changed, 78 insertions(+), 19 deletions(-) diff --git a/inference/inference_core.py b/inference/inference_core.py index e898257..0981d61 100644 --- a/inference/inference_core.py +++ b/inference/inference_core.py @@ -178,6 +178,9 @@ def put_to_permanent_memory(self, image, mask, ti=None): return is_update + def remove_from_permanent_memory(self, frame_idx): + self.memory.remove_from_permanent_memory(frame_idx) + @property def permanent_memory_frames(self): return list(self.memory.frame_id_to_permanent_mem_idx.keys()) \ No newline at end of file diff --git a/inference/interact/gui.py b/inference/interact/gui.py index 96268eb..8074ea6 100644 --- a/inference/interact/gui.py +++ b/inference/interact/gui.py @@ -318,10 +318,10 @@ def __init__(self, net: XMem, chosen_figures_area = QVBoxLayout(self.references_tab) chosen_figures_area.addWidget(QLabel("SAVED REFERENCES IN PERMANENT MEMORY")) - self.references_collection = ImageLinkCollection(self.scroll_to, self.load_current_image_thumbnail) + self.references_collection = ImageLinkCollection(self.scroll_to, self.load_current_image_thumbnail, delete_image=self.on_remove_reference, name='Reference frames') chosen_figures_area.addWidget(self.references_collection) - self.candidates_collection = ImageLinkCollection(self.scroll_to, self.load_current_image_thumbnail) + self.candidates_collection = ImageLinkCollection(self.scroll_to, self.load_current_image_thumbnail, name='Candidate frames') chosen_figures_area.addWidget(QLabel("ANNOTATION CANDIDATES")) chosen_figures_area.addWidget(self.candidates_collection) @@ -330,7 +330,7 @@ def __init__(self, net: XMem, draw_area.addLayout(tabs_layout, 1) candidates_area = QVBoxLayout() - self.candidates_k_slider = NamedSlider("k", 1, 10, 1, default=5) + self.candidates_k_slider = NamedSlider("k", 1, 20, 1, default=5) self.candidates_alpha_slider = NamedSlider("α", 0, 100, 1, default=50, multiplier=0.01, min_text='Frames', max_text='Masks') candidates_area.addWidget(QLabel("Candidates calculation hyperparameters")) candidates_area.addWidget(self.candidates_k_slider) @@ -512,8 +512,17 @@ def update_current_image_fast(self): self.main_canvas.setPixmap(QPixmap(qImg.scaled(self.main_canvas.size(), Qt.KeepAspectRatio, Qt.FastTransformation))) - def load_current_image_thumbnail(self, size=128): - curr_pixmap = self.main_canvas.pixmap() + def load_current_image_thumbnail(self, *args, size=128): + # all this instead of self.main_canvas.pixmap() because it contains the brush as well + viz = get_visualization(self.viz_mode, self.current_image, self.current_mask, + self.overlay_layer, self.vis_target_objects) + + height, width, channel = viz.shape + bytesPerLine = 3 * width + qImg = QImage(viz.data, width, height, bytesPerLine, QImage.Format_RGB888) + curr_pixmap = QPixmap(qImg.scaled(self.main_canvas.size(), + Qt.KeepAspectRatio, Qt.FastTransformation)) + curr_size = curr_pixmap.size() h = curr_size.height() w = curr_size.width() @@ -544,11 +553,11 @@ def show_current_frame(self, fast=False): def style_editing_reference(self): self.save_reference_button.setText("Update reference") - self.save_reference_button.setStyleSheet('QPushButton {background-color: #E4A11B; font-size: bold; }') + self.save_reference_button.setStyleSheet('QPushButton {background-color: #E4A11B; font-weight: bold; }') def style_new_reference(self): self.save_reference_button.setText("Save reference") - self.save_reference_button.setStyleSheet('QPushButton {background-color: #14A44D; font-size: bold;}') + self.save_reference_button.setStyleSheet('QPushButton {background-color: #14A44D; font-weight: bold;}') def pixel_pos_to_image_pos(self, x, y): # Un-scale and un-pad the label coordinates into image coordinates @@ -650,6 +659,8 @@ def confirm_ready_for_propagation(self): def general_propagation_callback(self, propagation_type: str): if not self.confirm_ready_for_propagation(): return + + self.tabs.setCurrentIndex(0) if propagation_type == 'full': self.on_full_propagation() elif propagation_type == 'forward': @@ -811,10 +822,15 @@ def on_save_reference(self): if self.cursur in self.candidates_ids: self.candidates_ids.remove(self.cursur) - self.candidates_collection.remove_image(self.cursur) self.show_current_frame() + self.tabs.setCurrentIndex(1) + + def on_remove_reference(self, img_idx): + self.processor.remove_from_permanent_memory(img_idx) + self.reference_ids.remove(img_idx) + self.show_current_frame() def on_prev_frame(self): # self.tl_slide will trigger on setValue diff --git a/inference/interact/gui_utils.py b/inference/interact/gui_utils.py index 737bffa..8bd3621 100644 --- a/inference/interact/gui_utils.py +++ b/inference/interact/gui_utils.py @@ -4,7 +4,9 @@ import traceback, sys from PyQt5.QtCore import Qt, QRunnable, pyqtSlot, pyqtSignal, QObject, QPoint, QRect, QSize -from PyQt5.QtWidgets import (QHBoxLayout, QLabel, QSpinBox, QVBoxLayout, QProgressBar, QDialog, QWidget, QProgressDialog, QScrollArea, QLayout, QLayoutItem, QStyle, QSizePolicy, QSpacerItem, QFrame, QPushButton, QSlider) +from PyQt5.QtWidgets import (QHBoxLayout, QLabel, QSpinBox, QVBoxLayout, QProgressBar, QDialog, QWidget, + QProgressDialog, QScrollArea, QLayout, QLayoutItem, QStyle, QSizePolicy, QSpacerItem, + QFrame, QPushButton, QSlider, QMessageBox) class WorkerSignals(QObject): ''' @@ -324,21 +326,37 @@ def mouseReleaseEvent(self, event): class ImageWithCaption(QWidget): - def __init__(self, img: QLabel, caption: str, *args, **kwargs) -> None: + def __init__(self, img: QLabel, caption: str, on_close: callable = None, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.layout = QVBoxLayout(self) self.text_label = QLabel(caption) - self.layout.addWidget(self.text_label) + self.close_btn = QPushButton("x") + self.close_btn.setMaximumSize(35, 35) + self.close_btn.setMinimumSize(35, 35) + self.close_btn.setStyleSheet('QPushButton {background-color: #DC4C64; font-weight: bold; }') + if on_close is not None: + self.close_btn.clicked.connect(on_close) + + self.top_tab_layout = QHBoxLayout() + self.top_tab_layout.addWidget(self.text_label) + self.top_tab_layout.addWidget(self.close_btn) + self.top_tab_layout.setAlignment(self.text_label, Qt.AlignmentFlag.AlignCenter) + self.top_tab_layout.setAlignment(self.close_btn, Qt.AlignmentFlag.AlignRight) + + self.layout.addLayout(self.top_tab_layout) + self.layout.addWidget(img) self.layout.setAlignment(self.text_label, Qt.AlignmentFlag.AlignHCenter) + class ImageLinkCollection(QWidget): - def __init__(self, on_click: callable, load_image: callable, *args, **kwargs) -> None: + def __init__(self, on_click: callable, load_image: callable, delete_image: callable = None, name: str = None, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.on_click = on_click self.load_image = load_image - + self.delete_image = delete_image + self.name = name # scrollable_area = QScrollArea(self) # frame = QFrame(scrollable_area) @@ -348,7 +366,7 @@ def __init__(self, on_click: callable, load_image: callable, *args, **kwargs) -> def add_image(self, img_idx): - image = self.load_image() + image = self.load_image(img_idx) # TODO: add frame number # layout = QVBoxLayout() @@ -358,7 +376,7 @@ def add_image(self, img_idx): img_widget.clicked.connect(partial(self.on_click, img_idx)) - wrapper = ImageWithCaption(img_widget, f"Frame {img_idx:>6d}") + wrapper = ImageWithCaption(img_widget, f"Frame {img_idx:>6d}", on_close=partial(self.on_close_click, img_idx)) # layout.addWidget(img_widget) self.img_widgets_lookup[img_idx] = wrapper @@ -368,6 +386,16 @@ def remove_image(self, img_idx): img_widget = self.img_widgets_lookup.pop(img_idx) self.flow_layout.removeWidget(img_widget) - # def set_active(img_idx): - # # TODO: make red border on selected, remove on others - # pass + def on_close_click(self, img_idx): + qm = QMessageBox() + question = f"Delete Frame {img_idx}" + if self.name is not None: + question += f' from {self.name}' + + question += '?' + ret = qm.question(self, 'Confirm deletion', question, qm.Yes | qm.No) + + if ret == qm.Yes: + self.remove_image(img_idx) + if self.delete_image is not None: + self.delete_image(img_idx) diff --git a/inference/memory_manager.py b/inference/memory_manager.py index 114839b..93f4563 100644 --- a/inference/memory_manager.py +++ b/inference/memory_manager.py @@ -392,7 +392,12 @@ def consolidation(self, candidate_key, candidate_shrinkage, candidate_selection, def copy_perm_mem_only(self): new_mem = MemoryManager(config=self.config) + + if self.permanent_work_mem.key is None or self.permanent_work_mem.key.size(-1) == 0: + return new_mem + new_mem.permanent_work_mem = self.permanent_work_mem + new_mem.frame_id_to_permanent_mem_idx = self.frame_id_to_permanent_mem_idx key0 = self.permanent_work_mem.key[..., 0:0] value0 = self.permanent_work_mem.value[0][..., 0:0] @@ -407,9 +412,16 @@ def copy_perm_mem_only(self): key_shape = self.permanent_work_mem.key.shape sample_key = self.permanent_work_mem.key[..., 0:self.HW].view(*key_shape[:-1], self.H, self.W) new_mem.create_hidden_state(len(self.permanent_work_mem.all_objects), sample_key) - + new_mem.temporary_work_mem.obj_groups = self.temporary_work_mem.obj_groups new_mem.temporary_work_mem.all_objects = self.temporary_work_mem.all_objects + + new_mem.CK = self.CK + new_mem.CV = self.CV + new_mem.H = self.H + new_mem.W = self.W + new_mem.HW = self.HW + return new_mem From 89d4c353afd153f276072eda48c8967bcfbc1a40 Mon Sep 17 00:00:00 2001 From: max810 Date: Sun, 12 Mar 2023 23:35:49 +0400 Subject: [PATCH 24/49] Fixed importing masks (rgb and grayscale) --- inference/interact/gui.py | 30 ++++++++++++++----------- inference/interact/gui_utils.py | 2 +- inference/interact/resource_manager.py | 31 ++++++++++++++++++++++++-- 3 files changed, 47 insertions(+), 16 deletions(-) diff --git a/inference/interact/gui.py b/inference/interact/gui.py index 8074ea6..127cc4d 100644 --- a/inference/interact/gui.py +++ b/inference/interact/gui.py @@ -15,6 +15,7 @@ import functools import os +from pathlib import Path from time import perf_counter import cv2 @@ -61,6 +62,7 @@ def __init__(self, net: XMem, self.processor.set_all_labels(list(range(1, self.num_objects+1))) self.res_man = resource_manager self.threadpool = QThreadPool() + self.last_opened_directory = str(Path.home()) self.num_frames = len(self.res_man) self.height, self.width = self.res_man.h, self.res_man.w @@ -76,7 +78,7 @@ def __init__(self, net: XMem, self.commit_button = QPushButton('Commit') self.commit_button.clicked.connect(self.on_commit) self.save_reference_button = QPushButton('Save reference') - self.save_reference_button.clicked.connect(self.on_save_reference) + self.save_reference_button.clicked.connect(self.on_save_reference) self.compute_candidates_button = QPushButton('Compute Annotation candidates') self.compute_candidates_button.clicked.connect(self.on_compute_candidates) @@ -251,7 +253,6 @@ def __init__(self, net: XMem, # self.test_btn = QPushButton('TEST') # self.test_btn.clicked.connect(self.TEST) - # navi.addWidget(self.test_btn) navi.addWidget(self.save_reference_button) # navi.addWidget(self.compute_candidates_button) @@ -410,10 +411,6 @@ def __init__(self, net: XMem, self.console_push_text('Initialized.') self.initialized = True - def TEST(self): - print(self.res_man.all_masks_present()) - pass - def resizeEvent(self, event): self.show_current_frame() @@ -1083,7 +1080,9 @@ def on_clear_memory(self): def _open_file(self, prompt): options = QFileDialog.Options() - file_name, _ = QFileDialog.getOpenFileName(self, prompt, "", "Image files (*)", options=options) + file_name, _ = QFileDialog.getOpenFileName(self, prompt, self.last_opened_directory, "Image files (*)", options=options) + if file_name: + self.last_opened_directory = str(Path(file_name).parent) return file_name def on_import_mask(self): @@ -1091,7 +1090,7 @@ def on_import_mask(self): if len(file_name) == 0: return - mask = self.res_man.read_external_image(file_name, size=(self.height, self.width)) + mask = self.res_man.read_external_image(file_name, size=(self.height, self.width), force_mask=True) shape_condition = ( (len(mask.shape) == 2) and @@ -1108,11 +1107,16 @@ def on_import_mask(self): elif not object_condition: self.console_push_text(f'Expected {self.num_objects} objects. Got {mask.max()} objects instead.') else: - self.console_push_text(f'Mask file {file_name} loaded.') - self.current_image_torch = self.current_prob = None - self.current_mask = mask - self.show_current_frame() - self.save_current_mask() + qm = QMessageBox(QMessageBox.Icon.Question, "Confirm mask replacement", "") + question = f"Replace mask for current frame {self.cursur} with {Path(file_name).name}?" + ret = qm.question(self, 'Confirm mask replacemen', question, qm.Yes | qm.No) + + if ret == qm.Yes: + self.console_push_text(f'Mask file {file_name} loaded.') + self.current_image_torch = self.current_prob = None + self.current_mask = mask + self.show_current_frame() + self.save_current_mask() def on_import_layer(self): file_name = self._open_file('Layer') diff --git a/inference/interact/gui_utils.py b/inference/interact/gui_utils.py index 8bd3621..be30410 100644 --- a/inference/interact/gui_utils.py +++ b/inference/interact/gui_utils.py @@ -387,7 +387,7 @@ def remove_image(self, img_idx): self.flow_layout.removeWidget(img_widget) def on_close_click(self, img_idx): - qm = QMessageBox() + qm = QMessageBox(QMessageBox.Icon.Warning, "Confirm deletion", "") question = f"Delete Frame {img_idx}" if self.name is not None: question += f' from {self.name}' diff --git a/inference/interact/resource_manager.py b/inference/interact/resource_manager.py index ef36830..f8d261d 100644 --- a/inference/interact/resource_manager.py +++ b/inference/interact/resource_manager.py @@ -206,15 +206,42 @@ def _get_mask_unbuffered(self, ti): else: return None - def read_external_image(self, file_name, size=None): + def read_external_image(self, file_name, size=None, force_mask=False): image = Image.open(file_name) is_mask = image.mode in ['L', 'P'] + if size is not None: # PIL uses (width, height) image = image.resize((size[1], size[0]), - resample=Image.Resampling.NEAREST if is_mask else Image.Resampling.BICUBIC) + resample=Image.Resampling.NEAREST if is_mask or force_mask else Image.Resampling.BICUBIC) + if force_mask and image.mode != 'P': + if image.mode in ['RGB', 'L'] and len(image.getcolors()) <= 2: + image = np.array(image.convert('L')) + # hardcoded for b&w images + image = np.where(image, 1, 0) # 255 (or whatever) -> binarize + + return image.astype('uint8') + image = image.convert('P', palette=self.palette) # saved without DAVIS palette, just number objects 0, 1, ... + image = np.array(image) return image + + def replace_mask_with_external(self, ti, mask_path): + try: + image = Image.open(mask_path) + assert image.mode in ['L', 'P', 'RGB'] + + if image.mode == 'RGB': + if len(image.getcolors()) <= 2: + image = image.convert('L') + + image = image.convert('P') + image.putpalette(self.palette) + image.save(path.join(self.mask_dir, self.names[ti]+'.png')) + + except (FileNotFoundError, AssertionError, ValueError): + raise ValueError(f"Invalid file: {mask_path}") + def invalidate(self, ti): # the image buffer is never invalidated From 530fb336a4c5cada8296b59ecdfd096328a76908 Mon Sep 17 00:00:00 2001 From: max810 Date: Mon, 13 Mar 2023 14:42:07 +0400 Subject: [PATCH 25/49] Fixed duplicating video names errors --- inference/interact/resource_manager.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/inference/interact/resource_manager.py b/inference/interact/resource_manager.py index f8d261d..65b9b58 100644 --- a/inference/interact/resource_manager.py +++ b/inference/interact/resource_manager.py @@ -1,5 +1,6 @@ import os from os import path +from pathlib import Path import shutil import collections @@ -51,7 +52,11 @@ def __init__(self, config): # create temporary workspace if not specified if self.workspace is None: if images is not None: - basename = path.basename(images) + p_images = Path(images) + if p_images.name == 'JPEGImages': + basename = p_images.parent.name + else: + basename = p_images.name elif video is not None: basename = path.basename(video)[:-4] else: From 06d03ef8452bc8e55bf200a555b1ffb455967c4a Mon Sep 17 00:00:00 2001 From: max810 Date: Mon, 13 Mar 2023 17:38:15 +0400 Subject: [PATCH 26/49] Added option to upload all existing masks at once. Importing a mask now automatically saves it to the references --- inference/interact/gui.py | 48 +++++++++++++++++++++++++++++++++++---- 1 file changed, 44 insertions(+), 4 deletions(-) diff --git a/inference/interact/gui.py b/inference/interact/gui.py index 127cc4d..fcbffa9 100644 --- a/inference/interact/gui.py +++ b/inference/interact/gui.py @@ -16,6 +16,7 @@ import os from pathlib import Path +import re from time import perf_counter import cv2 @@ -214,6 +215,9 @@ def __init__(self, net: XMem, # import mask/layer self.import_mask_button = QPushButton('Import mask') self.import_mask_button.clicked.connect(self.on_import_mask) + + self.import_all_masks_button = QPushButton('Import ALL masks') + self.import_all_masks_button.clicked.connect(self.on_import_all_masks) self.import_layer_button = QPushButton('Import layer') self.import_layer_button.clicked.connect(self.on_import_layer) @@ -314,6 +318,7 @@ def __init__(self, net: XMem, import_area = QHBoxLayout() import_area.setAlignment(Qt.AlignTop) import_area.addWidget(self.import_mask_button) + import_area.addWidget(self.import_all_masks_button) import_area.addWidget(self.import_layer_button) minimap_area.addLayout(import_area) @@ -1085,10 +1090,44 @@ def _open_file(self, prompt): self.last_opened_directory = str(Path(file_name).parent) return file_name - def on_import_mask(self): - file_name = self._open_file('Mask') - if len(file_name) == 0: - return + def on_import_all_masks(self): + dir_path = QFileDialog.getExistingDirectory() + if dir_path: + self.last_opened_directory = dir_path + + all_correct = True + frame_ids = [] + incorrect_files = [] + pattern = re.compile(r'([0-9]+)') + files_paths = sorted(Path(dir_path).iterdir()) + for p_f in files_paths: + match = pattern.search(p_f.name) + if match: + frame_id = int(match.string[match.start():match.end()]) + frame_ids.append(frame_id) + else: + all_correct = False + incorrect_files.apend(p_f.name) + + + if not all_correct or frame_ids != sorted(frame_ids): + qm = QErrorMessage(self) + qm.setWindowModality(Qt.WindowModality.WindowModal) + broken_file_names = '\n'.join(incorrect_files) + qm.showMessage(f"Files with incorrect names: {broken_file_names}") + + else: + for i, p_f in zip(frame_ids, files_paths): + self.scroll_to(i) + self.on_import_mask(str(p_f)) + + def on_import_mask(self, mask_file_path=None): + if mask_file_path: + file_name = mask_file_path + else: + file_name = self._open_file('Mask') + if len(file_name) == 0: + return mask = self.res_man.read_external_image(file_name, size=(self.height, self.width), force_mask=True) @@ -1117,6 +1156,7 @@ def on_import_mask(self): self.current_mask = mask self.show_current_frame() self.save_current_mask() + self.on_save_reference() def on_import_layer(self): file_name = self._open_file('Layer') From d7bdbc5dec410c3c5500f842b6de50363ae6d56e Mon Sep 17 00:00:00 2001 From: max810 Date: Mon, 13 Mar 2023 17:54:53 +0400 Subject: [PATCH 27/49] Save overlay by default --- inference/interact/gui.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/inference/interact/gui.py b/inference/interact/gui.py index fcbffa9..13664cd 100644 --- a/inference/interact/gui.py +++ b/inference/interact/gui.py @@ -140,9 +140,9 @@ def __init__(self, net: XMem, self.combo.currentTextChanged.connect(self.set_viz_mode) self.save_visualization_checkbox = QCheckBox(self) + self.save_visualization_checkbox.setChecked(True) self.save_visualization_checkbox.toggled.connect(self.on_save_visualization_toggle) - self.save_visualization_checkbox.setChecked(False) - self.save_visualization = False + self.save_visualization = True # Radio buttons for type of interactions self.curr_interaction = 'Click' From bffb1f316a9917c30743e199035489d18eab70a7 Mon Sep 17 00:00:00 2001 From: max810 Date: Tue, 14 Mar 2023 12:04:52 +0400 Subject: [PATCH 28/49] Fixed existing masks being changed during propagation --- inference/interact/gui.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/inference/interact/gui.py b/inference/interact/gui.py index 13664cd..4a46677 100644 --- a/inference/interact/gui.py +++ b/inference/interact/gui.py @@ -709,13 +709,14 @@ def on_pause(self): def on_propagation(self): # start to propagate self.load_current_torch_image_mask() - # TODO: put into permanent memory - # TODO: debug why quality so bad self.show_current_frame(fast=True) self.console_push_text('Propagation started.') + is_mask = self.cursur in self.reference_ids msk = self.current_prob[1:] if self.cursur in self.reference_ids else None - self.current_prob, key = self.processor.step(self.current_image_torch, msk, return_key=True) + current_prob, key = self.processor.step(self.current_image_torch, msk, return_key=True) + if not is_mask: + self.current_prob = current_prob self.res_man.add_key_with_mask(self.cursur, key, self.current_prob) self.current_mask = torch_prob_to_numpy_mask(self.current_prob) @@ -731,14 +732,16 @@ def on_propagation(self): self.load_current_image_mask(no_mask=True) self.load_current_torch_image_mask(no_mask=True) - + is_mask = self.cursur in self.reference_ids msk = self.current_prob[1:] if self.cursur in self.reference_ids else None - self.current_prob, key = self.processor.step(self.current_image_torch, msk, return_key=True) + current_prob, key = self.processor.step(self.current_image_torch, msk, return_key=True) self.res_man.add_key_with_mask(self.cursur, key, self.current_prob) - self.current_mask = torch_prob_to_numpy_mask(self.current_prob) - - self.save_current_mask() + if not is_mask: + self.current_prob = current_prob + self.current_mask = torch_prob_to_numpy_mask(self.current_prob) + self.save_current_mask() + self.show_current_frame(fast=True) self.update_memory_size() From 33ee1a9de5fabd70321df0f2c7ceedce928f9df1 Mon Sep 17 00:00:00 2001 From: max810 Date: Wed, 15 Mar 2023 19:13:13 +0400 Subject: [PATCH 29/49] Fixed frame selector not using alpha properly --- inference/frame_selection/frame_selection.py | 15 +++++++++++---- inference/interact/gui.py | 6 +++--- inference/interact/resource_manager.py | 18 ++++++++++++------ 3 files changed, 26 insertions(+), 13 deletions(-) diff --git a/inference/frame_selection/frame_selection.py b/inference/frame_selection/frame_selection.py index e8fc8d9..c6951db 100644 --- a/inference/frame_selection/frame_selection.py +++ b/inference/frame_selection/frame_selection.py @@ -10,6 +10,7 @@ import numpy as np from tqdm import tqdm from inference.frame_selection.frame_selection_utils import extract_keys +from torchvision.transforms import Resize, InterpolationMode from model.memory_util import get_similarity @@ -96,16 +97,17 @@ def calculate_proposals_for_annotations_with_iterative_distance_cycle_MASKS(data return chosen_frames -def select_next_candidates(keys: torch.Tensor, masks: List[torch.tensor], num_next_candidates: int, previously_chosen_candidates: List[int] = (0,), print_progress=False, alpha=1.0, min_mask_presence_px=9, device: torch.device = 'cuda:0', progress_callback=None, only_new_candidates=True): +def select_next_candidates(keys: torch.Tensor, masks: List[torch.tensor], num_next_candidates: int, previously_chosen_candidates: List[int] = (0,), print_progress=False, alpha=1.0, min_mask_presence_px=9, device: torch.device = 'cuda:0', progress_callback=None, only_new_candidates=True, epsilon=1e-4, h=30, w=54): assert len(keys) == len(masks) assert len(keys) > 0 - assert keys[0].shape[-2:] == masks[0].shape[-2:] + # assert keys[0].shape[-2:] == masks[0].shape[-2:] assert num_next_candidates > 0 assert len(previously_chosen_candidates) > 0 assert 0.0 <= alpha <= 1.0 assert min_mask_presence_px >= 0 assert len(previously_chosen_candidates) < len(keys) + """ Select candidate frames for annotation based on dissimilarity and cycle consistency. @@ -138,6 +140,8 @@ def select_next_candidates(keys: torch.Tensor, masks: List[torch.tensor], num_ne ensures that the dissimilarity D(A->A)=0, while D(A->B)>0, and is larger the more different A and B are (pixel-wise). """ + resize = Resize((h, w), interpolation=InterpolationMode.BILINEAR) + with torch.no_grad(): composite_keys = [] keys = keys.squeeze() @@ -146,12 +150,15 @@ def select_next_candidates(keys: torch.Tensor, masks: List[torch.tensor], num_ne masks_validity = np.full(N, True) for i, mask in enumerate(masks): - mask_size_px = (mask > 0).sum() + mask_size_px = (mask > epsilon).sum() if mask_size_px < min_mask_presence_px: masks_validity[i] = False composite_keys.append(None) + print(i, mask_size_px) + continue + mask = resize(mask) composite_key = keys[i] * mask.max(dim=0, keepdim=True).values # any object -> 1., background -> 0.. Keep 1 channel only composite_key = composite_key * alpha + keys[i] * (1 - alpha) @@ -163,12 +170,12 @@ def select_next_candidates(keys: torch.Tensor, masks: List[torch.tensor], num_ne for i in tqdm(range(num_next_candidates), desc='Iteratively picking the most dissimilar frames', disable=not print_progress): candidate_dissimilarities = [] for j in tqdm(range(N), desc='Computing similarity to chosen frames', disable=not print_progress): - qk = composite_keys[j].to(device) if not masks_validity[j]: # ignore this potential candidate dissimilarity_min_across_all = 0 else: + qk = composite_keys[j].to(device) dissimilarities_across_mem_keys = [] for mem_key in chosen_candidate_keys: mem_key = mem_key diff --git a/inference/interact/gui.py b/inference/interact/gui.py index 4a46677..14623c8 100644 --- a/inference/interact/gui.py +++ b/inference/interact/gui.py @@ -717,7 +717,7 @@ def on_propagation(self): current_prob, key = self.processor.step(self.current_image_torch, msk, return_key=True) if not is_mask: self.current_prob = current_prob - self.res_man.add_key_with_mask(self.cursur, key, self.current_prob) + self.res_man.add_key_with_mask(self.cursur, key, self.current_prob[1:]) self.current_mask = torch_prob_to_numpy_mask(self.current_prob) # clear @@ -735,7 +735,7 @@ def on_propagation(self): is_mask = self.cursur in self.reference_ids msk = self.current_prob[1:] if self.cursur in self.reference_ids else None current_prob, key = self.processor.step(self.current_image_torch, msk, return_key=True) - self.res_man.add_key_with_mask(self.cursur, key, self.current_prob) + self.res_man.add_key_with_mask(self.cursur, key, self.current_prob[1:]) if not is_mask: self.current_prob = current_prob @@ -797,7 +797,7 @@ def _update_progress(i): k = self.candidates_k_slider.value() alpha = self.candidates_alpha_slider.value() candidate_progress = QProgressDialog("Selecting candidates", None, 0, k, self, Qt.WindowFlags(Qt.WindowType.Dialog | ~Qt.WindowCloseButtonHint)) - worker = Worker(select_next_candidates, self.res_man.keys, self.res_man.small_masks, k, self.reference_ids, print_progress=False, alpha=alpha, min_mask_presence_px=9) # Any other args, kwargs are passed to the run function + worker = Worker(select_next_candidates, self.res_man.keys, self.res_man.small_masks, k, self.reference_ids, print_progress=False, alpha=alpha, min_mask_presence_px=9, h=self.res_man.key_h, w=self.res_man.key_w) # Any other args, kwargs are passed to the run function worker.signals.result.connect(_update_candidates) worker.signals.progress.connect(_update_progress) diff --git a/inference/interact/resource_manager.py b/inference/interact/resource_manager.py index 65b9b58..7ffdada 100644 --- a/inference/interact/resource_manager.py +++ b/inference/interact/resource_manager.py @@ -109,9 +109,11 @@ def __init__(self, config): self.visualization_init = False self._resize = None - self._small_masks = None + self._masks = None self._keys = None self._keys_processed = np.zeros(self.length, dtype=bool) + self.key_h = None + self.key_w = None def _extract_frames(self, video): cap = cv2.VideoCapture(video) @@ -154,17 +156,21 @@ def _copy_resize_frames(self, images): def add_key_with_mask(self, ti, key, mask): if self._keys is None: c, h, w = key.squeeze().shape - c_mask = mask.shape[0] + if self.key_h is None: + self.key_h = h + if self.key_w is None: + self.key_w = w + c_mask, h_mask, w_mask = mask.shape self._keys = torch.empty((self.length, c, h, w), dtype=key.dtype, device=key.device) - self._small_masks = torch.empty((self.length, c_mask, h, w), dtype=mask.dtype, device=key.device) - self._resize = Resize((h, w), interpolation=InterpolationMode.NEAREST) + self._masks = torch.empty((self.length, c_mask, h_mask, w_mask), dtype=mask.dtype, device=key.device) + # self._resize = Resize((h, w), interpolation=InterpolationMode.NEAREST) if not self._keys_processed[ti]: # keys don't change for the video, so we only save them once self._keys[ti] = key self._keys_processed[ti] = True - self._small_masks[ti] = self._resize(mask) + self._masks[ti] = mask# self._resize(mask) def all_masks_present(self): return self._keys_processed.sum() == self.length @@ -265,7 +271,7 @@ def w(self): @property def small_masks(self): - return self._small_masks + return self._masks @property def keys(self): From b4308c55be5e5d56dae0358424b42c16153bc66f Mon Sep 17 00:00:00 2001 From: max810 Date: Fri, 24 Mar 2023 15:25:01 +0400 Subject: [PATCH 30/49] Final bugfixes --- environment.yml | 14 +++++- inference/frame_selection/frame_selection.py | 41 ++++++++++------- inference/inference_core.py | 12 ++--- inference/interact/gui.py | 23 ++++++---- inference/interact/resource_manager.py | 42 +++++++++-------- inference/run_on_video.py | 47 +++++++++++++++----- 6 files changed, 118 insertions(+), 61 deletions(-) diff --git a/environment.yml b/environment.yml index 44d3fc0..5abaab2 100644 --- a/environment.yml +++ b/environment.yml @@ -117,6 +117,7 @@ dependencies: - cupy-cuda11x==11.5.0 - cycler==0.11.0 - cython==0.29.32 + - dill==0.3.6 - fastrlock==0.8.1 - filelock==3.8.0 - fonttools==4.38.0 @@ -131,18 +132,26 @@ dependencies: - h5py==3.7.0 - haishoku==1.1.8 - hickle==5.0.2 + - imageio==2.26.0 - importlib-metadata==6.0.0 - kiwisolver==1.4.4 + - lazy-loader==0.1 - markdown==3.4.1 - markupsafe==2.1.2 - matplotlib==3.6.2 + - multiprocess==0.70.14 + - networkx==3.0 - numpy==1.23.5 - nvidia-ml-py==11.495.46 - oauthlib==3.2.2 - opencv-python==4.6.0.66 + - p-tqdm==1.4.0 - packaging==22.0 - pandas==1.5.2 + - pathos==0.3.0 - pillow==9.2.0 + - pox==0.3.2 + - ppft==1.7.6.6 - profilehooks==1.12.0 - progressbar2==4.1.1 - protobuf==3.20.3 @@ -158,11 +167,13 @@ dependencies: - python-graphviz==0.20.1 - python-utils==3.3.3 - pytz==2022.7 + - pywavelets==1.4.1 - requests==2.28.2 - requests-oauthlib==1.3.1 - rsa==4.9 + - scikit-image==0.20.0 - scikit-learn==1.2.0 - - scipy==1.9.3 + - scipy==1.9.1 - seaborn==0.12.2 - setuptools==66.1.1 - smmap==5.0.0 @@ -175,6 +186,7 @@ dependencies: - termcolor==2.2.0 - thin-plate-spline==1.0.1 - thinplate==1.0.0 + - tifffile==2023.2.28 - tomli==2.0.1 - torch-tb-profiler==0.4.1 - torchmetrics==0.9.3 diff --git a/inference/frame_selection/frame_selection.py b/inference/frame_selection/frame_selection.py index c6951db..262acdb 100644 --- a/inference/frame_selection/frame_selection.py +++ b/inference/frame_selection/frame_selection.py @@ -97,14 +97,14 @@ def calculate_proposals_for_annotations_with_iterative_distance_cycle_MASKS(data return chosen_frames -def select_next_candidates(keys: torch.Tensor, masks: List[torch.tensor], num_next_candidates: int, previously_chosen_candidates: List[int] = (0,), print_progress=False, alpha=1.0, min_mask_presence_px=9, device: torch.device = 'cuda:0', progress_callback=None, only_new_candidates=True, epsilon=1e-4, h=30, w=54): +def select_next_candidates(keys: torch.Tensor, shrinkages, selections, masks: List[torch.tensor], num_next_candidates: int, previously_chosen_candidates: List[int] = (0,), print_progress=False, alpha=1.0, min_mask_presence_percent=0.25, device: torch.device = 'cuda:0', progress_callback=None, only_new_candidates=True, epsilon=0.5, h=30, w=54): assert len(keys) == len(masks) assert len(keys) > 0 # assert keys[0].shape[-2:] == masks[0].shape[-2:] assert num_next_candidates > 0 assert len(previously_chosen_candidates) > 0 assert 0.0 <= alpha <= 1.0 - assert min_mask_presence_px >= 0 + assert min_mask_presence_percent >= 0 assert len(previously_chosen_candidates) < len(keys) @@ -140,7 +140,7 @@ def select_next_candidates(keys: torch.Tensor, masks: List[torch.tensor], num_ne ensures that the dissimilarity D(A->A)=0, while D(A->B)>0, and is larger the more different A and B are (pixel-wise). """ - resize = Resize((h, w), interpolation=InterpolationMode.BILINEAR) + resize = Resize((h, w), interpolation=InterpolationMode.NEAREST) with torch.no_grad(): composite_keys = [] @@ -149,21 +149,26 @@ def select_next_candidates(keys: torch.Tensor, masks: List[torch.tensor], num_ne h, w = keys[0].shape[1:3] # removing batch dimension masks_validity = np.full(N, True) + invalid = 0 for i, mask in enumerate(masks): - mask_size_px = (mask > epsilon).sum() + mask_3ch = mask if mask.ndim == 3 else mask.unsqueeze(0) + mask_bin = mask_3ch.max(dim=0).values + mask_size_px = (mask_bin > epsilon).sum() + print(f"{i:3d}", float(mask_size_px / mask_bin.numel() * 100)) - if mask_size_px < min_mask_presence_px: + if mask_size_px / mask_bin.numel() < (min_mask_presence_percent / 100.0): # percentages to ratio masks_validity[i] = False composite_keys.append(None) - print(i, mask_size_px) + invalid += 1 continue mask = resize(mask) composite_key = keys[i] * mask.max(dim=0, keepdim=True).values # any object -> 1., background -> 0.. Keep 1 channel only composite_key = composite_key * alpha + keys[i] * (1 - alpha) - composite_keys.append(composite_key.to(device)) - + composite_keys.append(composite_key.to(dtype=keys[i].dtype, device=device)) + + print(f"INVALID: {invalid} / {len(masks)}") chosen_candidates = list(previously_chosen_candidates) chosen_candidate_keys = [composite_keys[i] for i in chosen_candidates] @@ -175,28 +180,32 @@ def select_next_candidates(keys: torch.Tensor, masks: List[torch.tensor], num_ne # ignore this potential candidate dissimilarity_min_across_all = 0 else: - qk = composite_keys[j].to(device) + qk = composite_keys[j].to(device).unsqueeze(0) + q_shrinkage = shrinkages[j].to(device).unsqueeze(0) + q_selection = selections[j].to(device).unsqueeze(0) + dissimilarities_across_mem_keys = [] - for mem_key in chosen_candidate_keys: - mem_key = mem_key + for mem_idx, mem_key in zip(chosen_candidates, chosen_candidate_keys): + mem_key = mem_key.unsqueeze(0) + mem_shrinkage = shrinkages[mem_idx].to(device).unsqueeze(0) + mem_selection = selections[mem_idx].to(device).unsqueeze(0) - similarity_per_pixel = get_similarity(mem_key, ms=None, qk=qk, qe=None) - reverse_similarity_per_pixel = get_similarity(qk, ms=None, qk=mem_key, qe=None) + similarity_per_pixel = get_similarity(mem_key, ms=mem_shrinkage, qk=qk, qe=q_selection) + reverse_similarity_per_pixel = get_similarity(qk, ms=q_shrinkage, qk=mem_key, qe=mem_selection) # mapping of pixels A -> B would be very similar to B -> A if the images are similar # and very different if the images are different - cycle_dissimilarity_per_pixel = (similarity_per_pixel - reverse_similarity_per_pixel) + cycle_dissimilarity_per_pixel = (similarity_per_pixel - reverse_similarity_per_pixel).to(dtype=torch.float32) # Take non-negative mappings, normalize by tensor size cycle_dissimilarity_score = F.relu(cycle_dissimilarity_per_pixel).sum() / cycle_dissimilarity_per_pixel.numel() - dissimilarities_across_mem_keys.append(cycle_dissimilarity_score) # filtering our existing or very similar frames # if the key has already been used or is very similar to at least one of the chosen candidates # dissimilarity_min_across_all -> 0 (or close to) dissimilarity_min_across_all = min(dissimilarities_across_mem_keys) - + candidate_dissimilarities.append(dissimilarity_min_across_all) index = torch.argmax(torch.tensor(candidate_dissimilarities)) diff --git a/inference/inference_core.py b/inference/inference_core.py index 0981d61..ed9190e 100644 --- a/inference/inference_core.py +++ b/inference/inference_core.py @@ -59,7 +59,7 @@ def encode_frame_key(self, image): need_sk=True) return key, shrinkage, selection - def step(self, image, mask=None, valid_labels=None, end=False, manually_curated_masks=False, disable_memory_updates=False, do_not_add_mask_to_memory=False, return_key=False): + def step(self, image, mask=None, valid_labels=None, end=False, manually_curated_masks=False, disable_memory_updates=False, do_not_add_mask_to_memory=False, return_key_and_stuff=False): # For feedback: # 1. We run the model as usual # 2. We get feedback: 2 lists, one with good prediction indices, one with bad @@ -88,7 +88,7 @@ def step(self, image, mask=None, valid_labels=None, end=False, manually_curated_ key, shrinkage, selection, f16, f8, f4 = self.network.encode_key(image, need_ek=(self.enable_long_term or need_segment), - need_sk=is_mem_frame) + need_sk=True) multi_scale_features = (f16, f8, f4) if disable_memory_updates: @@ -146,8 +146,8 @@ def step(self, image, mask=None, valid_labels=None, end=False, manually_curated_ res = unpad(pred_prob_with_bg, self.pad) - if return_key: - return res, key + if return_key_and_stuff: + return res, key, shrinkage, selection else: return res @@ -167,14 +167,14 @@ def put_to_permanent_memory(self, image, mask, ti=None): pred_prob_with_bg[1:].unsqueeze(0), is_deep_update=False) is_update = self.memory.frame_already_saved(ti) - print(ti, f"update={is_update}") + # print(ti, f"update={is_update}") if self.memory.frame_already_saved(ti): self.memory.update_permanent_memory(ti, key, shrinkage, value, selection=selection if self.enable_long_term else None) else: self.memory.add_memory(key, shrinkage, value, self.all_labels, selection=selection if self.enable_long_term else None, permanent=True, ti=ti) - print(self.memory.permanent_work_mem.key.shape) + # print(self.memory.permanent_work_mem.key.shape) return is_update diff --git a/inference/interact/gui.py b/inference/interact/gui.py index 14623c8..865d243 100644 --- a/inference/interact/gui.py +++ b/inference/interact/gui.py @@ -30,10 +30,10 @@ from PyQt5.QtWidgets import (QWidget, QApplication, QComboBox, QCheckBox, QHBoxLayout, QLabel, QPushButton, QTextEdit, QSpinBox, QFileDialog, QPlainTextEdit, QVBoxLayout, QSizePolicy, QButtonGroup, QSlider, QShortcut, - QRadioButton, QTabWidget, QDialog, QErrorMessage, QMessageBox) + QRadioButton, QTabWidget, QDialog, QErrorMessage, QMessageBox, QLineEdit) -from PyQt5.QtGui import QPixmap, QKeySequence, QImage, QTextCursor, QIcon -from PyQt5.QtCore import Qt, QTimer, QThreadPool +from PyQt5.QtGui import QPixmap, QKeySequence, QImage, QTextCursor, QIcon, QRegExpValidator +from PyQt5.QtCore import Qt, QTimer, QThreadPool, QRegExp from model.network import XMem @@ -336,8 +336,14 @@ def __init__(self, net: XMem, draw_area.addLayout(tabs_layout, 1) candidates_area = QVBoxLayout() + self.candidates_min_mask_size_edit = QLineEdit() + float_validator = QRegExpValidator(QRegExp(r"^(100(\.0+)?|[1-9]?\d(\.\d+)?|0(\.\d+)?)$")) + self.candidates_min_mask_size_edit.setValidator(float_validator) + self.candidates_min_mask_size_edit.setText("0.25") self.candidates_k_slider = NamedSlider("k", 1, 20, 1, default=5) self.candidates_alpha_slider = NamedSlider("α", 0, 100, 1, default=50, multiplier=0.01, min_text='Frames', max_text='Masks') + candidates_area.addWidget(QLabel("Min mask size, % of the total image size, 0-100")) + candidates_area.addWidget(self.candidates_min_mask_size_edit) candidates_area.addWidget(QLabel("Candidates calculation hyperparameters")) candidates_area.addWidget(self.candidates_k_slider) candidates_area.addWidget(self.candidates_alpha_slider) @@ -714,10 +720,10 @@ def on_propagation(self): self.console_push_text('Propagation started.') is_mask = self.cursur in self.reference_ids msk = self.current_prob[1:] if self.cursur in self.reference_ids else None - current_prob, key = self.processor.step(self.current_image_torch, msk, return_key=True) + current_prob, key, shrinkage, selection = self.processor.step(self.current_image_torch, msk, return_key_and_stuff=True) if not is_mask: self.current_prob = current_prob - self.res_man.add_key_with_mask(self.cursur, key, self.current_prob[1:]) + self.res_man.add_key_and_stuff_with_mask(self.cursur, key, shrinkage, selection, self.current_prob[1:]) self.current_mask = torch_prob_to_numpy_mask(self.current_prob) # clear @@ -734,8 +740,8 @@ def on_propagation(self): self.load_current_torch_image_mask(no_mask=True) is_mask = self.cursur in self.reference_ids msk = self.current_prob[1:] if self.cursur in self.reference_ids else None - current_prob, key = self.processor.step(self.current_image_torch, msk, return_key=True) - self.res_man.add_key_with_mask(self.cursur, key, self.current_prob[1:]) + current_prob, key, shrinkage, selection = self.processor.step(self.current_image_torch, msk, return_key_and_stuff=True) + self.res_man.add_key_and_stuff_with_mask(self.cursur, key, shrinkage, selection, self.current_prob[1:]) if not is_mask: self.current_prob = current_prob @@ -797,7 +803,8 @@ def _update_progress(i): k = self.candidates_k_slider.value() alpha = self.candidates_alpha_slider.value() candidate_progress = QProgressDialog("Selecting candidates", None, 0, k, self, Qt.WindowFlags(Qt.WindowType.Dialog | ~Qt.WindowCloseButtonHint)) - worker = Worker(select_next_candidates, self.res_man.keys, self.res_man.small_masks, k, self.reference_ids, print_progress=False, alpha=alpha, min_mask_presence_px=9, h=self.res_man.key_h, w=self.res_man.key_w) # Any other args, kwargs are passed to the run function + worker = Worker(select_next_candidates, self.res_man.keys, self.res_man.shrinkages, self.res_man.selections, self.res_man.small_masks, k, self.reference_ids, + print_progress=False, min_mask_presence_percent=float(self.candidates_min_mask_size_edit.text()), alpha=alpha, h=self.res_man.key_h, w=self.res_man.key_w) # Any other args, kwargs are passed to the run function worker.signals.result.connect(_update_candidates) worker.signals.progress.connect(_update_progress) diff --git a/inference/interact/resource_manager.py b/inference/interact/resource_manager.py index 7ffdada..459c99c 100644 --- a/inference/interact/resource_manager.py +++ b/inference/interact/resource_manager.py @@ -153,7 +153,7 @@ def _copy_resize_frames(self, images): cv2.imwrite(path.join(self.image_dir, image_name), frame) print('Done!') - def add_key_with_mask(self, ti, key, mask): + def add_key_and_stuff_with_mask(self, ti, key, shrinkage, selection, mask): if self._keys is None: c, h, w = key.squeeze().shape if self.key_h is None: @@ -162,12 +162,16 @@ def add_key_with_mask(self, ti, key, mask): self.key_w = w c_mask, h_mask, w_mask = mask.shape self._keys = torch.empty((self.length, c, h, w), dtype=key.dtype, device=key.device) + self._shrinkages = torch.empty((self.length, 1, h, w), dtype=key.dtype, device=key.device) + self._selections = torch.empty((self.length, c, h, w), dtype=key.dtype, device=key.device) self._masks = torch.empty((self.length, c_mask, h_mask, w_mask), dtype=mask.dtype, device=key.device) # self._resize = Resize((h, w), interpolation=InterpolationMode.NEAREST) if not self._keys_processed[ti]: # keys don't change for the video, so we only save them once self._keys[ti] = key + self._shrinkages[ti] = shrinkage + self._selections[ti] = selection self._keys_processed[ti] = True self._masks[ti] = mask# self._resize(mask) @@ -232,27 +236,21 @@ def read_external_image(self, file_name, size=None, force_mask=False): image = np.where(image, 1, 0) # 255 (or whatever) -> binarize return image.astype('uint8') + elif image.mode == 'RGB': + image = image.convert('P', palette=self.palette) + tmp_image = np.array(image) + out_image = np.zeros_like(tmp_image) + for i, c in enumerate(np.unique(tmp_image)): + if i == 0: + continue + out_image[tmp_image == c] = i # palette indices into 0, 1, 2, ... + self.palette = image.getpalette() + return out_image + image = image.convert('P', palette=self.palette) # saved without DAVIS palette, just number objects 0, 1, ... image = np.array(image) return image - - def replace_mask_with_external(self, ti, mask_path): - try: - image = Image.open(mask_path) - assert image.mode in ['L', 'P', 'RGB'] - - if image.mode == 'RGB': - if len(image.getcolors()) <= 2: - image = image.convert('L') - - image = image.convert('P') - image.putpalette(self.palette) - image.save(path.join(self.mask_dir, self.names[ti]+'.png')) - - except (FileNotFoundError, AssertionError, ValueError): - raise ValueError(f"Invalid file: {mask_path}") - def invalidate(self, ti): # the image buffer is never invalidated @@ -277,3 +275,11 @@ def small_masks(self): def keys(self): return self._keys + + @property + def shrinkages(self): + return self._shrinkages + + @property + def selections(self): + return self._selections diff --git a/inference/run_on_video.py b/inference/run_on_video.py index 5f5d10e..bdbfdaf 100644 --- a/inference/run_on_video.py +++ b/inference/run_on_video.py @@ -1,5 +1,6 @@ import os from os import PathLike, path +from time import perf_counter from typing import Iterable, Literal, Union, List from collections import defaultdict from pathlib import Path @@ -17,11 +18,11 @@ from model.network import XMem from util.image_saver import create_overlay, save_image -from util.tensor_util import compute_tensor_iou +from util.tensor_util import compute_array_iou, compute_tensor_iou from inference.inference_core import InferenceCore from inference.data.video_reader import VideoReader from inference.data.mask_mapper import MaskMapper -from inference.frame_selection.frame_selection import KNOWN_ANNOTATION_PREDICTORS +# from inference.frame_selection.frame_selection import KNOWN_ANNOTATION_PREDICTORS from inference.frame_selection.frame_selection_utils import disparity_func, get_determenistic_augmentations @@ -48,7 +49,7 @@ def _inference_on_video(frames_with_masks, imgs_in_path, masks_in_path, masks_ou overwrite_config: dict = None, frame_selector_func: callable = None, save_overlay=True, - b_and_w_color=(255, 0, 0)): + b_and_w_color=(255, 0, 0), measure_fps=False): torch.autograd.set_grad_enabled(False) frames_with_masks = set(frames_with_masks) config = { @@ -125,13 +126,14 @@ def _inference_on_video(frames_with_masks, imgs_in_path, masks_in_path, masks_ou if only_predict_frames_to_annotate_and_quit > 0: assert frame_selector_func is not None chosen_annotation_candidate_frames = frame_selector_func( - loader, processor, print_progress=print_progress, how_many_frames=only_predict_frames_to_annotate_and_quit) + loader, processor, print_progress=print_progress, how_many_frames=only_predict_frames_to_annotate_and_quit, existing_masks_path=masks_out_path) return chosen_annotation_candidate_frames frames_ = [] masks_ = [] + total_preloading_time = 0.0 if original_memory_mechanism: # only the first frame goes into permanent memory originally frames_to_put_in_permanent_memory = [0] @@ -158,7 +160,10 @@ def _inference_on_video(frames_with_masks, imgs_in_path, masks_in_path, masks_ou msk = vid_reader.resize_mask(msk.unsqueeze(0))[0] processor.set_all_labels(list(mapper.remappings.values())) + a = perf_counter() processor.put_to_permanent_memory(rgb, msk) + b = perf_counter() + total_preloading_time += (b - a) if not first_mask_loaded: first_mask_loaded = True @@ -187,6 +192,7 @@ def _inference_on_video(frames_with_masks, imgs_in_path, masks_in_path, masks_ou if compute_uncertainty and uncertainty_name == 'bald': bald = BALD() + total_processing_time = 0.0 for ti, data in enumerate(tqdm(loader, disable=not print_progress)): with torch.cuda.amp.autocast(enabled=True): rgb = data['rgb'].cuda()[0] @@ -251,6 +257,7 @@ def _inference_on_video(frames_with_masks, imgs_in_path, masks_in_path, masks_ou do_not_add_mask_to_memory = msk is not None # Run the model on this frame # 2+ channels, classes+ and background + a = perf_counter() prob = processor.step(rgb, msk, labels, end=(ti == vid_length-1), manually_curated_masks=manually_curated_masks, do_not_add_mask_to_memory=do_not_add_mask_to_memory) @@ -280,14 +287,16 @@ def _inference_on_video(frames_with_masks, imgs_in_path, masks_in_path, masks_ou # Probability mask -> index mask out_mask = torch.argmax(prob, dim=0) out_mask = (out_mask.detach().cpu().numpy()).astype(np.uint8) + b = perf_counter() + total_processing_time += (b - a) if compute_iou: # mask is [0, 1] # gt is [0, 255] # both -> [False, True] - if gt is not None: - iou = float(compute_tensor_iou(torch.tensor( - out_mask).type(torch.bool), gt.type(torch.bool))) + if gt is not None and msk is None: + # skipping frames + iou = float(compute_array_iou(out_mask, gt)) else: iou = -1 curr_stat['iou'] = iou @@ -317,6 +326,13 @@ def _inference_on_video(frames_with_masks, imgs_in_path, masks_in_path, masks_ou stats.append(curr_stat) + if measure_fps: + print(f"TOTAL PRELOADING TIME: {total_preloading_time:.4f}s") + print(f"TOTAL PROCESSING TIME: {total_processing_time:.4f}s") + print(f"TOTAL TIME: {total_preloading_time + total_processing_time:.4f}s") + print(f"TOTAL PROCESSING FPS: {len(loader) / total_processing_time:.4f}") + print(f"TOTAL FPS: {len(loader) / (total_preloading_time + total_processing_time):.4f}") + return pd.DataFrame(stats) @@ -362,9 +378,11 @@ def run_on_video( def predict_annotation_candidates( imgs_in_path: Union[str, PathLike], - approach: str, + candidate_selection_function: callable, + masks_in_path: Union[str, PathLike] = None, num_candidates: int = 1, print_progress=True, + **kwargs ) -> List[int]: """ Args: @@ -379,18 +397,23 @@ def predict_annotation_candidates( annotation_candidates (List[int]): A list of frames indices (0-based) chosen as annotation candidates, sorted by importance (most -> least). Always contains [0] - first frame - at index 0. """ - candidate_selection_function = KNOWN_ANNOTATION_PREDICTORS[approach] + # candidate_selection_function = KNOWN_ANNOTATION_PREDICTORS[approach] assert num_candidates >= 1 if num_candidates == 1: return [0] # First frame is hard-coded to always be used + try: + masks_out_path = kwargs.pop('masks_out_path') + except KeyError: + masks_out_path = None + return _inference_on_video( imgs_in_path=imgs_in_path, - masks_in_path=imgs_in_path, # Ignored - masks_out_path=None, # Ignored - frames_with_masks=[0], # Ignored + masks_in_path=masks_in_path, # Ignored + masks_out_path=masks_out_path, # Used for some frame selectors + frames_with_masks=[0], compute_uncertainty=False, compute_iou=False, print_progress=print_progress, From d5c899024699d9f87bd6f492f5ef4363696f5ff7 Mon Sep 17 00:00:00 2001 From: max810 Date: Fri, 24 Mar 2023 15:44:28 +0400 Subject: [PATCH 31/49] Updated the environment file --- environment.yml | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/environment.yml b/environment.yml index 5abaab2..bd478d8 100644 --- a/environment.yml +++ b/environment.yml @@ -1,4 +1,4 @@ -name: XMem +name: XMemA channels: - pytorch - conda-forge @@ -164,7 +164,7 @@ dependencies: - pyqt5-qt5==5.15.2 - pyqt5-sip==12.11.0 - python-dateutil==2.8.2 - - python-graphviz==0.20.1 + # - python-graphviz==0.20.1 - python-utils==3.3.3 - pytz==2022.7 - pywavelets==1.4.1 @@ -185,7 +185,7 @@ dependencies: - tensorboard-plugin-wit==1.8.1 - termcolor==2.2.0 - thin-plate-spline==1.0.1 - - thinplate==1.0.0 + # - thinplate==1.0.0 - tifffile==2023.2.28 - tomli==2.0.1 - torch-tb-profiler==0.4.1 @@ -198,5 +198,4 @@ dependencies: - urllib3==1.26.14 - werkzeug==2.2.2 - wheel==0.38.4 - - zipp==3.12.0 -prefix: /home/maksym/miniconda3/envs/XMem + - zipp==3.12.0 \ No newline at end of file From 6be53463eb57f427949fe9354d6d5de061f193b5 Mon Sep 17 00:00:00 2001 From: max810 Date: Thu, 6 Apr 2023 13:19:01 +0400 Subject: [PATCH 32/49] Full propagation now clear temp and long memory automatically --- inference/frame_selection/frame_selection.py | 1 - inference/interact/gui.py | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/inference/frame_selection/frame_selection.py b/inference/frame_selection/frame_selection.py index 262acdb..f4cfbcf 100644 --- a/inference/frame_selection/frame_selection.py +++ b/inference/frame_selection/frame_selection.py @@ -154,7 +154,6 @@ def select_next_candidates(keys: torch.Tensor, shrinkages, selections, masks: Li mask_3ch = mask if mask.ndim == 3 else mask.unsqueeze(0) mask_bin = mask_3ch.max(dim=0).values mask_size_px = (mask_bin > epsilon).sum() - print(f"{i:3d}", float(mask_size_px / mask_bin.numel() * 100)) if mask_size_px / mask_bin.numel() < (min_mask_presence_percent / 100.0): # percentages to ratio masks_validity[i] = False diff --git a/inference/interact/gui.py b/inference/interact/gui.py index 865d243..adcd17f 100644 --- a/inference/interact/gui.py +++ b/inference/interact/gui.py @@ -677,6 +677,7 @@ def general_propagation_callback(self, propagation_type: str): self.on_backward_propagation() def on_full_propagation(self): + self.on_clear_memory() self.scroll_to(0) self.on_forward_propagation() From 612afa891b373f18d8d604fc35f196468a2ccbe4 Mon Sep 17 00:00:00 2001 From: max810 Date: Fri, 16 Jun 2023 11:53:14 +0400 Subject: [PATCH 33/49] Minor cleanup --- inference/frame_selection/frame_selection.py | 26 +++++-- .../frame_selection/frame_selection_utils.py | 15 ++-- inference/run_experiments.py | 30 +++++++- inference/run_on_video.py | 69 ++++++++++++++++++- 4 files changed, 125 insertions(+), 15 deletions(-) diff --git a/inference/frame_selection/frame_selection.py b/inference/frame_selection/frame_selection.py index f4cfbcf..908aeaa 100644 --- a/inference/frame_selection/frame_selection.py +++ b/inference/frame_selection/frame_selection.py @@ -140,13 +140,13 @@ def select_next_candidates(keys: torch.Tensor, shrinkages, selections, masks: Li ensures that the dissimilarity D(A->A)=0, while D(A->B)>0, and is larger the more different A and B are (pixel-wise). """ - resize = Resize((h, w), interpolation=InterpolationMode.NEAREST) with torch.no_grad(): composite_keys = [] keys = keys.squeeze() N = len(keys) h, w = keys[0].shape[1:3] # removing batch dimension + resize = Resize((h, w), interpolation=InterpolationMode.NEAREST) masks_validity = np.full(N, True) invalid = 0 @@ -155,12 +155,24 @@ def select_next_candidates(keys: torch.Tensor, shrinkages, selections, masks: Li mask_bin = mask_3ch.max(dim=0).values mask_size_px = (mask_bin > epsilon).sum() - if mask_size_px / mask_bin.numel() < (min_mask_presence_percent / 100.0): # percentages to ratio - masks_validity[i] = False - composite_keys.append(None) - invalid += 1 - continue + ratio = mask_size_px / mask_bin.numel() * 100 + if ratio < min_mask_presence_percent: # percentages to ratio + if i not in previously_chosen_candidates: + # if it's previously chosen, it's okay, we don't test for their validity + # e.g. we select frame #J, because we predicted something for it + # but in reality it's actually empty, so gt=0 + # so next iteration will break + masks_validity[i] = False + composite_keys.append(None) + invalid += 1 + + continue + # if it's previously chosen, it's okay + # if i in previously_chosen_candidates: + # print(f"{i} previous candidate would be invalid (ratio perc={ratio})") + # raise ValueError(f"Given min_mask_presence_percent={min_mask_presence_percent}, even the previous candidates will be ignored. Reduce the value to avoid the error.") + mask = resize(mask) composite_key = keys[i] * mask.max(dim=0, keepdim=True).values # any object -> 1., background -> 0.. Keep 1 channel only composite_key = composite_key * alpha + keys[i] * (1 - alpha) @@ -190,7 +202,7 @@ def select_next_candidates(keys: torch.Tensor, shrinkages, selections, masks: Li mem_selection = selections[mem_idx].to(device).unsqueeze(0) similarity_per_pixel = get_similarity(mem_key, ms=mem_shrinkage, qk=qk, qe=q_selection) - reverse_similarity_per_pixel = get_similarity(qk, ms=q_shrinkage, qk=mem_key, qe=mem_selection) + reverse_similarity_per_pixel = get_similarity(qk, ms=q_shrinkage, qk=mem_key, qe=mem_selection) # mapping of pixels A -> B would be very similar to B -> A if the images are similar # and very different if the images are different diff --git a/inference/frame_selection/frame_selection_utils.py b/inference/frame_selection/frame_selection_utils.py index 59865d7..1d3bc85 100644 --- a/inference/frame_selection/frame_selection_utils.py +++ b/inference/frame_selection/frame_selection_utils.py @@ -6,10 +6,10 @@ import torch import torchvision.transforms.functional as FT from torchvision.transforms import ColorJitter, Grayscale, RandomPosterize, RandomAdjustSharpness, ToTensor, RandomAffine +from tqdm import tqdm -# -def extract_keys(dataloder, processor, print_progress=False): +def extract_keys(dataloder, processor, print_progress=False, flatten=True, **kwargs): frame_keys = [] shrinkages = [] selections = [] @@ -29,9 +29,14 @@ def extract_keys(dataloder, processor, print_progress=False): key_sum += key.type(torch.float64) - frame_keys.append(key.flatten(start_dim=2).cpu()) - shrinkages.append(shrinkage.flatten(start_dim=2).cpu()) - selections.append(selection.flatten(start_dim=2).cpu()) + if flatten: + key = key.flatten(start_dim=2) + shrinkage = shrinkage.flatten(start_dim=2) + selection = selection.flatten(start_dim=2) + + frame_keys.append(key.cpu()) + shrinkages.append(shrinkage.cpu()) + selections.append(selection.cpu()) num_frames = ti + 1 # 0 after 1 iteration, 1 after 2, etc. diff --git a/inference/run_experiments.py b/inference/run_experiments.py index d980d44..c25961d 100644 --- a/inference/run_experiments.py +++ b/inference/run_experiments.py @@ -10,6 +10,7 @@ from tqdm import tqdm from matplotlib import pyplot as plt from PIL import Image +from inference.frame_selection.frame_selection import uniformly_selected_frames from util.metrics import batched_f_measure, batched_jaccard from p_tqdm import p_umap @@ -236,6 +237,28 @@ def run_inference_with_pre_chosen_frames(chosen_frames_csv_path: str, videos_inf # with open(f'output/AL_comparison_all_methods/ious_{video_name}_all_methods.json', 'wt') as f_out: # json.dump(ious, f_out) +def run_inference_with_uniform_frames(videos_info: Dict[str, Dict], output_path: str, **kwargs): + num_runs = len(videos_info) + + i = 0 + p_bar = tqdm(desc='Running inference comparing multiple different AL approaches', total=num_runs) + + for video_name, info in videos_info.items(): + frames = os.listdir(info['video_frames_path']) + chosen_frames = uniformly_selected_frames(frames, how_many_frames=info['num_annotation_candidates']) + + video_frames_path = info['video_frames_path'] + video_masks_path = info['video_masks_path'] + + output_masks_path = Path(output_path) / video_name + try: + stats = run_on_video(video_frames_path, video_masks_path, output_masks_path, + frames_with_masks=chosen_frames, compute_iou=False, print_progress=False, **kwargs) + except ValueError as e: + print(f"[!!!] {e}") + p_bar.update() + i += 1 + def visualize_chosen_frames(video_name: str, num_total_frames: int, data: pd.Series, output_path: str): def _sort_index(series): @@ -312,6 +335,8 @@ def _proc(p_video: Path): } for p_method in p_video.iterdir(): + if not p_method.is_dir(): + continue method_name = p_method.name p_masks = p_method / 'masks' preds = _load_preds(p_masks, palette=first_mask, size=(w, h)) @@ -348,11 +373,14 @@ def _proc(p_video: Path): results = pd.DataFrame.from_records(list_of_stats).dropna(axis='columns').set_index('video_name') return results -def compute_metrics(p_source_masks, p_preds): +def compute_metrics(p_source_masks, p_preds, pred_to_annot_names_lookup=None): list_of_stats = [] # for p_pred_video in list(p_preds.iterdir()): def _proc(p_pred_video: Path): video_name = p_pred_video.name + if pred_to_annot_names_lookup is not None: + video_name = pred_to_annot_names_lookup[video_name] + # if 'XMem' in str(p_pred_video): p_pred_video = Path(p_pred_video) / 'masks' p_gts = p_source_masks / video_name diff --git a/inference/run_on_video.py b/inference/run_on_video.py index bdbfdaf..798f66f 100644 --- a/inference/run_on_video.py +++ b/inference/run_on_video.py @@ -1,21 +1,25 @@ +from functools import partial import os from os import PathLike, path +from tempfile import TemporaryDirectory from time import perf_counter from typing import Iterable, Literal, Union, List from collections import defaultdict from pathlib import Path +from warnings import warn import numpy as np import pandas as pd import torch import torch.nn.functional as F -from torchvision.transforms import functional as FT +from torchvision.transforms import functional as FT, ToTensor from torch.utils.data import DataLoader from baal.active.heuristics import BALD from scipy.stats import entropy from tqdm import tqdm from PIL import Image +from inference.frame_selection.frame_selection import select_next_candidates from model.network import XMem from util.image_saver import create_overlay, save_image from util.tensor_util import compute_array_iou, compute_tensor_iou @@ -23,7 +27,7 @@ from inference.data.video_reader import VideoReader from inference.data.mask_mapper import MaskMapper # from inference.frame_selection.frame_selection import KNOWN_ANNOTATION_PREDICTORS -from inference.frame_selection.frame_selection_utils import disparity_func, get_determenistic_augmentations +from inference.frame_selection.frame_selection_utils import disparity_func, extract_keys, get_determenistic_augmentations def save_frames(dataset, frame_indices, output_folder): @@ -384,6 +388,9 @@ def predict_annotation_candidates( print_progress=True, **kwargs ) -> List[int]: + + warn('predict_annotation_candidates is deprecated, used ', DeprecationWarning, stacklevel=2) + """ Args: imgs_in_path (Union[str, PathLike]): Path to the directory containing video frames in the following format: `frame_000000.png` .jpg works too. @@ -421,3 +428,61 @@ def predict_annotation_candidates( only_predict_frames_to_annotate_and_quit=num_candidates, frame_selector_func=candidate_selection_function ) + +def select_k_next_best_annotation_candidates( + imgs_in_path: Union[str, PathLike], + masks_in_path: Union[str, PathLike], + k: int = 5, + print_progress=True, + previously_chosen_candidates=[0], + **kwargs +): + # extracting the keys and corresponding matrices + keys, shrinkages, selections, *_ = _inference_on_video( + imgs_in_path=imgs_in_path, + masks_in_path=masks_in_path, # Ignored + masks_out_path=None, # Used for some frame selectors + frames_with_masks=previously_chosen_candidates, + compute_uncertainty=False, + compute_iou=False, + print_progress=print_progress, + manually_curated_masks=False, + only_predict_frames_to_annotate_and_quit=True, # exact number is ignored here + frame_selector_func=partial(extract_keys, flatten=False), + **kwargs + ) + + # running inference once to obtain masks + to_tensor = ToTensor() + with TemporaryDirectory() as d: + p_masks_out = Path(d) + _inference_on_video( + imgs_in_path=imgs_in_path, + masks_in_path=masks_in_path, # Ignored + masks_out_path=p_masks_out, # Used for some frame selectors + frames_with_masks=previously_chosen_candidates, + compute_uncertainty=False, + compute_iou=False, + print_progress=print_progress, + manually_curated_masks=False, + **kwargs + ) + + masks = [to_tensor(Image.open(p)) for p in sorted((p_masks_out / 'masks').iterdir())] + + keys = torch.cat(keys) + shrinkages = torch.cat(shrinkages) + selections = torch.cat(selections) + + # TODO: fix shapes + print(f"[xxx] Running with previously chosen candidates: {previously_chosen_candidates}") + min_mask_presence_percent = 0.25 + try: + all_selected_frames = select_next_candidates(keys, shrinkages=shrinkages, selections=selections, masks=masks, num_next_candidates=k, previously_chosen_candidates=previously_chosen_candidates, print_progress=print_progress, alpha=0.5, only_new_candidates=False, min_mask_presence_percent=min_mask_presence_percent) + except ValueError: + print(f"INVALID in video {imgs_in_path}") + min_mask_presence_percent = 0.01 + all_selected_frames = select_next_candidates(keys, shrinkages=shrinkages, selections=selections, masks=masks, num_next_candidates=k, previously_chosen_candidates=previously_chosen_candidates, print_progress=print_progress, alpha=0.5, only_new_candidates=False, min_mask_presence_percent=min_mask_presence_percent) + + + return all_selected_frames \ No newline at end of file From 6bed4e1016a1e0d037c6f236cf682194c083c416 Mon Sep 17 00:00:00 2001 From: max810 Date: Thu, 22 Jun 2023 17:21:39 +0400 Subject: [PATCH 34/49] Quality of Life improvements: saving and loading references now --- .gitignore | 4 +++- inference/interact/gui.py | 12 ++++++++++- inference/interact/resource_manager.py | 30 +++++++++++++++++++++++++- 3 files changed, 43 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index 3198e73..096f495 100644 --- a/.gitignore +++ b/.gitignore @@ -145,4 +145,6 @@ torchlogs/ *junk* *.profile -.DS_Store \ No newline at end of file +.DS_Store +*.jpg +*.zip \ No newline at end of file diff --git a/inference/interact/gui.py b/inference/interact/gui.py index adcd17f..20ea915 100644 --- a/inference/interact/gui.py +++ b/inference/interact/gui.py @@ -418,6 +418,7 @@ def __init__(self, net: XMem, self.show_current_frame() self.show() self.style_new_reference() + self.load_existing_references() self.console_push_text('Initialized.') self.initialized = True @@ -567,6 +568,12 @@ def style_new_reference(self): self.save_reference_button.setText("Save reference") self.save_reference_button.setStyleSheet('QPushButton {background-color: #14A44D; font-weight: bold;}') + def load_existing_references(self): + for i in self.res_man.references: + self.scroll_to(i) + self.on_save_reference() + self.scroll_to(0) + def pixel_pos_to_image_pos(self, x, y): # Un-scale and un-pad the label coordinates into image coordinates oh, ow = self.image_size.height(), self.image_size.width() @@ -832,6 +839,7 @@ def on_save_reference(self): self.reference_ids.add(self.cursur) self.references_collection.add_image(self.cursur) + self.res_man.add_reference(self.cursur) if self.cursur in self.candidates_ids: self.candidates_ids.remove(self.cursur) @@ -843,6 +851,8 @@ def on_save_reference(self): def on_remove_reference(self, img_idx): self.processor.remove_from_permanent_memory(img_idx) self.reference_ids.remove(img_idx) + self.res_man.remove_reference(self.cursur) + self.show_current_frame() def on_prev_frame(self): @@ -1159,7 +1169,7 @@ def on_import_mask(self, mask_file_path=None): else: qm = QMessageBox(QMessageBox.Icon.Question, "Confirm mask replacement", "") question = f"Replace mask for current frame {self.cursur} with {Path(file_name).name}?" - ret = qm.question(self, 'Confirm mask replacemen', question, qm.Yes | qm.No) + ret = qm.question(self, 'Confirm mask replacement', question, qm.Yes | qm.No) if ret == qm.Yes: self.console_push_text(f'Mask file {file_name} loaded.') diff --git a/inference/interact/resource_manager.py b/inference/interact/resource_manager.py index 459c99c..c0eb155 100644 --- a/inference/interact/resource_manager.py +++ b/inference/interact/resource_manager.py @@ -1,3 +1,4 @@ +import json import os from os import path from pathlib import Path @@ -53,7 +54,8 @@ def __init__(self, config): if self.workspace is None: if images is not None: p_images = Path(images) - if p_images.name == 'JPEGImages': + if p_images.name == 'JPEGImages' or (Path.cwd() / 'workspace') in p_images.parents: + # take the name instead of actual images dir (second case checks for videos already in ./workspace ) basename = p_images.parent.name else: basename = p_images.name @@ -66,6 +68,9 @@ def __init__(self, config): self.workspace = path.join('./workspace', basename) print(f'Workspace is in: {self.workspace}') + self.workspace_info_file = path.join(self.workspace, 'info.json') + self.references = set() + self._try_load_references() # determine the location of input images need_decoding = False @@ -178,6 +183,29 @@ def add_key_and_stuff_with_mask(self, ti, key, shrinkage, selection, mask): def all_masks_present(self): return self._keys_processed.sum() == self.length + + def add_reference(self, frame_id: int): + self.references.add(frame_id) + self._save_references() + + def remove_reference(self, frame_id: int): + self.references.remove(frame_id) + self._save_references() + + def _save_references(self): + with open(self.workspace_info_file, 'wt') as f: + data = {'references': sorted(self.references)} + + json.dump(data, f) + + def _try_load_references(self): + try: + with open(self.workspace_info_file) as f: + data = json.load(f) + self.references = set(data['references']) + except Exception: + pass + def save_mask(self, ti, mask): # mask should be uint8 H*W without channels From 30068ffb459fd75ec01dbb43914822e279b7c4b1 Mon Sep 17 00:00:00 2001 From: max810 Date: Thu, 22 Jun 2023 17:53:40 +0400 Subject: [PATCH 35/49] Added utility scripts as a submodule --- .gitmodules | 3 +++ XMem_utilities | 1 + 2 files changed, 4 insertions(+) create mode 100644 .gitmodules create mode 160000 XMem_utilities diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..ed43ff8 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "XMem_utilities"] + path = XMem_utilities + url = git@github.com:max810/Xmem_utility_scripts.git diff --git a/XMem_utilities b/XMem_utilities new file mode 160000 index 0000000..46df79e --- /dev/null +++ b/XMem_utilities @@ -0,0 +1 @@ +Subproject commit 46df79ecde808089cd9244ba43d8f28863c9fa71 From e45760c8217be6532f22560123ce71fd9035c4d2 Mon Sep 17 00:00:00 2001 From: max810 Date: Thu, 22 Jun 2023 17:54:51 +0400 Subject: [PATCH 36/49] Moved train_s0.sh to utility scripts --- train_s0.sh | 18 ------------------ 1 file changed, 18 deletions(-) delete mode 100755 train_s0.sh diff --git a/train_s0.sh b/train_s0.sh deleted file mode 100755 index f940a3a..0000000 --- a/train_s0.sh +++ /dev/null @@ -1,18 +0,0 @@ -STAGE=0 -B=4 # original 16 -LR=2.5e-06 # original 1e-5 / (16/4) batch size difference -START_WARM=80000 # original 20K iterations * (16/4) batch size difference -END_WARM=280000 # original 70K iterations * (16/4) batch size difference -NUM_ITER=600000 # original 150K iterations * (16/4) batch size difference -NUM_WORKERS=6 - -OMP_NUM_THREADS=6 python -m torch.distributed.run --master_port 25763 --nproc_per_node=2 train.py \ - --exp_id xmem_multiscale \ - --stage "$STAGE" \ - --s0_batch_size="$B" \ - --s0_lr="$LR" \ - --s0_start_warm="$START_WARM" \ - --s0_end_warm="$END_WARM" \ - --s0_iterations="$NUM_ITER" \ - --num_workers="$NUM_WORKERS" \ - --load_checkpoint='last' From fcfbc74db38b202eae92d914c6629c1bcbc91765 Mon Sep 17 00:00:00 2001 From: max810 Date: Thu, 27 Jul 2023 13:09:45 +0400 Subject: [PATCH 37/49] App now stores last used num_objects, so no need to specify it evey time. Also switched default brush to Free for convenience --- inference/interact/gui.py | 4 ++-- inference/interact/resource_manager.py | 24 ++++++++++++++++++------ interactive_demo.py | 13 ++++++++----- 3 files changed, 28 insertions(+), 13 deletions(-) diff --git a/inference/interact/gui.py b/inference/interact/gui.py index 20ea915..2b1c997 100644 --- a/inference/interact/gui.py +++ b/inference/interact/gui.py @@ -145,7 +145,7 @@ def __init__(self, net: XMem, self.save_visualization = True # Radio buttons for type of interactions - self.curr_interaction = 'Click' + self.curr_interaction = 'Free' self.interaction_group = QButtonGroup() self.radio_fbrs = QRadioButton('Click') self.radio_s2m = QRadioButton('Scribble') @@ -156,7 +156,7 @@ def __init__(self, net: XMem, self.radio_fbrs.toggled.connect(self.interaction_radio_clicked) self.radio_s2m.toggled.connect(self.interaction_radio_clicked) self.radio_free.toggled.connect(self.interaction_radio_clicked) - self.radio_fbrs.toggle() + self.radio_free.toggle() # Main canvas -> QLabel self.main_canvas = QLabel() diff --git a/inference/interact/resource_manager.py b/inference/interact/resource_manager.py index c0eb155..dc193a9 100644 --- a/inference/interact/resource_manager.py +++ b/inference/interact/resource_manager.py @@ -70,7 +70,14 @@ def __init__(self, config): print(f'Workspace is in: {self.workspace}') self.workspace_info_file = path.join(self.workspace, 'info.json') self.references = set() - self._try_load_references() + self._num_objects = None + self._try_load_info() + + if config['num_objects'] is not None: # forced overwrite from user + self._num_objects = config['num_objects'] + elif self._num_objects is None: # both are None, single object first run use case + self._num_objects = config['num_objects_default_value'] + self._save_info() # determine the location of input images need_decoding = False @@ -186,23 +193,24 @@ def all_masks_present(self): def add_reference(self, frame_id: int): self.references.add(frame_id) - self._save_references() + self._save_info() def remove_reference(self, frame_id: int): self.references.remove(frame_id) - self._save_references() + self._save_info() - def _save_references(self): + def _save_info(self): with open(self.workspace_info_file, 'wt') as f: - data = {'references': sorted(self.references)} + data = {'references': sorted(self.references), 'num_objects': self._num_objects} json.dump(data, f) - def _try_load_references(self): + def _try_load_info(self): try: with open(self.workspace_info_file) as f: data = json.load(f) self.references = set(data['references']) + self._num_objects = data['num_objects'] except Exception: pass @@ -311,3 +319,7 @@ def shrinkages(self): @property def selections(self): return self._selections + + @property + def num_objects(self): + return self._num_objects diff --git a/interactive_demo.py b/interactive_demo.py index fef2140..9dcba6b 100644 --- a/interactive_demo.py +++ b/interactive_demo.py @@ -46,7 +46,7 @@ parser.add_argument('--buffer_size', help='Correlate with CPU memory consumption', type=int, default=100) - parser.add_argument('--num_objects', type=int, default=1) + parser.add_argument('--num_objects', type=int, default=None) # Long-memory options # Defaults. Some can be changed in the GUI. @@ -81,15 +81,18 @@ else: s2m_model = None - s2m_controller = S2MController(s2m_model, args.num_objects, ignore_class=255) + # Manages most IO + config['num_objects_default_value'] = 1 + resource_manager = ResourceManager(config) + num_objects = resource_manager.num_objects + config['num_objects'] = num_objects + + s2m_controller = S2MController(s2m_model, num_objects, ignore_class=255) if args.fbrs_model is not None: fbrs_controller = FBRSController(args.fbrs_model) else: fbrs_controller = None - # Manages most IO - resource_manager = ResourceManager(config) - app = QApplication(sys.argv) ex = App(network, resource_manager, s2m_controller, fbrs_controller, config) sys.exit(app.exec_()) From 919e3bc72e74c12d717d8b19fcce7345339b5e4c Mon Sep 17 00:00:00 2001 From: max810 Date: Thu, 27 Jul 2023 14:36:44 +0400 Subject: [PATCH 38/49] Added a script to import existing images and masks; if masks are imported, will also calculate and save num_objects --- import_existing.py | 98 ++++++++++++++++++++++++++ inference/interact/resource_manager.py | 6 +- 2 files changed, 102 insertions(+), 2 deletions(-) create mode 100644 import_existing.py diff --git a/import_existing.py b/import_existing.py new file mode 100644 index 0000000..a00b255 --- /dev/null +++ b/import_existing.py @@ -0,0 +1,98 @@ +import json +from pathlib import Path +import argparse + +import numpy as np +from PIL import Image +import progressbar +from tqdm import tqdm + + +def resize_preserve(img, size, interpolation): + h, w = img.height, img.width + # Resize preserving aspect ratio + new_w = (w*size//min(w, h)) + new_h = (h*size//min(w, h)) + + resized_img = img.resize((new_w, new_h), resample=interpolation) + + return resized_img + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--name', type=str, help='The name of the project to use (name of the corresponding folder in the workspace). Will be created if doesn\'t exist ', required=True) + parser.add_argument('--size', type=str, help='The name of the project to use (name of the corresponding folder in the workspace). Will be created if doesn\'t exist ', default=480) + parser.add_argument('--images', type=str, help='Path to the folder with video frames', required=False) + parser.add_argument('--masks', type=str, help='Path to the folder with existing masks', required=False) + + args = parser.parse_args() + p_project = Path('workspace') / str(args.name) + if p_project.exists(): + print(f"Found the project {args.name} in the workspace.") + else: + print(f"Creating new project {args.name} in the workspace.") + + if args.images is not None: + p_imgs = Path(args.images) + p_imgs_out = p_project / 'images' + p_imgs_out.mkdir(parents=True, exist_ok=True) + + if any(p_imgs_out.iterdir()): + print(f"The project {args.name} already has images in the workspace. Delete them first.") + exit(0) + + img_files = sorted(p_imgs.iterdir()) + + for i in progressbar.progressbar(range(len(img_files)), prefix="Copying/resizing images..."): + p_img = img_files[i] + img = Image.open(p_img) + resized_img = resize_preserve(img, args.size, Image.Resampling.BILINEAR) + resized_img.save(p_imgs_out / f'frame_{i:06d}{p_img.suffix}') # keep the same image format + + if args.masks is not None: + p_masks = Path(args.masks) + p_masks_out = p_project / 'masks' + p_masks_out.mkdir(parents=True, exist_ok=True) + + if any(p_masks_out.iterdir()): + print(f"The project {args.name} already has masks in the workspace. Delete them first.") + exit(0) + + from util.palette import davis_palette + lookup_table = np.full(256, 0, dtype=np.uint8) + num_objects = 0 + + mask_files = sorted(p_masks.iterdir()) + + for i in progressbar.progressbar(range(len(mask_files)), prefix="Copying/resizing masks; converting to DAVIS color palette..."): + p_mask = mask_files[i] + mask = Image.open(p_mask) + resized_mask = resize_preserve(mask, args.size, Image.Resampling.NEAREST).convert('P') + + unique_colors = resized_mask.getcolors() + for _, c in unique_colors: + if c == 0: + continue + elif lookup_table[c] == 0: + num_objects += 1 + lookup_table[c] = num_objects + + # We need range indices like (0, 1, 2, 3, 4) for each unique color, in any order, as long as black is still 0 + # If the new colors appear, we'll treat them as new objects + index_array = lookup_table[resized_mask] + index_mask = Image.fromarray(index_array, mode='P') + index_mask.putpalette(davis_palette) + + index_mask.save(p_masks_out / f'frame_{i:06d}{p_mask.suffix}') # keep the same image form + + try: + with open(p_project / 'info.json') as f: + data = json.load(f) + except Exception: + data = {} + + data['num_objects'] = num_objects + + with open(p_project / 'info.json', 'wt') as f_out: + json.dump(data, f_out, indent=4) + diff --git a/inference/interact/resource_manager.py b/inference/interact/resource_manager.py index dc193a9..9389016 100644 --- a/inference/interact/resource_manager.py +++ b/inference/interact/resource_manager.py @@ -203,14 +203,16 @@ def _save_info(self): with open(self.workspace_info_file, 'wt') as f: data = {'references': sorted(self.references), 'num_objects': self._num_objects} - json.dump(data, f) + json.dump(data, f, indent=4) def _try_load_info(self): try: with open(self.workspace_info_file) as f: data = json.load(f) - self.references = set(data['references']) self._num_objects = data['num_objects'] + + # We might have num_objects, but not references if imported the project + self.references = set(data['references']) except Exception: pass From f71e4bf1aceaf5010c27adf6d569509b53f6c89e Mon Sep 17 00:00:00 2001 From: max810 Date: Thu, 27 Jul 2023 16:17:20 +0400 Subject: [PATCH 39/49] Added color picker on the side to easily switch between object annotations --- inference/interact/gui.py | 7 +++++ inference/interact/gui_utils.py | 56 ++++++++++++++++++++++++++++++--- 2 files changed, 59 insertions(+), 4 deletions(-) diff --git a/inference/interact/gui.py b/inference/interact/gui.py index 2b1c997..5a0bbbb 100644 --- a/inference/interact/gui.py +++ b/inference/interact/gui.py @@ -266,7 +266,13 @@ def __init__(self, net: XMem, navi.addWidget(self.backward_run_button) # Drawing area, main canvas and minimap + self.color_picker = ColorPicker(self.num_objects, davis_palette) + self.color_picker.clicked.connect(self.hit_number_key) + color_picker_wrapper = QVBoxLayout() + color_picker_wrapper.setAlignment(Qt.AlignmentFlag.AlignTop) + color_picker_wrapper.addWidget(self.color_picker) draw_area = QHBoxLayout() + draw_area.addLayout(color_picker_wrapper) draw_area.addWidget(self.main_canvas, 4) self.tabs = QTabWidget() @@ -917,6 +923,7 @@ def hit_number_key(self, number): self.vis_brush(self.last_ex, self.last_ey) self.update_interact_vis() self.show_current_frame() + self.color_picker.select(self.current_object) def clear_brush(self): self.brush_vis_map.fill(0) diff --git a/inference/interact/gui_utils.py b/inference/interact/gui_utils.py index be30410..5598039 100644 --- a/inference/interact/gui_utils.py +++ b/inference/interact/gui_utils.py @@ -2,11 +2,13 @@ from typing import Optional, Union import time import traceback, sys +from PyQt5 import QtCore +from PyQt5.QtGui import QPalette, QColor from PyQt5.QtCore import Qt, QRunnable, pyqtSlot, pyqtSignal, QObject, QPoint, QRect, QSize from PyQt5.QtWidgets import (QHBoxLayout, QLabel, QSpinBox, QVBoxLayout, QProgressBar, QDialog, QWidget, QProgressDialog, QScrollArea, QLayout, QLayoutItem, QStyle, QSizePolicy, QSpacerItem, - QFrame, QPushButton, QSlider, QMessageBox) + QFrame, QPushButton, QSlider, QMessageBox, QGridLayout) class WorkerSignals(QObject): ''' @@ -368,9 +370,6 @@ def __init__(self, on_click: callable, load_image: callable, delete_image: calla def add_image(self, img_idx): image = self.load_image(img_idx) - # TODO: add frame number - # layout = QVBoxLayout() - # frame_num = QLabel(f"Frame {img_idx}") img_widget = ClickableLabel() img_widget.setPixmap(image) @@ -399,3 +398,52 @@ def on_close_click(self, img_idx): self.remove_image(img_idx) if self.delete_image is not None: self.delete_image(img_idx) + +class ColorPicker(QWidget): + clicked = pyqtSignal(int) + + def __init__(self, num_colors, color_palette: bytes, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.num_colors = num_colors + + self.inner_layout = QGridLayout(self) # 2 x N/2 + self.palette = color_palette + self.previously_selected = None + + for i in range(self.num_colors): + index = i + 1 + + color_widget = ClickableLabel(str(index)) + + color = self.palette[index * 3: index*3 + 3] + + color_widget.setStyleSheet(f"QLabel {{font-family: Monospace; color:white; font-weight: 900; background-color: rgb{tuple(color)}}} QLabel.selected {{border: 4px solid}}") + color_widget.setAlignment(Qt.AlignmentFlag.AlignCenter) + + color_widget.setFixedSize(40, 40) + self.inner_layout.addWidget(color_widget, int(i / 2), i % 2) + + color_widget.clicked.connect(partial(self._on_color_clicked, index)) + self.select(1) + + def _on_color_clicked(self, index: int): + self.clicked.emit(index) + pass + + def select(self, index: int): # 1-based, not 0-based + widget = self.inner_layout.itemAt(index - 1).widget() + widget.setProperty("class", "selected") + widget.style().unpolish(widget) + widget.style().polish(widget) + widget.update() + + # print(widget.text()) + # print(widget.styleSheet()) + + if self.previously_selected is not None: + self.previously_selected.setProperty("class", "") + self.previously_selected.style().unpolish(self.previously_selected) + self.previously_selected.style().polish(self.previously_selected) + self.previously_selected.update() + + self.previously_selected = self.inner_layout.itemAt(index - 1).widget() From f50d2b987362a4c4ab906bdf07cc28c805cd767d Mon Sep 17 00:00:00 2001 From: max810 Date: Thu, 27 Jul 2023 16:35:51 +0400 Subject: [PATCH 40/49] Added text instructions to color picker, simplified gui.py arrangement --- inference/interact/gui.py | 7 ++----- inference/interact/gui_utils.py | 23 +++++++++++++++++++++-- 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/inference/interact/gui.py b/inference/interact/gui.py index 5a0bbbb..037369e 100644 --- a/inference/interact/gui.py +++ b/inference/interact/gui.py @@ -104,7 +104,7 @@ def __init__(self, net: XMem, self.lcd = QTextEdit() self.lcd.setReadOnly(True) self.lcd.setMaximumHeight(28) - self.lcd.setMaximumWidth(120) + self.lcd.setFixedWidth(120) self.lcd.setText('{: 4d} / {: 4d}'.format(0, self.num_frames-1)) # timeline slider @@ -268,11 +268,8 @@ def __init__(self, net: XMem, # Drawing area, main canvas and minimap self.color_picker = ColorPicker(self.num_objects, davis_palette) self.color_picker.clicked.connect(self.hit_number_key) - color_picker_wrapper = QVBoxLayout() - color_picker_wrapper.setAlignment(Qt.AlignmentFlag.AlignTop) - color_picker_wrapper.addWidget(self.color_picker) draw_area = QHBoxLayout() - draw_area.addLayout(color_picker_wrapper) + draw_area.addWidget(self.color_picker) draw_area.addWidget(self.main_canvas, 4) self.tabs = QTabWidget() diff --git a/inference/interact/gui_utils.py b/inference/interact/gui_utils.py index 5598039..a1b9b7f 100644 --- a/inference/interact/gui_utils.py +++ b/inference/interact/gui_utils.py @@ -406,7 +406,13 @@ def __init__(self, num_colors, color_palette: bytes, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.num_colors = num_colors - self.inner_layout = QGridLayout(self) # 2 x N/2 + self.outer_layout = QVBoxLayout(self) + self.outer_layout.setAlignment(Qt.AlignmentFlag.AlignTop) + + self.inner_layout = QGridLayout() # 2 x N/2 + self.inner_layout_wrapper = QHBoxLayout() + self.inner_layout_wrapper.setAlignment(Qt.AlignmentFlag.AlignHCenter) + self.inner_layout_wrapper.addLayout(self.inner_layout) self.palette = color_palette self.previously_selected = None @@ -424,7 +430,20 @@ def __init__(self, num_colors, color_palette: bytes, *args, **kwargs) -> None: self.inner_layout.addWidget(color_widget, int(i / 2), i % 2) color_widget.clicked.connect(partial(self._on_color_clicked, index)) - self.select(1) + + color_picker_name = QLabel("Object selector") + color_picker_name.setAlignment(Qt.AlignmentFlag.AlignCenter) + color_picker_name.setStyleSheet("QLabel {font-family: Monospace; background-color: rgb(225, 225, 225); font-weight: 900}") + + color_picker_instruction = QLabel("Click or use\nnumerical keys") + color_picker_instruction.setStyleSheet("QLabel {font-family: Monospace; background-color: rgb(225, 225, 225)}") + color_picker_instruction.setAlignment(Qt.AlignmentFlag.AlignCenter) + + self.outer_layout.addWidget(color_picker_name) + self.outer_layout.addWidget(color_picker_instruction) + self.outer_layout.addLayout(self.inner_layout_wrapper) + + self.select(1) # First object selected by default def _on_color_clicked(self, index: int): self.clicked.emit(index) From d743726531be68a05fbcdf0f92329904b420fda0 Mon Sep 17 00:00:00 2001 From: max810 Date: Thu, 27 Jul 2023 16:44:50 +0400 Subject: [PATCH 41/49] Minor visual tweaks, added an object count label --- inference/interact/gui_utils.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/inference/interact/gui_utils.py b/inference/interact/gui_utils.py index a1b9b7f..c99145c 100644 --- a/inference/interact/gui_utils.py +++ b/inference/interact/gui_utils.py @@ -410,9 +410,8 @@ def __init__(self, num_colors, color_palette: bytes, *args, **kwargs) -> None: self.outer_layout.setAlignment(Qt.AlignmentFlag.AlignTop) self.inner_layout = QGridLayout() # 2 x N/2 - self.inner_layout_wrapper = QHBoxLayout() - self.inner_layout_wrapper.setAlignment(Qt.AlignmentFlag.AlignHCenter) - self.inner_layout_wrapper.addLayout(self.inner_layout) + # self.inner_layout_wrapper = QHBoxLayout() + self.inner_layout.setAlignment(Qt.AlignmentFlag.AlignCenter) self.palette = color_palette self.previously_selected = None @@ -433,15 +432,26 @@ def __init__(self, num_colors, color_palette: bytes, *args, **kwargs) -> None: color_picker_name = QLabel("Object selector") color_picker_name.setAlignment(Qt.AlignmentFlag.AlignCenter) - color_picker_name.setStyleSheet("QLabel {font-family: Monospace; background-color: rgb(225, 225, 225); font-weight: 900}") + color_picker_name.setStyleSheet("QLabel {font-family: Monospace; font-weight: 900}") + + num_objects_label = QLabel(f"({self.num_colors} objects)") + num_objects_label.setStyleSheet("QLabel {font-family: Monospace; font-weight: 900}") + num_objects_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + color_picker_instruction = QLabel("Click or use\nnumerical keys") - color_picker_instruction.setStyleSheet("QLabel {font-family: Monospace; background-color: rgb(225, 225, 225)}") + color_picker_instruction.setStyleSheet("QLabel {font-family: Monospace; font-style: italic}") color_picker_instruction.setAlignment(Qt.AlignmentFlag.AlignCenter) - self.outer_layout.addWidget(color_picker_name) - self.outer_layout.addWidget(color_picker_instruction) - self.outer_layout.addLayout(self.inner_layout_wrapper) + text_wrapper_widget = QWidget() + text_wrapper_widget.setStyleSheet("QWidget {background-color: rgb(225, 225, 225);}") + text_layout = QVBoxLayout(text_wrapper_widget) + text_layout.addWidget(color_picker_name) + text_layout.addWidget(num_objects_label) + text_layout.addWidget(color_picker_instruction) + + self.outer_layout.addWidget(text_wrapper_widget) + self.outer_layout.addLayout(self.inner_layout) self.select(1) # First object selected by default From 2d4a75cab98a177bc74119f5948096d3ed59c182 Mon Sep 17 00:00:00 2001 From: max810 Date: Fri, 28 Jul 2023 17:10:33 +0400 Subject: [PATCH 42/49] iMoved palette->ojbect index mapping to a separate class, sped up all mask importing, fixed a bug when the app crashed when deleting a reference not the current frame showing; fixed incorrect working memory size gauge --- import_existing.py | 21 +++------ inference/interact/gui.py | 59 ++++++++++++++++++++------ inference/interact/resource_manager.py | 40 +++++++++-------- 3 files changed, 74 insertions(+), 46 deletions(-) diff --git a/import_existing.py b/import_existing.py index a00b255..de51404 100644 --- a/import_existing.py +++ b/import_existing.py @@ -7,6 +7,8 @@ import progressbar from tqdm import tqdm +from util.image_loader import PaletteConverter + def resize_preserve(img, size, interpolation): h, w = img.height, img.width @@ -59,8 +61,7 @@ def resize_preserve(img, size, interpolation): exit(0) from util.palette import davis_palette - lookup_table = np.full(256, 0, dtype=np.uint8) - num_objects = 0 + palette_converter = PaletteConverter(davis_palette) mask_files = sorted(p_masks.iterdir()) @@ -69,19 +70,7 @@ def resize_preserve(img, size, interpolation): mask = Image.open(p_mask) resized_mask = resize_preserve(mask, args.size, Image.Resampling.NEAREST).convert('P') - unique_colors = resized_mask.getcolors() - for _, c in unique_colors: - if c == 0: - continue - elif lookup_table[c] == 0: - num_objects += 1 - lookup_table[c] = num_objects - - # We need range indices like (0, 1, 2, 3, 4) for each unique color, in any order, as long as black is still 0 - # If the new colors appear, we'll treat them as new objects - index_array = lookup_table[resized_mask] - index_mask = Image.fromarray(index_array, mode='P') - index_mask.putpalette(davis_palette) + index_mask = palette_converter.image_to_index_mask(resized_mask) index_mask.save(p_masks_out / f'frame_{i:06d}{p_mask.suffix}') # keep the same image form @@ -91,7 +80,7 @@ def resize_preserve(img, size, interpolation): except Exception: data = {} - data['num_objects'] = num_objects + data['num_objects'] = palette_converter._num_objects with open(p_project / 'info.json', 'wt') as f_out: json.dump(data, f_out, indent=4) diff --git a/inference/interact/gui.py b/inference/interact/gui.py index 037369e..cbd21f1 100644 --- a/inference/interact/gui.py +++ b/inference/interact/gui.py @@ -854,7 +854,7 @@ def on_save_reference(self): def on_remove_reference(self, img_idx): self.processor.remove_from_permanent_memory(img_idx) self.reference_ids.remove(img_idx) - self.res_man.remove_reference(self.cursur) + self.res_man.remove_reference(img_idx) self.show_current_frame() @@ -1064,7 +1064,7 @@ def on_gpu_timer(self): def update_memory_size(self): try: - max_work_elements = self.processor.memory.max_work_elements + max_work_elements = self.processor.memory.max_work_elements + self.processor.memory.permanent_work_mem.size max_long_elements = self.processor.memory.max_long_elements curr_work_elements = self.processor.memory.temporary_work_mem.size + self.processor.memory.permanent_work_mem.size @@ -1142,11 +1142,32 @@ def on_import_all_masks(self): qm.showMessage(f"Files with incorrect names: {broken_file_names}") else: - for i, p_f in zip(frame_ids, files_paths): - self.scroll_to(i) - self.on_import_mask(str(p_f)) - - def on_import_mask(self, mask_file_path=None): + if len(frame_ids) > 10: + qm = QMessageBox(QMessageBox.Icon.Question, "Confirm mask replacement", "") + question = f"There are more than 10 masks to import, so confirmations for each individual one would not be asked. Are you willing to continue?" + ret = qm.question(self, 'Confirm mask replacement', question, qm.Yes | qm.No) + if ret == qm.Yes: + progress_dialog = QProgressDialog("Importing masks", None, 0, len(frame_ids), self, Qt.WindowFlags(Qt.WindowType.Dialog | ~Qt.WindowCloseButtonHint)) + progress_dialog.open() + a = perf_counter() + for i, p_f in zip(frame_ids, files_paths): + # Only showing progress bar to speed up + self.cursur = i + self.on_import_mask(str(p_f), ask_confirmation=False) + progress_dialog.setValue(i + 1) + QApplication.processEvents() + b = perf_counter() + self.console_push_text(f"Importing {len(frame_ids)} masks took {b-a:.2f} seconds ({len(frame_ids)/(b-a):.2f} FPS)") + self.cursur = 0 + + else: + for i, p_f in zip(frame_ids, files_paths): + self.scroll_to(i) + self.on_import_mask(str(p_f), ask_confirmation=True) + + from profilehooks import profile + @profile(stdout=False, immediate=False, filename='on_import_mask.profile') + def on_import_mask(self, mask_file_path=None, ask_confirmation=True): if mask_file_path: file_name = mask_file_path else: @@ -1171,17 +1192,29 @@ def on_import_mask(self, mask_file_path=None): elif not object_condition: self.console_push_text(f'Expected {self.num_objects} objects. Got {mask.max()} objects instead.') else: - qm = QMessageBox(QMessageBox.Icon.Question, "Confirm mask replacement", "") - question = f"Replace mask for current frame {self.cursur} with {Path(file_name).name}?" - ret = qm.question(self, 'Confirm mask replacement', question, qm.Yes | qm.No) + if ask_confirmation: + qm = QMessageBox(QMessageBox.Icon.Question, "Confirm mask replacement", "") + question = f"Replace mask for current frame {self.cursur} with {Path(file_name).name}?" + ret = qm.question(self, 'Confirm mask replacement', question, qm.Yes | qm.No) - if ret == qm.Yes: + if not ask_confirmation or ret == qm.Yes: self.console_push_text(f'Mask file {file_name} loaded.') self.current_image_torch = self.current_prob = None self.current_mask = mask - self.show_current_frame() + + if ask_confirmation: + # for speedup purposes + self.curr_frame_dirty = False + self.reset_this_interaction() + self.show_current_frame() + self.save_current_mask() - self.on_save_reference() + + if ask_confirmation: + # Only save references if it's an individual image or a few (< 10) + # If the user is importing 1000+ masks, the memory is going to explode + self.on_save_reference() + def on_import_layer(self): file_name = self._open_file('Layer') diff --git a/inference/interact/resource_manager.py b/inference/interact/resource_manager.py index 9389016..99a7c36 100644 --- a/inference/interact/resource_manager.py +++ b/inference/interact/resource_manager.py @@ -10,6 +10,8 @@ import torch from torchvision.transforms import Resize, InterpolationMode +from util.image_loader import PaletteConverter + if not hasattr(Image, 'Resampling'): # Pillow<9.0 Image.Resampling = Image import numpy as np @@ -49,6 +51,7 @@ def __init__(self, config): self.workspace = config['workspace'] self.size = config['size'] self.palette = davis_palette + self.palette_converter = PaletteConverter(self.palette) # create temporary workspace if not specified if self.workspace is None: @@ -196,6 +199,7 @@ def add_reference(self, frame_id: int): self._save_info() def remove_reference(self, frame_id: int): + print(self.references) self.references.remove(frame_id) self._save_info() @@ -267,25 +271,27 @@ def read_external_image(self, file_name, size=None, force_mask=False): # PIL uses (width, height) image = image.resize((size[1], size[0]), resample=Image.Resampling.NEAREST if is_mask or force_mask else Image.Resampling.BICUBIC) + if force_mask and image.mode != 'P': - if image.mode in ['RGB', 'L'] and len(image.getcolors()) <= 2: - image = np.array(image.convert('L')) - # hardcoded for b&w images - image = np.where(image, 1, 0) # 255 (or whatever) -> binarize - - return image.astype('uint8') - elif image.mode == 'RGB': - image = image.convert('P', palette=self.palette) - tmp_image = np.array(image) - out_image = np.zeros_like(tmp_image) - for i, c in enumerate(np.unique(tmp_image)): - if i == 0: - continue - out_image[tmp_image == c] = i # palette indices into 0, 1, 2, ... - self.palette = image.getpalette() - return out_image + image = self.palette_converter.image_to_index_mask(image) + # if image.mode in ['RGB', 'L'] and len(image.getcolors()) <= 2: + # image = np.array(image.convert('L')) + # # hardcoded for b&w images + # image = np.where(image, 1, 0) # 255 (or whatever) -> binarize + + # return image.astype('uint8') + # elif image.mode == 'RGB': + # image = image.convert('P', palette=self.palette) + # tmp_image = np.array(image) + # out_image = np.zeros_like(tmp_image) + # for i, c in enumerate(np.unique(tmp_image)): + # if i == 0: + # continue + # out_image[tmp_image == c] = i # palette indices into 0, 1, 2, ... + # self.palette = image.getpalette() + # return out_image - image = image.convert('P', palette=self.palette) # saved without DAVIS palette, just number objects 0, 1, ... + # image = image.convert('P', palette=self.palette) # saved without DAVIS palette, just number objects 0, 1, ... image = np.array(image) return image From d7596d304f68b9351ac5ba02bf586dbe242c5e81 Mon Sep 17 00:00:00 2001 From: max810 Date: Wed, 2 Aug 2023 18:00:34 +0400 Subject: [PATCH 43/49] Huge refactoring, simplified run_on_video, bug fixes, saving images now in parallel in the background (for cmd, not for gui) --- inference/data/video_reader.py | 37 +- inference/frame_selection/frame_selection.py | 48 +- .../frame_selection/frame_selection_utils.py | 5 +- inference/interact/gui.py | 2 +- inference/interact/resource_manager.py | 2 + inference/run_on_video.py | 572 +++++++----------- main.py | 58 +- util/configuration.py | 26 + util/image_loader.py | 85 +++ util/image_saver.py | 158 ++++- 10 files changed, 589 insertions(+), 404 deletions(-) create mode 100644 util/image_loader.py diff --git a/inference/data/video_reader.py b/inference/data/video_reader.py index b389678..9be6ca1 100644 --- a/inference/data/video_reader.py +++ b/inference/data/video_reader.py @@ -1,6 +1,9 @@ +from dataclasses import dataclass, replace import os from os import path +from typing import Optional +import torch from torch.utils.data.dataset import Dataset from torchvision import transforms from torchvision.transforms import InterpolationMode @@ -12,6 +15,17 @@ from dataset.range_transform import im_normalization +@dataclass +class Sample: + rgb: torch.Tensor + raw_image_pil: Image.Image + frame: str + save: bool + shape: tuple + need_resize: bool + mask: Optional[torch.Tensor] = None + + class VideoReader(Dataset): """ This class is used to read a video, one frame at a time @@ -54,12 +68,9 @@ def __init__(self, vid_name, image_dir, mask_dir, size=-1, to_save=None, use_all self.size = size - def __getitem__(self, idx): + def __getitem__(self, idx) -> Sample: frame = self.frames[idx] - info = {} data = {} - info['frame'] = frame - info['save'] = (self.to_save is None) or (frame[:-4] in self.to_save) im_path = path.join(self.image_dir, frame) img = Image.open(im_path).convert('RGB') @@ -75,7 +86,7 @@ def __getitem__(self, idx): if not os.path.exists(gt_path): gt_path = path.join(self.mask_dir, frame[:-4]+'.PNG') - data['raw_image_tensor'] = FT.to_tensor(img) # for dataloaders it cannot be raw PIL.Image, only tensors + data['raw_image_pil'] = img # for dataloaders it cannot be raw PIL.Image, only tensors img = self.im_transform(img) load_mask = self.use_all_masks or (gt_path == self.first_gt_path) @@ -84,10 +95,15 @@ def __getitem__(self, idx): mask = np.array(mask, dtype=np.uint8) data['mask'] = mask + info = {} + info['save'] = (self.to_save is None) or (frame[:-4] in self.to_save) + info['frame'] = frame info['shape'] = shape info['need_resize'] = not (self.size < 0) + data['rgb'] = img - data['info'] = info + + data = Sample(**data, **info) return data @@ -104,4 +120,11 @@ def map_the_colors_back(self, pred_mask: Image.Image): return pred_mask.quantize(palette=self.reference_mask, dither=Image.Dither.NONE).convert('RGB') def __len__(self): - return len(self.frames) \ No newline at end of file + return len(self.frames) + + @staticmethod + def collate_fn_identity(x): + if x.mask is not None: + return replace(x, mask=torch.tensor(x.mask)) + else: + return x \ No newline at end of file diff --git a/inference/frame_selection/frame_selection.py b/inference/frame_selection/frame_selection.py index 908aeaa..90fe345 100644 --- a/inference/frame_selection/frame_selection.py +++ b/inference/frame_selection/frame_selection.py @@ -97,7 +97,7 @@ def calculate_proposals_for_annotations_with_iterative_distance_cycle_MASKS(data return chosen_frames -def select_next_candidates(keys: torch.Tensor, shrinkages, selections, masks: List[torch.tensor], num_next_candidates: int, previously_chosen_candidates: List[int] = (0,), print_progress=False, alpha=1.0, min_mask_presence_percent=0.25, device: torch.device = 'cuda:0', progress_callback=None, only_new_candidates=True, epsilon=0.5, h=30, w=54): +def select_next_candidates(keys: torch.Tensor, shrinkages, selections, masks: List[torch.tensor], num_next_candidates: int, previously_chosen_candidates: List[int] = (0,), print_progress=False, alpha=0.5, min_mask_presence_percent=0.25, device: torch.device = 'cuda:0', progress_callback=None, only_new_candidates=True, epsilon=0.5): assert len(keys) == len(masks) assert len(keys) > 0 # assert keys[0].shape[-2:] == masks[0].shape[-2:] @@ -113,32 +113,44 @@ def select_next_candidates(keys: torch.Tensor, shrinkages, selections, masks: Li Parameters ---------- - `keys` : `List[torch.Tensor]` - A list of "key" feature maps for all frames of the video. - `masks` : `List[torch.Tensor]` + keys : torch.Tensor + A list of "key" feature maps for all frames of the video (from XMem key encoder) + shrinkages : [Type] + A list of "shrinkage" feature maps for all frames of the video (from XMem key encoder). Used for similarity computation. + selections : [Type] + A list of "sellection" feature maps for all frames of the video (from XMem key encoder). Used for similarity computation. + masks : List[torch.Tensor] A list of masks for each frame (predicted or user-provided). - `num_next_candidates` : `int` + num_next_candidates : int The number of candidate frames to select. - `previously_chosen_candidates` : `List[int]`, optional - A list of previously chosen candidates. Default is (0,). - `print_progress` : `bool`, optional + previously_chosen_candidates : List[int], optional + A list of previously chosen candidate indices. Default is (0,). + print_progress : bool, optional Whether to print progress information. Default is False. - `alpha` : `float`, optional - The weight for cycle consistency in the candidate selection process. Default is 1.0. - `min_mask_presence_px` : `int`, optional - The minimum number of pixels for a valid mask. Default is 9. + alpha : float, optional + The weight for the masks in the candidate selection process, [0..1]. If 0 - masks will be ignored, the same frames will be chosen for the same video. If 1.0 - ONLY regions of the frames containing the mask will be compared. Default is 0.5. + If you trust your masks and want object-specific selections, set higher. If your predictions are really bad, set lower + min_mask_presence_percent : float, optional + The minimum percentage of pixels for a valid mask. Default is 0.25. Used to ignore frames with a tiny mask (when heavily occluded or just some random wrong prediction) + device : torch.device, optional + The device to run the computation on. Default is 'cuda:0'. + progress_callback : callable, optional + A callback function for progress updates. Used in GUI for a progress bar. Default is None. + only_new_candidates : bool, optional + Whether to return only the newly selected candidates or include previous as well. Default is True. + epsilon : float, optional + Threshold for foreground/background [0..1]. Default is 0.5 Returns ------- - `List[int]` + List[int] A list of indices of the selected candidate frames. Notes ----- This function uses a dissimilarity measure and cycle consistency to select candidate frames for the user to annotate. The dissimilarity measure ensures that the selected frames are as diverse as possible, while the cycle consistency - ensures that the dissimilarity D(A->A)=0, while D(A->B)>0, and is larger the more different A and B are (pixel-wise). - + ensures that the dissimilarity D(A->A)=0, while D(A->B)>0, and is larger the more different A and B are (pixel-wise, on feature map level - so both semantically and spatially). """ with torch.no_grad(): @@ -152,7 +164,7 @@ def select_next_candidates(keys: torch.Tensor, shrinkages, selections, masks: Li invalid = 0 for i, mask in enumerate(masks): mask_3ch = mask if mask.ndim == 3 else mask.unsqueeze(0) - mask_bin = mask_3ch.max(dim=0).values + mask_bin = mask_3ch.max(dim=0).values # for multiple objects -> use them as one large mask (simplest solution) mask_size_px = (mask_bin > epsilon).sum() ratio = mask_size_px / mask_bin.numel() * 100 @@ -179,7 +191,7 @@ def select_next_candidates(keys: torch.Tensor, shrinkages, selections, masks: Li composite_keys.append(composite_key.to(dtype=keys[i].dtype, device=device)) - print(f"INVALID: {invalid} / {len(masks)}") + print(f"Frames with invalid (empty or too small) masks: {invalid} / {len(masks)}") chosen_candidates = list(previously_chosen_candidates) chosen_candidate_keys = [composite_keys[i] for i in chosen_candidates] @@ -212,7 +224,7 @@ def select_next_candidates(keys: torch.Tensor, shrinkages, selections, masks: Li cycle_dissimilarity_score = F.relu(cycle_dissimilarity_per_pixel).sum() / cycle_dissimilarity_per_pixel.numel() dissimilarities_across_mem_keys.append(cycle_dissimilarity_score) - # filtering our existing or very similar frames + # filtering out existing or very similar frames # if the key has already been used or is very similar to at least one of the chosen candidates # dissimilarity_min_across_all -> 0 (or close to) dissimilarity_min_across_all = min(dissimilarities_across_mem_keys) diff --git a/inference/frame_selection/frame_selection_utils.py b/inference/frame_selection/frame_selection_utils.py index 1d3bc85..8abe6ed 100644 --- a/inference/frame_selection/frame_selection_utils.py +++ b/inference/frame_selection/frame_selection_utils.py @@ -8,6 +8,8 @@ from torchvision.transforms import ColorJitter, Grayscale, RandomPosterize, RandomAdjustSharpness, ToTensor, RandomAffine from tqdm import tqdm +from inference.data.video_reader import Sample + def extract_keys(dataloder, processor, print_progress=False, flatten=True, **kwargs): frame_keys = [] @@ -18,7 +20,8 @@ def extract_keys(dataloder, processor, print_progress=False, flatten=True, **kwa key_sum = None for ti, data in enumerate(tqdm(dataloder, disable=not print_progress, desc='Calculating key features')): - rgb = data['rgb'].cuda()[0] + data: Sample = data + rgb = data.rgb.cuda() key, shrinkage, selection = processor.encode_frame_key(rgb) if key_sum is None: diff --git a/inference/interact/gui.py b/inference/interact/gui.py index cbd21f1..7de54de 100644 --- a/inference/interact/gui.py +++ b/inference/interact/gui.py @@ -815,7 +815,7 @@ def _update_progress(i): alpha = self.candidates_alpha_slider.value() candidate_progress = QProgressDialog("Selecting candidates", None, 0, k, self, Qt.WindowFlags(Qt.WindowType.Dialog | ~Qt.WindowCloseButtonHint)) worker = Worker(select_next_candidates, self.res_man.keys, self.res_man.shrinkages, self.res_man.selections, self.res_man.small_masks, k, self.reference_ids, - print_progress=False, min_mask_presence_percent=float(self.candidates_min_mask_size_edit.text()), alpha=alpha, h=self.res_man.key_h, w=self.res_man.key_w) # Any other args, kwargs are passed to the run function + print_progress=False, min_mask_presence_percent=float(self.candidates_min_mask_size_edit.text()), alpha=alpha) # Any other args, kwargs are passed to the run function worker.signals.result.connect(_update_candidates) worker.signals.progress.connect(_update_progress) diff --git a/inference/interact/resource_manager.py b/inference/interact/resource_manager.py index 99a7c36..58f694a 100644 --- a/inference/interact/resource_manager.py +++ b/inference/interact/resource_manager.py @@ -204,6 +204,8 @@ def remove_reference(self, frame_id: int): self._save_info() def _save_info(self): + p_workspace_subdir = Path(self.workspace_info_file).parent + p_workspace_subdir.mkdir(parents=True, exist_ok=True) with open(self.workspace_info_file, 'wt') as f: data = {'references': sorted(self.references), 'num_objects': self._num_objects} diff --git a/inference/run_on_video.py b/inference/run_on_video.py index 798f66f..13736c0 100644 --- a/inference/run_on_video.py +++ b/inference/run_on_video.py @@ -1,10 +1,11 @@ +from dataclasses import replace from functools import partial -import os +from multiprocessing import Process, Queue from os import PathLike, path from tempfile import TemporaryDirectory from time import perf_counter -from typing import Iterable, Literal, Union, List -from collections import defaultdict +import time +from typing import Iterable, Literal, Optional, Union, List from pathlib import Path from warnings import warn @@ -14,105 +15,177 @@ import torch.nn.functional as F from torchvision.transforms import functional as FT, ToTensor from torch.utils.data import DataLoader -from baal.active.heuristics import BALD -from scipy.stats import entropy from tqdm import tqdm from PIL import Image from inference.frame_selection.frame_selection import select_next_candidates from model.network import XMem -from util.image_saver import create_overlay, save_image -from util.tensor_util import compute_array_iou, compute_tensor_iou +from util.configuration import VIDEO_INFERENCE_CONFIG +from util.image_saver import ParallelImageSaver, create_overlay, save_image +from util.tensor_util import compute_array_iou from inference.inference_core import InferenceCore -from inference.data.video_reader import VideoReader +from inference.data.video_reader import Sample, VideoReader from inference.data.mask_mapper import MaskMapper -# from inference.frame_selection.frame_selection import KNOWN_ANNOTATION_PREDICTORS -from inference.frame_selection.frame_selection_utils import disparity_func, extract_keys, get_determenistic_augmentations - - -def save_frames(dataset, frame_indices, output_folder): - p_out = Path(output_folder) - - if not p_out.exists(): - p_out.mkdir(parents=True) - - for i in frame_indices: - sample = dataset[i] - rgb_raw_tensor = sample['raw_image_tensor'].cpu().squeeze() - img = FT.to_pil_image(rgb_raw_tensor) - - img.save(p_out / f'frame_{i:06d}.png') +from inference.frame_selection.frame_selection_utils import extract_keys, get_determenistic_augmentations +from profilehooks import profile +# @profile(stdout=False, immediate=False, filename='run_on_video_aug_2023_parallel_composite_and_saving.profile') def _inference_on_video(frames_with_masks, imgs_in_path, masks_in_path, masks_out_path, original_memory_mechanism=False, - compute_iou=False, compute_uncertainty=False, manually_curated_masks=False, print_progress=True, + compute_iou=False, + manually_curated_masks=False, + print_progress=True, augment_images_with_masks=False, - uncertainty_name: str = None, - only_predict_frames_to_annotate_and_quit=0, overwrite_config: dict = None, - frame_selector_func: callable = None, save_overlay=True, - b_and_w_color=(255, 0, 0), measure_fps=False): + object_color_if_single_object=(255, 255, 255), + print_fps=False, + image_saving_max_queue_size=200): + torch.autograd.set_grad_enabled(False) frames_with_masks = set(frames_with_masks) - config = { - 'buffer_size': 100, - 'deep_update_every': -1, - 'enable_long_term': True, - 'enable_long_term_count_usage': True, - 'fbrs_model': 'saves/fbrs.pth', - 'hidden_dim': 64, - 'images': None, - 'key_dim': 64, - 'max_long_term_elements': 10000, - 'max_mid_term_frames': 10, - 'mem_every': 10, - 'min_mid_term_frames': 5, - 'model': './saves/XMem.pth', - 'no_amp': False, - 'num_objects': 1, - 'num_prototypes': 128, - 's2m_model': 'saves/s2m.pth', - 'size': 480, - 'top_k': 30, - 'value_dim': 512, - # f'../VIDEOS/RESULTS/XMem_memory/thanks_two_face_5_frames/', - 'masks_out_path': masks_out_path, - 'workspace': None, - 'save_masks': True - } - - if overwrite_config is not None: - config.update(overwrite_config) - - if compute_uncertainty: - assert uncertainty_name is not None - uncertainty_name = uncertainty_name.lower() - assert uncertainty_name in {'entropy', - 'bald', 'disparity', 'disparity_large'} - compute_disparity = uncertainty_name.startswith('disparity') + + config = VIDEO_INFERENCE_CONFIG.copy() + overwrite_config = {} if overwrite_config is None else overwrite_config + overwrite_config['masks_out_path'] = masks_out_path + config.update(overwrite_config) + + mapper, processor, vid_reader, loader = _load_main_objects(imgs_in_path, masks_in_path, config) + vid_name = vid_reader.vid_name + vid_length = len(loader) + + at_least_one_mask_loaded = False + total_preloading_time = 0.0 + + if original_memory_mechanism: + # only the first frame goes into permanent memory originally + frames_to_put_in_permanent_memory = [0] + # the rest are going to be processed later else: - compute_disparity = False + # in our modification, all frames with provided masks go into permanent memory + frames_to_put_in_permanent_memory = frames_with_masks + at_least_one_mask_loaded, total_preloading_time = _preload_permanent_memory(frames_to_put_in_permanent_memory, vid_reader, mapper, processor, augment_images_with_masks=augment_images_with_masks) - vid_reader = VideoReader( - "", - imgs_in_path, # f'/home/maksym/RESEARCH/VIDEOS/thanks_no_ears_5_annot/JPEGImages', - masks_in_path, # f'/home/maksym/RESEARCH/VIDEOS/thanks_no_ears_5_annot/Annotations_binarized_two_face', - size=config['size'], - use_all_masks=(only_predict_frames_to_annotate_and_quit == 0) - ) + if not at_least_one_mask_loaded: + raise ValueError("No valid masks provided!") + + stats = [] + + total_processing_time = 0.0 + with ParallelImageSaver(config['masks_out_path'], vid_name=vid_name, overlay_color_if_b_and_w=object_color_if_single_object, max_queue_size=image_saving_max_queue_size) as im_saver: + for ti, data in enumerate(tqdm(loader, disable=not print_progress)): + with torch.cuda.amp.autocast(enabled=True): + data: Sample = data # Just for Intellisense + # No batch dimension here, just single samples + sample = replace(data, rgb=data.rgb.cuda()) + + if ti in frames_with_masks: + msk = sample.mask + else: + msk = None + + # Map possibly non-continuous labels to continuous ones + if msk is not None: + # https://github.com/hkchengrex/XMem/issues/21 just make exhaustive = True + msk, labels = mapper.convert_mask( + msk.numpy(), exhaustive=True) + msk = torch.Tensor(msk).cuda() + if sample.need_resize: + msk = vid_reader.resize_mask(msk.unsqueeze(0))[0] + processor.set_all_labels(list(mapper.remappings.values())) + else: + labels = None + + if original_memory_mechanism: + # we only ignore the first mask, since it's already in the permanent memory + do_not_add_mask_to_memory = (ti == 0) + else: + # we ignore all frames with masks, since they are already preloaded in the permanent memory + do_not_add_mask_to_memory = msk is not None + # Run the model on this frame + # 2+ channels, classes+ and background + a = perf_counter() + prob = processor.step(sample.rgb, msk, labels, end=(ti == vid_length-1), + manually_curated_masks=manually_curated_masks, do_not_add_mask_to_memory=do_not_add_mask_to_memory) + + # Upsample to original size if needed + out_mask = _post_process(sample, prob) + b = perf_counter() + total_processing_time += (b - a) + + curr_stat = {'frame': sample.frame, 'mask_provided': msk is not None} + if compute_iou: + gt = sample.mask # for IoU computations, original mask or None, NOT msk + if gt is not None and msk is None: # There exists a ground truth, but the model didn't see it + iou = float(compute_array_iou(out_mask, gt)) + else: + iou = -1 # skipping frames where the model saw the GT + curr_stat['iou'] = iou + stats.append(curr_stat) + + # Save the mask and the overlay (potentially) + + if config['save_masks']: + out_mask = mapper.remap_index_mask(out_mask) + out_img = Image.fromarray(out_mask) + out_img = vid_reader.map_the_colors_back(out_img) + + im_saver.save_mask(mask=out_img, frame_name=sample.frame) + + if save_overlay: + original_img = sample.raw_image_pil + im_saver.save_overlay(orig_img=original_img, mask=out_img, frame_name=sample.frame) + im_saver.wait_for_jobs_to_finish(verbose=True) + + if print_fps: + print(f"TOTAL PRELOADING TIME: {total_preloading_time:.4f}s") + print(f"TOTAL PROCESSING TIME: {total_processing_time:.4f}s") + print(f"TOTAL TIME (excluding image saving): {total_preloading_time + total_processing_time:.4f}s") + print(f"TOTAL PROCESSING FPS: {len(loader) / total_processing_time:.4f}") + print(f"TOTAL FPS (excluding image saving): {len(loader) / (total_preloading_time + total_processing_time):.4f}") + return pd.DataFrame(stats) + +def _load_main_objects(imgs_in_path, masks_in_path, config): model_path = config['model'] network = XMem(config, model_path).cuda().eval() if model_path is not None: model_weights = torch.load(model_path) network.load_weights(model_weights, init_as_zero_if_needed=True) else: - print('No model loaded.') + warn('No model weights were loaded, as config["model"] was not specified.') + + mapper = MaskMapper() + processor = InferenceCore(network, config=config) + + vid_reader, loader = _create_dataloaders(imgs_in_path, masks_in_path, config) + return mapper,processor,vid_reader,loader + + +def _post_process(sample, prob): + if sample.need_resize: + prob = F.interpolate(prob.unsqueeze( + 1), sample.shape, mode='bilinear', align_corners=False)[:, 0] + + # Probability mask -> index mask + out_mask = torch.argmax(prob, dim=0) + out_mask = (out_mask.detach().cpu().numpy()).astype(np.uint8) + return out_mask + + +def _create_dataloaders(imgs_in_path: Union[str, PathLike], masks_in_path: Union[str, PathLike], config: dict): + vid_reader = VideoReader( + "", + imgs_in_path, # f'/home/maksym/RESEARCH/VIDEOS/thanks_no_ears_5_annot/JPEGImages', + masks_in_path, # f'/home/maksym/RESEARCH/VIDEOS/thanks_no_ears_5_annot/Annotations_binarized_two_face', + size=config['size'], + use_all_masks=True + ) + + # Just return the samples as they are; only using DataLoader for preloading frames from the disk + loader = DataLoader(vid_reader, batch_size=None, shuffle=False, num_workers=8, collate_fn=VideoReader.collate_fn_identity) - loader = DataLoader(vid_reader, batch_size=1, shuffle=False, num_workers=8) - vid_name = vid_reader.vid_name vid_length = len(loader) # no need to count usage for LT if the video is not that long anyway config['enable_long_term_count_usage'] = ( @@ -122,63 +195,41 @@ def _inference_on_video(frames_with_masks, imgs_in_path, masks_in_path, masks_ou * config['num_prototypes']) >= config['max_long_term_elements'] ) + + return vid_reader,loader - mapper = MaskMapper() - processor = InferenceCore(network, config=config) - first_mask_loaded = False - - if only_predict_frames_to_annotate_and_quit > 0: - assert frame_selector_func is not None - chosen_annotation_candidate_frames = frame_selector_func( - loader, processor, print_progress=print_progress, how_many_frames=only_predict_frames_to_annotate_and_quit, existing_masks_path=masks_out_path) - - return chosen_annotation_candidate_frames - - frames_ = [] - masks_ = [] - total_preloading_time = 0.0 - if original_memory_mechanism: - # only the first frame goes into permanent memory originally - frames_to_put_in_permanent_memory = [0] - # the rest are going to be processed later - else: - # in our modification, all frames with provided masks go into permanent memory - frames_to_put_in_permanent_memory = frames_with_masks +def _preload_permanent_memory(frames_to_put_in_permanent_memory: List[int], vid_reader: VideoReader, mapper: MaskMapper, processor: InferenceCore, augment_images_with_masks=False): + total_preloading_time = 0 + at_least_one_mask_loaded = False for j in frames_to_put_in_permanent_memory: - sample = vid_reader[j] - rgb = sample['rgb'].cuda() - rgb_raw_tensor = sample['raw_image_tensor'].cpu() - msk = sample['mask'] - info = sample['info'] - need_resize = info['need_resize'] + sample: Sample = vid_reader[j] + sample = replace(sample, rgb=sample.rgb.cuda()) # https://github.com/hkchengrex/XMem/issues/21 just make exhaustive = True - msk, labels = mapper.convert_mask(msk, exhaustive=True) + msk, labels = mapper.convert_mask(sample.mask, exhaustive=True) msk = torch.Tensor(msk).cuda() if min(msk.shape) == 0: # empty mask, e.g. [1, 0, 720, 1280] - print(f"Skipping adding frame {j} to memory, as the mask is empty") + warn(f"Skipping adding frame {j} to permanent memory, as the mask is empty") continue # just don't add anything to the memory - if need_resize: + if sample.need_resize: msk = vid_reader.resize_mask(msk.unsqueeze(0))[0] + # sample = replace(sample, mask=msk) processor.set_all_labels(list(mapper.remappings.values())) a = perf_counter() - processor.put_to_permanent_memory(rgb, msk) + processor.put_to_permanent_memory(sample.rgb, msk) b = perf_counter() total_preloading_time += (b - a) - if not first_mask_loaded: - first_mask_loaded = True - - frames_.append(rgb) - masks_.append(msk) + if not at_least_one_mask_loaded: + at_least_one_mask_loaded = True if augment_images_with_masks: augs = get_determenistic_augmentations( - rgb.shape, msk, subset='best_all') - rgb_raw = FT.to_pil_image(rgb_raw_tensor) + sample.rgb.shape, msk, subset='best_all') + rgb_raw = sample.raw_image_pil for img_aug, mask_aug in augs: # tensor -> PIL.Image -> tensor -> whatever normalization vid_reader applies @@ -187,157 +238,8 @@ def _inference_on_video(frames_with_masks, imgs_in_path, masks_in_path, masks_ou msk_aug = mask_aug(msk) processor.put_to_permanent_memory(rgb_aug, msk_aug) - - if not first_mask_loaded: - raise ValueError("No valid masks provided!") - - stats = [] - - if compute_uncertainty and uncertainty_name == 'bald': - bald = BALD() - - total_processing_time = 0.0 - for ti, data in enumerate(tqdm(loader, disable=not print_progress)): - with torch.cuda.amp.autocast(enabled=True): - rgb = data['rgb'].cuda()[0] - rgb_raw_tensor = data['raw_image_tensor'].cpu()[0] - - gt = data.get('mask') # for IoU computations - if ti in frames_with_masks: - msk = data['mask'] - else: - msk = None - - info = data['info'] - frame = info['frame'][0] - shape = info['shape'] - need_resize = info['need_resize'][0] - curr_stat = {'frame': frame, 'mask_provided': msk is not None} - - # not important anymore as long as at least one mask is in permanent memory - if original_memory_mechanism and not first_mask_loaded: - if msk is not None: - first_mask_loaded = True - else: - # no point to do anything without a mask - continue - - # Map possibly non-continuous labels to continuous ones - if msk is not None: - # https://github.com/hkchengrex/XMem/issues/21 just make exhaustive = True - msk, labels = mapper.convert_mask( - msk[0].numpy(), exhaustive=True) - msk = torch.Tensor(msk).cuda() - if need_resize: - msk = vid_reader.resize_mask(msk.unsqueeze(0))[0] - processor.set_all_labels(list(mapper.remappings.values())) - - else: - labels = None - - if (compute_uncertainty and uncertainty_name == 'bald') or compute_disparity: - dry_run_preds = [] - augged_images = [] - augs = get_determenistic_augmentations(subset='original_only') - rgb_raw = FT.to_pil_image(rgb_raw_tensor) - for img_aug, mask_aug in augs: - # tensor -> PIL.Image -> tensor -> whatever normalization vid_reader applies - augged_img = img_aug(rgb_raw) - augged_images.append(augged_img) - rgb_aug = vid_reader.im_transform(augged_img).cuda() - - # does not do anything, since original_only=True augmentations don't alter the mask at all - msk = mask_aug(msk) - - dry_run_prob = processor.step(rgb_aug, msk, labels, end=(ti == vid_length-1), - manually_curated_masks=manually_curated_masks, disable_memory_updates=True) - dry_run_preds.append(dry_run_prob.cpu()) - - if original_memory_mechanism: - # we only ignore the first mask, since it's already in the permanent memory - do_not_add_mask_to_memory = (ti == 0) - else: - # we ignore all frames with masks, since they are already preloaded in the permanent memory - do_not_add_mask_to_memory = msk is not None - # Run the model on this frame - # 2+ channels, classes+ and background - a = perf_counter() - prob = processor.step(rgb, msk, labels, end=(ti == vid_length-1), - manually_curated_masks=manually_curated_masks, do_not_add_mask_to_memory=do_not_add_mask_to_memory) - - if compute_uncertainty: - if uncertainty_name == 'bald': - # [batch=1, num_classes, ..., num_iterations] - all_samples = torch.stack( - [x.unsqueeze(0) for x in dry_run_preds + [prob.cpu()]], dim=-1).numpy() - score = bald.compute_score(all_samples) - curr_stat['bald'] = float(np.squeeze(score).mean()) - elif compute_disparity: - disparity_stats = disparity_func( - predictions=[prob] + dry_run_preds, augs=[img_aug for img_aug, _ in augs], images=[rgb_raw] + augged_images, output_save_path=None) - curr_stat['disparity'] = float(disparity_stats['avg']) - curr_stat['disparity_large'] = float( - disparity_stats['large']) - else: - e = entropy(prob.cpu()) - e_mean = np.mean(e) - curr_stat['entropy'] = float(e_mean) - - # Upsample to original size if needed - if need_resize: - prob = F.interpolate(prob.unsqueeze( - 1), shape, mode='bilinear', align_corners=False)[:, 0] - - # Probability mask -> index mask - out_mask = torch.argmax(prob, dim=0) - out_mask = (out_mask.detach().cpu().numpy()).astype(np.uint8) - b = perf_counter() - total_processing_time += (b - a) - - if compute_iou: - # mask is [0, 1] - # gt is [0, 255] - # both -> [False, True] - if gt is not None and msk is None: - # skipping frames - iou = float(compute_array_iou(out_mask, gt)) - else: - iou = -1 - curr_stat['iou'] = iou - - # Save the mask - if config['save_masks']: - original_img = FT.to_pil_image(rgb_raw_tensor) - - out_mask = mapper.remap_index_mask(out_mask) - out_img = Image.fromarray(out_mask) - out_img = vid_reader.map_the_colors_back(out_img) - save_image(out_img, frame, vid_name, general_dir_path=config['masks_out_path'], sub_dir_name='masks', extension='.png') - - if save_overlay: - overlaid_img = create_overlay(original_img, out_img, color_if_black_and_white=b_and_w_color) - save_image(overlaid_img, frame, vid_name, general_dir_path=config['masks_out_path'], sub_dir_name='overlay', extension='.jpg') - - if False: # args.save_scores: - np_path = path.join(args.output, 'Scores', vid_name) - os.makedirs(np_path, exist_ok=True) - if ti == len(loader)-1: - hkl.dump(mapper.remappings, path.join( - np_path, f'backward.hkl'), mode='w') - if args.save_all or info['save'][0]: - hkl.dump(prob, path.join( - np_path, f'{frame[:-4]}.hkl'), mode='w', compression='lzf') - - stats.append(curr_stat) - - if measure_fps: - print(f"TOTAL PRELOADING TIME: {total_preloading_time:.4f}s") - print(f"TOTAL PROCESSING TIME: {total_processing_time:.4f}s") - print(f"TOTAL TIME: {total_preloading_time + total_processing_time:.4f}s") - print(f"TOTAL PROCESSING FPS: {len(loader) / total_processing_time:.4f}") - print(f"TOTAL FPS: {len(loader) / (total_preloading_time + total_processing_time):.4f}") - - return pd.DataFrame(stats) + + return at_least_one_mask_loaded, total_preloading_time def run_on_video( @@ -372,117 +274,95 @@ def run_on_video( masks_in_path=masks_in_path, masks_out_path=masks_out_path, frames_with_masks=frames_with_masks, - compute_uncertainty=False, compute_iou=compute_iou, print_progress=print_progress, - manually_curated_masks=False, **kwargs ) -def predict_annotation_candidates( +def select_k_next_best_annotation_candidates( imgs_in_path: Union[str, PathLike], - candidate_selection_function: callable, - masks_in_path: Union[str, PathLike] = None, - num_candidates: int = 1, + masks_in_path: Union[str, PathLike], # at least the 1st frame + masks_out_path: Optional[Union[str, PathLike]] = None, + k: int = 5, print_progress=True, + previously_chosen_candidates=[0], + use_previously_predicted_masks=True, + # Candidate selection hyperparameters + alpha=0.5, + min_mask_presence_percent=0.25, **kwargs -) -> List[int]: - - warn('predict_annotation_candidates is deprecated, used ', DeprecationWarning, stacklevel=2) - +): """ - Args: - imgs_in_path (Union[str, PathLike]): Path to the directory containing video frames in the following format: `frame_000000.png` .jpg works too. - - if num_candidates == 1: - return [0] # First frame is hard-coded to always be used - - # p_bar.update() + Selects the next best annotation candidate frames based on the provided frames and mask paths. + + Parameters: + imgs_in_path (Union[str, PathLike]): The path to the directory containing input images. + masks_in_path (Union[str, PathLike]): The path to the directory containing the first frame masks. + masks_out_path (Optional[Union[str, PathLike]], optional): The path to save the generated masks. + If not provided, a temporary directory will be used. Defaults to None. + k (int, optional): The number of next best annotation candidate frames to select. Defaults to 5. + print_progress (bool, optional): Whether to print progress during processing. Defaults to True. + previously_chosen_candidates (list, optional): List of indices of frames with previously chosen candidates. + Defaults to [0]. + use_previously_predicted_masks (bool, optional): Whether to use previously predicted masks. + If True, `masks_out_path` must be provided. Defaults to True. + alpha (float, optional): Hyperparameter controlling the candidate selection process. Defaults to 0.5. + min_mask_presence_percent (float, optional): Minimum mask presence percentage for candidate selection. + Defaults to 0.25. + **kwargs: Additional keyword arguments to pass to `run_on_video`. Returns: - annotation_candidates (List[int]): A list of frames indices (0-based) chosen as annotation candidates, sorted by importance (most -> least). Always contains [0] - first frame - at index 0. + list: A list of indices representing the selected next best annotation candidate frames. """ + mapper, processor, vid_reader, loader = _load_main_objects(imgs_in_path, masks_in_path, VIDEO_INFERENCE_CONFIG) - # candidate_selection_function = KNOWN_ANNOTATION_PREDICTORS[approach] - - assert num_candidates >= 1 - - if num_candidates == 1: - return [0] # First frame is hard-coded to always be used - - try: - masks_out_path = kwargs.pop('masks_out_path') - except KeyError: - masks_out_path = None - - return _inference_on_video( - imgs_in_path=imgs_in_path, - masks_in_path=masks_in_path, # Ignored - masks_out_path=masks_out_path, # Used for some frame selectors - frames_with_masks=[0], - compute_uncertainty=False, - compute_iou=False, - print_progress=print_progress, - manually_curated_masks=False, - only_predict_frames_to_annotate_and_quit=num_candidates, - frame_selector_func=candidate_selection_function - ) - -def select_k_next_best_annotation_candidates( - imgs_in_path: Union[str, PathLike], - masks_in_path: Union[str, PathLike], - k: int = 5, - print_progress=True, - previously_chosen_candidates=[0], - **kwargs -): + # Extracting "key" feature maps + # Could be combined with inference (like in GUI), but the code would be a mess + frame_keys, shrinkages, selections, *_ = extract_keys(loader, processor, print_progress=print_progress, flatten=False) # extracting the keys and corresponding matrices - keys, shrinkages, selections, *_ = _inference_on_video( - imgs_in_path=imgs_in_path, - masks_in_path=masks_in_path, # Ignored - masks_out_path=None, # Used for some frame selectors - frames_with_masks=previously_chosen_candidates, - compute_uncertainty=False, - compute_iou=False, - print_progress=print_progress, - manually_curated_masks=False, - only_predict_frames_to_annotate_and_quit=True, # exact number is ignored here - frame_selector_func=partial(extract_keys, flatten=False), - **kwargs - ) - # running inference once to obtain masks to_tensor = ToTensor() - with TemporaryDirectory() as d: - p_masks_out = Path(d) - _inference_on_video( + if masks_out_path is not None: + p_masks_out = Path(masks_out_path) + + if use_previously_predicted_masks: + print("Using existing predicted masks, no need to run inference.") + assert masks_out_path is not None, "When `use_existing_masks=True`, you need to put the path to previously predicted masks in `masks_out_path`" + try: + masks = [to_tensor(Image.open(p)) for p in sorted((p_masks_out / 'masks').iterdir())] + except Exception as e: + warn("Loading previously predicting masks failed for `select_k_next_best_annotation_candidates`.") + raise e + if len(masks) != len(frame_keys): + raise FileNotFoundError(f"Not enough masks ({len(masks)}) for {len(frame_keys)} frames provided when using `use_previously_predicted_masks=True`!") + else: + print("Existing predictions were not given, will run full inference and save masks in `masks_out_path` or a temporary directory if `masks_out_path` is not given.") + if masks_out_path is None: + d = TemporaryDirectory() + p_masks_out = Path(d) + + # running inference once to obtain masks + run_on_video( imgs_in_path=imgs_in_path, masks_in_path=masks_in_path, # Ignored masks_out_path=p_masks_out, # Used for some frame selectors frames_with_masks=previously_chosen_candidates, - compute_uncertainty=False, compute_iou=False, print_progress=print_progress, - manually_curated_masks=False, **kwargs ) masks = [to_tensor(Image.open(p)) for p in sorted((p_masks_out / 'masks').iterdir())] - keys = torch.cat(keys) + keys = torch.cat(frame_keys) shrinkages = torch.cat(shrinkages) selections = torch.cat(selections) - # TODO: fix shapes - print(f"[xxx] Running with previously chosen candidates: {previously_chosen_candidates}") - min_mask_presence_percent = 0.25 - try: - all_selected_frames = select_next_candidates(keys, shrinkages=shrinkages, selections=selections, masks=masks, num_next_candidates=k, previously_chosen_candidates=previously_chosen_candidates, print_progress=print_progress, alpha=0.5, only_new_candidates=False, min_mask_presence_percent=min_mask_presence_percent) - except ValueError: - print(f"INVALID in video {imgs_in_path}") - min_mask_presence_percent = 0.01 - all_selected_frames = select_next_candidates(keys, shrinkages=shrinkages, selections=selections, masks=masks, num_next_candidates=k, previously_chosen_candidates=previously_chosen_candidates, print_progress=print_progress, alpha=0.5, only_new_candidates=False, min_mask_presence_percent=min_mask_presence_percent) + new_selected_candidates = select_next_candidates(keys, shrinkages=shrinkages, selections=selections, masks=masks, num_next_candidates=k, previously_chosen_candidates=previously_chosen_candidates, print_progress=print_progress, alpha=alpha, only_new_candidates=True, min_mask_presence_percent=min_mask_presence_percent) + if masks_out_path is None: + # Remove the temporary directory + d.cleanup() - return all_selected_frames \ No newline at end of file + return new_selected_candidates \ No newline at end of file diff --git a/main.py b/main.py index e7d6a9a..ba49306 100644 --- a/main.py +++ b/main.py @@ -1,36 +1,36 @@ -from inference.run_on_video import run_on_video, predict_annotation_candidates - +import os +import random +from inference.run_on_video import run_on_video, select_k_next_best_annotation_candidates if __name__ == '__main__': # If pytorch cannot download the weights due to an ssl error, uncomment the following lines # import ssl # ssl._create_default_https_context = ssl._create_unverified_context - # Example for a fully-labeled video - video_frames_path = 'example_videos/DAVIS-bmx/frames' - video_masks_path = 'example_videos/DAVIS-bmx/masks' - output_masks_path_baseline = 'output/DAVIS-bmx/baseline' - output_masks_path_5_frames = 'output/DAVIS-bmx/5_frames' - - num_annotation_candidates = 5 - - # The following step is not necessary, you as a human can also choose suitable frames and provide annotations - compute_iou = True - chosen_annotation_frames = predict_annotation_candidates(video_frames_path, num_candidates=num_annotation_candidates) - - print(f"The following frames were chosen as annotation candidates: {chosen_annotation_frames}") - - stats_first_frame_only = run_on_video(video_frames_path, video_masks_path, output_masks_path_baseline, frames_with_masks=[0], compute_iou=True) - stats_5_frames = run_on_video(video_frames_path, video_masks_path, output_masks_path_5_frames, frames_with_masks=chosen_annotation_frames, compute_iou=True) - - print(f"Average IoU for the video: {float(stats_first_frame_only['iou'].mean())} (first frame only)") - print(f"Average IoU for the video: {float(stats_5_frames['iou'].mean())} ({num_annotation_candidates} chosen annotated frames)") - - # Example for a video with only a few annotations present - video_frames_path = 'example_videos/imbalanced-scenes/frames' - video_masks_path = 'example_videos/imbalanced-scenes/masks' - output_masks_path_baseline = 'output/imbalanced-scenes/baseline' - output_masks_path_3_frames = 'output/imbalanced-scenes/3_frames' - run_on_video(video_frames_path, video_masks_path, output_masks_path_baseline, frames_with_masks=[0], compute_iou=False) - run_on_video(video_frames_path, video_masks_path, output_masks_path_3_frames, frames_with_masks=[0, 140, 830], compute_iou=False) + imgs_path = 'example_videos/caps/JPEGImages' + masks_path = 'example_videos/caps/Annotations' + output_path = 'output/example_video_caps' + frames_with_masks = [0, 14, 33, 43, 66] + + # Run inference with preselected annotations + run_on_video(imgs_path, masks_path, output_path, frames_with_masks) + + # Get proposals for the next 3 best annotation candidates using previously predicted masks + # If you don't have previous predictions, just put `use_previously_predicted_masks=False`, the algorithm will run inference internally + next_candidates = select_k_next_best_annotation_candidates(imgs_path, masks_path, output_path, previously_chosen_candidates=frames_with_masks, use_previously_predicted_masks=True) + print("Next candidates for annotations are: ") + for idx in next_candidates: + print(f"\tFrame {idx}") + + # Run inference on a video with all annotations provided, compute IoU + imgs_path = 'example_videos/chair/JPEGImages' + masks_path = 'example_videos/chair/Annotations' + output_path = 'output/example_video_chair' + + num_frames = len(os.listdir(imgs_path)) + frames_with_masks = random.sample(range(0, num_frames), 3) # Give 3 random masks as GT annotations + + stats = run_on_video(imgs_path, masks_path, output_path, frames_with_masks, compute_iou=True) # stats: pandas DataFrame + mean_iou = stats[stats['iou'] != -1]['iou'].mean() # -1 is for GT annotations, we just skip them + print(f"Average IoU: {mean_iou}") # Should be 90%+ as a sanity check \ No newline at end of file diff --git a/util/configuration.py b/util/configuration.py index 8445dad..ed2a1c8 100644 --- a/util/configuration.py +++ b/util/configuration.py @@ -134,6 +134,32 @@ def __setitem__(self, key, value): def __str__(self): return str(self.args) + +VIDEO_INFERENCE_CONFIG = { + 'buffer_size': 100, + 'deep_update_every': -1, + 'enable_long_term': True, + 'enable_long_term_count_usage': True, + 'fbrs_model': 'saves/fbrs.pth', + 'hidden_dim': 64, + 'images': None, + 'key_dim': 64, + 'max_long_term_elements': 10000, + 'max_mid_term_frames': 10, + 'mem_every': 10, + 'min_mid_term_frames': 5, + 'model': './saves/XMem.pth', + 'no_amp': False, + 'num_objects': 1, + 'num_prototypes': 128, + 's2m_model': 'saves/s2m.pth', + 'size': 480, + 'top_k': 30, + 'value_dim': 512, + 'masks_out_path': None, + 'workspace': None, + 'save_masks': True + } if __name__ == '__main__': c = Configuration() c.parse() diff --git a/util/image_loader.py b/util/image_loader.py new file mode 100644 index 0000000..79b850d --- /dev/null +++ b/util/image_loader.py @@ -0,0 +1,85 @@ +import numpy as np +from PIL import Image + +class PaletteConverter: + """ + A class to convert images to index masks using a given palette. + + This class allows converting images to index masks by mapping unique colors + in the input image to corresponding object indices in the output index mask. + The palette provided during initialization is used to assign colors to the objects. + Color black is assumed to be background and is ignored, thus index_mask's indices start with 1. + + Recommended to use over a set of images (e.g. masks for a single video), + as it provides CONSISTENT object indices even if some of them are not present in certain frames. + + Parameters: + ----------- + palette : bytes + A bytes object representing the palette used for color mapping. See `util.palette` module. + num_potential_colors : int, optional + The number of potential colors in the lookup table. Default is 256. + + Properties: + ----------- + palette : bytes + The palette used for color mapping. + lookup : numpy.ndarray + An array to keep track of color-to-object index mapping. + num_objects : int + The number of unique objects detected in all the images. + + Methods: + -------- + image_to_index_mask(img: Image.Image) -> Image.Image: + Convert an input image to an index mask using the palette and stored object mapping. + + Example: + -------- + # Create a palette converter object + ``` + palette = b'\\xff\\x00\\x00\\xff\\xff\\xff\\x00\\x00\\x00\\x00\\xff\\x00' # or use palettes from `util.palette` module + converter = PaletteConverter(palette) + + # Convert an image to index mask + from PIL import Image + for img_path in [...]: + input_img = Image.open(img_path) + index_mask = converter.image_to_index_mask(input_img) # lookup gets updated internally to preserve consistent object indices across multiple images + ``` + """ + def __init__(self, palette: bytes, num_potential_colors=256) -> None: + self._palette = palette + self._lookup = np.zeros(num_potential_colors, dtype=np.uint8) + self._num_objects = 0 + + def image_to_index_mask(self, img: Image.Image) -> Image.Image: + img_p = img.convert('P') + unique_colors = img_p.getcolors() + for _, c in unique_colors: + if c == 0: + # Blacks is always 0 and is ignored + continue + elif self._lookup[c] == 0: + self._num_objects += 1 + self._lookup[c] = self._num_objects + + # We need range indices like (0, 1, 2, 3, 4) for each unique color, in any order, as long as black is still 0 + # If the new colors appear, we'll treat them as new objects + index_array = self._lookup[img_p] # We use the lookup as the "image", and the actual P images as the "indices" for it (thus color_id -> object_id) + index_mask = Image.fromarray(index_array, mode='P') + index_mask.putpalette(self._palette) + + return index_mask + + @property + def palette(self): + return self._palette + + @property + def lookup(self): + return self._lookup + + @property + def num_objects(self): + return self._num_objects diff --git a/util/image_saver.py b/util/image_saver.py index f858671..25fbde6 100644 --- a/util/image_saver.py +++ b/util/image_saver.py @@ -1,5 +1,9 @@ +from multiprocessing import Process, Queue, Value import os +from pathlib import Path +import queue from time import perf_counter +import time import cv2 import numpy as np from PIL import Image @@ -8,6 +12,8 @@ from dataset.range_transform import inv_im_trans from collections import defaultdict +from inference.interact.interactive_utils import overlay_davis + def tensor_to_numpy(image): image_np = (image.numpy() * 255).astype('uint8') return image_np @@ -152,7 +158,7 @@ def _check_if_black_and_white(img: Image.Image): return False -def create_overlay(img: Image.Image, mask: Image.Image, mask_alpha=0.5, color_if_black_and_white=(255, 0, 0)): # all RGB +def create_overlay(img: Image.Image, mask: Image.Image, mask_alpha=0.5, color_if_black_and_white=(255, 255, 255)): # all RGB; Use (128, 0, 0) to mimic DAVIS color palette if you want mask = mask.convert('RGB') is_b_and_w = _check_if_black_and_white(mask) @@ -177,4 +183,152 @@ def save_image(img: Image.Image, frame_name, video_name, general_dir_path, sub_d os.makedirs(this_out_path, exist_ok=True) img_save_path = os.path.join(this_out_path, frame_name[:-4] + extension) - cv2.imwrite(img_save_path, cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)) \ No newline at end of file + img.save(img_save_path) + # cv2.imwrite(img_save_path, cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)) + +class ParallelImageSaver: + """ + A class for parallel saving of masks and / or overlay images using multiple processes. + Composing overlays and saving images on the drive is pretty slow, this class does it in the background. + + Parameters + ---------- + general_output_path : str + The general path where images and masks will be saved. + vid_name : str + The name of the video or identifier for the output files. + overlay_color_if_b_and_w : tuple, optional + The RGB color to use for masks when there is only one object. Default is (255, 255, 255) (white). + max_queue_size : int, optional + The maximum size of the mask and overlay queues. Default is 200. + + Methods + ------- + save_mask(mask, frame_name) + Start saving a mask in the background. + save_overlay(orig_img, mask, frame_name) + Create an overlay given an image and a mask, and start saving it in the background. + qsize() -> Tuple(int, int) + Get the current size of the mask and overlay queues (how many frames are still left to process). + __enter__() + Enter the context manager and return the instance itself. + __exit__(exc_type, exc_value, exc_tb) + Exit the context manager and handle cleanup. + wait_for_jobs_to_finish(verbose=False) + Wait for all saving jobs to finish. Optional, will be called automatically in __exit__. Only recommened to use if you want to print verbose progress. + + Examples + -------- + # Example usage of ParallelImageSaver class + with ParallelImageSaver("/output/directory", "video_1", overlay_color_if_b_and_w=(100, 100, 100)) as image_saver: + image = Image.open("img.jpg") + mask = Image.open("mask.png") + + # These will be saved in parallel in background processes + image_saver.save_mask(mask_image, "frame_000001") + image_saver.save_overlay(image, mask, "frame_000001") + + image_saver.wait_for_jobs_to_finish(verbose=True) # Optional + + # The images will be saved in separate processes in the background. + """ + + def __init__(self, general_output_path: str, vid_name: str, overlay_color_if_b_and_w=(255, 255, 255), max_queue_size=200) -> None: + self._mask_queue = Queue(max_queue_size) + self._overlay_queue = Queue(max_queue_size) + + self._mask_saver_worker = None + self._overlay_saver_worker = None + + self._p_out = Path(general_output_path) + self._vid_name = vid_name + self._object_color = overlay_color_if_b_and_w + self._finished = Value('b', False) + + def save_mask(self, mask: Image.Image, frame_name: str): + self._mask_queue.put((mask, frame_name, 'masks', '.png')) + + if self._mask_saver_worker is None: + self._mask_saver_worker = Process(target=self._save_mask_fn) + self._mask_saver_worker.start() + + def save_overlay(self, orig_img: Image.Image, mask: Image.Image, frame_name: str): + self._overlay_queue.put((orig_img, mask, frame_name, 'overlay', '.jpg')) + + if self._overlay_saver_worker is None: + self._overlay_saver_worker = Process(target=self._save_overlay_fn) + self._overlay_saver_worker.start() + + def _save_mask_fn(self): + while True: + try: + mask, frame_name, subdir, extension = self._mask_queue.get_nowait() + except queue.Empty: + if self._finished.value: + return + else: + time.sleep(1) + continue + save_image(mask, frame_name, self._vid_name, self._p_out, subdir, extension) + + def _save_overlay_fn(self): + while True: + try: + orig_image, mask, frame_name, subdir, extension = self._overlay_queue.get_nowait() + except queue.Empty: + if self._finished.value: + return + else: + time.sleep(1) + continue + overlaid_img = create_overlay(orig_image, mask, color_if_black_and_white=self._object_color) + save_image(overlaid_img, frame_name, self._vid_name, self._p_out, subdir, extension) + + def qsize(self): + return self._mask_queue.qsize(), self._overlay_queue.qsize() + + def __enter__(self): + # No need to initialize anything here + return self + + def __exit__(self, exc_type, exc_value, exc_tb): + if exc_type is not None: + # Just kill everything for cleaner exit + # Yeah, the child processed should be immediately killed if the main one exits, but just in case + if self._mask_saver_worker is not None: + self._mask_saver_worker.kill() + if self._mask_saver_worker is not None: + self._mask_saver_worker.kill() + + raise exc_value + else: + self.wait_for_jobs_to_finish(verbose=False) + if self._mask_saver_worker is not None: + self._mask_saver_worker.close() + + if self._overlay_saver_worker is not None: + self._overlay_saver_worker.close() + + def wait_for_jobs_to_finish(self, verbose=False): + # Optional, no need to call unless you want the verbose output + # Will be called automatically by the __exit__ method + self._finished.value = True # No need for a lock, as it's a single write with multiple reads + + if not verbose: + if self._mask_saver_worker is not None: + self._mask_saver_worker.join() + + if self._overlay_saver_worker is not None: + self._overlay_saver_worker.join() + + else: + while True: + masks_left, overlays_left = self.qsize() + if max(masks_left, overlays_left) > 0: + print(f"Finishing saving the results, {masks_left:>4d} masks and {overlays_left:>4d} overlays left.") + time.sleep(1) + else: + break + + self.wait_for_jobs_to_finish(verbose=False) # just to `.join()` them both + print("All saving jobs finished") \ No newline at end of file From 21828a4fca813c355ca0732abf3c3a2558c86b0f Mon Sep 17 00:00:00 2001 From: max810 Date: Thu, 3 Aug 2023 17:40:48 +0400 Subject: [PATCH 44/49] `run_on_video` now runs on videos. --- inference/data/video_reader.py | 76 +++++++++++++++++++------- inference/interact/resource_manager.py | 2 +- inference/run_on_video.py | 4 +- main.py | 9 ++- 4 files changed, 68 insertions(+), 23 deletions(-) diff --git a/inference/data/video_reader.py b/inference/data/video_reader.py index 9be6ca1..cad0d52 100644 --- a/inference/data/video_reader.py +++ b/inference/data/video_reader.py @@ -1,7 +1,10 @@ from dataclasses import dataclass, replace import os from os import path +from tempfile import TemporaryDirectory from typing import Optional +import cv2 +import progressbar import torch from torch.utils.data.dataset import Dataset @@ -30,7 +33,7 @@ class VideoReader(Dataset): """ This class is used to read a video, one frame at a time """ - def __init__(self, vid_name, image_dir, mask_dir, size=-1, to_save=None, use_all_masks=False, size_dir=None): + def __init__(self, vid_name, video_path, mask_dir, size=-1, to_save=None, use_all_masks=False, size_dir=None): """ image_dir - points to a directory of jpg images mask_dir - points to a directory of png masks @@ -41,16 +44,11 @@ def __init__(self, vid_name, image_dir, mask_dir, size=-1, to_save=None, use_all Default false. Set to true for YouTubeVOS validation. """ self.vid_name = vid_name - self.image_dir = image_dir + self.video_path = video_path self.mask_dir = mask_dir self.to_save = to_save self.use_all_masks = use_all_masks - if size_dir is None: - self.size_dir = self.image_dir - else: - self.size_dir = size_dir - self.frames = sorted(os.listdir(self.image_dir)) self.reference_mask = Image.open(path.join(mask_dir, sorted(os.listdir(mask_dir))[0])).convert('P') self.first_gt_path = path.join(self.mask_dir, sorted(os.listdir(self.mask_dir))[0]) @@ -67,26 +65,38 @@ def __init__(self, vid_name, image_dir, mask_dir, size=-1, to_save=None, use_all ]) self.size = size + if os.path.isfile(self.video_path): + self.tmp_dir = TemporaryDirectory() + self.image_dir = self.tmp_dir.name + self._extract_frames() + else: + self.image_dir = video_path + + if size_dir is None: + self.size_dir = self.image_dir + else: + self.size_dir = size_dir + + self.frames = sorted(os.listdir(self.image_dir)) def __getitem__(self, idx) -> Sample: - frame = self.frames[idx] data = {} - - im_path = path.join(self.image_dir, frame) + frame_name = self.frames[idx] + im_path = path.join(self.image_dir, frame_name) img = Image.open(im_path).convert('RGB') if self.image_dir == self.size_dir: shape = np.array(img).shape[:2] else: - size_path = path.join(self.size_dir, frame) + size_path = path.join(self.size_dir, frame_name) size_im = Image.open(size_path).convert('RGB') shape = np.array(size_im).shape[:2] - gt_path = path.join(self.mask_dir, frame[:-4]+'.png') + gt_path = path.join(self.mask_dir, frame_name[:-4]+'.png') if not os.path.exists(gt_path): - gt_path = path.join(self.mask_dir, frame[:-4]+'.PNG') - - data['raw_image_pil'] = img # for dataloaders it cannot be raw PIL.Image, only tensors + gt_path = path.join(self.mask_dir, frame_name[:-4]+'.PNG') + + data['raw_image_pil'] = img img = self.im_transform(img) load_mask = self.use_all_masks or (gt_path == self.first_gt_path) @@ -96,8 +106,8 @@ def __getitem__(self, idx) -> Sample: data['mask'] = mask info = {} - info['save'] = (self.to_save is None) or (frame[:-4] in self.to_save) - info['frame'] = frame + info['save'] = (self.to_save is None) or (frame_name[:-4] in self.to_save) + info['frame'] = frame_name info['shape'] = shape info['need_resize'] = not (self.size < 0) @@ -106,6 +116,35 @@ def __getitem__(self, idx) -> Sample: data = Sample(**data, **info) return data + + def __len__(self): + return len(self.frames) + + def __del__(self): + if hasattr(self, 'tmp_dir'): + self.tmp_dir.cleanup() + + def _extract_frames(self): + cap = cv2.VideoCapture(self.video_path) + frame_index = 0 + print(f'Extracting frames from {self.video_path} into a temporary dir...') + bar = progressbar.ProgressBar(max_value=int(cap.get(cv2.CAP_PROP_FRAME_COUNT))) + while(cap.isOpened()): + _, frame = cap.read() + if frame is None: + break + if self.size > 0: + h, w = frame.shape[:2] + new_w = (w*self.size//min(w, h)) + new_h = (h*self.size//min(w, h)) + if new_w != w or new_h != h: + frame = cv2.resize(frame,dsize=(new_w,new_h),interpolation=cv2.INTER_AREA) + cv2.imwrite(path.join(self.image_dir, f'frame_{frame_index:06d}.jpg'), frame) + frame_index += 1 + bar.update(frame_index) + bar.finish() + print('Done!') + def resize_mask(self, mask): # mask transform is applied AFTER mapper, so we need to post-process it in eval.py @@ -119,9 +158,6 @@ def map_the_colors_back(self, pred_mask: Image.Image): # dither=Dither.NONE just in case return pred_mask.quantize(palette=self.reference_mask, dither=Image.Dither.NONE).convert('RGB') - def __len__(self): - return len(self.frames) - @staticmethod def collate_fn_identity(x): if x.mask is not None: diff --git a/inference/interact/resource_manager.py b/inference/interact/resource_manager.py index 58f694a..80680fc 100644 --- a/inference/interact/resource_manager.py +++ b/inference/interact/resource_manager.py @@ -145,7 +145,7 @@ def _extract_frames(self, video): new_h = (h*self.size//min(w, h)) if new_w != w or new_h != h: frame = cv2.resize(frame,dsize=(new_w,new_h),interpolation=cv2.INTER_AREA) - cv2.imwrite(path.join(self.image_dir, f'{frame_index:07d}.jpg'), frame) + cv2.imwrite(path.join(self.image_dir, f'frame_{frame_index:06d}.jpg'), frame) frame_index += 1 bar.update(frame_index) bar.finish() diff --git a/inference/run_on_video.py b/inference/run_on_video.py index 13736c0..0efca6a 100644 --- a/inference/run_on_video.py +++ b/inference/run_on_video.py @@ -184,7 +184,7 @@ def _create_dataloaders(imgs_in_path: Union[str, PathLike], masks_in_path: Union ) # Just return the samples as they are; only using DataLoader for preloading frames from the disk - loader = DataLoader(vid_reader, batch_size=None, shuffle=False, num_workers=8, collate_fn=VideoReader.collate_fn_identity) + loader = DataLoader(vid_reader, batch_size=None, shuffle=False, num_workers=1, collate_fn=VideoReader.collate_fn_identity) vid_length = len(loader) # no need to count usage for LT if the video is not that long anyway @@ -207,6 +207,8 @@ def _preload_permanent_memory(frames_to_put_in_permanent_memory: List[int], vid_ sample = replace(sample, rgb=sample.rgb.cuda()) # https://github.com/hkchengrex/XMem/issues/21 just make exhaustive = True + if sample.mask is None: + raise FileNotFoundError(f"Couldn't find mask {j}! Check that the filename is either the same as for frame {j} or follows the `frame_%06d.png` format if using a video file for input.") msk, labels = mapper.convert_mask(sample.mask, exhaustive=True) msk = torch.Tensor(msk).cuda() diff --git a/main.py b/main.py index ba49306..c225a7f 100644 --- a/main.py +++ b/main.py @@ -7,13 +7,20 @@ # import ssl # ssl._create_default_https_context = ssl._create_unverified_context + # Run inference on a video file with preselected annotated frames + video_path = 'example_videos/chair/chair.mp4' + masks_path = 'example_videos/chair/Annotations' + output_path = 'output/example_video_chair_from_mp4' + frames_with_masks = [5, 10, 15] + + run_on_video(video_path, masks_path, output_path, frames_with_masks) + # Run inference on extracted .jpg frames with preselected annotations imgs_path = 'example_videos/caps/JPEGImages' masks_path = 'example_videos/caps/Annotations' output_path = 'output/example_video_caps' frames_with_masks = [0, 14, 33, 43, 66] - # Run inference with preselected annotations run_on_video(imgs_path, masks_path, output_path, frames_with_masks) # Get proposals for the next 3 best annotation candidates using previously predicted masks From 0f7d44b5ff9ecccf23265b46cdf8c9be6a07411a Mon Sep 17 00:00:00 2001 From: max810 Date: Mon, 7 Aug 2023 16:44:05 +0400 Subject: [PATCH 45/49] Added tooltips for buttons --- inference/interact/gui.py | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/inference/interact/gui.py b/inference/interact/gui.py index 7de54de..35eccc3 100644 --- a/inference/interact/gui.py +++ b/inference/interact/gui.py @@ -75,26 +75,34 @@ def __init__(self, net: XMem, # some buttons self.play_button = QPushButton('Play Video') + self.play_button.setToolTip("Play/Pause the video") self.play_button.clicked.connect(self.on_play_video) self.commit_button = QPushButton('Commit') + self.commit_button.setToolTip("Finish current interaction with the mask") self.commit_button.clicked.connect(self.on_commit) self.save_reference_button = QPushButton('Save reference') + self.save_reference_button.setToolTip("Save current mask in the permanent memory.\nUsed by the model as a reference ground truth.") self.save_reference_button.clicked.connect(self.on_save_reference) self.compute_candidates_button = QPushButton('Compute Annotation candidates') + self.compute_candidates_button.setToolTip("Get next k frames that you should annotate.") self.compute_candidates_button.clicked.connect(self.on_compute_candidates) self.full_run_button = QPushButton('FULL Propagate') + self.full_run_button.setToolTip("Clear the temporary memory, scroll to beginning and predict new masks for all the frames.") self.full_run_button.clicked.connect(partial(self.general_propagation_callback, propagation_type='full')) self.forward_run_button = QPushButton('Forward Propagate') + self.forward_run_button.setToolTip("Predict new masks for all the frames starting with the current one.") self.forward_run_button.clicked.connect(partial(self.general_propagation_callback, propagation_type='forward')) self.forward_run_button.setMinimumWidth(200) self.backward_run_button = QPushButton('Backward Propagate') + self.backward_run_button.setToolTip("Predict new masks for all the frames before with the current one.") self.backward_run_button.clicked.connect(partial(self.general_propagation_callback, propagation_type='backward')) self.backward_run_button.setMinimumWidth(200) - self.reset_button = QPushButton('Reset Frame') + self.reset_button = QPushButton('Delete Mask') + self.reset_button.setToolTip("Delete the mask for the current frames. Cannot be undone!") self.reset_button.clicked.connect(self.on_reset_mask) self.spacebar = QShortcut(QKeySequence(Qt.Key_Space), self) @@ -148,8 +156,11 @@ def __init__(self, net: XMem, self.curr_interaction = 'Free' self.interaction_group = QButtonGroup() self.radio_fbrs = QRadioButton('Click') + self.radio_fbrs.setToolTip("Clicks in/out of the current mask. Careful - will delete existing mask!") self.radio_s2m = QRadioButton('Scribble') + self.radio_s2m.setToolTip('Draw a line in/out of the current mask. Edits existing masks directly.') self.radio_free = QRadioButton('Free') + self.radio_free.setToolTip('Free drawing') self.interaction_group.addButton(self.radio_fbrs) self.interaction_group.addButton(self.radio_s2m) self.interaction_group.addButton(self.radio_free) @@ -185,9 +196,13 @@ def __init__(self, net: XMem, # Parameters setting self.clear_mem_button = QPushButton('Clear TEMP and LONG memory') + self.clear_mem_button.setToolTip("Temporary and long-term memory can have features from the previous model run.
" + "If you had errors in the predictions, they might influence new masks.
" + "So for a new model run either clean the memory or just use FULL propagate.") self.clear_mem_button.clicked.connect(self.on_clear_memory) self.work_mem_gauge, self.work_mem_gauge_layout = create_gauge('Working memory size') + self.work_mem_gauge.setToolTip("Temporary and Permanent memory together.") self.long_mem_gauge, self.long_mem_gauge_layout = create_gauge('Long-term memory size') self.gpu_mem_gauge, self.gpu_mem_gauge_layout = create_gauge('GPU mem. (all processes, w/ caching)') self.torch_mem_gauge, self.torch_mem_gauge_layout = create_gauge('GPU mem. (used by torch, w/o caching)') @@ -214,9 +229,11 @@ def __init__(self, net: XMem, # import mask/layer self.import_mask_button = QPushButton('Import mask') + self.import_mask_button.setToolTip("Import an existing .png file with a mask for a current frame.\nReplace existing mask.") self.import_mask_button.clicked.connect(self.on_import_mask) self.import_all_masks_button = QPushButton('Import ALL masks') + self.import_all_masks_button.setToolTip("Import a list of mask for some or all frames in the video.\nIf more than 10 are imported, the invididual confirmations will not be shown.") self.import_all_masks_button.clicked.connect(self.on_import_all_masks) self.import_layer_button = QPushButton('Import layer') self.import_layer_button.clicked.connect(self.on_import_layer) @@ -340,11 +357,18 @@ def __init__(self, net: XMem, candidates_area = QVBoxLayout() self.candidates_min_mask_size_edit = QLineEdit() + self.candidates_min_mask_size_edit.setToolTip("Minimal size a mask should have to be considered, % of the total image size." + "\nIf it's smaller than the value specified, the frame will be ignored." + "\nUsed to filter out \"junk\" frames or frames with very heavy occlusions.") float_validator = QRegExpValidator(QRegExp(r"^(100(\.0+)?|[1-9]?\d(\.\d+)?|0(\.\d+)?)$")) self.candidates_min_mask_size_edit.setValidator(float_validator) self.candidates_min_mask_size_edit.setText("0.25") self.candidates_k_slider = NamedSlider("k", 1, 20, 1, default=5) + self.candidates_k_slider.setToolTip("How many annotation candidates to select.") self.candidates_alpha_slider = NamedSlider("α", 0, 100, 1, default=50, multiplier=0.01, min_text='Frames', max_text='Masks') + self.candidates_alpha_slider.setToolTip("Target importance." + "
If 0 the candidates will be the same regardless of which object is being segmented." + "
If 1 the only part of the image considered will be the one occupied by the mask.") candidates_area.addWidget(QLabel("Min mask size, % of the total image size, 0-100")) candidates_area.addWidget(self.candidates_min_mask_size_edit) candidates_area.addWidget(QLabel("Candidates calculation hyperparameters")) From 7f0643ce89c5355ec6e11bd9487acaa3e79872d1 Mon Sep 17 00:00:00 2001 From: max810 Date: Mon, 7 Aug 2023 18:13:50 +0400 Subject: [PATCH 46/49] Removed profiling code from the codebase --- inference/interact/gui.py | 2 -- inference/run_on_video.py | 3 --- 2 files changed, 5 deletions(-) diff --git a/inference/interact/gui.py b/inference/interact/gui.py index 35eccc3..3beaa50 100644 --- a/inference/interact/gui.py +++ b/inference/interact/gui.py @@ -1189,8 +1189,6 @@ def on_import_all_masks(self): self.scroll_to(i) self.on_import_mask(str(p_f), ask_confirmation=True) - from profilehooks import profile - @profile(stdout=False, immediate=False, filename='on_import_mask.profile') def on_import_mask(self, mask_file_path=None, ask_confirmation=True): if mask_file_path: file_name = mask_file_path diff --git a/inference/run_on_video.py b/inference/run_on_video.py index 0efca6a..ca927a2 100644 --- a/inference/run_on_video.py +++ b/inference/run_on_video.py @@ -28,9 +28,6 @@ from inference.data.mask_mapper import MaskMapper from inference.frame_selection.frame_selection_utils import extract_keys, get_determenistic_augmentations -from profilehooks import profile - -# @profile(stdout=False, immediate=False, filename='run_on_video_aug_2023_parallel_composite_and_saving.profile') def _inference_on_video(frames_with_masks, imgs_in_path, masks_in_path, masks_out_path, original_memory_mechanism=False, compute_iou=False, From 64747cda51b1bd009cf6dc2a826b946f1b4bc4fb Mon Sep 17 00:00:00 2001 From: Maksym Bekuzarov Date: Fri, 11 Aug 2023 13:17:56 +0400 Subject: [PATCH 47/49] Docker (#1) * Working Dockerfile for inference configuration, added a script that conveniently runs the inference on video, XMem now won't load unnecessary resnet imagenet weights * Fully working inference and GUI docker configurations * Cleaned and simplified Docker running code, minor bugfixes --- .dockerignore | 152 +++++++++++++++++++++++++++++++++++++ Dockerfile | 25 ++++++ inference/run_on_video.py | 4 +- interactive_demo.py | 2 +- model/modules.py | 8 +- model/network.py | 6 +- process_video.py | 30 ++++++++ run_gui_in_docker.sh | 65 ++++++++++++++++ run_inference_in_docker.sh | 65 ++++++++++++++++ 9 files changed, 347 insertions(+), 10 deletions(-) create mode 100644 .dockerignore create mode 100644 Dockerfile create mode 100644 process_video.py create mode 100644 run_gui_in_docker.sh create mode 100644 run_inference_in_docker.sh diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..ab672ed --- /dev/null +++ b/.dockerignore @@ -0,0 +1,152 @@ +log/ +output/ +.vscode/ +workspace/ +run*.sh + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +output/ +example_videos/ +torchlogs/ +*.png +*.csv +*.txt +*junk* +*.profile + +!requirements*.txt + +.DS_Store +*.jpg +*.zip +*.sh +Dockerfile \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..8e6f32f --- /dev/null +++ b/Dockerfile @@ -0,0 +1,25 @@ +# Use the specified PyTorch base image with CUDA support +FROM pytorch/pytorch:1.12.1-cuda11.3-cudnn8-runtime AS xmem2-base-inference + +# Set the working directory in the container +WORKDIR /app + +RUN python -m pip install --no-cache-dir opencv-python-headless scikit-learn Pillow==9.2.0 + +# Install Python dependencies from requirements.txt +COPY requirements.txt /app/requirements.txt +RUN python -m pip install --no-cache-dir -r requirements.txt + +# Copy the application files into the container +COPY . /app + +# FOR GUI - only a few extra dependencies +FROM xmem2-base-inference AS xmem2-gui + +# Qt dependencies +RUN apt-get update && apt-get install -y build-essential libgl1 libglib2.0-0 libxkbcommon-x11-0 '^libxcb.*-dev' libx11-xcb-dev libglu1-mesa-dev libxrender-dev libxi-dev libxkbcommon-dev libxkbcommon-x11-dev libfontconfig libdbus-1-3 mesa-utils libgl1-mesa-glx +RUN /bin/bash -c 'gcc --version' + +RUN python -m pip install --no-cache-dir -r requirements_demo.txt +# To avoid error messages when launching PyQT +ENV LIBGL_ALWAYS_INDIRECT=1 \ No newline at end of file diff --git a/inference/run_on_video.py b/inference/run_on_video.py index ca927a2..0b4eb67 100644 --- a/inference/run_on_video.py +++ b/inference/run_on_video.py @@ -5,7 +5,7 @@ from tempfile import TemporaryDirectory from time import perf_counter import time -from typing import Iterable, Literal, Optional, Union, List +from typing import Iterable, Optional, Union, List from pathlib import Path from warnings import warn @@ -146,7 +146,7 @@ def _inference_on_video(frames_with_masks, imgs_in_path, masks_in_path, masks_ou def _load_main_objects(imgs_in_path, masks_in_path, config): model_path = config['model'] - network = XMem(config, model_path).cuda().eval() + network = XMem(config, model_path, pretrained_key_encoder=False, pretrained_value_encoder=False).cuda().eval() if model_path is not None: model_weights = torch.load(model_path) network.load_weights(model_weights, init_as_zero_if_needed=True) diff --git a/interactive_demo.py b/interactive_demo.py index 9dcba6b..1fc490d 100644 --- a/interactive_demo.py +++ b/interactive_demo.py @@ -71,7 +71,7 @@ with torch.cuda.amp.autocast(enabled=not args.no_amp): # Load our checkpoint - network = XMem(config, args.model).cuda().eval() + network = XMem(config, args.model, pretrained_key_encoder=False, pretrained_value_encoder=False).cuda().eval() # Loads the S2M model if args.s2m_model is not None: diff --git a/model/modules.py b/model/modules.py index 82d1ac4..652cf3a 100644 --- a/model/modules.py +++ b/model/modules.py @@ -100,11 +100,11 @@ def forward(self, g, h): class ValueEncoder(nn.Module): - def __init__(self, value_dim, hidden_dim, single_object=False): + def __init__(self, value_dim, hidden_dim, single_object=False, pretrained=True): super().__init__() self.single_object = single_object - network = resnet.resnet18(pretrained=True, extra_dim=1 if single_object else 2) + network = resnet.resnet18(pretrained=pretrained, extra_dim=1 if single_object else 2) self.conv1 = network.conv1 self.bn1 = network.bn1 self.relu = network.relu # 1/2, 64 @@ -151,9 +151,9 @@ def forward(self, image, image_feat_f16, h, masks, others, is_deep_update=True): class KeyEncoder(nn.Module): - def __init__(self): + def __init__(self, pretrained=True): super().__init__() - network = resnet.resnet50(pretrained=True) + network = resnet.resnet50(pretrained=pretrained) self.conv1 = network.conv1 self.bn1 = network.bn1 self.relu = network.relu # 1/2, 64 diff --git a/model/network.py b/model/network.py index c5f179d..124dd42 100644 --- a/model/network.py +++ b/model/network.py @@ -15,7 +15,7 @@ class XMem(nn.Module): - def __init__(self, config, model_path=None, map_location=None): + def __init__(self, config, model_path=None, map_location=None, pretrained_key_encoder=True, pretrained_value_encoder=True): """ model_path/map_location are used in evaluation only map_location is for converting models saved in cuda to cpu @@ -26,8 +26,8 @@ def __init__(self, config, model_path=None, map_location=None): self.single_object = config.get('single_object', False) print(f'Single object mode: {self.single_object}') - self.key_encoder = KeyEncoder() - self.value_encoder = ValueEncoder(self.value_dim, self.hidden_dim, self.single_object) + self.key_encoder = KeyEncoder(pretrained=pretrained_key_encoder) + self.value_encoder = ValueEncoder(self.value_dim, self.hidden_dim, self.single_object, pretrained=pretrained_value_encoder) # Projection from f16 feature space to key/value space self.key_proj = KeyProjection(1024, self.key_dim) diff --git a/process_video.py b/process_video.py new file mode 100644 index 0000000..cea3132 --- /dev/null +++ b/process_video.py @@ -0,0 +1,30 @@ +import argparse +import re +from pathlib import Path + +from inference.run_on_video import run_on_video + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Process video frames given a few (1+) existing annotation masks') + parser.add_argument('--video', type=str, help='Path to the video file or directory with .jpg video frames to process', required=True) + parser.add_argument('--masks', type=str, help='Path to the directory with individual .png masks for corresponding video frames, named `frame_000000.png`, `frame_000123.png`, ... or similarly (the script searches for the first integer value in the filename). ' + 'Will use all masks int the directory.', required=True) + parser.add_argument('--output', type=str, help='Path to the output directory where to save the resulting segmentation masks and overlays. ' + 'Will be automatically created if does not exist', required=True) + + args = parser.parse_args() + + frames_with_masks = [] + for file_path in (p for p in Path(args.masks).iterdir() if p.is_file()): + frame_number_match = re.search(r'\d+', file_path.stem) + if frame_number_match is None: + print(f"ERROR: file {file_path} does not contain a frame number. Cannot load it as a mask.") + exit(1) + frames_with_masks.append(int(frame_number_match.group())) + + print("Using masks for frames: ", frames_with_masks) + + p_out = Path(args.output) + p_out.mkdir(parents=True, exist_ok=True) + run_on_video(args.video, args.masks, args.output) diff --git a/run_gui_in_docker.sh b/run_gui_in_docker.sh new file mode 100644 index 0000000..839786b --- /dev/null +++ b/run_gui_in_docker.sh @@ -0,0 +1,65 @@ +#!/bin/bash + +key='' +key_name='' +value='' + +# Parsing keyword arguments +while [ $# -gt 0 ]; do + if [ -z "${key}" ]; then + case "$1" in + --images|--video|--workspace) + key="other" + key_name="${1}" + ;; + --num_objects) + key="--num_objects" + ;; + *) + printf "***************************\n" + printf "* Error: Invalid argument ${1}\n" + printf "* Specify one of --images --video or --workspace with