diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index cd89f042cf..7a1fb87175 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -3,7 +3,6 @@ """ import numpy as np -from joblib import Parallel, delayed def count_matching_events(times1, times2, delta=10): @@ -109,48 +108,169 @@ def count_match_spikes(times1, all_times2, delta_frames): # , event_counts1, ev return matching_event_counts -def make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=1): - """ - Make the match_event_count matrix. - Basically it counts the matching events for all given pairs of spike trains from - sorting1 and sorting2. +def get_optimized_compute_matching_matrix(): + """ + This function is to avoid the bare try-except pattern when importing the compute_matching_matrix function + which uses numba. I tested using the numba dispatcher programatically to avoids this + but the performance improvements were lost. Think you can do better? Don't forget to measure performance against + the current implementation! + TODO: unify numba decorator across all modules + """ + + if hasattr(get_optimized_compute_matching_matrix, "_cached_function"): + return get_optimized_compute_matching_matrix._cached_function + + import numba + + @numba.jit(nopython=True, nogil=True) + def compute_matching_matrix( + frames_spike_train1, + frames_spike_train2, + unit_indices1, + unit_indices2, + num_units_sorting1, + num_units_sorting2, + delta_frames, + ): + """ + Compute a matrix representing the matches between two spike trains. + + Given two spike trains, this function finds matching spikes based on a temporal proximity criterion + defined by `delta_frames`. The resulting matrix indicates the number of matches between units + in `frames_spike_train1` and `frames_spike_train2`. + + Parameters + ---------- + frames_spike_train1 : ndarray + Array of frames for the first spike train. Should be ordered in ascending order. + frames_spike_train2 : ndarray + Array of frames for the second spike train. Should be ordered in ascending order. + unit_indices1 : ndarray + Array indicating the unit indices corresponding to each spike in `frames_spike_train1`. + unit_indices2 : ndarray + Array indicating the unit indices corresponding to each spike in `frames_spike_train2`. + num_units_sorting1 : int + Total number of units in the first spike train. + num_units_sorting2 : int + Total number of units in the second spike train. + delta_frames : int + Maximum difference in frames between two spikes to consider them as a match. + + Returns + ------- + matching_matrix : ndarray + A matrix of shape (num_units_sorting1, num_units_sorting2) where each entry [i, j] represents + the number of matching spikes between unit i of `frames_spike_train1` and unit j of `frames_spike_train2`. + + Notes + ----- + This algorithm identifies matching spikes between two ordered spike trains. + By iterating through each spike in the first train, it compares them against spikes in the second train, + determining matches based on the two spikes frames being within `delta_frames` of each other. + + To avoid redundant comparisons the algorithm maintains a reference, `lower_search_limit_in_second_train`, + which signifies the minimal index in the second spike train that might match the upcoming spike + in the first train. This means that the start of the search moves forward in the second train as the + matches between the two trains are found decreasing the number of comparisons needed. + + An important condition here is thatthe same spike is not matched twice. This is managed by keeping track + of the last matched frame for each unit pair in `previous_frame1_match` and `previous_frame2_match` + + For more details on the rationale behind this approach, refer to the documentation of this module and/or + the metrics section in SpikeForest documentation. + """ + + matching_matrix = np.zeros((num_units_sorting1, num_units_sorting2), dtype=np.uint16) + + # Used to avoid the same spike matching twice + previous_frame1_match = -np.ones_like(matching_matrix, dtype=np.int64) + previous_frame2_match = -np.ones_like(matching_matrix, dtype=np.int64) + + lower_search_limit_in_second_train = 0 + + for index1 in range(len(frames_spike_train1)): + # Keeps track of which frame in the second spike train should be used as a search start for matches + index2 = lower_search_limit_in_second_train + frame1 = frames_spike_train1[index1] + + # Determine next_frame1 if current frame is not the last frame + not_in_the_last_loop = index1 < len(frames_spike_train1) - 1 + if not_in_the_last_loop: + next_frame1 = frames_spike_train1[index1 + 1] + + while index2 < len(frames_spike_train2): + frame2 = frames_spike_train2[index2] + not_a_match = abs(frame1 - frame2) > delta_frames + if not_a_match: + # Go to the next frame in the first train + break + + # Map the match to a matrix + row, column = unit_indices1[index1], unit_indices2[index2] + + # The same spike cannot be matched twice see the notes in the docstring for more info on this constraint + if frame1 != previous_frame1_match[row, column] and frame2 != previous_frame2_match[row, column]: + previous_frame1_match[row, column] = frame1 + previous_frame2_match[row, column] = frame2 + + matching_matrix[row, column] += 1 + + index2 += 1 + + # Advance the lower_search_limit_in_second_train if the next frame in the first train does not match + not_a_match_with_next = abs(next_frame1 - frame2) > delta_frames + if not_a_match_with_next: + lower_search_limit_in_second_train = index2 + + return matching_matrix + + # Cache the compiled function + get_optimized_compute_matching_matrix._cached_function = compute_matching_matrix + + return compute_matching_matrix + + +def make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=None): + num_units_sorting1 = sorting1.get_num_units() + num_units_sorting2 = sorting2.get_num_units() + matching_matrix = np.zeros((num_units_sorting1, num_units_sorting2), dtype=np.uint16) + + spike_vector1_segments = sorting1.to_spike_vector(concatenated=False) + spike_vector2_segments = sorting2.to_spike_vector(concatenated=False) + + num_segments_sorting1 = sorting1.get_num_segments() + num_segments_sorting2 = sorting2.get_num_segments() + assert ( + num_segments_sorting1 == num_segments_sorting2 + ), "make_match_count_matrix : sorting1 and sorting2 must have the same segment number" + + # Segments should be matched one by one + for segment_index in range(num_segments_sorting1): + spike_vector1 = spike_vector1_segments[segment_index] + spike_vector2 = spike_vector2_segments[segment_index] + + sample_frames1_sorted = spike_vector1["sample_index"] + sample_frames2_sorted = spike_vector2["sample_index"] + + unit_indices1_sorted = spike_vector1["unit_index"] + unit_indices2_sorted = spike_vector2["unit_index"] - Parameters - ---------- - sorting1: SortingExtractor - The first sorting extractor - sorting2: SortingExtractor - The second sorting extractor - delta_frames: int - Number of frames to consider spikes coincident - n_jobs: int - Number of jobs to run in parallel + matching_matrix += get_optimized_compute_matching_matrix()( + sample_frames1_sorted, + sample_frames2_sorted, + unit_indices1_sorted, + unit_indices2_sorted, + num_units_sorting1, + num_units_sorting2, + delta_frames, + ) - Returns - ------- - match_event_count: array (int64) - Matrix of match count spike - """ + # Build a data frame from the matching matrix import pandas as pd - unit1_ids = np.array(sorting1.get_unit_ids()) - unit2_ids = np.array(sorting2.get_unit_ids()) - - match_event_counts = np.zeros((len(unit1_ids), len(unit2_ids)), dtype="int64") - - # preload all spiketrains 2 into a list - for segment_index in range(sorting1.get_num_segments()): - s2_spiketrains = [sorting2.get_unit_spike_train(u2, segment_index=segment_index) for u2 in unit2_ids] - - match_event_count_segment = Parallel(n_jobs=n_jobs)( - delayed(count_match_spikes)( - sorting1.get_unit_spike_train(u1, segment_index=segment_index), s2_spiketrains, delta_frames - ) - for i1, u1 in enumerate(unit1_ids) - ) - match_event_counts += np.array(match_event_count_segment) - - match_event_counts_df = pd.DataFrame(np.array(match_event_counts), index=unit1_ids, columns=unit2_ids) + unit_ids_of_sorting1 = sorting1.get_unit_ids() + unit_ids_of_sorting2 = sorting2.get_unit_ids() + match_event_counts_df = pd.DataFrame(matching_matrix, index=unit_ids_of_sorting1, columns=unit_ids_of_sorting2) return match_event_counts_df diff --git a/src/spikeinterface/comparison/tests/test_comparisontools.py b/src/spikeinterface/comparison/tests/test_comparisontools.py index 5d5c56d15c..c6494b04d1 100644 --- a/src/spikeinterface/comparison/tests/test_comparisontools.py +++ b/src/spikeinterface/comparison/tests/test_comparisontools.py @@ -15,6 +15,7 @@ do_count_score, compute_performance, ) +from spikeinterface.core.generate import generate_sorting def make_sorting(times1, labels1, times2, labels2): @@ -27,25 +28,113 @@ def make_sorting(times1, labels1, times2, labels2): def test_make_match_count_matrix(): delta_frames = 10 - # simple match sorting1, sorting2 = make_sorting( [100, 200, 300, 400], [0, 0, 1, 0], - [ - 101, - 201, - 301, - ], + [101, 201, 301], [0, 0, 5], ) - match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=1) - # ~ print(match_event_count) + match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames) assert match_event_count.shape[0] == len(sorting1.get_unit_ids()) assert match_event_count.shape[1] == len(sorting2.get_unit_ids()) +def test_make_match_count_matrix_sorting_with_itself_simple(): + delta_frames = 10 + + # simple sorting with itself + sorting1, sorting2 = make_sorting( + [100, 200, 300, 400], + [0, 0, 1, 0], + [100, 200, 300, 400], + [0, 0, 1, 0], + ) + + match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames) + + expected_result = [[3, 0], [0, 1]] + assert_array_equal(match_event_count.to_numpy(), expected_result) + + +def test_make_match_count_matrix_sorting_with_itself_longer(): + seed = 2 + sorting = generate_sorting(num_units=10, sampling_frequency=30000, durations=[5, 5], seed=seed) + + delta_frame_milliseconds = 0.1 # Short so that we only matches between a unit and itself + delta_frames_seconds = delta_frame_milliseconds / 1000 + delta_frames = delta_frames_seconds * sorting.get_sampling_frequency() + match_event_count = make_match_count_matrix(sorting, sorting, delta_frames) + + match_event_count_as_array = match_event_count.to_numpy() + matches_with_itself = np.diag(match_event_count_as_array) + + # The number of matches with itself should be equal to the number of spikes in each unit + spikes_per_unit_dict = sorting.count_num_spikes_per_unit() + expected_result = np.array([spikes_per_unit_dict[u] for u in spikes_per_unit_dict.keys()]) + assert_array_equal(matches_with_itself, expected_result) + + +def test_make_match_count_matrix_with_mismatched_sortings(): + delta_frames = 10 + + sorting1, sorting2 = make_sorting( + [100, 200, 300, 400], [0, 0, 1, 0], [500, 600, 700, 800], [0, 0, 1, 0] # Completely different spike times + ) + + match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames) + + expected_result = [[0, 0], [0, 0]] # No matches between sorting1 and sorting2 + assert_array_equal(match_event_count.to_numpy(), expected_result) + + +def test_make_match_count_matrix_no_double_matching(): + # Jeremy Magland condition: no double matching + frames_spike_train1 = [100, 105, 120, 1000] + unit_indices1 = [0, 1, 0, 0] + frames_spike_train2 = [101, 150, 1000] + unit_indices2 = [0, 1, 0] + delta_frames = 100 + + # Here the key is that the first frame in the first sorting (120) should not match anything in the second sorting + # Because the matching candidates in the second sorting are already matched to the first two frames + # in the first sorting + + # In detail: + # The first frame in sorting 1 (100) from unit 0 should match: + # * The first frame in sorting 2 (101) from unit 0 + # * The second frame in sorting 2 (150) from unit 1 + # The second frame in sorting 1 (105) from unit 1 should match: + # * The first frame in sorting 2 (101) from unit 0 + # * The second frame in sorting 2 (150) from unit 1 + # The third frame in sorting 1 (120) from unit 0 should not match anything + # The final frame in sorting 1 (1000) from unit 0 should only match the final frame in sorting 2 (1000) from unit 0 + + sorting1, sorting2 = make_sorting(frames_spike_train1, unit_indices1, frames_spike_train2, unit_indices2) + + result = make_match_count_matrix(sorting1, sorting2, delta_frames=delta_frames) + + expected_result = np.array([[2, 1], [1, 1]]) # Only one match is expected despite potential repeats + assert_array_equal(result.to_numpy(), expected_result) + + +def test_make_match_count_matrix_repeated_matching_but_no_double_counting(): + # Challenging condition, this was failing with the previous approach that used np.where and np.diff + frames_spike_train1 = [100, 105, 110] # Will fail with [100, 105, 110, 120] + frames_spike_train2 = [100, 105, 110] + unit_indices1 = [0, 0, 0] # Will fail with [0, 0, 0, 0] + unit_indices2 = [0, 0, 0] + delta_frames = 20 # long enough, so all frames in both sortings are within each other reach + + sorting1, sorting2 = make_sorting(frames_spike_train1, unit_indices1, frames_spike_train2, unit_indices2) + + result = make_match_count_matrix(sorting1, sorting2, delta_frames=delta_frames) + + expected_result = np.array([[3]]) + assert_array_equal(result.to_numpy(), expected_result) + + def test_make_agreement_scores(): delta_frames = 10 diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 9d887390bb..94b08d8cc3 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import warnings from typing import List, Optional, Union @@ -267,7 +269,7 @@ def get_total_num_spikes(self): ) return self.count_num_spikes_per_unit() - def count_num_spikes_per_unit(self): + def count_num_spikes_per_unit(self) -> dict: """ For each unit : get number of spikes across segments. diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 3fb01ea02f..c670474f0e 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -172,7 +172,7 @@ def generate_sorting( duration=durations[segment_index], refractory_period_ms=refractory_period_ms, firing_rates=firing_rates, - seed=seed, + seed=seed + segment_index, ) if empty_units is not None: diff --git a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py index 73bbee611b..eb8317e4df 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py @@ -210,7 +210,7 @@ def test_peak_sign(self): # invert recording rec_inv = scale(rec, gain=-1.0) - we_inv = extract_waveforms(rec_inv, sort, self.cache_folder / "toy_waveforms_inv") + we_inv = extract_waveforms(rec_inv, sort, self.cache_folder / "toy_waveforms_inv", seed=0) # compute amplitudes _ = compute_spike_amplitudes(we, peak_sign="neg")