Skip to content

Commit

Permalink
added spike vector and separate segments
Browse files Browse the repository at this point in the history
  • Loading branch information
h-mayorquin committed Oct 24, 2023
1 parent f740e83 commit 8da8614
Showing 1 changed file with 36 additions and 57 deletions.
93 changes: 36 additions & 57 deletions src/spikeinterface/comparison/comparisontools.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def compute_matching_matrix(
the metrics section in SpikeForest documentation.
"""

matching_matrix = np.zeros((num_units_sorting1, num_units_sorting2), dtype=np.int64)
matching_matrix = np.zeros((num_units_sorting1, num_units_sorting2), dtype=np.uint16)

# Used to avoid the same spike matching twice
previous_frame1_match = -np.ones_like(matching_matrix, dtype=np.int64)
Expand All @@ -188,7 +188,7 @@ def compute_matching_matrix(
lower_search_limit_in_second_train = 0

for index1 in range(len(frames_spike_train1)):
# Keeps track of which frame in the second spike train should be used as a search start
# Keeps track of which frame in the second spike train should be used as a search start for matches
index2 = lower_search_limit_in_second_train
frame1 = frames_spike_train1[index1]

Expand All @@ -201,14 +201,13 @@ def compute_matching_matrix(
frame2 = frames_spike_train2[index2]
not_a_match = abs(frame1 - frame2) > delta_frames
if not_a_match:
# Go to the next frame in the first train
break

# Map the match to a matrix
row, column = unit_indices1[index1], unit_indices2[index2]

# The same spike cannot be matched twice
# This condition is interpret matches asn accuracy task, see the documentation of the module
# Or the metrics section in spike forest to see the reason for this.
# The same spike cannot be matched twice see the notes in the docstring for more info on this constraint
if frame1 != previous_frame1_match[row, column] and frame2 != previous_frame2_match[row, column]:
previous_frame1_match[row, column] = frame1
previous_frame2_match[row, column] = frame2
Expand All @@ -217,7 +216,7 @@ def compute_matching_matrix(

index2 += 1

# Advance the minimal index 2 if not in the last loop iteration
# Advance the lower_search_limit_in_second_train if the next frame in the first train does not match
not_a_match_with_next = abs(next_frame1 - frame2) > delta_frames
if not_a_match_with_next:
lower_search_limit_in_second_train = index2
Expand All @@ -233,62 +232,42 @@ def compute_matching_matrix(
def make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=None):
num_units_sorting1 = sorting1.get_num_units()
num_units_sorting2 = sorting2.get_num_units()
num_segments_sorting1 = sorting1.get_num_segments()
num_segments_sorting2 = sorting2.get_num_segments()
unit1_ids = sorting1.get_unit_ids()
unit2_ids = sorting2.get_unit_ids()

sample_frames1_accumulator = []
unit_indices1_accumulator = []

sample_frames2_accumulator = []
unit_indices2_accumulator = []

for segment_index in range(num_segments_sorting1):
spike_trains1 = [sorting1.get_unit_spike_train(unit_id, segment_index) for unit_id in unit1_ids]
sample_frames1_accumulator.extend(spike_trains1)

unit_indices1 = [np.full(len(train), unit) for unit, train in enumerate(spike_trains1)]
unit_indices1_accumulator.extend(unit_indices1)

for segment_index in range(num_segments_sorting2):
spike_trains2 = [sorting2.get_unit_spike_train(unit_id, segment_index) for unit_id in unit2_ids]
sample_frames2_accumulator.extend(spike_trains2)

unit_indices2 = [np.full(len(train), unit) for unit, train in enumerate(spike_trains2)]
unit_indices2_accumulator.extend(unit_indices2)

# Concatenate accumulated data only once
sample_frames1_all = np.concatenate(sample_frames1_accumulator)
unit_indices1_all = np.concatenate(unit_indices1_accumulator)
matching_matrix = np.zeros((num_units_sorting1, num_units_sorting2), dtype=np.uint16)

sample_frames2_all = np.concatenate(sample_frames2_accumulator)
unit_indices2_all = np.concatenate(unit_indices2_accumulator)
spike_vector1_segments = sorting1.to_spike_vector(concatenated=False)
spike_vector2_segments = sorting2.to_spike_vector(concatenated=False)

# Sort the sample_frames and unit_indices arrays
sort_indices1 = np.argsort(sample_frames1_all)
sample_frames1_sorted = sample_frames1_all[sort_indices1]
unit_indices1_sorted = unit_indices1_all[sort_indices1]

sort_indices2 = np.argsort(sample_frames2_all)
sample_frames2_sorted = sample_frames2_all[sort_indices2]
unit_indices2_sorted = unit_indices2_all[sort_indices2]

compute_matching_matrix = get_optimized_compute_matching_matrix()

full_matrix = compute_matching_matrix(
sample_frames1_sorted,
sample_frames2_sorted,
unit_indices1_sorted,
unit_indices2_sorted,
num_units_sorting1,
num_units_sorting2,
delta_frames,
)
num_segments_sorting1 = sorting1.get_num_segments()
num_segments_sorting2 = sorting2.get_num_segments()
max_segment_to_compare = max(num_segments_sorting1, num_segments_sorting2)

# Segments should be matched one by one
for segment_index in range(max_segment_to_compare):
spike_vector1 = spike_vector1_segments[segment_index]
spike_vector2 = spike_vector2_segments[segment_index]

sample_frames1_sorted = spike_vector1["sample_index"]
sample_frames2_sorted = spike_vector2["sample_index"]

unit_indices1_sorted = spike_vector1["unit_index"]
unit_indices2_sorted = spike_vector2["unit_index"]

matching_matrix += get_optimized_compute_matching_matrix()(
sample_frames1_sorted,
sample_frames2_sorted,
unit_indices1_sorted,
unit_indices2_sorted,
num_units_sorting1,
num_units_sorting2,
delta_frames,
)

# Build a data frame from the matching matrix
import pandas as pd

match_event_counts_df = pd.DataFrame(full_matrix, index=unit1_ids, columns=unit2_ids)
unit_ids_of_sorting1 = sorting1.get_unit_ids()
unit_ids_of_sorting2 = sorting2.get_unit_ids()
match_event_counts_df = pd.DataFrame(matching_matrix, index=unit_ids_of_sorting1, columns=unit_ids_of_sorting2)

return match_event_counts_df

Expand Down

0 comments on commit 8da8614

Please sign in to comment.