Skip to content

Commit

Permalink
knn
Browse files Browse the repository at this point in the history
  • Loading branch information
yger committed Jun 29, 2024
1 parent 63da29e commit c3e2115
Showing 1 changed file with 15 additions and 31 deletions.
46 changes: 15 additions & 31 deletions src/spikeinterface/curation/auto_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def get_potential_auto_merge(
return potential_merges


def get_pairs_via_nntree(sorting_analyzer, k_nn=5, pair_mask=None, sparse_distances=False):
def get_pairs_via_nntree(sorting_analyzer, k_nn=5, pair_mask=None):

sorting = sorting_analyzer.sorting
unit_ids = sorting.unit_ids
Expand All @@ -349,42 +349,26 @@ def get_pairs_via_nntree(sorting_analyzer, k_nn=5, pair_mask=None, sparse_distan
from sklearn.neighbors import NearestNeighbors

data = (data - data.mean(0)) / data.std(0)

if sparse_distances:
import scipy.sparse
import sklearn.metrics

distances = scipy.sparse.lil_matrix((len(data), len(data)), dtype=np.float32)

for unit_ind1 in range(2):
valid = pair_mask[unit_ind1, unit_ind1+1:]
valid_indices = np.arange(unit_ind1+1, n)[valid]
mask_2 = np.isin(spikes["unit_index"], valid_indices)
if np.sum(mask_2) > 0:
mask_1 = spikes["unit_index"] == unit_ind1
tmp = sklearn.metrics.pairwise_distances(data[mask_1], data[mask_2])
distances[mask_1][:, mask_2] = tmp

all_spike_counts = sorting_analyzer.sorting.count_num_spikes_per_unit()
all_spike_counts = np.array(list(all_spike_counts.keys()))

if sparse_distances:
kdtree = NearestNeighbors(n_neighbors=k_nn, n_jobs=-1, metric="precomputed")
kdtree.fit(distances)
else:
kdtree = NearestNeighbors(n_neighbors=k_nn, n_jobs=-1)
kdtree.fit(data)
kdtree = NearestNeighbors(n_neighbors=k_nn, n_jobs=-1)
kdtree.fit(data)

for unit_ind in range(n):
print(unit_ind)
mask = spikes["unit_index"] == unit_ind
ind = kdtree.kneighbors(data[mask], return_distance=False)
ind = ind.flatten()
chan_inds, all_counts = np.unique(spikes["unit_index"][ind], return_counts=True)
all_counts = all_counts.astype(float)
#all_counts /= all_spike_counts[chan_inds]
best_indices = np.argsort(all_counts)[::-1][1:]
pair_mask[unit_ind] &= np.isin(np.arange(n), chan_inds[best_indices])
valid = pair_mask[unit_ind, unit_ind+1:]
valid_indices = np.arange(unit_ind+1, n)[valid]
if len(valid_indices) > 0:
ind = kdtree.kneighbors(data[mask], return_distance=False)
ind = ind.flatten()
mask_2 = np.isin(spikes["unit_index"][ind], valid_indices)
ind = ind[mask_2]
chan_inds, all_counts = np.unique(spikes["unit_index"][ind], return_counts=True)
all_counts = all_counts.astype(float)
#all_counts /= all_spike_counts[chan_inds]
best_indices = np.argsort(all_counts)[::-1][0:]
pair_mask[unit_ind] &= np.isin(np.arange(n), chan_inds[best_indices])
return pair_mask


Expand Down

0 comments on commit c3e2115

Please sign in to comment.