Skip to content

Commit

Permalink
added docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
h-mayorquin committed Oct 24, 2023
1 parent 52261e2 commit 6d5a98e
Showing 1 changed file with 65 additions and 12 deletions.
77 changes: 65 additions & 12 deletions src/spikeinterface/comparison/comparisontools.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,16 +110,20 @@ def count_match_spikes(times1, all_times2, delta_frames): # , event_counts1, ev


def get_optimized_compute_matching_matrix():
# Cache for compiled function
"""
This function is to avoid the bare try-except pattern when importing the compute_matching_matrix function
which uses numba. I tested using the numba dispatcher programatically to avoids this
but the performance improvements were lost. Think you can do better? Don't forget to measure performance against
the current implementation!
"""

if hasattr(get_optimized_compute_matching_matrix, "_cached_function"):
return get_optimized_compute_matching_matrix._cached_function

# Dynamic import of numba
import numba

# Nested function
@numba.jit(nopython=True, nogil=True)
def compute_matching_matrix_inner(
def compute_matching_matrix(
frames_spike_train1,
frames_spike_train2,
unit_indices1,
Expand All @@ -128,9 +132,57 @@ def compute_matching_matrix_inner(
num_units_sorting2,
delta_frames,
):
"""
Compute a matrix representing the matches between two spike trains.
Given two spike trains, this function finds matching spikes based on a temporal proximity criterion
defined by `delta_frames`. The resulting matrix indicates the number of matches between units
in `frames_spike_train1` and `frames_spike_train2`.
Parameters
----------
frames_spike_train1 : ndarray
Array of frames for the first spike train. Should be ordered in ascending order.
frames_spike_train2 : ndarray
Array of frames for the second spike train. Should be ordered in ascending order.
unit_indices1 : ndarray
Array indicating the unit indices corresponding to each spike in `frames_spike_train1`.
unit_indices2 : ndarray
Array indicating the unit indices corresponding to each spike in `frames_spike_train2`.
num_units_sorting1 : int
Total number of units in the first spike train.
num_units_sorting2 : int
Total number of units in the second spike train.
delta_frames : int
Maximum difference in frames between two spikes to consider them as a match.
Returns
-------
matching_matrix : ndarray
A matrix of shape (num_units_sorting1, num_units_sorting2) where each entry [i, j] represents
the number of matching spikes between unit i of `frames_spike_train1` and unit j of `frames_spike_train2`.
Notes
-----
This algorithm identifies matching spikes between two ordered spike trains.
By iterating through each spike in the first train, it compares them against spikes in the second train,
determining matches based on the two spikes frames being within `delta_frames` of each other.
To avoid redundant comparisons the algorithm maintains a reference, `lower_search_limit_in_second_train`,
which signifies the minimal index in the second spike train that might match the upcoming spike
in the first train. This means that the start of the search moves forward in the second train as the
matches between the two trains are found decreasing the number of comparisons needed.
An important condition here is thatthe same spike is not matched twice. This is managed by keeping track
of the last matched frame for each unit pair in `previous_frame1_match` and `previous_frame2_match`
For more details on the rationale behind this approach, refer to the documentation of this module and/or
the metrics section in SpikeForest documentation.
"""

matching_matrix = np.zeros((num_units_sorting1, num_units_sorting2), dtype=np.int64)

# Used for Jeremy Magldan condition where no unit can be matched twice.
# Used to avoid the same spike matching twice
previous_frame1_match = -np.ones_like(matching_matrix, dtype=np.int64)
previous_frame2_match = -np.ones_like(matching_matrix, dtype=np.int64)

Expand All @@ -155,7 +207,9 @@ def compute_matching_matrix_inner(
# Map the match to a matrix
row, column = unit_indices1[index1], unit_indices2[index2]

# Jeremy Magland condition, the same unit can't match twice
# The same spike cannot be matched twice
# This condition is interpret matches asn accuracy task, see the documentation of the module
# Or the metrics section in spike forest to see the reason for this.
if frame1 != previous_frame1_match[row, column] and frame2 != previous_frame2_match[row, column]:
previous_frame1_match[row, column] = frame1
previous_frame2_match[row, column] = frame2
Expand All @@ -165,17 +219,16 @@ def compute_matching_matrix_inner(
index2 += 1

# Advance the minimal index 2 if not in the last loop iteration
if not_in_the_last_loop:
not_a_match_with_next = abs(next_frame1 - frame2) > delta_frames
if not_a_match_with_next:
lower_search_limit_in_second_train = index2
not_a_match_with_next = abs(next_frame1 - frame2) > delta_frames
if not_a_match_with_next:
lower_search_limit_in_second_train = index2

return matching_matrix

# Cache the compiled function
get_optimized_compute_matching_matrix._cached_function = compute_matching_matrix_inner
get_optimized_compute_matching_matrix._cached_function = compute_matching_matrix

return compute_matching_matrix_inner
return compute_matching_matrix


def make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=None):
Expand Down

0 comments on commit 6d5a98e

Please sign in to comment.