Skip to content

Commit

Permalink
feat(ml): fix vid_dataloader conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
wr0124 committed Sep 13, 2024
1 parent 1b15aef commit 7ee826a
Showing 1 changed file with 55 additions and 24 deletions.
79 changes: 55 additions & 24 deletions data/self_supervised_temporal_labeled_mask_online_dataset.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import os
import random
import re

import bisect
import torch

from collections import OrderedDict
from data.base_dataset import BaseDataset, get_transform_list
from data.image_folder import make_labeled_path_dataset
from data.online_creation import crop_image
Expand Down Expand Up @@ -61,6 +61,29 @@ def __init__(self, opt, phase, name=""):

self.A_size = len(self.A_img_paths) # get the size of dataset A

self.frames_counts = OrderedDict()
for path in self.A_img_paths:
vid_series_paths = os.path.dirname(path)
if vid_series_paths not in self.frames_counts:
self.frames_counts[vid_series_paths] = (
0,
-self.num_frames * self.frame_step,
)
count, count_minus = self.frames_counts[vid_series_paths]
count += 1
self.frames_counts[vid_series_paths] = (
count,
count - self.num_frames * self.frame_step,
)

self.cumulative_sums = []
cumulative_sum = 0
self.vid_series_keys = list(self.frames_counts.keys())
for _, (_, count_minus) in self.frames_counts.items():
if count_minus > 0:
cumulative_sum += count_minus
self.cumulative_sums.append(cumulative_sum)

def get_img(
self,
A_img_path,
Expand All @@ -72,30 +95,38 @@ def get_img(
index=None,
): # all params are unused

while True:

if len(self.frames_counts) == 1: # single video mario
index_A = random.randint(0, self.range_A - 1)
else: # video series
range_A = sum(
count_minu
for count, count_minu in self.frames_counts.values()
if count_minu > 0
)
index_A = random.randint(0, range_A - 1)
selected_index = bisect.bisect_right(self.cumulative_sums, index_A)
selected_key = self.vid_series_keys[selected_index]
path_num = (
index_A - self.cumulative_sums[selected_index - 1]
if selected_index > 0
else 0
)
filtered_paths = [
path
for path in self.A_img_paths
if os.path.dirname(path) == selected_key
]
if path_num < len(filtered_paths): # this is absolu true
selected_path = filtered_paths[path_num - 1]
else:
print("Path number exceeds available paths in the selected directory")

images_A = []
labels_A = []

ref_A_img_path = self.A_img_paths[index_A]
ref_name_A = ref_A_img_path.split("/")[-1][: self.num_common_char]
ref_A_name = ref_A_img_path.split("/")[-1] # fullname of the ref_A

vid_series_path = os.path.dirname(ref_A_img_path)
vid_series = vid_series_path.split("/")[-1]
if ref_A_name.startswith(
vid_series
): # dataset contains different video in form of img/vid_series/vid_seriesframe.jpg
series_count = sum(vid_series_path in path for path in self.A_img_paths)
frame_num = int(ref_A_name[len(vid_series) : -4]) # remove ".jpg"
if frame_num < (series_count - self.num_frames):
break
else:
print("Condition not met, generating a new index_A...")
else: # dataset from one video in form of img/frames.jpg
break
index_A = self.A_img_paths.index(selected_path)

images_A = []
labels_A = []
ref_A_img_path = self.A_img_paths[index_A]
ref_name_A = ref_A_img_path.split("/")[-1][: self.num_common_char]

for i in range(self.num_frames):
cur_index_A = index_A + i * self.frame_step
Expand Down

0 comments on commit 7ee826a

Please sign in to comment.