Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize make_match_count_matrix #2114

Merged
merged 16 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 123 additions & 37 deletions src/spikeinterface/comparison/comparisontools.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""

import numpy as np
from joblib import Parallel, delayed
samuelgarcia marked this conversation as resolved.
Show resolved Hide resolved
import itertools


def count_matching_events(times1, times2, delta=10):
Expand Down Expand Up @@ -109,50 +109,136 @@ def count_match_spikes(times1, all_times2, delta_frames): # , event_counts1, ev
return matching_event_counts


def make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=1):
"""
Make the match_event_count matrix.
Basically it counts the matching events for all given pairs of spike trains from
sorting1 and sorting2.
def get_optimized_compute_matching_matrix():
# Cache for compiled function
if hasattr(get_optimized_compute_matching_matrix, "_cached_function"):
return get_optimized_compute_matching_matrix._cached_function

Parameters
----------
sorting1: SortingExtractor
The first sorting extractor
sorting2: SortingExtractor
The second sorting extractor
delta_frames: int
Number of frames to consider spikes coincident
n_jobs: int
Number of jobs to run in parallel
# Dynamic import of numba
import numba

Returns
-------
match_event_count: array (int64)
Matrix of match count spike
"""
import pandas as pd
# Nested function
@numba.jit(nopython=True, nogil=True)
def compute_matching_matrix_inner(
h-mayorquin marked this conversation as resolved.
Show resolved Hide resolved
frames_spike_train1,
frames_spike_train2,
unit_indices1,
unit_indices2,
num_units_sorting1,
num_units_sorting2,
delta_frames,
):
h-mayorquin marked this conversation as resolved.
Show resolved Hide resolved
matching_matrix = np.zeros((num_units_sorting1, num_units_sorting2), dtype=np.int64)

unit1_ids = np.array(sorting1.get_unit_ids())
unit2_ids = np.array(sorting2.get_unit_ids())
# Used for Jeremy Magldan condition where no unit can be matched twice.
previous_frame1_match = -np.ones_like(matching_matrix, dtype=np.int64)
previous_frame2_match = -np.ones_like(matching_matrix, dtype=np.int64)

match_event_counts = np.zeros((len(unit1_ids), len(unit2_ids)), dtype="int64")
lower_search_limit_in_second_train = 0

# preload all spiketrains 2 into a list
for segment_index in range(sorting1.get_num_segments()):
s2_spiketrains = [sorting2.get_unit_spike_train(u2, segment_index=segment_index) for u2 in unit2_ids]
for index1 in range(len(frames_spike_train1)):
# Keeps track of which frame in the second spike train should be used as a search start
index2 = lower_search_limit_in_second_train
frame1 = frames_spike_train1[index1]

match_event_count_segment = Parallel(n_jobs=n_jobs)(
delayed(count_match_spikes)(
sorting1.get_unit_spike_train(u1, segment_index=segment_index), s2_spiketrains, delta_frames
)
for i1, u1 in enumerate(unit1_ids)
)
match_event_counts += np.array(match_event_count_segment)
# Determine next_frame1 if current frame is not the last frame
not_in_the_last_loop = index1 < len(frames_spike_train1) - 1
if not_in_the_last_loop:
next_frame1 = frames_spike_train1[index1 + 1]

while index2 < len(frames_spike_train2):
frame2 = frames_spike_train2[index2]
not_a_match = abs(frame1 - frame2) > delta_frames
if not_a_match:
break

# Map the match to a matrix
row, column = unit_indices1[index1], unit_indices2[index2]

# Jeremy Magland condition, the same unit can't match twice
h-mayorquin marked this conversation as resolved.
Show resolved Hide resolved
if frame1 != previous_frame1_match[row, column] and frame2 != previous_frame2_match[row, column]:
previous_frame1_match[row, column] = frame1
previous_frame2_match[row, column] = frame2

matching_matrix[row, column] += 1

index2 += 1

# Advance the minimal index 2 if not in the last loop iteration
if not_in_the_last_loop:
h-mayorquin marked this conversation as resolved.
Show resolved Hide resolved
not_a_match_with_next = abs(next_frame1 - frame2) > delta_frames
h-mayorquin marked this conversation as resolved.
Show resolved Hide resolved
if not_a_match_with_next:
lower_search_limit_in_second_train = index2

return matching_matrix

# Cache the compiled function
get_optimized_compute_matching_matrix._cached_function = compute_matching_matrix_inner

return compute_matching_matrix_inner


def make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=None):
num_units_sorting1 = sorting1.get_num_units()
num_units_sorting2 = sorting2.get_num_units()
num_segments_sorting1 = sorting1.get_num_segments()
num_segments_sorting2 = sorting2.get_num_segments()
unit1_ids = sorting1.get_unit_ids()
unit2_ids = sorting2.get_unit_ids()

sample_frames1_accumulator = []
unit_indices1_accumulator = []

sample_frames2_accumulator = []
unit_indices2_accumulator = []

for segment_index in range(num_segments_sorting1):
spike_trains1 = [sorting1.get_unit_spike_train(unit_id, segment_index) for unit_id in unit1_ids]
sample_frames1_accumulator.extend(spike_trains1)

unit_indices1 = [np.full(len(train), unit) for unit, train in enumerate(spike_trains1)]
unit_indices1_accumulator.extend(unit_indices1)
h-mayorquin marked this conversation as resolved.
Show resolved Hide resolved

for segment_index in range(num_segments_sorting2):
spike_trains2 = [sorting2.get_unit_spike_train(unit_id, segment_index) for unit_id in unit2_ids]
sample_frames2_accumulator.extend(spike_trains2)

unit_indices2 = [np.full(len(train), unit) for unit, train in enumerate(spike_trains2)]
unit_indices2_accumulator.extend(unit_indices2)

# Concatenate accumulated data only once
sample_frames1_all = np.concatenate(sample_frames1_accumulator)
unit_indices1_all = np.concatenate(unit_indices1_accumulator)

sample_frames2_all = np.concatenate(sample_frames2_accumulator)
unit_indices2_all = np.concatenate(unit_indices2_accumulator)

# Sort the sample_frames and unit_indices arrays
sort_indices1 = np.argsort(sample_frames1_all)
sample_frames1_sorted = sample_frames1_all[sort_indices1]
unit_indices1_sorted = unit_indices1_all[sort_indices1]

sort_indices2 = np.argsort(sample_frames2_all)
sample_frames2_sorted = sample_frames2_all[sort_indices2]
unit_indices2_sorted = unit_indices2_all[sort_indices2]

optimized_compute_matching_matrix = get_optimized_compute_matching_matrix()

full_matrix = optimized_compute_matching_matrix(
sample_frames1_sorted,
sample_frames2_sorted,
unit_indices1_sorted,
unit_indices2_sorted,
num_units_sorting1,
num_units_sorting2,
delta_frames,
)

import pandas as pd

match_event_counts_df = pd.DataFrame(np.array(match_event_counts), index=unit1_ids, columns=unit2_ids)
df = pd.DataFrame(full_matrix, index=unit1_ids, columns=unit2_ids)

return match_event_counts_df
return df


def make_agreement_scores(sorting1, sorting2, delta_frames, n_jobs=1):
Expand Down
105 changes: 97 additions & 8 deletions src/spikeinterface/comparison/tests/test_comparisontools.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
do_count_score,
compute_performance,
)
from spikeinterface.core.generate import generate_sorting


def make_sorting(times1, labels1, times2, labels2):
Expand All @@ -27,25 +28,113 @@ def make_sorting(times1, labels1, times2, labels2):
def test_make_match_count_matrix():
delta_frames = 10

# simple match
sorting1, sorting2 = make_sorting(
[100, 200, 300, 400],
[0, 0, 1, 0],
[
101,
201,
301,
],
[101, 201, 301],
[0, 0, 5],
)

match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=1)
# ~ print(match_event_count)
match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames)

assert match_event_count.shape[0] == len(sorting1.get_unit_ids())
assert match_event_count.shape[1] == len(sorting2.get_unit_ids())


def test_make_match_count_matrix_sorting_with_itself_simple():
delta_frames = 10

# simple sorting with itself
sorting1, sorting2 = make_sorting(
[100, 200, 300, 400],
[0, 0, 1, 0],
[100, 200, 300, 400],
[0, 0, 1, 0],
)

match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames)

expected_result = [[3, 0], [0, 1]]
assert_array_equal(match_event_count.to_numpy(), expected_result)


def test_make_match_count_matrix_sorting_with_itself_longer():
seed = 2
sorting = generate_sorting(num_units=10, sampling_frequency=30000, durations=[5, 5], seed=seed)

delta_frame_milliseconds = 0.1 # Short so that we only matches between a unit and itself
delta_frames_seconds = delta_frame_milliseconds / 1000
delta_frames = delta_frames_seconds * sorting.get_sampling_frequency()
match_event_count = make_match_count_matrix(sorting, sorting, delta_frames)

match_event_count_as_array = match_event_count.to_numpy()
matches_with_itself = np.diag(match_event_count_as_array)

# The number of matches with itself should be equal to the number of spikes in each unit
spikes_per_unit_dict = sorting.count_num_spikes_per_unit()
expected_result = np.array([spikes_per_unit_dict[u] for u in spikes_per_unit_dict.keys()])
assert_array_equal(matches_with_itself, expected_result)


def test_make_match_count_matrix_with_mismatched_sortings():
delta_frames = 10

sorting1, sorting2 = make_sorting(
[100, 200, 300, 400], [0, 0, 1, 0], [500, 600, 700, 800], [0, 0, 1, 0] # Completely different spike times
)

match_event_count = make_match_count_matrix(sorting1, sorting2, delta_frames)

expected_result = [[0, 0], [0, 0]] # No matches between sorting1 and sorting2
assert_array_equal(match_event_count.to_numpy(), expected_result)


def test_make_match_count_matrix_no_double_matching():
# Jeremy Magland condition: no double matching
frames_spike_train1 = [100, 105, 120, 1000]
unit_indices1 = [0, 1, 0, 0]
frames_spike_train2 = [101, 150, 1000]
unit_indices2 = [0, 1, 0]
delta_frames = 100

# Here the key is that the first frame in the first sorting (120) should not match anything in the second sorting
# Because the matching candidates in the second sorting are already matched to the first two frames
# in the first sorting

# In detail:
# The first frame in sorting 1 (100) from unit 0 should match:
# * The first frame in sorting 2 (101) from unit 0
# * The second frame in sorting 2 (150) from unit 1
# The second frame in sorting 1 (105) from unit 1 should match:
# * The first frame in sorting 2 (101) from unit 0
# * The second frame in sorting 2 (150) from unit 1
# The third frame in sorting 1 (120) from unit 0 should not match anything
# The final frame in sorting 1 (1000) from unit 0 should only match the final frame in sorting 2 (1000) from unit 0

sorting1, sorting2 = make_sorting(frames_spike_train1, unit_indices1, frames_spike_train2, unit_indices2)

result = make_match_count_matrix(sorting1, sorting2, delta_frames=delta_frames)

expected_result = np.array([[2, 1], [1, 1]]) # Only one match is expected despite potential repeats
assert_array_equal(result.to_numpy(), expected_result)


def test_make_match_count_matrix_repeated_matching_but_no_double_counting():
# Challenging condition, this was failing with the previous approach that used np.where and np.diff
frames_spike_train1 = [100, 105, 110] # Will fail with [100, 105, 110, 120]
frames_spike_train2 = [100, 105, 110]
unit_indices1 = [0, 0, 0] # Will fail with [0, 0, 0, 0]
unit_indices2 = [0, 0, 0]
delta_frames = 20 # long enough, so all frames in both sortings are within each other reach

sorting1, sorting2 = make_sorting(frames_spike_train1, unit_indices1, frames_spike_train2, unit_indices2)

result = make_match_count_matrix(sorting1, sorting2, delta_frames=delta_frames)

expected_result = np.array([[3]])
assert_array_equal(result.to_numpy(), expected_result)


def test_make_agreement_scores():
delta_frames = 10

Expand Down
4 changes: 3 additions & 1 deletion src/spikeinterface/core/basesorting.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import warnings
from typing import List, Optional, Union

Expand Down Expand Up @@ -268,7 +270,7 @@ def get_total_num_spikes(self):
)
return self.count_num_spikes_per_unit()

def count_num_spikes_per_unit(self):
def count_num_spikes_per_unit(self) -> dict:
"""
For each unit : get number of spikes across segments.

Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def generate_sorting(
duration=durations[segment_index],
refractory_period_ms=refractory_period_ms,
firing_rates=firing_rates,
seed=seed,
seed=seed + segment_index,
h-mayorquin marked this conversation as resolved.
Show resolved Hide resolved
)

if empty_units is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def test_peak_sign(self):
# invert recording
rec_inv = scale(rec, gain=-1.0)

we_inv = extract_waveforms(rec_inv, sort, self.cache_folder / "toy_waveforms_inv")
we_inv = extract_waveforms(rec_inv, sort, self.cache_folder / "toy_waveforms_inv", seed=0)

# compute amplitudes
_ = compute_spike_amplitudes(we, peak_sign="neg")
Expand Down