From 837b90edbb1f42bba1a492f5810fa6f9b3bcec09 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 2 Nov 2023 15:16:33 +0100 Subject: [PATCH 1/7] add ovewrite option to save to folder --- src/spikeinterface/core/base.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index b51bace55f..5c6a6d260c 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -788,7 +788,7 @@ def save_to_memory(self, **kwargs) -> "BaseExtractor": return cached # TODO rename to saveto_binary_folder - def save_to_folder(self, name=None, folder=None, verbose=True, **save_kwargs): + def save_to_folder(self, name=None, folder=None, ovewrite=False, verbose=True, **save_kwargs): """ Save extractor to folder. @@ -819,6 +819,8 @@ def save_to_folder(self, name=None, folder=None, verbose=True, **save_kwargs): folder: None str or Path Name of the folder. If "folder" is given, "name" must be None. + ovewrite: bool, default: False + If True, the folder is removed if it already exists Returns ------- @@ -839,7 +841,12 @@ def save_to_folder(self, name=None, folder=None, verbose=True, **save_kwargs): print(f"Use cache_folder={folder}") else: folder = Path(folder) - assert not folder.exists(), f"folder {folder} already exists, choose another name" + if ovewrite and folder.is_dir(): + import shutil + + shutil.rmtree(folder) + + assert not folder.exists(), f"folder {folder} already exists, choose another name or use ovewrite=True" folder.mkdir(parents=True, exist_ok=False) # dump provenance From ba908a5784fd0b5dd0fb3cac9024e7f93023b55b Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Fri, 3 Nov 2023 11:33:00 +0100 Subject: [PATCH 2/7] fix names --- src/spikeinterface/core/base.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 5c6a6d260c..1a8674697a 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -788,7 +788,7 @@ def save_to_memory(self, **kwargs) -> "BaseExtractor": return cached # TODO rename to saveto_binary_folder - def save_to_folder(self, name=None, folder=None, ovewrite=False, verbose=True, **save_kwargs): + def save_to_folder(self, name=None, folder=None, overwrite=False, verbose=True, **save_kwargs): """ Save extractor to folder. @@ -819,7 +819,7 @@ def save_to_folder(self, name=None, folder=None, ovewrite=False, verbose=True, * folder: None str or Path Name of the folder. If "folder" is given, "name" must be None. - ovewrite: bool, default: False + overwrite: bool, default: False If True, the folder is removed if it already exists Returns @@ -841,12 +841,12 @@ def save_to_folder(self, name=None, folder=None, ovewrite=False, verbose=True, * print(f"Use cache_folder={folder}") else: folder = Path(folder) - if ovewrite and folder.is_dir(): + if overwrite and folder.is_dir(): import shutil shutil.rmtree(folder) - assert not folder.exists(), f"folder {folder} already exists, choose another name or use ovewrite=True" + assert not folder.exists(), f"folder {folder} already exists, choose another name or use overwrite=True" folder.mkdir(parents=True, exist_ok=False) # dump provenance From c4375ae4a0fe32e7bf66e3b96a7fe9af19459920 Mon Sep 17 00:00:00 2001 From: Zach McKenzie <92116279+zm711@users.noreply.github.com> Date: Fri, 3 Nov 2023 08:20:50 -0400 Subject: [PATCH 3/7] update readme to the version docs --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 883dcdb944..d51f372848 100644 --- a/README.md +++ b/README.md @@ -67,7 +67,7 @@ With SpikeInterface, users can: ## Documentation -Detailed documentation of the latest PyPI release of SpikeInterface can be found [here](https://spikeinterface.readthedocs.io/en/0.98.2). +Detailed documentation of the latest PyPI release of SpikeInterface can be found [here](https://spikeinterface.readthedocs.io/en/0.99.0). Detailed documentation of the development version of SpikeInterface can be found [here](https://spikeinterface.readthedocs.io/en/latest). From 3712da0ab5bfe22cf53598dc04c47c37ef554c02 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Mon, 6 Nov 2023 21:35:01 +0000 Subject: [PATCH 4/7] Handle start / stop frame default `None`. --- src/spikeinterface/sortingcomponents/motion_interpolation.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/spikeinterface/sortingcomponents/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion_interpolation.py index a81212897c..86485aa25d 100644 --- a/src/spikeinterface/sortingcomponents/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion_interpolation.py @@ -387,6 +387,11 @@ def get_traces(self, start_frame, end_frame, channel_indices): ) # times = np.asarray(self.time_vector[start_frame:end_frame]) else: + if start_frame is None: + start_frame = 0 + if end_frame is None: + end_frame = self.get_num_samples() + times = np.arange((end_frame or self.get_num_samples()) - (start_frame or 0), dtype="float64") times /= self.sampling_frequency t0 = start_frame / self.sampling_frequency From 266be6f2861490bd3e8a3d5e3d3a9c49527f6f50 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Mon, 6 Nov 2023 21:36:50 +0000 Subject: [PATCH 5/7] Remove redundant `else` statement. --- .../sortingcomponents/motion_interpolation.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion_interpolation.py index 86485aa25d..ba046db85f 100644 --- a/src/spikeinterface/sortingcomponents/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion_interpolation.py @@ -386,18 +386,18 @@ def get_traces(self, start_frame, end_frame, channel_indices): "time_vector for InterpolateMotionRecording do not work because temporal_bins start from 0" ) # times = np.asarray(self.time_vector[start_frame:end_frame]) - else: - if start_frame is None: - start_frame = 0 - if end_frame is None: - end_frame = self.get_num_samples() - - times = np.arange((end_frame or self.get_num_samples()) - (start_frame or 0), dtype="float64") - times /= self.sampling_frequency - t0 = start_frame / self.sampling_frequency - # if self.t_start is not None: - # t0 = t0 + self.t_start - times += t0 + + if start_frame is None: + start_frame = 0 + if end_frame is None: + end_frame = self.get_num_samples() + + times = np.arange((end_frame or self.get_num_samples()) - (start_frame or 0), dtype="float64") + times /= self.sampling_frequency + t0 = start_frame / self.sampling_frequency + # if self.t_start is not None: + # t0 = t0 + self.t_start + times += t0 traces = self.parent_recording_segment.get_traces(start_frame, end_frame, channel_indices=slice(None)) From dcce25a397e7b365a1bd0de22032b32d48128a4c Mon Sep 17 00:00:00 2001 From: Joe Ziminski <55797454+JoeZiminski@users.noreply.github.com> Date: Wed, 8 Nov 2023 12:48:16 +0000 Subject: [PATCH 6/7] Update src/spikeinterface/sortingcomponents/motion_interpolation.py Co-authored-by: Alessio Buccino --- src/spikeinterface/sortingcomponents/motion_interpolation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion_interpolation.py index ba046db85f..93a8ce62c8 100644 --- a/src/spikeinterface/sortingcomponents/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion_interpolation.py @@ -392,7 +392,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): if end_frame is None: end_frame = self.get_num_samples() - times = np.arange((end_frame or self.get_num_samples()) - (start_frame or 0), dtype="float64") + times = np.arange(end_frame - start_frame, dtype="float64") times /= self.sampling_frequency t0 = start_frame / self.sampling_frequency # if self.t_start is not None: From 329197618a9b48ef876d1e8b8e79f07f4abf5e49 Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Thu, 9 Nov 2023 10:33:00 +0100 Subject: [PATCH 7/7] Fix compute matching v3 (#2182) * some change to test * another change * another attempt * attempt merge * add condition * add auth * fix test and simpler implementation * small typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * avoid corner cose of doing the matching loop twice * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove n_jobs * Little docs cleanup * Remove internal n_jobs * Remove last internal n_jobs * Apply suggestions from code review * fix test * comment to test * docstring improvements * variable naming * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * new proposal for compute_matching * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: Heberto Mayorquin Co-authored-by: Alessio Buccino --- .../comparison/comparisontools.py | 130 +++++++++--------- .../comparison/paircomparisons.py | 4 +- .../comparison/tests/test_comparisontools.py | 66 ++++++--- 3 files changed, 111 insertions(+), 89 deletions(-) diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index 7a1fb87175..3cd856d662 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -124,12 +124,12 @@ def get_optimized_compute_matching_matrix(): @numba.jit(nopython=True, nogil=True) def compute_matching_matrix( - frames_spike_train1, - frames_spike_train2, + spike_frames_train1, + spike_frames_train2, unit_indices1, unit_indices2, - num_units_sorting1, - num_units_sorting2, + num_units_train1, + num_units_train2, delta_frames, ): """ @@ -137,30 +137,33 @@ def compute_matching_matrix( Given two spike trains, this function finds matching spikes based on a temporal proximity criterion defined by `delta_frames`. The resulting matrix indicates the number of matches between units - in `frames_spike_train1` and `frames_spike_train2`. + in `spike_frames_train1` and `spike_frames_train2`. Parameters ---------- - frames_spike_train1 : ndarray - Array of frames for the first spike train. Should be ordered in ascending order. - frames_spike_train2 : ndarray - Array of frames for the second spike train. Should be ordered in ascending order. + spike_frames_train1 : ndarray + An array of integer frame numbers corresponding to spike times for the first train. Must be in ascending order. + spike_frames_train2 : ndarray + An array of integer frame numbers corresponding to spike times for the second train. Must be in ascending order. unit_indices1 : ndarray - Array indicating the unit indices corresponding to each spike in `frames_spike_train1`. + An array of integers where `unit_indices1[i]` gives the unit index associated with the spike at `spike_frames_train1[i]`. unit_indices2 : ndarray - Array indicating the unit indices corresponding to each spike in `frames_spike_train2`. - num_units_sorting1 : int - Total number of units in the first spike train. - num_units_sorting2 : int - Total number of units in the second spike train. + An array of integers where `unit_indices2[i]` gives the unit index associated with the spike at `spike_frames_train2[i]`. + num_units_train1 : int + The total count of unique units in the first spike train. + num_units_train2 : int + The total count of unique units in the second spike train. delta_frames : int - Maximum difference in frames between two spikes to consider them as a match. + The inclusive upper limit on the frame difference for which two spikes are considered matching. That is + if `abs(spike_frames_train1[i] - spike_frames_train2[j]) <= delta_frames` then the spikes at `spike_frames_train1[i]` + and `spike_frames_train2[j]` are considered matching. Returns ------- matching_matrix : ndarray - A matrix of shape (num_units_sorting1, num_units_sorting2) where each entry [i, j] represents - the number of matching spikes between unit i of `frames_spike_train1` and unit j of `frames_spike_train2`. + A 2D numpy array of shape `(num_units_train1, num_units_train2)`. Each element `[i, j]` represents + the count of matching spike pairs between unit `i` from `spike_frames_train1` and unit `j` from `spike_frames_train2`. + Notes ----- @@ -168,59 +171,58 @@ def compute_matching_matrix( By iterating through each spike in the first train, it compares them against spikes in the second train, determining matches based on the two spikes frames being within `delta_frames` of each other. - To avoid redundant comparisons the algorithm maintains a reference, `lower_search_limit_in_second_train`, + To avoid redundant comparisons the algorithm maintains a reference, `second_train_search_start `, which signifies the minimal index in the second spike train that might match the upcoming spike - in the first train. This means that the start of the search moves forward in the second train as the - matches between the two trains are found decreasing the number of comparisons needed. + in the first train. + + The logic can be summarized as follows: + 1. Iterate through each spike in the first train + 2. For each spike, find the first match in the second train. + 3. Save the index of the first match as the new `second_train_search_start ` + 3. For each match, find as many matches as possible from the first match onwards. - An important condition here is thatthe same spike is not matched twice. This is managed by keeping track - of the last matched frame for each unit pair in `previous_frame1_match` and `previous_frame2_match` + An important condition here is that the same spike is not matched twice. This is managed by keeping track + of the last matched frame for each unit pair in `last_match_frame1` and `last_match_frame2` For more details on the rationale behind this approach, refer to the documentation of this module and/or - the metrics section in SpikeForest documentation. + the metrics section in SpikeForest documentation. """ - matching_matrix = np.zeros((num_units_sorting1, num_units_sorting2), dtype=np.uint16) + matching_matrix = np.zeros((num_units_train1, num_units_train2), dtype=np.uint16) # Used to avoid the same spike matching twice - previous_frame1_match = -np.ones_like(matching_matrix, dtype=np.int64) - previous_frame2_match = -np.ones_like(matching_matrix, dtype=np.int64) - - lower_search_limit_in_second_train = 0 - - for index1 in range(len(frames_spike_train1)): - # Keeps track of which frame in the second spike train should be used as a search start for matches - index2 = lower_search_limit_in_second_train - frame1 = frames_spike_train1[index1] - - # Determine next_frame1 if current frame is not the last frame - not_in_the_last_loop = index1 < len(frames_spike_train1) - 1 - if not_in_the_last_loop: - next_frame1 = frames_spike_train1[index1 + 1] - - while index2 < len(frames_spike_train2): - frame2 = frames_spike_train2[index2] - not_a_match = abs(frame1 - frame2) > delta_frames - if not_a_match: - # Go to the next frame in the first train + last_match_frame1 = -np.ones_like(matching_matrix, dtype=np.int64) + last_match_frame2 = -np.ones_like(matching_matrix, dtype=np.int64) + + num_spike_frames_train1 = len(spike_frames_train1) + num_spike_frames_train2 = len(spike_frames_train2) + + # Keeps track of which frame in the second spike train should be used as a search start for matches + second_train_search_start = 0 + for index1 in range(num_spike_frames_train1): + frame1 = spike_frames_train1[index1] + + for index2 in range(second_train_search_start, num_spike_frames_train2): + frame2 = spike_frames_train2[index2] + if frame2 < frame1 - delta_frames: + # no match move the left limit for the next loop + second_train_search_start += 1 + continue + elif frame2 > frame1 + delta_frames: + # no match stop search in train2 and continue increment in train1 break + else: + # match + unit_index1, unit_index2 = unit_indices1[index1], unit_indices2[index2] - # Map the match to a matrix - row, column = unit_indices1[index1], unit_indices2[index2] - - # The same spike cannot be matched twice see the notes in the docstring for more info on this constraint - if frame1 != previous_frame1_match[row, column] and frame2 != previous_frame2_match[row, column]: - previous_frame1_match[row, column] = frame1 - previous_frame2_match[row, column] = frame2 - - matching_matrix[row, column] += 1 - - index2 += 1 + if ( + frame1 != last_match_frame1[unit_index1, unit_index2] + and frame2 != last_match_frame2[unit_index1, unit_index2] + ): + last_match_frame1[unit_index1, unit_index2] = frame1 + last_match_frame2[unit_index1, unit_index2] = frame2 - # Advance the lower_search_limit_in_second_train if the next frame in the first train does not match - not_a_match_with_next = abs(next_frame1 - frame2) > delta_frames - if not_a_match_with_next: - lower_search_limit_in_second_train = index2 + matching_matrix[unit_index1, unit_index2] += 1 return matching_matrix @@ -230,7 +232,7 @@ def compute_matching_matrix( return compute_matching_matrix -def make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=None): +def make_match_count_matrix(sorting1, sorting2, delta_frames): num_units_sorting1 = sorting1.get_num_units() num_units_sorting2 = sorting2.get_num_units() matching_matrix = np.zeros((num_units_sorting1, num_units_sorting2), dtype=np.uint16) @@ -275,7 +277,7 @@ def make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=None): return match_event_counts_df -def make_agreement_scores(sorting1, sorting2, delta_frames, n_jobs=1): +def make_agreement_scores(sorting1, sorting2, delta_frames): """ Make the agreement matrix. No threshold (min_score) is applied at this step. @@ -291,8 +293,6 @@ def make_agreement_scores(sorting1, sorting2, delta_frames, n_jobs=1): The second sorting extractor delta_frames: int Number of frames to consider spikes coincident - n_jobs: int - Number of jobs to run in parallel Returns ------- @@ -309,7 +309,7 @@ def make_agreement_scores(sorting1, sorting2, delta_frames, n_jobs=1): event_counts1 = pd.Series(ev_counts1, index=unit1_ids) event_counts2 = pd.Series(ev_counts2, index=unit2_ids) - match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=n_jobs) + match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames) agreement_scores = make_agreement_scores_from_count(match_event_count, event_counts1, event_counts2) diff --git a/src/spikeinterface/comparison/paircomparisons.py b/src/spikeinterface/comparison/paircomparisons.py index e2dc30493d..7f21aa657f 100644 --- a/src/spikeinterface/comparison/paircomparisons.py +++ b/src/spikeinterface/comparison/paircomparisons.py @@ -84,9 +84,7 @@ def _do_agreement(self): self.event_counts2 = do_count_event(self.sorting2) # matrix of event match count for each pair - self.match_event_count = make_match_count_matrix( - self.sorting1, self.sorting2, self.delta_frames, n_jobs=self.n_jobs - ) + self.match_event_count = make_match_count_matrix(self.sorting1, self.sorting2, self.delta_frames) # agreement matrix score for each pair self.agreement_scores = make_agreement_scores_from_count( diff --git a/src/spikeinterface/comparison/tests/test_comparisontools.py b/src/spikeinterface/comparison/tests/test_comparisontools.py index c6494b04d1..ab24678a1e 100644 --- a/src/spikeinterface/comparison/tests/test_comparisontools.py +++ b/src/spikeinterface/comparison/tests/test_comparisontools.py @@ -135,6 +135,23 @@ def test_make_match_count_matrix_repeated_matching_but_no_double_counting(): assert_array_equal(result.to_numpy(), expected_result) +def test_make_match_count_matrix_test_proper_search_in_the_second_train(): + "Search exhaustively in the second train, but only within the delta_frames window, do not terminate search early" + frames_spike_train1 = [500, 600, 800] + frames_spike_train2 = [0, 100, 200, 300, 500, 800] + unit_indices1 = [0, 0, 0] + unit_indices2 = [0, 0, 0, 0, 0, 0] + delta_frames = 20 + + sorting1, sorting2 = make_sorting(frames_spike_train1, unit_indices1, frames_spike_train2, unit_indices2) + + result = make_match_count_matrix(sorting1, sorting2, delta_frames=delta_frames) + + expected_result = np.array([[2]]) + + assert_array_equal(result.to_numpy(), expected_result) + + def test_make_agreement_scores(): delta_frames = 10 @@ -150,7 +167,7 @@ def test_make_agreement_scores(): [0, 0, 5], ) - agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames, n_jobs=1) + agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames) print(agreement_scores) ok = np.array([[2 / 3, 0], [0, 1.0]], dtype="float64") @@ -158,7 +175,7 @@ def test_make_agreement_scores(): assert_array_equal(agreement_scores.values, ok) # test if symetric - agreement_scores2 = make_agreement_scores(sorting2, sorting1, delta_frames, n_jobs=1) + agreement_scores2 = make_agreement_scores(sorting2, sorting1, delta_frames) assert_array_equal(agreement_scores, agreement_scores2.T) @@ -178,7 +195,7 @@ def test_make_possible_match(): [0, 0, 5], ) - agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames, n_jobs=1) + agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames) possible_match_12, possible_match_21 = make_possible_match(agreement_scores, min_accuracy) @@ -207,7 +224,7 @@ def test_make_best_match(): [0, 0, 5], ) - agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames, n_jobs=1) + agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames) best_match_12, best_match_21 = make_best_match(agreement_scores, min_accuracy) @@ -236,7 +253,7 @@ def test_make_hungarian_match(): [0, 0, 5], ) - agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames, n_jobs=1) + agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames) hungarian_match_12, hungarian_match_21 = make_hungarian_match(agreement_scores, min_accuracy) @@ -344,8 +361,8 @@ def test_do_confusion_matrix(): event_counts1 = do_count_event(sorting1) event_counts2 = do_count_event(sorting2) - match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=1) - agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames, n_jobs=1) + match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames) + agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames) hungarian_match_12, hungarian_match_21 = make_hungarian_match(agreement_scores, min_accuracy) confusion = do_confusion_matrix(event_counts1, event_counts2, hungarian_match_12, match_event_count) @@ -363,8 +380,8 @@ def test_do_confusion_matrix(): event_counts1 = do_count_event(sorting1) event_counts2 = do_count_event(sorting2) - match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=1) - agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames, n_jobs=1) + match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames) + agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames) hungarian_match_12, hungarian_match_21 = make_hungarian_match(agreement_scores, min_accuracy) confusion = do_confusion_matrix(event_counts1, event_counts2, hungarian_match_12, match_event_count) @@ -391,8 +408,8 @@ def test_do_count_score_and_perf(): event_counts1 = do_count_event(sorting1) event_counts2 = do_count_event(sorting2) - match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=1) - agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames, n_jobs=1) + match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames) + agreement_scores = make_agreement_scores(sorting1, sorting2, delta_frames) hungarian_match_12, hungarian_match_21 = make_hungarian_match(agreement_scores, min_accuracy) count_score = do_count_score(event_counts1, event_counts2, hungarian_match_12, match_event_count) @@ -415,13 +432,20 @@ def test_do_count_score_and_perf(): if __name__ == "__main__": test_make_match_count_matrix() - test_make_agreement_scores() - - test_make_possible_match() - test_make_best_match() - test_make_hungarian_match() - - test_do_score_labels() - test_compare_spike_trains() - test_do_confusion_matrix() - test_do_count_score_and_perf() + test_make_match_count_matrix_sorting_with_itself_simple() + test_make_match_count_matrix_sorting_with_itself_longer() + test_make_match_count_matrix_with_mismatched_sortings() + test_make_match_count_matrix_no_double_matching() + test_make_match_count_matrix_repeated_matching_but_no_double_counting() + test_make_match_count_matrix_test_proper_search_in_the_second_train() + + # test_make_agreement_scores() + + # test_make_possible_match() + # test_make_best_match() + # test_make_hungarian_match() + + # test_do_score_labels() + # test_compare_spike_trains() + # test_do_confusion_matrix() + # test_do_count_score_and_perf()