diff --git a/src/napatrackmater/Trackvector.py b/src/napatrackmater/Trackvector.py index b4a2672..3fbb391 100644 --- a/src/napatrackmater/Trackvector.py +++ b/src/napatrackmater/Trackvector.py @@ -4039,14 +4039,16 @@ def sample_subarrays(data, num_samples, tracklet_length, total_duration): if sub_data.shape[0] == tracklet_length: subarrays.append(sub_data) - while len(subarrays) < num_samples: - start_index = np.random.randint(0, total_duration - tracklet_length) - end_index = start_index + tracklet_length - if end_index <= total_duration: + + if len(subarrays) < num_samples: + additional_subarrays = [] + for i in range(num_samples - len(subarrays)): + start_index = np.random.randint(0, total_duration - tracklet_length) + end_index = start_index + tracklet_length sub_data = data[start_index:end_index, :] if sub_data.shape[0] == tracklet_length: - subarrays.append(sub_data) - + additional_subarrays.append(sub_data) + subarrays.extend(additional_subarrays[: num_samples - len(subarrays)]) return subarrays