Skip to content

Commit

Permalink
group in same file CollisionGTComparison and CollisionGTStudy
Browse files Browse the repository at this point in the history
group in same file CorrelogramGTComparison and CorrelogramGTStudy
  • Loading branch information
samuelgarcia committed Sep 19, 2023
1 parent 9b5b28b commit 8d9ce49
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 167 deletions.
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
import numpy as np

from .paircomparisons import GroundTruthComparison
from .groundtruthstudy import GroundTruthStudy
from .studytools import iter_computed_sorting ## TODO remove this
from .comparisontools import make_collision_events

import numpy as np






class CollisionGTComparison(GroundTruthComparison):
"""
Expand Down Expand Up @@ -156,3 +162,87 @@ def compute_collision_by_similarity(self, similarity_matrix, unit_ids=None, good
pair_names = pair_names[order]

return similarities, recall_scores, pair_names



class CollisionGTStudy(GroundTruthStudy):
def run_comparisons(self, exhaustive_gt=True, collision_lag=2.0, nbins=11, **kwargs):
self.comparisons = {}
for rec_name, sorter_name, sorting in iter_computed_sorting(self.study_folder):
gt_sorting = self.get_ground_truth(rec_name)
comp = CollisionGTComparison(
gt_sorting, sorting, exhaustive_gt=exhaustive_gt, collision_lag=collision_lag, nbins=nbins
)
self.comparisons[(rec_name, sorter_name)] = comp
self.exhaustive_gt = exhaustive_gt
self.collision_lag = collision_lag

def get_lags(self):
fs = self.comparisons[(self.rec_names[0], self.sorter_names[0])].sorting1.get_sampling_frequency()
lags = self.comparisons[(self.rec_names[0], self.sorter_names[0])].bins / fs * 1000
return lags

def precompute_scores_by_similarities(self, good_only=True, min_accuracy=0.9):
if not hasattr(self, "_good_only") or self._good_only != good_only:
import sklearn

similarity_matrix = {}
for rec_name in self.rec_names:
templates = self.get_templates(rec_name)
flat_templates = templates.reshape(templates.shape[0], -1)
similarity_matrix[rec_name] = sklearn.metrics.pairwise.cosine_similarity(flat_templates)

self.all_similarities = {}
self.all_recall_scores = {}
self.good_only = good_only

for sorter_ind, sorter_name in enumerate(self.sorter_names):
# loop over recordings
all_similarities = []
all_recall_scores = []

for rec_name in self.rec_names:
if (rec_name, sorter_name) in self.comparisons.keys():
comp = self.comparisons[(rec_name, sorter_name)]
similarities, recall_scores, pair_names = comp.compute_collision_by_similarity(
similarity_matrix[rec_name], good_only=good_only, min_accuracy=min_accuracy
)

all_similarities.append(similarities)
all_recall_scores.append(recall_scores)

self.all_similarities[sorter_name] = np.concatenate(all_similarities, axis=0)
self.all_recall_scores[sorter_name] = np.concatenate(all_recall_scores, axis=0)

def get_mean_over_similarity_range(self, similarity_range, sorter_name):
idx = (self.all_similarities[sorter_name] >= similarity_range[0]) & (
self.all_similarities[sorter_name] <= similarity_range[1]
)
all_similarities = self.all_similarities[sorter_name][idx]
all_recall_scores = self.all_recall_scores[sorter_name][idx]

order = np.argsort(all_similarities)
all_similarities = all_similarities[order]
all_recall_scores = all_recall_scores[order, :]

mean_recall_scores = np.nanmean(all_recall_scores, axis=0)

return mean_recall_scores

def get_lag_profile_over_similarity_bins(self, similarity_bins, sorter_name):
all_similarities = self.all_similarities[sorter_name]
all_recall_scores = self.all_recall_scores[sorter_name]

order = np.argsort(all_similarities)
all_similarities = all_similarities[order]
all_recall_scores = all_recall_scores[order, :]

result = {}

for i in range(similarity_bins.size - 1):
cmin, cmax = similarity_bins[i], similarity_bins[i + 1]
amin, amax = np.searchsorted(all_similarities, [cmin, cmax])
mean_recall_scores = np.nanmean(all_recall_scores[amin:amax], axis=0)
result[(cmin, cmax)] = mean_recall_scores

return result
88 changes: 0 additions & 88 deletions src/spikeinterface/comparison/collisionstudy.py

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import numpy as np
from .paircomparisons import GroundTruthComparison
from .groundtruthstudy import GroundTruthStudy
from .studytools import iter_computed_sorting ## TODO remove this
from spikeinterface.postprocessing import compute_correlograms


import numpy as np



class CorrelogramGTComparison(GroundTruthComparison):
"""
This class is an extension of GroundTruthComparison by focusing
Expand Down Expand Up @@ -108,3 +113,75 @@ def compute_correlogram_by_similarity(self, similarity_matrix, window_ms=None):
errors = errors[order, :]

return similarities, errors



class CorrelogramGTStudy(GroundTruthStudy):
def run_comparisons(self, exhaustive_gt=True, window_ms=100.0, bin_ms=1.0, well_detected_score=0.8, **kwargs):
self.comparisons = {}
for rec_name, sorter_name, sorting in iter_computed_sorting(self.study_folder):
gt_sorting = self.get_ground_truth(rec_name)
comp = CorrelogramGTComparison(
gt_sorting,
sorting,
exhaustive_gt=exhaustive_gt,
window_ms=window_ms,
bin_ms=bin_ms,
well_detected_score=well_detected_score,
)
self.comparisons[(rec_name, sorter_name)] = comp

self.exhaustive_gt = exhaustive_gt

@property
def time_bins(self):
for key, value in self.comparisons.items():
return value.time_bins

def precompute_scores_by_similarities(self, good_only=True):
if not hasattr(self, "_computed"):
import sklearn

similarity_matrix = {}
for rec_name in self.rec_names:
templates = self.get_templates(rec_name)
flat_templates = templates.reshape(templates.shape[0], -1)
similarity_matrix[rec_name] = sklearn.metrics.pairwise.cosine_similarity(flat_templates)

self.all_similarities = {}
self.all_errors = {}
self._computed = True

for sorter_ind, sorter_name in enumerate(self.sorter_names):
# loop over recordings
all_errors = []
all_similarities = []
for rec_name in self.rec_names:
try:
comp = self.comparisons[(rec_name, sorter_name)]
similarities, errors = comp.compute_correlogram_by_similarity(similarity_matrix[rec_name])
all_similarities.append(similarities)
all_errors.append(errors)
except Exception:
pass

self.all_similarities[sorter_name] = np.concatenate(all_similarities, axis=0)
self.all_errors[sorter_name] = np.concatenate(all_errors, axis=0)

def get_error_profile_over_similarity_bins(self, similarity_bins, sorter_name):
all_similarities = self.all_similarities[sorter_name]
all_errors = self.all_errors[sorter_name]

order = np.argsort(all_similarities)
all_similarities = all_similarities[order]
all_errors = all_errors[order, :]

result = {}

for i in range(similarity_bins.size - 1):
cmin, cmax = similarity_bins[i], similarity_bins[i + 1]
amin, amax = np.searchsorted(all_similarities, [cmin, cmax])
mean_errors = np.nanmean(all_errors[amin:amax], axis=0)
result[(cmin, cmax)] = mean_errors

return result
76 changes: 0 additions & 76 deletions src/spikeinterface/comparison/correlogramstudy.py

This file was deleted.

0 comments on commit 8d9ce49

Please sign in to comment.