diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index 2f3b8396fc..3ed3e386da 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -110,16 +110,20 @@ def count_match_spikes(times1, all_times2, delta_frames): # , event_counts1, ev def get_optimized_compute_matching_matrix(): - # Cache for compiled function + """ + This function is to avoid the bare try-except pattern when importing the compute_matching_matrix function + which uses numba. I tested using the numba dispatcher programatically to avoids this + but the performance improvements were lost. Think you can do better? Don't forget to measure performance against + the current implementation! + """ + if hasattr(get_optimized_compute_matching_matrix, "_cached_function"): return get_optimized_compute_matching_matrix._cached_function - # Dynamic import of numba import numba - # Nested function @numba.jit(nopython=True, nogil=True) - def compute_matching_matrix_inner( + def compute_matching_matrix( frames_spike_train1, frames_spike_train2, unit_indices1, @@ -128,9 +132,57 @@ def compute_matching_matrix_inner( num_units_sorting2, delta_frames, ): + """ + Compute a matrix representing the matches between two spike trains. + + Given two spike trains, this function finds matching spikes based on a temporal proximity criterion + defined by `delta_frames`. The resulting matrix indicates the number of matches between units + in `frames_spike_train1` and `frames_spike_train2`. + + Parameters + ---------- + frames_spike_train1 : ndarray + Array of frames for the first spike train. Should be ordered in ascending order. + frames_spike_train2 : ndarray + Array of frames for the second spike train. Should be ordered in ascending order. + unit_indices1 : ndarray + Array indicating the unit indices corresponding to each spike in `frames_spike_train1`. + unit_indices2 : ndarray + Array indicating the unit indices corresponding to each spike in `frames_spike_train2`. + num_units_sorting1 : int + Total number of units in the first spike train. + num_units_sorting2 : int + Total number of units in the second spike train. + delta_frames : int + Maximum difference in frames between two spikes to consider them as a match. + + Returns + ------- + matching_matrix : ndarray + A matrix of shape (num_units_sorting1, num_units_sorting2) where each entry [i, j] represents + the number of matching spikes between unit i of `frames_spike_train1` and unit j of `frames_spike_train2`. + + Notes + ----- + This algorithm identifies matching spikes between two ordered spike trains. + By iterating through each spike in the first train, it compares them against spikes in the second train, + determining matches based on the two spikes frames being within `delta_frames` of each other. + + To avoid redundant comparisons the algorithm maintains a reference, `lower_search_limit_in_second_train`, + which signifies the minimal index in the second spike train that might match the upcoming spike + in the first train. This means that the start of the search moves forward in the second train as the + matches between the two trains are found decreasing the number of comparisons needed. + + An important condition here is thatthe same spike is not matched twice. This is managed by keeping track + of the last matched frame for each unit pair in `previous_frame1_match` and `previous_frame2_match` + + For more details on the rationale behind this approach, refer to the documentation of this module and/or + the metrics section in SpikeForest documentation. + """ + matching_matrix = np.zeros((num_units_sorting1, num_units_sorting2), dtype=np.int64) - # Used for Jeremy Magldan condition where no unit can be matched twice. + # Used to avoid the same spike matching twice previous_frame1_match = -np.ones_like(matching_matrix, dtype=np.int64) previous_frame2_match = -np.ones_like(matching_matrix, dtype=np.int64) @@ -155,7 +207,9 @@ def compute_matching_matrix_inner( # Map the match to a matrix row, column = unit_indices1[index1], unit_indices2[index2] - # Jeremy Magland condition, the same unit can't match twice + # The same spike cannot be matched twice + # This condition is interpret matches asn accuracy task, see the documentation of the module + # Or the metrics section in spike forest to see the reason for this. if frame1 != previous_frame1_match[row, column] and frame2 != previous_frame2_match[row, column]: previous_frame1_match[row, column] = frame1 previous_frame2_match[row, column] = frame2 @@ -165,17 +219,16 @@ def compute_matching_matrix_inner( index2 += 1 # Advance the minimal index 2 if not in the last loop iteration - if not_in_the_last_loop: - not_a_match_with_next = abs(next_frame1 - frame2) > delta_frames - if not_a_match_with_next: - lower_search_limit_in_second_train = index2 + not_a_match_with_next = abs(next_frame1 - frame2) > delta_frames + if not_a_match_with_next: + lower_search_limit_in_second_train = index2 return matching_matrix # Cache the compiled function - get_optimized_compute_matching_matrix._cached_function = compute_matching_matrix_inner + get_optimized_compute_matching_matrix._cached_function = compute_matching_matrix - return compute_matching_matrix_inner + return compute_matching_matrix def make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=None):