From 8da861482c755bf4abba2d8fcb35bf759e485a55 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 24 Oct 2023 15:57:13 +0200 Subject: [PATCH] added spike vector and separate segments --- .../comparison/comparisontools.py | 93 +++++++------------ 1 file changed, 36 insertions(+), 57 deletions(-) diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index 924fe70174..f01db549bb 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -179,7 +179,7 @@ def compute_matching_matrix( the metrics section in SpikeForest documentation. """ - matching_matrix = np.zeros((num_units_sorting1, num_units_sorting2), dtype=np.int64) + matching_matrix = np.zeros((num_units_sorting1, num_units_sorting2), dtype=np.uint16) # Used to avoid the same spike matching twice previous_frame1_match = -np.ones_like(matching_matrix, dtype=np.int64) @@ -188,7 +188,7 @@ def compute_matching_matrix( lower_search_limit_in_second_train = 0 for index1 in range(len(frames_spike_train1)): - # Keeps track of which frame in the second spike train should be used as a search start + # Keeps track of which frame in the second spike train should be used as a search start for matches index2 = lower_search_limit_in_second_train frame1 = frames_spike_train1[index1] @@ -201,14 +201,13 @@ def compute_matching_matrix( frame2 = frames_spike_train2[index2] not_a_match = abs(frame1 - frame2) > delta_frames if not_a_match: + # Go to the next frame in the first train break # Map the match to a matrix row, column = unit_indices1[index1], unit_indices2[index2] - # The same spike cannot be matched twice - # This condition is interpret matches asn accuracy task, see the documentation of the module - # Or the metrics section in spike forest to see the reason for this. + # The same spike cannot be matched twice see the notes in the docstring for more info on this constraint if frame1 != previous_frame1_match[row, column] and frame2 != previous_frame2_match[row, column]: previous_frame1_match[row, column] = frame1 previous_frame2_match[row, column] = frame2 @@ -217,7 +216,7 @@ def compute_matching_matrix( index2 += 1 - # Advance the minimal index 2 if not in the last loop iteration + # Advance the lower_search_limit_in_second_train if the next frame in the first train does not match not_a_match_with_next = abs(next_frame1 - frame2) > delta_frames if not_a_match_with_next: lower_search_limit_in_second_train = index2 @@ -233,62 +232,42 @@ def compute_matching_matrix( def make_match_count_matrix(sorting1, sorting2, delta_frames, n_jobs=None): num_units_sorting1 = sorting1.get_num_units() num_units_sorting2 = sorting2.get_num_units() - num_segments_sorting1 = sorting1.get_num_segments() - num_segments_sorting2 = sorting2.get_num_segments() - unit1_ids = sorting1.get_unit_ids() - unit2_ids = sorting2.get_unit_ids() - - sample_frames1_accumulator = [] - unit_indices1_accumulator = [] - - sample_frames2_accumulator = [] - unit_indices2_accumulator = [] - - for segment_index in range(num_segments_sorting1): - spike_trains1 = [sorting1.get_unit_spike_train(unit_id, segment_index) for unit_id in unit1_ids] - sample_frames1_accumulator.extend(spike_trains1) - - unit_indices1 = [np.full(len(train), unit) for unit, train in enumerate(spike_trains1)] - unit_indices1_accumulator.extend(unit_indices1) - - for segment_index in range(num_segments_sorting2): - spike_trains2 = [sorting2.get_unit_spike_train(unit_id, segment_index) for unit_id in unit2_ids] - sample_frames2_accumulator.extend(spike_trains2) - - unit_indices2 = [np.full(len(train), unit) for unit, train in enumerate(spike_trains2)] - unit_indices2_accumulator.extend(unit_indices2) - - # Concatenate accumulated data only once - sample_frames1_all = np.concatenate(sample_frames1_accumulator) - unit_indices1_all = np.concatenate(unit_indices1_accumulator) + matching_matrix = np.zeros((num_units_sorting1, num_units_sorting2), dtype=np.uint16) - sample_frames2_all = np.concatenate(sample_frames2_accumulator) - unit_indices2_all = np.concatenate(unit_indices2_accumulator) + spike_vector1_segments = sorting1.to_spike_vector(concatenated=False) + spike_vector2_segments = sorting2.to_spike_vector(concatenated=False) - # Sort the sample_frames and unit_indices arrays - sort_indices1 = np.argsort(sample_frames1_all) - sample_frames1_sorted = sample_frames1_all[sort_indices1] - unit_indices1_sorted = unit_indices1_all[sort_indices1] - - sort_indices2 = np.argsort(sample_frames2_all) - sample_frames2_sorted = sample_frames2_all[sort_indices2] - unit_indices2_sorted = unit_indices2_all[sort_indices2] - - compute_matching_matrix = get_optimized_compute_matching_matrix() - - full_matrix = compute_matching_matrix( - sample_frames1_sorted, - sample_frames2_sorted, - unit_indices1_sorted, - unit_indices2_sorted, - num_units_sorting1, - num_units_sorting2, - delta_frames, - ) + num_segments_sorting1 = sorting1.get_num_segments() + num_segments_sorting2 = sorting2.get_num_segments() + max_segment_to_compare = max(num_segments_sorting1, num_segments_sorting2) + + # Segments should be matched one by one + for segment_index in range(max_segment_to_compare): + spike_vector1 = spike_vector1_segments[segment_index] + spike_vector2 = spike_vector2_segments[segment_index] + + sample_frames1_sorted = spike_vector1["sample_index"] + sample_frames2_sorted = spike_vector2["sample_index"] + + unit_indices1_sorted = spike_vector1["unit_index"] + unit_indices2_sorted = spike_vector2["unit_index"] + + matching_matrix += get_optimized_compute_matching_matrix()( + sample_frames1_sorted, + sample_frames2_sorted, + unit_indices1_sorted, + unit_indices2_sorted, + num_units_sorting1, + num_units_sorting2, + delta_frames, + ) + # Build a data frame from the matching matrix import pandas as pd - match_event_counts_df = pd.DataFrame(full_matrix, index=unit1_ids, columns=unit2_ids) + unit_ids_of_sorting1 = sorting1.get_unit_ids() + unit_ids_of_sorting2 = sorting2.get_unit_ids() + match_event_counts_df = pd.DataFrame(matching_matrix, index=unit_ids_of_sorting1, columns=unit_ids_of_sorting2) return match_event_counts_df