Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize make_match_count_matrix #2114

Merged
merged 16 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 159 additions & 39 deletions src/spikeinterface/comparison/comparisontools.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
"""

import numpy as np
from joblib import Parallel, delayed
samuelgarcia marked this conversation as resolved.
Show resolved Hide resolved


def count_matching_events(times1, times2, delta=10):
Expand Down Expand Up @@ -109,48 +108,169 @@ def count_match_spikes(times1, all_times2, delta_frames): # , event_counts1, ev
return matching_event_counts


def make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=1):
"""
Make the match_event_count matrix.
Basically it counts the matching events for all given pairs of spike trains from
sorting1 and sorting2.
def get_optimized_compute_matching_matrix():
"""
This function is to avoid the bare try-except pattern when importing the compute_matching_matrix function
which uses numba. I tested using the numba dispatcher programatically to avoids this
but the performance improvements were lost. Think you can do better? Don't forget to measure performance against
the current implementation!
samuelgarcia marked this conversation as resolved.
Show resolved Hide resolved
TODO: unify numba decorator across all modules
"""

if hasattr(get_optimized_compute_matching_matrix, "_cached_function"):
return get_optimized_compute_matching_matrix._cached_function

import numba

@numba.jit(nopython=True, nogil=True)
def compute_matching_matrix(
frames_spike_train1,
frames_spike_train2,
unit_indices1,
unit_indices2,
num_units_sorting1,
num_units_sorting2,
delta_frames,
):
h-mayorquin marked this conversation as resolved.
Show resolved Hide resolved
"""
Compute a matrix representing the matches between two spike trains.

Given two spike trains, this function finds matching spikes based on a temporal proximity criterion
defined by `delta_frames`. The resulting matrix indicates the number of matches between units
in `frames_spike_train1` and `frames_spike_train2`.

Parameters
----------
frames_spike_train1 : ndarray
Array of frames for the first spike train. Should be ordered in ascending order.
frames_spike_train2 : ndarray
Array of frames for the second spike train. Should be ordered in ascending order.
unit_indices1 : ndarray
Array indicating the unit indices corresponding to each spike in `frames_spike_train1`.
unit_indices2 : ndarray
Array indicating the unit indices corresponding to each spike in `frames_spike_train2`.
num_units_sorting1 : int
Total number of units in the first spike train.
num_units_sorting2 : int
Total number of units in the second spike train.
delta_frames : int
Maximum difference in frames between two spikes to consider them as a match.

Returns
-------
matching_matrix : ndarray
A matrix of shape (num_units_sorting1, num_units_sorting2) where each entry [i, j] represents
the number of matching spikes between unit i of `frames_spike_train1` and unit j of `frames_spike_train2`.

Notes
-----
This algorithm identifies matching spikes between two ordered spike trains.
By iterating through each spike in the first train, it compares them against spikes in the second train,
determining matches based on the two spikes frames being within `delta_frames` of each other.

To avoid redundant comparisons the algorithm maintains a reference, `lower_search_limit_in_second_train`,
which signifies the minimal index in the second spike train that might match the upcoming spike
in the first train. This means that the start of the search moves forward in the second train as the
matches between the two trains are found decreasing the number of comparisons needed.

An important condition here is thatthe same spike is not matched twice. This is managed by keeping track
of the last matched frame for each unit pair in `previous_frame1_match` and `previous_frame2_match`

For more details on the rationale behind this approach, refer to the documentation of this module and/or
the metrics section in SpikeForest documentation.
"""

matching_matrix = np.zeros((num_units_sorting1, num_units_sorting2), dtype=np.uint16)

# Used to avoid the same spike matching twice
previous_frame1_match = -np.ones_like(matching_matrix, dtype=np.int64)
previous_frame2_match = -np.ones_like(matching_matrix, dtype=np.int64)

lower_search_limit_in_second_train = 0

for index1 in range(len(frames_spike_train1)):
# Keeps track of which frame in the second spike train should be used as a search start for matches
index2 = lower_search_limit_in_second_train
frame1 = frames_spike_train1[index1]

# Determine next_frame1 if current frame is not the last frame
not_in_the_last_loop = index1 < len(frames_spike_train1) - 1
if not_in_the_last_loop:
next_frame1 = frames_spike_train1[index1 + 1]

while index2 < len(frames_spike_train2):
frame2 = frames_spike_train2[index2]
not_a_match = abs(frame1 - frame2) > delta_frames
if not_a_match:
# Go to the next frame in the first train
break

# Map the match to a matrix
row, column = unit_indices1[index1], unit_indices2[index2]

# The same spike cannot be matched twice see the notes in the docstring for more info on this constraint
if frame1 != previous_frame1_match[row, column] and frame2 != previous_frame2_match[row, column]:
previous_frame1_match[row, column] = frame1
previous_frame2_match[row, column] = frame2

matching_matrix[row, column] += 1

index2 += 1

# Advance the lower_search_limit_in_second_train if the next frame in the first train does not match
not_a_match_with_next = abs(next_frame1 - frame2) > delta_frames
if not_a_match_with_next:
lower_search_limit_in_second_train = index2

return matching_matrix

# Cache the compiled function
get_optimized_compute_matching_matrix._cached_function = compute_matching_matrix

return compute_matching_matrix


def make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=None):
num_units_sorting1 = sorting1.get_num_units()
num_units_sorting2 = sorting2.get_num_units()
matching_matrix = np.zeros((num_units_sorting1, num_units_sorting2), dtype=np.uint16)

spike_vector1_segments = sorting1.to_spike_vector(concatenated=False)
spike_vector2_segments = sorting2.to_spike_vector(concatenated=False)

num_segments_sorting1 = sorting1.get_num_segments()
num_segments_sorting2 = sorting2.get_num_segments()
assert (
num_segments_sorting1 == num_segments_sorting2
), "make_match_count_matrix : sorting1 and sorting must have the same segment number"
h-mayorquin marked this conversation as resolved.
Show resolved Hide resolved

# Segments should be matched one by one
for segment_index in range(num_segments_sorting1):
spike_vector1 = spike_vector1_segments[segment_index]
spike_vector2 = spike_vector2_segments[segment_index]

sample_frames1_sorted = spike_vector1["sample_index"]
sample_frames2_sorted = spike_vector2["sample_index"]

unit_indices1_sorted = spike_vector1["unit_index"]
unit_indices2_sorted = spike_vector2["unit_index"]

Parameters
----------
sorting1: SortingExtractor
The first sorting extractor
sorting2: SortingExtractor
The second sorting extractor
delta_frames: int
Number of frames to consider spikes coincident
n_jobs: int
Number of jobs to run in parallel
matching_matrix += get_optimized_compute_matching_matrix()(
sample_frames1_sorted,
sample_frames2_sorted,
unit_indices1_sorted,
unit_indices2_sorted,
num_units_sorting1,
num_units_sorting2,
delta_frames,
)

Returns
-------
match_event_count: array (int64)
Matrix of match count spike
"""
# Build a data frame from the matching matrix
import pandas as pd

unit1_ids = np.array(sorting1.get_unit_ids())
unit2_ids = np.array(sorting2.get_unit_ids())

match_event_counts = np.zeros((len(unit1_ids), len(unit2_ids)), dtype="int64")

# preload all spiketrains 2 into a list
for segment_index in range(sorting1.get_num_segments()):
s2_spiketrains = [sorting2.get_unit_spike_train(u2, segment_index=segment_index) for u2 in unit2_ids]

match_event_count_segment = Parallel(n_jobs=n_jobs)(
delayed(count_match_spikes)(
sorting1.get_unit_spike_train(u1, segment_index=segment_index), s2_spiketrains, delta_frames
)
for i1, u1 in enumerate(unit1_ids)
)
match_event_counts += np.array(match_event_count_segment)

match_event_counts_df = pd.DataFrame(np.array(match_event_counts), index=unit1_ids, columns=unit2_ids)
unit_ids_of_sorting1 = sorting1.get_unit_ids()
unit_ids_of_sorting2 = sorting2.get_unit_ids()
match_event_counts_df = pd.DataFrame(matching_matrix, index=unit_ids_of_sorting1, columns=unit_ids_of_sorting2)

return match_event_counts_df

Expand Down
105 changes: 97 additions & 8 deletions src/spikeinterface/comparison/tests/test_comparisontools.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
do_count_score,
compute_performance,
)
from spikeinterface.core.generate import generate_sorting


def make_sorting(times1, labels1, times2, labels2):
Expand All @@ -27,25 +28,113 @@ def make_sorting(times1, labels1, times2, labels2):
def test_make_match_count_matrix():
delta_frames = 10

# simple match
sorting1, sorting2 = make_sorting(
[100, 200, 300, 400],
[0, 0, 1, 0],
[
101,
201,
301,
],
[101, 201, 301],
[0, 0, 5],
)

match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=1)
# ~ print(match_event_count)
match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames)

assert match_event_count.shape[0] == len(sorting1.get_unit_ids())
assert match_event_count.shape[1] == len(sorting2.get_unit_ids())


def test_make_match_count_matrix_sorting_with_itself_simple():
delta_frames = 10

# simple sorting with itself
sorting1, sorting2 = make_sorting(
[100, 200, 300, 400],
[0, 0, 1, 0],
[100, 200, 300, 400],
[0, 0, 1, 0],
)

match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames)

expected_result = [[3, 0], [0, 1]]
assert_array_equal(match_event_count.to_numpy(), expected_result)


def test_make_match_count_matrix_sorting_with_itself_longer():
seed = 2
sorting = generate_sorting(num_units=10, sampling_frequency=30000, durations=[5, 5], seed=seed)

delta_frame_milliseconds = 0.1 # Short so that we only matches between a unit and itself
delta_frames_seconds = delta_frame_milliseconds / 1000
delta_frames = delta_frames_seconds * sorting.get_sampling_frequency()
match_event_count = make_match_count_matrix(sorting, sorting, delta_frames)

match_event_count_as_array = match_event_count.to_numpy()
matches_with_itself = np.diag(match_event_count_as_array)

# The number of matches with itself should be equal to the number of spikes in each unit
spikes_per_unit_dict = sorting.count_num_spikes_per_unit()
expected_result = np.array([spikes_per_unit_dict[u] for u in spikes_per_unit_dict.keys()])
assert_array_equal(matches_with_itself, expected_result)


def test_make_match_count_matrix_with_mismatched_sortings():
delta_frames = 10

sorting1, sorting2 = make_sorting(
[100, 200, 300, 400], [0, 0, 1, 0], [500, 600, 700, 800], [0, 0, 1, 0] # Completely different spike times
)

match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames)

expected_result = [[0, 0], [0, 0]] # No matches between sorting1 and sorting2
assert_array_equal(match_event_count.to_numpy(), expected_result)


def test_make_match_count_matrix_no_double_matching():
# Jeremy Magland condition: no double matching
frames_spike_train1 = [100, 105, 120, 1000]
unit_indices1 = [0, 1, 0, 0]
frames_spike_train2 = [101, 150, 1000]
unit_indices2 = [0, 1, 0]
delta_frames = 100

# Here the key is that the first frame in the first sorting (120) should not match anything in the second sorting
# Because the matching candidates in the second sorting are already matched to the first two frames
# in the first sorting

# In detail:
# The first frame in sorting 1 (100) from unit 0 should match:
# * The first frame in sorting 2 (101) from unit 0
# * The second frame in sorting 2 (150) from unit 1
# The second frame in sorting 1 (105) from unit 1 should match:
# * The first frame in sorting 2 (101) from unit 0
# * The second frame in sorting 2 (150) from unit 1
# The third frame in sorting 1 (120) from unit 0 should not match anything
# The final frame in sorting 1 (1000) from unit 0 should only match the final frame in sorting 2 (1000) from unit 0

sorting1, sorting2 = make_sorting(frames_spike_train1, unit_indices1, frames_spike_train2, unit_indices2)

result = make_match_count_matrix(sorting1, sorting2, delta_frames=delta_frames)

expected_result = np.array([[2, 1], [1, 1]]) # Only one match is expected despite potential repeats
assert_array_equal(result.to_numpy(), expected_result)


def test_make_match_count_matrix_repeated_matching_but_no_double_counting():
# Challenging condition, this was failing with the previous approach that used np.where and np.diff
frames_spike_train1 = [100, 105, 110] # Will fail with [100, 105, 110, 120]
frames_spike_train2 = [100, 105, 110]
unit_indices1 = [0, 0, 0] # Will fail with [0, 0, 0, 0]
unit_indices2 = [0, 0, 0]
delta_frames = 20 # long enough, so all frames in both sortings are within each other reach

sorting1, sorting2 = make_sorting(frames_spike_train1, unit_indices1, frames_spike_train2, unit_indices2)

result = make_match_count_matrix(sorting1, sorting2, delta_frames=delta_frames)

expected_result = np.array([[3]])
assert_array_equal(result.to_numpy(), expected_result)


def test_make_agreement_scores():
delta_frames = 10

Expand Down
4 changes: 3 additions & 1 deletion src/spikeinterface/core/basesorting.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import warnings
from typing import List, Optional, Union

Expand Down Expand Up @@ -267,7 +269,7 @@ def get_total_num_spikes(self):
)
return self.count_num_spikes_per_unit()

def count_num_spikes_per_unit(self):
def count_num_spikes_per_unit(self) -> dict:
"""
For each unit : get number of spikes across segments.

Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def generate_sorting(
duration=durations[segment_index],
refractory_period_ms=refractory_period_ms,
firing_rates=firing_rates,
seed=seed,
seed=seed + segment_index,
h-mayorquin marked this conversation as resolved.
Show resolved Hide resolved
)

if empty_units is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def test_peak_sign(self):
# invert recording
rec_inv = scale(rec, gain=-1.0)

we_inv = extract_waveforms(rec_inv, sort, self.cache_folder / "toy_waveforms_inv")
we_inv = extract_waveforms(rec_inv, sort, self.cache_folder / "toy_waveforms_inv", seed=0)

# compute amplitudes
_ = compute_spike_amplitudes(we, peak_sign="neg")
Expand Down
Loading