From 611724dfe019bf37c0d28640cb172a46862d64d3 Mon Sep 17 00:00:00 2001 From: julie wang Date: Mon, 7 Oct 2024 10:17:25 +0200 Subject: [PATCH] feat(ml):debug dataloader frame --- ...self_supervised_vid_mask_online_dataset.py | 23 +++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/data/self_supervised_vid_mask_online_dataset.py b/data/self_supervised_vid_mask_online_dataset.py index 42418a92..7064339a 100644 --- a/data/self_supervised_vid_mask_online_dataset.py +++ b/data/self_supervised_vid_mask_online_dataset.py @@ -87,6 +87,19 @@ def __init__(self, opt, phase, name=""): cumulative_sum += count_minus self.cumulative_sums.append(cumulative_sum) + # Iterate through each video series to get available frame pool + self.available_frame_pool = [] + start_count = 0 + for i in range(len(self.cumulative_sums)): + num_frames = ( + self.cumulative_sums[i] - self.cumulative_sums[i - 1] + if i != 0 + else self.cumulative_sums[i] + ) + end_count = start_count + num_frames + self.available_frame_pool.append(list(range(start_count, end_count))) + start_count = end_count + def get_img( self, A_img_path, @@ -109,7 +122,13 @@ def get_img( ) # chose one frame from available video series # according to the selected_index, get the video series and frame number - selected_index = bisect.bisect_left(self.cumulative_sums, random_A) + selected_index = [ + i for i, t in enumerate(self.available_frame_pool) if random_A in t + ] + if len(selected_index) == 1: + selected_index = selected_index[0] + else: + raise ValueError("random_A not found in any sublist, check dataset") selected_vid = self.vid_series_paths[selected_index] if selected_index > 0: frame_num = random_A - self.cumulative_sums[selected_index - 1] @@ -123,7 +142,7 @@ def get_img( ] # Get path and index_A - selected_path = filtered_paths[frame_num - 1] + selected_path = filtered_paths[frame_num] index_A = self.A_img_paths.index(selected_path) images_A = []