From 5231ad30e10787bc0e7e0034ec99115f6b413f16 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 18 Oct 2023 17:18:36 +0200 Subject: [PATCH 01/13] added testing and method --- .../comparison/comparisontools.py | 124 +++++++++++++----- .../comparison/tests/test_comparisontools.py | 105 +++++++++++++-- src/spikeinterface/core/basesorting.py | 4 +- 3 files changed, 188 insertions(+), 45 deletions(-) diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index 20ee7910b4..dc0aaaf9a8 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -3,6 +3,7 @@ """ import numpy as np +import numba from joblib import Parallel, delayed @@ -109,50 +110,101 @@ def count_match_spikes(times1, all_times2, delta_frames): # , event_counts1, ev return matching_event_counts -def make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=1): - """ - Make the match_event_count matrix. - Basically it counts the matching events for all given pairs of spike trains from - sorting1 and sorting2. +@numba.jit(nopython=True, nogil=True) +def compute_matching_matrix( + frames_spike_train1, + frames_spike_train2, + unit_indices1, + unit_indices2, + num_units_sorting1, + num_units_sorting2, + delta_frames, +): + matching_matrix = np.zeros((num_units_sorting1, num_units_sorting2), dtype=np.int64) - Parameters - ---------- - sorting1: SortingExtractor - The first sorting extractor - sorting2: SortingExtractor - The second sorting extractor - delta_frames: int - Number of frames to consider spikes coincident - n_jobs: int - Number of jobs to run in parallel + # Used for Jeremy Magldan condition where no unit can be matched twice. + previous_frame1_match = -np.ones_like(matching_matrix, dtype=np.int64) + previous_frame2_match = -np.ones_like(matching_matrix, dtype=np.int64) - Returns - ------- - match_event_count: array (int64) - Matrix of match count spike - """ - import pandas as pd + lower_search_limit_in_second_train = 0 - unit1_ids = np.array(sorting1.get_unit_ids()) - unit2_ids = np.array(sorting2.get_unit_ids()) + for index1 in range(len(frames_spike_train1)): + # Keeps track of which frame in the second spike train should be used as a search start + index2 = lower_search_limit_in_second_train + frame1 = frames_spike_train1[index1] - match_event_counts = np.zeros((len(unit1_ids), len(unit2_ids)), dtype="int64") + # Determine next_frame1 if current frame is not the last frame + not_in_the_last_loop = index1 < len(frames_spike_train1) - 1 + if not_in_the_last_loop: + next_frame1 = frames_spike_train1[index1 + 1] - # preload all spiketrains 2 into a list - for segment_index in range(sorting1.get_num_segments()): - s2_spiketrains = [sorting2.get_unit_spike_train(u2, segment_index=segment_index) for u2 in unit2_ids] + while index2 < len(frames_spike_train2): + frame2 = frames_spike_train2[index2] + not_a_match = abs(frame1 - frame2) > delta_frames + if not_a_match: + break - match_event_count_segment = Parallel(n_jobs=n_jobs)( - delayed(count_match_spikes)( - sorting1.get_unit_spike_train(u1, segment_index=segment_index), s2_spiketrains, delta_frames - ) - for i1, u1 in enumerate(unit1_ids) - ) - match_event_counts += np.array(match_event_count_segment) + # Map the match to a matrix + row, column = unit_indices1[index1], unit_indices2[index2] + + # Jeremy Magland condition, the same unit can't match twice + if frame1 != previous_frame1_match[row, column] and frame2 != previous_frame2_match[row, column]: + previous_frame1_match[row, column] = frame1 + previous_frame2_match[row, column] = frame2 + + matching_matrix[row, column] += 1 + + index2 += 1 + + # Advance the minimal index 2 if not in the last loop iteration + if not_in_the_last_loop: + not_a_match_with_next = abs(next_frame1 - frame2) > delta_frames + if not_a_match_with_next: + lower_search_limit_in_second_train = index2 + + return matching_matrix + + +def make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=None): + num_units_sorting1 = sorting1.get_num_units() + num_units_sorting2 = sorting2.get_num_units() + + unit1_ids = sorting1.get_unit_ids() + unit2_ids = sorting2.get_unit_ids() + spike_trains1 = [sorting1.get_unit_spike_train(unit_id) for unit_id in unit1_ids] + spike_trains2 = [sorting2.get_unit_spike_train(unit_id) for unit_id in unit2_ids] + + sample_frames1 = np.concatenate(spike_trains1) + sample_frames2 = np.concatenate(spike_trains2) + + # Directly creating unit indices without intermediate lists + unit_indices1 = np.concatenate([np.full(len(train), unit) for unit, train in enumerate(spike_trains1)]) + unit_indices2 = np.concatenate([np.full(len(train), unit) for unit, train in enumerate(spike_trains2)]) + + # Sort the sample_frames and unit_indices arrays + sort_indices1 = np.argsort(sample_frames1) + sample_frames1 = sample_frames1[sort_indices1] + unit_indices1 = unit_indices1[sort_indices1] + + sort_indices2 = np.argsort(sample_frames2) + sample_frames2 = sample_frames2[sort_indices2] + unit_indices2 = unit_indices2[sort_indices2] + + full_matrix = compute_matching_matrix( + sample_frames1, + sample_frames2, + unit_indices1, + unit_indices2, + num_units_sorting1, + num_units_sorting2, + delta_frames, + ) + + import pandas as pd - match_event_counts_df = pd.DataFrame(np.array(match_event_counts), index=unit1_ids, columns=unit2_ids) + df = pd.DataFrame(full_matrix, index=unit1_ids, columns=unit2_ids) - return match_event_counts_df + return df def make_agreement_scores(sorting1, sorting2, delta_frames, n_jobs=1): diff --git a/src/spikeinterface/comparison/tests/test_comparisontools.py b/src/spikeinterface/comparison/tests/test_comparisontools.py index 5d5c56d15c..e46f904351 100644 --- a/src/spikeinterface/comparison/tests/test_comparisontools.py +++ b/src/spikeinterface/comparison/tests/test_comparisontools.py @@ -15,6 +15,7 @@ do_count_score, compute_performance, ) +from spikeinterface.core.generate import generate_sorting def make_sorting(times1, labels1, times2, labels2): @@ -27,25 +28,113 @@ def make_sorting(times1, labels1, times2, labels2): def test_make_match_count_matrix(): delta_frames = 10 - # simple match sorting1, sorting2 = make_sorting( [100, 200, 300, 400], [0, 0, 1, 0], - [ - 101, - 201, - 301, - ], + [101, 201, 301], [0, 0, 5], ) - match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=1) - # ~ print(match_event_count) + match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames) assert match_event_count.shape[0] == len(sorting1.get_unit_ids()) assert match_event_count.shape[1] == len(sorting2.get_unit_ids()) +def test_make_match_count_matrix_sorting_with_itself_simple(): + delta_frames = 10 + + # simple sorting with itself + sorting1, sorting2 = make_sorting( + [100, 200, 300, 400], + [0, 0, 1, 0], + [100, 200, 300, 400], + [0, 0, 1, 0], + ) + + match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames) + + expected_result = [[3, 0], [0, 1]] + assert_array_equal(match_event_count.to_numpy(), expected_result) + + +def test_make_match_count_matrix_sorting_with_itself_longer(): + seed = 2 + sorting = generate_sorting(num_units=10, sampling_frequency=30000, durations=[10], seed=seed) + + delta_frame_milliseconds = 0.1 # Short so that we only matches between a unit and itself + delta_frames_seconds = delta_frame_milliseconds / 1000 + delta_frames = delta_frames_seconds * sorting.get_sampling_frequency() + match_event_count = make_match_count_matrix(sorting, sorting, delta_frames) + + match_event_count_as_array = match_event_count.to_numpy() + matches_with_itself = np.diag(match_event_count_as_array) + + # The number of matches with itself should be equal to the number of spikes in each unit + spikes_per_unit_dict = sorting.count_num_spikes_per_unit() + expected_result = np.array([spikes_per_unit_dict[u] for u in spikes_per_unit_dict.keys()]) + assert_array_equal(matches_with_itself, expected_result) + + +def test_make_match_count_matrix_with_mismatched_sortings(): + delta_frames = 10 + + sorting1, sorting2 = make_sorting( + [100, 200, 300, 400], [0, 0, 1, 0], [500, 600, 700, 800], [0, 0, 1, 0] # Completely different spike times + ) + + match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames) + + expected_result = [[0, 0], [0, 0]] # No matches between sorting1 and sorting2 + assert_array_equal(match_event_count.to_numpy(), expected_result) + + +def test_no_double_matching(): + # Jeremy Magland condition: no double matching + frames_spike_train1 = [100, 105, 120, 1000] + unit_indices1 = [0, 1, 0, 0] + frames_spike_train2 = [101, 150, 1000] + unit_indices2 = [0, 1, 0] + delta_frames = 100 + + # Here the key is that the first frame in the first sorting (120) should not match anything in the second sorting + # Because the matching candidates in the second sorting are already matched to the first two frames + # in the first sorting + + # In detail: + # The first frame in sorting 1 (100) from unit 0 should match: + # * The first frame in sorting 2 (101) from unit 0 + # * The second frame in sorting 2 (150) from unit 1 + # The second frame in sorting 1 (105) from unit 1 should match: + # * The first frame in sorting 2 (101) from unit 0 + # * The second frame in sorting 2 (150) from unit 1 + # The third frame in sorting 1 (120) from unit 0 should not match anything + # The final frame in sorting 1 (1000) from unit 0 should only match the final frame in sorting 2 (1000) from unit 0 + + sorting1, sorting2 = make_sorting(frames_spike_train1, unit_indices1, frames_spike_train2, unit_indices2) + + result = make_match_count_matrix(sorting1, sorting2, delta_frames=delta_frames) + + expected_result = np.array([[2, 1], [1, 1]]) # Only one match is expected despite potential repeats + assert_array_equal(result.to_numpy(), expected_result) + + +def test_repeated_matching_but_no_double_counting(): + # Challenging condition, this was failing with the previous approach that used np.where and np.diff + frames_spike_train1 = [100, 105, 110] + frames_spike_train2 = [100, 105, 110] + unit_indices1 = [0, 0, 0] + unit_indices2 = [0, 0, 0] + delta_frames = 20 # long enough, so all frames in both sortings are within each other reach + + sorting1, sorting2 = make_sorting(frames_spike_train1, unit_indices1, frames_spike_train2, unit_indices2) + + result = make_match_count_matrix(sorting1, sorting2, delta_frames=delta_frames) + + expected_result = np.array([[3]]) + assert_array_equal(result.to_numpy(), expected_result) + + def test_make_agreement_scores(): delta_frames = 10 diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index e6d08d38f7..c40c4492f6 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import warnings from typing import List, Optional, Union @@ -268,7 +270,7 @@ def get_total_num_spikes(self): ) return self.count_num_spikes_per_unit() - def count_num_spikes_per_unit(self): + def count_num_spikes_per_unit(self) -> dict: """ For each unit : get number of spikes across segments. From b2ddc4a1729f67ed5bc5e271e45db14d02894a6c Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Fri, 20 Oct 2023 16:26:46 +0200 Subject: [PATCH 02/13] added jitting dispatch mechanism --- src/spikeinterface/comparison/comparisontools.py | 11 +++++++++-- .../comparison/tests/test_comparisontools.py | 10 +++++----- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index dc0aaaf9a8..0848731812 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -4,7 +4,6 @@ import numpy as np import numba -from joblib import Parallel, delayed def count_matching_events(times1, times2, delta=10): @@ -190,7 +189,15 @@ def make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=None): sample_frames2 = sample_frames2[sort_indices2] unit_indices2 = unit_indices2[sort_indices2] - full_matrix = compute_matching_matrix( + import numba + + # Check if compute_matching_matrix is already jitted + if not isinstance(compute_matching_matrix, numba.core.registry.CPUDispatcher): + optimized_compute_matching_matrix = numba.jit(nopython=True, nogil=True)(compute_matching_matrix) + else: + optimized_compute_matching_matrix = compute_matching_matrix + + full_matrix = optimized_compute_matching_matrix( sample_frames1, sample_frames2, unit_indices1, diff --git a/src/spikeinterface/comparison/tests/test_comparisontools.py b/src/spikeinterface/comparison/tests/test_comparisontools.py index e46f904351..c6494b04d1 100644 --- a/src/spikeinterface/comparison/tests/test_comparisontools.py +++ b/src/spikeinterface/comparison/tests/test_comparisontools.py @@ -60,7 +60,7 @@ def test_make_match_count_matrix_sorting_with_itself_simple(): def test_make_match_count_matrix_sorting_with_itself_longer(): seed = 2 - sorting = generate_sorting(num_units=10, sampling_frequency=30000, durations=[10], seed=seed) + sorting = generate_sorting(num_units=10, sampling_frequency=30000, durations=[5, 5], seed=seed) delta_frame_milliseconds = 0.1 # Short so that we only matches between a unit and itself delta_frames_seconds = delta_frame_milliseconds / 1000 @@ -89,7 +89,7 @@ def test_make_match_count_matrix_with_mismatched_sortings(): assert_array_equal(match_event_count.to_numpy(), expected_result) -def test_no_double_matching(): +def test_make_match_count_matrix_no_double_matching(): # Jeremy Magland condition: no double matching frames_spike_train1 = [100, 105, 120, 1000] unit_indices1 = [0, 1, 0, 0] @@ -119,11 +119,11 @@ def test_no_double_matching(): assert_array_equal(result.to_numpy(), expected_result) -def test_repeated_matching_but_no_double_counting(): +def test_make_match_count_matrix_repeated_matching_but_no_double_counting(): # Challenging condition, this was failing with the previous approach that used np.where and np.diff - frames_spike_train1 = [100, 105, 110] + frames_spike_train1 = [100, 105, 110] # Will fail with [100, 105, 110, 120] frames_spike_train2 = [100, 105, 110] - unit_indices1 = [0, 0, 0] + unit_indices1 = [0, 0, 0] # Will fail with [0, 0, 0, 0] unit_indices2 = [0, 0, 0] delta_frames = 20 # long enough, so all frames in both sortings are within each other reach From a8b09ff1ec4bf67711a9d729cc798448f10c4152 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Fri, 20 Oct 2023 16:51:09 +0200 Subject: [PATCH 03/13] test passing --- .../comparison/comparisontools.py | 57 +++++++++++++------ src/spikeinterface/core/generate.py | 2 +- 2 files changed, 40 insertions(+), 19 deletions(-) diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index 0848731812..84cacc0caa 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -167,27 +167,48 @@ def compute_matching_matrix( def make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=None): num_units_sorting1 = sorting1.get_num_units() num_units_sorting2 = sorting2.get_num_units() - + num_segments_sorting1 = sorting1.get_num_segments() + num_segments_sorting2 = sorting2.get_num_segments() unit1_ids = sorting1.get_unit_ids() unit2_ids = sorting2.get_unit_ids() - spike_trains1 = [sorting1.get_unit_spike_train(unit_id) for unit_id in unit1_ids] - spike_trains2 = [sorting2.get_unit_spike_train(unit_id) for unit_id in unit2_ids] - sample_frames1 = np.concatenate(spike_trains1) - sample_frames2 = np.concatenate(spike_trains2) + sample_frames1_accumulator = [] + unit_indices1_accumulator = [] + + sample_frames2_accumulator = [] + unit_indices2_accumulator = [] + + for segment_index in range(num_segments_sorting1): + spike_trains1 = [sorting1.get_unit_spike_train(unit_id, segment_index) for unit_id in unit1_ids] + sample_frames1 = np.concatenate(spike_trains1) + unit_indices1 = np.concatenate([np.full(len(train), unit) for unit, train in enumerate(spike_trains1)]) + + sample_frames1_accumulator.append(sample_frames1) + unit_indices1_accumulator.append(unit_indices1) + + for segment_index in range(num_segments_sorting2): + spike_trains2 = [sorting2.get_unit_spike_train(unit_id, segment_index) for unit_id in unit2_ids] + sample_frames2 = np.concatenate(spike_trains2) + unit_indices2 = np.concatenate([np.full(len(train), unit) for unit, train in enumerate(spike_trains2)]) + + sample_frames2_accumulator.append(sample_frames2) + unit_indices2_accumulator.append(unit_indices2) + + # Concatenate accumulated data + sample_frames1_all = np.concatenate(sample_frames1_accumulator) + unit_indices1_all = np.concatenate(unit_indices1_accumulator) - # Directly creating unit indices without intermediate lists - unit_indices1 = np.concatenate([np.full(len(train), unit) for unit, train in enumerate(spike_trains1)]) - unit_indices2 = np.concatenate([np.full(len(train), unit) for unit, train in enumerate(spike_trains2)]) + sample_frames2_all = np.concatenate(sample_frames2_accumulator) + unit_indices2_all = np.concatenate(unit_indices2_accumulator) # Sort the sample_frames and unit_indices arrays - sort_indices1 = np.argsort(sample_frames1) - sample_frames1 = sample_frames1[sort_indices1] - unit_indices1 = unit_indices1[sort_indices1] + sort_indices1 = np.argsort(sample_frames1_all) + sample_frames1_sorted = sample_frames1_all[sort_indices1] + unit_indices1_sorted = unit_indices1_all[sort_indices1] - sort_indices2 = np.argsort(sample_frames2) - sample_frames2 = sample_frames2[sort_indices2] - unit_indices2 = unit_indices2[sort_indices2] + sort_indices2 = np.argsort(sample_frames2_all) + sample_frames2_sorted = sample_frames2_all[sort_indices2] + unit_indices2_sorted = unit_indices2_all[sort_indices2] import numba @@ -198,10 +219,10 @@ def make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=None): optimized_compute_matching_matrix = compute_matching_matrix full_matrix = optimized_compute_matching_matrix( - sample_frames1, - sample_frames2, - unit_indices1, - unit_indices2, + sample_frames1_sorted, + sample_frames2_sorted, + unit_indices1_sorted, + unit_indices2_sorted, num_units_sorting1, num_units_sorting2, delta_frames, diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index eeb1e8af60..72b9ec6450 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -162,7 +162,7 @@ def generate_sorting( duration=durations[segment_index], refractory_period_ms=refractory_period_ms, firing_rates=firing_rates, - seed=seed, + seed=seed + segment_index, ) if empty_units is not None: From 94f8df19fe271d6d73bdedebed920728b1098b9a Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Fri, 20 Oct 2023 16:54:31 +0200 Subject: [PATCH 04/13] avoid concatenation --- .../comparison/comparisontools.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index 84cacc0caa..235b46196b 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -3,7 +3,7 @@ """ import numpy as np -import numba +import itertools def count_matching_events(times1, times2, delta=10): @@ -180,21 +180,19 @@ def make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=None): for segment_index in range(num_segments_sorting1): spike_trains1 = [sorting1.get_unit_spike_train(unit_id, segment_index) for unit_id in unit1_ids] - sample_frames1 = np.concatenate(spike_trains1) - unit_indices1 = np.concatenate([np.full(len(train), unit) for unit, train in enumerate(spike_trains1)]) + sample_frames1_accumulator.extend(spike_trains1) - sample_frames1_accumulator.append(sample_frames1) - unit_indices1_accumulator.append(unit_indices1) + unit_indices1 = [np.full(len(train), unit) for unit, train in enumerate(spike_trains1)] + unit_indices1_accumulator.extend(unit_indices1) for segment_index in range(num_segments_sorting2): spike_trains2 = [sorting2.get_unit_spike_train(unit_id, segment_index) for unit_id in unit2_ids] - sample_frames2 = np.concatenate(spike_trains2) - unit_indices2 = np.concatenate([np.full(len(train), unit) for unit, train in enumerate(spike_trains2)]) + sample_frames2_accumulator.extend(spike_trains2) - sample_frames2_accumulator.append(sample_frames2) - unit_indices2_accumulator.append(unit_indices2) + unit_indices2 = [np.full(len(train), unit) for unit, train in enumerate(spike_trains2)] + unit_indices2_accumulator.extend(unit_indices2) - # Concatenate accumulated data + # Concatenate accumulated data only once sample_frames1_all = np.concatenate(sample_frames1_accumulator) unit_indices1_all = np.concatenate(unit_indices1_accumulator) From 29d0d7dc768383fccce559203ad254b9edf1a3a8 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Fri, 20 Oct 2023 17:16:10 +0200 Subject: [PATCH 05/13] monster of dynamic import --- .../comparison/comparisontools.py | 124 ++++++++++-------- 1 file changed, 66 insertions(+), 58 deletions(-) diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index 235b46196b..2f3b8396fc 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -109,59 +109,73 @@ def count_match_spikes(times1, all_times2, delta_frames): # , event_counts1, ev return matching_event_counts -@numba.jit(nopython=True, nogil=True) -def compute_matching_matrix( - frames_spike_train1, - frames_spike_train2, - unit_indices1, - unit_indices2, - num_units_sorting1, - num_units_sorting2, - delta_frames, -): - matching_matrix = np.zeros((num_units_sorting1, num_units_sorting2), dtype=np.int64) - - # Used for Jeremy Magldan condition where no unit can be matched twice. - previous_frame1_match = -np.ones_like(matching_matrix, dtype=np.int64) - previous_frame2_match = -np.ones_like(matching_matrix, dtype=np.int64) - - lower_search_limit_in_second_train = 0 - - for index1 in range(len(frames_spike_train1)): - # Keeps track of which frame in the second spike train should be used as a search start - index2 = lower_search_limit_in_second_train - frame1 = frames_spike_train1[index1] - - # Determine next_frame1 if current frame is not the last frame - not_in_the_last_loop = index1 < len(frames_spike_train1) - 1 - if not_in_the_last_loop: - next_frame1 = frames_spike_train1[index1 + 1] - - while index2 < len(frames_spike_train2): - frame2 = frames_spike_train2[index2] - not_a_match = abs(frame1 - frame2) > delta_frames - if not_a_match: - break - - # Map the match to a matrix - row, column = unit_indices1[index1], unit_indices2[index2] - - # Jeremy Magland condition, the same unit can't match twice - if frame1 != previous_frame1_match[row, column] and frame2 != previous_frame2_match[row, column]: - previous_frame1_match[row, column] = frame1 - previous_frame2_match[row, column] = frame2 - - matching_matrix[row, column] += 1 - - index2 += 1 - - # Advance the minimal index 2 if not in the last loop iteration +def get_optimized_compute_matching_matrix(): + # Cache for compiled function + if hasattr(get_optimized_compute_matching_matrix, "_cached_function"): + return get_optimized_compute_matching_matrix._cached_function + + # Dynamic import of numba + import numba + + # Nested function + @numba.jit(nopython=True, nogil=True) + def compute_matching_matrix_inner( + frames_spike_train1, + frames_spike_train2, + unit_indices1, + unit_indices2, + num_units_sorting1, + num_units_sorting2, + delta_frames, + ): + matching_matrix = np.zeros((num_units_sorting1, num_units_sorting2), dtype=np.int64) + + # Used for Jeremy Magldan condition where no unit can be matched twice. + previous_frame1_match = -np.ones_like(matching_matrix, dtype=np.int64) + previous_frame2_match = -np.ones_like(matching_matrix, dtype=np.int64) + + lower_search_limit_in_second_train = 0 + + for index1 in range(len(frames_spike_train1)): + # Keeps track of which frame in the second spike train should be used as a search start + index2 = lower_search_limit_in_second_train + frame1 = frames_spike_train1[index1] + + # Determine next_frame1 if current frame is not the last frame + not_in_the_last_loop = index1 < len(frames_spike_train1) - 1 if not_in_the_last_loop: - not_a_match_with_next = abs(next_frame1 - frame2) > delta_frames - if not_a_match_with_next: - lower_search_limit_in_second_train = index2 + next_frame1 = frames_spike_train1[index1 + 1] + + while index2 < len(frames_spike_train2): + frame2 = frames_spike_train2[index2] + not_a_match = abs(frame1 - frame2) > delta_frames + if not_a_match: + break - return matching_matrix + # Map the match to a matrix + row, column = unit_indices1[index1], unit_indices2[index2] + + # Jeremy Magland condition, the same unit can't match twice + if frame1 != previous_frame1_match[row, column] and frame2 != previous_frame2_match[row, column]: + previous_frame1_match[row, column] = frame1 + previous_frame2_match[row, column] = frame2 + + matching_matrix[row, column] += 1 + + index2 += 1 + + # Advance the minimal index 2 if not in the last loop iteration + if not_in_the_last_loop: + not_a_match_with_next = abs(next_frame1 - frame2) > delta_frames + if not_a_match_with_next: + lower_search_limit_in_second_train = index2 + + return matching_matrix + + # Cache the compiled function + get_optimized_compute_matching_matrix._cached_function = compute_matching_matrix_inner + + return compute_matching_matrix_inner def make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=None): @@ -208,13 +222,7 @@ def make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=None): sample_frames2_sorted = sample_frames2_all[sort_indices2] unit_indices2_sorted = unit_indices2_all[sort_indices2] - import numba - - # Check if compute_matching_matrix is already jitted - if not isinstance(compute_matching_matrix, numba.core.registry.CPUDispatcher): - optimized_compute_matching_matrix = numba.jit(nopython=True, nogil=True)(compute_matching_matrix) - else: - optimized_compute_matching_matrix = compute_matching_matrix + optimized_compute_matching_matrix = get_optimized_compute_matching_matrix() full_matrix = optimized_compute_matching_matrix( sample_frames1_sorted, From 52261e2199d5095258b066c251ceae943cc3b939 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 23 Oct 2023 17:56:50 +0200 Subject: [PATCH 06/13] Add missing seed in QM tests --- .../qualitymetrics/tests/test_quality_metric_calculator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py index 73bbee611b..eb8317e4df 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py @@ -210,7 +210,7 @@ def test_peak_sign(self): # invert recording rec_inv = scale(rec, gain=-1.0) - we_inv = extract_waveforms(rec_inv, sort, self.cache_folder / "toy_waveforms_inv") + we_inv = extract_waveforms(rec_inv, sort, self.cache_folder / "toy_waveforms_inv", seed=0) # compute amplitudes _ = compute_spike_amplitudes(we, peak_sign="neg") From 6d5a98ea274460f7c92b572ddad98081bd186752 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 24 Oct 2023 12:26:51 +0200 Subject: [PATCH 07/13] added docstring --- .../comparison/comparisontools.py | 77 ++++++++++++++++--- 1 file changed, 65 insertions(+), 12 deletions(-) diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index 2f3b8396fc..3ed3e386da 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -110,16 +110,20 @@ def count_match_spikes(times1, all_times2, delta_frames): # , event_counts1, ev def get_optimized_compute_matching_matrix(): - # Cache for compiled function + """ + This function is to avoid the bare try-except pattern when importing the compute_matching_matrix function + which uses numba. I tested using the numba dispatcher programatically to avoids this + but the performance improvements were lost. Think you can do better? Don't forget to measure performance against + the current implementation! + """ + if hasattr(get_optimized_compute_matching_matrix, "_cached_function"): return get_optimized_compute_matching_matrix._cached_function - # Dynamic import of numba import numba - # Nested function @numba.jit(nopython=True, nogil=True) - def compute_matching_matrix_inner( + def compute_matching_matrix( frames_spike_train1, frames_spike_train2, unit_indices1, @@ -128,9 +132,57 @@ def compute_matching_matrix_inner( num_units_sorting2, delta_frames, ): + """ + Compute a matrix representing the matches between two spike trains. + + Given two spike trains, this function finds matching spikes based on a temporal proximity criterion + defined by `delta_frames`. The resulting matrix indicates the number of matches between units + in `frames_spike_train1` and `frames_spike_train2`. + + Parameters + ---------- + frames_spike_train1 : ndarray + Array of frames for the first spike train. Should be ordered in ascending order. + frames_spike_train2 : ndarray + Array of frames for the second spike train. Should be ordered in ascending order. + unit_indices1 : ndarray + Array indicating the unit indices corresponding to each spike in `frames_spike_train1`. + unit_indices2 : ndarray + Array indicating the unit indices corresponding to each spike in `frames_spike_train2`. + num_units_sorting1 : int + Total number of units in the first spike train. + num_units_sorting2 : int + Total number of units in the second spike train. + delta_frames : int + Maximum difference in frames between two spikes to consider them as a match. + + Returns + ------- + matching_matrix : ndarray + A matrix of shape (num_units_sorting1, num_units_sorting2) where each entry [i, j] represents + the number of matching spikes between unit i of `frames_spike_train1` and unit j of `frames_spike_train2`. + + Notes + ----- + This algorithm identifies matching spikes between two ordered spike trains. + By iterating through each spike in the first train, it compares them against spikes in the second train, + determining matches based on the two spikes frames being within `delta_frames` of each other. + + To avoid redundant comparisons the algorithm maintains a reference, `lower_search_limit_in_second_train`, + which signifies the minimal index in the second spike train that might match the upcoming spike + in the first train. This means that the start of the search moves forward in the second train as the + matches between the two trains are found decreasing the number of comparisons needed. + + An important condition here is thatthe same spike is not matched twice. This is managed by keeping track + of the last matched frame for each unit pair in `previous_frame1_match` and `previous_frame2_match` + + For more details on the rationale behind this approach, refer to the documentation of this module and/or + the metrics section in SpikeForest documentation. + """ + matching_matrix = np.zeros((num_units_sorting1, num_units_sorting2), dtype=np.int64) - # Used for Jeremy Magldan condition where no unit can be matched twice. + # Used to avoid the same spike matching twice previous_frame1_match = -np.ones_like(matching_matrix, dtype=np.int64) previous_frame2_match = -np.ones_like(matching_matrix, dtype=np.int64) @@ -155,7 +207,9 @@ def compute_matching_matrix_inner( # Map the match to a matrix row, column = unit_indices1[index1], unit_indices2[index2] - # Jeremy Magland condition, the same unit can't match twice + # The same spike cannot be matched twice + # This condition is interpret matches asn accuracy task, see the documentation of the module + # Or the metrics section in spike forest to see the reason for this. if frame1 != previous_frame1_match[row, column] and frame2 != previous_frame2_match[row, column]: previous_frame1_match[row, column] = frame1 previous_frame2_match[row, column] = frame2 @@ -165,17 +219,16 @@ def compute_matching_matrix_inner( index2 += 1 # Advance the minimal index 2 if not in the last loop iteration - if not_in_the_last_loop: - not_a_match_with_next = abs(next_frame1 - frame2) > delta_frames - if not_a_match_with_next: - lower_search_limit_in_second_train = index2 + not_a_match_with_next = abs(next_frame1 - frame2) > delta_frames + if not_a_match_with_next: + lower_search_limit_in_second_train = index2 return matching_matrix # Cache the compiled function - get_optimized_compute_matching_matrix._cached_function = compute_matching_matrix_inner + get_optimized_compute_matching_matrix._cached_function = compute_matching_matrix - return compute_matching_matrix_inner + return compute_matching_matrix def make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=None): From f740e8323a835a737910e0b618a9d88e1c514783 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 24 Oct 2023 12:41:56 +0200 Subject: [PATCH 08/13] some small fixes, remove unecessary variable change, import of itertools --- src/spikeinterface/comparison/comparisontools.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index 3ed3e386da..924fe70174 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -3,7 +3,6 @@ """ import numpy as np -import itertools def count_matching_events(times1, times2, delta=10): @@ -275,9 +274,9 @@ def make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=None): sample_frames2_sorted = sample_frames2_all[sort_indices2] unit_indices2_sorted = unit_indices2_all[sort_indices2] - optimized_compute_matching_matrix = get_optimized_compute_matching_matrix() + compute_matching_matrix = get_optimized_compute_matching_matrix() - full_matrix = optimized_compute_matching_matrix( + full_matrix = compute_matching_matrix( sample_frames1_sorted, sample_frames2_sorted, unit_indices1_sorted, @@ -289,9 +288,9 @@ def make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=None): import pandas as pd - df = pd.DataFrame(full_matrix, index=unit1_ids, columns=unit2_ids) + match_event_counts_df = pd.DataFrame(full_matrix, index=unit1_ids, columns=unit2_ids) - return df + return match_event_counts_df def make_agreement_scores(sorting1, sorting2, delta_frames, n_jobs=1): From 8da861482c755bf4abba2d8fcb35bf759e485a55 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 24 Oct 2023 15:57:13 +0200 Subject: [PATCH 09/13] added spike vector and separate segments --- .../comparison/comparisontools.py | 93 +++++++------------ 1 file changed, 36 insertions(+), 57 deletions(-) diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index 924fe70174..f01db549bb 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -179,7 +179,7 @@ def compute_matching_matrix( the metrics section in SpikeForest documentation. """ - matching_matrix = np.zeros((num_units_sorting1, num_units_sorting2), dtype=np.int64) + matching_matrix = np.zeros((num_units_sorting1, num_units_sorting2), dtype=np.uint16) # Used to avoid the same spike matching twice previous_frame1_match = -np.ones_like(matching_matrix, dtype=np.int64) @@ -188,7 +188,7 @@ def compute_matching_matrix( lower_search_limit_in_second_train = 0 for index1 in range(len(frames_spike_train1)): - # Keeps track of which frame in the second spike train should be used as a search start + # Keeps track of which frame in the second spike train should be used as a search start for matches index2 = lower_search_limit_in_second_train frame1 = frames_spike_train1[index1] @@ -201,14 +201,13 @@ def compute_matching_matrix( frame2 = frames_spike_train2[index2] not_a_match = abs(frame1 - frame2) > delta_frames if not_a_match: + # Go to the next frame in the first train break # Map the match to a matrix row, column = unit_indices1[index1], unit_indices2[index2] - # The same spike cannot be matched twice - # This condition is interpret matches asn accuracy task, see the documentation of the module - # Or the metrics section in spike forest to see the reason for this. + # The same spike cannot be matched twice see the notes in the docstring for more info on this constraint if frame1 != previous_frame1_match[row, column] and frame2 != previous_frame2_match[row, column]: previous_frame1_match[row, column] = frame1 previous_frame2_match[row, column] = frame2 @@ -217,7 +216,7 @@ def compute_matching_matrix( index2 += 1 - # Advance the minimal index 2 if not in the last loop iteration + # Advance the lower_search_limit_in_second_train if the next frame in the first train does not match not_a_match_with_next = abs(next_frame1 - frame2) > delta_frames if not_a_match_with_next: lower_search_limit_in_second_train = index2 @@ -233,62 +232,42 @@ def compute_matching_matrix( def make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=None): num_units_sorting1 = sorting1.get_num_units() num_units_sorting2 = sorting2.get_num_units() - num_segments_sorting1 = sorting1.get_num_segments() - num_segments_sorting2 = sorting2.get_num_segments() - unit1_ids = sorting1.get_unit_ids() - unit2_ids = sorting2.get_unit_ids() - - sample_frames1_accumulator = [] - unit_indices1_accumulator = [] - - sample_frames2_accumulator = [] - unit_indices2_accumulator = [] - - for segment_index in range(num_segments_sorting1): - spike_trains1 = [sorting1.get_unit_spike_train(unit_id, segment_index) for unit_id in unit1_ids] - sample_frames1_accumulator.extend(spike_trains1) - - unit_indices1 = [np.full(len(train), unit) for unit, train in enumerate(spike_trains1)] - unit_indices1_accumulator.extend(unit_indices1) - - for segment_index in range(num_segments_sorting2): - spike_trains2 = [sorting2.get_unit_spike_train(unit_id, segment_index) for unit_id in unit2_ids] - sample_frames2_accumulator.extend(spike_trains2) - - unit_indices2 = [np.full(len(train), unit) for unit, train in enumerate(spike_trains2)] - unit_indices2_accumulator.extend(unit_indices2) - - # Concatenate accumulated data only once - sample_frames1_all = np.concatenate(sample_frames1_accumulator) - unit_indices1_all = np.concatenate(unit_indices1_accumulator) + matching_matrix = np.zeros((num_units_sorting1, num_units_sorting2), dtype=np.uint16) - sample_frames2_all = np.concatenate(sample_frames2_accumulator) - unit_indices2_all = np.concatenate(unit_indices2_accumulator) + spike_vector1_segments = sorting1.to_spike_vector(concatenated=False) + spike_vector2_segments = sorting2.to_spike_vector(concatenated=False) - # Sort the sample_frames and unit_indices arrays - sort_indices1 = np.argsort(sample_frames1_all) - sample_frames1_sorted = sample_frames1_all[sort_indices1] - unit_indices1_sorted = unit_indices1_all[sort_indices1] - - sort_indices2 = np.argsort(sample_frames2_all) - sample_frames2_sorted = sample_frames2_all[sort_indices2] - unit_indices2_sorted = unit_indices2_all[sort_indices2] - - compute_matching_matrix = get_optimized_compute_matching_matrix() - - full_matrix = compute_matching_matrix( - sample_frames1_sorted, - sample_frames2_sorted, - unit_indices1_sorted, - unit_indices2_sorted, - num_units_sorting1, - num_units_sorting2, - delta_frames, - ) + num_segments_sorting1 = sorting1.get_num_segments() + num_segments_sorting2 = sorting2.get_num_segments() + max_segment_to_compare = max(num_segments_sorting1, num_segments_sorting2) + + # Segments should be matched one by one + for segment_index in range(max_segment_to_compare): + spike_vector1 = spike_vector1_segments[segment_index] + spike_vector2 = spike_vector2_segments[segment_index] + + sample_frames1_sorted = spike_vector1["sample_index"] + sample_frames2_sorted = spike_vector2["sample_index"] + + unit_indices1_sorted = spike_vector1["unit_index"] + unit_indices2_sorted = spike_vector2["unit_index"] + + matching_matrix += get_optimized_compute_matching_matrix()( + sample_frames1_sorted, + sample_frames2_sorted, + unit_indices1_sorted, + unit_indices2_sorted, + num_units_sorting1, + num_units_sorting2, + delta_frames, + ) + # Build a data frame from the matching matrix import pandas as pd - match_event_counts_df = pd.DataFrame(full_matrix, index=unit1_ids, columns=unit2_ids) + unit_ids_of_sorting1 = sorting1.get_unit_ids() + unit_ids_of_sorting2 = sorting2.get_unit_ids() + match_event_counts_df = pd.DataFrame(matching_matrix, index=unit_ids_of_sorting1, columns=unit_ids_of_sorting2) return match_event_counts_df From ab4e612c01f891ed32642b5d9cba5b0b4bf7d3fe Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Tue, 31 Oct 2023 17:35:17 +0100 Subject: [PATCH 10/13] Update src/spikeinterface/comparison/comparisontools.py --- src/spikeinterface/comparison/comparisontools.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index f01db549bb..d065392e72 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -114,6 +114,7 @@ def get_optimized_compute_matching_matrix(): which uses numba. I tested using the numba dispatcher programatically to avoids this but the performance improvements were lost. Think you can do better? Don't forget to measure performance against the current implementation! + TODO: unify numba decorator across all modules """ if hasattr(get_optimized_compute_matching_matrix, "_cached_function"): From 2742108729b7a3edb7599d59e6e9212413f92b69 Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Tue, 31 Oct 2023 17:41:07 +0100 Subject: [PATCH 11/13] Update src/spikeinterface/comparison/comparisontools.py --- src/spikeinterface/comparison/comparisontools.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index d065392e72..44d5bd01d7 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -240,10 +240,10 @@ def make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=None): num_segments_sorting1 = sorting1.get_num_segments() num_segments_sorting2 = sorting2.get_num_segments() - max_segment_to_compare = max(num_segments_sorting1, num_segments_sorting2) + assert num_segments_sorting1 == num_segments_sorting2, "make_match_count_matrix : sorting1 and sorting must have the same segment number" # Segments should be matched one by one - for segment_index in range(max_segment_to_compare): + for segment_index in range(num_segments_sorting1): spike_vector1 = spike_vector1_segments[segment_index] spike_vector2 = spike_vector2_segments[segment_index] From 5df47e18cd6e5bfce4c7f09ecaa14aae73584665 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 31 Oct 2023 16:41:24 +0000 Subject: [PATCH 12/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/comparison/comparisontools.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index 44d5bd01d7..6f2b8796cc 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -240,7 +240,9 @@ def make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=None): num_segments_sorting1 = sorting1.get_num_segments() num_segments_sorting2 = sorting2.get_num_segments() - assert num_segments_sorting1 == num_segments_sorting2, "make_match_count_matrix : sorting1 and sorting must have the same segment number" + assert ( + num_segments_sorting1 == num_segments_sorting2 + ), "make_match_count_matrix : sorting1 and sorting must have the same segment number" # Segments should be matched one by one for segment_index in range(num_segments_sorting1): From 345cb9f7d7d807a456ab27e1411676366fdf00ea Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 31 Oct 2023 17:57:54 +0100 Subject: [PATCH 13/13] Update src/spikeinterface/comparison/comparisontools.py Co-authored-by: Alessio Buccino --- src/spikeinterface/comparison/comparisontools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index d3e40df6c8..7a1fb87175 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -242,7 +242,7 @@ def make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=None): num_segments_sorting2 = sorting2.get_num_segments() assert ( num_segments_sorting1 == num_segments_sorting2 - ), "make_match_count_matrix : sorting1 and sorting must have the same segment number" + ), "make_match_count_matrix : sorting1 and sorting2 must have the same segment number" # Segments should be matched one by one for segment_index in range(num_segments_sorting1):