From c2266ab05a74de6d7e8170ff644bd50e408ca5ff Mon Sep 17 00:00:00 2001 From: kapoorlab Date: Mon, 22 Jul 2024 14:12:22 +0000 Subject: [PATCH] sample sub arrays --- src/napatrackmater/Trackvector.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/src/napatrackmater/Trackvector.py b/src/napatrackmater/Trackvector.py index 51fcadf..b4a2672 100644 --- a/src/napatrackmater/Trackvector.py +++ b/src/napatrackmater/Trackvector.py @@ -4028,23 +4028,25 @@ def sample_subarrays(data, num_samples, tracklet_length, total_duration): for i in range(num_samples): start_min = i * interval start_max = min((i + 1) * interval, total_duration - tracklet_length) - if start_min >= start_max: - continue + + if start_max <= start_min: + start_max = start_min + 1 + start_index = np.random.randint(start_min, start_max) end_index = start_index + tracklet_length - sub_data = data[start_index:end_index, :] - if sub_data.shape[0] == tracklet_length: - subarrays.append(sub_data) - - if len(subarrays) < num_samples: - additional_subarrays = [] - for i in range(num_samples - len(subarrays)): - start_index = np.random.randint(0, total_duration - tracklet_length) - end_index = start_index + tracklet_length + if end_index <= total_duration: + sub_data = data[start_index:end_index, :] + if sub_data.shape[0] == tracklet_length: + subarrays.append(sub_data) + + while len(subarrays) < num_samples: + start_index = np.random.randint(0, total_duration - tracklet_length) + end_index = start_index + tracklet_length + if end_index <= total_duration: sub_data = data[start_index:end_index, :] if sub_data.shape[0] == tracklet_length: - additional_subarrays.append(sub_data) - subarrays.extend(additional_subarrays[: num_samples - len(subarrays)]) + subarrays.append(sub_data) + return subarrays