From 9a9c0e794e100b516715dd48eb55ef44d4d65916 Mon Sep 17 00:00:00 2001 From: Georgios Efstathiadis Date: Fri, 3 Nov 2023 17:51:05 -0400 Subject: [PATCH] bug in find_continuous_dominant_peaks calculation --- forest/oak/base.py | 32 ++++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/forest/oak/base.py b/forest/oak/base.py index 974771f7..2c2f6224 100644 --- a/forest/oak/base.py +++ b/forest/oak/base.py @@ -367,10 +367,12 @@ def find_continuous_dominant_peaks(valid_peaks: np.ndarray, min_t: int, cont_peaks = np.zeros_like(extended_peaks) - for slice_ind in range(num_cols - min_t): + for slice_ind in range(num_cols + 1 - min_t): slice_mat = extended_peaks[:, slice_ind:slice_ind + min_t] - for win_ind in range(min_t): + windows = list(range(min_t)) + list(range(min_t-2, -1, -1)) + + for win_ind in windows: pr = np.where(slice_mat[:, win_ind] != 0)[0] stop = True @@ -380,16 +382,26 @@ def find_continuous_dominant_peaks(valid_peaks: np.ndarray, min_t: int, min(p + delta + 1, num_rows) ) - peaks = slice_mat[p, win_ind] - if win_ind > 0: - peaks += slice_mat[index, win_ind - 1] - if win_ind < min_t - 1: - peaks += slice_mat[index, win_ind + 1] + peaks1 = slice_mat[p, win_ind] + peaks2 = peaks1 + if win_ind == 0: + peaks1 += slice_mat[index, win_ind + 1] + elif win_ind == min_t - 1: + peaks1 += slice_mat[index, win_ind - 1] + else: + peaks1 += slice_mat[index, win_ind - 1] + peaks2 += slice_mat[index, win_ind + 1] - if np.any(peaks > 1): - stop = False + if win_ind == 0 or win_ind == min_t - 1: + if np.any(peaks1 > 1): + stop = False + else: + slice_mat[p, win_ind] = 0 else: - slice_mat[p, win_ind] = 0 + if np.any(peaks1 > 1) and np.any(peaks2 > 1): + stop = False + else: + slice_mat[p, win_ind] = 0 if stop: break