Skip to content

Commit

Permalink
feat(ml): improve algo
Browse files Browse the repository at this point in the history
  • Loading branch information
wr0124 committed Sep 17, 2024
1 parent 7ee826a commit d86f20c
Showing 1 changed file with 32 additions and 19 deletions.
51 changes: 32 additions & 19 deletions data/self_supervised_temporal_labeled_mask_online_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,24 +61,37 @@ def __init__(self, opt, phase, name=""):

self.A_size = len(self.A_img_paths) # get the size of dataset A

# dataset form img(bbox)/vid_series/vid_series_#frame.png(.txt)
# Dict to track the number of frames in each video series
self.frames_counts = OrderedDict()
for path in self.A_img_paths:
vid_series_paths = os.path.dirname(path)
# If this video series path hasn't been processed yet, initialize it in frames_counts.
# The value is a tuple (count, count_minus), where:
# - 'count' will store the number of frames in the series.
# - 'count_minus' is calculated as a negative offset based on the number of frames and step size.
# This offset acts as a limit to determine which frames to choose from the series.
if vid_series_paths not in self.frames_counts:
self.frames_counts[vid_series_paths] = (
0,
-self.num_frames * self.frame_step,
)
# Retrieve the current count and count_minus for the video series.
count, count_minus = self.frames_counts[vid_series_paths]
count += 1
# Update frames_counts with the new count and recalculate count_minus.
# Count is the total number of frames in the video series
# count_minus is the number of available frames in this video series
self.frames_counts[vid_series_paths] = (
count,
count - self.num_frames * self.frame_step,
)

# Store cumulative sums of available frames in the order of video series.
self.cumulative_sums = []
cumulative_sum = 0
# Create a list of video series paths for tracking later
self.vid_series_keys = list(self.frames_counts.keys())
# Iterate through each video series in frames_counts to compute the cumulative sum of available frame.
for _, (_, count_minus) in self.frames_counts.items():
if count_minus > 0:
cumulative_sum += count_minus
Expand All @@ -98,29 +111,29 @@ def get_img(
if len(self.frames_counts) == 1: # single video mario
index_A = random.randint(0, self.range_A - 1)
else: # video series
range_A = sum(
count_minu
for count, count_minu in self.frames_counts.values()
if count_minu > 0
)
index_A = random.randint(0, range_A - 1)
selected_index = bisect.bisect_right(self.cumulative_sums, index_A)
selected_key = self.vid_series_keys[selected_index]
path_num = (
index_A - self.cumulative_sums[selected_index - 1]
if selected_index > 0
else 0
)
range_A = self.cumulative_sums[
-1
] # total number of frames that can be used as a starting frame
random_A = random.randint(
0, range_A - 1
) # chose one frame from available video series

# according to the selected_index, get the video series and frame number
selected_index = bisect.bisect_left(self.cumulative_sums, random_A)
selected_vid = self.vid_series_keys[selected_index]
if selected_index > 0:
frame_num = random_A - self.cumulative_sums[selected_index - 1]
else:
frame_num = random_A

filtered_paths = [
path
for path in self.A_img_paths
if os.path.dirname(path) == selected_key
if os.path.dirname(path) == selected_vid
]
if path_num < len(filtered_paths): # this is absolu true
selected_path = filtered_paths[path_num - 1]
else:
print("Path number exceeds available paths in the selected directory")

# Get path and index_A
selected_path = filtered_paths[frame_num - 1]
index_A = self.A_img_paths.index(selected_path)

images_A = []
Expand Down

0 comments on commit d86f20c

Please sign in to comment.