diff --git a/src/spikeinterface/comparison/collisioncomparison.py b/src/spikeinterface/comparison/collision.py similarity index 58% rename from src/spikeinterface/comparison/collisioncomparison.py rename to src/spikeinterface/comparison/collision.py index 3b279717b7..864809b04b 100644 --- a/src/spikeinterface/comparison/collisioncomparison.py +++ b/src/spikeinterface/comparison/collision.py @@ -1,8 +1,14 @@ -import numpy as np - from .paircomparisons import GroundTruthComparison +from .groundtruthstudy import GroundTruthStudy +from .studytools import iter_computed_sorting ## TODO remove this from .comparisontools import make_collision_events +import numpy as np + + + + + class CollisionGTComparison(GroundTruthComparison): """ @@ -156,3 +162,87 @@ def compute_collision_by_similarity(self, similarity_matrix, unit_ids=None, good pair_names = pair_names[order] return similarities, recall_scores, pair_names + + + +class CollisionGTStudy(GroundTruthStudy): + def run_comparisons(self, exhaustive_gt=True, collision_lag=2.0, nbins=11, **kwargs): + self.comparisons = {} + for rec_name, sorter_name, sorting in iter_computed_sorting(self.study_folder): + gt_sorting = self.get_ground_truth(rec_name) + comp = CollisionGTComparison( + gt_sorting, sorting, exhaustive_gt=exhaustive_gt, collision_lag=collision_lag, nbins=nbins + ) + self.comparisons[(rec_name, sorter_name)] = comp + self.exhaustive_gt = exhaustive_gt + self.collision_lag = collision_lag + + def get_lags(self): + fs = self.comparisons[(self.rec_names[0], self.sorter_names[0])].sorting1.get_sampling_frequency() + lags = self.comparisons[(self.rec_names[0], self.sorter_names[0])].bins / fs * 1000 + return lags + + def precompute_scores_by_similarities(self, good_only=True, min_accuracy=0.9): + if not hasattr(self, "_good_only") or self._good_only != good_only: + import sklearn + + similarity_matrix = {} + for rec_name in self.rec_names: + templates = self.get_templates(rec_name) + flat_templates = templates.reshape(templates.shape[0], -1) + similarity_matrix[rec_name] = sklearn.metrics.pairwise.cosine_similarity(flat_templates) + + self.all_similarities = {} + self.all_recall_scores = {} + self.good_only = good_only + + for sorter_ind, sorter_name in enumerate(self.sorter_names): + # loop over recordings + all_similarities = [] + all_recall_scores = [] + + for rec_name in self.rec_names: + if (rec_name, sorter_name) in self.comparisons.keys(): + comp = self.comparisons[(rec_name, sorter_name)] + similarities, recall_scores, pair_names = comp.compute_collision_by_similarity( + similarity_matrix[rec_name], good_only=good_only, min_accuracy=min_accuracy + ) + + all_similarities.append(similarities) + all_recall_scores.append(recall_scores) + + self.all_similarities[sorter_name] = np.concatenate(all_similarities, axis=0) + self.all_recall_scores[sorter_name] = np.concatenate(all_recall_scores, axis=0) + + def get_mean_over_similarity_range(self, similarity_range, sorter_name): + idx = (self.all_similarities[sorter_name] >= similarity_range[0]) & ( + self.all_similarities[sorter_name] <= similarity_range[1] + ) + all_similarities = self.all_similarities[sorter_name][idx] + all_recall_scores = self.all_recall_scores[sorter_name][idx] + + order = np.argsort(all_similarities) + all_similarities = all_similarities[order] + all_recall_scores = all_recall_scores[order, :] + + mean_recall_scores = np.nanmean(all_recall_scores, axis=0) + + return mean_recall_scores + + def get_lag_profile_over_similarity_bins(self, similarity_bins, sorter_name): + all_similarities = self.all_similarities[sorter_name] + all_recall_scores = self.all_recall_scores[sorter_name] + + order = np.argsort(all_similarities) + all_similarities = all_similarities[order] + all_recall_scores = all_recall_scores[order, :] + + result = {} + + for i in range(similarity_bins.size - 1): + cmin, cmax = similarity_bins[i], similarity_bins[i + 1] + amin, amax = np.searchsorted(all_similarities, [cmin, cmax]) + mean_recall_scores = np.nanmean(all_recall_scores[amin:amax], axis=0) + result[(cmin, cmax)] = mean_recall_scores + + return result diff --git a/src/spikeinterface/comparison/collisionstudy.py b/src/spikeinterface/comparison/collisionstudy.py deleted file mode 100644 index 34a556e8b9..0000000000 --- a/src/spikeinterface/comparison/collisionstudy.py +++ /dev/null @@ -1,88 +0,0 @@ -from .groundtruthstudy import GroundTruthStudy -from .studytools import iter_computed_sorting -from .collisioncomparison import CollisionGTComparison - -import numpy as np - - -class CollisionGTStudy(GroundTruthStudy): - def run_comparisons(self, exhaustive_gt=True, collision_lag=2.0, nbins=11, **kwargs): - self.comparisons = {} - for rec_name, sorter_name, sorting in iter_computed_sorting(self.study_folder): - gt_sorting = self.get_ground_truth(rec_name) - comp = CollisionGTComparison( - gt_sorting, sorting, exhaustive_gt=exhaustive_gt, collision_lag=collision_lag, nbins=nbins - ) - self.comparisons[(rec_name, sorter_name)] = comp - self.exhaustive_gt = exhaustive_gt - self.collision_lag = collision_lag - - def get_lags(self): - fs = self.comparisons[(self.rec_names[0], self.sorter_names[0])].sorting1.get_sampling_frequency() - lags = self.comparisons[(self.rec_names[0], self.sorter_names[0])].bins / fs * 1000 - return lags - - def precompute_scores_by_similarities(self, good_only=True, min_accuracy=0.9): - if not hasattr(self, "_good_only") or self._good_only != good_only: - import sklearn - - similarity_matrix = {} - for rec_name in self.rec_names: - templates = self.get_templates(rec_name) - flat_templates = templates.reshape(templates.shape[0], -1) - similarity_matrix[rec_name] = sklearn.metrics.pairwise.cosine_similarity(flat_templates) - - self.all_similarities = {} - self.all_recall_scores = {} - self.good_only = good_only - - for sorter_ind, sorter_name in enumerate(self.sorter_names): - # loop over recordings - all_similarities = [] - all_recall_scores = [] - - for rec_name in self.rec_names: - if (rec_name, sorter_name) in self.comparisons.keys(): - comp = self.comparisons[(rec_name, sorter_name)] - similarities, recall_scores, pair_names = comp.compute_collision_by_similarity( - similarity_matrix[rec_name], good_only=good_only, min_accuracy=min_accuracy - ) - - all_similarities.append(similarities) - all_recall_scores.append(recall_scores) - - self.all_similarities[sorter_name] = np.concatenate(all_similarities, axis=0) - self.all_recall_scores[sorter_name] = np.concatenate(all_recall_scores, axis=0) - - def get_mean_over_similarity_range(self, similarity_range, sorter_name): - idx = (self.all_similarities[sorter_name] >= similarity_range[0]) & ( - self.all_similarities[sorter_name] <= similarity_range[1] - ) - all_similarities = self.all_similarities[sorter_name][idx] - all_recall_scores = self.all_recall_scores[sorter_name][idx] - - order = np.argsort(all_similarities) - all_similarities = all_similarities[order] - all_recall_scores = all_recall_scores[order, :] - - mean_recall_scores = np.nanmean(all_recall_scores, axis=0) - - return mean_recall_scores - - def get_lag_profile_over_similarity_bins(self, similarity_bins, sorter_name): - all_similarities = self.all_similarities[sorter_name] - all_recall_scores = self.all_recall_scores[sorter_name] - - order = np.argsort(all_similarities) - all_similarities = all_similarities[order] - all_recall_scores = all_recall_scores[order, :] - - result = {} - - for i in range(similarity_bins.size - 1): - cmin, cmax = similarity_bins[i], similarity_bins[i + 1] - amin, amax = np.searchsorted(all_similarities, [cmin, cmax]) - mean_recall_scores = np.nanmean(all_recall_scores[amin:amax], axis=0) - result[(cmin, cmax)] = mean_recall_scores - - return result diff --git a/src/spikeinterface/comparison/correlogramcomparison.py b/src/spikeinterface/comparison/correlogram.py similarity index 58% rename from src/spikeinterface/comparison/correlogramcomparison.py rename to src/spikeinterface/comparison/correlogram.py index 80e881a152..9c5e1e91cf 100644 --- a/src/spikeinterface/comparison/correlogramcomparison.py +++ b/src/spikeinterface/comparison/correlogram.py @@ -1,8 +1,13 @@ -import numpy as np from .paircomparisons import GroundTruthComparison +from .groundtruthstudy import GroundTruthStudy +from .studytools import iter_computed_sorting ## TODO remove this from spikeinterface.postprocessing import compute_correlograms +import numpy as np + + + class CorrelogramGTComparison(GroundTruthComparison): """ This class is an extension of GroundTruthComparison by focusing @@ -108,3 +113,75 @@ def compute_correlogram_by_similarity(self, similarity_matrix, window_ms=None): errors = errors[order, :] return similarities, errors + + + +class CorrelogramGTStudy(GroundTruthStudy): + def run_comparisons(self, exhaustive_gt=True, window_ms=100.0, bin_ms=1.0, well_detected_score=0.8, **kwargs): + self.comparisons = {} + for rec_name, sorter_name, sorting in iter_computed_sorting(self.study_folder): + gt_sorting = self.get_ground_truth(rec_name) + comp = CorrelogramGTComparison( + gt_sorting, + sorting, + exhaustive_gt=exhaustive_gt, + window_ms=window_ms, + bin_ms=bin_ms, + well_detected_score=well_detected_score, + ) + self.comparisons[(rec_name, sorter_name)] = comp + + self.exhaustive_gt = exhaustive_gt + + @property + def time_bins(self): + for key, value in self.comparisons.items(): + return value.time_bins + + def precompute_scores_by_similarities(self, good_only=True): + if not hasattr(self, "_computed"): + import sklearn + + similarity_matrix = {} + for rec_name in self.rec_names: + templates = self.get_templates(rec_name) + flat_templates = templates.reshape(templates.shape[0], -1) + similarity_matrix[rec_name] = sklearn.metrics.pairwise.cosine_similarity(flat_templates) + + self.all_similarities = {} + self.all_errors = {} + self._computed = True + + for sorter_ind, sorter_name in enumerate(self.sorter_names): + # loop over recordings + all_errors = [] + all_similarities = [] + for rec_name in self.rec_names: + try: + comp = self.comparisons[(rec_name, sorter_name)] + similarities, errors = comp.compute_correlogram_by_similarity(similarity_matrix[rec_name]) + all_similarities.append(similarities) + all_errors.append(errors) + except Exception: + pass + + self.all_similarities[sorter_name] = np.concatenate(all_similarities, axis=0) + self.all_errors[sorter_name] = np.concatenate(all_errors, axis=0) + + def get_error_profile_over_similarity_bins(self, similarity_bins, sorter_name): + all_similarities = self.all_similarities[sorter_name] + all_errors = self.all_errors[sorter_name] + + order = np.argsort(all_similarities) + all_similarities = all_similarities[order] + all_errors = all_errors[order, :] + + result = {} + + for i in range(similarity_bins.size - 1): + cmin, cmax = similarity_bins[i], similarity_bins[i + 1] + amin, amax = np.searchsorted(all_similarities, [cmin, cmax]) + mean_errors = np.nanmean(all_errors[amin:amax], axis=0) + result[(cmin, cmax)] = mean_errors + + return result diff --git a/src/spikeinterface/comparison/correlogramstudy.py b/src/spikeinterface/comparison/correlogramstudy.py deleted file mode 100644 index fb00c08157..0000000000 --- a/src/spikeinterface/comparison/correlogramstudy.py +++ /dev/null @@ -1,76 +0,0 @@ -from .groundtruthstudy import GroundTruthStudy -from .studytools import iter_computed_sorting -from .correlogramcomparison import CorrelogramGTComparison - -import numpy as np - - -class CorrelogramGTStudy(GroundTruthStudy): - def run_comparisons(self, exhaustive_gt=True, window_ms=100.0, bin_ms=1.0, well_detected_score=0.8, **kwargs): - self.comparisons = {} - for rec_name, sorter_name, sorting in iter_computed_sorting(self.study_folder): - gt_sorting = self.get_ground_truth(rec_name) - comp = CorrelogramGTComparison( - gt_sorting, - sorting, - exhaustive_gt=exhaustive_gt, - window_ms=window_ms, - bin_ms=bin_ms, - well_detected_score=well_detected_score, - ) - self.comparisons[(rec_name, sorter_name)] = comp - - self.exhaustive_gt = exhaustive_gt - - @property - def time_bins(self): - for key, value in self.comparisons.items(): - return value.time_bins - - def precompute_scores_by_similarities(self, good_only=True): - if not hasattr(self, "_computed"): - import sklearn - - similarity_matrix = {} - for rec_name in self.rec_names: - templates = self.get_templates(rec_name) - flat_templates = templates.reshape(templates.shape[0], -1) - similarity_matrix[rec_name] = sklearn.metrics.pairwise.cosine_similarity(flat_templates) - - self.all_similarities = {} - self.all_errors = {} - self._computed = True - - for sorter_ind, sorter_name in enumerate(self.sorter_names): - # loop over recordings - all_errors = [] - all_similarities = [] - for rec_name in self.rec_names: - try: - comp = self.comparisons[(rec_name, sorter_name)] - similarities, errors = comp.compute_correlogram_by_similarity(similarity_matrix[rec_name]) - all_similarities.append(similarities) - all_errors.append(errors) - except Exception: - pass - - self.all_similarities[sorter_name] = np.concatenate(all_similarities, axis=0) - self.all_errors[sorter_name] = np.concatenate(all_errors, axis=0) - - def get_error_profile_over_similarity_bins(self, similarity_bins, sorter_name): - all_similarities = self.all_similarities[sorter_name] - all_errors = self.all_errors[sorter_name] - - order = np.argsort(all_similarities) - all_similarities = all_similarities[order] - all_errors = all_errors[order, :] - - result = {} - - for i in range(similarity_bins.size - 1): - cmin, cmax = similarity_bins[i], similarity_bins[i + 1] - amin, amax = np.searchsorted(all_similarities, [cmin, cmax]) - mean_errors = np.nanmean(all_errors[amin:amax], axis=0) - result[(cmin, cmax)] = mean_errors - - return result