From 403d1df30c18eb63f84b200ea8a861c59d9d6ac5 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Thu, 9 May 2024 18:57:31 +0200 Subject: [PATCH 01/16] update `postprocessing` logic --- element_array_ephys/spike_sorting/si_spike_sorting.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 94f12f84..4c90337e 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -286,7 +286,7 @@ def make(self, key): _ = si.postprocessing.compute_correlograms(we) metric_names = si.qualitymetrics.get_quality_metric_list() - metric_names.extend(si.qualitymetrics.get_quality_pca_metric_list()) + metric_names.extend(si.qualitymetrics.get_quality_pca_metric_list()) # To compute commonly used cluster quality metrics. qc_metrics = si.qualitymetrics.compute_quality_metrics( @@ -308,7 +308,7 @@ def make(self, key): metrics = pd.DataFrame() metrics = pd.concat([qc_metrics, template_metrics], axis=1) - # Save the output (metrics.csv to the output dir) + # Save metrics.csv to the output dir metrics_output_dir = output_dir / sorter_name / "metrics" metrics_output_dir.mkdir(parents=True, exist_ok=True) metrics.to_csv(metrics_output_dir / "metrics.csv") From c934e67ea6e5de2e30b35dbc10ab547e49917159 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Fri, 24 May 2024 11:00:06 -0500 Subject: [PATCH 02/16] feat: prototyping with the new `sorting_analyzer` --- .../spike_sorting/si_spike_sorting.py | 27 +++++++++++++++++-- setup.py | 2 +- 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index ab803490..f7cb1e57 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -255,6 +255,29 @@ def make(self, key): sorting_file, base_folder=output_dir ) + # Sorting Analyzer + analyzer_output_dir = output_dir / sorter_name / "sorting_analyzer" + if analyzer_output_dir.exists(): + sorting_analyzer = si.load_sorting_analyzer(folder=analyzer_output_dir) + else: + sorting_analyzer = si.create_sorting_analyzer( + sorting=si_sorting, + recording=si_recording, + format="binary_folder", + folder=analyzer_output_dir, + sparse=True, + overwrite=True, + ) + + job_kwargs = params.get("SI_JOB_KWARGS", {"n_jobs": -1, "chunk_duration": "1s"}) + all_computable_extensions = ['random_spikes', 'waveforms', 'templates', 'noise_levels', 'amplitude_scalings', 'correlograms', 'isi_histograms', 'principal_components', 'spike_amplitudes', 'spike_locations', 'template_metrics', 'template_similarity', 'unit_locations', 'quality_metrics'] + extensions_to_compute = ['random_spikes', 'waveforms', 'templates', 'noise_levels', + 'spike_amplitudes', 'spike_locations', 'unit_locations', + 'principal_components', + 'template_metrics', 'quality_metrics'] + + sorting_analyzer.compute(extensions_to_compute, **job_kwargs) + # Extract waveforms we: si.WaveformExtractor = si.extract_waveforms( si_recording, @@ -287,7 +310,7 @@ def make(self, key): _ = si.postprocessing.compute_correlograms(we) metric_names = si.qualitymetrics.get_quality_metric_list() - metric_names.extend(si.qualitymetrics.get_quality_pca_metric_list()) + # metric_names.extend(si.qualitymetrics.get_quality_pca_metric_list()) # TODO: temporarily removed # To compute commonly used cluster quality metrics. qc_metrics = si.qualitymetrics.compute_quality_metrics( @@ -297,7 +320,7 @@ def make(self, key): # To compute commonly used waveform/template metrics. template_metric_names = si.postprocessing.get_template_metric_names() - template_metric_names.extend(["amplitude", "duration"]) + template_metric_names.extend(["amplitude", "duration"]) # TODO: does this do anything? template_metrics = si.postprocessing.compute_template_metrics( waveform_extractor=we, diff --git a/setup.py b/setup.py index 52cd38b1..e62719d8 100644 --- a/setup.py +++ b/setup.py @@ -35,7 +35,7 @@ "openpyxl", "plotly", "seaborn", - "spikeinterface", + "spikeinterface>=0.101.0", "scikit-image", "nbformat>=4.2.0", "pyopenephys>=1.1.6", From 3666cda077448cc40d7b7e9c219c9c489396cbd6 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Fri, 24 May 2024 14:34:05 -0500 Subject: [PATCH 03/16] feat: update ingestion to be compatible with spikeinterface 0.101+ --- element_array_ephys/ephys_no_curation.py | 209 ++++++++---------- .../spike_sorting/si_spike_sorting.py | 93 ++------ 2 files changed, 116 insertions(+), 186 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index b0a8bc26..413868da 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -1037,98 +1037,69 @@ def make(self, key): # Get sorter method and create output directory. sorter_name = clustering_method.replace(".", "_") - si_waveform_dir = output_dir / sorter_name / "waveform" - si_sorting_dir = output_dir / sorter_name / "spike_sorting" + si_sorting_analyzer_dir = output_dir / sorter_name / "sorting_analyzer" - if si_waveform_dir.exists(): # Read from spikeinterface outputs - we: si.WaveformExtractor = si.load_waveforms( - si_waveform_dir, with_recording=False - ) - si_sorting: si.sorters.BaseSorter = si.load_extractor( - si_sorting_dir / "si_sorting.pkl", base_folder=output_dir - ) + if si_sorting_analyzer_dir.exists(): # Read from spikeinterface outputs + sorting_analyzer = si.load_sorting_analyzer(folder=si_sorting_analyzer_dir) + si_sorting = sorting_analyzer.sorting - unit_peak_channel: dict[int, int] = si.get_template_extremum_channel( - we, outputs="index" - ) # {unit: peak_channel_id} + # Find representative channel for each unit + unit_peak_channel: dict[int, np.ndarray] = ( + si.ChannelSparsity.from_best_channels( + sorting_analyzer, 1, peak_sign="neg" + ).unit_id_to_channel_indices + ) # {unit: peak_channel_index} + unit_peak_channel = {u: chn[0] for u, chn in unit_peak_channel.items()} spike_count_dict: dict[int, int] = si_sorting.count_num_spikes_per_unit() # {unit: spike_count} - spikes = si_sorting.to_spike_vector() - # reorder channel2electrode_map according to recording channel ids channel2electrode_map = { chn_idx: channel2electrode_map[chn_idx] - for chn_idx in we.channel_ids_to_indices(we.channel_ids) + for chn_idx in sorting_analyzer.channel_ids_to_indices( + sorting_analyzer.channel_ids + ) } # Get unit id to quality label mapping - try: - cluster_quality_label_map = pd.read_csv( - si_sorting_dir / "sorter_output" / "cluster_KSLabel.tsv", - delimiter="\t", + cluster_quality_label_map = { + int(unit_id): ( + si_sorting.get_unit_property(unit_id, "KSLabel") + if "KSLabel" in si_sorting.get_property_keys() + else "n.a." ) - except FileNotFoundError: - cluster_quality_label_map = {} - else: - cluster_quality_label_map: dict[ - int, str - ] = cluster_quality_label_map.set_index("cluster_id")[ - "KSLabel" - ].to_dict() # {unit: quality_label} - - # Get electrode where peak unit activity is recorded - peak_electrode_ind = np.array( - [ - channel2electrode_map[unit_peak_channel[unit_id]]["electrode"] - for unit_id in si_sorting.unit_ids - ] - ) - - # Get channel depth - channel_depth_ind = np.array( - [ - we.get_probe().contact_positions[unit_peak_channel[unit_id]][1] - for unit_id in si_sorting.unit_ids - ] - ) - - # Assign electrode and depth for each spike - new_spikes = np.empty( - spikes.shape, - spikes.dtype.descr + [("electrode", " Date: Fri, 24 May 2024 14:52:45 -0500 Subject: [PATCH 04/16] format: black formatting --- element_array_ephys/ephys_no_curation.py | 10 +++++++--- .../spike_sorting/si_spike_sorting.py | 19 ++++++++++++------- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 413868da..99247e35 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -1256,7 +1256,9 @@ def make(self, key): unit_peak_channel = {u: chn[0] for u, chn in unit_peak_channel.items()} # reorder channel2electrode_map according to recording channel ids - channel_indices = sorting_analyzer.channel_ids_to_indices(sorting_analyzer.channel_ids).tolist() + channel_indices = sorting_analyzer.channel_ids_to_indices( + sorting_analyzer.channel_ids + ).tolist() channel2electrode_map = { chn_idx: channel2electrode_map[chn_idx] for chn_idx in channel_indices } @@ -1500,7 +1502,9 @@ def make(self, key): if si_sorting_analyzer_dir.exists(): # read from spikeinterface outputs sorting_analyzer = si.load_sorting_analyzer(folder=si_sorting_analyzer_dir) qc_metrics = sorting_analyzer.get_extension("quality_metrics").get_data() - template_metrics = sorting_analyzer.get_extension("template_metrics").get_data() + template_metrics = sorting_analyzer.get_extension( + "template_metrics" + ).get_data() metrics_df = pd.concat([qc_metrics, template_metrics], axis=1) metrics_df.rename( @@ -1514,7 +1518,7 @@ def make(self, key): "drift_mad": "cumulative_drift", "half_width": "halfwidth", "peak_trough_ratio": "pt_ratio", - "peak_to_valley": "duration" + "peak_to_valley": "duration", }, inplace=True, ) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 55c6efdd..33201d86 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -270,28 +270,33 @@ def make(self, key): overwrite=True, ) - job_kwargs = params["SI_POSTPROCESSING_PARAMS"].get("job_kwargs", {"n_jobs": -1, "chunk_duration": "1s"}) + job_kwargs = params["SI_POSTPROCESSING_PARAMS"].get( + "job_kwargs", {"n_jobs": -1, "chunk_duration": "1s"} + ) extensions_params = params["SI_POSTPROCESSING_PARAMS"].get("extensions", {}) # The order of extension computation is drawn from sorting_analyzer.get_computable_extensions() # each extension is parameterized by params specified in extensions_params dictionary (skip if not specified) - extensions_to_compute = {ext_name: extensions_params[ext_name] - for ext_name in sorting_analyzer.get_computable_extensions() - if ext_name in extensions_params} + extensions_to_compute = { + ext_name: extensions_params[ext_name] + for ext_name in sorting_analyzer.get_computable_extensions() + if ext_name in extensions_params + } sorting_analyzer.compute(extensions_to_compute, **job_kwargs) # Save to phy format if params["SI_POSTPROCESSING_PARAMS"].get("export_to_phy", False): si.exporters.export_to_phy( - sorting_analyzer=sorting_analyzer, output_folder=output_dir / sorter_name / "phy", - **job_kwargs + sorting_analyzer=sorting_analyzer, + output_folder=output_dir / sorter_name / "phy", + **job_kwargs, ) # Generate spike interface report if params["SI_POSTPROCESSING_PARAMS"].get("export_report", True): si.exporters.export_report( sorting_analyzer=sorting_analyzer, output_folder=output_dir / sorter_name / "spikeinterface_report", - **job_kwargs + **job_kwargs, ) self.insert1( From 07a09f6152b9632ce713287a85dedd0ad1bf8e9b Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Fri, 24 May 2024 15:28:52 -0500 Subject: [PATCH 05/16] chore: code clean up --- .../spike_sorting/si_spike_sorting.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 33201d86..a0ff2035 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -80,11 +80,9 @@ def make(self, key): sorter_name = clustering_method.replace(".", "_") for required_key in ( - "SI_SORTING_PARAMS", "SI_PREPROCESSING_METHOD", + "SI_SORTING_PARAMS", "SI_POSTPROCESSING_PARAMS", - "SI_WAVEFORM_EXTRACTION_PARAMS", - "SI_QUALITY_METRICS_PARAMS", ): if required_key not in params: raise ValueError( @@ -256,6 +254,10 @@ def make(self, key): sorting_file, base_folder=output_dir ) + job_kwargs = params["SI_POSTPROCESSING_PARAMS"].get( + "job_kwargs", {"n_jobs": -1, "chunk_duration": "1s"} + ) + # Sorting Analyzer analyzer_output_dir = output_dir / sorter_name / "sorting_analyzer" if (analyzer_output_dir / "extensions").exists(): @@ -268,14 +270,12 @@ def make(self, key): folder=analyzer_output_dir, sparse=True, overwrite=True, + **job_kwargs ) - job_kwargs = params["SI_POSTPROCESSING_PARAMS"].get( - "job_kwargs", {"n_jobs": -1, "chunk_duration": "1s"} - ) - extensions_params = params["SI_POSTPROCESSING_PARAMS"].get("extensions", {}) # The order of extension computation is drawn from sorting_analyzer.get_computable_extensions() # each extension is parameterized by params specified in extensions_params dictionary (skip if not specified) + extensions_params = params["SI_POSTPROCESSING_PARAMS"].get("extensions", {}) extensions_to_compute = { ext_name: extensions_params[ext_name] for ext_name in sorting_analyzer.get_computable_extensions() From 3fcf542d1435f4f891f2bbf93eaa3668da1986ea Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Fri, 24 May 2024 15:29:09 -0500 Subject: [PATCH 06/16] update: update requirements to install `SpikeInterface` from github (latest version) --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index e62719d8..f1ba9c90 100644 --- a/setup.py +++ b/setup.py @@ -35,7 +35,7 @@ "openpyxl", "plotly", "seaborn", - "spikeinterface>=0.101.0", + "spikeinterface @ git+https://github.com/SpikeInterface/spikeinterface.git", "scikit-image", "nbformat>=4.2.0", "pyopenephys>=1.1.6", From 76dfc94568bf28296da18905d0b187588bc99397 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Tue, 28 May 2024 10:32:19 -0500 Subject: [PATCH 07/16] fix: minor bug in spikes ingestion --- element_array_ephys/ephys_no_curation.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 99247e35..9222ccd2 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -1048,8 +1048,8 @@ def make(self, key): si.ChannelSparsity.from_best_channels( sorting_analyzer, 1, peak_sign="neg" ).unit_id_to_channel_indices - ) # {unit: peak_channel_index} - unit_peak_channel = {u: chn[0] for u, chn in unit_peak_channel.items()} + ) + unit_peak_channel: dict[int, int] = {u: chn[0] for u, chn in unit_peak_channel.items()} spike_count_dict: dict[int, int] = si_sorting.count_num_spikes_per_unit() # {unit: spike_count} @@ -1076,9 +1076,9 @@ def make(self, key): spikes_df = pd.DataFrame(spike_locations.spikes) units = [] - for unit_id in si_sorting.unit_ids: + for unit_idx, unit_id in enumerate(si_sorting.unit_ids): unit_id = int(unit_id) - unit_spikes_df = spikes_df[spikes_df.unit_index == unit_id] + unit_spikes_df = spikes_df[spikes_df.unit_index == unit_idx] spike_sites = np.array( [ channel2electrode_map[chn_idx]["electrode"] @@ -1087,6 +1087,9 @@ def make(self, key): ) unit_spikes_loc = spike_locations.get_data()[unit_spikes_df.index] _, spike_depths = zip(*unit_spikes_loc) # x-coordinates, y-coordinates + spike_times = si_sorting.get_unit_spike_train(unit_id, return_times=True) + + assert len(spike_times) == len(spike_sites) == len(spike_depths) units.append( { @@ -1094,9 +1097,7 @@ def make(self, key): **channel2electrode_map[unit_peak_channel[unit_id]], "unit": unit_id, "cluster_quality_label": cluster_quality_label_map[unit_id], - "spike_times": si_sorting.get_unit_spike_train( - unit_id, return_times=True - ), + "spike_times": spike_times, "spike_count": spike_count_dict[unit_id], "spike_sites": spike_sites, "spike_depths": spike_depths, From 9094754b6f23bd65a71390094ac509e06d22b34c Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Tue, 28 May 2024 10:38:59 -0500 Subject: [PATCH 08/16] update: bump version --- CHANGELOG.md | 5 +++++ element_array_ephys/version.py | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5e45e427..5d81dcba 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,11 @@ Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) convention. +## [0.4.0] - 2024-05-28 + ++ Add - support for SpikeInterface version >= 0.101.0 (updated API) + + ## [0.3.4] - 2024-03-22 + Add - pytest diff --git a/element_array_ephys/version.py b/element_array_ephys/version.py index 148bac24..2e6de55a 100644 --- a/element_array_ephys/version.py +++ b/element_array_ephys/version.py @@ -1,3 +1,3 @@ """Package metadata.""" -__version__ = "0.3.4" +__version__ = "0.4.0" From 51e2ced3f36fa1b69bacf69ea1fbf295c84eaf16 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Tue, 28 May 2024 13:14:00 -0500 Subject: [PATCH 09/16] feat: add `memoized_result` on spike sorting --- CHANGELOG.md | 1 + .../spike_sorting/si_spike_sorting.py | 103 ++++++++++-------- setup.py | 2 +- 3 files changed, 60 insertions(+), 46 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5d81dcba..cd8bb5b0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and ## [0.4.0] - 2024-05-28 + Add - support for SpikeInterface version >= 0.101.0 (updated API) ++ Add - feature for memoization of spike sorting results (prevent duplicated runs) ## [0.3.4] - 2024-03-22 diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index a0ff2035..dff74dd7 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -8,7 +8,7 @@ import pandas as pd import spikeinterface as si from element_array_ephys import probe, readers -from element_interface.utils import find_full_path +from element_interface.utils import find_full_path, memoized_result from spikeinterface import exporters, postprocessing, qualitymetrics, sorters from . import si_preprocessing @@ -192,23 +192,29 @@ def make(self, key): recording_file, base_folder=output_dir ) + sorting_params = params["SI_SORTING_PARAMS"] + sorting_output_dir = output_dir / sorter_name / "spike_sorting" + # Run sorting - # Sorting performed in a dedicated docker environment if the sorter is not built in the spikeinterface package. - si_sorting: si.sorters.BaseSorter = si.sorters.run_sorter( - sorter_name=sorter_name, - recording=si_recording, - output_folder=output_dir / sorter_name / "spike_sorting", - remove_existing_folder=True, - verbose=True, - docker_image=sorter_name not in si.sorters.installed_sorters(), - **params.get("SI_SORTING_PARAMS", {}), + @memoized_result( + uniqueness_dict=sorting_params, + output_directory=sorting_output_dir, ) + def _run_sorter(): + # Sorting performed in a dedicated docker environment if the sorter is not built in the spikeinterface package. + si_sorting: si.sorters.BaseSorter = si.sorters.run_sorter( + sorter_name=sorter_name, + recording=si_recording, + output_folder=sorting_output_dir, + remove_existing_folder=True, + verbose=True, + docker_image=sorter_name not in si.sorters.installed_sorters(), + **sorting_params, + ) - # Save sorting object - sorting_save_path = ( - output_dir / sorter_name / "spike_sorting" / "si_sorting.pkl" - ) - si_sorting.dump_to_pickle(sorting_save_path, relative_to=output_dir) + # Save sorting object + sorting_save_path = sorting_output_dir / "si_sorting.pkl" + si_sorting.dump_to_pickle(sorting_save_path, relative_to=output_dir) self.insert1( { @@ -254,15 +260,20 @@ def make(self, key): sorting_file, base_folder=output_dir ) - job_kwargs = params["SI_POSTPROCESSING_PARAMS"].get( + postprocessing_params = params["SI_POSTPROCESSING_PARAMS"] + + job_kwargs = postprocessing_params.get( "job_kwargs", {"n_jobs": -1, "chunk_duration": "1s"} ) - # Sorting Analyzer analyzer_output_dir = output_dir / sorter_name / "sorting_analyzer" - if (analyzer_output_dir / "extensions").exists(): - sorting_analyzer = si.load_sorting_analyzer(folder=analyzer_output_dir) - else: + + @memoized_result( + uniqueness_dict=postprocessing_params, + output_directory=analyzer_output_dir, + ) + def _sorting_analyzer_compute(): + # Sorting Analyzer sorting_analyzer = si.create_sorting_analyzer( sorting=si_sorting, recording=si_recording, @@ -273,31 +284,33 @@ def make(self, key): **job_kwargs ) - # The order of extension computation is drawn from sorting_analyzer.get_computable_extensions() - # each extension is parameterized by params specified in extensions_params dictionary (skip if not specified) - extensions_params = params["SI_POSTPROCESSING_PARAMS"].get("extensions", {}) - extensions_to_compute = { - ext_name: extensions_params[ext_name] - for ext_name in sorting_analyzer.get_computable_extensions() - if ext_name in extensions_params - } - - sorting_analyzer.compute(extensions_to_compute, **job_kwargs) - - # Save to phy format - if params["SI_POSTPROCESSING_PARAMS"].get("export_to_phy", False): - si.exporters.export_to_phy( - sorting_analyzer=sorting_analyzer, - output_folder=output_dir / sorter_name / "phy", - **job_kwargs, - ) - # Generate spike interface report - if params["SI_POSTPROCESSING_PARAMS"].get("export_report", True): - si.exporters.export_report( - sorting_analyzer=sorting_analyzer, - output_folder=output_dir / sorter_name / "spikeinterface_report", - **job_kwargs, - ) + # The order of extension computation is drawn from sorting_analyzer.get_computable_extensions() + # each extension is parameterized by params specified in extensions_params dictionary (skip if not specified) + extensions_params = postprocessing_params.get("extensions", {}) + extensions_to_compute = { + ext_name: extensions_params[ext_name] + for ext_name in sorting_analyzer.get_computable_extensions() + if ext_name in extensions_params + } + + sorting_analyzer.compute(extensions_to_compute, **job_kwargs) + + # Save to phy format + if postprocessing_params.get("export_to_phy", False): + si.exporters.export_to_phy( + sorting_analyzer=sorting_analyzer, + output_folder=analyzer_output_dir / "phy", + **job_kwargs, + ) + # Generate spike interface report + if postprocessing_params.get("export_report", True): + si.exporters.export_report( + sorting_analyzer=sorting_analyzer, + output_folder=analyzer_output_dir / "spikeinterface_report", + **job_kwargs, + ) + + _sorting_analyzer_compute() self.insert1( { diff --git a/setup.py b/setup.py index f1ba9c90..66789740 100644 --- a/setup.py +++ b/setup.py @@ -39,7 +39,7 @@ "scikit-image", "nbformat>=4.2.0", "pyopenephys>=1.1.6", - "element-interface @ git+https://github.com/datajoint/element-interface.git", + "element-interface @ git+https://github.com/datajoint/element-interface.git@dev_memoized_results", "numba", ], extras_require={ From 0afb4529de262fbee6b21461e5aec58765fd0e12 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Tue, 28 May 2024 14:22:20 -0500 Subject: [PATCH 10/16] chore: minor code cleanup --- element_array_ephys/ephys_no_curation.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 9222ccd2..b49d4422 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -8,14 +8,12 @@ import datajoint as dj import numpy as np import pandas as pd -import spikeinterface as si from element_interface.utils import dict_to_uuid, find_full_path, find_root_directory -from spikeinterface import exporters, postprocessing, qualitymetrics, sorters from . import ephys_report, probe from .readers import kilosort, openephys, spikeglx -log = dj.logger +logger = dj.logger schema = dj.schema() @@ -824,7 +822,7 @@ def infer_output_dir(cls, key, relative: bool = False, mkdir: bool = False): if mkdir: output_dir.mkdir(parents=True, exist_ok=True) - log.info(f"{output_dir} created!") + logger.info(f"{output_dir} created!") return output_dir.relative_to(processed_dir) if relative else output_dir @@ -1040,6 +1038,8 @@ def make(self, key): si_sorting_analyzer_dir = output_dir / sorter_name / "sorting_analyzer" if si_sorting_analyzer_dir.exists(): # Read from spikeinterface outputs + import spikeinterface as si + sorting_analyzer = si.load_sorting_analyzer(folder=si_sorting_analyzer_dir) si_sorting = sorting_analyzer.sorting @@ -1246,6 +1246,8 @@ def make(self, key): si_sorting_analyzer_dir = output_dir / sorter_name / "sorting_analyzer" if si_sorting_analyzer_dir.exists(): # read from spikeinterface outputs + import spikeinterface as si + sorting_analyzer = si.load_sorting_analyzer(folder=si_sorting_analyzer_dir) # Find representative channel for each unit @@ -1501,6 +1503,8 @@ def make(self, key): si_sorting_analyzer_dir = output_dir / sorter_name / "sorting_analyzer" if si_sorting_analyzer_dir.exists(): # read from spikeinterface outputs + import spikeinterface as si + sorting_analyzer = si.load_sorting_analyzer(folder=si_sorting_analyzer_dir) qc_metrics = sorting_analyzer.get_extension("quality_metrics").get_data() template_metrics = sorting_analyzer.get_extension( From e8f445c3b4b532b3159638e71d231e2048939a90 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Tue, 28 May 2024 16:47:22 -0500 Subject: [PATCH 11/16] fix: merge fix & formatting --- element_array_ephys/spike_sorting/si_spike_sorting.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index dff74dd7..9e14f636 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -248,7 +248,6 @@ def make(self, key): ).fetch1("clustering_method", "clustering_output_dir", "params") output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) sorter_name = clustering_method.replace(".", "_") - output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) recording_file = output_dir / sorter_name / "recording" / "si_recording.pkl" sorting_file = output_dir / sorter_name / "spike_sorting" / "si_sorting.pkl" @@ -281,7 +280,7 @@ def _sorting_analyzer_compute(): folder=analyzer_output_dir, sparse=True, overwrite=True, - **job_kwargs + **job_kwargs, ) # The order of extension computation is drawn from sorting_analyzer.get_computable_extensions() From 6155f13fd755ac76ec79fdd1594b0e96ef8d550b Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Tue, 28 May 2024 17:01:10 -0500 Subject: [PATCH 12/16] fix: calling `_run_sorter()` --- element_array_ephys/spike_sorting/si_spike_sorting.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 9e14f636..5c1d6567 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -216,6 +216,8 @@ def _run_sorter(): sorting_save_path = sorting_output_dir / "si_sorting.pkl" si_sorting.dump_to_pickle(sorting_save_path, relative_to=output_dir) + _run_sorter() + self.insert1( { **key, From f6a52d9d3f31b7ebe2853da4545551898cfa50ae Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Tue, 28 May 2024 20:07:27 -0500 Subject: [PATCH 13/16] chore: more robust channel mapping --- element_array_ephys/ephys_no_curation.py | 29 ++++++++---------------- 1 file changed, 10 insertions(+), 19 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index b49d4422..142f350b 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -1028,9 +1028,8 @@ def make(self, key): # Get channel and electrode-site mapping electrode_query = (EphysRecording.Channel & key).proj(..., "-channel_name") - channel2electrode_map = electrode_query.fetch(as_dict=True) channel2electrode_map: dict[int, dict] = { - chn.pop("channel_idx"): chn for chn in channel2electrode_map + chn.pop("channel_idx"): chn for chn in electrode_query.fetch(as_dict=True) } # Get sorter method and create output directory. @@ -1054,12 +1053,10 @@ def make(self, key): spike_count_dict: dict[int, int] = si_sorting.count_num_spikes_per_unit() # {unit: spike_count} - # reorder channel2electrode_map according to recording channel ids + # update channel2electrode_map to match with probe's channel index channel2electrode_map = { - chn_idx: channel2electrode_map[chn_idx] - for chn_idx in sorting_analyzer.channel_ids_to_indices( - sorting_analyzer.channel_ids - ) + idx: channel2electrode_map[int(chn_idx)] + for idx, chn_idx in enumerate(sorting_analyzer.get_probe().contact_ids) } # Get unit id to quality label mapping @@ -1239,9 +1236,8 @@ def make(self, key): # Get channel and electrode-site mapping electrode_query = (EphysRecording.Channel & key).proj(..., "-channel_name") - channel2electrode_map = electrode_query.fetch(as_dict=True) channel2electrode_map: dict[int, dict] = { - chn.pop("channel_idx"): chn for chn in channel2electrode_map + chn.pop("channel_idx"): chn for chn in electrode_query.fetch(as_dict=True) } si_sorting_analyzer_dir = output_dir / sorter_name / "sorting_analyzer" @@ -1258,12 +1254,10 @@ def make(self, key): ) # {unit: peak_channel_index} unit_peak_channel = {u: chn[0] for u, chn in unit_peak_channel.items()} - # reorder channel2electrode_map according to recording channel ids - channel_indices = sorting_analyzer.channel_ids_to_indices( - sorting_analyzer.channel_ids - ).tolist() + # update channel2electrode_map to match with probe's channel index channel2electrode_map = { - chn_idx: channel2electrode_map[chn_idx] for chn_idx in channel_indices + idx: channel2electrode_map[int(chn_idx)] + for idx, chn_idx in enumerate(sorting_analyzer.get_probe().contact_ids) } templates = sorting_analyzer.get_extension("templates") @@ -1276,12 +1270,9 @@ def yield_unit_waveforms(): unit_waveforms = templates.get_unit_template( unit_id=unit["unit"], operator="average" ) - peak_chn_idx = channel_indices.index( - unit_peak_channel[unit["unit"]] - ) unit_peak_waveform = { **unit, - "peak_electrode_waveform": unit_waveforms[:, peak_chn_idx], + "peak_electrode_waveform": unit_waveforms[:, unit_peak_channel[unit["unit"]]], } unit_electrode_waveforms = [ @@ -1290,7 +1281,7 @@ def yield_unit_waveforms(): **channel2electrode_map[chn_idx], "waveform_mean": unit_waveforms[:, chn_idx], } - for chn_idx in channel_indices + for chn_idx in channel2electrode_map ] yield unit_peak_waveform, unit_electrode_waveforms From 1ff92dd15db6ff9e8458f53ec96fdffb6b93305d Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Wed, 29 May 2024 16:09:16 -0500 Subject: [PATCH 14/16] fix: use relative path for phy output --- element_array_ephys/spike_sorting/si_spike_sorting.py | 1 + 1 file changed, 1 insertion(+) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 5c1d6567..93619303 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -301,6 +301,7 @@ def _sorting_analyzer_compute(): si.exporters.export_to_phy( sorting_analyzer=sorting_analyzer, output_folder=analyzer_output_dir / "phy", + use_relative_path=True, **job_kwargs, ) # Generate spike interface report From b45970974df001319a4ebae182bf291313f5e39a Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Wed, 29 May 2024 16:16:21 -0500 Subject: [PATCH 15/16] feat: in data ingestion, set peak_sign="both" --- element_array_ephys/ephys_no_curation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 142f350b..8eadba49 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -1045,7 +1045,7 @@ def make(self, key): # Find representative channel for each unit unit_peak_channel: dict[int, np.ndarray] = ( si.ChannelSparsity.from_best_channels( - sorting_analyzer, 1, peak_sign="neg" + sorting_analyzer, 1, peak_sign="both" ).unit_id_to_channel_indices ) unit_peak_channel: dict[int, int] = {u: chn[0] for u, chn in unit_peak_channel.items()} @@ -1249,7 +1249,7 @@ def make(self, key): # Find representative channel for each unit unit_peak_channel: dict[int, np.ndarray] = ( si.ChannelSparsity.from_best_channels( - sorting_analyzer, 1, peak_sign="neg" + sorting_analyzer, 1, peak_sign="both" ).unit_id_to_channel_indices ) # {unit: peak_channel_index} unit_peak_channel = {u: chn[0] for u, chn in unit_peak_channel.items()} From 1a1b18f8a52b83298bffc8d82555ccc147151dd1 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Mon, 3 Jun 2024 13:22:49 -0500 Subject: [PATCH 16/16] feat: replace `output_folder` with `folder` when calling `run_sorter`, use default value for `peak_sign` --- element_array_ephys/ephys_no_curation.py | 21 ++++++++++++------- .../spike_sorting/si_spike_sorting.py | 2 +- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 8eadba49..891cee0f 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -1045,10 +1045,13 @@ def make(self, key): # Find representative channel for each unit unit_peak_channel: dict[int, np.ndarray] = ( si.ChannelSparsity.from_best_channels( - sorting_analyzer, 1, peak_sign="both" + sorting_analyzer, + 1, ).unit_id_to_channel_indices ) - unit_peak_channel: dict[int, int] = {u: chn[0] for u, chn in unit_peak_channel.items()} + unit_peak_channel: dict[int, int] = { + u: chn[0] for u, chn in unit_peak_channel.items() + } spike_count_dict: dict[int, int] = si_sorting.count_num_spikes_per_unit() # {unit: spike_count} @@ -1084,7 +1087,9 @@ def make(self, key): ) unit_spikes_loc = spike_locations.get_data()[unit_spikes_df.index] _, spike_depths = zip(*unit_spikes_loc) # x-coordinates, y-coordinates - spike_times = si_sorting.get_unit_spike_train(unit_id, return_times=True) + spike_times = si_sorting.get_unit_spike_train( + unit_id, return_times=True + ) assert len(spike_times) == len(spike_sites) == len(spike_depths) @@ -1243,13 +1248,13 @@ def make(self, key): si_sorting_analyzer_dir = output_dir / sorter_name / "sorting_analyzer" if si_sorting_analyzer_dir.exists(): # read from spikeinterface outputs import spikeinterface as si - + sorting_analyzer = si.load_sorting_analyzer(folder=si_sorting_analyzer_dir) # Find representative channel for each unit unit_peak_channel: dict[int, np.ndarray] = ( si.ChannelSparsity.from_best_channels( - sorting_analyzer, 1, peak_sign="both" + sorting_analyzer, 1 ).unit_id_to_channel_indices ) # {unit: peak_channel_index} unit_peak_channel = {u: chn[0] for u, chn in unit_peak_channel.items()} @@ -1272,7 +1277,9 @@ def yield_unit_waveforms(): ) unit_peak_waveform = { **unit, - "peak_electrode_waveform": unit_waveforms[:, unit_peak_channel[unit["unit"]]], + "peak_electrode_waveform": unit_waveforms[ + :, unit_peak_channel[unit["unit"]] + ], } unit_electrode_waveforms = [ @@ -1495,7 +1502,7 @@ def make(self, key): si_sorting_analyzer_dir = output_dir / sorter_name / "sorting_analyzer" if si_sorting_analyzer_dir.exists(): # read from spikeinterface outputs import spikeinterface as si - + sorting_analyzer = si.load_sorting_analyzer(folder=si_sorting_analyzer_dir) qc_metrics = sorting_analyzer.get_extension("quality_metrics").get_data() template_metrics = sorting_analyzer.get_extension( diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 93619303..57aa0ba1 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -205,7 +205,7 @@ def _run_sorter(): si_sorting: si.sorters.BaseSorter = si.sorters.run_sorter( sorter_name=sorter_name, recording=si_recording, - output_folder=sorting_output_dir, + folder=sorting_output_dir, remove_existing_folder=True, verbose=True, docker_image=sorter_name not in si.sorters.installed_sorters(),