From 65c8f9b18da90e1e226b0fb1f409c697c7b3654d Mon Sep 17 00:00:00 2001 From: kapoorlab Date: Tue, 23 Jan 2024 15:54:54 +0100 Subject: [PATCH] compute raw matrix --- src/napatrackmater/Trackvector.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/src/napatrackmater/Trackvector.py b/src/napatrackmater/Trackvector.py index df7b4df7..e745cb20 100644 --- a/src/napatrackmater/Trackvector.py +++ b/src/napatrackmater/Trackvector.py @@ -2137,19 +2137,21 @@ def convert_tracks_to_simple_arrays( def compute_raw_matrix(track_arrays, t_delta): - track_duration = track_arrays.shape[0] t_delta = int(t_delta) - if track_duration < t_delta: - #zero pad - pad_rows = t_delta - track_duration - - # Pad with the last row and then zeros + + if track_duration < t_delta: + repetitions = t_delta - track_duration last_row = track_arrays[-1, :] - zero_pad = np.zeros((pad_rows, track_arrays.shape[1])) - track_arrays = np.vstack((track_arrays, np.tile(last_row, (pad_rows, 1)), zero_pad)) + repeated_rows = np.tile(last_row, (repetitions, 1)) + result_matrix = np.vstack([track_arrays, repeated_rows]) + elif track_duration > t_delta: + result_matrix = track_arrays[:t_delta, :] + else: + result_matrix = track_arrays + - flattened_array = track_arrays.flatten() + flattened_array = result_matrix.flatten() return flattened_array