From 05c9ff1796c89fb1c0d64612f004e4b5fa1ed00e Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 10 Nov 2023 12:13:30 +0100 Subject: [PATCH 01/16] Fix corner case in make_match_count_matrix() Add symetric option and propagate in SymmetricSortingComparison/GroundTruthComparison --- .../comparison/comparisontools.py | 135 ++++++++++-------- .../comparison/paircomparisons.py | 9 +- .../comparison/tests/test_comparisontools.py | 53 ++++++- 3 files changed, 132 insertions(+), 65 deletions(-) diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index 3cd856d662..aa9adfcb5c 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -132,61 +132,6 @@ def compute_matching_matrix( num_units_train2, delta_frames, ): - """ - Compute a matrix representing the matches between two spike trains. - - Given two spike trains, this function finds matching spikes based on a temporal proximity criterion - defined by `delta_frames`. The resulting matrix indicates the number of matches between units - in `spike_frames_train1` and `spike_frames_train2`. - - Parameters - ---------- - spike_frames_train1 : ndarray - An array of integer frame numbers corresponding to spike times for the first train. Must be in ascending order. - spike_frames_train2 : ndarray - An array of integer frame numbers corresponding to spike times for the second train. Must be in ascending order. - unit_indices1 : ndarray - An array of integers where `unit_indices1[i]` gives the unit index associated with the spike at `spike_frames_train1[i]`. - unit_indices2 : ndarray - An array of integers where `unit_indices2[i]` gives the unit index associated with the spike at `spike_frames_train2[i]`. - num_units_train1 : int - The total count of unique units in the first spike train. - num_units_train2 : int - The total count of unique units in the second spike train. - delta_frames : int - The inclusive upper limit on the frame difference for which two spikes are considered matching. That is - if `abs(spike_frames_train1[i] - spike_frames_train2[j]) <= delta_frames` then the spikes at `spike_frames_train1[i]` - and `spike_frames_train2[j]` are considered matching. - - Returns - ------- - matching_matrix : ndarray - A 2D numpy array of shape `(num_units_train1, num_units_train2)`. Each element `[i, j]` represents - the count of matching spike pairs between unit `i` from `spike_frames_train1` and unit `j` from `spike_frames_train2`. - - - Notes - ----- - This algorithm identifies matching spikes between two ordered spike trains. - By iterating through each spike in the first train, it compares them against spikes in the second train, - determining matches based on the two spikes frames being within `delta_frames` of each other. - - To avoid redundant comparisons the algorithm maintains a reference, `second_train_search_start `, - which signifies the minimal index in the second spike train that might match the upcoming spike - in the first train. - - The logic can be summarized as follows: - 1. Iterate through each spike in the first train - 2. For each spike, find the first match in the second train. - 3. Save the index of the first match as the new `second_train_search_start ` - 3. For each match, find as many matches as possible from the first match onwards. - - An important condition here is that the same spike is not matched twice. This is managed by keeping track - of the last matched frame for each unit pair in `last_match_frame1` and `last_match_frame2` - - For more details on the rationale behind this approach, refer to the documentation of this module and/or - the metrics section in SpikeForest documentation. - """ matching_matrix = np.zeros((num_units_train1, num_units_train2), dtype=np.uint16) @@ -232,7 +177,61 @@ def compute_matching_matrix( return compute_matching_matrix -def make_match_count_matrix(sorting1, sorting2, delta_frames): +def make_match_count_matrix(sorting1, sorting2, delta_frames, symetric=False): + """ + Compute a matrix representing the matches between two Sorting objects. + + Given two spike trains, this function finds matching spikes based on a temporal proximity criterion + defined by `delta_frames`. The resulting matrix indicates the number of matches between units + in `spike_frames_train1` and `spike_frames_train2` for each pair of units. + + Note that this algo is not symetric and biased toward sorting1 is the ground truth. + + Parameters + ---------- + sorting1 : Sorting + An array of integer frame numbers corresponding to spike times for the first train. Must be in ascending order. + sorting2 : Sorting + An array of integer frame numbers corresponding to spike times for the second train. Must be in ascending order. + delta_frames : int + The inclusive upper limit on the frame difference for which two spikes are considered matching. That is + if `abs(spike_frames_train1[i] - spike_frames_train2[j]) <= delta_frames` then the spikes at `spike_frames_train1[i]` + and `spike_frames_train2[j]` are considered matching. + symetric: bool, dfault False + If symetric, the this the algos is run two times by switching sorting1 and sorting2 the minimum of the two + results is taken. + Returns + ------- + matching_matrix : ndarray + A 2D numpy array of shape `(num_units_train1, num_units_train2)`. Each element `[i, j]` represents + the count of matching spike pairs between unit `i` from `spike_frames_train1` and unit `j` from `spike_frames_train2`. + + Notes + ----- + This algorithm identifies matching spikes between two ordered spike trains. + By iterating through each spike in the first train, it compares them against spikes in the second train, + determining matches based on the two spikes frames being within `delta_frames` of each other. + + To avoid redundant comparisons the algorithm maintains a reference, `second_train_search_start `, + which signifies the minimal index in the second spike train that might match the upcoming spike + in the first train. + + The logic can be summarized as follows: + 1. Iterate through each spike in the first train + 2. For each spike, find the first match in the second train. + 3. Save the index of the first match as the new `second_train_search_start ` + 3. For each match, find as many matches as possible from the first match onwards. + + An important condition here is that the same spike is not matched twice. This is managed by keeping track + of the last matched frame for each unit pair in `last_match_frame1` and `last_match_frame2` + There are corner cases where a spike can be counted twice in the the spiketrain 2 in case of bursting situations + (below delta_frames) in the spiketrain 1. To ensure that the number of match do not exceed the number of spike, + we applied a final clip. + + For more details on the rationale behind this approach, refer to the documentation of this module and/or + the metrics section in SpikeForest documentation. + """ + num_units_sorting1 = sorting1.get_num_units() num_units_sorting2 = sorting2.get_num_units() matching_matrix = np.zeros((num_units_sorting1, num_units_sorting2), dtype=np.uint16) @@ -257,7 +256,7 @@ def make_match_count_matrix(sorting1, sorting2, delta_frames): unit_indices1_sorted = spike_vector1["unit_index"] unit_indices2_sorted = spike_vector2["unit_index"] - matching_matrix += get_optimized_compute_matching_matrix()( + matching_matrix_seg = get_optimized_compute_matching_matrix()( sample_frames1_sorted, sample_frames2_sorted, unit_indices1_sorted, @@ -267,6 +266,28 @@ def make_match_count_matrix(sorting1, sorting2, delta_frames): delta_frames, ) + if symetric: + matching_matrix_seg_switch = get_optimized_compute_matching_matrix()( + sample_frames2_sorted, + sample_frames1_sorted, + unit_indices2_sorted, + unit_indices1_sorted, + num_units_sorting2, + num_units_sorting1, + delta_frames, + ) + matching_matrix_seg = np.maximum(matching_matrix_seg, matching_matrix_seg_switch.T) + + + matching_matrix += matching_matrix_seg + + + # ensure the number of match do not exceed the number of spike in train 2 + # this is a simple way to handle corner cases for bursting in sorting1 + spike_count2 = np.array(list(sorting2.count_num_spikes_per_unit().values())) + spike_count2 = spike_count2[np.newaxis, :] + matching_matrix = np.clip(matching_matrix, None, spike_count2) + # Build a data frame from the matching matrix import pandas as pd diff --git a/src/spikeinterface/comparison/paircomparisons.py b/src/spikeinterface/comparison/paircomparisons.py index 7f21aa657f..e57f2c047a 100644 --- a/src/spikeinterface/comparison/paircomparisons.py +++ b/src/spikeinterface/comparison/paircomparisons.py @@ -28,6 +28,7 @@ def __init__( delta_time=0.4, match_score=0.5, chance_score=0.1, + symetric=False, n_jobs=1, verbose=False, ): @@ -55,6 +56,8 @@ def __init__( self.unit1_ids = self.sorting1.get_unit_ids() self.unit2_ids = self.sorting2.get_unit_ids() + self.symetric = symetric + self._do_agreement() self._do_matching() @@ -84,7 +87,9 @@ def _do_agreement(self): self.event_counts2 = do_count_event(self.sorting2) # matrix of event match count for each pair - self.match_event_count = make_match_count_matrix(self.sorting1, self.sorting2, self.delta_frames) + self.match_event_count = make_match_count_matrix( + self.sorting1, self.sorting2, self.delta_frames, symetric=self.symetric + ) # agreement matrix score for each pair self.agreement_scores = make_agreement_scores_from_count( @@ -151,6 +156,7 @@ def __init__( delta_time=delta_time, match_score=match_score, chance_score=chance_score, + symetric=True, n_jobs=n_jobs, verbose=verbose, ) @@ -283,6 +289,7 @@ def __init__( delta_time=delta_time, match_score=match_score, chance_score=chance_score, + symetric=False, n_jobs=n_jobs, verbose=verbose, ) diff --git a/src/spikeinterface/comparison/tests/test_comparisontools.py b/src/spikeinterface/comparison/tests/test_comparisontools.py index ab24678a1e..137d4cff05 100644 --- a/src/spikeinterface/comparison/tests/test_comparisontools.py +++ b/src/spikeinterface/comparison/tests/test_comparisontools.py @@ -135,6 +135,43 @@ def test_make_match_count_matrix_repeated_matching_but_no_double_counting(): assert_array_equal(result.to_numpy(), expected_result) +def test_make_match_count_matrix_repeated_matching_but_no_double_counting_2(): + # More challenging condition, this was failing with the previous approach that used np.where and np.diff + # This actual implementation should fail but the "clip protection" by number of spike make the solution. + # This is cheating but acceptable for really corner cases (burst in the ground truth). + frames_spike_train1 = [100, 105, 110] + frames_spike_train2 = [100, 105, ] + unit_indices1 = [0, 0, 0] + unit_indices2 = [0, 0,] + delta_frames = 20 # long enough, so all frames in both sortings are within each other reach + + sorting1, sorting2 = make_sorting(frames_spike_train1, unit_indices1, frames_spike_train2, unit_indices2) + + # this is easy because it is sorting2 centric + result = make_match_count_matrix(sorting2, sorting1, delta_frames=delta_frames, symetric=False) + expected_result = np.array([[2]]) + assert_array_equal(result.to_numpy(), expected_result) + + # this work only because we protect by clipping + result = make_match_count_matrix(sorting1, sorting2, delta_frames=delta_frames, symetric=False) + expected_result = np.array([[2]]) + assert_array_equal(result.to_numpy(), expected_result) + +def test_make_match_count_matrix_symetric(): + frames_spike_train1 = [100, 102, 105, 120, 1000, ] + unit_indices1 = [0, 2, 1, 0, 0] + frames_spike_train2 = [101, 150, 1000] + unit_indices2 = [0, 1, 0] + delta_frames = 100 + + sorting1, sorting2 = make_sorting(frames_spike_train1, unit_indices1, frames_spike_train2, unit_indices2) + + result = make_match_count_matrix(sorting1, sorting2, delta_frames=delta_frames, symetric=True) + result_T = make_match_count_matrix(sorting2, sorting1, delta_frames=delta_frames, symetric=True) + + assert_array_equal(result.T, result_T) + + def test_make_match_count_matrix_test_proper_search_in_the_second_train(): "Search exhaustively in the second train, but only within the delta_frames window, do not terminate search early" frames_spike_train1 = [500, 600, 800] @@ -431,13 +468,15 @@ def test_do_count_score_and_perf(): if __name__ == "__main__": - test_make_match_count_matrix() - test_make_match_count_matrix_sorting_with_itself_simple() - test_make_match_count_matrix_sorting_with_itself_longer() - test_make_match_count_matrix_with_mismatched_sortings() - test_make_match_count_matrix_no_double_matching() - test_make_match_count_matrix_repeated_matching_but_no_double_counting() - test_make_match_count_matrix_test_proper_search_in_the_second_train() + # test_make_match_count_matrix() + # test_make_match_count_matrix_sorting_with_itself_simple() + # test_make_match_count_matrix_sorting_with_itself_longer() + # test_make_match_count_matrix_with_mismatched_sortings() + # test_make_match_count_matrix_no_double_matching() + # test_make_match_count_matrix_repeated_matching_but_no_double_counting() + # test_make_match_count_matrix_repeated_matching_but_no_double_counting_2() + # test_make_match_count_matrix_test_proper_search_in_the_second_train() + test_make_match_count_matrix_symetric() # test_make_agreement_scores() From 3aff2255d5a6426b91fd667be13abad8d795df86 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 10 Nov 2023 11:16:08 +0000 Subject: [PATCH 02/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../comparison/comparisontools.py | 5 +---- .../comparison/tests/test_comparisontools.py | 19 ++++++++++++++++--- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index aa9adfcb5c..c8a6edb577 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -132,7 +132,6 @@ def compute_matching_matrix( num_units_train2, delta_frames, ): - matching_matrix = np.zeros((num_units_train1, num_units_train2), dtype=np.uint16) # Used to avoid the same spike matching twice @@ -225,7 +224,7 @@ def make_match_count_matrix(sorting1, sorting2, delta_frames, symetric=False): An important condition here is that the same spike is not matched twice. This is managed by keeping track of the last matched frame for each unit pair in `last_match_frame1` and `last_match_frame2` There are corner cases where a spike can be counted twice in the the spiketrain 2 in case of bursting situations - (below delta_frames) in the spiketrain 1. To ensure that the number of match do not exceed the number of spike, + (below delta_frames) in the spiketrain 1. To ensure that the number of match do not exceed the number of spike, we applied a final clip. For more details on the rationale behind this approach, refer to the documentation of this module and/or @@ -278,10 +277,8 @@ def make_match_count_matrix(sorting1, sorting2, delta_frames, symetric=False): ) matching_matrix_seg = np.maximum(matching_matrix_seg, matching_matrix_seg_switch.T) - matching_matrix += matching_matrix_seg - # ensure the number of match do not exceed the number of spike in train 2 # this is a simple way to handle corner cases for bursting in sorting1 spike_count2 = np.array(list(sorting2.count_num_spikes_per_unit().values())) diff --git a/src/spikeinterface/comparison/tests/test_comparisontools.py b/src/spikeinterface/comparison/tests/test_comparisontools.py index 137d4cff05..5a6f18f0f9 100644 --- a/src/spikeinterface/comparison/tests/test_comparisontools.py +++ b/src/spikeinterface/comparison/tests/test_comparisontools.py @@ -140,9 +140,15 @@ def test_make_match_count_matrix_repeated_matching_but_no_double_counting_2(): # This actual implementation should fail but the "clip protection" by number of spike make the solution. # This is cheating but acceptable for really corner cases (burst in the ground truth). frames_spike_train1 = [100, 105, 110] - frames_spike_train2 = [100, 105, ] + frames_spike_train2 = [ + 100, + 105, + ] unit_indices1 = [0, 0, 0] - unit_indices2 = [0, 0,] + unit_indices2 = [ + 0, + 0, + ] delta_frames = 20 # long enough, so all frames in both sortings are within each other reach sorting1, sorting2 = make_sorting(frames_spike_train1, unit_indices1, frames_spike_train2, unit_indices2) @@ -157,8 +163,15 @@ def test_make_match_count_matrix_repeated_matching_but_no_double_counting_2(): expected_result = np.array([[2]]) assert_array_equal(result.to_numpy(), expected_result) + def test_make_match_count_matrix_symetric(): - frames_spike_train1 = [100, 102, 105, 120, 1000, ] + frames_spike_train1 = [ + 100, + 102, + 105, + 120, + 1000, + ] unit_indices1 = [0, 2, 1, 0, 0] frames_spike_train2 = [101, 150, 1000] unit_indices2 = [0, 1, 0] From c5263f7e7a4fe50bee987aa4947d865c094ef067 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 10 Nov 2023 12:20:22 +0100 Subject: [PATCH 03/16] get_optimized_compute_matching_matrix: protect with index instead of frame --- src/spikeinterface/comparison/comparisontools.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index aa9adfcb5c..0a142ed878 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -161,11 +161,11 @@ def compute_matching_matrix( unit_index1, unit_index2 = unit_indices1[index1], unit_indices2[index2] if ( - frame1 != last_match_frame1[unit_index1, unit_index2] - and frame2 != last_match_frame2[unit_index1, unit_index2] + index1 != last_match_frame1[unit_index1, unit_index2] + and index2 != last_match_frame2[unit_index1, unit_index2] ): - last_match_frame1[unit_index1, unit_index2] = frame1 - last_match_frame2[unit_index1, unit_index2] = frame2 + last_match_frame1[unit_index1, unit_index2] = index1 + last_match_frame2[unit_index1, unit_index2] = index2 matching_matrix[unit_index1, unit_index2] += 1 From a823a08dde2d2acf0c055d6c0eb93b80bd09212b Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 10 Nov 2023 13:14:24 +0100 Subject: [PATCH 04/16] oups --- src/spikeinterface/comparison/paircomparisons.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/comparison/paircomparisons.py b/src/spikeinterface/comparison/paircomparisons.py index e57f2c047a..d6d40c8d8c 100644 --- a/src/spikeinterface/comparison/paircomparisons.py +++ b/src/spikeinterface/comparison/paircomparisons.py @@ -28,7 +28,7 @@ def __init__( delta_time=0.4, match_score=0.5, chance_score=0.1, - symetric=False, + symmetric=False, n_jobs=1, verbose=False, ): @@ -56,7 +56,7 @@ def __init__( self.unit1_ids = self.sorting1.get_unit_ids() self.unit2_ids = self.sorting2.get_unit_ids() - self.symetric = symetric + self.symmetric = symmetric self._do_agreement() self._do_matching() @@ -88,7 +88,7 @@ def _do_agreement(self): # matrix of event match count for each pair self.match_event_count = make_match_count_matrix( - self.sorting1, self.sorting2, self.delta_frames, symetric=self.symetric + self.sorting1, self.sorting2, self.delta_frames, symmetric=self.symmetric ) # agreement matrix score for each pair @@ -156,7 +156,7 @@ def __init__( delta_time=delta_time, match_score=match_score, chance_score=chance_score, - symetric=True, + symmetric=True, n_jobs=n_jobs, verbose=verbose, ) @@ -289,7 +289,7 @@ def __init__( delta_time=delta_time, match_score=match_score, chance_score=chance_score, - symetric=False, + symmetric=False, n_jobs=n_jobs, verbose=verbose, ) From b03febe478aab1af44f82f1e7357c29771c3fbbc Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 10 Nov 2023 13:15:51 +0100 Subject: [PATCH 05/16] oups --- src/spikeinterface/comparison/comparisontools.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index 26f220ef73..9dcda06ada 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -176,7 +176,7 @@ def compute_matching_matrix( return compute_matching_matrix -def make_match_count_matrix(sorting1, sorting2, delta_frames, symetric=False): +def make_match_count_matrix(sorting1, sorting2, delta_frames, symmetric=False): """ Compute a matrix representing the matches between two Sorting objects. @@ -184,7 +184,7 @@ def make_match_count_matrix(sorting1, sorting2, delta_frames, symetric=False): defined by `delta_frames`. The resulting matrix indicates the number of matches between units in `spike_frames_train1` and `spike_frames_train2` for each pair of units. - Note that this algo is not symetric and biased toward sorting1 is the ground truth. + Note that this algo is not symmetric and biased toward sorting1 is the ground truth. Parameters ---------- @@ -196,8 +196,8 @@ def make_match_count_matrix(sorting1, sorting2, delta_frames, symetric=False): The inclusive upper limit on the frame difference for which two spikes are considered matching. That is if `abs(spike_frames_train1[i] - spike_frames_train2[j]) <= delta_frames` then the spikes at `spike_frames_train1[i]` and `spike_frames_train2[j]` are considered matching. - symetric: bool, dfault False - If symetric, the this the algos is run two times by switching sorting1 and sorting2 the minimum of the two + symmetric: bool, dfault False + If symmetric, the this the algos is run two times by switching sorting1 and sorting2 the minimum of the two results is taken. Returns ------- @@ -265,7 +265,7 @@ def make_match_count_matrix(sorting1, sorting2, delta_frames, symetric=False): delta_frames, ) - if symetric: + if symmetric: matching_matrix_seg_switch = get_optimized_compute_matching_matrix()( sample_frames2_sorted, sample_frames1_sorted, From dc55dfc97a66ae8681a501951bbdaa1be8342be9 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 10 Nov 2023 16:25:10 +0100 Subject: [PATCH 06/16] oups --- .../comparison/tests/test_comparisontools.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/comparison/tests/test_comparisontools.py b/src/spikeinterface/comparison/tests/test_comparisontools.py index 5a6f18f0f9..b6cd3fc3b4 100644 --- a/src/spikeinterface/comparison/tests/test_comparisontools.py +++ b/src/spikeinterface/comparison/tests/test_comparisontools.py @@ -154,17 +154,17 @@ def test_make_match_count_matrix_repeated_matching_but_no_double_counting_2(): sorting1, sorting2 = make_sorting(frames_spike_train1, unit_indices1, frames_spike_train2, unit_indices2) # this is easy because it is sorting2 centric - result = make_match_count_matrix(sorting2, sorting1, delta_frames=delta_frames, symetric=False) + result = make_match_count_matrix(sorting2, sorting1, delta_frames=delta_frames, symmetric=False) expected_result = np.array([[2]]) assert_array_equal(result.to_numpy(), expected_result) # this work only because we protect by clipping - result = make_match_count_matrix(sorting1, sorting2, delta_frames=delta_frames, symetric=False) + result = make_match_count_matrix(sorting1, sorting2, delta_frames=delta_frames, symmetric=False) expected_result = np.array([[2]]) assert_array_equal(result.to_numpy(), expected_result) -def test_make_match_count_matrix_symetric(): +def test_make_match_count_matrix_symmetric(): frames_spike_train1 = [ 100, 102, @@ -179,8 +179,8 @@ def test_make_match_count_matrix_symetric(): sorting1, sorting2 = make_sorting(frames_spike_train1, unit_indices1, frames_spike_train2, unit_indices2) - result = make_match_count_matrix(sorting1, sorting2, delta_frames=delta_frames, symetric=True) - result_T = make_match_count_matrix(sorting2, sorting1, delta_frames=delta_frames, symetric=True) + result = make_match_count_matrix(sorting1, sorting2, delta_frames=delta_frames, symmetric=True) + result_T = make_match_count_matrix(sorting2, sorting1, delta_frames=delta_frames, symmetric=True) assert_array_equal(result.T, result_T) @@ -224,7 +224,7 @@ def test_make_agreement_scores(): assert_array_equal(agreement_scores.values, ok) - # test if symetric + # test if symmetric agreement_scores2 = make_agreement_scores(sorting2, sorting1, delta_frames) assert_array_equal(agreement_scores, agreement_scores2.T) @@ -489,7 +489,7 @@ def test_do_count_score_and_perf(): # test_make_match_count_matrix_repeated_matching_but_no_double_counting() # test_make_match_count_matrix_repeated_matching_but_no_double_counting_2() # test_make_match_count_matrix_test_proper_search_in_the_second_train() - test_make_match_count_matrix_symetric() + test_make_match_count_matrix_symmetric() # test_make_agreement_scores() From b55991ca59cec7f40e9eca5c6383bef8f156d2ac Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Mon, 13 Nov 2023 10:10:17 +0100 Subject: [PATCH 07/16] Update src/spikeinterface/comparison/comparisontools.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/comparison/comparisontools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index 9dcda06ada..1b503629bd 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -221,7 +221,7 @@ def make_match_count_matrix(sorting1, sorting2, delta_frames, symmetric=False): 3. Save the index of the first match as the new `second_train_search_start ` 3. For each match, find as many matches as possible from the first match onwards. - An important condition here is that the same spike is not matched twice. This is managed by keeping track + An important condition is that the same spike is not matched twice. This is managed by keeping track of the last matched frame for each unit pair in `last_match_frame1` and `last_match_frame2` There are corner cases where a spike can be counted twice in the the spiketrain 2 in case of bursting situations (below delta_frames) in the spiketrain 1. To ensure that the number of match do not exceed the number of spike, From 866469addcab659552166047a4524976e4fe0687 Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Mon, 13 Nov 2023 10:10:35 +0100 Subject: [PATCH 08/16] Update src/spikeinterface/comparison/comparisontools.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/comparison/comparisontools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index 1b503629bd..2f7ed61427 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -184,7 +184,7 @@ def make_match_count_matrix(sorting1, sorting2, delta_frames, symmetric=False): defined by `delta_frames`. The resulting matrix indicates the number of matches between units in `spike_frames_train1` and `spike_frames_train2` for each pair of units. - Note that this algo is not symmetric and biased toward sorting1 is the ground truth. + Note that this algo is not symmetric and is biased with `sorting1` representing ground truth for the comparison Parameters ---------- From 2e315392934a5af6239636791c3eee1c3bf45787 Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Mon, 13 Nov 2023 10:10:42 +0100 Subject: [PATCH 09/16] Update src/spikeinterface/comparison/comparisontools.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/comparison/comparisontools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index 2f7ed61427..1ad1c59499 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -178,7 +178,7 @@ def compute_matching_matrix( def make_match_count_matrix(sorting1, sorting2, delta_frames, symmetric=False): """ - Compute a matrix representing the matches between two Sorting objects. + Computes a matrix representing the matches between two Sorting objects. Given two spike trains, this function finds matching spikes based on a temporal proximity criterion defined by `delta_frames`. The resulting matrix indicates the number of matches between units From 4fa5a0cde87e64c818a9bceb5e2ebcd1268a5289 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 13 Nov 2023 13:54:15 +0100 Subject: [PATCH 10/16] symmetric > ensure_symmetry --- .../comparison/comparisontools.py | 55 +++++++++++++++---- .../comparison/paircomparisons.py | 10 ++-- .../comparison/tests/test_comparisontools.py | 50 ++++++++--------- 3 files changed, 73 insertions(+), 42 deletions(-) diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index 1ad1c59499..d9ab2e685d 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -132,6 +132,36 @@ def compute_matching_matrix( num_units_train2, delta_frames, ): + """ + Internal function used by `make_match_count_matrix()`. + This function is for one segment only. + The llop over segment is done in `make_match_count_matrix()` + + Parameters + ---------- + spike_frames_train1 : ndarray + An array of integer frame numbers corresponding to spike times for the first train. Must be in ascending order. + spike_frames_train2 : ndarray + An array of integer frame numbers corresponding to spike times for the second train. Must be in ascending order. + unit_indices1 : ndarray + An array of integers where `unit_indices1[i]` gives the unit index associated with the spike at `spike_frames_train1[i]`. + unit_indices2 : ndarray + An array of integers where `unit_indices2[i]` gives the unit index associated with the spike at `spike_frames_train2[i]`. + num_units_train1 : int + The total count of unique units in the first spike train. + num_units_train2 : int + The total count of unique units in the second spike train. + delta_frames : int + The inclusive upper limit on the frame difference for which two spikes are considered matching. That is + if `abs(spike_frames_train1[i] - spike_frames_train2[j]) <= delta_frames` then the spikes at `spike_frames_train1[i]` + and `spike_frames_train2[j]` are considered matching. + + Returns + ------- + matching_matrix : ndarray + A 2D numpy array of shape `(num_units_train1, num_units_train2)` + + """ matching_matrix = np.zeros((num_units_train1, num_units_train2), dtype=np.uint16) # Used to avoid the same spike matching twice @@ -176,7 +206,7 @@ def compute_matching_matrix( return compute_matching_matrix -def make_match_count_matrix(sorting1, sorting2, delta_frames, symmetric=False): +def make_match_count_matrix(sorting1, sorting2, delta_frames, ensure_symmetry=False): """ Computes a matrix representing the matches between two Sorting objects. @@ -194,11 +224,11 @@ def make_match_count_matrix(sorting1, sorting2, delta_frames, symmetric=False): An array of integer frame numbers corresponding to spike times for the second train. Must be in ascending order. delta_frames : int The inclusive upper limit on the frame difference for which two spikes are considered matching. That is - if `abs(spike_frames_train1[i] - spike_frames_train2[j]) <= delta_frames` then the spikes at `spike_frames_train1[i]` - and `spike_frames_train2[j]` are considered matching. - symmetric: bool, dfault False - If symmetric, the this the algos is run two times by switching sorting1 and sorting2 the minimum of the two - results is taken. + if `abs(spike_frames_train1[i] - spike_frames_train2[j]) <= delta_frames` then the spikes at + `spike_frames_train1[i]` and `spike_frames_train2[j]` are considered matching. + ensure_symmetry: bool, default False + If ensure_symmetry=True, then the algo is run two times by switching sorting1 and sorting2. + And the minimum of the two results is taken. Returns ------- matching_matrix : ndarray @@ -221,11 +251,12 @@ def make_match_count_matrix(sorting1, sorting2, delta_frames, symmetric=False): 3. Save the index of the first match as the new `second_train_search_start ` 3. For each match, find as many matches as possible from the first match onwards. - An important condition is that the same spike is not matched twice. This is managed by keeping track + An important condition here is that the same spike is not matched twice. This is managed by keeping track of the last matched frame for each unit pair in `last_match_frame1` and `last_match_frame2` - There are corner cases where a spike can be counted twice in the the spiketrain 2 in case of bursting situations - (below delta_frames) in the spiketrain 1. To ensure that the number of match do not exceed the number of spike, - we applied a final clip. + There are corner cases where a spike can be counted twice in the spiketrain 2 if there are bouts of bursting activity + (below delta_frames) in the spiketrain 1. To ensure that the number of matches does not exceed the number of spikes, + we apply a final clip. + For more details on the rationale behind this approach, refer to the documentation of this module and/or the metrics section in SpikeForest documentation. @@ -265,7 +296,7 @@ def make_match_count_matrix(sorting1, sorting2, delta_frames, symmetric=False): delta_frames, ) - if symmetric: + if ensure_symmetry: matching_matrix_seg_switch = get_optimized_compute_matching_matrix()( sample_frames2_sorted, sample_frames1_sorted, @@ -327,7 +358,7 @@ def make_agreement_scores(sorting1, sorting2, delta_frames): event_counts1 = pd.Series(ev_counts1, index=unit1_ids) event_counts2 = pd.Series(ev_counts2, index=unit2_ids) - match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames) + match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames, ensure_symmetry=True) agreement_scores = make_agreement_scores_from_count(match_event_count, event_counts1, event_counts2) diff --git a/src/spikeinterface/comparison/paircomparisons.py b/src/spikeinterface/comparison/paircomparisons.py index d6d40c8d8c..02e74b7053 100644 --- a/src/spikeinterface/comparison/paircomparisons.py +++ b/src/spikeinterface/comparison/paircomparisons.py @@ -28,7 +28,7 @@ def __init__( delta_time=0.4, match_score=0.5, chance_score=0.1, - symmetric=False, + ensure_symmetry=False, n_jobs=1, verbose=False, ): @@ -56,7 +56,7 @@ def __init__( self.unit1_ids = self.sorting1.get_unit_ids() self.unit2_ids = self.sorting2.get_unit_ids() - self.symmetric = symmetric + self.ensure_symmetry = ensure_symmetry self._do_agreement() self._do_matching() @@ -88,7 +88,7 @@ def _do_agreement(self): # matrix of event match count for each pair self.match_event_count = make_match_count_matrix( - self.sorting1, self.sorting2, self.delta_frames, symmetric=self.symmetric + self.sorting1, self.sorting2, self.delta_frames, ensure_symmetry=self.ensure_symmetry ) # agreement matrix score for each pair @@ -156,7 +156,7 @@ def __init__( delta_time=delta_time, match_score=match_score, chance_score=chance_score, - symmetric=True, + ensure_symmetry=True, n_jobs=n_jobs, verbose=verbose, ) @@ -289,7 +289,7 @@ def __init__( delta_time=delta_time, match_score=match_score, chance_score=chance_score, - symmetric=False, + ensure_symmetry=False, n_jobs=n_jobs, verbose=verbose, ) diff --git a/src/spikeinterface/comparison/tests/test_comparisontools.py b/src/spikeinterface/comparison/tests/test_comparisontools.py index b6cd3fc3b4..31adee8ca4 100644 --- a/src/spikeinterface/comparison/tests/test_comparisontools.py +++ b/src/spikeinterface/comparison/tests/test_comparisontools.py @@ -154,17 +154,17 @@ def test_make_match_count_matrix_repeated_matching_but_no_double_counting_2(): sorting1, sorting2 = make_sorting(frames_spike_train1, unit_indices1, frames_spike_train2, unit_indices2) # this is easy because it is sorting2 centric - result = make_match_count_matrix(sorting2, sorting1, delta_frames=delta_frames, symmetric=False) + result = make_match_count_matrix(sorting2, sorting1, delta_frames=delta_frames, ensure_symmetry=False) expected_result = np.array([[2]]) assert_array_equal(result.to_numpy(), expected_result) # this work only because we protect by clipping - result = make_match_count_matrix(sorting1, sorting2, delta_frames=delta_frames, symmetric=False) + result = make_match_count_matrix(sorting1, sorting2, delta_frames=delta_frames, ensure_symmetry=False) expected_result = np.array([[2]]) assert_array_equal(result.to_numpy(), expected_result) -def test_make_match_count_matrix_symmetric(): +def test_make_match_count_matrix_ensure_symmetry(): frames_spike_train1 = [ 100, 102, @@ -179,8 +179,8 @@ def test_make_match_count_matrix_symmetric(): sorting1, sorting2 = make_sorting(frames_spike_train1, unit_indices1, frames_spike_train2, unit_indices2) - result = make_match_count_matrix(sorting1, sorting2, delta_frames=delta_frames, symmetric=True) - result_T = make_match_count_matrix(sorting2, sorting1, delta_frames=delta_frames, symmetric=True) + result = make_match_count_matrix(sorting1, sorting2, delta_frames=delta_frames, ensure_symmetry=True) + result_T = make_match_count_matrix(sorting2, sorting1, delta_frames=delta_frames, ensure_symmetry=True) assert_array_equal(result.T, result_T) @@ -481,23 +481,23 @@ def test_do_count_score_and_perf(): if __name__ == "__main__": - # test_make_match_count_matrix() - # test_make_match_count_matrix_sorting_with_itself_simple() - # test_make_match_count_matrix_sorting_with_itself_longer() - # test_make_match_count_matrix_with_mismatched_sortings() - # test_make_match_count_matrix_no_double_matching() - # test_make_match_count_matrix_repeated_matching_but_no_double_counting() - # test_make_match_count_matrix_repeated_matching_but_no_double_counting_2() - # test_make_match_count_matrix_test_proper_search_in_the_second_train() - test_make_match_count_matrix_symmetric() - - # test_make_agreement_scores() - - # test_make_possible_match() - # test_make_best_match() - # test_make_hungarian_match() - - # test_do_score_labels() - # test_compare_spike_trains() - # test_do_confusion_matrix() - # test_do_count_score_and_perf() + test_make_match_count_matrix() + test_make_match_count_matrix_sorting_with_itself_simple() + test_make_match_count_matrix_sorting_with_itself_longer() + test_make_match_count_matrix_with_mismatched_sortings() + test_make_match_count_matrix_no_double_matching() + test_make_match_count_matrix_repeated_matching_but_no_double_counting() + test_make_match_count_matrix_repeated_matching_but_no_double_counting_2() + test_make_match_count_matrix_test_proper_search_in_the_second_train() + test_make_match_count_matrix_ensure_symmetry() + + test_make_agreement_scores() + + test_make_possible_match() + test_make_best_match() + test_make_hungarian_match() + + test_do_score_labels() + test_compare_spike_trains() + test_do_confusion_matrix() + test_do_count_score_and_perf() From bb38af83a4ebea4af91ae389306ae08226f8fd46 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 13 Nov 2023 12:58:23 +0000 Subject: [PATCH 11/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/comparison/comparisontools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index 7bab5fe3aa..f475dbf2e1 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -257,7 +257,7 @@ def make_match_count_matrix(sorting1, sorting2, delta_frames, ensure_symmetry=Fa There are corner cases where a spike can be counted twice in the spiketrain 2 if there are bouts of bursting activity (below delta_frames) in the spiketrain 1. To ensure that the number of matches does not exceed the number of spikes, we apply a final clip. - + For more details on the rationale behind this approach, refer to the documentation of this module and/or the metrics section in SpikeForest documentation. From 8d3b830bb79fd0e2f0e4b55fffee69ba5cc59a48 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 13 Nov 2023 14:10:09 +0100 Subject: [PATCH 12/16] Improve do_count_event() --- src/spikeinterface/comparison/comparisontools.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index 7bab5fe3aa..b3d76f9da5 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -63,6 +63,9 @@ def compute_agreement_score(num_matches, num1, num2): def do_count_event(sorting): """ Count event for each units in a sorting. + + Kept for backward compatibility sorting.count_num_spikes_per_unit() is doing the same. + Parameters ---------- sorting: SortingExtractor @@ -75,14 +78,8 @@ def do_count_event(sorting): """ import pandas as pd - unit_ids = sorting.get_unit_ids() - ev_counts = np.zeros(len(unit_ids), dtype="int64") - for segment_index in range(sorting.get_num_segments()): - ev_counts += np.array( - [len(sorting.get_unit_spike_train(u, segment_index=segment_index)) for u in unit_ids], dtype="int64" - ) - event_counts = pd.Series(ev_counts, index=unit_ids) - return event_counts + return pd.Series(sorting.count_num_spikes_per_unit()) + def count_match_spikes(times1, all_times2, delta_frames): # , event_counts1, event_counts2 unit2_ids, From 27f4e6dbf4d0d48f640e53c530b90fbfb3434166 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 13 Nov 2023 13:11:30 +0000 Subject: [PATCH 13/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/comparison/comparisontools.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index 64b9b757f6..c56d02e3b3 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -63,7 +63,7 @@ def compute_agreement_score(num_matches, num1, num2): def do_count_event(sorting): """ Count event for each units in a sorting. - + Kept for backward compatibility sorting.count_num_spikes_per_unit() is doing the same. Parameters @@ -81,7 +81,6 @@ def do_count_event(sorting): return pd.Series(sorting.count_num_spikes_per_unit()) - def count_match_spikes(times1, all_times2, delta_frames): # , event_counts1, event_counts2 unit2_ids, """ Computes matching spikes between one spike train and a list of others. From 0349dbb017c75cbd2fb2c54b94e24e053c19b0b3 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 13 Nov 2023 14:36:24 +0100 Subject: [PATCH 14/16] Update src/spikeinterface/comparison/comparisontools.py --- src/spikeinterface/comparison/comparisontools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index c56d02e3b3..731753287e 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -131,7 +131,7 @@ def compute_matching_matrix( """ Internal function used by `make_match_count_matrix()`. This function is for one segment only. - The llop over segment is done in `make_match_count_matrix()` + The loop over segment is done in `make_match_count_matrix()` Parameters ---------- From 62aa0ad36e632ecac94b9f3d37bf3d7081ea8989 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 13 Nov 2023 14:53:09 +0100 Subject: [PATCH 15/16] expose ensure_symmetry in make_agreement_scores() --- src/spikeinterface/comparison/comparisontools.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index 64b9b757f6..5fd8195998 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -324,12 +324,12 @@ def make_match_count_matrix(sorting1, sorting2, delta_frames, ensure_symmetry=Fa return match_event_counts_df -def make_agreement_scores(sorting1, sorting2, delta_frames): +def make_agreement_scores(sorting1, sorting2, delta_frames, ensure_symmetry=True): """ Make the agreement matrix. No threshold (min_score) is applied at this step. - Note : this computation is symmetric. + Note : this computation is symmetric by default. Inverting sorting1 and sorting2 give the transposed matrix. Parameters @@ -340,7 +340,9 @@ def make_agreement_scores(sorting1, sorting2, delta_frames): The second sorting extractor delta_frames: int Number of frames to consider spikes coincident - + ensure_symmetry: bool, default: True + If ensure_symmetry is True, then the algo is run two times by switching sorting1 and sorting2. + And the minimum of the two results is taken. Returns ------- agreement_scores: array (float) @@ -356,7 +358,7 @@ def make_agreement_scores(sorting1, sorting2, delta_frames): event_counts1 = pd.Series(ev_counts1, index=unit1_ids) event_counts2 = pd.Series(ev_counts2, index=unit2_ids) - match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames, ensure_symmetry=True) + match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames, ensure_symmetry=ensure_symmetry) agreement_scores = make_agreement_scores_from_count(match_event_count, event_counts1, event_counts2) From 93cd8fff205a71112bfb5e365dbfae022e0c5d12 Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Mon, 13 Nov 2023 14:53:57 +0100 Subject: [PATCH 16/16] Update src/spikeinterface/comparison/comparisontools.py Co-authored-by: Heberto Mayorquin --- src/spikeinterface/comparison/comparisontools.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index 8e3b984420..19ba6afd27 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -155,7 +155,8 @@ def compute_matching_matrix( Returns ------- matching_matrix : ndarray - A 2D numpy array of shape `(num_units_train1, num_units_train2)` + A 2D numpy array of shape `(num_units_train1, num_units_train2)`. Each element `[i, j]` represents + the count of matching spike pairs between unit `i` from `spike_frames_train1` and unit `j` from `spike_frames_train2`. """