From 7ee826a9cf747ac476af54fed2979627c67f792e Mon Sep 17 00:00:00 2001 From: julie wang Date: Tue, 3 Sep 2024 09:11:50 +0200 Subject: [PATCH] feat(ml): fix vid_dataloader conflict --- ...ed_temporal_labeled_mask_online_dataset.py | 79 +++++++++++++------ 1 file changed, 55 insertions(+), 24 deletions(-) diff --git a/data/self_supervised_temporal_labeled_mask_online_dataset.py b/data/self_supervised_temporal_labeled_mask_online_dataset.py index 896515b88..8174a3654 100644 --- a/data/self_supervised_temporal_labeled_mask_online_dataset.py +++ b/data/self_supervised_temporal_labeled_mask_online_dataset.py @@ -1,9 +1,9 @@ import os import random import re - +import bisect import torch - +from collections import OrderedDict from data.base_dataset import BaseDataset, get_transform_list from data.image_folder import make_labeled_path_dataset from data.online_creation import crop_image @@ -61,6 +61,29 @@ def __init__(self, opt, phase, name=""): self.A_size = len(self.A_img_paths) # get the size of dataset A + self.frames_counts = OrderedDict() + for path in self.A_img_paths: + vid_series_paths = os.path.dirname(path) + if vid_series_paths not in self.frames_counts: + self.frames_counts[vid_series_paths] = ( + 0, + -self.num_frames * self.frame_step, + ) + count, count_minus = self.frames_counts[vid_series_paths] + count += 1 + self.frames_counts[vid_series_paths] = ( + count, + count - self.num_frames * self.frame_step, + ) + + self.cumulative_sums = [] + cumulative_sum = 0 + self.vid_series_keys = list(self.frames_counts.keys()) + for _, (_, count_minus) in self.frames_counts.items(): + if count_minus > 0: + cumulative_sum += count_minus + self.cumulative_sums.append(cumulative_sum) + def get_img( self, A_img_path, @@ -72,30 +95,38 @@ def get_img( index=None, ): # all params are unused - while True: - + if len(self.frames_counts) == 1: # single video mario index_A = random.randint(0, self.range_A - 1) + else: # video series + range_A = sum( + count_minu + for count, count_minu in self.frames_counts.values() + if count_minu > 0 + ) + index_A = random.randint(0, range_A - 1) + selected_index = bisect.bisect_right(self.cumulative_sums, index_A) + selected_key = self.vid_series_keys[selected_index] + path_num = ( + index_A - self.cumulative_sums[selected_index - 1] + if selected_index > 0 + else 0 + ) + filtered_paths = [ + path + for path in self.A_img_paths + if os.path.dirname(path) == selected_key + ] + if path_num < len(filtered_paths): # this is absolu true + selected_path = filtered_paths[path_num - 1] + else: + print("Path number exceeds available paths in the selected directory") - images_A = [] - labels_A = [] - - ref_A_img_path = self.A_img_paths[index_A] - ref_name_A = ref_A_img_path.split("/")[-1][: self.num_common_char] - ref_A_name = ref_A_img_path.split("/")[-1] # fullname of the ref_A - - vid_series_path = os.path.dirname(ref_A_img_path) - vid_series = vid_series_path.split("/")[-1] - if ref_A_name.startswith( - vid_series - ): # dataset contains different video in form of img/vid_series/vid_seriesframe.jpg - series_count = sum(vid_series_path in path for path in self.A_img_paths) - frame_num = int(ref_A_name[len(vid_series) : -4]) # remove ".jpg" - if frame_num < (series_count - self.num_frames): - break - else: - print("Condition not met, generating a new index_A...") - else: # dataset from one video in form of img/frames.jpg - break + index_A = self.A_img_paths.index(selected_path) + + images_A = [] + labels_A = [] + ref_A_img_path = self.A_img_paths[index_A] + ref_name_A = ref_A_img_path.split("/")[-1][: self.num_common_char] for i in range(self.num_frames): cur_index_A = index_A + i * self.frame_step