From 0acc125e1688a83c66542f19519045ee2f6eadf6 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 8 Sep 2023 10:28:21 +0200 Subject: [PATCH 01/26] Start GroundTruthStudy refactoring. --- .../comparison/groundtruthstudy.py | 66 ++++++++- .../comparison/tests/test_groundtruthstudy.py | 128 ++++++++++++------ 2 files changed, 152 insertions(+), 42 deletions(-) diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index 7b146f07bc..12588019ba 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -22,8 +22,72 @@ collect_run_times, ) - class GroundTruthStudy: + """ + This class is an helper function to run any comparison on several "cases" for several ground truth dataset. + + "cases" can be: + * several sorter for comparisons + * same sorter with differents parameters + * parameters of comparisons + * any combination of theses + + For enough flexibility cases key can be a tuple so that we can varify complexity along several + "axis" (paremeters or sorter) + + Ground truth dataset need recording+sorting. This can be from meraec file or from the internal generator + :py:fun:`generate_ground_truth_recording()` + + This GroundTruthStudy have been refactor in version 0.100 to be more flexible than previous versions. + Folders structures are not backward compatible. + + + + """ + def __init__(self, study_folder=None): + # import pandas as pd + + self.study_folder = Path(study_folder) + + # self.computed_names = None + # self.recording_names = None + # self.cases_names = None + + self.datasets = {} + self.cases = {} + + # self.rec_names = None + # self.sorter_names = None + + self.scan_folder() + + # self.comparisons = None + # self.exhaustive_gt = None + + @classmethod + def create(cls, study_folder, datasets={}, cases={}): + pass + + def __repr__(self): + t = f"GroundTruthStudy {self.study_folder.stem} \n" + t += f" recordings: {len(self.rec_names)} {self.rec_names}\n" + if len(self.sorter_names): + t += " cases: {} {}\n".format(len(self.sorter_names), self.sorter_names) + + return t + + def scan_folder(self): + self.rec_names = get_rec_names(self.study_folder) + # scan computed names + self.computed_names = list(iter_computed_names(self.study_folder)) # list of pair (rec_name, sorter_name) + self.sorter_names = np.unique([e for _, e in iter_computed_names(self.study_folder)]).tolist() + self._is_scanned = True + + + + + +class OLDGroundTruthStudy: def __init__(self, study_folder=None): import pandas as pd diff --git a/src/spikeinterface/comparison/tests/test_groundtruthstudy.py b/src/spikeinterface/comparison/tests/test_groundtruthstudy.py index 70f8a63c8c..f28d901075 100644 --- a/src/spikeinterface/comparison/tests/test_groundtruthstudy.py +++ b/src/spikeinterface/comparison/tests/test_groundtruthstudy.py @@ -3,16 +3,18 @@ import pytest from pathlib import Path -from spikeinterface.extractors import toy_example +# from spikeinterface.extractors import toy_example +from spikeinterface import generate_ground_truth_recording +from spikeinterface.preprocessing import bandpass_filter from spikeinterface.sorters import installed_sorters from spikeinterface.comparison import GroundTruthStudy -try: - import tridesclous +# try: +# import tridesclous - HAVE_TDC = True -except ImportError: - HAVE_TDC = False +# HAVE_TDC = True +# except ImportError: +# HAVE_TDC = False if hasattr(pytest, "global_test_folder"): @@ -27,61 +29,105 @@ def setup_module(): if study_folder.is_dir(): shutil.rmtree(study_folder) - _setup_comparison_study() + create_study(study_folder) -def _setup_comparison_study(): - rec0, gt_sorting0 = toy_example(num_channels=4, duration=30, seed=0, num_segments=1) - rec1, gt_sorting1 = toy_example(num_channels=32, duration=30, seed=0, num_segments=1) +def simple_preprocess(rec): + return bandpass_filter(rec) - gt_dict = { + +def create_study(study_folder): + rec0, gt_sorting0 = generate_ground_truth_recording(num_channels=4, durations=[30.], seed=42) + rec1, gt_sorting1 = generate_ground_truth_recording(num_channels=4, durations=[30.], seed=91) + + datasets = { "toy_tetrode": (rec0, gt_sorting0), "toy_probe32": (rec1, gt_sorting1), + "toy_probe32_preprocess": (simple_preprocess(rec1), gt_sorting1), } - study = GroundTruthStudy.create(study_folder, gt_dict) + + # cases can also be generated via simple loops + cases = { + # + ("tdc2", "no-preprocess", "tetrode"): { + "label": "tridesclous2 without preprocessing and standard params", + "dataset": "toy_tetrode", + "run_sorter_params": { + + }, + "comparison_params": { + + }, + }, + # + ("tdc2", "with-preprocess", "probe32"): { + "label": "tridesclous2 with preprocessing standar params", + "dataset": "toy_probe32_preprocess", + "run_sorter_params": { + + }, + "comparison_params": { + + }, + }, + # + ("sc2", "no-preprocess", "tetrode"): { + "label": "spykingcircus2 without preprocessing standar params", + "dataset": "toy_tetrode", + "run_sorter_params": { + + }, + "comparison_params": { + + }, + }, + } + + study = GroundTruthStudy.create(study_folder, datasets=datasets, cases=cases) + print(study) -@pytest.mark.skipif(not HAVE_TDC, reason="Test requires Python package 'tridesclous'") -def test_run_study_sorters(): - study = GroundTruthStudy(study_folder) - sorter_list = [ - "tridesclous", - ] - print( - f"\n#################################\nINSTALLED SORTERS\n#################################\n" - f"{installed_sorters()}" - ) - study.run_sorters(sorter_list) +# @pytest.mark.skipif(not HAVE_TDC, reason="Test requires Python package 'tridesclous'") +# def test_run_study_sorters(): +# study = GroundTruthStudy(study_folder) +# sorter_list = [ +# "tridesclous", +# ] +# print( +# f"\n#################################\nINSTALLED SORTERS\n#################################\n" +# f"{installed_sorters()}" +# ) +# study.run_sorters(sorter_list) -@pytest.mark.skipif(not HAVE_TDC, reason="Test requires Python package 'tridesclous'") -def test_extract_sortings(): - study = GroundTruthStudy(study_folder) +# @pytest.mark.skipif(not HAVE_TDC, reason="Test requires Python package 'tridesclous'") +# def test_extract_sortings(): +# study = GroundTruthStudy(study_folder) - study.copy_sortings() +# study.copy_sortings() - for rec_name in study.rec_names: - gt_sorting = study.get_ground_truth(rec_name) +# for rec_name in study.rec_names: +# gt_sorting = study.get_ground_truth(rec_name) - for rec_name in study.rec_names: - metrics = study.get_metrics(rec_name=rec_name) +# for rec_name in study.rec_names: +# metrics = study.get_metrics(rec_name=rec_name) - snr = study.get_units_snr(rec_name=rec_name) +# snr = study.get_units_snr(rec_name=rec_name) - study.copy_sortings() +# study.copy_sortings() - run_times = study.aggregate_run_times() +# run_times = study.aggregate_run_times() - study.run_comparisons(exhaustive_gt=True) +# study.run_comparisons(exhaustive_gt=True) - perf = study.aggregate_performance_by_unit() +# perf = study.aggregate_performance_by_unit() - count_units = study.aggregate_count_units() - dataframes = study.aggregate_dataframes() - print(dataframes) +# count_units = study.aggregate_count_units() +# dataframes = study.aggregate_dataframes() +# print(dataframes) if __name__ == "__main__": - # setup_module() + setup_module() # test_run_study_sorters() - test_extract_sortings() + # test_extract_sortings() From 462961ff8321c1a060705f27005f38dfd6ef3a66 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 8 Sep 2023 13:44:05 +0200 Subject: [PATCH 02/26] new GroundTruthStudy wip --- .../comparison/groundtruthstudy.py | 153 +++++++++++++++--- .../comparison/tests/test_groundtruthstudy.py | 23 ++- 2 files changed, 146 insertions(+), 30 deletions(-) diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index 12588019ba..fc4de5a18d 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -1,26 +1,32 @@ from pathlib import Path import shutil +import json +import pickle import numpy as np from spikeinterface.core import load_extractor -from spikeinterface.extractors import NpzSortingExtractor -from spikeinterface.sorters import sorter_dict, run_sorters +from spikeinterface.core.core_tools import SIJsonEncoder + +from spikeinterface.sorters import run_sorter_jobs, read_sorter_folder from spikeinterface import WaveformExtractor from spikeinterface.qualitymetrics import compute_quality_metrics from .paircomparisons import compare_sorter_to_ground_truth -from .studytools import ( - setup_comparison_study, - get_rec_names, - get_recordings, - iter_working_folder, - iter_computed_names, - iter_computed_sorting, - collect_run_times, -) +# from .studytools import ( +# setup_comparison_study, +# get_rec_names, +# get_recordings, +# iter_working_folder, +# iter_computed_names, +# iter_computed_sorting, +# collect_run_times, +# ) + + +_key_separator = " ## " class GroundTruthStudy: """ @@ -44,10 +50,10 @@ class GroundTruthStudy: """ - def __init__(self, study_folder=None): + def __init__(self, study_folder): # import pandas as pd - self.study_folder = Path(study_folder) + self.folder = Path(study_folder) # self.computed_names = None # self.recording_names = None @@ -66,22 +72,121 @@ def __init__(self, study_folder=None): @classmethod def create(cls, study_folder, datasets={}, cases={}): - pass + study_folder = Path(study_folder) + study_folder.mkdir(exist_ok=False, parents=True) + + (study_folder / "datasets").mkdir() + (study_folder / "datasets/recordings").mkdir() + (study_folder / "datasets/gt_sortings").mkdir() + (study_folder / "sorters").mkdir() + (study_folder / "sortings").mkdir() + + for key, (rec, gt_sorting) in datasets.items(): + assert "/" not in key + assert "\\" not in key + + # rec are pickle + rec.dump_to_pickle(study_folder / f"datasets/recordings/{key}.pickle") + + # sorting are pickle + saved as NumpyFolderSorting + gt_sorting.dump_to_pickle(study_folder / f"datasets/gt_sortings/{key}.pickle") + gt_sorting.save(format="numpy_folder", folder=study_folder / f"datasets/gt_sortings/{key}") + + + # (study_folder / "cases.jon").write_text( + # json.dumps(cases, indent=4, cls=SIJsonEncoder), + # encoding="utf8", + # ) + # cases is dump to a pickle file, json is not possible because of tuple key + (study_folder / "cases.pickle").write_bytes(pickle.dumps(cases)) + + return cls(study_folder) + + + def scan_folder(self): + if not (self.folder / "datasets").exists(): + raise ValueError(f"This is folder is not a {self.folder} GroundTruthStudy") + + for rec_file in (self.folder / "datasets/recordings").glob("*.pickle"): + key = rec_file.stem + rec = load_extractor(rec_file) + gt_sorting = load_extractor(self.folder / f"datasets/gt_sortings/{key}") + self.datasets[key] = (rec, gt_sorting) + + with open(self.folder / "cases.pickle", "rb") as f: + self.cases = pickle.load(f) def __repr__(self): - t = f"GroundTruthStudy {self.study_folder.stem} \n" - t += f" recordings: {len(self.rec_names)} {self.rec_names}\n" - if len(self.sorter_names): - t += " cases: {} {}\n".format(len(self.sorter_names), self.sorter_names) + t = f"GroundTruthStudy {self.folder.stem} \n" + t += f" datasets: {len(self.datasets)} {list(self.datasets.keys())}\n" + t += f" cases: {len(self.cases)} {list(self.cases.keys())}\n" return t - def scan_folder(self): - self.rec_names = get_rec_names(self.study_folder) - # scan computed names - self.computed_names = list(iter_computed_names(self.study_folder)) # list of pair (rec_name, sorter_name) - self.sorter_names = np.unique([e for _, e in iter_computed_names(self.study_folder)]).tolist() - self._is_scanned = True + def key_to_str(self, key): + if isinstance(key, str): + return key + elif isinstance(key, tuple): + return _key_separator.join(key) + else: + raise ValueError("Keys for cases must str or tuple") + + def run_sorters(self, case_keys=None, engine='loop', engine_kwargs={}, keep=True, verbose=False): + """ + + """ + if case_keys is None: + case_keys = self.cases.keys() + + job_list = [] + for key in case_keys: + sorting_folder = self.folder / "sortings" / self.key_to_str(key) + sorting_exists = sorting_folder.exists() + + sorter_folder = self.folder / "sorters" / self.key_to_str(key) + sorter_folder_exists = sorting_folder.exists() + + if keep: + if sorting_exists: + continue + if sorter_folder_exists: + # the sorter folder exists but havent been copied to sortings folder + sorting = read_sorter_folder(sorter_folder, raise_error=False) + if sorting is not None: + # save and skip + sorting.save(format="numpy_folder", folder=sorting_folder) + continue + + params = self.cases[key]["run_sorter_params"].copy() + # this ensure that sorter_name is given + recording, _ = self.datasets[self.cases[key]["dataset"]] + sorter_name = params.pop("sorter_name") + job = dict(sorter_name=sorter_name, + recording=recording, + output_folder=sorter_folder) + job.update(params) + job_list.append(job) + + run_sorter_jobs(job_list, engine=engine, engine_kwargs=engine_kwargs, return_output=False) + + # TODO create a list in laucher for engine blocking and non-blocking + if engine not in ("slurm", ): + self.copy_sortings(case_keys) + + def copy_sortings(self, case_keys=None): + if case_keys is None: + case_keys = self.cases.keys() + + for key in case_keys: + sorting_folder = self.folder / "sortings" / self.key_to_str(key) + sorter_folder = self.folder / "sorters" / self.key_to_str(key) + + sorting = read_sorter_folder(sorter_folder, raise_error=False) + if sorting is not None: + sorting.save(format="numpy_folder", folder=sorting_folder) + + def run_comparisons(self): + pass diff --git a/src/spikeinterface/comparison/tests/test_groundtruthstudy.py b/src/spikeinterface/comparison/tests/test_groundtruthstudy.py index f28d901075..15ba7db2ab 100644 --- a/src/spikeinterface/comparison/tests/test_groundtruthstudy.py +++ b/src/spikeinterface/comparison/tests/test_groundtruthstudy.py @@ -25,18 +25,19 @@ study_folder = cache_folder / "test_groundtruthstudy/" +print(study_folder.absolute()) def setup_module(): if study_folder.is_dir(): shutil.rmtree(study_folder) - create_study(study_folder) + create_a_study(study_folder) def simple_preprocess(rec): return bandpass_filter(rec) -def create_study(study_folder): +def create_a_study(study_folder): rec0, gt_sorting0 = generate_ground_truth_recording(num_channels=4, durations=[30.], seed=42) rec1, gt_sorting1 = generate_ground_truth_recording(num_channels=4, durations=[30.], seed=91) @@ -53,7 +54,7 @@ def create_study(study_folder): "label": "tridesclous2 without preprocessing and standard params", "dataset": "toy_tetrode", "run_sorter_params": { - + "sorter_name": "tridesclous2", }, "comparison_params": { @@ -64,7 +65,7 @@ def create_study(study_folder): "label": "tridesclous2 with preprocessing standar params", "dataset": "toy_probe32_preprocess", "run_sorter_params": { - + "sorter_name": "tridesclous2", }, "comparison_params": { @@ -75,7 +76,7 @@ def create_study(study_folder): "label": "spykingcircus2 without preprocessing standar params", "dataset": "toy_tetrode", "run_sorter_params": { - + "sorter_name": "spykingcircus2", }, "comparison_params": { @@ -87,6 +88,13 @@ def create_study(study_folder): print(study) + +def test_GroundTruthStudy(): + study = GroundTruthStudy(study_folder) + print(study) + + study.run_sorters(verbose=True) + # @pytest.mark.skipif(not HAVE_TDC, reason="Test requires Python package 'tridesclous'") # def test_run_study_sorters(): # study = GroundTruthStudy(study_folder) @@ -128,6 +136,9 @@ def create_study(study_folder): if __name__ == "__main__": - setup_module() + # setup_module() + test_GroundTruthStudy() + + # test_run_study_sorters() # test_extract_sortings() From e0af88dbae3593a62372706ed842cde3b1736464 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 8 Sep 2023 20:32:11 +0200 Subject: [PATCH 03/26] Make internal sorters able to be run with none dumpable to json recording. --- src/spikeinterface/comparison/groundtruthstudy.py | 2 +- .../comparison/tests/test_groundtruthstudy.py | 4 ++-- src/spikeinterface/core/base.py | 6 ++++-- src/spikeinterface/sorters/internal/si_based.py | 14 +++++++++++--- .../sorters/internal/spyking_circus2.py | 4 +--- .../sorters/internal/tridesclous2.py | 4 +--- 6 files changed, 20 insertions(+), 14 deletions(-) diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index fc4de5a18d..2eeb697980 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -105,7 +105,7 @@ def create(cls, study_folder, datasets={}, cases={}): def scan_folder(self): if not (self.folder / "datasets").exists(): - raise ValueError(f"This is folder is not a {self.folder} GroundTruthStudy") + raise ValueError(f"This is folder is not a GroundTruthStudy : {self.folder.absolute()}") for rec_file in (self.folder / "datasets/recordings").glob("*.pickle"): key = rec_file.stem diff --git a/src/spikeinterface/comparison/tests/test_groundtruthstudy.py b/src/spikeinterface/comparison/tests/test_groundtruthstudy.py index 15ba7db2ab..169c5a12bb 100644 --- a/src/spikeinterface/comparison/tests/test_groundtruthstudy.py +++ b/src/spikeinterface/comparison/tests/test_groundtruthstudy.py @@ -85,7 +85,7 @@ def create_a_study(study_folder): } study = GroundTruthStudy.create(study_folder, datasets=datasets, cases=cases) - print(study) + # print(study) @@ -136,7 +136,7 @@ def test_GroundTruthStudy(): if __name__ == "__main__": - # setup_module() + setup_module() test_GroundTruthStudy() diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 87c0805630..4f6043f16e 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -425,14 +425,15 @@ def from_dict(dictionary: dict, base_folder: Optional[Union[Path, str]] = None) extractor: RecordingExtractor or SortingExtractor The loaded extractor object """ - if dictionary["relative_paths"]: + # for pickle dump relative_path was not in the dict, this ensure compatibility + if dictionary.get("relative_paths", False): assert base_folder is not None, "When relative_paths=True, need to provide base_folder" dictionary = _make_paths_absolute(dictionary, base_folder) extractor = _load_extractor_from_dict(dictionary) folder_metadata = dictionary.get("folder_metadata", None) if folder_metadata is not None: folder_metadata = Path(folder_metadata) - if dictionary["relative_paths"]: + if dictionary.get("relative_paths", False): folder_metadata = base_folder / folder_metadata extractor.load_metadata_from_folder(folder_metadata) return extractor @@ -622,6 +623,7 @@ def dump_to_pickle( include_annotations=True, include_properties=include_properties, folder_metadata=folder_metadata, + relative_to=None, recursive=False, ) file_path = self._get_file_path(file_path, [".pkl", ".pickle"]) diff --git a/src/spikeinterface/sorters/internal/si_based.py b/src/spikeinterface/sorters/internal/si_based.py index 1496ffbbd1..ee5dcbea0d 100644 --- a/src/spikeinterface/sorters/internal/si_based.py +++ b/src/spikeinterface/sorters/internal/si_based.py @@ -1,4 +1,4 @@ -from spikeinterface.core import load_extractor +from spikeinterface.core import load_extractor, NumpyRecording from spikeinterface.sorters import BaseSorter @@ -14,8 +14,16 @@ def is_installed(cls): @classmethod def _setup_recording(cls, recording, output_folder, params, verbose): - # nothing to do here because the spikeinterface_recording.json is here anyway - pass + # Some recording not json serializable but they can be saved to pickle + # * NoiseGeneratorRecording or InjectTemplatesRecording: we force a pickle because this is light + # * for NumpyRecording (this is a bit crazy because it flush the entire buffer!!) + if recording.check_if_dumpable() and not isinstance(recording, NumpyRecording): + rec_file = output_folder.parent / "spikeinterface_recording.pickle" + recording.dump_to_pickle(rec_file) + # TODO (hard) : find a solution for NumpyRecording without any dump + # this will need an internal API change I think + # because the run_sorter is from the "folder" (because of container mainly and also many other reasons) + # and not from the recording itself @classmethod def _get_result_from_folder(cls, output_folder): diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 9de2762562..72171cd5b5 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -54,9 +54,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): job_kwargs["verbose"] = verbose job_kwargs["progress_bar"] = verbose - recording = load_extractor( - sorter_output_folder.parent / "spikeinterface_recording.json", base_folder=sorter_output_folder.parent - ) + recording = load_extractor(sorter_output_folder.parent / "spikeinterface_recording.pickle") sampling_rate = recording.get_sampling_frequency() num_channels = recording.get_num_channels() diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index 42f51d3a77..7cbf01cf68 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -49,9 +49,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): import hdbscan - recording_raw = load_extractor( - sorter_output_folder.parent / "spikeinterface_recording.json", base_folder=sorter_output_folder.parent - ) + recording_raw = load_extractor(sorter_output_folder.parent / "spikeinterface_recording.pickle") num_chans = recording_raw.get_num_channels() sampling_frequency = recording_raw.get_sampling_frequency() From 9905bf59fc4447e5f80bbf5acadb71f692337982 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 8 Sep 2023 21:24:24 +0200 Subject: [PATCH 04/26] wip --- src/spikeinterface/comparison/groundtruthstudy.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index 2eeb697980..d760703ea1 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -165,6 +165,8 @@ def run_sorters(self, case_keys=None, engine='loop', engine_kwargs={}, keep=True recording=recording, output_folder=sorter_folder) job.update(params) + # the verbose is overwritten and global to all run_sorters + job["verbose"] = verbose job_list.append(job) run_sorter_jobs(job_list, engine=engine, engine_kwargs=engine_kwargs, return_output=False) From 98fa0f81b280ef79c691444d0d3999abb2c9a160 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Sat, 9 Sep 2023 08:57:29 +0200 Subject: [PATCH 05/26] gt_study wip --- .../comparison/groundtruthstudy.py | 59 ++++++++++++++----- .../comparison/tests/test_groundtruthstudy.py | 12 +++- 2 files changed, 53 insertions(+), 18 deletions(-) diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index d760703ea1..3debced277 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -5,7 +5,7 @@ import numpy as np -from spikeinterface.core import load_extractor +from spikeinterface.core import load_extractor, extract_waveforms from spikeinterface.core.core_tools import SIJsonEncoder from spikeinterface.sorters import run_sorter_jobs, read_sorter_folder @@ -13,7 +13,7 @@ from spikeinterface import WaveformExtractor from spikeinterface.qualitymetrics import compute_quality_metrics -from .paircomparisons import compare_sorter_to_ground_truth +from .paircomparisons import compare_sorter_to_ground_truth, GroundTruthComparison # from .studytools import ( # setup_comparison_study, @@ -51,25 +51,15 @@ class GroundTruthStudy: """ def __init__(self, study_folder): - # import pandas as pd - self.folder = Path(study_folder) - # self.computed_names = None - # self.recording_names = None - # self.cases_names = None - self.datasets = {} self.cases = {} - - # self.rec_names = None - # self.sorter_names = None + self.sortings = {} + self.comparisons = {} self.scan_folder() - # self.comparisons = None - # self.exhaustive_gt = None - @classmethod def create(cls, study_folder, datasets={}, cases={}): study_folder = Path(study_folder) @@ -116,10 +106,26 @@ def scan_folder(self): with open(self.folder / "cases.pickle", "rb") as f: self.cases = pickle.load(f) + self.comparisons = {k: None for k in self.cases} + + self.sortings = {} + for key in self.cases: + sorting_folder = self.folder / "sortings" / self.key_to_str(key) + print(sorting_folder) + print(sorting_folder.is_dir()) + if sorting_folder.exists(): + sorting = load_extractor(sorting_folder) + else: + sorting = None + self.sortings[key] = sorting + + def __repr__(self): t = f"GroundTruthStudy {self.folder.stem} \n" t += f" datasets: {len(self.datasets)} {list(self.datasets.keys())}\n" t += f" cases: {len(self.cases)} {list(self.cases.keys())}\n" + num_computed = sum([1 for sorting in self.sortings.values() if sorting is not None]) + t += f" computed: {num_computed}\n" return t @@ -187,10 +193,31 @@ def copy_sortings(self, case_keys=None): if sorting is not None: sorting.save(format="numpy_folder", folder=sorting_folder) - def run_comparisons(self): - pass + def run_comparisons(self, case_keys=None, comparison_class=GroundTruthComparison, **kwargs): + if case_keys is None: + case_keys = self.cases.keys() + + for key in case_keys: + dataset_key = self.cases[key]["dataset"] + _, gt_sorting = self.datasets[dataset_key] + sorting = self.sortings[key] + comp = comparison_class(gt_sorting, sorting, **kwargs) + self.comparisons[key] = comp + def extract_waveforms_gt(self, case_keys=None, **extract_kwargs): + + if case_keys is None: + case_keys = self.cases.keys() + + base_folder = self.folder / "waveforms" + base_folder.mkdir(exist_ok=True) + + for key in case_keys: + dataset_key = self.cases[key]["dataset"] + recording, gt_sorting = self.datasets[dataset_key] + wf_folder = base_folder / self.key_to_str(key) + we = extract_waveforms(recording, gt_sorting, folder=wf_folder) diff --git a/src/spikeinterface/comparison/tests/test_groundtruthstudy.py b/src/spikeinterface/comparison/tests/test_groundtruthstudy.py index 169c5a12bb..9aaa742184 100644 --- a/src/spikeinterface/comparison/tests/test_groundtruthstudy.py +++ b/src/spikeinterface/comparison/tests/test_groundtruthstudy.py @@ -93,7 +93,15 @@ def test_GroundTruthStudy(): study = GroundTruthStudy(study_folder) print(study) - study.run_sorters(verbose=True) + # study.run_sorters(verbose=True) + + # print(study.sortings) + + # print(study.comparisons) + # study.run_comparisons() + # print(study.comparisons) + + study.extract_waveforms_gt(n_jobs=-1) # @pytest.mark.skipif(not HAVE_TDC, reason="Test requires Python package 'tridesclous'") # def test_run_study_sorters(): @@ -136,7 +144,7 @@ def test_GroundTruthStudy(): if __name__ == "__main__": - setup_module() + # setup_module() test_GroundTruthStudy() From f0940a5265d9f1db235dc4db66af15e0b513fc51 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Sat, 9 Sep 2023 18:19:18 +0200 Subject: [PATCH 06/26] gt study wip --- .../comparison/groundtruthstudy.py | 200 +++++++++++++++++- .../comparison/tests/test_groundtruthstudy.py | 48 +++-- 2 files changed, 224 insertions(+), 24 deletions(-) diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index 3debced277..9eb771b71a 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -1,11 +1,12 @@ from pathlib import Path import shutil +import os import json import pickle import numpy as np -from spikeinterface.core import load_extractor, extract_waveforms +from spikeinterface.core import load_extractor, extract_waveforms, load_waveforms from spikeinterface.core.core_tools import SIJsonEncoder from spikeinterface.sorters import run_sorter_jobs, read_sorter_folder @@ -26,7 +27,16 @@ # ) +# TODO : save comparison in folders +# TODO : find a way to set level names + + + +# This is to separate names when the key are tuples when saving folders _key_separator = " ## " +# This would be more funny +# _key_separator = " (°_°) " + class GroundTruthStudy: """ @@ -70,6 +80,10 @@ def create(cls, study_folder, datasets={}, cases={}): (study_folder / "datasets/gt_sortings").mkdir() (study_folder / "sorters").mkdir() (study_folder / "sortings").mkdir() + (study_folder / "sortings" / "run_logs").mkdir() + (study_folder / "metrics").mkdir() + + for key, (rec, gt_sorting) in datasets.items(): assert "/" not in key @@ -111,8 +125,6 @@ def scan_folder(self): self.sortings = {} for key in self.cases: sorting_folder = self.folder / "sortings" / self.key_to_str(key) - print(sorting_folder) - print(sorting_folder.is_dir()) if sorting_folder.exists(): sorting = load_extractor(sorting_folder) else: @@ -160,9 +172,13 @@ def run_sorters(self, case_keys=None, engine='loop', engine_kwargs={}, keep=True sorting = read_sorter_folder(sorter_folder, raise_error=False) if sorting is not None: # save and skip - sorting.save(format="numpy_folder", folder=sorting_folder) + self.copy_sortings(case_keys=[key]) continue - + + if sorting_exists: + # TODO : delete sorting + log + pass + params = self.cases[key]["run_sorter_params"].copy() # this ensure that sorter_name is given recording, _ = self.datasets[self.cases[key]["dataset"]] @@ -181,17 +197,29 @@ def run_sorters(self, case_keys=None, engine='loop', engine_kwargs={}, keep=True if engine not in ("slurm", ): self.copy_sortings(case_keys) - def copy_sortings(self, case_keys=None): + def copy_sortings(self, case_keys=None, force=True): if case_keys is None: case_keys = self.cases.keys() for key in case_keys: sorting_folder = self.folder / "sortings" / self.key_to_str(key) sorter_folder = self.folder / "sorters" / self.key_to_str(key) + log_file = self.folder / "sortings" / "run_logs" / f"{self.key_to_str(key)}.json" sorting = read_sorter_folder(sorter_folder, raise_error=False) if sorting is not None: - sorting.save(format="numpy_folder", folder=sorting_folder) + if sorting_folder.exists(): + if force: + # TODO delete folder + log + shutil.rmtree(sorting_folder) + else: + continue + + sorting = sorting.save(format="numpy_folder", folder=sorting_folder) + self.sortings[key] = sorting + + # copy logs + shutil.copyfile(sorter_folder / "spikeinterface_log.json", log_file) def run_comparisons(self, case_keys=None, comparison_class=GroundTruthComparison, **kwargs): @@ -202,9 +230,29 @@ def run_comparisons(self, case_keys=None, comparison_class=GroundTruthComparison dataset_key = self.cases[key]["dataset"] _, gt_sorting = self.datasets[dataset_key] sorting = self.sortings[key] + if sorting is None: + self.comparisons[key] = None + continue comp = comparison_class(gt_sorting, sorting, **kwargs) self.comparisons[key] = comp + def get_run_times(self, case_keys=None): + import pandas as pd + if case_keys is None: + case_keys = self.cases.keys() + + log_folder = self.folder / "sortings" / "run_logs" + + run_times = {} + for key in case_keys: + log_file = log_folder / f"{self.key_to_str(key)}.json" + with open(log_file, mode="r") as logfile: + log = json.load(logfile) + run_time = log.get("run_time", None) + run_times[key] = run_time + + return pd.Series(run_times, name="run_time") + def extract_waveforms_gt(self, case_keys=None, **extract_kwargs): if case_keys is None: @@ -219,6 +267,144 @@ def extract_waveforms_gt(self, case_keys=None, **extract_kwargs): wf_folder = base_folder / self.key_to_str(key) we = extract_waveforms(recording, gt_sorting, folder=wf_folder) + def get_waveform_extractor(self, key): + # some recording are not dumpable to json and the waveforms extactor need it! + # so we load it with and put after + we = load_waveforms(self.folder / "waveforms" / self.key_to_str(key), with_recording=False) + dataset_key = self.cases[key]["dataset"] + recording, _ = self.datasets[dataset_key] + we.set_recording(recording) + return we + + def get_templates(self, key, mode="mean"): + we = self.get_waveform_extractor(key) + templates = we.get_all_templates(mode=mode) + return templates + + def compute_metrics(self, case_keys=None, metric_names=["snr", "firing_rate"], force=False): + if case_keys is None: + case_keys = self.cases.keys() + + for key in case_keys: + filename = self.folder / "metrics" / f"{self.key_to_str(key)}.txt" + if filename.exists(): + if force: + os.remove(filename) + else: + continue + + we = self.get_waveform_extractor(key) + metrics = compute_quality_metrics(we, metric_names=metric_names) + metrics.to_csv(filename, sep="\t", index=True) + + def get_metrics(self, key): + import pandas as pd + filename = self.folder / "metrics" / f"{self.key_to_str(key)}.txt" + if not filename.exists(): + return + metrics = pd.read_csv(filename, sep="\t", index_col=0) + dataset_key = self.cases[key]["dataset"] + recording, gt_sorting = self.datasets[dataset_key] + metrics.index = gt_sorting.unit_ids + return metrics + + def get_units_snr(self, key): + """ + """ + return self.get_metrics(key)["snr"] + + def aggregate_performance_by_unit(self, case_keys=None): + + import pandas as pd + + if case_keys is None: + case_keys = self.cases.keys() + + perf_by_unit = [] + for key in case_keys: + comp = self.comparisons.get(key, None) + assert comp is not None, "You need to do study.run_comparisons() first" + + perf = comp.get_performance(method="by_unit", output="pandas") + if isinstance(key, str): + cols = ["level0"] + perf["level0"] = key + + elif isinstance(key, tuple): + cols = [f'level{i}' for i in range(len(key))] + for col, k in zip(cols, key): + perf[col] = k + + perf = perf.reset_index() + perf_by_unit.append(perf) + + + + perf_by_unit = pd.concat(perf_by_unit) + perf_by_unit = perf_by_unit.set_index(cols) + + return perf_by_unit + + # def aggregate_count_units(self, well_detected_score=None, redundant_score=None, overmerged_score=None): + + def aggregate_count_units( + self, case_keys=None, well_detected_score=None, redundant_score=None, overmerged_score=None + ): + + import pandas as pd + + if case_keys is None: + case_keys = self.cases.keys() + + perf_by_unit = [] + for key in case_keys: + comp = self.comparisons.get(key, None) + assert comp is not None, "You need to do study.run_comparisons() first" + + + + # assert self.comparisons is not None, "run_comparisons first" + + # import pandas as pd + + # index = pd.MultiIndex.from_tuples(self.computed_names, names=["rec_name", "sorter_name"]) + + # count_units = pd.DataFrame( + # index=index, + # columns=["num_gt", "num_sorter", "num_well_detected", "num_redundant", "num_overmerged"], + # dtype=int, + # ) + + # if self.exhaustive_gt: + # count_units["num_false_positive"] = pd.Series(dtype=int) + # count_units["num_bad"] = pd.Series(dtype=int) + + # for rec_name, sorter_name, sorting in iter_computed_sorting(self.study_folder): + # gt_sorting = self.get_ground_truth(rec_name) + # comp = self.comparisons[(rec_name, sorter_name)] + + # count_units.loc[(rec_name, sorter_name), "num_gt"] = len(gt_sorting.get_unit_ids()) + # count_units.loc[(rec_name, sorter_name), "num_sorter"] = len(sorting.get_unit_ids()) + # count_units.loc[(rec_name, sorter_name), "num_well_detected"] = comp.count_well_detected_units( + # well_detected_score + # ) + # if self.exhaustive_gt: + # count_units.loc[(rec_name, sorter_name), "num_overmerged"] = comp.count_overmerged_units( + # overmerged_score + # ) + # count_units.loc[(rec_name, sorter_name), "num_redundant"] = comp.count_redundant_units(redundant_score) + # count_units.loc[(rec_name, sorter_name), "num_false_positive"] = comp.count_false_positive_units( + # redundant_score + # ) + # count_units.loc[(rec_name, sorter_name), "num_bad"] = comp.count_bad_units() + + # return count_units + + + + + + class OLDGroundTruthStudy: diff --git a/src/spikeinterface/comparison/tests/test_groundtruthstudy.py b/src/spikeinterface/comparison/tests/test_groundtruthstudy.py index 9aaa742184..3593b0b05f 100644 --- a/src/spikeinterface/comparison/tests/test_groundtruthstudy.py +++ b/src/spikeinterface/comparison/tests/test_groundtruthstudy.py @@ -71,17 +71,17 @@ def create_a_study(study_folder): }, }, - # - ("sc2", "no-preprocess", "tetrode"): { - "label": "spykingcircus2 without preprocessing standar params", - "dataset": "toy_tetrode", - "run_sorter_params": { - "sorter_name": "spykingcircus2", - }, - "comparison_params": { - - }, - }, + # we comment this at the moement because SC2 is quite slow for testing + # ("sc2", "no-preprocess", "tetrode"): { + # "label": "spykingcircus2 without preprocessing standar params", + # "dataset": "toy_tetrode", + # "run_sorter_params": { + # "sorter_name": "spykingcircus2", + # }, + # "comparison_params": { + + # }, + # }, } study = GroundTruthStudy.create(study_folder, datasets=datasets, cases=cases) @@ -93,16 +93,30 @@ def test_GroundTruthStudy(): study = GroundTruthStudy(study_folder) print(study) - # study.run_sorters(verbose=True) + study.run_sorters(verbose=True) - # print(study.sortings) + print(study.sortings) - # print(study.comparisons) - # study.run_comparisons() - # print(study.comparisons) + print(study.comparisons) + study.run_comparisons() + print(study.comparisons) study.extract_waveforms_gt(n_jobs=-1) + study.compute_metrics() + + for key in study.cases: + metrics = study.get_metrics(key) + print(metrics) + + study.aggregate_performance_by_unit() + + +# perf = study.aggregate_performance_by_unit() +# count_units = study.aggregate_count_units() + + + # @pytest.mark.skipif(not HAVE_TDC, reason="Test requires Python package 'tridesclous'") # def test_run_study_sorters(): # study = GroundTruthStudy(study_folder) @@ -144,7 +158,7 @@ def test_GroundTruthStudy(): if __name__ == "__main__": - # setup_module() + setup_module() test_GroundTruthStudy() From b0267dcd72b69c0c1982d57200381c9ab6c1ec0f Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Sun, 10 Sep 2023 21:45:40 +0200 Subject: [PATCH 07/26] Add levels concept in GTStudy --- .../comparison/groundtruthstudy.py | 83 ++++++++++++++++--- .../comparison/tests/test_groundtruthstudy.py | 3 +- 2 files changed, 74 insertions(+), 12 deletions(-) diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index 9eb771b71a..76c019f6b9 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -71,7 +71,29 @@ def __init__(self, study_folder): self.scan_folder() @classmethod - def create(cls, study_folder, datasets={}, cases={}): + def create(cls, study_folder, datasets={}, cases={}, levels=None): + + # check that cases keys are homogeneous + key0 = list(cases.keys())[0] + if isinstance(key0, str): + assert all(isinstance(key, str) for key in cases.keys()), "Keys for cases are not homogeneous" + if levels is None: + levels = "level0" + else: + assert isinstance(levels, str) + elif isinstance(key0, tuple): + assert all(isinstance(key, tuple) for key in cases.keys()), "Keys for cases are not homogeneous" + num_levels = len(key0) + assert all(len(key) == num_levels for key in cases.keys()), "Keys for cases are not homogeneous, tuple negth differ" + if levels is None: + levels = [f"level{i}" for i in range(num_levels)] + else: + levels = list(levels) + assert len(levels) == num_levels + else: + raise ValueError("Keys for cases must str or tuple") + + study_folder = Path(study_folder) study_folder.mkdir(exist_ok=False, parents=True) @@ -97,6 +119,10 @@ def create(cls, study_folder, datasets={}, cases={}): gt_sorting.save(format="numpy_folder", folder=study_folder / f"datasets/gt_sortings/{key}") + info = {} + info["levels"] = levels + (study_folder / "info.json").write_text(json.dumps(info, indent=4), encoding="utf8") + # (study_folder / "cases.jon").write_text( # json.dumps(cases, indent=4, cls=SIJsonEncoder), # encoding="utf8", @@ -111,6 +137,12 @@ def scan_folder(self): if not (self.folder / "datasets").exists(): raise ValueError(f"This is folder is not a GroundTruthStudy : {self.folder.absolute()}") + with open(self.folder / "info.json", "r") as f: + self.info = json.load(f) + if isinstance(self.levels, list): + # because tuple caoont be stored in json + self.levels = tuple(self.info["levels"]) + for rec_file in (self.folder / "datasets/recordings").glob("*.pickle"): key = rec_file.stem rec = load_extractor(rec_file) @@ -327,12 +359,9 @@ def aggregate_performance_by_unit(self, case_keys=None): perf = comp.get_performance(method="by_unit", output="pandas") if isinstance(key, str): - cols = ["level0"] - perf["level0"] = key - + perf[self.levels] = key elif isinstance(key, tuple): - cols = [f'level{i}' for i in range(len(key))] - for col, k in zip(cols, key): + for col, k in zip(self.levels, key): perf[col] = k perf = perf.reset_index() @@ -341,7 +370,7 @@ def aggregate_performance_by_unit(self, case_keys=None): perf_by_unit = pd.concat(perf_by_unit) - perf_by_unit = perf_by_unit.set_index(cols) + perf_by_unit = perf_by_unit.set_index(self.levels) return perf_by_unit @@ -354,18 +383,50 @@ def aggregate_count_units( import pandas as pd if case_keys is None: - case_keys = self.cases.keys() + case_keys = list(self.cases.keys()) + + if isinstance(case_keys[0], str): + index = pd.Index(case_keys, name=self.levels) + else: + index = pd.MultiIndex.from_tuples(case_keys, names=self.levels) + + + columns = ["num_gt", "num_sorter", "num_well_detected", "num_redundant", "num_overmerged"] + comp = self.comparisons[case_keys[0]] + if comp.exhaustive_gt: + columns.extend(["num_false_positive", "num_bad"]) + count_units = pd.DataFrame(index=index, columns=columns, dtype=int) + - perf_by_unit = [] for key in case_keys: comp = self.comparisons.get(key, None) assert comp is not None, "You need to do study.run_comparisons() first" + gt_sorting = comp.sorting1 + sorting = comp.sorting2 + + count_units.loc[key, "num_gt"] = len(gt_sorting.get_unit_ids()) + count_units.loc[key, "num_sorter"] = len(sorting.get_unit_ids()) + count_units.loc[key, "num_well_detected"] = comp.count_well_detected_units( + well_detected_score + ) + if comp.exhaustive_gt: + count_units.loc[key, "num_overmerged"] = comp.count_overmerged_units( + overmerged_score + ) + count_units.loc[key, "num_redundant"] = comp.count_redundant_units(redundant_score) + count_units.loc[key, "num_false_positive"] = comp.count_false_positive_units( + redundant_score + ) + count_units.loc[key, "num_bad"] = comp.count_bad_units() + + # count_units = pd.concat(count_units) + # count_units = count_units.set_index(cols) + return count_units - # assert self.comparisons is not None, "run_comparisons first" - # import pandas as pd + count_units = [] # index = pd.MultiIndex.from_tuples(self.computed_names, names=["rec_name", "sorter_name"]) diff --git a/src/spikeinterface/comparison/tests/test_groundtruthstudy.py b/src/spikeinterface/comparison/tests/test_groundtruthstudy.py index 3593b0b05f..5c5af476e4 100644 --- a/src/spikeinterface/comparison/tests/test_groundtruthstudy.py +++ b/src/spikeinterface/comparison/tests/test_groundtruthstudy.py @@ -84,7 +84,7 @@ def create_a_study(study_folder): # }, } - study = GroundTruthStudy.create(study_folder, datasets=datasets, cases=cases) + study = GroundTruthStudy.create(study_folder, datasets=datasets, cases=cases, levels=["sorter_name", "processing", "probe_type"]) # print(study) @@ -110,6 +110,7 @@ def test_GroundTruthStudy(): print(metrics) study.aggregate_performance_by_unit() + study.aggregate_count_units() # perf = study.aggregate_performance_by_unit() From 0750638eb13030b22ad30b9db94fa968a60c7fa2 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 11 Sep 2023 16:56:23 +0200 Subject: [PATCH 08/26] wip gtstudy --- .../comparison/groundtruthstudy.py | 53 ++----------------- .../comparison/tests/test_groundtruthstudy.py | 2 +- src/spikeinterface/widgets/widget_list.py | 3 ++ 3 files changed, 9 insertions(+), 49 deletions(-) diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index 76c019f6b9..049c97c234 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -139,9 +139,11 @@ def scan_folder(self): with open(self.folder / "info.json", "r") as f: self.info = json.load(f) - if isinstance(self.levels, list): - # because tuple caoont be stored in json - self.levels = tuple(self.info["levels"]) + + self.levels = self.info["levels"] + # if isinstance(self.levels, list): + # # because tuple caoont be stored in json + # self.levels = tuple(self.info["levels"]) for rec_file in (self.folder / "datasets/recordings").glob("*.pickle"): key = rec_file.stem @@ -371,11 +373,8 @@ def aggregate_performance_by_unit(self, case_keys=None): perf_by_unit = pd.concat(perf_by_unit) perf_by_unit = perf_by_unit.set_index(self.levels) - return perf_by_unit - # def aggregate_count_units(self, well_detected_score=None, redundant_score=None, overmerged_score=None): - def aggregate_count_units( self, case_keys=None, well_detected_score=None, redundant_score=None, overmerged_score=None ): @@ -420,51 +419,9 @@ def aggregate_count_units( ) count_units.loc[key, "num_bad"] = comp.count_bad_units() - # count_units = pd.concat(count_units) - # count_units = count_units.set_index(cols) - return count_units - count_units = [] - - # index = pd.MultiIndex.from_tuples(self.computed_names, names=["rec_name", "sorter_name"]) - - # count_units = pd.DataFrame( - # index=index, - # columns=["num_gt", "num_sorter", "num_well_detected", "num_redundant", "num_overmerged"], - # dtype=int, - # ) - - # if self.exhaustive_gt: - # count_units["num_false_positive"] = pd.Series(dtype=int) - # count_units["num_bad"] = pd.Series(dtype=int) - - # for rec_name, sorter_name, sorting in iter_computed_sorting(self.study_folder): - # gt_sorting = self.get_ground_truth(rec_name) - # comp = self.comparisons[(rec_name, sorter_name)] - - # count_units.loc[(rec_name, sorter_name), "num_gt"] = len(gt_sorting.get_unit_ids()) - # count_units.loc[(rec_name, sorter_name), "num_sorter"] = len(sorting.get_unit_ids()) - # count_units.loc[(rec_name, sorter_name), "num_well_detected"] = comp.count_well_detected_units( - # well_detected_score - # ) - # if self.exhaustive_gt: - # count_units.loc[(rec_name, sorter_name), "num_overmerged"] = comp.count_overmerged_units( - # overmerged_score - # ) - # count_units.loc[(rec_name, sorter_name), "num_redundant"] = comp.count_redundant_units(redundant_score) - # count_units.loc[(rec_name, sorter_name), "num_false_positive"] = comp.count_false_positive_units( - # redundant_score - # ) - # count_units.loc[(rec_name, sorter_name), "num_bad"] = comp.count_bad_units() - - # return count_units - - - - - diff --git a/src/spikeinterface/comparison/tests/test_groundtruthstudy.py b/src/spikeinterface/comparison/tests/test_groundtruthstudy.py index 5c5af476e4..1da79b9efe 100644 --- a/src/spikeinterface/comparison/tests/test_groundtruthstudy.py +++ b/src/spikeinterface/comparison/tests/test_groundtruthstudy.py @@ -162,6 +162,6 @@ def test_GroundTruthStudy(): setup_module() test_GroundTruthStudy() - # test_run_study_sorters() # test_extract_sortings() + diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index f3c640ff16..1e9d5301cf 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -20,6 +20,7 @@ from .unit_templates import UnitTemplatesWidget from .unit_waveforms_density_map import UnitWaveformDensityMapWidget from .unit_waveforms import UnitWaveformsWidget +from .gtstudy import StudyRunTimesWidget widget_list = [ @@ -41,6 +42,7 @@ UnitTemplatesWidget, UnitWaveformDensityMapWidget, UnitWaveformsWidget, + StudyRunTimesWidget, ] @@ -88,6 +90,7 @@ plot_unit_templates = UnitTemplatesWidget plot_unit_waveforms_density_map = UnitWaveformDensityMapWidget plot_unit_waveforms = UnitWaveformsWidget +plot_study_run_times = StudyRunTimesWidget def plot_timeseries(*args, **kwargs): From ee2eb2f04d5c17817fcb9f014f9814f5192cb624 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 12 Sep 2023 14:23:00 +0200 Subject: [PATCH 09/26] STart porting matplotlib widgets related to ground truth study. --- .../comparison/groundtruthstudy.py | 4 +- .../comparison/tests/test_groundtruthstudy.py | 48 +---- src/spikeinterface/widgets/gtstudy.py | 192 ++++++++++++++++++ src/spikeinterface/widgets/widget_list.py | 6 +- 4 files changed, 201 insertions(+), 49 deletions(-) create mode 100644 src/spikeinterface/widgets/gtstudy.py diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index 049c97c234..d936c50e5e 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -347,7 +347,7 @@ def get_units_snr(self, key): """ return self.get_metrics(key)["snr"] - def aggregate_performance_by_unit(self, case_keys=None): + def get_performance_by_unit(self, case_keys=None): import pandas as pd @@ -375,7 +375,7 @@ def aggregate_performance_by_unit(self, case_keys=None): perf_by_unit = perf_by_unit.set_index(self.levels) return perf_by_unit - def aggregate_count_units( + def get_count_units( self, case_keys=None, well_detected_score=None, redundant_score=None, overmerged_score=None ): diff --git a/src/spikeinterface/comparison/tests/test_groundtruthstudy.py b/src/spikeinterface/comparison/tests/test_groundtruthstudy.py index 1da79b9efe..52d5c73d3b 100644 --- a/src/spikeinterface/comparison/tests/test_groundtruthstudy.py +++ b/src/spikeinterface/comparison/tests/test_groundtruthstudy.py @@ -109,54 +109,10 @@ def test_GroundTruthStudy(): metrics = study.get_metrics(key) print(metrics) - study.aggregate_performance_by_unit() - study.aggregate_count_units() + study.get_performance_by_unit() + study.get_count_units() -# perf = study.aggregate_performance_by_unit() -# count_units = study.aggregate_count_units() - - - -# @pytest.mark.skipif(not HAVE_TDC, reason="Test requires Python package 'tridesclous'") -# def test_run_study_sorters(): -# study = GroundTruthStudy(study_folder) -# sorter_list = [ -# "tridesclous", -# ] -# print( -# f"\n#################################\nINSTALLED SORTERS\n#################################\n" -# f"{installed_sorters()}" -# ) -# study.run_sorters(sorter_list) - - -# @pytest.mark.skipif(not HAVE_TDC, reason="Test requires Python package 'tridesclous'") -# def test_extract_sortings(): -# study = GroundTruthStudy(study_folder) - -# study.copy_sortings() - -# for rec_name in study.rec_names: -# gt_sorting = study.get_ground_truth(rec_name) - -# for rec_name in study.rec_names: -# metrics = study.get_metrics(rec_name=rec_name) - -# snr = study.get_units_snr(rec_name=rec_name) - -# study.copy_sortings() - -# run_times = study.aggregate_run_times() - -# study.run_comparisons(exhaustive_gt=True) - -# perf = study.aggregate_performance_by_unit() - -# count_units = study.aggregate_count_units() -# dataframes = study.aggregate_dataframes() -# print(dataframes) - if __name__ == "__main__": setup_module() diff --git a/src/spikeinterface/widgets/gtstudy.py b/src/spikeinterface/widgets/gtstudy.py new file mode 100644 index 0000000000..aa1a80c3d3 --- /dev/null +++ b/src/spikeinterface/widgets/gtstudy.py @@ -0,0 +1,192 @@ +import numpy as np + +from .base import BaseWidget, to_attr +from .utils import get_unit_colors + +from ..core import ChannelSparsity +from ..core.waveform_extractor import WaveformExtractor +from ..core.basesorting import BaseSorting + + +class StudyRunTimesWidget(BaseWidget): + """ + Plot sorter run times for a GroundTruthStudy + + + Parameters + ---------- + study: GroundTruthStudy + A study object. + case_keys: list or None + A selection of cases to plot, if None, then all. + + """ + + def __init__( + self, + study, + case_keys=None, + backend=None, + **backend_kwargs, + ): + + if case_keys is None: + case_keys = list(study.cases.keys()) + + plot_data = dict( + study=study, + run_times=study.get_run_times(), + case_keys=case_keys, + ) + + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + + dp = to_attr(data_plot) + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + for i, key in enumerate(dp.case_keys): + label = dp.study.cases[key]["label"] + rt = dp.run_times.loc[key] + self.ax.bar(i, rt, width=0.8, label=label) + + self.ax.legend() + + + +# TODO : plot optionally average on some levels using group by +class StudyUnitCountsWidget(BaseWidget): + """ + Plot unit counts for a study: "num_well_detected", "num_false_positive", "num_redundant", "num_overmerged" + + + Parameters + ---------- + study: GroundTruthStudy + A study object. + case_keys: list or None + A selection of cases to plot, if None, then all. + + """ + + def __init__( + self, + study, + case_keys=None, + backend=None, + **backend_kwargs, + ): + + if case_keys is None: + case_keys = list(study.cases.keys()) + + plot_data = dict( + study=study, + count_units = study.get_count_units(case_keys=case_keys), + case_keys=case_keys, + ) + + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + from .utils import get_some_colors + + dp = to_attr(data_plot) + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + columns = dp.count_units.columns.tolist() + columns.remove("num_gt") + columns.remove("num_sorter") + + ncol = len(columns) + + colors = get_some_colors(columns, color_engine="auto", + map_name="hot") + colors["num_well_detected"] = "green" + + xticklabels = [] + for i, key in enumerate(dp.case_keys): + for c, col in enumerate(columns): + x = i + 1 + c / (ncol + 1) + y = dp.count_units.loc[key, col] + if not "well_detected" in col: + y = -y + + if i == 0: + label = col.replace("num_", "").replace("_", " ").title() + else: + label = None + + self.ax.bar([x], [y], width=1 / (ncol + 2), label=label, color=colors[col]) + + xticklabels.append(dp.study.cases[key]["label"]) + + self.ax.set_xticks(np.arange(len(dp.case_keys)) + 1) + self.ax.set_xticklabels(xticklabels) + self.ax.legend() + + +# TODO : plot optionally average on some levels using group by +class StudyPerformances(BaseWidget): + """ + Plot performances over case for a study. + + + Parameters + ---------- + study: GroundTruthStudy + A study object. + mode: str + Which mode in "swarm" + case_keys: list or None + A selection of cases to plot, if None, then all. + + """ + + def __init__( + self, + study, + mode="swarm", + case_keys=None, + backend=None, + **backend_kwargs, + ): + + if case_keys is None: + case_keys = list(study.cases.keys()) + + plot_data = dict( + study=study, + perfs=study.get_performance_by_unit(case_keys=case_keys), + mode=mode, + case_keys=case_keys, + ) + + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + from .utils import get_some_colors + + import pandas as pd + import seaborn as sns + + dp = to_attr(data_plot) + perfs = dp.perfs + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + if dp.mode == "swarm": + levels = perfs.index.names + df = pd.melt(perfs.reset_index(), id_vars=levels, var_name='Metric', value_name='Score', + value_vars=('accuracy','precision', 'recall')) + df['x'] = df.apply(lambda r: ' '.join([r[col] for col in levels]), axis=1) + sns.swarmplot(data=df, x='x', y='Score', hue='Metric', dodge=True) diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index 1e9d5301cf..4bc91e0737 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -20,7 +20,7 @@ from .unit_templates import UnitTemplatesWidget from .unit_waveforms_density_map import UnitWaveformDensityMapWidget from .unit_waveforms import UnitWaveformsWidget -from .gtstudy import StudyRunTimesWidget +from .gtstudy import StudyRunTimesWidget, StudyUnitCountsWidget, StudyPerformances widget_list = [ @@ -43,6 +43,8 @@ UnitWaveformDensityMapWidget, UnitWaveformsWidget, StudyRunTimesWidget, + StudyUnitCountsWidget, + StudyPerformances ] @@ -91,6 +93,8 @@ plot_unit_waveforms_density_map = UnitWaveformDensityMapWidget plot_unit_waveforms = UnitWaveformsWidget plot_study_run_times = StudyRunTimesWidget +plot_study_unit_counts = StudyUnitCountsWidget +plot_study_performances = StudyPerformances def plot_timeseries(*args, **kwargs): From d80341ca2cd84852988cc5704bafc1c0a6d16540 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 12 Sep 2023 18:16:26 +0200 Subject: [PATCH 10/26] remove gtstudy widgets from legacy and port some of then in the API. --- .../comparison/tests/test_groundtruthstudy.py | 15 +- .../widgets/_legacy_mpl_widgets/__init__.py | 16 - .../widgets/_legacy_mpl_widgets/gtstudy.py | 574 ------------------ src/spikeinterface/widgets/gtstudy.py | 60 ++ src/spikeinterface/widgets/widget_list.py | 6 +- 5 files changed, 66 insertions(+), 605 deletions(-) delete mode 100644 src/spikeinterface/widgets/_legacy_mpl_widgets/gtstudy.py diff --git a/src/spikeinterface/comparison/tests/test_groundtruthstudy.py b/src/spikeinterface/comparison/tests/test_groundtruthstudy.py index 52d5c73d3b..a75ac272be 100644 --- a/src/spikeinterface/comparison/tests/test_groundtruthstudy.py +++ b/src/spikeinterface/comparison/tests/test_groundtruthstudy.py @@ -1,20 +1,11 @@ -import importlib import shutil import pytest from pathlib import Path -# from spikeinterface.extractors import toy_example from spikeinterface import generate_ground_truth_recording from spikeinterface.preprocessing import bandpass_filter -from spikeinterface.sorters import installed_sorters from spikeinterface.comparison import GroundTruthStudy -# try: -# import tridesclous - -# HAVE_TDC = True -# except ImportError: -# HAVE_TDC = False if hasattr(pytest, "global_test_folder"): @@ -71,7 +62,7 @@ def create_a_study(study_folder): }, }, - # we comment this at the moement because SC2 is quite slow for testing + # we comment this at the moement because SC2 is quite slow for testing # ("sc2", "no-preprocess", "tetrode"): { # "label": "spykingcircus2 without preprocessing standar params", # "dataset": "toy_tetrode", @@ -118,6 +109,4 @@ def test_GroundTruthStudy(): setup_module() test_GroundTruthStudy() - # test_run_study_sorters() - # test_extract_sortings() - + \ No newline at end of file diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py index c0dcd7ea6e..bf28c891f5 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py @@ -41,22 +41,6 @@ from .sortingperformance import plot_sorting_performance -# ground truth study (=comparison over sorter) -from .gtstudy import ( - StudyComparisonRunTimesWidget, - plot_gt_study_run_times, - StudyComparisonUnitCountsWidget, - StudyComparisonUnitCountsAveragesWidget, - plot_gt_study_unit_counts, - plot_gt_study_unit_counts_averages, - plot_gt_study_performances, - plot_gt_study_performances_averages, - StudyComparisonPerformancesWidget, - StudyComparisonPerformancesAveragesWidget, - plot_gt_study_performances_by_template_similarity, - StudyComparisonPerformancesByTemplateSimilarity, -) - # ground truth comparions (=comparison over sorter) from .gtcomparison import ( plot_gt_performances, diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/gtstudy.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/gtstudy.py deleted file mode 100644 index 573221f528..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/gtstudy.py +++ /dev/null @@ -1,574 +0,0 @@ -""" -Various widgets on top of GroundTruthStudy to summary results: - * run times - * performances - * count units -""" -import numpy as np - - -from .basewidget import BaseWidget - - -class StudyComparisonRunTimesWidget(BaseWidget): - """ - Plot run times for a study. - - Parameters - ---------- - study: GroundTruthStudy - The study object to consider - ax: matplotlib ax - The ax to be used. If not given a figure is created - color: - - - """ - - def __init__(self, study, color="#F7DC6F", ax=None): - from matplotlib import pyplot as plt - import pandas as pd - - self.study = study - self.color = color - - BaseWidget.__init__(self, ax=ax) - - def plot(self): - study = self.study - ax = self.ax - - all_run_times = study.aggregate_run_times() - av_run_times = all_run_times.reset_index().groupby("sorter_name")["run_time"].mean() - - if len(study.rec_names) == 1: - # no errors bars - yerr = None - else: - # errors bars across recording - yerr = all_run_times.reset_index().groupby("sorter_name")["run_time"].std() - - sorter_names = av_run_times.index - - x = np.arange(sorter_names.size) + 1 - ax.bar(x, av_run_times.values, width=0.8, color=self.color, yerr=yerr) - ax.set_ylabel("run times (s)") - ax.set_xticks(x) - ax.set_xticklabels(sorter_names, rotation=45) - ax.set_xlim(0, sorter_names.size + 1) - - -def plot_gt_study_run_times(*args, **kwargs): - W = StudyComparisonRunTimesWidget(*args, **kwargs) - W.plot() - return W - - -plot_gt_study_run_times.__doc__ = StudyComparisonRunTimesWidget.__doc__ - - -class StudyComparisonUnitCountsAveragesWidget(BaseWidget): - """ - Plot averages over found units for a study. - - Parameters - ---------- - study: GroundTruthStudy - The study object to consider - ax: matplotlib ax - The ax to be used. If not given a figure is created - cmap_name - log_scale: if the y-axis should be displayed as log scaled - - """ - - def __init__(self, study, cmap_name="Set2", log_scale=False, ax=None): - from matplotlib import pyplot as plt - import pandas as pd - - self.study = study - self.cmap_name = cmap_name - self.log_scale = log_scale - - BaseWidget.__init__(self, ax=ax) - - def plot(self): - study = self.study - ax = self.ax - - count_units = study.aggregate_count_units() - - if study.exhaustive_gt: - columns = ["num_well_detected", "num_false_positive", "num_redundant", "num_overmerged"] - else: - columns = ["num_well_detected", "num_redundant", "num_overmerged"] - ncol = len(columns) - - df = count_units.reset_index() - - m = df.groupby("sorter_name")[columns].mean() - - cmap = plt.get_cmap(self.cmap_name, 4) - - if len(study.rec_names) == 1: - # no errors bars - stds = None - else: - # errors bars across recording - stds = df.groupby("sorter_name")[columns].std() - - sorter_names = m.index - clean_labels = [col.replace("num_", "").replace("_", " ").title() for col in columns] - - for c, col in enumerate(columns): - x = np.arange(sorter_names.size) + 1 + c / (ncol + 2) - if stds is None: - yerr = None - else: - yerr = stds[col].values - ax.bar(x, m[col].values, yerr=yerr, width=1 / (ncol + 2), color=cmap(c), label=clean_labels[c]) - - ax.legend() - if self.log_scale: - ax.set_yscale("log") - - ax.set_xticks(np.arange(sorter_names.size) + 1) - ax.set_xticklabels(sorter_names, rotation=45) - ax.set_ylabel("# units") - ax.set_xlim(0, sorter_names.size + 1) - - if count_units["num_gt"].unique().size == 1: - num_gt = count_units["num_gt"].unique()[0] - ax.axhline(num_gt, ls="--", color="k") - - -class StudyComparisonUnitCountsWidget(BaseWidget): - """ - Plot averages over found units for a study. - - Parameters - ---------- - study: GroundTruthStudy - The study object to consider - ax: matplotlib ax - The ax to be used. If not given a figure is created - cmap_name - log_scale: if the y-axis should be displayed as log scaled - - """ - - def __init__(self, study, cmap_name="Set2", log_scale=False, ax=None): - from matplotlib import pyplot as plt - import pandas as pd - - self.study = study - self.cmap_name = cmap_name - self.log_scale = log_scale - - num_rec = len(study.rec_names) - if ax is None: - fig, axes = plt.subplots(ncols=1, nrows=num_rec, squeeze=False) - else: - axes = np.array([ax]).T - - BaseWidget.__init__(self, axes=axes) - - def plot(self): - study = self.study - ax = self.ax - - import seaborn as sns - - study = self.study - - count_units = study.aggregate_count_units() - count_units = count_units.reset_index() - - if study.exhaustive_gt: - columns = ["num_well_detected", "num_false_positive", "num_redundant", "num_overmerged"] - else: - columns = ["num_well_detected", "num_redundant", "num_overmerged"] - - ncol = len(columns) - cmap = plt.get_cmap(self.cmap_name, 4) - - for r, rec_name in enumerate(study.rec_names): - ax = self.axes[r, 0] - ax.set_title(rec_name) - df = count_units.loc[count_units["rec_name"] == rec_name, :] - m = df.groupby("sorter_name")[columns].mean() - sorter_names = m.index - clean_labels = [col.replace("num_", "").replace("_", " ").title() for col in columns] - - for c, col in enumerate(columns): - x = np.arange(sorter_names.size) + 1 + c / (ncol + 2) - ax.bar(x, m[col].values, width=1 / (ncol + 2), color=cmap(c), label=clean_labels[c]) - - if r == 0: - ax.legend() - - if self.log_scale: - ax.set_yscale("log") - - if r == len(study.rec_names) - 1: - ax.set_xticks(np.arange(sorter_names.size) + 1) - ax.set_xticklabels(sorter_names, rotation=45) - ax.set_ylabel("# units") - ax.set_xlim(0, sorter_names.size + 1) - - if count_units["num_gt"].unique().size == 1: - num_gt = count_units["num_gt"].unique()[0] - ax.axhline(num_gt, ls="--", color="k") - - -def plot_gt_study_unit_counts_averages(*args, **kwargs): - W = StudyComparisonUnitCountsAveragesWidget(*args, **kwargs) - W.plot() - return W - - -plot_gt_study_unit_counts_averages.__doc__ = StudyComparisonUnitCountsAveragesWidget.__doc__ - - -def plot_gt_study_unit_counts(*args, **kwargs): - W = StudyComparisonUnitCountsWidget(*args, **kwargs) - W.plot() - return W - - -plot_gt_study_unit_counts.__doc__ = StudyComparisonUnitCountsWidget.__doc__ - - -class StudyComparisonPerformancesWidget(BaseWidget): - """ - Plot run times for a study. - - Parameters - ---------- - study: GroundTruthStudy - The study object to consider - ax: matplotlib ax - The ax to be used. If not given a figure is created - cmap_name - - """ - - def __init__(self, study, palette="Set1", ax=None): - from matplotlib import pyplot as plt - import pandas as pd - - self.study = study - self.palette = palette - - num_rec = len(study.rec_names) - if ax is None: - fig, axes = plt.subplots(ncols=1, nrows=num_rec, squeeze=False) - else: - axes = np.array([ax]).T - - BaseWidget.__init__(self, axes=axes) - - def plot(self, average=False): - import seaborn as sns - - study = self.study - - sns.set_palette(sns.color_palette(self.palette)) - - perf_by_units = study.aggregate_performance_by_unit() - perf_by_units = perf_by_units.reset_index() - - for r, rec_name in enumerate(study.rec_names): - ax = self.axes[r, 0] - ax.set_title(rec_name) - df = perf_by_units.loc[perf_by_units["rec_name"] == rec_name, :] - df = pd.melt( - df, - id_vars="sorter_name", - var_name="Metric", - value_name="Score", - value_vars=("accuracy", "precision", "recall"), - ).sort_values("sorter_name") - sns.swarmplot( - data=df, x="sorter_name", y="Score", hue="Metric", dodge=True, s=3, ax=ax - ) # order=sorter_list, - # ~ ax.set_xticklabels(sorter_names_short, rotation=30, ha='center') - # ~ ax.legend(bbox_to_anchor=(1.0, 1), loc=2, borderaxespad=0., frameon=False, fontsize=8, markerscale=0.5) - - ax.set_ylim(0, 1.05) - ax.set_ylabel(f"Perfs for {rec_name}") - if r < len(study.rec_names) - 1: - ax.set_xlabel("") - ax.set(xticklabels=[]) - - -class StudyComparisonTemplateSimilarityWidget(BaseWidget): - """ - Plot run times for a study. - - Parameters - ---------- - study: GroundTruthStudy - The study object to consider - ax: matplotlib ax - The ax to be used. If not given a figure is created - cmap_name - - """ - - def __init__(self, study, cmap_name="Set1", ax=None): - from matplotlib import pyplot as plt - import pandas as pd - - self.study = study - self.cmap_name = cmap_name - - BaseWidget.__init__(self, ax=ax) - - def plot(self): - import seaborn as sns - - study = self.study - ax = self.ax - - perf_by_units = study.aggregate_performance_by_unit() - perf_by_units = perf_by_units.reset_index() - - columns = ["accuracy", "precision", "recall"] - to_agg = {} - ncol = len(columns) - - for column in columns: - perf_by_units[column] = pd.to_numeric(perf_by_units[column], downcast="float") - to_agg[column] = ["mean"] - - data = perf_by_units.groupby(["sorter_name", "rec_name"]).agg(to_agg) - - m = data.groupby("sorter_name").mean() - - cmap = plt.get_cmap(self.cmap_name, 4) - - if len(study.rec_names) == 1: - # no errors bars - stds = None - else: - # errors bars across recording - stds = data.groupby("sorter_name").std() - - sorter_names = m.index - clean_labels = [col.replace("num_", "").replace("_", " ").title() for col in columns] - - width = 1 / (ncol + 2) - - for c, col in enumerate(columns): - x = np.arange(sorter_names.size) + 1 + c / (ncol + 2) - if stds is None: - yerr = None - else: - yerr = stds[col].values - ax.bar(x, m[col].values.flatten(), yerr=yerr.flatten(), width=width, color=cmap(c), label=clean_labels[c]) - - ax.legend() - - ax.set_xticks(np.arange(sorter_names.size) + 1 + width) - ax.set_xticklabels(sorter_names, rotation=45) - ax.set_ylabel("metric") - ax.set_xlim(0, sorter_names.size + 1) - - -class StudyComparisonPerformancesAveragesWidget(BaseWidget): - """ - Plot run times for a study. - - Parameters - ---------- - study: GroundTruthStudy - The study object to consider - ax: matplotlib ax - The ax to be used. If not given a figure is created - cmap_name - - """ - - def __init__(self, study, cmap_name="Set1", ax=None): - from matplotlib import pyplot as plt - import pandas as pd - - self.study = study - self.cmap_name = cmap_name - - BaseWidget.__init__(self, ax=ax) - - def plot(self): - import seaborn as sns - - study = self.study - ax = self.ax - - perf_by_units = study.aggregate_performance_by_unit() - perf_by_units = perf_by_units.reset_index() - - columns = ["accuracy", "precision", "recall"] - to_agg = {} - ncol = len(columns) - - for column in columns: - perf_by_units[column] = pd.to_numeric(perf_by_units[column], downcast="float") - to_agg[column] = ["mean"] - - data = perf_by_units.groupby(["sorter_name", "rec_name"]).agg(to_agg) - - m = data.groupby("sorter_name").mean() - - cmap = plt.get_cmap(self.cmap_name, 4) - - if len(study.rec_names) == 1: - # no errors bars - stds = None - else: - # errors bars across recording - stds = data.groupby("sorter_name").std() - - sorter_names = m.index - clean_labels = [col.replace("num_", "").replace("_", " ").title() for col in columns] - - width = 1 / (ncol + 2) - - for c, col in enumerate(columns): - x = np.arange(sorter_names.size) + 1 + c / (ncol + 2) - if stds is None: - yerr = None - else: - yerr = stds[col].values - ax.bar(x, m[col].values.flatten(), yerr=yerr.flatten(), width=width, color=cmap(c), label=clean_labels[c]) - - ax.legend() - - ax.set_xticks(np.arange(sorter_names.size) + 1 + width) - ax.set_xticklabels(sorter_names, rotation=45) - ax.set_ylabel("metric") - ax.set_xlim(0, sorter_names.size + 1) - - -class StudyComparisonPerformancesByTemplateSimilarity(BaseWidget): - """ - Plot run times for a study. - - Parameters - ---------- - study: GroundTruthStudy - The study object to consider - ax: matplotlib ax - The ax to be used. If not given a figure is created - cmap_name - - """ - - def __init__(self, study, cmap_name="Set1", ax=None, ylim=(0.6, 1), show_legend=True): - from matplotlib import pyplot as plt - import pandas as pd - - self.study = study - self.cmap_name = cmap_name - self.show_legend = show_legend - self.ylim = ylim - - BaseWidget.__init__(self, ax=ax) - - def plot(self): - import sklearn - - cmap = plt.get_cmap(self.cmap_name, len(self.study.sorter_names)) - colors = [cmap(i) for i in range(len(self.study.sorter_names))] - - flat_templates_gt = {} - for rec_name in self.study.rec_names: - waveform_folder = self.study.study_folder / "waveforms" / f"waveforms_GroundTruth_{rec_name}" - if not waveform_folder.is_dir(): - self.study.compute_waveforms(rec_name) - - templates = self.study.get_templates(rec_name) - flat_templates_gt[rec_name] = templates.reshape(templates.shape[0], -1) - - all_results = {} - - for sorter_name in self.study.sorter_names: - all_results[sorter_name] = {"similarity": [], "accuracy": []} - - for rec_name in self.study.rec_names: - try: - waveform_folder = self.study.study_folder / "waveforms" / f"waveforms_{sorter_name}_{rec_name}" - if not waveform_folder.is_dir(): - self.study.compute_waveforms(rec_name, sorter_name) - templates = self.study.get_templates(rec_name, sorter_name) - flat_templates = templates.reshape(templates.shape[0], -1) - similarity_matrix = sklearn.metrics.pairwise.cosine_similarity( - flat_templates_gt[rec_name], flat_templates - ) - - comp = self.study.comparisons[(rec_name, sorter_name)] - - for i, u1 in enumerate(comp.sorting1.unit_ids): - u2 = comp.best_match_12[u1] - if u2 != -1: - all_results[sorter_name]["similarity"] += [ - similarity_matrix[comp.sorting1.id_to_index(u1), comp.sorting2.id_to_index(u2)] - ] - all_results[sorter_name]["accuracy"] += [comp.agreement_scores.at[u1, u2]] - except Exception: - pass - - all_results[sorter_name]["similarity"] = np.array(all_results[sorter_name]["similarity"]) - all_results[sorter_name]["accuracy"] = np.array(all_results[sorter_name]["accuracy"]) - - from matplotlib.patches import Ellipse - - similarity_means = [all_results[sorter_name]["similarity"].mean() for sorter_name in self.study.sorter_names] - similarity_stds = [all_results[sorter_name]["similarity"].std() for sorter_name in self.study.sorter_names] - - accuracy_means = [all_results[sorter_name]["accuracy"].mean() for sorter_name in self.study.sorter_names] - accuracy_stds = [all_results[sorter_name]["accuracy"].std() for sorter_name in self.study.sorter_names] - - scount = 0 - for x, y, i, j in zip(similarity_means, accuracy_means, similarity_stds, accuracy_stds): - e = Ellipse((x, y), i, j) - e.set_alpha(0.2) - e.set_facecolor(colors[scount]) - self.ax.add_artist(e) - self.ax.scatter([x], [y], c=colors[scount], label=self.study.sorter_names[scount]) - scount += 1 - - self.ax.set_ylabel("accuracy") - self.ax.set_xlabel("cosine similarity") - if self.ylim is not None: - self.ax.set_ylim(self.ylim) - - if self.show_legend: - self.ax.legend() - - -def plot_gt_study_performances(*args, **kwargs): - W = StudyComparisonPerformancesWidget(*args, **kwargs) - W.plot() - return W - - -plot_gt_study_performances.__doc__ = StudyComparisonPerformancesWidget.__doc__ - - -def plot_gt_study_performances_averages(*args, **kwargs): - W = StudyComparisonPerformancesAveragesWidget(*args, **kwargs) - W.plot() - return W - - -plot_gt_study_performances_averages.__doc__ = StudyComparisonPerformancesAveragesWidget.__doc__ - - -def plot_gt_study_performances_by_template_similarity(*args, **kwargs): - W = StudyComparisonPerformancesByTemplateSimilarity(*args, **kwargs) - W.plot() - return W - - -plot_gt_study_performances_by_template_similarity.__doc__ = StudyComparisonPerformancesByTemplateSimilarity.__doc__ diff --git a/src/spikeinterface/widgets/gtstudy.py b/src/spikeinterface/widgets/gtstudy.py index aa1a80c3d3..304cf1a44a 100644 --- a/src/spikeinterface/widgets/gtstudy.py +++ b/src/spikeinterface/widgets/gtstudy.py @@ -190,3 +190,63 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): value_vars=('accuracy','precision', 'recall')) df['x'] = df.apply(lambda r: ' '.join([r[col] for col in levels]), axis=1) sns.swarmplot(data=df, x='x', y='Score', hue='Metric', dodge=True) + + + +class StudyPerformancesVsMetrics(BaseWidget): + """ + Plot performances vs a metrics (snr for instance) over case for a study. + + + Parameters + ---------- + study: GroundTruthStudy + A study object. + mode: str + Which mode in "swarm" + case_keys: list or None + A selection of cases to plot, if None, then all. + + """ + + def __init__( + self, + study, + metric_name="snr", + performance_name="accuracy", + case_keys=None, + backend=None, + **backend_kwargs, + ): + + if case_keys is None: + case_keys = list(study.cases.keys()) + + plot_data = dict( + study=study, + metric_name=metric_name, + performance_name=performance_name, + case_keys=case_keys, + ) + + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + from .utils import get_some_colors + + dp = to_attr(data_plot) + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + + study = dp.study + perfs = study.get_performance_by_unit(case_keys=dp.case_keys) + + for key in dp.case_keys: + x = study.get_metrics(key)[dp.metric_name].values + y = perfs.xs(key)[dp.performance_name].values + label = dp.study.cases[key]["label"] + self.ax.scatter(x, y, label=label) + + self.ax.legend() \ No newline at end of file diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index 4bc91e0737..3a1bdd12dc 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -20,7 +20,7 @@ from .unit_templates import UnitTemplatesWidget from .unit_waveforms_density_map import UnitWaveformDensityMapWidget from .unit_waveforms import UnitWaveformsWidget -from .gtstudy import StudyRunTimesWidget, StudyUnitCountsWidget, StudyPerformances +from .gtstudy import StudyRunTimesWidget, StudyUnitCountsWidget, StudyPerformances, StudyPerformancesVsMetrics widget_list = [ @@ -44,7 +44,8 @@ UnitWaveformsWidget, StudyRunTimesWidget, StudyUnitCountsWidget, - StudyPerformances + StudyPerformances, + StudyPerformancesVsMetrics ] @@ -95,6 +96,7 @@ plot_study_run_times = StudyRunTimesWidget plot_study_unit_counts = StudyUnitCountsWidget plot_study_performances = StudyPerformances +plot_stufy_performances_vs_metrics = StudyPerformancesVsMetrics def plot_timeseries(*args, **kwargs): From f97f76a7948f87cdf6873ce0a0b378f1120040b7 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 12 Sep 2023 18:23:43 +0200 Subject: [PATCH 11/26] Clean --- .../comparison/groundtruthstudy.py | 340 +----------------- 1 file changed, 10 insertions(+), 330 deletions(-) diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index d936c50e5e..8d43fb5f0c 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -16,19 +16,10 @@ from .paircomparisons import compare_sorter_to_ground_truth, GroundTruthComparison -# from .studytools import ( -# setup_comparison_study, -# get_rec_names, -# get_recordings, -# iter_working_folder, -# iter_computed_names, -# iter_computed_sorting, -# collect_run_times, -# ) - -# TODO : save comparison in folders -# TODO : find a way to set level names +# TODO : save comparison in folders when COmparison object will be able to serialize +# TODO ??: make an internal optional binary copy when running several external sorter +# on the same dataset to avoid multiple save binary ? even when the recording is float32 (ks need int16) @@ -48,17 +39,16 @@ class GroundTruthStudy: * parameters of comparisons * any combination of theses - For enough flexibility cases key can be a tuple so that we can varify complexity along several - "axis" (paremeters or sorter) + For enough flexibility cases key can be a tuple so that we can varify complexity along several + "levels" or "axis" (paremeters or sorter). + + Generated dataframes will have index with several levels optionaly. - Ground truth dataset need recording+sorting. This can be from meraec file or from the internal generator + Ground truth dataset need recording+sorting. This can be from mearec file or from the internal generator :py:fun:`generate_ground_truth_recording()` This GroundTruthStudy have been refactor in version 0.100 to be more flexible than previous versions. - Folders structures are not backward compatible. - - - + Folders structures are not backward compatible at all. """ def __init__(self, study_folder): self.folder = Path(study_folder) @@ -105,8 +95,6 @@ def create(cls, study_folder, datasets={}, cases={}, levels=None): (study_folder / "sortings" / "run_logs").mkdir() (study_folder / "metrics").mkdir() - - for key, (rec, gt_sorting) in datasets.items(): assert "/" not in key assert "\\" not in key @@ -341,7 +329,7 @@ def get_metrics(self, key): recording, gt_sorting = self.datasets[dataset_key] metrics.index = gt_sorting.unit_ids return metrics - + def get_units_snr(self, key): """ """ @@ -369,8 +357,6 @@ def get_performance_by_unit(self, case_keys=None): perf = perf.reset_index() perf_by_unit.append(perf) - - perf_by_unit = pd.concat(perf_by_unit) perf_by_unit = perf_by_unit.set_index(self.levels) return perf_by_unit @@ -421,309 +407,3 @@ def get_count_units( return count_units - - - - -class OLDGroundTruthStudy: - def __init__(self, study_folder=None): - import pandas as pd - - self.study_folder = Path(study_folder) - self._is_scanned = False - self.computed_names = None - self.rec_names = None - self.sorter_names = None - - self.scan_folder() - - self.comparisons = None - self.exhaustive_gt = None - - def __repr__(self): - t = "Ground truth study\n" - t += " " + str(self.study_folder) + "\n" - t += " recordings: {} {}\n".format(len(self.rec_names), self.rec_names) - if len(self.sorter_names): - t += " sorters: {} {}\n".format(len(self.sorter_names), self.sorter_names) - - return t - - def scan_folder(self): - self.rec_names = get_rec_names(self.study_folder) - # scan computed names - self.computed_names = list(iter_computed_names(self.study_folder)) # list of pair (rec_name, sorter_name) - self.sorter_names = np.unique([e for _, e in iter_computed_names(self.study_folder)]).tolist() - self._is_scanned = True - - @classmethod - def create(cls, study_folder, gt_dict, **job_kwargs): - setup_comparison_study(study_folder, gt_dict, **job_kwargs) - return cls(study_folder) - - def run_sorters(self, sorter_list, mode_if_folder_exists="keep", remove_sorter_folders=False, **kwargs): - sorter_folders = self.study_folder / "sorter_folders" - recording_dict = get_recordings(self.study_folder) - - run_sorters( - sorter_list, - recording_dict, - sorter_folders, - with_output=False, - mode_if_folder_exists=mode_if_folder_exists, - **kwargs, - ) - - # results are copied so the heavy sorter_folders can be removed - self.copy_sortings() - - if remove_sorter_folders: - shutil.rmtree(self.study_folder / "sorter_folders") - - def _check_rec_name(self, rec_name): - if not self._is_scanned: - self.scan_folder() - if len(self.rec_names) > 1 and rec_name is None: - raise Exception("Pass 'rec_name' parameter to select which recording to use.") - elif len(self.rec_names) == 1: - rec_name = self.rec_names[0] - else: - rec_name = self.rec_names[self.rec_names.index(rec_name)] - return rec_name - - def get_ground_truth(self, rec_name=None): - rec_name = self._check_rec_name(rec_name) - sorting = load_extractor(self.study_folder / "ground_truth" / rec_name) - return sorting - - def get_recording(self, rec_name=None): - rec_name = self._check_rec_name(rec_name) - rec = load_extractor(self.study_folder / "raw_files" / rec_name) - return rec - - def get_sorting(self, sort_name, rec_name=None): - rec_name = self._check_rec_name(rec_name) - - selected_sorting = None - if sort_name in self.sorter_names: - for r_name, sorter_name, sorting in iter_computed_sorting(self.study_folder): - if sort_name == sorter_name and r_name == rec_name: - selected_sorting = sorting - return selected_sorting - - def copy_sortings(self): - sorter_folders = self.study_folder / "sorter_folders" - sorting_folders = self.study_folder / "sortings" - log_olders = self.study_folder / "sortings" / "run_log" - - log_olders.mkdir(parents=True, exist_ok=True) - - for rec_name, sorter_name, output_folder in iter_working_folder(sorter_folders): - SorterClass = sorter_dict[sorter_name] - fname = rec_name + "[#]" + sorter_name - npz_filename = sorting_folders / (fname + ".npz") - - try: - sorting = SorterClass.get_result_from_folder(output_folder) - NpzSortingExtractor.write_sorting(sorting, npz_filename) - except: - if npz_filename.is_file(): - npz_filename.unlink() - if (output_folder / "spikeinterface_log.json").is_file(): - shutil.copyfile( - output_folder / "spikeinterface_log.json", sorting_folders / "run_log" / (fname + ".json") - ) - - self.scan_folder() - - def run_comparisons(self, exhaustive_gt=False, **kwargs): - self.comparisons = {} - for rec_name, sorter_name, sorting in iter_computed_sorting(self.study_folder): - gt_sorting = self.get_ground_truth(rec_name) - sc = compare_sorter_to_ground_truth(gt_sorting, sorting, exhaustive_gt=exhaustive_gt, **kwargs) - self.comparisons[(rec_name, sorter_name)] = sc - self.exhaustive_gt = exhaustive_gt - - def aggregate_run_times(self): - return collect_run_times(self.study_folder) - - def aggregate_performance_by_unit(self): - assert self.comparisons is not None, "run_comparisons first" - - perf_by_unit = [] - for rec_name, sorter_name, sorting in iter_computed_sorting(self.study_folder): - comp = self.comparisons[(rec_name, sorter_name)] - - perf = comp.get_performance(method="by_unit", output="pandas") - perf["rec_name"] = rec_name - perf["sorter_name"] = sorter_name - perf = perf.reset_index() - perf_by_unit.append(perf) - - import pandas as pd - - perf_by_unit = pd.concat(perf_by_unit) - perf_by_unit = perf_by_unit.set_index(["rec_name", "sorter_name", "gt_unit_id"]) - - return perf_by_unit - - def aggregate_count_units(self, well_detected_score=None, redundant_score=None, overmerged_score=None): - assert self.comparisons is not None, "run_comparisons first" - - import pandas as pd - - index = pd.MultiIndex.from_tuples(self.computed_names, names=["rec_name", "sorter_name"]) - - count_units = pd.DataFrame( - index=index, - columns=["num_gt", "num_sorter", "num_well_detected", "num_redundant", "num_overmerged"], - dtype=int, - ) - - if self.exhaustive_gt: - count_units["num_false_positive"] = pd.Series(dtype=int) - count_units["num_bad"] = pd.Series(dtype=int) - - for rec_name, sorter_name, sorting in iter_computed_sorting(self.study_folder): - gt_sorting = self.get_ground_truth(rec_name) - comp = self.comparisons[(rec_name, sorter_name)] - - count_units.loc[(rec_name, sorter_name), "num_gt"] = len(gt_sorting.get_unit_ids()) - count_units.loc[(rec_name, sorter_name), "num_sorter"] = len(sorting.get_unit_ids()) - count_units.loc[(rec_name, sorter_name), "num_well_detected"] = comp.count_well_detected_units( - well_detected_score - ) - if self.exhaustive_gt: - count_units.loc[(rec_name, sorter_name), "num_overmerged"] = comp.count_overmerged_units( - overmerged_score - ) - count_units.loc[(rec_name, sorter_name), "num_redundant"] = comp.count_redundant_units(redundant_score) - count_units.loc[(rec_name, sorter_name), "num_false_positive"] = comp.count_false_positive_units( - redundant_score - ) - count_units.loc[(rec_name, sorter_name), "num_bad"] = comp.count_bad_units() - - return count_units - - def aggregate_dataframes(self, copy_into_folder=True, **karg_thresh): - dataframes = {} - dataframes["run_times"] = self.aggregate_run_times().reset_index() - perfs = self.aggregate_performance_by_unit() - - dataframes["perf_by_unit"] = perfs.reset_index() - dataframes["count_units"] = self.aggregate_count_units(**karg_thresh).reset_index() - - if copy_into_folder: - tables_folder = self.study_folder / "tables" - tables_folder.mkdir(parents=True, exist_ok=True) - - for name, df in dataframes.items(): - df.to_csv(str(tables_folder / (name + ".csv")), sep="\t", index=False) - - return dataframes - - def get_waveform_extractor(self, rec_name, sorter_name=None): - rec = self.get_recording(rec_name) - - if sorter_name is None: - name = "GroundTruth" - sorting = self.get_ground_truth(rec_name) - else: - assert sorter_name in self.sorter_names - name = sorter_name - sorting = self.get_sorting(sorter_name, rec_name) - - waveform_folder = self.study_folder / "waveforms" / f"waveforms_{name}_{rec_name}" - - if waveform_folder.is_dir(): - we = WaveformExtractor.load(waveform_folder) - else: - we = WaveformExtractor.create(rec, sorting, waveform_folder) - return we - - def compute_waveforms( - self, - rec_name, - sorter_name=None, - ms_before=3.0, - ms_after=4.0, - max_spikes_per_unit=500, - n_jobs=-1, - total_memory="1G", - ): - we = self.get_waveform_extractor(rec_name, sorter_name) - we.set_params(ms_before=ms_before, ms_after=ms_after, max_spikes_per_unit=max_spikes_per_unit) - we.run_extract_waveforms(n_jobs=n_jobs, total_memory=total_memory) - - def get_templates(self, rec_name, sorter_name=None, mode="median"): - """ - Get template for a given recording. - - If sorter_name=None then template are from the ground truth. - - """ - we = self.get_waveform_extractor(rec_name, sorter_name=sorter_name) - templates = we.get_all_templates(mode=mode) - return templates - - def compute_metrics( - self, - rec_name, - metric_names=["snr"], - ms_before=3.0, - ms_after=4.0, - max_spikes_per_unit=500, - n_jobs=-1, - total_memory="1G", - ): - we = self.get_waveform_extractor(rec_name) - we.set_params(ms_before=ms_before, ms_after=ms_after, max_spikes_per_unit=max_spikes_per_unit) - we.run_extract_waveforms(n_jobs=n_jobs, total_memory=total_memory) - - # metrics - metrics = compute_quality_metrics(we, metric_names=metric_names) - folder = self.study_folder / "metrics" - folder.mkdir(exist_ok=True) - filename = folder / f"metrics _{rec_name}.txt" - metrics.to_csv(filename, sep="\t", index=True) - - return metrics - - def get_metrics(self, rec_name=None, **metric_kwargs): - """ - Load or compute units metrics for a given recording. - """ - rec_name = self._check_rec_name(rec_name) - metrics_folder = self.study_folder / "metrics" - metrics_folder.mkdir(parents=True, exist_ok=True) - - filename = self.study_folder / "metrics" / f"metrics _{rec_name}.txt" - import pandas as pd - - if filename.is_file(): - metrics = pd.read_csv(filename, sep="\t", index_col=0) - gt_sorting = self.get_ground_truth(rec_name) - metrics.index = gt_sorting.unit_ids - else: - metrics = self.compute_metrics(rec_name, **metric_kwargs) - - metrics.index.name = "unit_id" - # add rec name columns - metrics["rec_name"] = rec_name - - return metrics - - def get_units_snr(self, rec_name=None, **metric_kwargs): - """ """ - metric = self.get_metrics(rec_name=rec_name, **metric_kwargs) - return metric["snr"] - - def concat_all_snr(self): - metrics = [] - for rec_name in self.rec_names: - df = self.get_metrics(rec_name) - df = df.reset_index() - metrics.append(df) - metrics = pd.concat(metrics) - metrics = metrics.set_index(["rec_name", "unit_id"]) - return metrics["snr"] From ba2e961bd9b26fd7acc226183b19bc5b3a85401b Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 13 Sep 2023 14:58:14 +0200 Subject: [PATCH 12/26] small fix --- src/spikeinterface/comparison/groundtruthstudy.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index 8d43fb5f0c..9f0039b9cb 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -228,7 +228,12 @@ def copy_sortings(self, case_keys=None, force=True): sorter_folder = self.folder / "sorters" / self.key_to_str(key) log_file = self.folder / "sortings" / "run_logs" / f"{self.key_to_str(key)}.json" - sorting = read_sorter_folder(sorter_folder, raise_error=False) + + if (sorter_folder / "spikeinterface_log.json").exists(): + sorting = read_sorter_folder(sorter_folder, raise_error=False) + else: + sorting = None + if sorting is not None: if sorting_folder.exists(): if force: From 9b5b28b9b6cf0b7d7e313d12cf2015253087f032 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 13 Sep 2023 15:03:57 +0200 Subject: [PATCH 13/26] small fix --- src/spikeinterface/widgets/gtstudy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/widgets/gtstudy.py b/src/spikeinterface/widgets/gtstudy.py index 304cf1a44a..bc2c1246b7 100644 --- a/src/spikeinterface/widgets/gtstudy.py +++ b/src/spikeinterface/widgets/gtstudy.py @@ -35,7 +35,7 @@ def __init__( plot_data = dict( study=study, - run_times=study.get_run_times(), + run_times=study.get_run_times(case_keys), case_keys=case_keys, ) From 8d9ce49d14df99c1901854a398c2862c13184ceb Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 19 Sep 2023 10:00:38 +0200 Subject: [PATCH 14/26] group in same file CollisionGTComparison and CollisionGTStudy group in same file CorrelogramGTComparison and CorrelogramGTStudy --- .../{collisioncomparison.py => collision.py} | 94 ++++++++++++++++++- .../comparison/collisionstudy.py | 88 ----------------- ...orrelogramcomparison.py => correlogram.py} | 79 +++++++++++++++- .../comparison/correlogramstudy.py | 76 --------------- 4 files changed, 170 insertions(+), 167 deletions(-) rename src/spikeinterface/comparison/{collisioncomparison.py => collision.py} (58%) delete mode 100644 src/spikeinterface/comparison/collisionstudy.py rename src/spikeinterface/comparison/{correlogramcomparison.py => correlogram.py} (58%) delete mode 100644 src/spikeinterface/comparison/correlogramstudy.py diff --git a/src/spikeinterface/comparison/collisioncomparison.py b/src/spikeinterface/comparison/collision.py similarity index 58% rename from src/spikeinterface/comparison/collisioncomparison.py rename to src/spikeinterface/comparison/collision.py index 3b279717b7..864809b04b 100644 --- a/src/spikeinterface/comparison/collisioncomparison.py +++ b/src/spikeinterface/comparison/collision.py @@ -1,8 +1,14 @@ -import numpy as np - from .paircomparisons import GroundTruthComparison +from .groundtruthstudy import GroundTruthStudy +from .studytools import iter_computed_sorting ## TODO remove this from .comparisontools import make_collision_events +import numpy as np + + + + + class CollisionGTComparison(GroundTruthComparison): """ @@ -156,3 +162,87 @@ def compute_collision_by_similarity(self, similarity_matrix, unit_ids=None, good pair_names = pair_names[order] return similarities, recall_scores, pair_names + + + +class CollisionGTStudy(GroundTruthStudy): + def run_comparisons(self, exhaustive_gt=True, collision_lag=2.0, nbins=11, **kwargs): + self.comparisons = {} + for rec_name, sorter_name, sorting in iter_computed_sorting(self.study_folder): + gt_sorting = self.get_ground_truth(rec_name) + comp = CollisionGTComparison( + gt_sorting, sorting, exhaustive_gt=exhaustive_gt, collision_lag=collision_lag, nbins=nbins + ) + self.comparisons[(rec_name, sorter_name)] = comp + self.exhaustive_gt = exhaustive_gt + self.collision_lag = collision_lag + + def get_lags(self): + fs = self.comparisons[(self.rec_names[0], self.sorter_names[0])].sorting1.get_sampling_frequency() + lags = self.comparisons[(self.rec_names[0], self.sorter_names[0])].bins / fs * 1000 + return lags + + def precompute_scores_by_similarities(self, good_only=True, min_accuracy=0.9): + if not hasattr(self, "_good_only") or self._good_only != good_only: + import sklearn + + similarity_matrix = {} + for rec_name in self.rec_names: + templates = self.get_templates(rec_name) + flat_templates = templates.reshape(templates.shape[0], -1) + similarity_matrix[rec_name] = sklearn.metrics.pairwise.cosine_similarity(flat_templates) + + self.all_similarities = {} + self.all_recall_scores = {} + self.good_only = good_only + + for sorter_ind, sorter_name in enumerate(self.sorter_names): + # loop over recordings + all_similarities = [] + all_recall_scores = [] + + for rec_name in self.rec_names: + if (rec_name, sorter_name) in self.comparisons.keys(): + comp = self.comparisons[(rec_name, sorter_name)] + similarities, recall_scores, pair_names = comp.compute_collision_by_similarity( + similarity_matrix[rec_name], good_only=good_only, min_accuracy=min_accuracy + ) + + all_similarities.append(similarities) + all_recall_scores.append(recall_scores) + + self.all_similarities[sorter_name] = np.concatenate(all_similarities, axis=0) + self.all_recall_scores[sorter_name] = np.concatenate(all_recall_scores, axis=0) + + def get_mean_over_similarity_range(self, similarity_range, sorter_name): + idx = (self.all_similarities[sorter_name] >= similarity_range[0]) & ( + self.all_similarities[sorter_name] <= similarity_range[1] + ) + all_similarities = self.all_similarities[sorter_name][idx] + all_recall_scores = self.all_recall_scores[sorter_name][idx] + + order = np.argsort(all_similarities) + all_similarities = all_similarities[order] + all_recall_scores = all_recall_scores[order, :] + + mean_recall_scores = np.nanmean(all_recall_scores, axis=0) + + return mean_recall_scores + + def get_lag_profile_over_similarity_bins(self, similarity_bins, sorter_name): + all_similarities = self.all_similarities[sorter_name] + all_recall_scores = self.all_recall_scores[sorter_name] + + order = np.argsort(all_similarities) + all_similarities = all_similarities[order] + all_recall_scores = all_recall_scores[order, :] + + result = {} + + for i in range(similarity_bins.size - 1): + cmin, cmax = similarity_bins[i], similarity_bins[i + 1] + amin, amax = np.searchsorted(all_similarities, [cmin, cmax]) + mean_recall_scores = np.nanmean(all_recall_scores[amin:amax], axis=0) + result[(cmin, cmax)] = mean_recall_scores + + return result diff --git a/src/spikeinterface/comparison/collisionstudy.py b/src/spikeinterface/comparison/collisionstudy.py deleted file mode 100644 index 34a556e8b9..0000000000 --- a/src/spikeinterface/comparison/collisionstudy.py +++ /dev/null @@ -1,88 +0,0 @@ -from .groundtruthstudy import GroundTruthStudy -from .studytools import iter_computed_sorting -from .collisioncomparison import CollisionGTComparison - -import numpy as np - - -class CollisionGTStudy(GroundTruthStudy): - def run_comparisons(self, exhaustive_gt=True, collision_lag=2.0, nbins=11, **kwargs): - self.comparisons = {} - for rec_name, sorter_name, sorting in iter_computed_sorting(self.study_folder): - gt_sorting = self.get_ground_truth(rec_name) - comp = CollisionGTComparison( - gt_sorting, sorting, exhaustive_gt=exhaustive_gt, collision_lag=collision_lag, nbins=nbins - ) - self.comparisons[(rec_name, sorter_name)] = comp - self.exhaustive_gt = exhaustive_gt - self.collision_lag = collision_lag - - def get_lags(self): - fs = self.comparisons[(self.rec_names[0], self.sorter_names[0])].sorting1.get_sampling_frequency() - lags = self.comparisons[(self.rec_names[0], self.sorter_names[0])].bins / fs * 1000 - return lags - - def precompute_scores_by_similarities(self, good_only=True, min_accuracy=0.9): - if not hasattr(self, "_good_only") or self._good_only != good_only: - import sklearn - - similarity_matrix = {} - for rec_name in self.rec_names: - templates = self.get_templates(rec_name) - flat_templates = templates.reshape(templates.shape[0], -1) - similarity_matrix[rec_name] = sklearn.metrics.pairwise.cosine_similarity(flat_templates) - - self.all_similarities = {} - self.all_recall_scores = {} - self.good_only = good_only - - for sorter_ind, sorter_name in enumerate(self.sorter_names): - # loop over recordings - all_similarities = [] - all_recall_scores = [] - - for rec_name in self.rec_names: - if (rec_name, sorter_name) in self.comparisons.keys(): - comp = self.comparisons[(rec_name, sorter_name)] - similarities, recall_scores, pair_names = comp.compute_collision_by_similarity( - similarity_matrix[rec_name], good_only=good_only, min_accuracy=min_accuracy - ) - - all_similarities.append(similarities) - all_recall_scores.append(recall_scores) - - self.all_similarities[sorter_name] = np.concatenate(all_similarities, axis=0) - self.all_recall_scores[sorter_name] = np.concatenate(all_recall_scores, axis=0) - - def get_mean_over_similarity_range(self, similarity_range, sorter_name): - idx = (self.all_similarities[sorter_name] >= similarity_range[0]) & ( - self.all_similarities[sorter_name] <= similarity_range[1] - ) - all_similarities = self.all_similarities[sorter_name][idx] - all_recall_scores = self.all_recall_scores[sorter_name][idx] - - order = np.argsort(all_similarities) - all_similarities = all_similarities[order] - all_recall_scores = all_recall_scores[order, :] - - mean_recall_scores = np.nanmean(all_recall_scores, axis=0) - - return mean_recall_scores - - def get_lag_profile_over_similarity_bins(self, similarity_bins, sorter_name): - all_similarities = self.all_similarities[sorter_name] - all_recall_scores = self.all_recall_scores[sorter_name] - - order = np.argsort(all_similarities) - all_similarities = all_similarities[order] - all_recall_scores = all_recall_scores[order, :] - - result = {} - - for i in range(similarity_bins.size - 1): - cmin, cmax = similarity_bins[i], similarity_bins[i + 1] - amin, amax = np.searchsorted(all_similarities, [cmin, cmax]) - mean_recall_scores = np.nanmean(all_recall_scores[amin:amax], axis=0) - result[(cmin, cmax)] = mean_recall_scores - - return result diff --git a/src/spikeinterface/comparison/correlogramcomparison.py b/src/spikeinterface/comparison/correlogram.py similarity index 58% rename from src/spikeinterface/comparison/correlogramcomparison.py rename to src/spikeinterface/comparison/correlogram.py index 80e881a152..9c5e1e91cf 100644 --- a/src/spikeinterface/comparison/correlogramcomparison.py +++ b/src/spikeinterface/comparison/correlogram.py @@ -1,8 +1,13 @@ -import numpy as np from .paircomparisons import GroundTruthComparison +from .groundtruthstudy import GroundTruthStudy +from .studytools import iter_computed_sorting ## TODO remove this from spikeinterface.postprocessing import compute_correlograms +import numpy as np + + + class CorrelogramGTComparison(GroundTruthComparison): """ This class is an extension of GroundTruthComparison by focusing @@ -108,3 +113,75 @@ def compute_correlogram_by_similarity(self, similarity_matrix, window_ms=None): errors = errors[order, :] return similarities, errors + + + +class CorrelogramGTStudy(GroundTruthStudy): + def run_comparisons(self, exhaustive_gt=True, window_ms=100.0, bin_ms=1.0, well_detected_score=0.8, **kwargs): + self.comparisons = {} + for rec_name, sorter_name, sorting in iter_computed_sorting(self.study_folder): + gt_sorting = self.get_ground_truth(rec_name) + comp = CorrelogramGTComparison( + gt_sorting, + sorting, + exhaustive_gt=exhaustive_gt, + window_ms=window_ms, + bin_ms=bin_ms, + well_detected_score=well_detected_score, + ) + self.comparisons[(rec_name, sorter_name)] = comp + + self.exhaustive_gt = exhaustive_gt + + @property + def time_bins(self): + for key, value in self.comparisons.items(): + return value.time_bins + + def precompute_scores_by_similarities(self, good_only=True): + if not hasattr(self, "_computed"): + import sklearn + + similarity_matrix = {} + for rec_name in self.rec_names: + templates = self.get_templates(rec_name) + flat_templates = templates.reshape(templates.shape[0], -1) + similarity_matrix[rec_name] = sklearn.metrics.pairwise.cosine_similarity(flat_templates) + + self.all_similarities = {} + self.all_errors = {} + self._computed = True + + for sorter_ind, sorter_name in enumerate(self.sorter_names): + # loop over recordings + all_errors = [] + all_similarities = [] + for rec_name in self.rec_names: + try: + comp = self.comparisons[(rec_name, sorter_name)] + similarities, errors = comp.compute_correlogram_by_similarity(similarity_matrix[rec_name]) + all_similarities.append(similarities) + all_errors.append(errors) + except Exception: + pass + + self.all_similarities[sorter_name] = np.concatenate(all_similarities, axis=0) + self.all_errors[sorter_name] = np.concatenate(all_errors, axis=0) + + def get_error_profile_over_similarity_bins(self, similarity_bins, sorter_name): + all_similarities = self.all_similarities[sorter_name] + all_errors = self.all_errors[sorter_name] + + order = np.argsort(all_similarities) + all_similarities = all_similarities[order] + all_errors = all_errors[order, :] + + result = {} + + for i in range(similarity_bins.size - 1): + cmin, cmax = similarity_bins[i], similarity_bins[i + 1] + amin, amax = np.searchsorted(all_similarities, [cmin, cmax]) + mean_errors = np.nanmean(all_errors[amin:amax], axis=0) + result[(cmin, cmax)] = mean_errors + + return result diff --git a/src/spikeinterface/comparison/correlogramstudy.py b/src/spikeinterface/comparison/correlogramstudy.py deleted file mode 100644 index fb00c08157..0000000000 --- a/src/spikeinterface/comparison/correlogramstudy.py +++ /dev/null @@ -1,76 +0,0 @@ -from .groundtruthstudy import GroundTruthStudy -from .studytools import iter_computed_sorting -from .correlogramcomparison import CorrelogramGTComparison - -import numpy as np - - -class CorrelogramGTStudy(GroundTruthStudy): - def run_comparisons(self, exhaustive_gt=True, window_ms=100.0, bin_ms=1.0, well_detected_score=0.8, **kwargs): - self.comparisons = {} - for rec_name, sorter_name, sorting in iter_computed_sorting(self.study_folder): - gt_sorting = self.get_ground_truth(rec_name) - comp = CorrelogramGTComparison( - gt_sorting, - sorting, - exhaustive_gt=exhaustive_gt, - window_ms=window_ms, - bin_ms=bin_ms, - well_detected_score=well_detected_score, - ) - self.comparisons[(rec_name, sorter_name)] = comp - - self.exhaustive_gt = exhaustive_gt - - @property - def time_bins(self): - for key, value in self.comparisons.items(): - return value.time_bins - - def precompute_scores_by_similarities(self, good_only=True): - if not hasattr(self, "_computed"): - import sklearn - - similarity_matrix = {} - for rec_name in self.rec_names: - templates = self.get_templates(rec_name) - flat_templates = templates.reshape(templates.shape[0], -1) - similarity_matrix[rec_name] = sklearn.metrics.pairwise.cosine_similarity(flat_templates) - - self.all_similarities = {} - self.all_errors = {} - self._computed = True - - for sorter_ind, sorter_name in enumerate(self.sorter_names): - # loop over recordings - all_errors = [] - all_similarities = [] - for rec_name in self.rec_names: - try: - comp = self.comparisons[(rec_name, sorter_name)] - similarities, errors = comp.compute_correlogram_by_similarity(similarity_matrix[rec_name]) - all_similarities.append(similarities) - all_errors.append(errors) - except Exception: - pass - - self.all_similarities[sorter_name] = np.concatenate(all_similarities, axis=0) - self.all_errors[sorter_name] = np.concatenate(all_errors, axis=0) - - def get_error_profile_over_similarity_bins(self, similarity_bins, sorter_name): - all_similarities = self.all_similarities[sorter_name] - all_errors = self.all_errors[sorter_name] - - order = np.argsort(all_similarities) - all_similarities = all_similarities[order] - all_errors = all_errors[order, :] - - result = {} - - for i in range(similarity_bins.size - 1): - cmin, cmax = similarity_bins[i], similarity_bins[i + 1] - amin, amax = np.searchsorted(all_similarities, [cmin, cmax]) - mean_errors = np.nanmean(all_errors[amin:amax], axis=0) - result[(cmin, cmax)] = mean_errors - - return result From b1297e6aef50aa507415359b773f1c5611230b1f Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 19 Sep 2023 13:08:36 +0200 Subject: [PATCH 15/26] Update CollisionGTStudy and CorrelogramGTStudy --- src/spikeinterface/comparison/__init__.py | 9 +- src/spikeinterface/comparison/collision.py | 96 +++++++++---------- src/spikeinterface/comparison/correlogram.py | 85 +++++++--------- .../comparison/groundtruthstudy.py | 4 +- .../_legacy_mpl_widgets/collisioncomp.py | 2 +- 5 files changed, 83 insertions(+), 113 deletions(-) diff --git a/src/spikeinterface/comparison/__init__.py b/src/spikeinterface/comparison/__init__.py index a390bb7689..7ac5b29aa2 100644 --- a/src/spikeinterface/comparison/__init__.py +++ b/src/spikeinterface/comparison/__init__.py @@ -28,12 +28,11 @@ compare_multiple_templates, MultiTemplateComparison, ) -from .collisioncomparison import CollisionGTComparison -from .correlogramcomparison import CorrelogramGTComparison + from .groundtruthstudy import GroundTruthStudy -from .collisionstudy import CollisionGTStudy -from .correlogramstudy import CorrelogramGTStudy -from .studytools import aggregate_performances_table +from .collision import CollisionGTComparison, CollisionGTStudy +from .correlogram import CorrelogramGTComparison, CorrelogramGTStudy +# from .studytools import aggregate_performances_table from .hybrid import ( HybridSpikesRecording, HybridUnitsRecording, diff --git a/src/spikeinterface/comparison/collision.py b/src/spikeinterface/comparison/collision.py index 864809b04b..c526c22ae4 100644 --- a/src/spikeinterface/comparison/collision.py +++ b/src/spikeinterface/comparison/collision.py @@ -12,8 +12,9 @@ class CollisionGTComparison(GroundTruthComparison): """ - This class is an extension of GroundTruthComparison by focusing - to benchmark spike in collision + This class is an extension of GroundTruthComparison by focusing to benchmark spike in collision. + + This class needs maintenance and need a bit of refactoring. collision_lag: float @@ -166,60 +167,49 @@ def compute_collision_by_similarity(self, similarity_matrix, unit_ids=None, good class CollisionGTStudy(GroundTruthStudy): - def run_comparisons(self, exhaustive_gt=True, collision_lag=2.0, nbins=11, **kwargs): - self.comparisons = {} - for rec_name, sorter_name, sorting in iter_computed_sorting(self.study_folder): - gt_sorting = self.get_ground_truth(rec_name) - comp = CollisionGTComparison( - gt_sorting, sorting, exhaustive_gt=exhaustive_gt, collision_lag=collision_lag, nbins=nbins - ) - self.comparisons[(rec_name, sorter_name)] = comp + def run_comparisons(self, case_keys=None, exhaustive_gt=True, collision_lag=2.0, nbins=11, **kwargs): + _kwargs = dict() + _kwargs.update(kwargs) + _kwargs["exhaustive_gt"] = exhaustive_gt + _kwargs["collision_lag"] = collision_lag + _kwargs["nbins"] = nbins + GroundTruthStudy.run_comparisons(self, case_keys=None, comparison_class=CollisionGTComparison, **_kwargs) self.exhaustive_gt = exhaustive_gt self.collision_lag = collision_lag - def get_lags(self): - fs = self.comparisons[(self.rec_names[0], self.sorter_names[0])].sorting1.get_sampling_frequency() - lags = self.comparisons[(self.rec_names[0], self.sorter_names[0])].bins / fs * 1000 + def get_lags(self, key): + comp = self.comparisons[key] + fs = comp.sorting1.get_sampling_frequency() + lags = comp.bins / fs * 1000. return lags - def precompute_scores_by_similarities(self, good_only=True, min_accuracy=0.9): - if not hasattr(self, "_good_only") or self._good_only != good_only: - import sklearn - - similarity_matrix = {} - for rec_name in self.rec_names: - templates = self.get_templates(rec_name) - flat_templates = templates.reshape(templates.shape[0], -1) - similarity_matrix[rec_name] = sklearn.metrics.pairwise.cosine_similarity(flat_templates) - - self.all_similarities = {} - self.all_recall_scores = {} - self.good_only = good_only - - for sorter_ind, sorter_name in enumerate(self.sorter_names): - # loop over recordings - all_similarities = [] - all_recall_scores = [] - - for rec_name in self.rec_names: - if (rec_name, sorter_name) in self.comparisons.keys(): - comp = self.comparisons[(rec_name, sorter_name)] - similarities, recall_scores, pair_names = comp.compute_collision_by_similarity( - similarity_matrix[rec_name], good_only=good_only, min_accuracy=min_accuracy - ) - - all_similarities.append(similarities) - all_recall_scores.append(recall_scores) - - self.all_similarities[sorter_name] = np.concatenate(all_similarities, axis=0) - self.all_recall_scores[sorter_name] = np.concatenate(all_recall_scores, axis=0) - - def get_mean_over_similarity_range(self, similarity_range, sorter_name): - idx = (self.all_similarities[sorter_name] >= similarity_range[0]) & ( - self.all_similarities[sorter_name] <= similarity_range[1] + def precompute_scores_by_similarities(self, case_keys=None, good_only=False, min_accuracy=0.9): + import sklearn + if case_keys is None: + case_keys = self.cases.keys() + + self.all_similarities = {} + self.all_recall_scores = {} + self.good_only = good_only + + for key in case_keys: + templates = self.get_templates(key) + flat_templates = templates.reshape(templates.shape[0], -1) + similarity = sklearn.metrics.pairwise.cosine_similarity(flat_templates) + comp = self.comparisons[key] + similarities, recall_scores, pair_names = comp.compute_collision_by_similarity( + similarity, good_only=good_only, min_accuracy=min_accuracy + ) + self.all_similarities[key] = similarities + self.all_recall_scores[key] = recall_scores + + + def get_mean_over_similarity_range(self, similarity_range, key): + idx = (self.all_similarities[key] >= similarity_range[0]) & ( + self.all_similarities[key] <= similarity_range[1] ) - all_similarities = self.all_similarities[sorter_name][idx] - all_recall_scores = self.all_recall_scores[sorter_name][idx] + all_similarities = self.all_similarities[key][idx] + all_recall_scores = self.all_recall_scores[key][idx] order = np.argsort(all_similarities) all_similarities = all_similarities[order] @@ -229,9 +219,9 @@ def get_mean_over_similarity_range(self, similarity_range, sorter_name): return mean_recall_scores - def get_lag_profile_over_similarity_bins(self, similarity_bins, sorter_name): - all_similarities = self.all_similarities[sorter_name] - all_recall_scores = self.all_recall_scores[sorter_name] + def get_lag_profile_over_similarity_bins(self, similarity_bins, key): + all_similarities = self.all_similarities[key] + all_recall_scores = self.all_recall_scores[key] order = np.argsort(all_similarities) all_similarities = all_similarities[order] diff --git a/src/spikeinterface/comparison/correlogram.py b/src/spikeinterface/comparison/correlogram.py index 9c5e1e91cf..b2376cb52d 100644 --- a/src/spikeinterface/comparison/correlogram.py +++ b/src/spikeinterface/comparison/correlogram.py @@ -11,11 +11,9 @@ class CorrelogramGTComparison(GroundTruthComparison): """ This class is an extension of GroundTruthComparison by focusing - to benchmark correlation reconstruction + to benchmark correlation reconstruction. - - collision_lag: float - Collision lag in ms. + This class needs maintenance and need a bit of refactoring. """ @@ -110,27 +108,21 @@ def compute_correlogram_by_similarity(self, similarity_matrix, window_ms=None): order = np.argsort(similarities) similarities = similarities[order] - errors = errors[order, :] + errors = errors[order] return similarities, errors class CorrelogramGTStudy(GroundTruthStudy): - def run_comparisons(self, exhaustive_gt=True, window_ms=100.0, bin_ms=1.0, well_detected_score=0.8, **kwargs): - self.comparisons = {} - for rec_name, sorter_name, sorting in iter_computed_sorting(self.study_folder): - gt_sorting = self.get_ground_truth(rec_name) - comp = CorrelogramGTComparison( - gt_sorting, - sorting, - exhaustive_gt=exhaustive_gt, - window_ms=window_ms, - bin_ms=bin_ms, - well_detected_score=well_detected_score, - ) - self.comparisons[(rec_name, sorter_name)] = comp - + def run_comparisons(self, case_keys=None, exhaustive_gt=True, window_ms=100.0, bin_ms=1.0, well_detected_score=0.8, **kwargs): + _kwargs = dict() + _kwargs.update(kwargs) + _kwargs["exhaustive_gt"] = exhaustive_gt + _kwargs["window_ms"] = window_ms + _kwargs["bin_ms"] = bin_ms + _kwargs["well_detected_score"] = well_detected_score + GroundTruthStudy.run_comparisons(self, case_keys=None, comparison_class=CorrelogramGTComparison, **_kwargs) self.exhaustive_gt = exhaustive_gt @property @@ -138,39 +130,28 @@ def time_bins(self): for key, value in self.comparisons.items(): return value.time_bins - def precompute_scores_by_similarities(self, good_only=True): - if not hasattr(self, "_computed"): - import sklearn - - similarity_matrix = {} - for rec_name in self.rec_names: - templates = self.get_templates(rec_name) - flat_templates = templates.reshape(templates.shape[0], -1) - similarity_matrix[rec_name] = sklearn.metrics.pairwise.cosine_similarity(flat_templates) - - self.all_similarities = {} - self.all_errors = {} - self._computed = True - - for sorter_ind, sorter_name in enumerate(self.sorter_names): - # loop over recordings - all_errors = [] - all_similarities = [] - for rec_name in self.rec_names: - try: - comp = self.comparisons[(rec_name, sorter_name)] - similarities, errors = comp.compute_correlogram_by_similarity(similarity_matrix[rec_name]) - all_similarities.append(similarities) - all_errors.append(errors) - except Exception: - pass - - self.all_similarities[sorter_name] = np.concatenate(all_similarities, axis=0) - self.all_errors[sorter_name] = np.concatenate(all_errors, axis=0) - - def get_error_profile_over_similarity_bins(self, similarity_bins, sorter_name): - all_similarities = self.all_similarities[sorter_name] - all_errors = self.all_errors[sorter_name] + def precompute_scores_by_similarities(self, case_keys=None, good_only=True): + import sklearn.metrics + + if case_keys is None: + case_keys = self.cases.keys() + + self.all_similarities = {} + self.all_errors = {} + + for key in case_keys: + templates = self.get_templates(key) + flat_templates = templates.reshape(templates.shape[0], -1) + similarity = sklearn.metrics.pairwise.cosine_similarity(flat_templates) + comp = self.comparisons[key] + similarities, errors = comp.compute_correlogram_by_similarity(similarity) + + self.all_similarities[key] = similarities + self.all_errors[key] = errors + + def get_error_profile_over_similarity_bins(self, similarity_bins, key): + all_similarities = self.all_similarities[key] + all_errors = self.all_errors[key] order = np.argsort(all_similarities) all_similarities = all_similarities[order] diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index 9f0039b9cb..0c08318ef4 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -155,7 +155,7 @@ def scan_folder(self): def __repr__(self): - t = f"GroundTruthStudy {self.folder.stem} \n" + t = f"{self.__class__.__name__} {self.folder.stem} \n" t += f" datasets: {len(self.datasets)} {list(self.datasets.keys())}\n" t += f" cases: {len(self.cases)} {list(self.cases.keys())}\n" num_computed = sum([1 for sorting in self.sortings.values() if sorting is not None]) @@ -303,7 +303,7 @@ def get_waveform_extractor(self, key): we.set_recording(recording) return we - def get_templates(self, key, mode="mean"): + def get_templates(self, key, mode="average"): we = self.get_waveform_extractor(key) templates = we.get_all_templates(mode=mode) return templates diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py index 6d981e1fd4..096a5f3933 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py @@ -1,7 +1,7 @@ import numpy as np from .basewidget import BaseWidget -from spikeinterface.comparison.collisioncomparison import CollisionGTComparison +from spikeinterface.comparison import CollisionGTComparison class ComparisonCollisionPairByPairWidget(BaseWidget): From 8a7a90e130e3007ad73ae840ee4e889c9a6b146f Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 19 Sep 2023 13:35:50 +0200 Subject: [PATCH 16/26] wip --- src/spikeinterface/comparison/groundtruthstudy.py | 5 +---- .../widgets/_legacy_mpl_widgets/collisioncomp.py | 1 - 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index 0c08318ef4..6898f381b6 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -17,10 +17,7 @@ from .paircomparisons import compare_sorter_to_ground_truth, GroundTruthComparison -# TODO : save comparison in folders when COmparison object will be able to serialize -# TODO ??: make an internal optional binary copy when running several external sorter -# on the same dataset to avoid multiple save binary ? even when the recording is float32 (ks need int16) - +# TODO later : save comparison in folders when comparison object will be able to serialize # This is to separate names when the key are tuples when saving folders diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py index 096a5f3933..d25f1ea97b 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/collisioncomp.py @@ -1,7 +1,6 @@ import numpy as np from .basewidget import BaseWidget -from spikeinterface.comparison import CollisionGTComparison class ComparisonCollisionPairByPairWidget(BaseWidget): From fe6f60f45b8ee1f50e81c8d7b5b209965507c1df Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 19 Sep 2023 13:39:31 +0200 Subject: [PATCH 17/26] Re move studytools.py. Not needed anymore. --- src/spikeinterface/comparison/__init__.py | 2 +- src/spikeinterface/comparison/studytools.py | 352 -------------------- 2 files changed, 1 insertion(+), 353 deletions(-) delete mode 100644 src/spikeinterface/comparison/studytools.py diff --git a/src/spikeinterface/comparison/__init__.py b/src/spikeinterface/comparison/__init__.py index 7ac5b29aa2..bff85dde4a 100644 --- a/src/spikeinterface/comparison/__init__.py +++ b/src/spikeinterface/comparison/__init__.py @@ -32,7 +32,7 @@ from .groundtruthstudy import GroundTruthStudy from .collision import CollisionGTComparison, CollisionGTStudy from .correlogram import CorrelogramGTComparison, CorrelogramGTStudy -# from .studytools import aggregate_performances_table + from .hybrid import ( HybridSpikesRecording, HybridUnitsRecording, diff --git a/src/spikeinterface/comparison/studytools.py b/src/spikeinterface/comparison/studytools.py deleted file mode 100644 index 00119c1586..0000000000 --- a/src/spikeinterface/comparison/studytools.py +++ /dev/null @@ -1,352 +0,0 @@ -""" -High level tools to run many ground-truth comparison with -many sorter on many recordings and then collect and aggregate results -in an easy way. - -The all mechanism is based on an intrinsic organization -into a "study_folder" with several subfolder: - * raw_files : contain a copy in binary format of recordings - * sorter_folders : contains output of sorters - * ground_truth : contains a copy of sorting ground in npz format - * sortings: contains light copy of all sorting in npz format - * tables: some table in cvs format -""" - -from pathlib import Path -import shutil -import json -import os - - -from spikeinterface.core import load_extractor -from spikeinterface.core.job_tools import fix_job_kwargs -from spikeinterface.extractors import NpzSortingExtractor -from spikeinterface.sorters import sorter_dict -from spikeinterface.sorters.basesorter import is_log_ok - - -from .comparisontools import _perf_keys -from .paircomparisons import compare_sorter_to_ground_truth - - - - - -# This is deprecated and will be removed -def iter_working_folder(working_folder): - working_folder = Path(working_folder) - for rec_folder in working_folder.iterdir(): - if not rec_folder.is_dir(): - continue - for output_folder in rec_folder.iterdir(): - if (output_folder / "spikeinterface_job.json").is_file(): - with open(output_folder / "spikeinterface_job.json", "r") as f: - job_dict = json.load(f) - rec_name = job_dict["rec_name"] - sorter_name = job_dict["sorter_name"] - yield rec_name, sorter_name, output_folder - else: - rec_name = rec_folder.name - sorter_name = output_folder.name - if not output_folder.is_dir(): - continue - if not is_log_ok(output_folder): - continue - yield rec_name, sorter_name, output_folder - -# This is deprecated and will be removed -def iter_sorting_output(working_folder): - """Iterator over output_folder to retrieve all triplets of (rec_name, sorter_name, sorting).""" - for rec_name, sorter_name, output_folder in iter_working_folder(working_folder): - SorterClass = sorter_dict[sorter_name] - sorting = SorterClass.get_result_from_folder(output_folder) - yield rec_name, sorter_name, sorting - - - -def setup_comparison_study(study_folder, gt_dict, **job_kwargs): - """ - Based on a dict of (recording, sorting) create the study folder. - - Parameters - ---------- - study_folder: str - The study folder. - gt_dict : a dict of tuple (recording, sorting_gt) - Dict of tuple that contain recording and sorting ground truth - """ - job_kwargs = fix_job_kwargs(job_kwargs) - study_folder = Path(study_folder) - assert not study_folder.is_dir(), "'study_folder' already exists. Please remove it" - - study_folder.mkdir(parents=True, exist_ok=True) - sorting_folders = study_folder / "sortings" - log_folder = sorting_folders / "run_log" - log_folder.mkdir(parents=True, exist_ok=True) - tables_folder = study_folder / "tables" - tables_folder.mkdir(parents=True, exist_ok=True) - - for rec_name, (recording, sorting_gt) in gt_dict.items(): - # write recording using save with binary - folder = study_folder / "ground_truth" / rec_name - sorting_gt.save(folder=folder, format="numpy_folder") - folder = study_folder / "raw_files" / rec_name - recording.save(folder=folder, format="binary", **job_kwargs) - - # make an index of recording names - with open(study_folder / "names.txt", mode="w", encoding="utf8") as f: - for rec_name in gt_dict: - f.write(rec_name + "\n") - - -def get_rec_names(study_folder): - """ - Get list of keys of recordings. - Read from the 'names.txt' file in study folder. - - Parameters - ---------- - study_folder: str - The study folder. - - Returns - ------- - rec_names: list - List of names. - """ - study_folder = Path(study_folder) - with open(study_folder / "names.txt", mode="r", encoding="utf8") as f: - rec_names = f.read()[:-1].split("\n") - return rec_names - - -def get_recordings(study_folder): - """ - Get ground recording as a dict. - - They are read from the 'raw_files' folder with binary format. - - Parameters - ---------- - study_folder: str - The study folder. - - Returns - ------- - recording_dict: dict - Dict of recording. - """ - study_folder = Path(study_folder) - - rec_names = get_rec_names(study_folder) - recording_dict = {} - for rec_name in rec_names: - rec = load_extractor(study_folder / "raw_files" / rec_name) - recording_dict[rec_name] = rec - - return recording_dict - - -def get_ground_truths(study_folder): - """ - Get ground truth sorting extractor as a dict. - - They are read from the 'ground_truth' folder with npz format. - - Parameters - ---------- - study_folder: str - The study folder. - - Returns - ------- - ground_truths: dict - Dict of sorting_gt. - """ - study_folder = Path(study_folder) - rec_names = get_rec_names(study_folder) - ground_truths = {} - for rec_name in rec_names: - sorting = load_extractor(study_folder / "ground_truth" / rec_name) - ground_truths[rec_name] = sorting - return ground_truths - - -def iter_computed_names(study_folder): - sorting_folder = Path(study_folder) / "sortings" - for filename in os.listdir(sorting_folder): - if filename.endswith(".npz") and "[#]" in filename: - rec_name, sorter_name = filename.replace(".npz", "").split("[#]") - yield rec_name, sorter_name - - -def iter_computed_sorting(study_folder): - """ - Iter over sorting files. - """ - sorting_folder = Path(study_folder) / "sortings" - for filename in os.listdir(sorting_folder): - if filename.endswith(".npz") and "[#]" in filename: - rec_name, sorter_name = filename.replace(".npz", "").split("[#]") - sorting = NpzSortingExtractor(sorting_folder / filename) - yield rec_name, sorter_name, sorting - - -def collect_run_times(study_folder): - """ - Collect run times in a working folder and store it in CVS files. - - The output is list of (rec_name, sorter_name, run_time) - """ - import pandas as pd - - study_folder = Path(study_folder) - sorting_folders = study_folder / "sortings" - log_folder = sorting_folders / "run_log" - tables_folder = study_folder / "tables" - - tables_folder.mkdir(parents=True, exist_ok=True) - - run_times = [] - for filename in os.listdir(log_folder): - if filename.endswith(".json") and "[#]" in filename: - rec_name, sorter_name = filename.replace(".json", "").split("[#]") - with open(log_folder / filename, encoding="utf8", mode="r") as logfile: - log = json.load(logfile) - run_time = log.get("run_time", None) - run_times.append((rec_name, sorter_name, run_time)) - - run_times = pd.DataFrame(run_times, columns=["rec_name", "sorter_name", "run_time"]) - run_times = run_times.set_index(["rec_name", "sorter_name"]) - - return run_times - - -def aggregate_sorting_comparison(study_folder, exhaustive_gt=False): - """ - Loop over output folder in a tree to collect sorting output and run - ground_truth_comparison on them. - - Parameters - ---------- - study_folder: str - The study folder. - exhaustive_gt: bool (default True) - Tell if the ground true is "exhaustive" or not. In other world if the - GT have all possible units. It allows more performance measurement. - For instance, MEArec simulated dataset have exhaustive_gt=True - - Returns - ---------- - comparisons: a dict of SortingComparison - - """ - - study_folder = Path(study_folder) - - ground_truths = get_ground_truths(study_folder) - results = collect_study_sorting(study_folder) - - comparisons = {} - for (rec_name, sorter_name), sorting in results.items(): - gt_sorting = ground_truths[rec_name] - sc = compare_sorter_to_ground_truth(gt_sorting, sorting, exhaustive_gt=exhaustive_gt) - comparisons[(rec_name, sorter_name)] = sc - - return comparisons - - -def aggregate_performances_table(study_folder, exhaustive_gt=False, **karg_thresh): - """ - Aggregate some results into dataframe to have a "study" overview on all recordingXsorter. - - Tables are: - * run_times: run times per recordingXsorter - * perf_pooled_with_sum: GroundTruthComparison.see get_performance - * perf_pooled_with_average: GroundTruthComparison.see get_performance - * count_units: given some threshold count how many units : 'well_detected', 'redundant', 'false_postive_units, 'bad' - - Parameters - ---------- - study_folder: str - The study folder. - karg_thresh: dict - Threshold parameters used for the "count_units" table. - - Returns - ------- - dataframes: a dict of DataFrame - Return several useful DataFrame to compare all results. - Note that count_units depend on karg_thresh. - """ - import pandas as pd - - study_folder = Path(study_folder) - sorter_folders = study_folder / "sorter_folders" - tables_folder = study_folder / "tables" - - comparisons = aggregate_sorting_comparison(study_folder, exhaustive_gt=exhaustive_gt) - ground_truths = get_ground_truths(study_folder) - results = collect_study_sorting(study_folder) - - study_folder = Path(study_folder) - - dataframes = {} - - # get run times: - run_times = pd.read_csv(str(tables_folder / "run_times.csv"), sep="\t") - run_times.columns = ["rec_name", "sorter_name", "run_time"] - run_times = run_times.set_index( - [ - "rec_name", - "sorter_name", - ] - ) - dataframes["run_times"] = run_times - - perf_pooled_with_sum = pd.DataFrame(index=run_times.index, columns=_perf_keys) - dataframes["perf_pooled_with_sum"] = perf_pooled_with_sum - - perf_pooled_with_average = pd.DataFrame(index=run_times.index, columns=_perf_keys) - dataframes["perf_pooled_with_average"] = perf_pooled_with_average - - count_units = pd.DataFrame( - index=run_times.index, columns=["num_gt", "num_sorter", "num_well_detected", "num_redundant"] - ) - dataframes["count_units"] = count_units - if exhaustive_gt: - count_units["num_false_positive"] = None - count_units["num_bad"] = None - - perf_by_spiketrain = [] - - for (rec_name, sorter_name), comp in comparisons.items(): - gt_sorting = ground_truths[rec_name] - sorting = results[(rec_name, sorter_name)] - - perf = comp.get_performance(method="pooled_with_sum", output="pandas") - perf_pooled_with_sum.loc[(rec_name, sorter_name), :] = perf - - perf = comp.get_performance(method="pooled_with_average", output="pandas") - perf_pooled_with_average.loc[(rec_name, sorter_name), :] = perf - - perf = comp.get_performance(method="by_spiketrain", output="pandas") - perf["rec_name"] = rec_name - perf["sorter_name"] = sorter_name - perf = perf.reset_index() - - perf_by_spiketrain.append(perf) - - count_units.loc[(rec_name, sorter_name), "num_gt"] = len(gt_sorting.get_unit_ids()) - count_units.loc[(rec_name, sorter_name), "num_sorter"] = len(sorting.get_unit_ids()) - count_units.loc[(rec_name, sorter_name), "num_well_detected"] = comp.count_well_detected_units(**karg_thresh) - count_units.loc[(rec_name, sorter_name), "num_redundant"] = comp.count_redundant_units() - if exhaustive_gt: - count_units.loc[(rec_name, sorter_name), "num_false_positive"] = comp.count_false_positive_units() - count_units.loc[(rec_name, sorter_name), "num_bad"] = comp.count_bad_units() - - perf_by_spiketrain = pd.concat(perf_by_spiketrain) - perf_by_spiketrain = perf_by_spiketrain.set_index(["rec_name", "sorter_name", "gt_unit_id"]) - dataframes["perf_by_spiketrain"] = perf_by_spiketrain - - return dataframes From 77505adc76fce228d66347d0aeb66bacce94cc8c Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 19 Sep 2023 13:40:53 +0200 Subject: [PATCH 18/26] rm studytools part2 --- src/spikeinterface/comparison/collision.py | 1 - src/spikeinterface/comparison/correlogram.py | 1 - .../comparison/tests/test_studytools.py | 59 ------------------- 3 files changed, 61 deletions(-) delete mode 100644 src/spikeinterface/comparison/tests/test_studytools.py diff --git a/src/spikeinterface/comparison/collision.py b/src/spikeinterface/comparison/collision.py index c526c22ae4..01626b34b8 100644 --- a/src/spikeinterface/comparison/collision.py +++ b/src/spikeinterface/comparison/collision.py @@ -1,6 +1,5 @@ from .paircomparisons import GroundTruthComparison from .groundtruthstudy import GroundTruthStudy -from .studytools import iter_computed_sorting ## TODO remove this from .comparisontools import make_collision_events import numpy as np diff --git a/src/spikeinterface/comparison/correlogram.py b/src/spikeinterface/comparison/correlogram.py index b2376cb52d..150f5afe55 100644 --- a/src/spikeinterface/comparison/correlogram.py +++ b/src/spikeinterface/comparison/correlogram.py @@ -1,6 +1,5 @@ from .paircomparisons import GroundTruthComparison from .groundtruthstudy import GroundTruthStudy -from .studytools import iter_computed_sorting ## TODO remove this from spikeinterface.postprocessing import compute_correlograms diff --git a/src/spikeinterface/comparison/tests/test_studytools.py b/src/spikeinterface/comparison/tests/test_studytools.py deleted file mode 100644 index dbc39d5e1d..0000000000 --- a/src/spikeinterface/comparison/tests/test_studytools.py +++ /dev/null @@ -1,59 +0,0 @@ -import os -import shutil -from pathlib import Path - -import pytest - -from spikeinterface.extractors import toy_example -from spikeinterface.comparison.studytools import ( - setup_comparison_study, - iter_computed_names, - iter_computed_sorting, - get_rec_names, - get_ground_truths, - get_recordings, -) - -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "comparison" -else: - cache_folder = Path("cache_folder") / "comparison" - - -study_folder = cache_folder / "test_studytools" - - -def setup_module(): - if study_folder.is_dir(): - shutil.rmtree(study_folder) - - -def test_setup_comparison_study(): - rec0, gt_sorting0 = toy_example(num_channels=4, duration=30, seed=0, num_segments=1) - rec1, gt_sorting1 = toy_example(num_channels=32, duration=30, seed=0, num_segments=1) - - gt_dict = { - "toy_tetrode": (rec0, gt_sorting0), - "toy_probe32": (rec1, gt_sorting1), - } - setup_comparison_study(study_folder, gt_dict) - - -def test_get_ground_truths(): - names = get_rec_names(study_folder) - d = get_ground_truths(study_folder) - d = get_recordings(study_folder) - - -def test_loops(): - names = list(iter_computed_names(study_folder)) - for rec_name, sorter_name, sorting in iter_computed_sorting(study_folder): - print(rec_name, sorter_name) - print(sorting) - - -if __name__ == "__main__": - setup_module() - test_setup_comparison_study() - test_get_ground_truths() - test_loops() From b5376a9b30d84a201a6c8ad7db15c644abe993a9 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 19 Sep 2023 14:26:22 +0200 Subject: [PATCH 19/26] Modify doc for gt study --- doc/modules/comparison.rst | 101 +++++++++++------- .../comparison/groundtruthstudy.py | 6 -- 2 files changed, 62 insertions(+), 45 deletions(-) diff --git a/doc/modules/comparison.rst b/doc/modules/comparison.rst index b452307e3c..9b2e701dac 100644 --- a/doc/modules/comparison.rst +++ b/doc/modules/comparison.rst @@ -248,21 +248,19 @@ An **over-merged** unit has a relatively high agreement (>= 0.2 by default) for We also have a high level class to compare many sorters against ground truth: :py:func:`~spiekinterface.comparison.GroundTruthStudy()` -A study is a systematic performance comparison of several ground truth recordings with several sorters. +A study is a systematic performance comparison of several ground truth recordings with several sorters or several cases +like the different parameter sets. -The study class proposes high-level tool functions to run many ground truth comparisons with many sorters +The study class proposes high-level tool functions to run many ground truth comparisons with many "cases" on many recordings and then collect and aggregate results in an easy way. The all mechanism is based on an intrinsic organization into a "study_folder" with several subfolder: - * raw_files : contain a copy of recordings in binary format - * sorter_folders : contains outputs of sorters - * ground_truth : contains a copy of sorting ground truth in npz format - * sortings: contains light copy of all sorting in npz format - * tables: some tables in csv format - -In order to run and rerun the computation all gt_sorting and recordings are copied to a fast and universal format: -binary (for recordings) and npz (for sortings). + * datasets: contains ground truth datasets + * sorters : contains outputs of sorters + * sortings: contains light copy of all sorting + * metrics: contains metrics + * ... .. code-block:: python @@ -274,28 +272,52 @@ binary (for recordings) and npz (for sortings). import spikeinterface.widgets as sw from spikeinterface.comparison import GroundTruthStudy - # Setup study folder - rec0, gt_sorting0 = se.toy_example(num_channels=4, duration=10, seed=10, num_segments=1) - rec1, gt_sorting1 = se.toy_example(num_channels=4, duration=10, seed=0, num_segments=1) - gt_dict = { - 'rec0': (rec0, gt_sorting0), - 'rec1': (rec1, gt_sorting1), + + # generate 2 simulated datasets (could be also mearec files) + rec0, gt_sorting0 = generate_ground_truth_recording(num_channels=4, durations=[30.], seed=42) + rec1, gt_sorting1 = generate_ground_truth_recording(num_channels=4, durations=[30.], seed=91) + + datasets = { + "toy0": (rec0, gt_sorting0), + "toy1": (rec1, gt_sorting1), } - study_folder = 'a_study_folder' - study = GroundTruthStudy.create(study_folder, gt_dict) - # all sorters for all recordings in one function. - sorter_list = ['herdingspikes', 'tridesclous', ] - study.run_sorters(sorter_list, mode_if_folder_exists="keep") + # define some "cases" here we want to tests tridesclous2 on 2 datasets and spykingcircus on one dataset + # so it is a two level study (sorter_name, dataset) + # this could be more complicated like (sorter_name, dataset, params) + cases = { + ("tdc2", "toy0"): { + "label": "tridesclous2 on tetrode0", + "dataset": "toy0", + "run_sorter_params": { + "sorter_name": "tridesclous2", + }, + }, + # + ("tdc2", "toy1"): { + "label": "tridesclous2 on tetrode1", + "dataset": "toy1", + "run_sorter_params": { + "sorter_name": "tridesclous2", + }, + }, + + ("sc", "toy0"): { + "label": "spykingcircus2 on tetrode0", + "dataset": "toy0", + "run_sorter_params": { + "sorter_name": "spykingcircus", + "docker_image": True + }, + }, + } + # this initilize a folder + study = GroundTruthStudy.create(study_folder, datasets=datasets, cases=cases, + levels=["sorter_name", "dataset"]) - # You can re-run **run_study_sorters** as many times as you want. - # By default **mode='keep'** so only uncomputed sorters are re-run. - # For instance, just remove the "sorter_folders/rec1/herdingspikes" to re-run - # only one sorter on one recording. - # - # Then we copy the spike sorting outputs into a separate subfolder. - # This allow us to remove the "large" sorter_folders. - study.copy_sortings() + + # all cases in one function + study.run_sorters() # Collect comparisons #   @@ -306,11 +328,11 @@ binary (for recordings) and npz (for sortings). # Note: use exhaustive_gt=True when you know exactly how many # units in ground truth (for synthetic datasets) + # run all comparisons and loop over the results study.run_comparisons(exhaustive_gt=True) - - for (rec_name, sorter_name), comp in study.comparisons.items(): + for key, comp in study.comparisons.items(): print('*' * 10) - print(rec_name, sorter_name) + print(key) # raw counting of tp/fp/... print(comp.count_score) # summary @@ -323,26 +345,27 @@ binary (for recordings) and npz (for sortings). # Collect synthetic dataframes and display # As shown previously, the performance is returned as a pandas dataframe. - # The :py:func:`~spikeinterface.comparison.aggregate_performances_table()` function, + # The :py:func:`~spikeinterface.comparison.get_performance_by_unit()` function, # gathers all the outputs in the study folder and merges them in a single dataframe. + # Same idea for :py:func:`~spikeinterface.comparison.get_count_units()` - dataframes = study.aggregate_dataframes() + # this is a dataframe + perfs = study.get_performance_by_unit() - # Pandas dataframes can be nicely displayed as tables in the notebook. - print(dataframes.keys()) + # this is a dataframe + unit_counts = study.get_count_units() # we can also access run times - print(dataframes['run_times']) + run_times = study.get_run_times() + print(run_times) # Easy plot with seaborn - run_times = dataframes['run_times'] fig1, ax1 = plt.subplots() sns.barplot(data=run_times, x='rec_name', y='run_time', hue='sorter_name', ax=ax1) ax1.set_title('Run times') ############################################################################## - perfs = dataframes['perf_by_unit'] fig2, ax2 = plt.subplots() sns.swarmplot(data=perfs, x='sorter_name', y='recall', hue='rec_name', ax=ax2) ax2.set_title('Recall') diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index 6898f381b6..6dc9cb30f0 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -126,9 +126,6 @@ def scan_folder(self): self.info = json.load(f) self.levels = self.info["levels"] - # if isinstance(self.levels, list): - # # because tuple caoont be stored in json - # self.levels = tuple(self.info["levels"]) for rec_file in (self.folder / "datasets/recordings").glob("*.pickle"): key = rec_file.stem @@ -169,9 +166,6 @@ def key_to_str(self, key): raise ValueError("Keys for cases must str or tuple") def run_sorters(self, case_keys=None, engine='loop', engine_kwargs={}, keep=True, verbose=False): - """ - - """ if case_keys is None: case_keys = self.cases.keys() From d7aaa95e295d16fd1c9e6fe10fd82f93029a5cb1 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 19 Sep 2023 21:01:18 +0200 Subject: [PATCH 20/26] gt study widget xlim --- src/spikeinterface/widgets/gtstudy.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/widgets/gtstudy.py b/src/spikeinterface/widgets/gtstudy.py index bc2c1246b7..438858beae 100644 --- a/src/spikeinterface/widgets/gtstudy.py +++ b/src/spikeinterface/widgets/gtstudy.py @@ -243,10 +243,14 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): study = dp.study perfs = study.get_performance_by_unit(case_keys=dp.case_keys) + max_metric = 0 for key in dp.case_keys: x = study.get_metrics(key)[dp.metric_name].values y = perfs.xs(key)[dp.performance_name].values label = dp.study.cases[key]["label"] self.ax.scatter(x, y, label=label) + max_metric = max(max_metric, np.max(x)) - self.ax.legend() \ No newline at end of file + self.ax.legend() + self.ax.set_xlim(0, max_metric * 1.05) + self.ax.set_ylim(0, 1.05) \ No newline at end of file From 5029445580bc6274ee8845636dd8d09b07e85826 Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Tue, 26 Sep 2023 13:25:48 +0200 Subject: [PATCH 21/26] Apply suggestions from code review thanks alessio Co-authored-by: Alessio Buccino --- doc/modules/comparison.rst | 1 - .../comparison/groundtruthstudy.py | 45 +++++++++---------- .../comparison/tests/test_groundtruthstudy.py | 1 - 3 files changed, 20 insertions(+), 27 deletions(-) diff --git a/doc/modules/comparison.rst b/doc/modules/comparison.rst index 9b2e701dac..57e9a0b5ba 100644 --- a/doc/modules/comparison.rst +++ b/doc/modules/comparison.rst @@ -293,7 +293,6 @@ The all mechanism is based on an intrinsic organization into a "study_folder" wi "sorter_name": "tridesclous2", }, }, - # ("tdc2", "toy1"): { "label": "tridesclous2 on tetrode1", "dataset": "toy1", diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index 6dc9cb30f0..2d4486bbe4 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -28,24 +28,23 @@ class GroundTruthStudy: """ - This class is an helper function to run any comparison on several "cases" for several ground truth dataset. + This class is an helper function to run any comparison on several "cases" for many ground-truth dataset. - "cases" can be: - * several sorter for comparisons + "cases" refer to: + * several sorters for comparisons * same sorter with differents parameters * parameters of comparisons - * any combination of theses + * any combination of these (and more) - For enough flexibility cases key can be a tuple so that we can varify complexity along several - "levels" or "axis" (paremeters or sorter). - - Generated dataframes will have index with several levels optionaly. + For increased flexibility, cases keys can be a tuple so that we can vary complexity along several + "levels" or "axis" (paremeters or sorters). + In this case, the result dataframes will have `MultiIndex` to handle the different levels. - Ground truth dataset need recording+sorting. This can be from mearec file or from the internal generator - :py:fun:`generate_ground_truth_recording()` + A ground-truth dataset is made of a `Recording` and a `Sorting` object. For example, it can be a simulated dataset with MEArec or internally generated (see + :py:fun:`~spikeinterface.core.generate.generate_ground_truth_recording()`). This GroundTruthStudy have been refactor in version 0.100 to be more flexible than previous versions. - Folders structures are not backward compatible at all. + Note that the underlying folder structure is not backward compatible! """ def __init__(self, study_folder): self.folder = Path(study_folder) @@ -85,21 +84,21 @@ def create(cls, study_folder, datasets={}, cases={}, levels=None): study_folder.mkdir(exist_ok=False, parents=True) (study_folder / "datasets").mkdir() - (study_folder / "datasets/recordings").mkdir() - (study_folder / "datasets/gt_sortings").mkdir() + (study_folder / "datasets" / "recordings").mkdir() + (study_folder / "datasets" / "gt_sortings").mkdir() (study_folder / "sorters").mkdir() (study_folder / "sortings").mkdir() (study_folder / "sortings" / "run_logs").mkdir() (study_folder / "metrics").mkdir() for key, (rec, gt_sorting) in datasets.items(): - assert "/" not in key - assert "\\" not in key + assert "/" not in key, "'/' cannot be in the key name!" + assert "\\" not in key, "'\\' cannot be in the key name!" - # rec are pickle + # recordings are pickled rec.dump_to_pickle(study_folder / f"datasets/recordings/{key}.pickle") - # sorting are pickle + saved as NumpyFolderSorting + # sortings are pickled + saved as NumpyFolderSorting gt_sorting.dump_to_pickle(study_folder / f"datasets/gt_sortings/{key}.pickle") gt_sorting.save(format="numpy_folder", folder=study_folder / f"datasets/gt_sortings/{key}") @@ -108,11 +107,7 @@ def create(cls, study_folder, datasets={}, cases={}, levels=None): info["levels"] = levels (study_folder / "info.json").write_text(json.dumps(info, indent=4), encoding="utf8") - # (study_folder / "cases.jon").write_text( - # json.dumps(cases, indent=4, cls=SIJsonEncoder), - # encoding="utf8", - # ) - # cases is dump to a pickle file, json is not possible because of tuple key + # cases is dumped to a pickle file, json is not possible because of the tuple key (study_folder / "cases.pickle").write_bytes(pickle.dumps(cases)) return cls(study_folder) @@ -127,10 +122,10 @@ def scan_folder(self): self.levels = self.info["levels"] - for rec_file in (self.folder / "datasets/recordings").glob("*.pickle"): + for rec_file in (self.folder / "datasets" / "recordings").glob("*.pickle"): key = rec_file.stem rec = load_extractor(rec_file) - gt_sorting = load_extractor(self.folder / f"datasets/gt_sortings/{key}") + gt_sorting = load_extractor(self.folder / f"datasets" / "gt_sortings" / key) self.datasets[key] = (rec, gt_sorting) with open(self.folder / "cases.pickle", "rb") as f: @@ -304,7 +299,7 @@ def compute_metrics(self, case_keys=None, metric_names=["snr", "firing_rate"], f case_keys = self.cases.keys() for key in case_keys: - filename = self.folder / "metrics" / f"{self.key_to_str(key)}.txt" + filename = self.folder / "metrics" / f"{self.key_to_str(key)}.csv" if filename.exists(): if force: os.remove(filename) diff --git a/src/spikeinterface/comparison/tests/test_groundtruthstudy.py b/src/spikeinterface/comparison/tests/test_groundtruthstudy.py index a75ac272be..12d764950e 100644 --- a/src/spikeinterface/comparison/tests/test_groundtruthstudy.py +++ b/src/spikeinterface/comparison/tests/test_groundtruthstudy.py @@ -16,7 +16,6 @@ study_folder = cache_folder / "test_groundtruthstudy/" -print(study_folder.absolute()) def setup_module(): if study_folder.is_dir(): From 32d3d7a6aebdaed8757fe6ca994c537e6034927c Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 26 Sep 2023 20:52:40 +0200 Subject: [PATCH 22/26] extract_waveforms_gt must be done on dataset key instead of case key. --- .../comparison/groundtruthstudy.py | 29 ++++++++++++++----- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index 2d4486bbe4..8a294a88af 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -267,24 +267,29 @@ def get_run_times(self, case_keys=None): return pd.Series(run_times, name="run_time") def extract_waveforms_gt(self, case_keys=None, **extract_kwargs): - + if case_keys is None: case_keys = self.cases.keys() base_folder = self.folder / "waveforms" base_folder.mkdir(exist_ok=True) - for key in case_keys: - dataset_key = self.cases[key]["dataset"] + dataset_keys = [self.cases[key]["dataset"] for key in case_keys] + dataset_keys = set(dataset_keys) + for dataset_key in dataset_keys: + # the waveforms depend on the dataset key + wf_folder = base_folder / self.key_to_str(dataset_key) recording, gt_sorting = self.datasets[dataset_key] - wf_folder = base_folder / self.key_to_str(key) we = extract_waveforms(recording, gt_sorting, folder=wf_folder) def get_waveform_extractor(self, key): # some recording are not dumpable to json and the waveforms extactor need it! # so we load it with and put after - we = load_waveforms(self.folder / "waveforms" / self.key_to_str(key), with_recording=False) + # this should be fixed in PR 2027 so remove this after + dataset_key = self.cases[key]["dataset"] + wf_folder = self.folder / "waveforms" / self.key_to_str(dataset_key) + we = load_waveforms(wf_folder, with_recording=False) recording, _ = self.datasets[dataset_key] we.set_recording(recording) return we @@ -298,21 +303,29 @@ def compute_metrics(self, case_keys=None, metric_names=["snr", "firing_rate"], f if case_keys is None: case_keys = self.cases.keys() + done = [] for key in case_keys: - filename = self.folder / "metrics" / f"{self.key_to_str(key)}.csv" + dataset_key = self.cases[key]["dataset"] + if dataset_key in done: + # some case can share the same waveform extractor + continue + done.append(dataset_key) + filename = self.folder / "metrics" / f"{self.key_to_str(dataset_key)}.csv" if filename.exists(): if force: os.remove(filename) else: continue - we = self.get_waveform_extractor(key) metrics = compute_quality_metrics(we, metric_names=metric_names) metrics.to_csv(filename, sep="\t", index=True) def get_metrics(self, key): import pandas as pd - filename = self.folder / "metrics" / f"{self.key_to_str(key)}.txt" + + dataset_key = self.cases[key]["dataset"] + + filename = self.folder / "metrics" / f"{self.key_to_str(dataset_key)}.csv" if not filename.exists(): return metrics = pd.read_csv(filename, sep="\t", index_col=0) From d48cd681f97fcee2374b65a97f0ecbc9d10b4588 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 27 Sep 2023 09:02:05 +0200 Subject: [PATCH 23/26] implement some TODOs --- .../comparison/groundtruthstudy.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index 8a294a88af..34777c6f20 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -22,8 +22,6 @@ # This is to separate names when the key are tuples when saving folders _key_separator = " ## " -# This would be more funny -# _key_separator = " (°_°) " class GroundTruthStudy: @@ -184,8 +182,12 @@ def run_sorters(self, case_keys=None, engine='loop', engine_kwargs={}, keep=True continue if sorting_exists: - # TODO : delete sorting + log - pass + # delete older sorting + log before running sorters + shutil.rmtree(sorting_exists) + log_file = self.folder / "sortings" / "run_logs" / f"{self.key_to_str(key)}.json" + if log_file.exists(): + log_file.unlink() + params = self.cases[key]["run_sorter_params"].copy() # this ensure that sorter_name is given @@ -201,7 +203,7 @@ def run_sorters(self, case_keys=None, engine='loop', engine_kwargs={}, keep=True run_sorter_jobs(job_list, engine=engine, engine_kwargs=engine_kwargs, return_output=False) - # TODO create a list in laucher for engine blocking and non-blocking + # TODO later create a list in laucher for engine blocking and non-blocking if engine not in ("slurm", ): self.copy_sortings(case_keys) @@ -223,8 +225,10 @@ def copy_sortings(self, case_keys=None, force=True): if sorting is not None: if sorting_folder.exists(): if force: - # TODO delete folder + log + # delete folder + log shutil.rmtree(sorting_folder) + if log_file.exists(): + log_file.unlink() else: continue From af72fbcaa040c4216e2f2b60465197b484e2d2c9 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 27 Sep 2023 11:25:20 +0200 Subject: [PATCH 24/26] oups --- src/spikeinterface/comparison/groundtruthstudy.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index 34777c6f20..fcebb356a0 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -31,7 +31,6 @@ class GroundTruthStudy: "cases" refer to: * several sorters for comparisons * same sorter with differents parameters - * parameters of comparisons * any combination of these (and more) For increased flexibility, cases keys can be a tuple so that we can vary complexity along several @@ -403,11 +402,11 @@ def get_count_units( count_units.loc[key, "num_well_detected"] = comp.count_well_detected_units( well_detected_score ) + count_units.loc[key, "num_overmerged"] = comp.count_overmerged_units( + overmerged_score + ) + count_units.loc[key, "num_redundant"] = comp.count_redundant_units(redundant_score) if comp.exhaustive_gt: - count_units.loc[key, "num_overmerged"] = comp.count_overmerged_units( - overmerged_score - ) - count_units.loc[key, "num_redundant"] = comp.count_redundant_units(redundant_score) count_units.loc[key, "num_false_positive"] = comp.count_false_positive_units( redundant_score ) From 6c561f214b02716e8da41a7ac198a94081f056a4 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 27 Sep 2023 15:54:14 +0200 Subject: [PATCH 25/26] more fix after merge with main and the new pickle to file mechanism --- .../comparison/groundtruthstudy.py | 21 +++++++++++-------- src/spikeinterface/sorters/basesorter.py | 10 ++++++--- src/spikeinterface/sorters/launcher.py | 8 ++++++- 3 files changed, 26 insertions(+), 13 deletions(-) diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index fcebb356a0..eb430f69bd 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -194,10 +194,12 @@ def run_sorters(self, case_keys=None, engine='loop', engine_kwargs={}, keep=True sorter_name = params.pop("sorter_name") job = dict(sorter_name=sorter_name, recording=recording, - output_folder=sorter_folder) + output_folder=sorter_folder, + ) job.update(params) # the verbose is overwritten and global to all run_sorters job["verbose"] = verbose + job["with_output"] = False job_list.append(job) run_sorter_jobs(job_list, engine=engine, engine_kwargs=engine_kwargs, return_output=False) @@ -217,7 +219,8 @@ def copy_sortings(self, case_keys=None, force=True): if (sorter_folder / "spikeinterface_log.json").exists(): - sorting = read_sorter_folder(sorter_folder, raise_error=False) + sorting = read_sorter_folder(sorter_folder, raise_error=False, + register_recording=False, sorting_info=False) else: sorting = None @@ -383,13 +386,12 @@ def get_count_units( index = pd.MultiIndex.from_tuples(case_keys, names=self.levels) - columns = ["num_gt", "num_sorter", "num_well_detected", "num_redundant", "num_overmerged"] + columns = ["num_gt", "num_sorter", "num_well_detected"] comp = self.comparisons[case_keys[0]] if comp.exhaustive_gt: - columns.extend(["num_false_positive", "num_bad"]) + columns.extend(["num_false_positive", "num_redundant", "num_overmerged", "num_bad"]) count_units = pd.DataFrame(index=index, columns=columns, dtype=int) - for key in case_keys: comp = self.comparisons.get(key, None) assert comp is not None, "You need to do study.run_comparisons() first" @@ -402,11 +404,12 @@ def get_count_units( count_units.loc[key, "num_well_detected"] = comp.count_well_detected_units( well_detected_score ) - count_units.loc[key, "num_overmerged"] = comp.count_overmerged_units( - overmerged_score - ) - count_units.loc[key, "num_redundant"] = comp.count_redundant_units(redundant_score) + if comp.exhaustive_gt: + count_units.loc[key, "num_redundant"] = comp.count_redundant_units(redundant_score) + count_units.loc[key, "num_overmerged"] = comp.count_overmerged_units( + overmerged_score + ) count_units.loc[key, "num_false_positive"] = comp.count_false_positive_units( redundant_score ) diff --git a/src/spikeinterface/sorters/basesorter.py b/src/spikeinterface/sorters/basesorter.py index 8d87558191..a956f8c811 100644 --- a/src/spikeinterface/sorters/basesorter.py +++ b/src/spikeinterface/sorters/basesorter.py @@ -202,7 +202,7 @@ def load_recording_from_folder(cls, output_folder, with_warnings=False): recording = None else: recording = load_extractor(json_file, base_folder=output_folder) - elif pickle_file.exits(): + elif pickle_file.exists(): recording = load_extractor(pickle_file) return recording @@ -324,8 +324,12 @@ def get_result_from_folder(cls, output_folder, register_recording=True, sorting_ if sorting_info: # set sorting info to Sorting object - with open(output_folder / "spikeinterface_recording.json", "r") as f: - rec_dict = json.load(f) + if (output_folder / "spikeinterface_recording.json").exists(): + with open(output_folder / "spikeinterface_recording.json", "r") as f: + rec_dict = json.load(f) + else: + rec_dict = None + with open(output_folder / "spikeinterface_params.json", "r") as f: params_dict = json.load(f) with open(output_folder / "spikeinterface_log.json", "r") as f: diff --git a/src/spikeinterface/sorters/launcher.py b/src/spikeinterface/sorters/launcher.py index f32a468a22..12c59cbe45 100644 --- a/src/spikeinterface/sorters/launcher.py +++ b/src/spikeinterface/sorters/launcher.py @@ -66,7 +66,8 @@ def run_sorter_jobs(job_list, engine="loop", engine_kwargs={}, return_output=Fal engine_kwargs: dict return_output: bool, dfault False - Return a sorting or None. + Return a sortings or None. + This also overwrite kwargs in in run_sorter(with_sorting=True/False) Returns ------- @@ -88,8 +89,13 @@ def run_sorter_jobs(job_list, engine="loop", engine_kwargs={}, return_output=Fal "processpoolexecutor", ), "Only 'loop', 'joblib', and 'processpoolexecutor' support return_output=True." out = [] + for kwargs in job_list: + kwargs['with_output'] = True else: out = None + for kwargs in job_list: + kwargs['with_output'] = False + if engine == "loop": # simple loop in main process From cb9a2289cf1aab818307265aefa1abfcf2a0329c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 27 Sep 2023 13:55:09 +0000 Subject: [PATCH 26/26] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- doc/modules/comparison.rst | 2 +- src/spikeinterface/comparison/collision.py | 17 +--- src/spikeinterface/comparison/correlogram.py | 6 +- .../comparison/groundtruthstudy.py | 92 ++++++++----------- .../comparison/tests/test_groundtruthstudy.py | 28 ++---- src/spikeinterface/sorters/launcher.py | 5 +- src/spikeinterface/widgets/gtstudy.py | 31 +++---- src/spikeinterface/widgets/widget_list.py | 2 +- 8 files changed, 74 insertions(+), 109 deletions(-) diff --git a/doc/modules/comparison.rst b/doc/modules/comparison.rst index 57e9a0b5ba..76ab7855c6 100644 --- a/doc/modules/comparison.rst +++ b/doc/modules/comparison.rst @@ -314,7 +314,7 @@ The all mechanism is based on an intrinsic organization into a "study_folder" wi study = GroundTruthStudy.create(study_folder, datasets=datasets, cases=cases, levels=["sorter_name", "dataset"]) - + # all cases in one function study.run_sorters() diff --git a/src/spikeinterface/comparison/collision.py b/src/spikeinterface/comparison/collision.py index 01626b34b8..dd04b2c72d 100644 --- a/src/spikeinterface/comparison/collision.py +++ b/src/spikeinterface/comparison/collision.py @@ -5,10 +5,6 @@ import numpy as np - - - - class CollisionGTComparison(GroundTruthComparison): """ This class is an extension of GroundTruthComparison by focusing to benchmark spike in collision. @@ -164,7 +160,6 @@ def compute_collision_by_similarity(self, similarity_matrix, unit_ids=None, good return similarities, recall_scores, pair_names - class CollisionGTStudy(GroundTruthStudy): def run_comparisons(self, case_keys=None, exhaustive_gt=True, collision_lag=2.0, nbins=11, **kwargs): _kwargs = dict() @@ -179,11 +174,12 @@ def run_comparisons(self, case_keys=None, exhaustive_gt=True, collision_lag=2.0, def get_lags(self, key): comp = self.comparisons[key] fs = comp.sorting1.get_sampling_frequency() - lags = comp.bins / fs * 1000. + lags = comp.bins / fs * 1000.0 return lags def precompute_scores_by_similarities(self, case_keys=None, good_only=False, min_accuracy=0.9): import sklearn + if case_keys is None: case_keys = self.cases.keys() @@ -197,16 +193,13 @@ def precompute_scores_by_similarities(self, case_keys=None, good_only=False, min similarity = sklearn.metrics.pairwise.cosine_similarity(flat_templates) comp = self.comparisons[key] similarities, recall_scores, pair_names = comp.compute_collision_by_similarity( - similarity, good_only=good_only, min_accuracy=min_accuracy - ) + similarity, good_only=good_only, min_accuracy=min_accuracy + ) self.all_similarities[key] = similarities self.all_recall_scores[key] = recall_scores - def get_mean_over_similarity_range(self, similarity_range, key): - idx = (self.all_similarities[key] >= similarity_range[0]) & ( - self.all_similarities[key] <= similarity_range[1] - ) + idx = (self.all_similarities[key] >= similarity_range[0]) & (self.all_similarities[key] <= similarity_range[1]) all_similarities = self.all_similarities[key][idx] all_recall_scores = self.all_recall_scores[key][idx] diff --git a/src/spikeinterface/comparison/correlogram.py b/src/spikeinterface/comparison/correlogram.py index 150f5afe55..aaffef1887 100644 --- a/src/spikeinterface/comparison/correlogram.py +++ b/src/spikeinterface/comparison/correlogram.py @@ -6,7 +6,6 @@ import numpy as np - class CorrelogramGTComparison(GroundTruthComparison): """ This class is an extension of GroundTruthComparison by focusing @@ -112,9 +111,10 @@ def compute_correlogram_by_similarity(self, similarity_matrix, window_ms=None): return similarities, errors - class CorrelogramGTStudy(GroundTruthStudy): - def run_comparisons(self, case_keys=None, exhaustive_gt=True, window_ms=100.0, bin_ms=1.0, well_detected_score=0.8, **kwargs): + def run_comparisons( + self, case_keys=None, exhaustive_gt=True, window_ms=100.0, bin_ms=1.0, well_detected_score=0.8, **kwargs + ): _kwargs = dict() _kwargs.update(kwargs) _kwargs["exhaustive_gt"] = exhaustive_gt diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index eb430f69bd..d43727cb44 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -32,17 +32,18 @@ class GroundTruthStudy: * several sorters for comparisons * same sorter with differents parameters * any combination of these (and more) - + For increased flexibility, cases keys can be a tuple so that we can vary complexity along several "levels" or "axis" (paremeters or sorters). In this case, the result dataframes will have `MultiIndex` to handle the different levels. - - A ground-truth dataset is made of a `Recording` and a `Sorting` object. For example, it can be a simulated dataset with MEArec or internally generated (see + + A ground-truth dataset is made of a `Recording` and a `Sorting` object. For example, it can be a simulated dataset with MEArec or internally generated (see :py:fun:`~spikeinterface.core.generate.generate_ground_truth_recording()`). - + This GroundTruthStudy have been refactor in version 0.100 to be more flexible than previous versions. Note that the underlying folder structure is not backward compatible! """ + def __init__(self, study_folder): self.folder = Path(study_folder) @@ -55,7 +56,6 @@ def __init__(self, study_folder): @classmethod def create(cls, study_folder, datasets={}, cases={}, levels=None): - # check that cases keys are homogeneous key0 = list(cases.keys())[0] if isinstance(key0, str): @@ -67,7 +67,9 @@ def create(cls, study_folder, datasets={}, cases={}, levels=None): elif isinstance(key0, tuple): assert all(isinstance(key, tuple) for key in cases.keys()), "Keys for cases are not homogeneous" num_levels = len(key0) - assert all(len(key) == num_levels for key in cases.keys()), "Keys for cases are not homogeneous, tuple negth differ" + assert all( + len(key) == num_levels for key in cases.keys() + ), "Keys for cases are not homogeneous, tuple negth differ" if levels is None: levels = [f"level{i}" for i in range(num_levels)] else: @@ -76,7 +78,6 @@ def create(cls, study_folder, datasets={}, cases={}, levels=None): else: raise ValueError("Keys for cases must str or tuple") - study_folder = Path(study_folder) study_folder.mkdir(exist_ok=False, parents=True) @@ -98,8 +99,7 @@ def create(cls, study_folder, datasets={}, cases={}, levels=None): # sortings are pickled + saved as NumpyFolderSorting gt_sorting.dump_to_pickle(study_folder / f"datasets/gt_sortings/{key}.pickle") gt_sorting.save(format="numpy_folder", folder=study_folder / f"datasets/gt_sortings/{key}") - - + info = {} info["levels"] = levels (study_folder / "info.json").write_text(json.dumps(info, indent=4), encoding="utf8") @@ -109,14 +109,13 @@ def create(cls, study_folder, datasets={}, cases={}, levels=None): return cls(study_folder) - def scan_folder(self): if not (self.folder / "datasets").exists(): raise ValueError(f"This is folder is not a GroundTruthStudy : {self.folder.absolute()}") with open(self.folder / "info.json", "r") as f: self.info = json.load(f) - + self.levels = self.info["levels"] for rec_file in (self.folder / "datasets" / "recordings").glob("*.pickle"): @@ -124,7 +123,7 @@ def scan_folder(self): rec = load_extractor(rec_file) gt_sorting = load_extractor(self.folder / f"datasets" / "gt_sortings" / key) self.datasets[key] = (rec, gt_sorting) - + with open(self.folder / "cases.pickle", "rb") as f: self.cases = pickle.load(f) @@ -139,7 +138,6 @@ def scan_folder(self): sorting = None self.sortings[key] = sorting - def __repr__(self): t = f"{self.__class__.__name__} {self.folder.stem} \n" t += f" datasets: {len(self.datasets)} {list(self.datasets.keys())}\n" @@ -157,7 +155,7 @@ def key_to_str(self, key): else: raise ValueError("Keys for cases must str or tuple") - def run_sorters(self, case_keys=None, engine='loop', engine_kwargs={}, keep=True, verbose=False): + def run_sorters(self, case_keys=None, engine="loop", engine_kwargs={}, keep=True, verbose=False): if case_keys is None: case_keys = self.cases.keys() @@ -187,15 +185,15 @@ def run_sorters(self, case_keys=None, engine='loop', engine_kwargs={}, keep=True if log_file.exists(): log_file.unlink() - params = self.cases[key]["run_sorter_params"].copy() # this ensure that sorter_name is given recording, _ = self.datasets[self.cases[key]["dataset"]] sorter_name = params.pop("sorter_name") - job = dict(sorter_name=sorter_name, - recording=recording, - output_folder=sorter_folder, - ) + job = dict( + sorter_name=sorter_name, + recording=recording, + output_folder=sorter_folder, + ) job.update(params) # the verbose is overwritten and global to all run_sorters job["verbose"] = verbose @@ -205,25 +203,25 @@ def run_sorters(self, case_keys=None, engine='loop', engine_kwargs={}, keep=True run_sorter_jobs(job_list, engine=engine, engine_kwargs=engine_kwargs, return_output=False) # TODO later create a list in laucher for engine blocking and non-blocking - if engine not in ("slurm", ): + if engine not in ("slurm",): self.copy_sortings(case_keys) def copy_sortings(self, case_keys=None, force=True): if case_keys is None: case_keys = self.cases.keys() - + for key in case_keys: sorting_folder = self.folder / "sortings" / self.key_to_str(key) sorter_folder = self.folder / "sorters" / self.key_to_str(key) log_file = self.folder / "sortings" / "run_logs" / f"{self.key_to_str(key)}.json" - if (sorter_folder / "spikeinterface_log.json").exists(): - sorting = read_sorter_folder(sorter_folder, raise_error=False, - register_recording=False, sorting_info=False) + sorting = read_sorter_folder( + sorter_folder, raise_error=False, register_recording=False, sorting_info=False + ) else: sorting = None - + if sorting is not None: if sorting_folder.exists(): if force: @@ -241,7 +239,6 @@ def copy_sortings(self, case_keys=None, force=True): shutil.copyfile(sorter_folder / "spikeinterface_log.json", log_file) def run_comparisons(self, case_keys=None, comparison_class=GroundTruthComparison, **kwargs): - if case_keys is None: case_keys = self.cases.keys() @@ -250,18 +247,19 @@ def run_comparisons(self, case_keys=None, comparison_class=GroundTruthComparison _, gt_sorting = self.datasets[dataset_key] sorting = self.sortings[key] if sorting is None: - self.comparisons[key] = None + self.comparisons[key] = None continue comp = comparison_class(gt_sorting, sorting, **kwargs) self.comparisons[key] = comp def get_run_times(self, case_keys=None): import pandas as pd + if case_keys is None: case_keys = self.cases.keys() log_folder = self.folder / "sortings" / "run_logs" - + run_times = {} for key in case_keys: log_file = log_folder / f"{self.key_to_str(key)}.json" @@ -273,7 +271,6 @@ def get_run_times(self, case_keys=None): return pd.Series(run_times, name="run_time") def extract_waveforms_gt(self, case_keys=None, **extract_kwargs): - if case_keys is None: case_keys = self.cases.keys() @@ -292,11 +289,11 @@ def get_waveform_extractor(self, key): # some recording are not dumpable to json and the waveforms extactor need it! # so we load it with and put after # this should be fixed in PR 2027 so remove this after - + dataset_key = self.cases[key]["dataset"] wf_folder = self.folder / "waveforms" / self.key_to_str(dataset_key) we = load_waveforms(wf_folder, with_recording=False) - recording, _ = self.datasets[dataset_key] + recording, _ = self.datasets[dataset_key] we.set_recording(recording) return we @@ -308,7 +305,7 @@ def get_templates(self, key, mode="average"): def compute_metrics(self, case_keys=None, metric_names=["snr", "firing_rate"], force=False): if case_keys is None: case_keys = self.cases.keys() - + done = [] for key in case_keys: dataset_key = self.cases[key]["dataset"] @@ -327,7 +324,7 @@ def compute_metrics(self, case_keys=None, metric_names=["snr", "firing_rate"], f metrics.to_csv(filename, sep="\t", index=True) def get_metrics(self, key): - import pandas as pd + import pandas as pd dataset_key = self.cases[key]["dataset"] @@ -336,17 +333,15 @@ def get_metrics(self, key): return metrics = pd.read_csv(filename, sep="\t", index_col=0) dataset_key = self.cases[key]["dataset"] - recording, gt_sorting = self.datasets[dataset_key] + recording, gt_sorting = self.datasets[dataset_key] metrics.index = gt_sorting.unit_ids return metrics def get_units_snr(self, key): - """ - """ + """ """ return self.get_metrics(key)["snr"] def get_performance_by_unit(self, case_keys=None): - import pandas as pd if case_keys is None: @@ -363,7 +358,7 @@ def get_performance_by_unit(self, case_keys=None): elif isinstance(key, tuple): for col, k in zip(self.levels, key): perf[col] = k - + perf = perf.reset_index() perf_by_unit.append(perf) @@ -371,10 +366,7 @@ def get_performance_by_unit(self, case_keys=None): perf_by_unit = perf_by_unit.set_index(self.levels) return perf_by_unit - def get_count_units( - self, case_keys=None, well_detected_score=None, redundant_score=None, overmerged_score=None - ): - + def get_count_units(self, case_keys=None, well_detected_score=None, redundant_score=None, overmerged_score=None): import pandas as pd if case_keys is None: @@ -385,7 +377,6 @@ def get_count_units( else: index = pd.MultiIndex.from_tuples(case_keys, names=self.levels) - columns = ["num_gt", "num_sorter", "num_well_detected"] comp = self.comparisons[case_keys[0]] if comp.exhaustive_gt: @@ -401,19 +392,12 @@ def get_count_units( count_units.loc[key, "num_gt"] = len(gt_sorting.get_unit_ids()) count_units.loc[key, "num_sorter"] = len(sorting.get_unit_ids()) - count_units.loc[key, "num_well_detected"] = comp.count_well_detected_units( - well_detected_score - ) - + count_units.loc[key, "num_well_detected"] = comp.count_well_detected_units(well_detected_score) + if comp.exhaustive_gt: count_units.loc[key, "num_redundant"] = comp.count_redundant_units(redundant_score) - count_units.loc[key, "num_overmerged"] = comp.count_overmerged_units( - overmerged_score - ) - count_units.loc[key, "num_false_positive"] = comp.count_false_positive_units( - redundant_score - ) + count_units.loc[key, "num_overmerged"] = comp.count_overmerged_units(overmerged_score) + count_units.loc[key, "num_false_positive"] = comp.count_false_positive_units(redundant_score) count_units.loc[key, "num_bad"] = comp.count_bad_units() return count_units - diff --git a/src/spikeinterface/comparison/tests/test_groundtruthstudy.py b/src/spikeinterface/comparison/tests/test_groundtruthstudy.py index 12d764950e..91c8c640e0 100644 --- a/src/spikeinterface/comparison/tests/test_groundtruthstudy.py +++ b/src/spikeinterface/comparison/tests/test_groundtruthstudy.py @@ -7,7 +7,6 @@ from spikeinterface.comparison import GroundTruthStudy - if hasattr(pytest, "global_test_folder"): cache_folder = pytest.global_test_folder / "comparison" else: @@ -28,8 +27,8 @@ def simple_preprocess(rec): def create_a_study(study_folder): - rec0, gt_sorting0 = generate_ground_truth_recording(num_channels=4, durations=[30.], seed=42) - rec1, gt_sorting1 = generate_ground_truth_recording(num_channels=4, durations=[30.], seed=91) + rec0, gt_sorting0 = generate_ground_truth_recording(num_channels=4, durations=[30.0], seed=42) + rec1, gt_sorting1 = generate_ground_truth_recording(num_channels=4, durations=[30.0], seed=91) datasets = { "toy_tetrode": (rec0, gt_sorting0), @@ -46,9 +45,7 @@ def create_a_study(study_folder): "run_sorter_params": { "sorter_name": "tridesclous2", }, - "comparison_params": { - - }, + "comparison_params": {}, }, # ("tdc2", "with-preprocess", "probe32"): { @@ -57,11 +54,9 @@ def create_a_study(study_folder): "run_sorter_params": { "sorter_name": "tridesclous2", }, - "comparison_params": { - - }, + "comparison_params": {}, }, - # we comment this at the moement because SC2 is quite slow for testing + # we comment this at the moement because SC2 is quite slow for testing # ("sc2", "no-preprocess", "tetrode"): { # "label": "spykingcircus2 without preprocessing standar params", # "dataset": "toy_tetrode", @@ -69,16 +64,16 @@ def create_a_study(study_folder): # "sorter_name": "spykingcircus2", # }, # "comparison_params": { - # }, # }, } - study = GroundTruthStudy.create(study_folder, datasets=datasets, cases=cases, levels=["sorter_name", "processing", "probe_type"]) + study = GroundTruthStudy.create( + study_folder, datasets=datasets, cases=cases, levels=["sorter_name", "processing", "probe_type"] + ) # print(study) - def test_GroundTruthStudy(): study = GroundTruthStudy(study_folder) print(study) @@ -98,14 +93,11 @@ def test_GroundTruthStudy(): for key in study.cases: metrics = study.get_metrics(key) print(metrics) - + study.get_performance_by_unit() study.get_count_units() - if __name__ == "__main__": setup_module() - test_GroundTruthStudy() - - \ No newline at end of file + test_GroundTruthStudy() diff --git a/src/spikeinterface/sorters/launcher.py b/src/spikeinterface/sorters/launcher.py index 12c59cbe45..704f6843f2 100644 --- a/src/spikeinterface/sorters/launcher.py +++ b/src/spikeinterface/sorters/launcher.py @@ -90,12 +90,11 @@ def run_sorter_jobs(job_list, engine="loop", engine_kwargs={}, return_output=Fal ), "Only 'loop', 'joblib', and 'processpoolexecutor' support return_output=True." out = [] for kwargs in job_list: - kwargs['with_output'] = True + kwargs["with_output"] = True else: out = None for kwargs in job_list: - kwargs['with_output'] = False - + kwargs["with_output"] = False if engine == "loop": # simple loop in main process diff --git a/src/spikeinterface/widgets/gtstudy.py b/src/spikeinterface/widgets/gtstudy.py index 438858beae..6a27b78dec 100644 --- a/src/spikeinterface/widgets/gtstudy.py +++ b/src/spikeinterface/widgets/gtstudy.py @@ -29,7 +29,6 @@ def __init__( backend=None, **backend_kwargs, ): - if case_keys is None: case_keys = list(study.cases.keys()) @@ -53,9 +52,8 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): label = dp.study.cases[key]["label"] rt = dp.run_times.loc[key] self.ax.bar(i, rt, width=0.8, label=label) - - self.ax.legend() + self.ax.legend() # TODO : plot optionally average on some levels using group by @@ -80,13 +78,12 @@ def __init__( backend=None, **backend_kwargs, ): - if case_keys is None: case_keys = list(study.cases.keys()) plot_data = dict( study=study, - count_units = study.get_count_units(case_keys=case_keys), + count_units=study.get_count_units(case_keys=case_keys), case_keys=case_keys, ) @@ -107,8 +104,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ncol = len(columns) - colors = get_some_colors(columns, color_engine="auto", - map_name="hot") + colors = get_some_colors(columns, color_engine="auto", map_name="hot") colors["num_well_detected"] = "green" xticklabels = [] @@ -118,7 +114,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): y = dp.count_units.loc[key, col] if not "well_detected" in col: y = -y - + if i == 0: label = col.replace("num_", "").replace("_", " ").title() else: @@ -158,7 +154,6 @@ def __init__( backend=None, **backend_kwargs, ): - if case_keys is None: case_keys = list(study.cases.keys()) @@ -186,11 +181,15 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): if dp.mode == "swarm": levels = perfs.index.names - df = pd.melt(perfs.reset_index(), id_vars=levels, var_name='Metric', value_name='Score', - value_vars=('accuracy','precision', 'recall')) - df['x'] = df.apply(lambda r: ' '.join([r[col] for col in levels]), axis=1) - sns.swarmplot(data=df, x='x', y='Score', hue='Metric', dodge=True) - + df = pd.melt( + perfs.reset_index(), + id_vars=levels, + var_name="Metric", + value_name="Score", + value_vars=("accuracy", "precision", "recall"), + ) + df["x"] = df.apply(lambda r: " ".join([r[col] for col in levels]), axis=1) + sns.swarmplot(data=df, x="x", y="Score", hue="Metric", dodge=True) class StudyPerformancesVsMetrics(BaseWidget): @@ -218,7 +217,6 @@ def __init__( backend=None, **backend_kwargs, ): - if case_keys is None: case_keys = list(study.cases.keys()) @@ -239,7 +237,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): dp = to_attr(data_plot) self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - study = dp.study perfs = study.get_performance_by_unit(case_keys=dp.case_keys) @@ -253,4 +250,4 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): self.ax.legend() self.ax.set_xlim(0, max_metric * 1.05) - self.ax.set_ylim(0, 1.05) \ No newline at end of file + self.ax.set_ylim(0, 1.05) diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index ce853f16bf..ed77de6128 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -53,7 +53,7 @@ StudyRunTimesWidget, StudyUnitCountsWidget, StudyPerformances, - StudyPerformancesVsMetrics + StudyPerformancesVsMetrics, ]