diff --git a/doc/modules/comparison.rst b/doc/modules/comparison.rst index b452307e3c..76ab7855c6 100644 --- a/doc/modules/comparison.rst +++ b/doc/modules/comparison.rst @@ -248,21 +248,19 @@ An **over-merged** unit has a relatively high agreement (>= 0.2 by default) for We also have a high level class to compare many sorters against ground truth: :py:func:`~spiekinterface.comparison.GroundTruthStudy()` -A study is a systematic performance comparison of several ground truth recordings with several sorters. +A study is a systematic performance comparison of several ground truth recordings with several sorters or several cases +like the different parameter sets. -The study class proposes high-level tool functions to run many ground truth comparisons with many sorters +The study class proposes high-level tool functions to run many ground truth comparisons with many "cases" on many recordings and then collect and aggregate results in an easy way. The all mechanism is based on an intrinsic organization into a "study_folder" with several subfolder: - * raw_files : contain a copy of recordings in binary format - * sorter_folders : contains outputs of sorters - * ground_truth : contains a copy of sorting ground truth in npz format - * sortings: contains light copy of all sorting in npz format - * tables: some tables in csv format - -In order to run and rerun the computation all gt_sorting and recordings are copied to a fast and universal format: -binary (for recordings) and npz (for sortings). + * datasets: contains ground truth datasets + * sorters : contains outputs of sorters + * sortings: contains light copy of all sorting + * metrics: contains metrics + * ... .. code-block:: python @@ -274,28 +272,51 @@ binary (for recordings) and npz (for sortings). import spikeinterface.widgets as sw from spikeinterface.comparison import GroundTruthStudy - # Setup study folder - rec0, gt_sorting0 = se.toy_example(num_channels=4, duration=10, seed=10, num_segments=1) - rec1, gt_sorting1 = se.toy_example(num_channels=4, duration=10, seed=0, num_segments=1) - gt_dict = { - 'rec0': (rec0, gt_sorting0), - 'rec1': (rec1, gt_sorting1), + + # generate 2 simulated datasets (could be also mearec files) + rec0, gt_sorting0 = generate_ground_truth_recording(num_channels=4, durations=[30.], seed=42) + rec1, gt_sorting1 = generate_ground_truth_recording(num_channels=4, durations=[30.], seed=91) + + datasets = { + "toy0": (rec0, gt_sorting0), + "toy1": (rec1, gt_sorting1), } - study_folder = 'a_study_folder' - study = GroundTruthStudy.create(study_folder, gt_dict) - # all sorters for all recordings in one function. - sorter_list = ['herdingspikes', 'tridesclous', ] - study.run_sorters(sorter_list, mode_if_folder_exists="keep") + # define some "cases" here we want to tests tridesclous2 on 2 datasets and spykingcircus on one dataset + # so it is a two level study (sorter_name, dataset) + # this could be more complicated like (sorter_name, dataset, params) + cases = { + ("tdc2", "toy0"): { + "label": "tridesclous2 on tetrode0", + "dataset": "toy0", + "run_sorter_params": { + "sorter_name": "tridesclous2", + }, + }, + ("tdc2", "toy1"): { + "label": "tridesclous2 on tetrode1", + "dataset": "toy1", + "run_sorter_params": { + "sorter_name": "tridesclous2", + }, + }, + + ("sc", "toy0"): { + "label": "spykingcircus2 on tetrode0", + "dataset": "toy0", + "run_sorter_params": { + "sorter_name": "spykingcircus", + "docker_image": True + }, + }, + } + # this initilize a folder + study = GroundTruthStudy.create(study_folder, datasets=datasets, cases=cases, + levels=["sorter_name", "dataset"]) - # You can re-run **run_study_sorters** as many times as you want. - # By default **mode='keep'** so only uncomputed sorters are re-run. - # For instance, just remove the "sorter_folders/rec1/herdingspikes" to re-run - # only one sorter on one recording. - # - # Then we copy the spike sorting outputs into a separate subfolder. - # This allow us to remove the "large" sorter_folders. - study.copy_sortings() + + # all cases in one function + study.run_sorters() # Collect comparisons #   @@ -306,11 +327,11 @@ binary (for recordings) and npz (for sortings). # Note: use exhaustive_gt=True when you know exactly how many # units in ground truth (for synthetic datasets) + # run all comparisons and loop over the results study.run_comparisons(exhaustive_gt=True) - - for (rec_name, sorter_name), comp in study.comparisons.items(): + for key, comp in study.comparisons.items(): print('*' * 10) - print(rec_name, sorter_name) + print(key) # raw counting of tp/fp/... print(comp.count_score) # summary @@ -323,26 +344,27 @@ binary (for recordings) and npz (for sortings). # Collect synthetic dataframes and display # As shown previously, the performance is returned as a pandas dataframe. - # The :py:func:`~spikeinterface.comparison.aggregate_performances_table()` function, + # The :py:func:`~spikeinterface.comparison.get_performance_by_unit()` function, # gathers all the outputs in the study folder and merges them in a single dataframe. + # Same idea for :py:func:`~spikeinterface.comparison.get_count_units()` - dataframes = study.aggregate_dataframes() + # this is a dataframe + perfs = study.get_performance_by_unit() - # Pandas dataframes can be nicely displayed as tables in the notebook. - print(dataframes.keys()) + # this is a dataframe + unit_counts = study.get_count_units() # we can also access run times - print(dataframes['run_times']) + run_times = study.get_run_times() + print(run_times) # Easy plot with seaborn - run_times = dataframes['run_times'] fig1, ax1 = plt.subplots() sns.barplot(data=run_times, x='rec_name', y='run_time', hue='sorter_name', ax=ax1) ax1.set_title('Run times') ############################################################################## - perfs = dataframes['perf_by_unit'] fig2, ax2 = plt.subplots() sns.swarmplot(data=perfs, x='sorter_name', y='recall', hue='rec_name', ax=ax2) ax2.set_title('Recall') diff --git a/src/spikeinterface/comparison/__init__.py b/src/spikeinterface/comparison/__init__.py index a390bb7689..bff85dde4a 100644 --- a/src/spikeinterface/comparison/__init__.py +++ b/src/spikeinterface/comparison/__init__.py @@ -28,12 +28,11 @@ compare_multiple_templates, MultiTemplateComparison, ) -from .collisioncomparison import CollisionGTComparison -from .correlogramcomparison import CorrelogramGTComparison + from .groundtruthstudy import GroundTruthStudy -from .collisionstudy import CollisionGTStudy -from .correlogramstudy import CorrelogramGTStudy -from .studytools import aggregate_performances_table +from .collision import CollisionGTComparison, CollisionGTStudy +from .correlogram import CorrelogramGTComparison, CorrelogramGTStudy + from .hybrid import ( HybridSpikesRecording, HybridUnitsRecording, diff --git a/src/spikeinterface/comparison/collisioncomparison.py b/src/spikeinterface/comparison/collision.py similarity index 64% rename from src/spikeinterface/comparison/collisioncomparison.py rename to src/spikeinterface/comparison/collision.py index 3b279717b7..dd04b2c72d 100644 --- a/src/spikeinterface/comparison/collisioncomparison.py +++ b/src/spikeinterface/comparison/collision.py @@ -1,13 +1,15 @@ -import numpy as np - from .paircomparisons import GroundTruthComparison +from .groundtruthstudy import GroundTruthStudy from .comparisontools import make_collision_events +import numpy as np + class CollisionGTComparison(GroundTruthComparison): """ - This class is an extension of GroundTruthComparison by focusing - to benchmark spike in collision + This class is an extension of GroundTruthComparison by focusing to benchmark spike in collision. + + This class needs maintenance and need a bit of refactoring. collision_lag: float @@ -156,3 +158,73 @@ def compute_collision_by_similarity(self, similarity_matrix, unit_ids=None, good pair_names = pair_names[order] return similarities, recall_scores, pair_names + + +class CollisionGTStudy(GroundTruthStudy): + def run_comparisons(self, case_keys=None, exhaustive_gt=True, collision_lag=2.0, nbins=11, **kwargs): + _kwargs = dict() + _kwargs.update(kwargs) + _kwargs["exhaustive_gt"] = exhaustive_gt + _kwargs["collision_lag"] = collision_lag + _kwargs["nbins"] = nbins + GroundTruthStudy.run_comparisons(self, case_keys=None, comparison_class=CollisionGTComparison, **_kwargs) + self.exhaustive_gt = exhaustive_gt + self.collision_lag = collision_lag + + def get_lags(self, key): + comp = self.comparisons[key] + fs = comp.sorting1.get_sampling_frequency() + lags = comp.bins / fs * 1000.0 + return lags + + def precompute_scores_by_similarities(self, case_keys=None, good_only=False, min_accuracy=0.9): + import sklearn + + if case_keys is None: + case_keys = self.cases.keys() + + self.all_similarities = {} + self.all_recall_scores = {} + self.good_only = good_only + + for key in case_keys: + templates = self.get_templates(key) + flat_templates = templates.reshape(templates.shape[0], -1) + similarity = sklearn.metrics.pairwise.cosine_similarity(flat_templates) + comp = self.comparisons[key] + similarities, recall_scores, pair_names = comp.compute_collision_by_similarity( + similarity, good_only=good_only, min_accuracy=min_accuracy + ) + self.all_similarities[key] = similarities + self.all_recall_scores[key] = recall_scores + + def get_mean_over_similarity_range(self, similarity_range, key): + idx = (self.all_similarities[key] >= similarity_range[0]) & (self.all_similarities[key] <= similarity_range[1]) + all_similarities = self.all_similarities[key][idx] + all_recall_scores = self.all_recall_scores[key][idx] + + order = np.argsort(all_similarities) + all_similarities = all_similarities[order] + all_recall_scores = all_recall_scores[order, :] + + mean_recall_scores = np.nanmean(all_recall_scores, axis=0) + + return mean_recall_scores + + def get_lag_profile_over_similarity_bins(self, similarity_bins, key): + all_similarities = self.all_similarities[key] + all_recall_scores = self.all_recall_scores[key] + + order = np.argsort(all_similarities) + all_similarities = all_similarities[order] + all_recall_scores = all_recall_scores[order, :] + + result = {} + + for i in range(similarity_bins.size - 1): + cmin, cmax = similarity_bins[i], similarity_bins[i + 1] + amin, amax = np.searchsorted(all_similarities, [cmin, cmax]) + mean_recall_scores = np.nanmean(all_recall_scores[amin:amax], axis=0) + result[(cmin, cmax)] = mean_recall_scores + + return result diff --git a/src/spikeinterface/comparison/collisionstudy.py b/src/spikeinterface/comparison/collisionstudy.py deleted file mode 100644 index 34a556e8b9..0000000000 --- a/src/spikeinterface/comparison/collisionstudy.py +++ /dev/null @@ -1,88 +0,0 @@ -from .groundtruthstudy import GroundTruthStudy -from .studytools import iter_computed_sorting -from .collisioncomparison import CollisionGTComparison - -import numpy as np - - -class CollisionGTStudy(GroundTruthStudy): - def run_comparisons(self, exhaustive_gt=True, collision_lag=2.0, nbins=11, **kwargs): - self.comparisons = {} - for rec_name, sorter_name, sorting in iter_computed_sorting(self.study_folder): - gt_sorting = self.get_ground_truth(rec_name) - comp = CollisionGTComparison( - gt_sorting, sorting, exhaustive_gt=exhaustive_gt, collision_lag=collision_lag, nbins=nbins - ) - self.comparisons[(rec_name, sorter_name)] = comp - self.exhaustive_gt = exhaustive_gt - self.collision_lag = collision_lag - - def get_lags(self): - fs = self.comparisons[(self.rec_names[0], self.sorter_names[0])].sorting1.get_sampling_frequency() - lags = self.comparisons[(self.rec_names[0], self.sorter_names[0])].bins / fs * 1000 - return lags - - def precompute_scores_by_similarities(self, good_only=True, min_accuracy=0.9): - if not hasattr(self, "_good_only") or self._good_only != good_only: - import sklearn - - similarity_matrix = {} - for rec_name in self.rec_names: - templates = self.get_templates(rec_name) - flat_templates = templates.reshape(templates.shape[0], -1) - similarity_matrix[rec_name] = sklearn.metrics.pairwise.cosine_similarity(flat_templates) - - self.all_similarities = {} - self.all_recall_scores = {} - self.good_only = good_only - - for sorter_ind, sorter_name in enumerate(self.sorter_names): - # loop over recordings - all_similarities = [] - all_recall_scores = [] - - for rec_name in self.rec_names: - if (rec_name, sorter_name) in self.comparisons.keys(): - comp = self.comparisons[(rec_name, sorter_name)] - similarities, recall_scores, pair_names = comp.compute_collision_by_similarity( - similarity_matrix[rec_name], good_only=good_only, min_accuracy=min_accuracy - ) - - all_similarities.append(similarities) - all_recall_scores.append(recall_scores) - - self.all_similarities[sorter_name] = np.concatenate(all_similarities, axis=0) - self.all_recall_scores[sorter_name] = np.concatenate(all_recall_scores, axis=0) - - def get_mean_over_similarity_range(self, similarity_range, sorter_name): - idx = (self.all_similarities[sorter_name] >= similarity_range[0]) & ( - self.all_similarities[sorter_name] <= similarity_range[1] - ) - all_similarities = self.all_similarities[sorter_name][idx] - all_recall_scores = self.all_recall_scores[sorter_name][idx] - - order = np.argsort(all_similarities) - all_similarities = all_similarities[order] - all_recall_scores = all_recall_scores[order, :] - - mean_recall_scores = np.nanmean(all_recall_scores, axis=0) - - return mean_recall_scores - - def get_lag_profile_over_similarity_bins(self, similarity_bins, sorter_name): - all_similarities = self.all_similarities[sorter_name] - all_recall_scores = self.all_recall_scores[sorter_name] - - order = np.argsort(all_similarities) - all_similarities = all_similarities[order] - all_recall_scores = all_recall_scores[order, :] - - result = {} - - for i in range(similarity_bins.size - 1): - cmin, cmax = similarity_bins[i], similarity_bins[i + 1] - amin, amax = np.searchsorted(all_similarities, [cmin, cmax]) - mean_recall_scores = np.nanmean(all_recall_scores[amin:amax], axis=0) - result[(cmin, cmax)] = mean_recall_scores - - return result diff --git a/src/spikeinterface/comparison/correlogramcomparison.py b/src/spikeinterface/comparison/correlogram.py similarity index 64% rename from src/spikeinterface/comparison/correlogramcomparison.py rename to src/spikeinterface/comparison/correlogram.py index 80e881a152..aaffef1887 100644 --- a/src/spikeinterface/comparison/correlogramcomparison.py +++ b/src/spikeinterface/comparison/correlogram.py @@ -1,16 +1,17 @@ -import numpy as np from .paircomparisons import GroundTruthComparison +from .groundtruthstudy import GroundTruthStudy from spikeinterface.postprocessing import compute_correlograms +import numpy as np + + class CorrelogramGTComparison(GroundTruthComparison): """ This class is an extension of GroundTruthComparison by focusing - to benchmark correlation reconstruction - + to benchmark correlation reconstruction. - collision_lag: float - Collision lag in ms. + This class needs maintenance and need a bit of refactoring. """ @@ -105,6 +106,62 @@ def compute_correlogram_by_similarity(self, similarity_matrix, window_ms=None): order = np.argsort(similarities) similarities = similarities[order] - errors = errors[order, :] + errors = errors[order] return similarities, errors + + +class CorrelogramGTStudy(GroundTruthStudy): + def run_comparisons( + self, case_keys=None, exhaustive_gt=True, window_ms=100.0, bin_ms=1.0, well_detected_score=0.8, **kwargs + ): + _kwargs = dict() + _kwargs.update(kwargs) + _kwargs["exhaustive_gt"] = exhaustive_gt + _kwargs["window_ms"] = window_ms + _kwargs["bin_ms"] = bin_ms + _kwargs["well_detected_score"] = well_detected_score + GroundTruthStudy.run_comparisons(self, case_keys=None, comparison_class=CorrelogramGTComparison, **_kwargs) + self.exhaustive_gt = exhaustive_gt + + @property + def time_bins(self): + for key, value in self.comparisons.items(): + return value.time_bins + + def precompute_scores_by_similarities(self, case_keys=None, good_only=True): + import sklearn.metrics + + if case_keys is None: + case_keys = self.cases.keys() + + self.all_similarities = {} + self.all_errors = {} + + for key in case_keys: + templates = self.get_templates(key) + flat_templates = templates.reshape(templates.shape[0], -1) + similarity = sklearn.metrics.pairwise.cosine_similarity(flat_templates) + comp = self.comparisons[key] + similarities, errors = comp.compute_correlogram_by_similarity(similarity) + + self.all_similarities[key] = similarities + self.all_errors[key] = errors + + def get_error_profile_over_similarity_bins(self, similarity_bins, key): + all_similarities = self.all_similarities[key] + all_errors = self.all_errors[key] + + order = np.argsort(all_similarities) + all_similarities = all_similarities[order] + all_errors = all_errors[order, :] + + result = {} + + for i in range(similarity_bins.size - 1): + cmin, cmax = similarity_bins[i], similarity_bins[i + 1] + amin, amax = np.searchsorted(all_similarities, [cmin, cmax]) + mean_errors = np.nanmean(all_errors[amin:amax], axis=0) + result[(cmin, cmax)] = mean_errors + + return result diff --git a/src/spikeinterface/comparison/correlogramstudy.py b/src/spikeinterface/comparison/correlogramstudy.py deleted file mode 100644 index fb00c08157..0000000000 --- a/src/spikeinterface/comparison/correlogramstudy.py +++ /dev/null @@ -1,76 +0,0 @@ -from .groundtruthstudy import GroundTruthStudy -from .studytools import iter_computed_sorting -from .correlogramcomparison import CorrelogramGTComparison - -import numpy as np - - -class CorrelogramGTStudy(GroundTruthStudy): - def run_comparisons(self, exhaustive_gt=True, window_ms=100.0, bin_ms=1.0, well_detected_score=0.8, **kwargs): - self.comparisons = {} - for rec_name, sorter_name, sorting in iter_computed_sorting(self.study_folder): - gt_sorting = self.get_ground_truth(rec_name) - comp = CorrelogramGTComparison( - gt_sorting, - sorting, - exhaustive_gt=exhaustive_gt, - window_ms=window_ms, - bin_ms=bin_ms, - well_detected_score=well_detected_score, - ) - self.comparisons[(rec_name, sorter_name)] = comp - - self.exhaustive_gt = exhaustive_gt - - @property - def time_bins(self): - for key, value in self.comparisons.items(): - return value.time_bins - - def precompute_scores_by_similarities(self, good_only=True): - if not hasattr(self, "_computed"): - import sklearn - - similarity_matrix = {} - for rec_name in self.rec_names: - templates = self.get_templates(rec_name) - flat_templates = templates.reshape(templates.shape[0], -1) - similarity_matrix[rec_name] = sklearn.metrics.pairwise.cosine_similarity(flat_templates) - - self.all_similarities = {} - self.all_errors = {} - self._computed = True - - for sorter_ind, sorter_name in enumerate(self.sorter_names): - # loop over recordings - all_errors = [] - all_similarities = [] - for rec_name in self.rec_names: - try: - comp = self.comparisons[(rec_name, sorter_name)] - similarities, errors = comp.compute_correlogram_by_similarity(similarity_matrix[rec_name]) - all_similarities.append(similarities) - all_errors.append(errors) - except Exception: - pass - - self.all_similarities[sorter_name] = np.concatenate(all_similarities, axis=0) - self.all_errors[sorter_name] = np.concatenate(all_errors, axis=0) - - def get_error_profile_over_similarity_bins(self, similarity_bins, sorter_name): - all_similarities = self.all_similarities[sorter_name] - all_errors = self.all_errors[sorter_name] - - order = np.argsort(all_similarities) - all_similarities = all_similarities[order] - all_errors = all_errors[order, :] - - result = {} - - for i in range(similarity_bins.size - 1): - cmin, cmax = similarity_bins[i], similarity_bins[i + 1] - amin, amax = np.searchsorted(all_similarities, [cmin, cmax]) - mean_errors = np.nanmean(all_errors[amin:amax], axis=0) - result[(cmin, cmax)] = mean_errors - - return result diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index 7b146f07bc..d43727cb44 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -1,327 +1,403 @@ from pathlib import Path import shutil +import os +import json +import pickle import numpy as np -from spikeinterface.core import load_extractor -from spikeinterface.extractors import NpzSortingExtractor -from spikeinterface.sorters import sorter_dict, run_sorters +from spikeinterface.core import load_extractor, extract_waveforms, load_waveforms +from spikeinterface.core.core_tools import SIJsonEncoder + +from spikeinterface.sorters import run_sorter_jobs, read_sorter_folder from spikeinterface import WaveformExtractor from spikeinterface.qualitymetrics import compute_quality_metrics -from .paircomparisons import compare_sorter_to_ground_truth - -from .studytools import ( - setup_comparison_study, - get_rec_names, - get_recordings, - iter_working_folder, - iter_computed_names, - iter_computed_sorting, - collect_run_times, -) +from .paircomparisons import compare_sorter_to_ground_truth, GroundTruthComparison -class GroundTruthStudy: - def __init__(self, study_folder=None): - import pandas as pd +# TODO later : save comparison in folders when comparison object will be able to serialize - self.study_folder = Path(study_folder) - self._is_scanned = False - self.computed_names = None - self.rec_names = None - self.sorter_names = None - self.scan_folder() +# This is to separate names when the key are tuples when saving folders +_key_separator = " ## " - self.comparisons = None - self.exhaustive_gt = None - def __repr__(self): - t = "Ground truth study\n" - t += " " + str(self.study_folder) + "\n" - t += " recordings: {} {}\n".format(len(self.rec_names), self.rec_names) - if len(self.sorter_names): - t += " sorters: {} {}\n".format(len(self.sorter_names), self.sorter_names) +class GroundTruthStudy: + """ + This class is an helper function to run any comparison on several "cases" for many ground-truth dataset. - return t + "cases" refer to: + * several sorters for comparisons + * same sorter with differents parameters + * any combination of these (and more) - def scan_folder(self): - self.rec_names = get_rec_names(self.study_folder) - # scan computed names - self.computed_names = list(iter_computed_names(self.study_folder)) # list of pair (rec_name, sorter_name) - self.sorter_names = np.unique([e for _, e in iter_computed_names(self.study_folder)]).tolist() - self._is_scanned = True + For increased flexibility, cases keys can be a tuple so that we can vary complexity along several + "levels" or "axis" (paremeters or sorters). + In this case, the result dataframes will have `MultiIndex` to handle the different levels. - @classmethod - def create(cls, study_folder, gt_dict, **job_kwargs): - setup_comparison_study(study_folder, gt_dict, **job_kwargs) - return cls(study_folder) + A ground-truth dataset is made of a `Recording` and a `Sorting` object. For example, it can be a simulated dataset with MEArec or internally generated (see + :py:fun:`~spikeinterface.core.generate.generate_ground_truth_recording()`). - def run_sorters(self, sorter_list, mode_if_folder_exists="keep", remove_sorter_folders=False, **kwargs): - sorter_folders = self.study_folder / "sorter_folders" - recording_dict = get_recordings(self.study_folder) - - run_sorters( - sorter_list, - recording_dict, - sorter_folders, - with_output=False, - mode_if_folder_exists=mode_if_folder_exists, - **kwargs, - ) - - # results are copied so the heavy sorter_folders can be removed - self.copy_sortings() - - if remove_sorter_folders: - shutil.rmtree(self.study_folder / "sorter_folders") - - def _check_rec_name(self, rec_name): - if not self._is_scanned: - self.scan_folder() - if len(self.rec_names) > 1 and rec_name is None: - raise Exception("Pass 'rec_name' parameter to select which recording to use.") - elif len(self.rec_names) == 1: - rec_name = self.rec_names[0] - else: - rec_name = self.rec_names[self.rec_names.index(rec_name)] - return rec_name - - def get_ground_truth(self, rec_name=None): - rec_name = self._check_rec_name(rec_name) - sorting = load_extractor(self.study_folder / "ground_truth" / rec_name) - return sorting - - def get_recording(self, rec_name=None): - rec_name = self._check_rec_name(rec_name) - rec = load_extractor(self.study_folder / "raw_files" / rec_name) - return rec - - def get_sorting(self, sort_name, rec_name=None): - rec_name = self._check_rec_name(rec_name) - - selected_sorting = None - if sort_name in self.sorter_names: - for r_name, sorter_name, sorting in iter_computed_sorting(self.study_folder): - if sort_name == sorter_name and r_name == rec_name: - selected_sorting = sorting - return selected_sorting - - def copy_sortings(self): - sorter_folders = self.study_folder / "sorter_folders" - sorting_folders = self.study_folder / "sortings" - log_olders = self.study_folder / "sortings" / "run_log" - - log_olders.mkdir(parents=True, exist_ok=True) - - for rec_name, sorter_name, output_folder in iter_working_folder(sorter_folders): - SorterClass = sorter_dict[sorter_name] - fname = rec_name + "[#]" + sorter_name - npz_filename = sorting_folders / (fname + ".npz") - - try: - sorting = SorterClass.get_result_from_folder(output_folder) - NpzSortingExtractor.write_sorting(sorting, npz_filename) - except: - if npz_filename.is_file(): - npz_filename.unlink() - if (output_folder / "spikeinterface_log.json").is_file(): - shutil.copyfile( - output_folder / "spikeinterface_log.json", sorting_folders / "run_log" / (fname + ".json") - ) + This GroundTruthStudy have been refactor in version 0.100 to be more flexible than previous versions. + Note that the underlying folder structure is not backward compatible! + """ - self.scan_folder() + def __init__(self, study_folder): + self.folder = Path(study_folder) - def run_comparisons(self, exhaustive_gt=False, **kwargs): + self.datasets = {} + self.cases = {} + self.sortings = {} self.comparisons = {} - for rec_name, sorter_name, sorting in iter_computed_sorting(self.study_folder): - gt_sorting = self.get_ground_truth(rec_name) - sc = compare_sorter_to_ground_truth(gt_sorting, sorting, exhaustive_gt=exhaustive_gt, **kwargs) - self.comparisons[(rec_name, sorter_name)] = sc - self.exhaustive_gt = exhaustive_gt - def aggregate_run_times(self): - return collect_run_times(self.study_folder) - - def aggregate_performance_by_unit(self): - assert self.comparisons is not None, "run_comparisons first" + self.scan_folder() - perf_by_unit = [] - for rec_name, sorter_name, sorting in iter_computed_sorting(self.study_folder): - comp = self.comparisons[(rec_name, sorter_name)] + @classmethod + def create(cls, study_folder, datasets={}, cases={}, levels=None): + # check that cases keys are homogeneous + key0 = list(cases.keys())[0] + if isinstance(key0, str): + assert all(isinstance(key, str) for key in cases.keys()), "Keys for cases are not homogeneous" + if levels is None: + levels = "level0" + else: + assert isinstance(levels, str) + elif isinstance(key0, tuple): + assert all(isinstance(key, tuple) for key in cases.keys()), "Keys for cases are not homogeneous" + num_levels = len(key0) + assert all( + len(key) == num_levels for key in cases.keys() + ), "Keys for cases are not homogeneous, tuple negth differ" + if levels is None: + levels = [f"level{i}" for i in range(num_levels)] + else: + levels = list(levels) + assert len(levels) == num_levels + else: + raise ValueError("Keys for cases must str or tuple") - perf = comp.get_performance(method="by_unit", output="pandas") - perf["rec_name"] = rec_name - perf["sorter_name"] = sorter_name - perf = perf.reset_index() - perf_by_unit.append(perf) + study_folder = Path(study_folder) + study_folder.mkdir(exist_ok=False, parents=True) - import pandas as pd + (study_folder / "datasets").mkdir() + (study_folder / "datasets" / "recordings").mkdir() + (study_folder / "datasets" / "gt_sortings").mkdir() + (study_folder / "sorters").mkdir() + (study_folder / "sortings").mkdir() + (study_folder / "sortings" / "run_logs").mkdir() + (study_folder / "metrics").mkdir() - perf_by_unit = pd.concat(perf_by_unit) - perf_by_unit = perf_by_unit.set_index(["rec_name", "sorter_name", "gt_unit_id"]) + for key, (rec, gt_sorting) in datasets.items(): + assert "/" not in key, "'/' cannot be in the key name!" + assert "\\" not in key, "'\\' cannot be in the key name!" - return perf_by_unit + # recordings are pickled + rec.dump_to_pickle(study_folder / f"datasets/recordings/{key}.pickle") - def aggregate_count_units(self, well_detected_score=None, redundant_score=None, overmerged_score=None): - assert self.comparisons is not None, "run_comparisons first" + # sortings are pickled + saved as NumpyFolderSorting + gt_sorting.dump_to_pickle(study_folder / f"datasets/gt_sortings/{key}.pickle") + gt_sorting.save(format="numpy_folder", folder=study_folder / f"datasets/gt_sortings/{key}") - import pandas as pd + info = {} + info["levels"] = levels + (study_folder / "info.json").write_text(json.dumps(info, indent=4), encoding="utf8") - index = pd.MultiIndex.from_tuples(self.computed_names, names=["rec_name", "sorter_name"]) + # cases is dumped to a pickle file, json is not possible because of the tuple key + (study_folder / "cases.pickle").write_bytes(pickle.dumps(cases)) - count_units = pd.DataFrame( - index=index, - columns=["num_gt", "num_sorter", "num_well_detected", "num_redundant", "num_overmerged"], - dtype=int, - ) + return cls(study_folder) - if self.exhaustive_gt: - count_units["num_false_positive"] = pd.Series(dtype=int) - count_units["num_bad"] = pd.Series(dtype=int) + def scan_folder(self): + if not (self.folder / "datasets").exists(): + raise ValueError(f"This is folder is not a GroundTruthStudy : {self.folder.absolute()}") - for rec_name, sorter_name, sorting in iter_computed_sorting(self.study_folder): - gt_sorting = self.get_ground_truth(rec_name) - comp = self.comparisons[(rec_name, sorter_name)] + with open(self.folder / "info.json", "r") as f: + self.info = json.load(f) - count_units.loc[(rec_name, sorter_name), "num_gt"] = len(gt_sorting.get_unit_ids()) - count_units.loc[(rec_name, sorter_name), "num_sorter"] = len(sorting.get_unit_ids()) - count_units.loc[(rec_name, sorter_name), "num_well_detected"] = comp.count_well_detected_units( - well_detected_score - ) - if self.exhaustive_gt: - count_units.loc[(rec_name, sorter_name), "num_overmerged"] = comp.count_overmerged_units( - overmerged_score - ) - count_units.loc[(rec_name, sorter_name), "num_redundant"] = comp.count_redundant_units(redundant_score) - count_units.loc[(rec_name, sorter_name), "num_false_positive"] = comp.count_false_positive_units( - redundant_score - ) - count_units.loc[(rec_name, sorter_name), "num_bad"] = comp.count_bad_units() + self.levels = self.info["levels"] - return count_units + for rec_file in (self.folder / "datasets" / "recordings").glob("*.pickle"): + key = rec_file.stem + rec = load_extractor(rec_file) + gt_sorting = load_extractor(self.folder / f"datasets" / "gt_sortings" / key) + self.datasets[key] = (rec, gt_sorting) - def aggregate_dataframes(self, copy_into_folder=True, **karg_thresh): - dataframes = {} - dataframes["run_times"] = self.aggregate_run_times().reset_index() - perfs = self.aggregate_performance_by_unit() + with open(self.folder / "cases.pickle", "rb") as f: + self.cases = pickle.load(f) - dataframes["perf_by_unit"] = perfs.reset_index() - dataframes["count_units"] = self.aggregate_count_units(**karg_thresh).reset_index() + self.comparisons = {k: None for k in self.cases} - if copy_into_folder: - tables_folder = self.study_folder / "tables" - tables_folder.mkdir(parents=True, exist_ok=True) + self.sortings = {} + for key in self.cases: + sorting_folder = self.folder / "sortings" / self.key_to_str(key) + if sorting_folder.exists(): + sorting = load_extractor(sorting_folder) + else: + sorting = None + self.sortings[key] = sorting - for name, df in dataframes.items(): - df.to_csv(str(tables_folder / (name + ".csv")), sep="\t", index=False) - - return dataframes + def __repr__(self): + t = f"{self.__class__.__name__} {self.folder.stem} \n" + t += f" datasets: {len(self.datasets)} {list(self.datasets.keys())}\n" + t += f" cases: {len(self.cases)} {list(self.cases.keys())}\n" + num_computed = sum([1 for sorting in self.sortings.values() if sorting is not None]) + t += f" computed: {num_computed}\n" - def get_waveform_extractor(self, rec_name, sorter_name=None): - rec = self.get_recording(rec_name) + return t - if sorter_name is None: - name = "GroundTruth" - sorting = self.get_ground_truth(rec_name) + def key_to_str(self, key): + if isinstance(key, str): + return key + elif isinstance(key, tuple): + return _key_separator.join(key) else: - assert sorter_name in self.sorter_names - name = sorter_name - sorting = self.get_sorting(sorter_name, rec_name) - - waveform_folder = self.study_folder / "waveforms" / f"waveforms_{name}_{rec_name}" + raise ValueError("Keys for cases must str or tuple") + + def run_sorters(self, case_keys=None, engine="loop", engine_kwargs={}, keep=True, verbose=False): + if case_keys is None: + case_keys = self.cases.keys() + + job_list = [] + for key in case_keys: + sorting_folder = self.folder / "sortings" / self.key_to_str(key) + sorting_exists = sorting_folder.exists() + + sorter_folder = self.folder / "sorters" / self.key_to_str(key) + sorter_folder_exists = sorting_folder.exists() + + if keep: + if sorting_exists: + continue + if sorter_folder_exists: + # the sorter folder exists but havent been copied to sortings folder + sorting = read_sorter_folder(sorter_folder, raise_error=False) + if sorting is not None: + # save and skip + self.copy_sortings(case_keys=[key]) + continue + + if sorting_exists: + # delete older sorting + log before running sorters + shutil.rmtree(sorting_exists) + log_file = self.folder / "sortings" / "run_logs" / f"{self.key_to_str(key)}.json" + if log_file.exists(): + log_file.unlink() + + params = self.cases[key]["run_sorter_params"].copy() + # this ensure that sorter_name is given + recording, _ = self.datasets[self.cases[key]["dataset"]] + sorter_name = params.pop("sorter_name") + job = dict( + sorter_name=sorter_name, + recording=recording, + output_folder=sorter_folder, + ) + job.update(params) + # the verbose is overwritten and global to all run_sorters + job["verbose"] = verbose + job["with_output"] = False + job_list.append(job) + + run_sorter_jobs(job_list, engine=engine, engine_kwargs=engine_kwargs, return_output=False) + + # TODO later create a list in laucher for engine blocking and non-blocking + if engine not in ("slurm",): + self.copy_sortings(case_keys) + + def copy_sortings(self, case_keys=None, force=True): + if case_keys is None: + case_keys = self.cases.keys() + + for key in case_keys: + sorting_folder = self.folder / "sortings" / self.key_to_str(key) + sorter_folder = self.folder / "sorters" / self.key_to_str(key) + log_file = self.folder / "sortings" / "run_logs" / f"{self.key_to_str(key)}.json" + + if (sorter_folder / "spikeinterface_log.json").exists(): + sorting = read_sorter_folder( + sorter_folder, raise_error=False, register_recording=False, sorting_info=False + ) + else: + sorting = None + + if sorting is not None: + if sorting_folder.exists(): + if force: + # delete folder + log + shutil.rmtree(sorting_folder) + if log_file.exists(): + log_file.unlink() + else: + continue + + sorting = sorting.save(format="numpy_folder", folder=sorting_folder) + self.sortings[key] = sorting + + # copy logs + shutil.copyfile(sorter_folder / "spikeinterface_log.json", log_file) + + def run_comparisons(self, case_keys=None, comparison_class=GroundTruthComparison, **kwargs): + if case_keys is None: + case_keys = self.cases.keys() + + for key in case_keys: + dataset_key = self.cases[key]["dataset"] + _, gt_sorting = self.datasets[dataset_key] + sorting = self.sortings[key] + if sorting is None: + self.comparisons[key] = None + continue + comp = comparison_class(gt_sorting, sorting, **kwargs) + self.comparisons[key] = comp + + def get_run_times(self, case_keys=None): + import pandas as pd - if waveform_folder.is_dir(): - we = WaveformExtractor.load(waveform_folder) - else: - we = WaveformExtractor.create(rec, sorting, waveform_folder) + if case_keys is None: + case_keys = self.cases.keys() + + log_folder = self.folder / "sortings" / "run_logs" + + run_times = {} + for key in case_keys: + log_file = log_folder / f"{self.key_to_str(key)}.json" + with open(log_file, mode="r") as logfile: + log = json.load(logfile) + run_time = log.get("run_time", None) + run_times[key] = run_time + + return pd.Series(run_times, name="run_time") + + def extract_waveforms_gt(self, case_keys=None, **extract_kwargs): + if case_keys is None: + case_keys = self.cases.keys() + + base_folder = self.folder / "waveforms" + base_folder.mkdir(exist_ok=True) + + dataset_keys = [self.cases[key]["dataset"] for key in case_keys] + dataset_keys = set(dataset_keys) + for dataset_key in dataset_keys: + # the waveforms depend on the dataset key + wf_folder = base_folder / self.key_to_str(dataset_key) + recording, gt_sorting = self.datasets[dataset_key] + we = extract_waveforms(recording, gt_sorting, folder=wf_folder) + + def get_waveform_extractor(self, key): + # some recording are not dumpable to json and the waveforms extactor need it! + # so we load it with and put after + # this should be fixed in PR 2027 so remove this after + + dataset_key = self.cases[key]["dataset"] + wf_folder = self.folder / "waveforms" / self.key_to_str(dataset_key) + we = load_waveforms(wf_folder, with_recording=False) + recording, _ = self.datasets[dataset_key] + we.set_recording(recording) return we - def compute_waveforms( - self, - rec_name, - sorter_name=None, - ms_before=3.0, - ms_after=4.0, - max_spikes_per_unit=500, - n_jobs=-1, - total_memory="1G", - ): - we = self.get_waveform_extractor(rec_name, sorter_name) - we.set_params(ms_before=ms_before, ms_after=ms_after, max_spikes_per_unit=max_spikes_per_unit) - we.run_extract_waveforms(n_jobs=n_jobs, total_memory=total_memory) - - def get_templates(self, rec_name, sorter_name=None, mode="median"): - """ - Get template for a given recording. - - If sorter_name=None then template are from the ground truth. - - """ - we = self.get_waveform_extractor(rec_name, sorter_name=sorter_name) + def get_templates(self, key, mode="average"): + we = self.get_waveform_extractor(key) templates = we.get_all_templates(mode=mode) return templates - def compute_metrics( - self, - rec_name, - metric_names=["snr"], - ms_before=3.0, - ms_after=4.0, - max_spikes_per_unit=500, - n_jobs=-1, - total_memory="1G", - ): - we = self.get_waveform_extractor(rec_name) - we.set_params(ms_before=ms_before, ms_after=ms_after, max_spikes_per_unit=max_spikes_per_unit) - we.run_extract_waveforms(n_jobs=n_jobs, total_memory=total_memory) - - # metrics - metrics = compute_quality_metrics(we, metric_names=metric_names) - folder = self.study_folder / "metrics" - folder.mkdir(exist_ok=True) - filename = folder / f"metrics _{rec_name}.txt" - metrics.to_csv(filename, sep="\t", index=True) + def compute_metrics(self, case_keys=None, metric_names=["snr", "firing_rate"], force=False): + if case_keys is None: + case_keys = self.cases.keys() + + done = [] + for key in case_keys: + dataset_key = self.cases[key]["dataset"] + if dataset_key in done: + # some case can share the same waveform extractor + continue + done.append(dataset_key) + filename = self.folder / "metrics" / f"{self.key_to_str(dataset_key)}.csv" + if filename.exists(): + if force: + os.remove(filename) + else: + continue + we = self.get_waveform_extractor(key) + metrics = compute_quality_metrics(we, metric_names=metric_names) + metrics.to_csv(filename, sep="\t", index=True) + + def get_metrics(self, key): + import pandas as pd + + dataset_key = self.cases[key]["dataset"] + filename = self.folder / "metrics" / f"{self.key_to_str(dataset_key)}.csv" + if not filename.exists(): + return + metrics = pd.read_csv(filename, sep="\t", index_col=0) + dataset_key = self.cases[key]["dataset"] + recording, gt_sorting = self.datasets[dataset_key] + metrics.index = gt_sorting.unit_ids return metrics - def get_metrics(self, rec_name=None, **metric_kwargs): - """ - Load or compute units metrics for a given recording. - """ - rec_name = self._check_rec_name(rec_name) - metrics_folder = self.study_folder / "metrics" - metrics_folder.mkdir(parents=True, exist_ok=True) + def get_units_snr(self, key): + """ """ + return self.get_metrics(key)["snr"] - filename = self.study_folder / "metrics" / f"metrics _{rec_name}.txt" + def get_performance_by_unit(self, case_keys=None): import pandas as pd - if filename.is_file(): - metrics = pd.read_csv(filename, sep="\t", index_col=0) - gt_sorting = self.get_ground_truth(rec_name) - metrics.index = gt_sorting.unit_ids + if case_keys is None: + case_keys = self.cases.keys() + + perf_by_unit = [] + for key in case_keys: + comp = self.comparisons.get(key, None) + assert comp is not None, "You need to do study.run_comparisons() first" + + perf = comp.get_performance(method="by_unit", output="pandas") + if isinstance(key, str): + perf[self.levels] = key + elif isinstance(key, tuple): + for col, k in zip(self.levels, key): + perf[col] = k + + perf = perf.reset_index() + perf_by_unit.append(perf) + + perf_by_unit = pd.concat(perf_by_unit) + perf_by_unit = perf_by_unit.set_index(self.levels) + return perf_by_unit + + def get_count_units(self, case_keys=None, well_detected_score=None, redundant_score=None, overmerged_score=None): + import pandas as pd + + if case_keys is None: + case_keys = list(self.cases.keys()) + + if isinstance(case_keys[0], str): + index = pd.Index(case_keys, name=self.levels) else: - metrics = self.compute_metrics(rec_name, **metric_kwargs) + index = pd.MultiIndex.from_tuples(case_keys, names=self.levels) - metrics.index.name = "unit_id" - # add rec name columns - metrics["rec_name"] = rec_name + columns = ["num_gt", "num_sorter", "num_well_detected"] + comp = self.comparisons[case_keys[0]] + if comp.exhaustive_gt: + columns.extend(["num_false_positive", "num_redundant", "num_overmerged", "num_bad"]) + count_units = pd.DataFrame(index=index, columns=columns, dtype=int) - return metrics + for key in case_keys: + comp = self.comparisons.get(key, None) + assert comp is not None, "You need to do study.run_comparisons() first" - def get_units_snr(self, rec_name=None, **metric_kwargs): - """ """ - metric = self.get_metrics(rec_name=rec_name, **metric_kwargs) - return metric["snr"] - - def concat_all_snr(self): - metrics = [] - for rec_name in self.rec_names: - df = self.get_metrics(rec_name) - df = df.reset_index() - metrics.append(df) - metrics = pd.concat(metrics) - metrics = metrics.set_index(["rec_name", "unit_id"]) - return metrics["snr"] + gt_sorting = comp.sorting1 + sorting = comp.sorting2 + + count_units.loc[key, "num_gt"] = len(gt_sorting.get_unit_ids()) + count_units.loc[key, "num_sorter"] = len(sorting.get_unit_ids()) + count_units.loc[key, "num_well_detected"] = comp.count_well_detected_units(well_detected_score) + + if comp.exhaustive_gt: + count_units.loc[key, "num_redundant"] = comp.count_redundant_units(redundant_score) + count_units.loc[key, "num_overmerged"] = comp.count_overmerged_units(overmerged_score) + count_units.loc[key, "num_false_positive"] = comp.count_false_positive_units(redundant_score) + count_units.loc[key, "num_bad"] = comp.count_bad_units() + + return count_units diff --git a/src/spikeinterface/comparison/studytools.py b/src/spikeinterface/comparison/studytools.py deleted file mode 100644 index 26d2c1ad6f..0000000000 --- a/src/spikeinterface/comparison/studytools.py +++ /dev/null @@ -1,349 +0,0 @@ -""" -High level tools to run many ground-truth comparison with -many sorter on many recordings and then collect and aggregate results -in an easy way. - -The all mechanism is based on an intrinsic organization -into a "study_folder" with several subfolder: - * raw_files : contain a copy in binary format of recordings - * sorter_folders : contains output of sorters - * ground_truth : contains a copy of sorting ground in npz format - * sortings: contains light copy of all sorting in npz format - * tables: some table in cvs format -""" - -from pathlib import Path -import shutil -import json -import os - - -from spikeinterface.core import load_extractor -from spikeinterface.core.job_tools import fix_job_kwargs -from spikeinterface.extractors import NpzSortingExtractor -from spikeinterface.sorters import sorter_dict -from spikeinterface.sorters.basesorter import is_log_ok - - -from .comparisontools import _perf_keys -from .paircomparisons import compare_sorter_to_ground_truth - - -# This is deprecated and will be removed -def iter_working_folder(working_folder): - working_folder = Path(working_folder) - for rec_folder in working_folder.iterdir(): - if not rec_folder.is_dir(): - continue - for output_folder in rec_folder.iterdir(): - if (output_folder / "spikeinterface_job.json").is_file(): - with open(output_folder / "spikeinterface_job.json", "r") as f: - job_dict = json.load(f) - rec_name = job_dict["rec_name"] - sorter_name = job_dict["sorter_name"] - yield rec_name, sorter_name, output_folder - else: - rec_name = rec_folder.name - sorter_name = output_folder.name - if not output_folder.is_dir(): - continue - if not is_log_ok(output_folder): - continue - yield rec_name, sorter_name, output_folder - - -# This is deprecated and will be removed -def iter_sorting_output(working_folder): - """Iterator over output_folder to retrieve all triplets of (rec_name, sorter_name, sorting).""" - for rec_name, sorter_name, output_folder in iter_working_folder(working_folder): - SorterClass = sorter_dict[sorter_name] - sorting = SorterClass.get_result_from_folder(output_folder) - yield rec_name, sorter_name, sorting - - -def setup_comparison_study(study_folder, gt_dict, **job_kwargs): - """ - Based on a dict of (recording, sorting) create the study folder. - - Parameters - ---------- - study_folder: str - The study folder. - gt_dict : a dict of tuple (recording, sorting_gt) - Dict of tuple that contain recording and sorting ground truth - """ - job_kwargs = fix_job_kwargs(job_kwargs) - study_folder = Path(study_folder) - assert not study_folder.is_dir(), "'study_folder' already exists. Please remove it" - - study_folder.mkdir(parents=True, exist_ok=True) - sorting_folders = study_folder / "sortings" - log_folder = sorting_folders / "run_log" - log_folder.mkdir(parents=True, exist_ok=True) - tables_folder = study_folder / "tables" - tables_folder.mkdir(parents=True, exist_ok=True) - - for rec_name, (recording, sorting_gt) in gt_dict.items(): - # write recording using save with binary - folder = study_folder / "ground_truth" / rec_name - sorting_gt.save(folder=folder, format="numpy_folder") - folder = study_folder / "raw_files" / rec_name - recording.save(folder=folder, format="binary", **job_kwargs) - - # make an index of recording names - with open(study_folder / "names.txt", mode="w", encoding="utf8") as f: - for rec_name in gt_dict: - f.write(rec_name + "\n") - - -def get_rec_names(study_folder): - """ - Get list of keys of recordings. - Read from the 'names.txt' file in study folder. - - Parameters - ---------- - study_folder: str - The study folder. - - Returns - ------- - rec_names: list - List of names. - """ - study_folder = Path(study_folder) - with open(study_folder / "names.txt", mode="r", encoding="utf8") as f: - rec_names = f.read()[:-1].split("\n") - return rec_names - - -def get_recordings(study_folder): - """ - Get ground recording as a dict. - - They are read from the 'raw_files' folder with binary format. - - Parameters - ---------- - study_folder: str - The study folder. - - Returns - ------- - recording_dict: dict - Dict of recording. - """ - study_folder = Path(study_folder) - - rec_names = get_rec_names(study_folder) - recording_dict = {} - for rec_name in rec_names: - rec = load_extractor(study_folder / "raw_files" / rec_name) - recording_dict[rec_name] = rec - - return recording_dict - - -def get_ground_truths(study_folder): - """ - Get ground truth sorting extractor as a dict. - - They are read from the 'ground_truth' folder with npz format. - - Parameters - ---------- - study_folder: str - The study folder. - - Returns - ------- - ground_truths: dict - Dict of sorting_gt. - """ - study_folder = Path(study_folder) - rec_names = get_rec_names(study_folder) - ground_truths = {} - for rec_name in rec_names: - sorting = load_extractor(study_folder / "ground_truth" / rec_name) - ground_truths[rec_name] = sorting - return ground_truths - - -def iter_computed_names(study_folder): - sorting_folder = Path(study_folder) / "sortings" - for filename in os.listdir(sorting_folder): - if filename.endswith(".npz") and "[#]" in filename: - rec_name, sorter_name = filename.replace(".npz", "").split("[#]") - yield rec_name, sorter_name - - -def iter_computed_sorting(study_folder): - """ - Iter over sorting files. - """ - sorting_folder = Path(study_folder) / "sortings" - for filename in os.listdir(sorting_folder): - if filename.endswith(".npz") and "[#]" in filename: - rec_name, sorter_name = filename.replace(".npz", "").split("[#]") - sorting = NpzSortingExtractor(sorting_folder / filename) - yield rec_name, sorter_name, sorting - - -def collect_run_times(study_folder): - """ - Collect run times in a working folder and store it in CVS files. - - The output is list of (rec_name, sorter_name, run_time) - """ - import pandas as pd - - study_folder = Path(study_folder) - sorting_folders = study_folder / "sortings" - log_folder = sorting_folders / "run_log" - tables_folder = study_folder / "tables" - - tables_folder.mkdir(parents=True, exist_ok=True) - - run_times = [] - for filename in os.listdir(log_folder): - if filename.endswith(".json") and "[#]" in filename: - rec_name, sorter_name = filename.replace(".json", "").split("[#]") - with open(log_folder / filename, encoding="utf8", mode="r") as logfile: - log = json.load(logfile) - run_time = log.get("run_time", None) - run_times.append((rec_name, sorter_name, run_time)) - - run_times = pd.DataFrame(run_times, columns=["rec_name", "sorter_name", "run_time"]) - run_times = run_times.set_index(["rec_name", "sorter_name"]) - - return run_times - - -def aggregate_sorting_comparison(study_folder, exhaustive_gt=False): - """ - Loop over output folder in a tree to collect sorting output and run - ground_truth_comparison on them. - - Parameters - ---------- - study_folder: str - The study folder. - exhaustive_gt: bool (default True) - Tell if the ground true is "exhaustive" or not. In other world if the - GT have all possible units. It allows more performance measurement. - For instance, MEArec simulated dataset have exhaustive_gt=True - - Returns - ---------- - comparisons: a dict of SortingComparison - - """ - - study_folder = Path(study_folder) - - ground_truths = get_ground_truths(study_folder) - results = collect_study_sorting(study_folder) - - comparisons = {} - for (rec_name, sorter_name), sorting in results.items(): - gt_sorting = ground_truths[rec_name] - sc = compare_sorter_to_ground_truth(gt_sorting, sorting, exhaustive_gt=exhaustive_gt) - comparisons[(rec_name, sorter_name)] = sc - - return comparisons - - -def aggregate_performances_table(study_folder, exhaustive_gt=False, **karg_thresh): - """ - Aggregate some results into dataframe to have a "study" overview on all recordingXsorter. - - Tables are: - * run_times: run times per recordingXsorter - * perf_pooled_with_sum: GroundTruthComparison.see get_performance - * perf_pooled_with_average: GroundTruthComparison.see get_performance - * count_units: given some threshold count how many units : 'well_detected', 'redundant', 'false_postive_units, 'bad' - - Parameters - ---------- - study_folder: str - The study folder. - karg_thresh: dict - Threshold parameters used for the "count_units" table. - - Returns - ------- - dataframes: a dict of DataFrame - Return several useful DataFrame to compare all results. - Note that count_units depend on karg_thresh. - """ - import pandas as pd - - study_folder = Path(study_folder) - sorter_folders = study_folder / "sorter_folders" - tables_folder = study_folder / "tables" - - comparisons = aggregate_sorting_comparison(study_folder, exhaustive_gt=exhaustive_gt) - ground_truths = get_ground_truths(study_folder) - results = collect_study_sorting(study_folder) - - study_folder = Path(study_folder) - - dataframes = {} - - # get run times: - run_times = pd.read_csv(str(tables_folder / "run_times.csv"), sep="\t") - run_times.columns = ["rec_name", "sorter_name", "run_time"] - run_times = run_times.set_index( - [ - "rec_name", - "sorter_name", - ] - ) - dataframes["run_times"] = run_times - - perf_pooled_with_sum = pd.DataFrame(index=run_times.index, columns=_perf_keys) - dataframes["perf_pooled_with_sum"] = perf_pooled_with_sum - - perf_pooled_with_average = pd.DataFrame(index=run_times.index, columns=_perf_keys) - dataframes["perf_pooled_with_average"] = perf_pooled_with_average - - count_units = pd.DataFrame( - index=run_times.index, columns=["num_gt", "num_sorter", "num_well_detected", "num_redundant"] - ) - dataframes["count_units"] = count_units - if exhaustive_gt: - count_units["num_false_positive"] = None - count_units["num_bad"] = None - - perf_by_spiketrain = [] - - for (rec_name, sorter_name), comp in comparisons.items(): - gt_sorting = ground_truths[rec_name] - sorting = results[(rec_name, sorter_name)] - - perf = comp.get_performance(method="pooled_with_sum", output="pandas") - perf_pooled_with_sum.loc[(rec_name, sorter_name), :] = perf - - perf = comp.get_performance(method="pooled_with_average", output="pandas") - perf_pooled_with_average.loc[(rec_name, sorter_name), :] = perf - - perf = comp.get_performance(method="by_spiketrain", output="pandas") - perf["rec_name"] = rec_name - perf["sorter_name"] = sorter_name - perf = perf.reset_index() - - perf_by_spiketrain.append(perf) - - count_units.loc[(rec_name, sorter_name), "num_gt"] = len(gt_sorting.get_unit_ids()) - count_units.loc[(rec_name, sorter_name), "num_sorter"] = len(sorting.get_unit_ids()) - count_units.loc[(rec_name, sorter_name), "num_well_detected"] = comp.count_well_detected_units(**karg_thresh) - count_units.loc[(rec_name, sorter_name), "num_redundant"] = comp.count_redundant_units() - if exhaustive_gt: - count_units.loc[(rec_name, sorter_name), "num_false_positive"] = comp.count_false_positive_units() - count_units.loc[(rec_name, sorter_name), "num_bad"] = comp.count_bad_units() - - perf_by_spiketrain = pd.concat(perf_by_spiketrain) - perf_by_spiketrain = perf_by_spiketrain.set_index(["rec_name", "sorter_name", "gt_unit_id"]) - dataframes["perf_by_spiketrain"] = perf_by_spiketrain - - return dataframes diff --git a/src/spikeinterface/comparison/tests/test_groundtruthstudy.py b/src/spikeinterface/comparison/tests/test_groundtruthstudy.py index 70f8a63c8c..91c8c640e0 100644 --- a/src/spikeinterface/comparison/tests/test_groundtruthstudy.py +++ b/src/spikeinterface/comparison/tests/test_groundtruthstudy.py @@ -1,19 +1,11 @@ -import importlib import shutil import pytest from pathlib import Path -from spikeinterface.extractors import toy_example -from spikeinterface.sorters import installed_sorters +from spikeinterface import generate_ground_truth_recording +from spikeinterface.preprocessing import bandpass_filter from spikeinterface.comparison import GroundTruthStudy -try: - import tridesclous - - HAVE_TDC = True -except ImportError: - HAVE_TDC = False - if hasattr(pytest, "global_test_folder"): cache_folder = pytest.global_test_folder / "comparison" @@ -27,61 +19,85 @@ def setup_module(): if study_folder.is_dir(): shutil.rmtree(study_folder) - _setup_comparison_study() + create_a_study(study_folder) + + +def simple_preprocess(rec): + return bandpass_filter(rec) -def _setup_comparison_study(): - rec0, gt_sorting0 = toy_example(num_channels=4, duration=30, seed=0, num_segments=1) - rec1, gt_sorting1 = toy_example(num_channels=32, duration=30, seed=0, num_segments=1) +def create_a_study(study_folder): + rec0, gt_sorting0 = generate_ground_truth_recording(num_channels=4, durations=[30.0], seed=42) + rec1, gt_sorting1 = generate_ground_truth_recording(num_channels=4, durations=[30.0], seed=91) - gt_dict = { + datasets = { "toy_tetrode": (rec0, gt_sorting0), "toy_probe32": (rec1, gt_sorting1), + "toy_probe32_preprocess": (simple_preprocess(rec1), gt_sorting1), } - study = GroundTruthStudy.create(study_folder, gt_dict) + # cases can also be generated via simple loops + cases = { + # + ("tdc2", "no-preprocess", "tetrode"): { + "label": "tridesclous2 without preprocessing and standard params", + "dataset": "toy_tetrode", + "run_sorter_params": { + "sorter_name": "tridesclous2", + }, + "comparison_params": {}, + }, + # + ("tdc2", "with-preprocess", "probe32"): { + "label": "tridesclous2 with preprocessing standar params", + "dataset": "toy_probe32_preprocess", + "run_sorter_params": { + "sorter_name": "tridesclous2", + }, + "comparison_params": {}, + }, + # we comment this at the moement because SC2 is quite slow for testing + # ("sc2", "no-preprocess", "tetrode"): { + # "label": "spykingcircus2 without preprocessing standar params", + # "dataset": "toy_tetrode", + # "run_sorter_params": { + # "sorter_name": "spykingcircus2", + # }, + # "comparison_params": { + # }, + # }, + } -@pytest.mark.skipif(not HAVE_TDC, reason="Test requires Python package 'tridesclous'") -def test_run_study_sorters(): - study = GroundTruthStudy(study_folder) - sorter_list = [ - "tridesclous", - ] - print( - f"\n#################################\nINSTALLED SORTERS\n#################################\n" - f"{installed_sorters()}" + study = GroundTruthStudy.create( + study_folder, datasets=datasets, cases=cases, levels=["sorter_name", "processing", "probe_type"] ) - study.run_sorters(sorter_list) + # print(study) -@pytest.mark.skipif(not HAVE_TDC, reason="Test requires Python package 'tridesclous'") -def test_extract_sortings(): +def test_GroundTruthStudy(): study = GroundTruthStudy(study_folder) + print(study) - study.copy_sortings() - - for rec_name in study.rec_names: - gt_sorting = study.get_ground_truth(rec_name) - - for rec_name in study.rec_names: - metrics = study.get_metrics(rec_name=rec_name) + study.run_sorters(verbose=True) - snr = study.get_units_snr(rec_name=rec_name) + print(study.sortings) - study.copy_sortings() + print(study.comparisons) + study.run_comparisons() + print(study.comparisons) - run_times = study.aggregate_run_times() + study.extract_waveforms_gt(n_jobs=-1) - study.run_comparisons(exhaustive_gt=True) + study.compute_metrics() - perf = study.aggregate_performance_by_unit() + for key in study.cases: + metrics = study.get_metrics(key) + print(metrics) - count_units = study.aggregate_count_units() - dataframes = study.aggregate_dataframes() - print(dataframes) + study.get_performance_by_unit() + study.get_count_units() if __name__ == "__main__": - # setup_module() - # test_run_study_sorters() - test_extract_sortings() + setup_module() + test_GroundTruthStudy() diff --git a/src/spikeinterface/comparison/tests/test_studytools.py b/src/spikeinterface/comparison/tests/test_studytools.py deleted file mode 100644 index dbc39d5e1d..0000000000 --- a/src/spikeinterface/comparison/tests/test_studytools.py +++ /dev/null @@ -1,59 +0,0 @@ -import os -import shutil -from pathlib import Path - -import pytest - -from spikeinterface.extractors import toy_example -from spikeinterface.comparison.studytools import ( - setup_comparison_study, - iter_computed_names, - iter_computed_sorting, - get_rec_names, - get_ground_truths, - get_recordings, -) - -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "comparison" -else: - cache_folder = Path("cache_folder") / "comparison" - - -study_folder = cache_folder / "test_studytools" - - -def setup_module(): - if study_folder.is_dir(): - shutil.rmtree(study_folder) - - -def test_setup_comparison_study(): - rec0, gt_sorting0 = toy_example(num_channels=4, duration=30, seed=0, num_segments=1) - rec1, gt_sorting1 = toy_example(num_channels=32, duration=30, seed=0, num_segments=1) - - gt_dict = { - "toy_tetrode": (rec0, gt_sorting0), - "toy_probe32": (rec1, gt_sorting1), - } - setup_comparison_study(study_folder, gt_dict) - - -def test_get_ground_truths(): - names = get_rec_names(study_folder) - d = get_ground_truths(study_folder) - d = get_recordings(study_folder) - - -def test_loops(): - names = list(iter_computed_names(study_folder)) - for rec_name, sorter_name, sorting in iter_computed_sorting(study_folder): - print(rec_name, sorter_name) - print(sorting) - - -if __name__ == "__main__": - setup_module() - test_setup_comparison_study() - test_get_ground_truths() - test_loops() diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index e8b3232e13..1430e8fb45 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -424,14 +424,15 @@ def from_dict(dictionary: dict, base_folder: Optional[Union[Path, str]] = None) extractor: RecordingExtractor or SortingExtractor The loaded extractor object """ - if dictionary["relative_paths"]: + # for pickle dump relative_path was not in the dict, this ensure compatibility + if dictionary.get("relative_paths", False): assert base_folder is not None, "When relative_paths=True, need to provide base_folder" dictionary = _make_paths_absolute(dictionary, base_folder) extractor = _load_extractor_from_dict(dictionary) folder_metadata = dictionary.get("folder_metadata", None) if folder_metadata is not None: folder_metadata = Path(folder_metadata) - if dictionary["relative_paths"]: + if dictionary.get("relative_paths", False): folder_metadata = base_folder / folder_metadata extractor.load_metadata_from_folder(folder_metadata) return extractor @@ -627,6 +628,7 @@ def dump_to_pickle( include_annotations=True, include_properties=include_properties, folder_metadata=folder_metadata, + relative_to=None, recursive=False, ) file_path = self._get_file_path(file_path, [".pkl", ".pickle"]) diff --git a/src/spikeinterface/sorters/basesorter.py b/src/spikeinterface/sorters/basesorter.py index 8d87558191..a956f8c811 100644 --- a/src/spikeinterface/sorters/basesorter.py +++ b/src/spikeinterface/sorters/basesorter.py @@ -202,7 +202,7 @@ def load_recording_from_folder(cls, output_folder, with_warnings=False): recording = None else: recording = load_extractor(json_file, base_folder=output_folder) - elif pickle_file.exits(): + elif pickle_file.exists(): recording = load_extractor(pickle_file) return recording @@ -324,8 +324,12 @@ def get_result_from_folder(cls, output_folder, register_recording=True, sorting_ if sorting_info: # set sorting info to Sorting object - with open(output_folder / "spikeinterface_recording.json", "r") as f: - rec_dict = json.load(f) + if (output_folder / "spikeinterface_recording.json").exists(): + with open(output_folder / "spikeinterface_recording.json", "r") as f: + rec_dict = json.load(f) + else: + rec_dict = None + with open(output_folder / "spikeinterface_params.json", "r") as f: params_dict = json.load(f) with open(output_folder / "spikeinterface_log.json", "r") as f: diff --git a/src/spikeinterface/sorters/internal/si_based.py b/src/spikeinterface/sorters/internal/si_based.py index 1496ffbbd1..989fab1258 100644 --- a/src/spikeinterface/sorters/internal/si_based.py +++ b/src/spikeinterface/sorters/internal/si_based.py @@ -1,4 +1,4 @@ -from spikeinterface.core import load_extractor +from spikeinterface.core import load_extractor, NumpyRecording from spikeinterface.sorters import BaseSorter @@ -14,7 +14,6 @@ def is_installed(cls): @classmethod def _setup_recording(cls, recording, output_folder, params, verbose): - # nothing to do here because the spikeinterface_recording.json is here anyway pass @classmethod diff --git a/src/spikeinterface/sorters/launcher.py b/src/spikeinterface/sorters/launcher.py index f32a468a22..704f6843f2 100644 --- a/src/spikeinterface/sorters/launcher.py +++ b/src/spikeinterface/sorters/launcher.py @@ -66,7 +66,8 @@ def run_sorter_jobs(job_list, engine="loop", engine_kwargs={}, return_output=Fal engine_kwargs: dict return_output: bool, dfault False - Return a sorting or None. + Return a sortings or None. + This also overwrite kwargs in in run_sorter(with_sorting=True/False) Returns ------- @@ -88,8 +89,12 @@ def run_sorter_jobs(job_list, engine="loop", engine_kwargs={}, return_output=Fal "processpoolexecutor", ), "Only 'loop', 'joblib', and 'processpoolexecutor' support return_output=True." out = [] + for kwargs in job_list: + kwargs["with_output"] = True else: out = None + for kwargs in job_list: + kwargs["with_output"] = False if engine == "loop": # simple loop in main process diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py index 9593f14d1c..c10c78cbfc 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py @@ -33,22 +33,6 @@ from .sortingperformance import plot_sorting_performance -# ground truth study (=comparison over sorter) -from .gtstudy import ( - StudyComparisonRunTimesWidget, - plot_gt_study_run_times, - StudyComparisonUnitCountsWidget, - StudyComparisonUnitCountsAveragesWidget, - plot_gt_study_unit_counts, - plot_gt_study_unit_counts_averages, - plot_gt_study_performances, - plot_gt_study_performances_averages, - StudyComparisonPerformancesWidget, - StudyComparisonPerformancesAveragesWidget, - plot_gt_study_performances_by_template_similarity, - StudyComparisonPerformancesByTemplateSimilarity, -) - # ground truth comparions (=comparison over sorter) from .gtcomparison import ( plot_gt_performances, diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py index 6d981e1fd4..d25f1ea97b 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py @@ -1,7 +1,6 @@ import numpy as np from .basewidget import BaseWidget -from spikeinterface.comparison.collisioncomparison import CollisionGTComparison class ComparisonCollisionPairByPairWidget(BaseWidget): diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/gtstudy.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/gtstudy.py deleted file mode 100644 index 573221f528..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/gtstudy.py +++ /dev/null @@ -1,574 +0,0 @@ -""" -Various widgets on top of GroundTruthStudy to summary results: - * run times - * performances - * count units -""" -import numpy as np - - -from .basewidget import BaseWidget - - -class StudyComparisonRunTimesWidget(BaseWidget): - """ - Plot run times for a study. - - Parameters - ---------- - study: GroundTruthStudy - The study object to consider - ax: matplotlib ax - The ax to be used. If not given a figure is created - color: - - - """ - - def __init__(self, study, color="#F7DC6F", ax=None): - from matplotlib import pyplot as plt - import pandas as pd - - self.study = study - self.color = color - - BaseWidget.__init__(self, ax=ax) - - def plot(self): - study = self.study - ax = self.ax - - all_run_times = study.aggregate_run_times() - av_run_times = all_run_times.reset_index().groupby("sorter_name")["run_time"].mean() - - if len(study.rec_names) == 1: - # no errors bars - yerr = None - else: - # errors bars across recording - yerr = all_run_times.reset_index().groupby("sorter_name")["run_time"].std() - - sorter_names = av_run_times.index - - x = np.arange(sorter_names.size) + 1 - ax.bar(x, av_run_times.values, width=0.8, color=self.color, yerr=yerr) - ax.set_ylabel("run times (s)") - ax.set_xticks(x) - ax.set_xticklabels(sorter_names, rotation=45) - ax.set_xlim(0, sorter_names.size + 1) - - -def plot_gt_study_run_times(*args, **kwargs): - W = StudyComparisonRunTimesWidget(*args, **kwargs) - W.plot() - return W - - -plot_gt_study_run_times.__doc__ = StudyComparisonRunTimesWidget.__doc__ - - -class StudyComparisonUnitCountsAveragesWidget(BaseWidget): - """ - Plot averages over found units for a study. - - Parameters - ---------- - study: GroundTruthStudy - The study object to consider - ax: matplotlib ax - The ax to be used. If not given a figure is created - cmap_name - log_scale: if the y-axis should be displayed as log scaled - - """ - - def __init__(self, study, cmap_name="Set2", log_scale=False, ax=None): - from matplotlib import pyplot as plt - import pandas as pd - - self.study = study - self.cmap_name = cmap_name - self.log_scale = log_scale - - BaseWidget.__init__(self, ax=ax) - - def plot(self): - study = self.study - ax = self.ax - - count_units = study.aggregate_count_units() - - if study.exhaustive_gt: - columns = ["num_well_detected", "num_false_positive", "num_redundant", "num_overmerged"] - else: - columns = ["num_well_detected", "num_redundant", "num_overmerged"] - ncol = len(columns) - - df = count_units.reset_index() - - m = df.groupby("sorter_name")[columns].mean() - - cmap = plt.get_cmap(self.cmap_name, 4) - - if len(study.rec_names) == 1: - # no errors bars - stds = None - else: - # errors bars across recording - stds = df.groupby("sorter_name")[columns].std() - - sorter_names = m.index - clean_labels = [col.replace("num_", "").replace("_", " ").title() for col in columns] - - for c, col in enumerate(columns): - x = np.arange(sorter_names.size) + 1 + c / (ncol + 2) - if stds is None: - yerr = None - else: - yerr = stds[col].values - ax.bar(x, m[col].values, yerr=yerr, width=1 / (ncol + 2), color=cmap(c), label=clean_labels[c]) - - ax.legend() - if self.log_scale: - ax.set_yscale("log") - - ax.set_xticks(np.arange(sorter_names.size) + 1) - ax.set_xticklabels(sorter_names, rotation=45) - ax.set_ylabel("# units") - ax.set_xlim(0, sorter_names.size + 1) - - if count_units["num_gt"].unique().size == 1: - num_gt = count_units["num_gt"].unique()[0] - ax.axhline(num_gt, ls="--", color="k") - - -class StudyComparisonUnitCountsWidget(BaseWidget): - """ - Plot averages over found units for a study. - - Parameters - ---------- - study: GroundTruthStudy - The study object to consider - ax: matplotlib ax - The ax to be used. If not given a figure is created - cmap_name - log_scale: if the y-axis should be displayed as log scaled - - """ - - def __init__(self, study, cmap_name="Set2", log_scale=False, ax=None): - from matplotlib import pyplot as plt - import pandas as pd - - self.study = study - self.cmap_name = cmap_name - self.log_scale = log_scale - - num_rec = len(study.rec_names) - if ax is None: - fig, axes = plt.subplots(ncols=1, nrows=num_rec, squeeze=False) - else: - axes = np.array([ax]).T - - BaseWidget.__init__(self, axes=axes) - - def plot(self): - study = self.study - ax = self.ax - - import seaborn as sns - - study = self.study - - count_units = study.aggregate_count_units() - count_units = count_units.reset_index() - - if study.exhaustive_gt: - columns = ["num_well_detected", "num_false_positive", "num_redundant", "num_overmerged"] - else: - columns = ["num_well_detected", "num_redundant", "num_overmerged"] - - ncol = len(columns) - cmap = plt.get_cmap(self.cmap_name, 4) - - for r, rec_name in enumerate(study.rec_names): - ax = self.axes[r, 0] - ax.set_title(rec_name) - df = count_units.loc[count_units["rec_name"] == rec_name, :] - m = df.groupby("sorter_name")[columns].mean() - sorter_names = m.index - clean_labels = [col.replace("num_", "").replace("_", " ").title() for col in columns] - - for c, col in enumerate(columns): - x = np.arange(sorter_names.size) + 1 + c / (ncol + 2) - ax.bar(x, m[col].values, width=1 / (ncol + 2), color=cmap(c), label=clean_labels[c]) - - if r == 0: - ax.legend() - - if self.log_scale: - ax.set_yscale("log") - - if r == len(study.rec_names) - 1: - ax.set_xticks(np.arange(sorter_names.size) + 1) - ax.set_xticklabels(sorter_names, rotation=45) - ax.set_ylabel("# units") - ax.set_xlim(0, sorter_names.size + 1) - - if count_units["num_gt"].unique().size == 1: - num_gt = count_units["num_gt"].unique()[0] - ax.axhline(num_gt, ls="--", color="k") - - -def plot_gt_study_unit_counts_averages(*args, **kwargs): - W = StudyComparisonUnitCountsAveragesWidget(*args, **kwargs) - W.plot() - return W - - -plot_gt_study_unit_counts_averages.__doc__ = StudyComparisonUnitCountsAveragesWidget.__doc__ - - -def plot_gt_study_unit_counts(*args, **kwargs): - W = StudyComparisonUnitCountsWidget(*args, **kwargs) - W.plot() - return W - - -plot_gt_study_unit_counts.__doc__ = StudyComparisonUnitCountsWidget.__doc__ - - -class StudyComparisonPerformancesWidget(BaseWidget): - """ - Plot run times for a study. - - Parameters - ---------- - study: GroundTruthStudy - The study object to consider - ax: matplotlib ax - The ax to be used. If not given a figure is created - cmap_name - - """ - - def __init__(self, study, palette="Set1", ax=None): - from matplotlib import pyplot as plt - import pandas as pd - - self.study = study - self.palette = palette - - num_rec = len(study.rec_names) - if ax is None: - fig, axes = plt.subplots(ncols=1, nrows=num_rec, squeeze=False) - else: - axes = np.array([ax]).T - - BaseWidget.__init__(self, axes=axes) - - def plot(self, average=False): - import seaborn as sns - - study = self.study - - sns.set_palette(sns.color_palette(self.palette)) - - perf_by_units = study.aggregate_performance_by_unit() - perf_by_units = perf_by_units.reset_index() - - for r, rec_name in enumerate(study.rec_names): - ax = self.axes[r, 0] - ax.set_title(rec_name) - df = perf_by_units.loc[perf_by_units["rec_name"] == rec_name, :] - df = pd.melt( - df, - id_vars="sorter_name", - var_name="Metric", - value_name="Score", - value_vars=("accuracy", "precision", "recall"), - ).sort_values("sorter_name") - sns.swarmplot( - data=df, x="sorter_name", y="Score", hue="Metric", dodge=True, s=3, ax=ax - ) # order=sorter_list, - # ~ ax.set_xticklabels(sorter_names_short, rotation=30, ha='center') - # ~ ax.legend(bbox_to_anchor=(1.0, 1), loc=2, borderaxespad=0., frameon=False, fontsize=8, markerscale=0.5) - - ax.set_ylim(0, 1.05) - ax.set_ylabel(f"Perfs for {rec_name}") - if r < len(study.rec_names) - 1: - ax.set_xlabel("") - ax.set(xticklabels=[]) - - -class StudyComparisonTemplateSimilarityWidget(BaseWidget): - """ - Plot run times for a study. - - Parameters - ---------- - study: GroundTruthStudy - The study object to consider - ax: matplotlib ax - The ax to be used. If not given a figure is created - cmap_name - - """ - - def __init__(self, study, cmap_name="Set1", ax=None): - from matplotlib import pyplot as plt - import pandas as pd - - self.study = study - self.cmap_name = cmap_name - - BaseWidget.__init__(self, ax=ax) - - def plot(self): - import seaborn as sns - - study = self.study - ax = self.ax - - perf_by_units = study.aggregate_performance_by_unit() - perf_by_units = perf_by_units.reset_index() - - columns = ["accuracy", "precision", "recall"] - to_agg = {} - ncol = len(columns) - - for column in columns: - perf_by_units[column] = pd.to_numeric(perf_by_units[column], downcast="float") - to_agg[column] = ["mean"] - - data = perf_by_units.groupby(["sorter_name", "rec_name"]).agg(to_agg) - - m = data.groupby("sorter_name").mean() - - cmap = plt.get_cmap(self.cmap_name, 4) - - if len(study.rec_names) == 1: - # no errors bars - stds = None - else: - # errors bars across recording - stds = data.groupby("sorter_name").std() - - sorter_names = m.index - clean_labels = [col.replace("num_", "").replace("_", " ").title() for col in columns] - - width = 1 / (ncol + 2) - - for c, col in enumerate(columns): - x = np.arange(sorter_names.size) + 1 + c / (ncol + 2) - if stds is None: - yerr = None - else: - yerr = stds[col].values - ax.bar(x, m[col].values.flatten(), yerr=yerr.flatten(), width=width, color=cmap(c), label=clean_labels[c]) - - ax.legend() - - ax.set_xticks(np.arange(sorter_names.size) + 1 + width) - ax.set_xticklabels(sorter_names, rotation=45) - ax.set_ylabel("metric") - ax.set_xlim(0, sorter_names.size + 1) - - -class StudyComparisonPerformancesAveragesWidget(BaseWidget): - """ - Plot run times for a study. - - Parameters - ---------- - study: GroundTruthStudy - The study object to consider - ax: matplotlib ax - The ax to be used. If not given a figure is created - cmap_name - - """ - - def __init__(self, study, cmap_name="Set1", ax=None): - from matplotlib import pyplot as plt - import pandas as pd - - self.study = study - self.cmap_name = cmap_name - - BaseWidget.__init__(self, ax=ax) - - def plot(self): - import seaborn as sns - - study = self.study - ax = self.ax - - perf_by_units = study.aggregate_performance_by_unit() - perf_by_units = perf_by_units.reset_index() - - columns = ["accuracy", "precision", "recall"] - to_agg = {} - ncol = len(columns) - - for column in columns: - perf_by_units[column] = pd.to_numeric(perf_by_units[column], downcast="float") - to_agg[column] = ["mean"] - - data = perf_by_units.groupby(["sorter_name", "rec_name"]).agg(to_agg) - - m = data.groupby("sorter_name").mean() - - cmap = plt.get_cmap(self.cmap_name, 4) - - if len(study.rec_names) == 1: - # no errors bars - stds = None - else: - # errors bars across recording - stds = data.groupby("sorter_name").std() - - sorter_names = m.index - clean_labels = [col.replace("num_", "").replace("_", " ").title() for col in columns] - - width = 1 / (ncol + 2) - - for c, col in enumerate(columns): - x = np.arange(sorter_names.size) + 1 + c / (ncol + 2) - if stds is None: - yerr = None - else: - yerr = stds[col].values - ax.bar(x, m[col].values.flatten(), yerr=yerr.flatten(), width=width, color=cmap(c), label=clean_labels[c]) - - ax.legend() - - ax.set_xticks(np.arange(sorter_names.size) + 1 + width) - ax.set_xticklabels(sorter_names, rotation=45) - ax.set_ylabel("metric") - ax.set_xlim(0, sorter_names.size + 1) - - -class StudyComparisonPerformancesByTemplateSimilarity(BaseWidget): - """ - Plot run times for a study. - - Parameters - ---------- - study: GroundTruthStudy - The study object to consider - ax: matplotlib ax - The ax to be used. If not given a figure is created - cmap_name - - """ - - def __init__(self, study, cmap_name="Set1", ax=None, ylim=(0.6, 1), show_legend=True): - from matplotlib import pyplot as plt - import pandas as pd - - self.study = study - self.cmap_name = cmap_name - self.show_legend = show_legend - self.ylim = ylim - - BaseWidget.__init__(self, ax=ax) - - def plot(self): - import sklearn - - cmap = plt.get_cmap(self.cmap_name, len(self.study.sorter_names)) - colors = [cmap(i) for i in range(len(self.study.sorter_names))] - - flat_templates_gt = {} - for rec_name in self.study.rec_names: - waveform_folder = self.study.study_folder / "waveforms" / f"waveforms_GroundTruth_{rec_name}" - if not waveform_folder.is_dir(): - self.study.compute_waveforms(rec_name) - - templates = self.study.get_templates(rec_name) - flat_templates_gt[rec_name] = templates.reshape(templates.shape[0], -1) - - all_results = {} - - for sorter_name in self.study.sorter_names: - all_results[sorter_name] = {"similarity": [], "accuracy": []} - - for rec_name in self.study.rec_names: - try: - waveform_folder = self.study.study_folder / "waveforms" / f"waveforms_{sorter_name}_{rec_name}" - if not waveform_folder.is_dir(): - self.study.compute_waveforms(rec_name, sorter_name) - templates = self.study.get_templates(rec_name, sorter_name) - flat_templates = templates.reshape(templates.shape[0], -1) - similarity_matrix = sklearn.metrics.pairwise.cosine_similarity( - flat_templates_gt[rec_name], flat_templates - ) - - comp = self.study.comparisons[(rec_name, sorter_name)] - - for i, u1 in enumerate(comp.sorting1.unit_ids): - u2 = comp.best_match_12[u1] - if u2 != -1: - all_results[sorter_name]["similarity"] += [ - similarity_matrix[comp.sorting1.id_to_index(u1), comp.sorting2.id_to_index(u2)] - ] - all_results[sorter_name]["accuracy"] += [comp.agreement_scores.at[u1, u2]] - except Exception: - pass - - all_results[sorter_name]["similarity"] = np.array(all_results[sorter_name]["similarity"]) - all_results[sorter_name]["accuracy"] = np.array(all_results[sorter_name]["accuracy"]) - - from matplotlib.patches import Ellipse - - similarity_means = [all_results[sorter_name]["similarity"].mean() for sorter_name in self.study.sorter_names] - similarity_stds = [all_results[sorter_name]["similarity"].std() for sorter_name in self.study.sorter_names] - - accuracy_means = [all_results[sorter_name]["accuracy"].mean() for sorter_name in self.study.sorter_names] - accuracy_stds = [all_results[sorter_name]["accuracy"].std() for sorter_name in self.study.sorter_names] - - scount = 0 - for x, y, i, j in zip(similarity_means, accuracy_means, similarity_stds, accuracy_stds): - e = Ellipse((x, y), i, j) - e.set_alpha(0.2) - e.set_facecolor(colors[scount]) - self.ax.add_artist(e) - self.ax.scatter([x], [y], c=colors[scount], label=self.study.sorter_names[scount]) - scount += 1 - - self.ax.set_ylabel("accuracy") - self.ax.set_xlabel("cosine similarity") - if self.ylim is not None: - self.ax.set_ylim(self.ylim) - - if self.show_legend: - self.ax.legend() - - -def plot_gt_study_performances(*args, **kwargs): - W = StudyComparisonPerformancesWidget(*args, **kwargs) - W.plot() - return W - - -plot_gt_study_performances.__doc__ = StudyComparisonPerformancesWidget.__doc__ - - -def plot_gt_study_performances_averages(*args, **kwargs): - W = StudyComparisonPerformancesAveragesWidget(*args, **kwargs) - W.plot() - return W - - -plot_gt_study_performances_averages.__doc__ = StudyComparisonPerformancesAveragesWidget.__doc__ - - -def plot_gt_study_performances_by_template_similarity(*args, **kwargs): - W = StudyComparisonPerformancesByTemplateSimilarity(*args, **kwargs) - W.plot() - return W - - -plot_gt_study_performances_by_template_similarity.__doc__ = StudyComparisonPerformancesByTemplateSimilarity.__doc__ diff --git a/src/spikeinterface/widgets/gtstudy.py b/src/spikeinterface/widgets/gtstudy.py new file mode 100644 index 0000000000..6a27b78dec --- /dev/null +++ b/src/spikeinterface/widgets/gtstudy.py @@ -0,0 +1,253 @@ +import numpy as np + +from .base import BaseWidget, to_attr +from .utils import get_unit_colors + +from ..core import ChannelSparsity +from ..core.waveform_extractor import WaveformExtractor +from ..core.basesorting import BaseSorting + + +class StudyRunTimesWidget(BaseWidget): + """ + Plot sorter run times for a GroundTruthStudy + + + Parameters + ---------- + study: GroundTruthStudy + A study object. + case_keys: list or None + A selection of cases to plot, if None, then all. + + """ + + def __init__( + self, + study, + case_keys=None, + backend=None, + **backend_kwargs, + ): + if case_keys is None: + case_keys = list(study.cases.keys()) + + plot_data = dict( + study=study, + run_times=study.get_run_times(case_keys), + case_keys=case_keys, + ) + + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + + dp = to_attr(data_plot) + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + for i, key in enumerate(dp.case_keys): + label = dp.study.cases[key]["label"] + rt = dp.run_times.loc[key] + self.ax.bar(i, rt, width=0.8, label=label) + + self.ax.legend() + + +# TODO : plot optionally average on some levels using group by +class StudyUnitCountsWidget(BaseWidget): + """ + Plot unit counts for a study: "num_well_detected", "num_false_positive", "num_redundant", "num_overmerged" + + + Parameters + ---------- + study: GroundTruthStudy + A study object. + case_keys: list or None + A selection of cases to plot, if None, then all. + + """ + + def __init__( + self, + study, + case_keys=None, + backend=None, + **backend_kwargs, + ): + if case_keys is None: + case_keys = list(study.cases.keys()) + + plot_data = dict( + study=study, + count_units=study.get_count_units(case_keys=case_keys), + case_keys=case_keys, + ) + + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + from .utils import get_some_colors + + dp = to_attr(data_plot) + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + columns = dp.count_units.columns.tolist() + columns.remove("num_gt") + columns.remove("num_sorter") + + ncol = len(columns) + + colors = get_some_colors(columns, color_engine="auto", map_name="hot") + colors["num_well_detected"] = "green" + + xticklabels = [] + for i, key in enumerate(dp.case_keys): + for c, col in enumerate(columns): + x = i + 1 + c / (ncol + 1) + y = dp.count_units.loc[key, col] + if not "well_detected" in col: + y = -y + + if i == 0: + label = col.replace("num_", "").replace("_", " ").title() + else: + label = None + + self.ax.bar([x], [y], width=1 / (ncol + 2), label=label, color=colors[col]) + + xticklabels.append(dp.study.cases[key]["label"]) + + self.ax.set_xticks(np.arange(len(dp.case_keys)) + 1) + self.ax.set_xticklabels(xticklabels) + self.ax.legend() + + +# TODO : plot optionally average on some levels using group by +class StudyPerformances(BaseWidget): + """ + Plot performances over case for a study. + + + Parameters + ---------- + study: GroundTruthStudy + A study object. + mode: str + Which mode in "swarm" + case_keys: list or None + A selection of cases to plot, if None, then all. + + """ + + def __init__( + self, + study, + mode="swarm", + case_keys=None, + backend=None, + **backend_kwargs, + ): + if case_keys is None: + case_keys = list(study.cases.keys()) + + plot_data = dict( + study=study, + perfs=study.get_performance_by_unit(case_keys=case_keys), + mode=mode, + case_keys=case_keys, + ) + + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + from .utils import get_some_colors + + import pandas as pd + import seaborn as sns + + dp = to_attr(data_plot) + perfs = dp.perfs + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + if dp.mode == "swarm": + levels = perfs.index.names + df = pd.melt( + perfs.reset_index(), + id_vars=levels, + var_name="Metric", + value_name="Score", + value_vars=("accuracy", "precision", "recall"), + ) + df["x"] = df.apply(lambda r: " ".join([r[col] for col in levels]), axis=1) + sns.swarmplot(data=df, x="x", y="Score", hue="Metric", dodge=True) + + +class StudyPerformancesVsMetrics(BaseWidget): + """ + Plot performances vs a metrics (snr for instance) over case for a study. + + + Parameters + ---------- + study: GroundTruthStudy + A study object. + mode: str + Which mode in "swarm" + case_keys: list or None + A selection of cases to plot, if None, then all. + + """ + + def __init__( + self, + study, + metric_name="snr", + performance_name="accuracy", + case_keys=None, + backend=None, + **backend_kwargs, + ): + if case_keys is None: + case_keys = list(study.cases.keys()) + + plot_data = dict( + study=study, + metric_name=metric_name, + performance_name=performance_name, + case_keys=case_keys, + ) + + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + from .utils import get_some_colors + + dp = to_attr(data_plot) + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + study = dp.study + perfs = study.get_performance_by_unit(case_keys=dp.case_keys) + + max_metric = 0 + for key in dp.case_keys: + x = study.get_metrics(key)[dp.metric_name].values + y = perfs.xs(key)[dp.performance_name].values + label = dp.study.cases[key]["label"] + self.ax.scatter(x, y, label=label) + max_metric = max(max_metric, np.max(x)) + + self.ax.legend() + self.ax.set_xlim(0, max_metric * 1.05) + self.ax.set_ylim(0, 1.05) diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index 6ea2593432..ed77de6128 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -24,6 +24,7 @@ from .unit_templates import UnitTemplatesWidget from .unit_waveforms_density_map import UnitWaveformDensityMapWidget from .unit_waveforms import UnitWaveformsWidget +from .gtstudy import StudyRunTimesWidget, StudyUnitCountsWidget, StudyPerformances, StudyPerformancesVsMetrics widget_list = [ @@ -49,6 +50,10 @@ UnitTemplatesWidget, UnitWaveformDensityMapWidget, UnitWaveformsWidget, + StudyRunTimesWidget, + StudyUnitCountsWidget, + StudyPerformances, + StudyPerformancesVsMetrics, ] @@ -106,6 +111,10 @@ plot_unit_templates = UnitTemplatesWidget plot_unit_waveforms_density_map = UnitWaveformDensityMapWidget plot_unit_waveforms = UnitWaveformsWidget +plot_study_run_times = StudyRunTimesWidget +plot_study_unit_counts = StudyUnitCountsWidget +plot_study_performances = StudyPerformances +plot_stufy_performances_vs_metrics = StudyPerformancesVsMetrics def plot_timeseries(*args, **kwargs):