From d1ce4e745f715aec142dc549f6de12bb4fb695a9 Mon Sep 17 00:00:00 2001 From: luiz Date: Fri, 27 Oct 2023 11:38:21 +0200 Subject: [PATCH 01/16] preprocessing --- pyproject.toml | 21 ++ requirements.txt | 1 + src/spikeinterface_pipelines/__init__.py | 0 .../preprocessing/models.py | 60 +++++ .../preprocessing/preprocessing.py | 249 ++++++++++++++++++ 5 files changed, 331 insertions(+) create mode 100644 pyproject.toml create mode 100644 requirements.txt create mode 100644 src/spikeinterface_pipelines/__init__.py create mode 100644 src/spikeinterface_pipelines/preprocessing/models.py create mode 100644 src/spikeinterface_pipelines/preprocessing/preprocessing.py diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..0322d3c --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,21 @@ +[project] +name = "spikeinterface_pipelines" +description = "Collection of standardized analysis pipelines based on SpikeInterfacee." +readme = "README.md" +authors = [{ name = "My Name", email = "me@email.com" }] +requires-python = ">=3.8" +dependencies = ["spikeinterface[full]"] +keywords = [ + "spikeinterface", + "spike sorting", + "electrophysiology", + "neuroscience", +] + +[project.urls] +Homepage = "https://github.com/SpikeInterface/spikeinterface_pipelines" +Documentation = "https://github.com/SpikeInterface/spikeinterface_pipelines" + +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..0c660d7 --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +spikeinterface[full] \ No newline at end of file diff --git a/src/spikeinterface_pipelines/__init__.py b/src/spikeinterface_pipelines/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/spikeinterface_pipelines/preprocessing/models.py b/src/spikeinterface_pipelines/preprocessing/models.py new file mode 100644 index 0000000..566419e --- /dev/null +++ b/src/spikeinterface_pipelines/preprocessing/models.py @@ -0,0 +1,60 @@ +from pydantic import BaseModel, Field +from typing import Optional +from enum import Enum + + +class PreprocessingStrategy(str, Enum): + cmr = "cmr" + destripe = "destripe" + + +class HighpassFilter(BaseModel): + freq_min: float = Field(default=300.0, description="Minimum frequency for the highpass filter") + margin_ms: float = Field(default=5.0, description="Margin in milliseconds") + + +class PhaseShift(BaseModel): + margin_ms: float = Field(default=100.0, description="Margin in milliseconds for phase shift") + + +class DetectBadChannels(BaseModel): + method: str = Field(default="coherence+psd", description="Method to detect bad channels") + dead_channel_threshold: float = Field(default=-0.5, description="Threshold for dead channel") + noisy_channel_threshold: float = Field(default=1.0, description="Threshold for noisy channel") + outside_channel_threshold: float = Field(default=-0.3, description="Threshold for outside channel") + n_neighbors: int = Field(default=11, description="Number of neighbors") + seed: int = Field(default=0, description="Seed value") + + +class CommonReference(BaseModel): + reference: str = Field(default="global", description="Type of reference") + operator: str = Field(default="median", description="Operator used for common reference") + + +class HighpassSpatialFilter(BaseModel): + n_channel_pad: int = Field(default=60, description="Number of channels to pad") + n_channel_taper: Optional[int] = Field(default=None, description="Number of channels to taper") + direction: str = Field(default="y", description="Direction for the spatial filter") + apply_agc: bool = Field(default=True, description="Whether to apply automatic gain control") + agc_window_length_s: float = Field(default=0.01, description="Window length in seconds for AGC") + highpass_butter_order: int = Field(default=3, description="Order for the Butterworth filter") + highpass_butter_wn: float = Field(default=0.01, description="Natural frequency for the Butterworth filter") + + +class MotionCorrection(BaseModel): + compute: bool = Field(default=True, description="Whether to compute motion correction") + apply: bool = Field(default=False, description="Whether to apply motion correction") + preset: str = Field(default="nonrigid_accurate", description="Preset for motion correction") + + +class PreprocessingParamsModel(BaseModel): + preprocessing_strategy: PreprocessingStrategy = Field(default="cmr", description="Strategy for preprocessing") + highpass_filter: HighpassFilter + phase_shift: PhaseShift + detect_bad_channels: DetectBadChannels + remove_out_channels: bool = Field(default=True, description="Flag to remove out channels") + remove_bad_channels: bool = Field(default=True, description="Flag to remove bad channels") + max_bad_channel_fraction_to_remove: float = Field(default=0.5, description="Maximum fraction of bad channels to remove") + common_reference: CommonReference + highpass_spatial_filter: HighpassSpatialFilter + motion_correction: MotionCorrection diff --git a/src/spikeinterface_pipelines/preprocessing/preprocessing.py b/src/spikeinterface_pipelines/preprocessing/preprocessing.py new file mode 100644 index 0000000..cdb4b8b --- /dev/null +++ b/src/spikeinterface_pipelines/preprocessing/preprocessing.py @@ -0,0 +1,249 @@ +import warnings +import os +import json +import time +import numpy as np +from pathlib import Path +from datetime import datetime +import spikeinterface as si +import spikeinterface.extractors as se +import spikeinterface.preprocessing as spre +from spikeinterface.core.core_tools import check_json + +from .models import PreprocessingParamsModel + + +warnings.filterwarnings("ignore") + +n_jobs_co = os.getenv('CO_CPUS') +n_jobs = int(n_jobs_co) if n_jobs_co is not None else -1 + +job_kwargs = { + 'n_jobs': n_jobs, + 'chunk_duration': '1s', + 'progress_bar': True +} + +data_folder = Path("../data/") +results_folder = Path("../results/") + + +def preprocessing( + data_folder: Path, + results_folder: Path, + job_kwargs: dict, + preprocessing_params: PreprocessingParamsModel, + debug: bool = False, + duration_s: float = 1. +) -> None: + """ + Preprocessing pipeline for ephys data. + """ + + data_process_prefix = "data_process_preprocessing" + + if debug: + print(f"DEBUG ENABLED - Only running with {duration_s} seconds") + + si.set_global_job_kwargs(**job_kwargs) + + # load job json files + job_config_json_files = [p for p in data_folder.iterdir() if p.suffix == ".json" and "job" in p.name] + print(f"Found {len(job_config_json_files)} json configurations") + + if len(job_config_json_files) > 0: + print("\n\nPREPROCESSING") + t_preprocessing_start_all = time.perf_counter() + preprocessing_vizualization_data = {} + print(f"Preprocessing strategy: {preprocessing_params.preprocessing_strategy}") + + for job_config_file in job_config_json_files: + datetime_start_preproc = datetime.now() + t_preprocessing_start = time.perf_counter() + preprocessing_notes = "" + + with open(job_config_file, "r") as f: + job_config = json.load(f) + session_name = job_config["session_name"] + session_folder_path = job_config["session_folder_path"] + + session = data_folder / session_folder_path + assert session.is_dir(), ( + f"Could not find {session_name} in {str((data_folder / session_folder_path).resolve())}." + f"Make sure mapping is correct!" + ) + + ecephys_full_folder = session / "ecephys" + ecephys_compressed_folder = session / "ecephys_compressed" + compressed = False + if ecephys_compressed_folder.is_dir(): + compressed = True + ecephys_folder = session / "ecephys_clipped" + else: + ecephys_folder = ecephys_full_folder + + experiment_name = job_config["experiment_name"] + stream_name = job_config["stream_name"] + block_index = job_config["block_index"] + segment_index = job_config["segment_index"] + recording_name = job_config["recording_name"] + + preprocessing_vizualization_data[recording_name] = {} + preprocessing_output_process_json = results_folder / f"{data_process_prefix}_{recording_name}.json" + preprocessing_output_folder = results_folder / f"preprocessed_{recording_name}" + preprocessingviz_output_file = results_folder / f"preprocessedviz_{recording_name}.json" + preprocessing_output_json = results_folder / f"preprocessed_{recording_name}.json" + + exp_stream_name = f"{experiment_name}_{stream_name}" + if not compressed: + recording = se.read_openephys( + ecephys_folder, + stream_name=stream_name, + block_index=block_index + ) + else: + recording = si.read_zarr(ecephys_compressed_folder / f"{exp_stream_name}.zarr") + + if debug: + recording_list = [] + for segment_index in range(recording.get_num_segments()): + recording_one = si.split_recording(recording)[segment_index] + recording_one = recording_one.frame_slice( + start_frame=0, + end_frame=int(duration_s*recording.sampling_frequency) + ) + recording_list.append(recording_one) + recording = si.append_recordings(recording_list) + + if segment_index is not None: + recording = si.split_recording(recording)[segment_index] + + print(f"Preprocessing recording: {recording_name}") + print(f"\tDuration: {np.round(recording.get_total_duration(), 2)} s") + + recording_ps_full = spre.phase_shift( + recording, + **preprocessing_params.phase_shift.model_dump() + ) + recording_hp_full = spre.highpass_filter( + recording_ps_full, + **preprocessing_params.highpass_filter.model_dump() + ) + preprocessing_vizualization_data[recording_name]["timeseries"] = {} + preprocessing_vizualization_data[recording_name]["timeseries"]["full"] = dict( + raw=recording.to_dict(relative_to=data_folder, recursive=True), + phase_shift=recording_ps_full.to_dict(relative_to=data_folder, recursive=True), + highpass=recording_hp_full.to_dict(relative_to=data_folder, recursive=True) + ) + + # IBL bad channel detection + _, channel_labels = spre.detect_bad_channels( + recording_hp_full, + **preprocessing_params.detect_bad_channels.model_dump() + ) + dead_channel_mask = channel_labels == "dead" + noise_channel_mask = channel_labels == "noise" + out_channel_mask = channel_labels == "out" + print(f"\tBad channel detection:") + print(f"\t\t- dead channels - {np.sum(dead_channel_mask)}\n\t\t- noise channels - {np.sum(noise_channel_mask)}\n\t\t- out channels - {np.sum(out_channel_mask)}") + dead_channel_ids = recording_hp_full.channel_ids[dead_channel_mask] + noise_channel_ids = recording_hp_full.channel_ids[noise_channel_mask] + out_channel_ids = recording_hp_full.channel_ids[out_channel_mask] + + all_bad_channel_ids = np.concatenate((dead_channel_ids, noise_channel_ids, out_channel_ids)) + + skip_processing = False + max_bad_channel_fraction_to_remove = preprocessing_params.max_bad_channel_fraction_to_remove + if len(all_bad_channel_ids) >= int(max_bad_channel_fraction_to_remove * recording.get_num_channels()): + print(f"\tMore than {max_bad_channel_fraction_to_remove * 100}% bad channels ({len(all_bad_channel_ids)}). " + f"Skipping further processing for this recording.") + preprocessing_notes += f"\n- Found {len(all_bad_channel_ids)} bad channels. Skipping further processing\n" + skip_processing = True + # in this case, processed timeseries will not be visualized + preprocessing_vizualization_data[recording_name]["timeseries"]["proc"] = None + recording_drift = recording_hp_full + else: + if preprocessing_params.remove_out_channels: + print(f"\tRemoving {len(out_channel_ids)} out channels") + recording_rm_out = recording_hp_full.remove_channels(out_channel_ids) + preprocessing_notes += f"\n- Removed {len(out_channel_ids)} outside of the brain." + else: + recording_rm_out = recording_hp_full + + recording_processed_cmr = spre.common_reference( + recording_rm_out, + **preprocessing_params.common_reference.model_dump() + ) + + bad_channel_ids = np.concatenate((dead_channel_ids, noise_channel_ids)) + recording_interp = spre.interpolate_bad_channels(recording_rm_out, bad_channel_ids) + recording_hp_spatial = spre.highpass_spatial_filter( + recording_interp, + **preprocessing_params.highpass_spatial_filter.model_dump() + ) + preprocessing_vizualization_data[recording_name]["timeseries"]["proc"] = dict( + highpass=recording_rm_out.to_dict(relative_to=data_folder, recursive=True), + cmr=recording_processed_cmr.to_dict(relative_to=data_folder, recursive=True), + highpass_spatial=recording_hp_spatial.to_dict(relative_to=data_folder, recursive=True) + ) + + if preprocessing_params.preprocessing_strategy == "cmr": + recording_processed = recording_processed_cmr + else: + recording_processed = recording_hp_spatial + + if preprocessing_params.remove_bad_channels: + print(f"\tRemoving {len(bad_channel_ids)} channels after {preprocessing_params.preprocessing_strategy} preprocessing") + recording_processed = recording_processed.remove_channels(bad_channel_ids) + preprocessing_notes += f"\n- Removed {len(bad_channel_ids)} bad channels after preprocessing.\n" + + # motion correction + if preprocessing_params.motion_correction.compute: + preset = preprocessing_params.motion_correction.preset + print(f"\tComputing motion correction with preset: {preset}") + motion_folder = results_folder / f"motion_{recording_name}" + recording_corrected = spre.correct_motion( + recording_processed, preset=preset, + folder=motion_folder, + **job_kwargs + ) + if preprocessing_params.motion_correction.apply: + print("\tApplying motion correction") + recording_processed = recording_corrected + + recording_saved = recording_processed.save(folder=preprocessing_output_folder) + recording_processed.dump_to_json(preprocessing_output_json, relative_to=data_folder) + recording_drift = recording_saved + + # store recording for drift visualization + preprocessing_vizualization_data[recording_name]["drift"] = dict( + recording=recording_drift.to_dict(relative_to=data_folder, recursive=True) + ) + with open(preprocessingviz_output_file, "w") as f: + json.dump(check_json(preprocessing_vizualization_data), f, indent=4) + + t_preprocessing_end = time.perf_counter() + elapsed_time_preprocessing = np.round(t_preprocessing_end - t_preprocessing_start, 2) + + # save params in output + preprocessing_params["recording_name"] = recording_name + preprocessing_outputs = dict(channel_labels=channel_labels.tolist()) + # preprocessing_process = DataProcess( + # name="Ephys preprocessing", + # version=VERSION, # either release or git commit + # start_date_time=datetime_start_preproc, + # end_date_time=datetime_start_preproc + timedelta(seconds=np.floor(elapsed_time_preprocessing)), + # input_location=str(data_folder), + # output_location=str(results_folder), + # code_url=URL, + # parameters=preprocessing_params, + # outputs=preprocessing_outputs, + # notes=preprocessing_notes + # ) + # with open(preprocessing_output_process_json, "w") as f: + # f.write(preprocessing_process.json(indent=3)) + + t_preprocessing_end_all = time.perf_counter() + elapsed_time_preprocessing_all = np.round(t_preprocessing_end_all - t_preprocessing_start_all, 2) + + print(f"PREPROCESSING time: {elapsed_time_preprocessing_all}s") \ No newline at end of file From 0a88040ce9eb8818bf9567f284d97f973b5319c9 Mon Sep 17 00:00:00 2001 From: luiz Date: Fri, 27 Oct 2023 12:06:49 +0200 Subject: [PATCH 02/16] postprocessing - wip --- .../postprocessing/__init__.py | 1 + .../postprocessing/models.py | 142 +++++++++++++++ .../postprocessing/postprocessing.py | 162 ++++++++++++++++++ .../preprocessing/__init__.py | 1 + 4 files changed, 306 insertions(+) create mode 100644 src/spikeinterface_pipelines/postprocessing/__init__.py create mode 100644 src/spikeinterface_pipelines/postprocessing/models.py create mode 100644 src/spikeinterface_pipelines/postprocessing/postprocessing.py create mode 100644 src/spikeinterface_pipelines/preprocessing/__init__.py diff --git a/src/spikeinterface_pipelines/postprocessing/__init__.py b/src/spikeinterface_pipelines/postprocessing/__init__.py new file mode 100644 index 0000000..c01852b --- /dev/null +++ b/src/spikeinterface_pipelines/postprocessing/__init__.py @@ -0,0 +1 @@ +from .postprocessing import postprocessing \ No newline at end of file diff --git a/src/spikeinterface_pipelines/postprocessing/models.py b/src/spikeinterface_pipelines/postprocessing/models.py new file mode 100644 index 0000000..5ef6d99 --- /dev/null +++ b/src/spikeinterface_pipelines/postprocessing/models.py @@ -0,0 +1,142 @@ +from pydantic import BaseModel, Field +from typing import Optional, List, Tuple +from enum import Enum + + +class PresenceRatio(BaseModel): + bin_duration_s: float = Field(60, description="Duration of the bin in seconds.") + + +class SNR(BaseModel): + peak_sign: str = Field("neg", description="Sign of the peak.") + peak_mode: str = Field("extremum", description="Mode of the peak.") + random_chunk_kwargs_dict: Optional[dict] = Field(None, description="Random chunk arguments.") + + +class ISIViolation(BaseModel): + isi_threshold_ms: float = Field(1.5, description="ISI threshold in milliseconds.") + min_isi_ms: float = Field(0., description="Minimum ISI in milliseconds.") + + +class RPViolation(BaseModel): + refractory_period_ms: float = Field(1., description="Refractory period in milliseconds.") + censored_period_ms: float = Field(0.0, description="Censored period in milliseconds.") + + +class SlidingRPViolation(BaseModel): + bin_size_ms: float = Field(0.25, description="The size of binning for the autocorrelogram in ms, by default 0.25.") + window_size_s: float = Field(1, description="Window in seconds to compute correlogram, by default 1.") + exclude_ref_period_below_ms: float = Field(0.5, description="Refractory periods below this value are excluded, by default 0.5") + max_ref_period_ms: float = Field(10, description="Maximum refractory period to test in ms, by default 10 ms.") + contamination_values: Optional[list] = Field(None, description="The contamination values to test, by default np.arange(0.5, 35, 0.5) %") + + +class PeakSign(str, Enum): + neg = "neg" + pos = "pos" + both = "both" + + +class AmplitudeCutoff(BaseModel): + peak_sign: PeakSign = Field("neg", description="The sign of the peaks.") + num_histogram_bins: int = Field(100, description="The number of bins to use to compute the amplitude histogram.") + histogram_smoothing_value: int = Field(3, description="Controls the smoothing applied to the amplitude histogram.") + amplitudes_bins_min_ratio: int = Field(5, description="The minimum ratio between number of amplitudes for a unit and the number of bins. If the ratio is less than this threshold, the amplitude_cutoff for the unit is set to NaN.") + + +class AmplitudeMedian(BaseModel): + peak_sign: PeakSign = Field("neg", description="The sign of the peaks.") + + +class NearestNeighbor(BaseModel): + max_spikes: int = Field(10000, description="The number of spikes to use, per cluster. Note that the calculation can be very slow when this number is >20000.") + min_spikes: int = Field(10, description="Minimum number of spikes.") + n_neighbors: int = Field(4, description="The number of neighbors to use.") + + +class NNIsolation(NearestNeighbor): + n_components: int = Field(10, description="The number of PC components to use to project the snippets to.") + radius_um: int = Field(100, description="The radius, in um, that channels need to be within the peak channel to be included.") + + +class QMParams(BaseModel): + presence_ratio: PresenceRatio + snr: SNR + isi_violation: ISIViolation + rp_violation: RPViolation + sliding_rp_violation: SlidingRPViolation + amplitude_cutoff: AmplitudeCutoff + amplitude_median: AmplitudeMedian + nearest_neighbor: NearestNeighbor + nn_isolation: NNIsolation + nn_noise_overlap: NNIsolation + + +class QualityMetrics(BaseModel): + qm_params: QMParams = Field(..., description="Quality metric parameters.") + metric_names: List[str] = Field(..., description="List of metric names to compute.") + n_jobs: int = Field(1, description="Number of jobs.") + + +class Sparsity(BaseModel): + method: str = Field("radius", description="Method for determining sparsity.") + radius_um: int = Field(100, description="Radius in micrometers for sparsity.") + + +class Waveforms(BaseModel): + ms_before: float = Field(3.0, description="Milliseconds before") + ms_after: float = Field(4.0, description="Milliseconds after") + max_spikes_per_unit: int = Field(500, description="Maximum spikes per unit") + return_scaled: bool = Field(True, description="Flag to determine if results should be scaled") + dtype: Optional[str] = Field(None, description="Data type for the waveforms") + precompute_template: Tuple[str, str] = Field(("average", "std"), description="Precomputation template method") + use_relative_path: bool = Field(True, description="Use relative paths") + + +class SpikeAmplitudes(BaseModel): + peak_sign: str = Field("neg", description="Sign of the peak") + return_scaled: bool = Field(True, description="Flag to determine if amplitudes should be scaled") + outputs: str = Field("concatenated", description="Output format for the spike amplitudes") + + +class Similarity(BaseModel): + method: str = Field("cosine_similarity", description="Method to compute similarity") + + +class Correlograms(BaseModel): + window_ms: float = Field(100.0, description="Size of the window in milliseconds") + bin_ms: float = Field(2.0, description="Size of the bin in milliseconds") + + +class ISIS(BaseModel): + window_ms: float = Field(100.0, description="Size of the window in milliseconds") + bin_ms: float = Field(5.0, description="Size of the bin in milliseconds") + + +class Locations(BaseModel): + method: str = Field("monopolar_triangulation", description="Method to determine locations") + + +class TemplateMetrics(BaseModel): + upsampling_factor: int = Field(10, description="Upsampling factor") + sparsity: Optional[str] = Field(None, description="Sparsity method") + + +class PrincipalComponents(BaseModel): + n_components: int = Field(5, description="Number of principal components") + mode: str = Field("by_channel_local", description="Mode of principal component analysis") + whiten: bool = Field(True, description="Whiten the components") + + +class PostprocessingParamsModel(BaseModel): + sparsity: Sparsity + waveforms_deduplicate: Waveforms + waveforms: Waveforms + spike_amplitudes: SpikeAmplitudes + similarity: Similarity + correlograms: Correlograms + isis: ISIS + locations: Locations + template_metrics: TemplateMetrics + principal_components: PrincipalComponents + quality_metrics: QualityMetrics diff --git a/src/spikeinterface_pipelines/postprocessing/postprocessing.py b/src/spikeinterface_pipelines/postprocessing/postprocessing.py new file mode 100644 index 0000000..21c3ce5 --- /dev/null +++ b/src/spikeinterface_pipelines/postprocessing/postprocessing.py @@ -0,0 +1,162 @@ +import warnings +import os +import numpy as np +from pathlib import Path +import shutil +import json +import time +from datetime import datetime +import spikeinterface as si +import spikeinterface.postprocessing as spost +import spikeinterface.qualitymetrics as sqm +import spikeinterface.curation as sc + +from .models import PostprocessingParamsModel + + +warnings.filterwarnings("ignore") + +n_jobs_co = os.getenv('CO_CPUS') +n_jobs = int(n_jobs_co) if n_jobs_co is not None else -1 + +job_kwargs = { + 'n_jobs': n_jobs, + 'chunk_duration': '1s', + 'progress_bar': True +} + +data_folder = Path("../data/") +results_folder = Path("../results/") +tmp_folder = results_folder / "tmp" +tmp_folder.mkdir() + + +def postprocessing( + data_folder: Path, + results_folder: Path, + job_kwargs: dict, + postprocessing_params: PostprocessingParamsModel, +) -> None: + data_process_prefix = "data_process_postprocessing" + si.set_global_job_kwargs(**job_kwargs) + print("\nPOSTPROCESSING") + t_postprocessing_start_all = time.perf_counter() + + # check if test + if (data_folder / "preprocessing_pipeline_output_test").is_dir(): + print("\n*******************\n**** TEST MODE ****\n*******************\n") + preprocessed_folder = data_folder / "preprocessing_pipeline_output_test" + spikesorted_folder = data_folder / "spikesorting_pipeline_output_test" + else: + preprocessed_folder = data_folder + spikesorted_folder = data_folder + + preprocessed_folders = [p for p in preprocessed_folder.iterdir() if p.is_dir() and "preprocessed_" in p.name] + + # load job json files + job_config_json_files = [p for p in data_folder.iterdir() if p.suffix == ".json" and "job" in p.name] + print(f"Found {len(job_config_json_files)} json configurations") + + if len(job_config_json_files) > 0: + recording_names = [] + for json_file in job_config_json_files: + with open(json_file, "r") as f: + config = json.load(f) + recording_name = config["recording_name"] + assert (preprocessed_folder / f"preprocessed_{recording_name}").is_dir(), f"Preprocessed folder for {recording_name} not found!" + recording_names.append(recording_name) + else: + recording_names = [("_").join(p.name.split("_")[1:]) for p in preprocessed_folders] + + for recording_name in recording_names: + datetime_start_postprocessing = datetime.now() + t_postprocessing_start = time.perf_counter() + postprocessing_notes = "" + + print(f"\tProcessing {recording_name}") + postprocessing_output_process_json = results_folder / f"{data_process_prefix}_{recording_name}.json" + postprocessing_output_folder = results_folder / f"postprocessed_{recording_name}" + postprocessing_sorting_output_folder = results_folder / f"postprocessed-sorting_{recording_name}" + + recording = si.load_extractor(preprocessed_folder / f"preprocessed_{recording_name}") + # make sure we have spikesorted output for the block-stream + sorted_folder = spikesorted_folder / f"spikesorted_{recording_name}" + if not sorted_folder.is_dir(): + raise FileNotFoundError(f"Spike sorted data for {recording_name} not found!") + + sorting = si.load_extractor(sorted_folder) + + # first extract some raw waveforms in memory to deduplicate based on peak alignment + wf_dedup_folder = tmp_folder / "postprocessed" / recording_name + we_raw = si.extract_waveforms( + recording, + sorting, + folder=wf_dedup_folder, + **postprocessing_params.waveforms_deduplicate.model_dump() + ) + + # de-duplication + sorting_deduplicated = sc.remove_redundant_units( + we_raw, + duplicate_threshold=postprocessing_params.duplicate_threshold + ) + print(f"\tNumber of original units: {len(we_raw.sorting.unit_ids)} -- Number of units after de-duplication: {len(sorting_deduplicated.unit_ids)}") + n_duplicated = int(len(sorting.unit_ids) - len(sorting_deduplicated.unit_ids)) + postprocessing_notes += f"\n- Removed {n_duplicated} duplicated units.\n" + deduplicated_unit_ids = sorting_deduplicated.unit_ids + + # use existing deduplicated waveforms to compute sparsity + sparsity_raw = si.compute_sparsity(we_raw, **sparsity_params) + sparsity_mask = sparsity_raw.mask[sorting.ids_to_indices(deduplicated_unit_ids), :] + sparsity = si.ChannelSparsity(mask=sparsity_mask, unit_ids=deduplicated_unit_ids, channel_ids=recording.channel_ids) + shutil.rmtree(wf_dedup_folder) + del we_raw + + # this is a trick to make the postprocessed folder "self-contained + sorting_deduplicated = sorting_deduplicated.save(folder=postprocessing_sorting_output_folder) + + # now extract waveforms on de-duplicated units + print("\tSaving sparse de-duplicated waveform extractor folder") + we = si.extract_waveforms( + recording, + sorting_deduplicated, + folder=postprocessing_output_folder, + sparsity=sparsity, + sparse=True, + overwrite=True, + **postprocessing_params.waveforms.model_dump() + ) + + print("\tComputing spike amplitides") + amps = spost.compute_spike_amplitudes(we, **postprocessing_params.spike_amplitudes.model_dump()) + + print("\tComputing unit locations") + unit_locs = spost.compute_unit_locations(we, **postprocessing_params.locations.model_dump()) + + print("\tComputing spike locations") + spike_locs = spost.compute_spike_locations(we, **postprocessing_params.locations.model_dump()) + + print("\tComputing correlograms") + corr = spost.compute_correlograms(we, **postprocessing_params.correlograms.model_dump()) + + print("\tComputing ISI histograms") + tm = spost.compute_isi_histograms(we, **postprocessing_params.isis.model_dump()) + + print("\tComputing template similarity") + sim = spost.compute_template_similarity(we, **postprocessing_params.similarity.model_dump()) + + print("\tComputing template metrics") + tm = spost.compute_template_metrics(we, **postprocessing_params.template_metrics.model_dump()) + + print("\tComputing PCA") + pc = spost.compute_principal_components(we, **postprocessing_params.principal_components.model_dump()) + + print("\tComputing quality metrics") + qm = sqm.compute_quality_metrics(we, **postprocessing_params.quality_metrics.model_dump()) + + t_postprocessing_end = time.perf_counter() + elapsed_time_postprocessing = np.round(t_postprocessing_end - t_postprocessing_start, 2) + + t_postprocessing_end_all = time.perf_counter() + elapsed_time_postprocessing_all = np.round(t_postprocessing_end_all - t_postprocessing_start_all, 2) + print(f"POSTPROCESSING time: {elapsed_time_postprocessing_all}s") \ No newline at end of file diff --git a/src/spikeinterface_pipelines/preprocessing/__init__.py b/src/spikeinterface_pipelines/preprocessing/__init__.py new file mode 100644 index 0000000..38557eb --- /dev/null +++ b/src/spikeinterface_pipelines/preprocessing/__init__.py @@ -0,0 +1 @@ +from .preprocessing import preprocessing From c8ec32bf04f4ccf3c27e97ff0e0578cb6758250f Mon Sep 17 00:00:00 2001 From: luiz Date: Fri, 27 Oct 2023 13:44:08 +0200 Subject: [PATCH 03/16] common model --- src/spikeinterface_pipelines/models.py | 7 ++++ .../postprocessing/models.py | 1 + .../postprocessing/postprocessing.py | 37 ++++++++++++------- .../preprocessing/preprocessing.py | 25 +++++++------ 4 files changed, 45 insertions(+), 25 deletions(-) create mode 100644 src/spikeinterface_pipelines/models.py diff --git a/src/spikeinterface_pipelines/models.py b/src/spikeinterface_pipelines/models.py new file mode 100644 index 0000000..ed5956f --- /dev/null +++ b/src/spikeinterface_pipelines/models.py @@ -0,0 +1,7 @@ +from pydantic import BaseModel, Field + + +class JobKwargs(BaseModel): + n_jobs: int = Field(-1, description="The number of jobs to run in parallel.") + chunk_duration: str = Field("1s", description="The duration of the chunks to process.") + progress_bar: bool = Field(True, description="Whether to display a progress bar.") \ No newline at end of file diff --git a/src/spikeinterface_pipelines/postprocessing/models.py b/src/spikeinterface_pipelines/postprocessing/models.py index 5ef6d99..1774c16 100644 --- a/src/spikeinterface_pipelines/postprocessing/models.py +++ b/src/spikeinterface_pipelines/postprocessing/models.py @@ -140,3 +140,4 @@ class PostprocessingParamsModel(BaseModel): template_metrics: TemplateMetrics principal_components: PrincipalComponents quality_metrics: QualityMetrics + duplicate_threshold: float = Field(0.9, description="Duplicate threshold") diff --git a/src/spikeinterface_pipelines/postprocessing/postprocessing.py b/src/spikeinterface_pipelines/postprocessing/postprocessing.py index 21c3ce5..6f0acb6 100644 --- a/src/spikeinterface_pipelines/postprocessing/postprocessing.py +++ b/src/spikeinterface_pipelines/postprocessing/postprocessing.py @@ -11,20 +11,12 @@ import spikeinterface.qualitymetrics as sqm import spikeinterface.curation as sc +from ..models import JobKwargs from .models import PostprocessingParamsModel warnings.filterwarnings("ignore") -n_jobs_co = os.getenv('CO_CPUS') -n_jobs = int(n_jobs_co) if n_jobs_co is not None else -1 - -job_kwargs = { - 'n_jobs': n_jobs, - 'chunk_duration': '1s', - 'progress_bar': True -} - data_folder = Path("../data/") results_folder = Path("../results/") tmp_folder = results_folder / "tmp" @@ -34,11 +26,26 @@ def postprocessing( data_folder: Path, results_folder: Path, - job_kwargs: dict, + job_kwargs: JobKwargs, postprocessing_params: PostprocessingParamsModel, ) -> None: + """ + Postprocessing pipeline + + Parameters + ---------- + data_folder: Path + Path to the data folder + results_folder: Path + Path to the results folder + job_kwargs: JobKwargs + Job kwargs + postprocessing_params: PostprocessingParamsModel + Postprocessing parameters + """ + si.set_global_job_kwargs(**job_kwargs.model_dump()) + data_process_prefix = "data_process_postprocessing" - si.set_global_job_kwargs(**job_kwargs) print("\nPOSTPROCESSING") t_postprocessing_start_all = time.perf_counter() @@ -106,9 +113,13 @@ def postprocessing( deduplicated_unit_ids = sorting_deduplicated.unit_ids # use existing deduplicated waveforms to compute sparsity - sparsity_raw = si.compute_sparsity(we_raw, **sparsity_params) + sparsity_raw = si.compute_sparsity(we_raw, **postprocessing_params.sparsity.model_dump()) sparsity_mask = sparsity_raw.mask[sorting.ids_to_indices(deduplicated_unit_ids), :] - sparsity = si.ChannelSparsity(mask=sparsity_mask, unit_ids=deduplicated_unit_ids, channel_ids=recording.channel_ids) + sparsity = si.ChannelSparsity( + mask=sparsity_mask, + unit_ids=deduplicated_unit_ids, + channel_ids=recording.channel_ids + ) shutil.rmtree(wf_dedup_folder) del we_raw diff --git a/src/spikeinterface_pipelines/preprocessing/preprocessing.py b/src/spikeinterface_pipelines/preprocessing/preprocessing.py index cdb4b8b..b8bf551 100644 --- a/src/spikeinterface_pipelines/preprocessing/preprocessing.py +++ b/src/spikeinterface_pipelines/preprocessing/preprocessing.py @@ -10,20 +10,12 @@ import spikeinterface.preprocessing as spre from spikeinterface.core.core_tools import check_json +from ..models import JobKwargs from .models import PreprocessingParamsModel warnings.filterwarnings("ignore") -n_jobs_co = os.getenv('CO_CPUS') -n_jobs = int(n_jobs_co) if n_jobs_co is not None else -1 - -job_kwargs = { - 'n_jobs': n_jobs, - 'chunk_duration': '1s', - 'progress_bar': True -} - data_folder = Path("../data/") results_folder = Path("../results/") @@ -31,13 +23,22 @@ def preprocessing( data_folder: Path, results_folder: Path, - job_kwargs: dict, + job_kwargs: JobKwargs, preprocessing_params: PreprocessingParamsModel, debug: bool = False, duration_s: float = 1. ) -> None: """ Preprocessing pipeline for ephys data. + + Parameters + ---------- + data_folder: Path + Path to the data folder. + results_folder: Path + Path to the results folder. + job_kwargs: dict + Job kwargs. """ data_process_prefix = "data_process_preprocessing" @@ -45,7 +46,7 @@ def preprocessing( if debug: print(f"DEBUG ENABLED - Only running with {duration_s} seconds") - si.set_global_job_kwargs(**job_kwargs) + si.set_global_job_kwargs(**job_kwargs.model_dump()) # load job json files job_config_json_files = [p for p in data_folder.iterdir() if p.suffix == ".json" and "job" in p.name] @@ -205,7 +206,7 @@ def preprocessing( recording_corrected = spre.correct_motion( recording_processed, preset=preset, folder=motion_folder, - **job_kwargs + **job_kwargs.model_dump() ) if preprocessing_params.motion_correction.apply: print("\tApplying motion correction") From 34f6fb38374f3756b64e26e0ba7f210ff8d71ca4 Mon Sep 17 00:00:00 2001 From: luiz Date: Fri, 27 Oct 2023 15:23:27 +0200 Subject: [PATCH 04/16] further simplifying --- .../postprocessing/postprocessing.py | 14 +- .../preprocessing/preprocessing.py | 318 ++++++------------ 2 files changed, 103 insertions(+), 229 deletions(-) diff --git a/src/spikeinterface_pipelines/postprocessing/postprocessing.py b/src/spikeinterface_pipelines/postprocessing/postprocessing.py index 6f0acb6..3ab75d4 100644 --- a/src/spikeinterface_pipelines/postprocessing/postprocessing.py +++ b/src/spikeinterface_pipelines/postprocessing/postprocessing.py @@ -17,17 +17,12 @@ warnings.filterwarnings("ignore") -data_folder = Path("../data/") -results_folder = Path("../results/") -tmp_folder = results_folder / "tmp" -tmp_folder.mkdir() - def postprocessing( - data_folder: Path, - results_folder: Path, job_kwargs: JobKwargs, postprocessing_params: PostprocessingParamsModel, + data_folder: Path = Path("../data/"), + results_folder: Path = Path("../results/"), ) -> None: """ Postprocessing pipeline @@ -45,6 +40,9 @@ def postprocessing( """ si.set_global_job_kwargs(**job_kwargs.model_dump()) + tmp_folder = results_folder / "tmp" + tmp_folder.mkdir() + data_process_prefix = "data_process_postprocessing" print("\nPOSTPROCESSING") t_postprocessing_start_all = time.perf_counter() @@ -170,4 +168,4 @@ def postprocessing( t_postprocessing_end_all = time.perf_counter() elapsed_time_postprocessing_all = np.round(t_postprocessing_end_all - t_postprocessing_start_all, 2) - print(f"POSTPROCESSING time: {elapsed_time_postprocessing_all}s") \ No newline at end of file + print(f"POSTPROCESSING time: {elapsed_time_postprocessing_all}s") diff --git a/src/spikeinterface_pipelines/preprocessing/preprocessing.py b/src/spikeinterface_pipelines/preprocessing/preprocessing.py index b8bf551..e949dcb 100644 --- a/src/spikeinterface_pipelines/preprocessing/preprocessing.py +++ b/src/spikeinterface_pipelines/preprocessing/preprocessing.py @@ -1,14 +1,8 @@ import warnings -import os -import json -import time import numpy as np from pathlib import Path -from datetime import datetime import spikeinterface as si -import spikeinterface.extractors as se import spikeinterface.preprocessing as spre -from spikeinterface.core.core_tools import check_json from ..models import JobKwargs from .models import PreprocessingParamsModel @@ -16,235 +10,117 @@ warnings.filterwarnings("ignore") -data_folder = Path("../data/") -results_folder = Path("../results/") - def preprocessing( - data_folder: Path, - results_folder: Path, job_kwargs: JobKwargs, + recording: si.BaseRecording, preprocessing_params: PreprocessingParamsModel, + results_path: Path = Path("./results/"), debug: bool = False, duration_s: float = 1. -) -> None: +) -> None | si.BaseRecording: """ Preprocessing pipeline for ephys data. Parameters ---------- - data_folder: Path - Path to the data folder. - results_folder: Path + recording: si.BaseRecording + Recording extractor. + preprocessing_params: PreprocessingParamsModel + Preprocessing parameters. + results_path: Path Path to the results folder. - job_kwargs: dict - Job kwargs. + debug: bool + Flag to run in debug mode. + duration_s: float + Duration in seconds to use in the debug mode. """ - - data_process_prefix = "data_process_preprocessing" + si.set_global_job_kwargs(**job_kwargs.model_dump()) if debug: print(f"DEBUG ENABLED - Only running with {duration_s} seconds") - si.set_global_job_kwargs(**job_kwargs.model_dump()) - - # load job json files - job_config_json_files = [p for p in data_folder.iterdir() if p.suffix == ".json" and "job" in p.name] - print(f"Found {len(job_config_json_files)} json configurations") - - if len(job_config_json_files) > 0: - print("\n\nPREPROCESSING") - t_preprocessing_start_all = time.perf_counter() - preprocessing_vizualization_data = {} - print(f"Preprocessing strategy: {preprocessing_params.preprocessing_strategy}") - - for job_config_file in job_config_json_files: - datetime_start_preproc = datetime.now() - t_preprocessing_start = time.perf_counter() - preprocessing_notes = "" - - with open(job_config_file, "r") as f: - job_config = json.load(f) - session_name = job_config["session_name"] - session_folder_path = job_config["session_folder_path"] - - session = data_folder / session_folder_path - assert session.is_dir(), ( - f"Could not find {session_name} in {str((data_folder / session_folder_path).resolve())}." - f"Make sure mapping is correct!" - ) - - ecephys_full_folder = session / "ecephys" - ecephys_compressed_folder = session / "ecephys_compressed" - compressed = False - if ecephys_compressed_folder.is_dir(): - compressed = True - ecephys_folder = session / "ecephys_clipped" - else: - ecephys_folder = ecephys_full_folder - - experiment_name = job_config["experiment_name"] - stream_name = job_config["stream_name"] - block_index = job_config["block_index"] - segment_index = job_config["segment_index"] - recording_name = job_config["recording_name"] - - preprocessing_vizualization_data[recording_name] = {} - preprocessing_output_process_json = results_folder / f"{data_process_prefix}_{recording_name}.json" - preprocessing_output_folder = results_folder / f"preprocessed_{recording_name}" - preprocessingviz_output_file = results_folder / f"preprocessedviz_{recording_name}.json" - preprocessing_output_json = results_folder / f"preprocessed_{recording_name}.json" - - exp_stream_name = f"{experiment_name}_{stream_name}" - if not compressed: - recording = se.read_openephys( - ecephys_folder, - stream_name=stream_name, - block_index=block_index - ) - else: - recording = si.read_zarr(ecephys_compressed_folder / f"{exp_stream_name}.zarr") - - if debug: - recording_list = [] - for segment_index in range(recording.get_num_segments()): - recording_one = si.split_recording(recording)[segment_index] - recording_one = recording_one.frame_slice( - start_frame=0, - end_frame=int(duration_s*recording.sampling_frequency) - ) - recording_list.append(recording_one) - recording = si.append_recordings(recording_list) - - if segment_index is not None: - recording = si.split_recording(recording)[segment_index] - - print(f"Preprocessing recording: {recording_name}") - print(f"\tDuration: {np.round(recording.get_total_duration(), 2)} s") - - recording_ps_full = spre.phase_shift( - recording, - **preprocessing_params.phase_shift.model_dump() - ) - recording_hp_full = spre.highpass_filter( - recording_ps_full, - **preprocessing_params.highpass_filter.model_dump() - ) - preprocessing_vizualization_data[recording_name]["timeseries"] = {} - preprocessing_vizualization_data[recording_name]["timeseries"]["full"] = dict( - raw=recording.to_dict(relative_to=data_folder, recursive=True), - phase_shift=recording_ps_full.to_dict(relative_to=data_folder, recursive=True), - highpass=recording_hp_full.to_dict(relative_to=data_folder, recursive=True) - ) - - # IBL bad channel detection - _, channel_labels = spre.detect_bad_channels( - recording_hp_full, - **preprocessing_params.detect_bad_channels.model_dump() - ) - dead_channel_mask = channel_labels == "dead" - noise_channel_mask = channel_labels == "noise" - out_channel_mask = channel_labels == "out" - print(f"\tBad channel detection:") - print(f"\t\t- dead channels - {np.sum(dead_channel_mask)}\n\t\t- noise channels - {np.sum(noise_channel_mask)}\n\t\t- out channels - {np.sum(out_channel_mask)}") - dead_channel_ids = recording_hp_full.channel_ids[dead_channel_mask] - noise_channel_ids = recording_hp_full.channel_ids[noise_channel_mask] - out_channel_ids = recording_hp_full.channel_ids[out_channel_mask] - - all_bad_channel_ids = np.concatenate((dead_channel_ids, noise_channel_ids, out_channel_ids)) - - skip_processing = False - max_bad_channel_fraction_to_remove = preprocessing_params.max_bad_channel_fraction_to_remove - if len(all_bad_channel_ids) >= int(max_bad_channel_fraction_to_remove * recording.get_num_channels()): - print(f"\tMore than {max_bad_channel_fraction_to_remove * 100}% bad channels ({len(all_bad_channel_ids)}). " - f"Skipping further processing for this recording.") - preprocessing_notes += f"\n- Found {len(all_bad_channel_ids)} bad channels. Skipping further processing\n" - skip_processing = True - # in this case, processed timeseries will not be visualized - preprocessing_vizualization_data[recording_name]["timeseries"]["proc"] = None - recording_drift = recording_hp_full - else: - if preprocessing_params.remove_out_channels: - print(f"\tRemoving {len(out_channel_ids)} out channels") - recording_rm_out = recording_hp_full.remove_channels(out_channel_ids) - preprocessing_notes += f"\n- Removed {len(out_channel_ids)} outside of the brain." - else: - recording_rm_out = recording_hp_full - - recording_processed_cmr = spre.common_reference( - recording_rm_out, - **preprocessing_params.common_reference.model_dump() - ) - - bad_channel_ids = np.concatenate((dead_channel_ids, noise_channel_ids)) - recording_interp = spre.interpolate_bad_channels(recording_rm_out, bad_channel_ids) - recording_hp_spatial = spre.highpass_spatial_filter( - recording_interp, - **preprocessing_params.highpass_spatial_filter.model_dump() - ) - preprocessing_vizualization_data[recording_name]["timeseries"]["proc"] = dict( - highpass=recording_rm_out.to_dict(relative_to=data_folder, recursive=True), - cmr=recording_processed_cmr.to_dict(relative_to=data_folder, recursive=True), - highpass_spatial=recording_hp_spatial.to_dict(relative_to=data_folder, recursive=True) - ) - - if preprocessing_params.preprocessing_strategy == "cmr": - recording_processed = recording_processed_cmr - else: - recording_processed = recording_hp_spatial - - if preprocessing_params.remove_bad_channels: - print(f"\tRemoving {len(bad_channel_ids)} channels after {preprocessing_params.preprocessing_strategy} preprocessing") - recording_processed = recording_processed.remove_channels(bad_channel_ids) - preprocessing_notes += f"\n- Removed {len(bad_channel_ids)} bad channels after preprocessing.\n" - - # motion correction - if preprocessing_params.motion_correction.compute: - preset = preprocessing_params.motion_correction.preset - print(f"\tComputing motion correction with preset: {preset}") - motion_folder = results_folder / f"motion_{recording_name}" - recording_corrected = spre.correct_motion( - recording_processed, preset=preset, - folder=motion_folder, - **job_kwargs.model_dump() - ) - if preprocessing_params.motion_correction.apply: - print("\tApplying motion correction") - recording_processed = recording_corrected - - recording_saved = recording_processed.save(folder=preprocessing_output_folder) - recording_processed.dump_to_json(preprocessing_output_json, relative_to=data_folder) - recording_drift = recording_saved - - # store recording for drift visualization - preprocessing_vizualization_data[recording_name]["drift"] = dict( - recording=recording_drift.to_dict(relative_to=data_folder, recursive=True) - ) - with open(preprocessingviz_output_file, "w") as f: - json.dump(check_json(preprocessing_vizualization_data), f, indent=4) - - t_preprocessing_end = time.perf_counter() - elapsed_time_preprocessing = np.round(t_preprocessing_end - t_preprocessing_start, 2) - - # save params in output - preprocessing_params["recording_name"] = recording_name - preprocessing_outputs = dict(channel_labels=channel_labels.tolist()) - # preprocessing_process = DataProcess( - # name="Ephys preprocessing", - # version=VERSION, # either release or git commit - # start_date_time=datetime_start_preproc, - # end_date_time=datetime_start_preproc + timedelta(seconds=np.floor(elapsed_time_preprocessing)), - # input_location=str(data_folder), - # output_location=str(results_folder), - # code_url=URL, - # parameters=preprocessing_params, - # outputs=preprocessing_outputs, - # notes=preprocessing_notes - # ) - # with open(preprocessing_output_process_json, "w") as f: - # f.write(preprocessing_process.json(indent=3)) - - t_preprocessing_end_all = time.perf_counter() - elapsed_time_preprocessing_all = np.round(t_preprocessing_end_all - t_preprocessing_start_all, 2) - - print(f"PREPROCESSING time: {elapsed_time_preprocessing_all}s") \ No newline at end of file + recording_name = recording.name + preprocessing_notes = "" + preprocessing_output_process_json = results_path / f"{data_process_prefix}_{recording_name}.json" + preprocessing_output_folder = results_path / f"preprocessed_{recording_name}" + preprocessing_output_json = results_path / f"preprocessed_{recording_name}.json" + + print(f"Preprocessing recording: {recording_name}") + print(f"\tDuration: {np.round(recording.get_total_duration(), 2)} s") + + recording_ps_full = spre.phase_shift( + recording, + **preprocessing_params.phase_shift.model_dump() + ) + recording_hp_full = spre.highpass_filter( + recording_ps_full, + **preprocessing_params.highpass_filter.model_dump() + ) + + # Detect bad channels + _, channel_labels = spre.detect_bad_channels( + recording_hp_full, + **preprocessing_params.detect_bad_channels.model_dump() + ) + dead_channel_mask = channel_labels == "dead" + noise_channel_mask = channel_labels == "noise" + out_channel_mask = channel_labels == "out" + print("\tBad channel detection:") + print(f"\t\t- dead channels - {np.sum(dead_channel_mask)}\n\t\t- noise channels - {np.sum(noise_channel_mask)}\n\t\t- out channels - {np.sum(out_channel_mask)}") + dead_channel_ids = recording_hp_full.channel_ids[dead_channel_mask] + noise_channel_ids = recording_hp_full.channel_ids[noise_channel_mask] + out_channel_ids = recording_hp_full.channel_ids[out_channel_mask] + all_bad_channel_ids = np.concatenate((dead_channel_ids, noise_channel_ids, out_channel_ids)) + + max_bad_channel_fraction_to_remove = preprocessing_params.max_bad_channel_fraction_to_remove + if len(all_bad_channel_ids) >= int(max_bad_channel_fraction_to_remove * recording.get_num_channels()): + print(f"\tMore than {max_bad_channel_fraction_to_remove * 100}% bad channels ({len(all_bad_channel_ids)}). ") + print("Skipping further processing for this recording.") + preprocessing_notes += f"\n- Found {len(all_bad_channel_ids)} bad channels. Skipping further processing\n" + return None + + if preprocessing_params.remove_out_channels: + print(f"\tRemoving {len(out_channel_ids)} out channels") + recording_rm_out = recording_hp_full.remove_channels(out_channel_ids) + preprocessing_notes += f"\n- Removed {len(out_channel_ids)} outside of the brain." + else: + recording_rm_out = recording_hp_full + + bad_channel_ids = np.concatenate((dead_channel_ids, noise_channel_ids)) + + if preprocessing_params.preprocessing_strategy == "cmr": + recording_processed = spre.common_reference( + recording_rm_out, + **preprocessing_params.common_reference.model_dump() + ) + else: + recording_interp = spre.interpolate_bad_channels(recording_rm_out, bad_channel_ids) + recording_processed = spre.highpass_spatial_filter( + recording_interp, + **preprocessing_params.highpass_spatial_filter.model_dump() + ) + + if preprocessing_params.remove_bad_channels: + print(f"\tRemoving {len(bad_channel_ids)} channels after {preprocessing_params.preprocessing_strategy} preprocessing") + recording_processed = recording_processed.remove_channels(bad_channel_ids) + preprocessing_notes += f"\n- Removed {len(bad_channel_ids)} bad channels after preprocessing.\n" + + # motion correction + if preprocessing_params.motion_correction.compute: + preset = preprocessing_params.motion_correction.preset + print(f"\tComputing motion correction with preset: {preset}") + motion_folder = output_path / f"motion_{recording_name}" + recording_corrected = spre.correct_motion( + recording_processed, preset=preset, + folder=motion_folder, + **job_kwargs.model_dump() + ) + if preprocessing_params.motion_correction.apply: + print("\tApplying motion correction") + recording_processed = recording_corrected + + # recording_saved = recording_processed.save(folder=preprocessing_output_folder) + # recording_processed.dump_to_json(preprocessing_output_json, relative_to=data_folder) + + return recording_processed From 228eced2f90fe1891aa9e9306326be7d7bf31c17 Mon Sep 17 00:00:00 2001 From: luiz Date: Fri, 27 Oct 2023 16:17:22 +0200 Subject: [PATCH 05/16] pypi action --- .github/workflows/pypi_release.yaml | 31 +++++++++++++++++++ pyproject.toml | 20 +++++++++--- requirements.txt | 1 - .../postprocessing/postprocessing.py | 4 +-- .../preprocessing/preprocessing.py | 26 ++++++++-------- 5 files changed, 62 insertions(+), 20 deletions(-) create mode 100644 .github/workflows/pypi_release.yaml delete mode 100644 requirements.txt diff --git a/.github/workflows/pypi_release.yaml b/.github/workflows/pypi_release.yaml new file mode 100644 index 0000000..080bcaf --- /dev/null +++ b/.github/workflows/pypi_release.yaml @@ -0,0 +1,31 @@ +name: Build and publish spikeinterface_pipelines Pyton package + +on: + push: + branches: + - main + paths: + - src/spikeinterface_pipelines/** + +jobs: + pypi-release: + name: PyPI release + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Set up Python + uses: actions/setup-python@v3 + with: + python-version: "3.8" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install build + pip install twine + - name: Build and publish to PyPI + run: | + python -m build + twine upload dist/* + env: + TWINE_USERNAME: __token__ + TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }} \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 0322d3c..46dafd1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,10 +1,15 @@ [project] name = "spikeinterface_pipelines" +version = "0.1.0" description = "Collection of standardized analysis pipelines based on SpikeInterfacee." readme = "README.md" -authors = [{ name = "My Name", email = "me@email.com" }] +authors = [ + { name = "Alessio Buccino", email = "alessiop.buccino@gmail.com" }, + { name = "Jeremy Magland", email = "jmagland@flatironinstitute.org" }, + { name = "Luiz Tauffer", email = "luiz.tauffer@catalystneuro.com" }, +] requires-python = ">=3.8" -dependencies = ["spikeinterface[full]"] +dependencies = ["spikeinterface[full]", "neo>=0.12.0"] keywords = [ "spikeinterface", "spike sorting", @@ -13,9 +18,16 @@ keywords = [ ] [project.urls] -Homepage = "https://github.com/SpikeInterface/spikeinterface_pipelines" -Documentation = "https://github.com/SpikeInterface/spikeinterface_pipelines" +homepage = "https://github.com/SpikeInterface/spikeinterface_pipelines" +documentation = "https://github.com/SpikeInterface/spikeinterface_pipelines" +repository = "https://github.com/SpikeInterface/spikeinterface_pipelines" [build-system] requires = ["setuptools"] build-backend = "setuptools.build_meta" + +[tool.setuptools] +package-dir = { "" = "src" } + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 0c660d7..0000000 --- a/requirements.txt +++ /dev/null @@ -1 +0,0 @@ -spikeinterface[full] \ No newline at end of file diff --git a/src/spikeinterface_pipelines/postprocessing/postprocessing.py b/src/spikeinterface_pipelines/postprocessing/postprocessing.py index 3ab75d4..1a096ed 100644 --- a/src/spikeinterface_pipelines/postprocessing/postprocessing.py +++ b/src/spikeinterface_pipelines/postprocessing/postprocessing.py @@ -22,7 +22,7 @@ def postprocessing( job_kwargs: JobKwargs, postprocessing_params: PostprocessingParamsModel, data_folder: Path = Path("../data/"), - results_folder: Path = Path("../results/"), + results_path: Path = Path("./results/"), ) -> None: """ Postprocessing pipeline @@ -40,7 +40,7 @@ def postprocessing( """ si.set_global_job_kwargs(**job_kwargs.model_dump()) - tmp_folder = results_folder / "tmp" + tmp_folder = results_path / "tmp" tmp_folder.mkdir() data_process_prefix = "data_process_postprocessing" diff --git a/src/spikeinterface_pipelines/preprocessing/preprocessing.py b/src/spikeinterface_pipelines/preprocessing/preprocessing.py index e949dcb..7c05cee 100644 --- a/src/spikeinterface_pipelines/preprocessing/preprocessing.py +++ b/src/spikeinterface_pipelines/preprocessing/preprocessing.py @@ -15,9 +15,7 @@ def preprocessing( job_kwargs: JobKwargs, recording: si.BaseRecording, preprocessing_params: PreprocessingParamsModel, - results_path: Path = Path("./results/"), - debug: bool = False, - duration_s: float = 1. + results_path: Path = Path("./results/") ) -> None | si.BaseRecording: """ Preprocessing pipeline for ephys data. @@ -37,22 +35,23 @@ def preprocessing( """ si.set_global_job_kwargs(**job_kwargs.model_dump()) - if debug: - print(f"DEBUG ENABLED - Only running with {duration_s} seconds") - - recording_name = recording.name preprocessing_notes = "" - preprocessing_output_process_json = results_path / f"{data_process_prefix}_{recording_name}.json" - preprocessing_output_folder = results_path / f"preprocessed_{recording_name}" - preprocessing_output_json = results_path / f"preprocessed_{recording_name}.json" - print(f"Preprocessing recording: {recording_name}") + # recording_name = recording.name + # preprocessing_output_process_json = results_path / f"{data_process_prefix}_{recording_name}.json" + # preprocessing_output_folder = results_path / f"preprocessed_{recording_name}" + # preprocessing_output_json = results_path / f"preprocessed_{recording_name}.json" + + print("Preprocessing recording") print(f"\tDuration: {np.round(recording.get_total_duration(), 2)} s") + # Phase shift correction recording_ps_full = spre.phase_shift( recording, **preprocessing_params.phase_shift.model_dump() ) + + # Highpass filter recording_hp_full = spre.highpass_filter( recording_ps_full, **preprocessing_params.highpass_filter.model_dump() @@ -89,6 +88,7 @@ def preprocessing( bad_channel_ids = np.concatenate((dead_channel_ids, noise_channel_ids)) + # Strategy: CMR or destripe if preprocessing_params.preprocessing_strategy == "cmr": recording_processed = spre.common_reference( recording_rm_out, @@ -106,11 +106,11 @@ def preprocessing( recording_processed = recording_processed.remove_channels(bad_channel_ids) preprocessing_notes += f"\n- Removed {len(bad_channel_ids)} bad channels after preprocessing.\n" - # motion correction + # Motion correction if preprocessing_params.motion_correction.compute: preset = preprocessing_params.motion_correction.preset print(f"\tComputing motion correction with preset: {preset}") - motion_folder = output_path / f"motion_{recording_name}" + motion_folder = results_path / "motion_correction" recording_corrected = spre.correct_motion( recording_processed, preset=preset, folder=motion_folder, From a46c6ea3259649c3a86e6175a787558adf6d1dbf Mon Sep 17 00:00:00 2001 From: luiz Date: Mon, 30 Oct 2023 11:44:00 +0100 Subject: [PATCH 06/16] default values, wip pipelining --- src/spikeinterface_pipelines/__init__.py | 1 + src/spikeinterface_pipelines/logger.py | 5 ++ src/spikeinterface_pipelines/models.py | 6 +- src/spikeinterface_pipelines/pipeline.py | 46 ++++++++++++++++ .../postprocessing/__init__.py | 3 +- .../preprocessing/__init__.py | 1 + .../preprocessing/models.py | 12 ++-- .../preprocessing/preprocessing.py | 9 +-- .../sorting/__init__.py | 1 + .../sorting/models.py | 55 +++++++++++++++++++ 10 files changed, 125 insertions(+), 14 deletions(-) create mode 100644 src/spikeinterface_pipelines/logger.py create mode 100644 src/spikeinterface_pipelines/pipeline.py create mode 100644 src/spikeinterface_pipelines/sorting/__init__.py create mode 100644 src/spikeinterface_pipelines/sorting/models.py diff --git a/src/spikeinterface_pipelines/__init__.py b/src/spikeinterface_pipelines/__init__.py index e69de29..b211130 100644 --- a/src/spikeinterface_pipelines/__init__.py +++ b/src/spikeinterface_pipelines/__init__.py @@ -0,0 +1 @@ +from .pipeline import pipeline \ No newline at end of file diff --git a/src/spikeinterface_pipelines/logger.py b/src/spikeinterface_pipelines/logger.py new file mode 100644 index 0000000..d098728 --- /dev/null +++ b/src/spikeinterface_pipelines/logger.py @@ -0,0 +1,5 @@ +import logging + + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) \ No newline at end of file diff --git a/src/spikeinterface_pipelines/models.py b/src/spikeinterface_pipelines/models.py index ed5956f..1a64bee 100644 --- a/src/spikeinterface_pipelines/models.py +++ b/src/spikeinterface_pipelines/models.py @@ -2,6 +2,6 @@ class JobKwargs(BaseModel): - n_jobs: int = Field(-1, description="The number of jobs to run in parallel.") - chunk_duration: str = Field("1s", description="The duration of the chunks to process.") - progress_bar: bool = Field(True, description="Whether to display a progress bar.") \ No newline at end of file + n_jobs: int = Field(default=-1, description="The number of jobs to run in parallel.") + chunk_duration: str = Field(default="1s", description="The duration of the chunks to process.") + progress_bar: bool = Field(default=True, description="Whether to display a progress bar.") diff --git a/src/spikeinterface_pipelines/pipeline.py b/src/spikeinterface_pipelines/pipeline.py new file mode 100644 index 0000000..5fa3c0f --- /dev/null +++ b/src/spikeinterface_pipelines/pipeline.py @@ -0,0 +1,46 @@ +from pathlib import Path +import shutil +import spikeinterface as si +import spikeinterface.sorters as ss + +from .logger import logger +from .preprocessing import preprocessing, PreprocessingParamsModel +from .sorting import SortingParamsModel + + +# TODO - WIP +def pipeline( + recording: si.BaseRecording, + results_path: Path = Path("./results/"), + preprocessing_params: PreprocessingParamsModel = PreprocessingParamsModel(), + sorting_params: SortingParamsModel = SortingParamsModel(), +): + # Preprocessing + results_path_preprocessing = results_path / "preprocessing" + recording_preprocessed = preprocessing( + recording=recording, + preprocessing_params=preprocessing_params, + results_path=results_path_preprocessing, + ) + if recording_preprocessed is None: + return None + + # Spike Sorting + results_path_sorting = results_path / "sorting" + try: + sorting = ss.run_sorter( + recording=recording_preprocessed, + output_folder=str(results_path_sorting), + verbose=False, + delete_output_folder=True, + **sorting_params.model_dump(), + ) + # remove empty units + sorting = sorting.remove_empty_units() + except Exception as e: + # save log to results + results_path_sorting.mkdir() + shutil.copy(spikesorted_raw_output_folder / "spikeinterface_log.json", sorting_output_folder) + raise e + + return recording diff --git a/src/spikeinterface_pipelines/postprocessing/__init__.py b/src/spikeinterface_pipelines/postprocessing/__init__.py index c01852b..6e4b53b 100644 --- a/src/spikeinterface_pipelines/postprocessing/__init__.py +++ b/src/spikeinterface_pipelines/postprocessing/__init__.py @@ -1 +1,2 @@ -from .postprocessing import postprocessing \ No newline at end of file +from .postprocessing import postprocessing +from .models import PostprocessingParamsModel \ No newline at end of file diff --git a/src/spikeinterface_pipelines/preprocessing/__init__.py b/src/spikeinterface_pipelines/preprocessing/__init__.py index 38557eb..445fde2 100644 --- a/src/spikeinterface_pipelines/preprocessing/__init__.py +++ b/src/spikeinterface_pipelines/preprocessing/__init__.py @@ -1 +1,2 @@ from .preprocessing import preprocessing +from .models import PreprocessingParamsModel \ No newline at end of file diff --git a/src/spikeinterface_pipelines/preprocessing/models.py b/src/spikeinterface_pipelines/preprocessing/models.py index 566419e..bd4dcdd 100644 --- a/src/spikeinterface_pipelines/preprocessing/models.py +++ b/src/spikeinterface_pipelines/preprocessing/models.py @@ -49,12 +49,12 @@ class MotionCorrection(BaseModel): class PreprocessingParamsModel(BaseModel): preprocessing_strategy: PreprocessingStrategy = Field(default="cmr", description="Strategy for preprocessing") - highpass_filter: HighpassFilter - phase_shift: PhaseShift - detect_bad_channels: DetectBadChannels + highpass_filter: HighpassFilter = Field(default=HighpassFilter(), description="Highpass filter") + phase_shift: PhaseShift = Field(default=PhaseShift(), description="Phase shift") + common_reference: CommonReference = Field(default=CommonReference(), description="Common reference") + highpass_spatial_filter: HighpassSpatialFilter = Field(default=HighpassSpatialFilter(), description="Highpass spatial filter") + motion_correction: MotionCorrection = Field(default=MotionCorrection(), description="Motion correction") + detect_bad_channels: DetectBadChannels = Field(default=DetectBadChannels(), description="Detect bad channels") remove_out_channels: bool = Field(default=True, description="Flag to remove out channels") remove_bad_channels: bool = Field(default=True, description="Flag to remove bad channels") max_bad_channel_fraction_to_remove: float = Field(default=0.5, description="Maximum fraction of bad channels to remove") - common_reference: CommonReference - highpass_spatial_filter: HighpassSpatialFilter - motion_correction: MotionCorrection diff --git a/src/spikeinterface_pipelines/preprocessing/preprocessing.py b/src/spikeinterface_pipelines/preprocessing/preprocessing.py index 7c05cee..ac2aef0 100644 --- a/src/spikeinterface_pipelines/preprocessing/preprocessing.py +++ b/src/spikeinterface_pipelines/preprocessing/preprocessing.py @@ -1,6 +1,7 @@ import warnings import numpy as np from pathlib import Path +from typing import Union import spikeinterface as si import spikeinterface.preprocessing as spre @@ -12,11 +13,11 @@ def preprocessing( - job_kwargs: JobKwargs, recording: si.BaseRecording, - preprocessing_params: PreprocessingParamsModel, - results_path: Path = Path("./results/") -) -> None | si.BaseRecording: + preprocessing_params: PreprocessingParamsModel = PreprocessingParamsModel(), + results_path: Path = Path("./results/"), + job_kwargs: JobKwargs = JobKwargs(), +) -> si.BaseRecording | None: """ Preprocessing pipeline for ephys data. diff --git a/src/spikeinterface_pipelines/sorting/__init__.py b/src/spikeinterface_pipelines/sorting/__init__.py new file mode 100644 index 0000000..c4bfb2a --- /dev/null +++ b/src/spikeinterface_pipelines/sorting/__init__.py @@ -0,0 +1 @@ +from .models import SortingParamsModel \ No newline at end of file diff --git a/src/spikeinterface_pipelines/sorting/models.py b/src/spikeinterface_pipelines/sorting/models.py new file mode 100644 index 0000000..eff36b9 --- /dev/null +++ b/src/spikeinterface_pipelines/sorting/models.py @@ -0,0 +1,55 @@ +from pydantic import BaseModel, Field +from typing import Union, List +from enum import Enum + + +class SorterName(str, Enum): + ironclust = "ironclust" + kilosort25 = "kilosort25" + kilosort3 = "kilosort3" + mountainsort5 = "mountainsort5" + + +class Kilosort25Model(BaseModel): + detect_threshold: float = Field(default=6, description="Threshold for spike detection") + projection_threshold: List[float] = Field(default=[10, 4], description="Threshold on projections") + preclust_threshold: float = Field(default=8, description="Threshold crossings for pre-clustering (in PCA projection space)") + car: bool = Field(default=True, description="Enable or disable common reference") + minFR: float = Field(default=0.1, description="Minimum spike rate (Hz), if a cluster falls below this for too long it gets removed") + minfr_goodchannels: float = Field(default=0.1, description="Minimum firing rate on a 'good' channel") + nblocks: int = Field(default=5, description="blocks for registration. 0 turns it off, 1 does rigid registration. Replaces 'datashift' option.") + sig: float = Field(default=20, description="spatial smoothness constant for registration") + freq_min: float = Field(default=150, description="High-pass filter cutoff frequency") + sigmaMask: float = Field(default=30, description="Spatial constant in um for computing residual variance of spike") + nPCs: int = Field(default=3, description="Number of PCA dimensions") + ntbuff: int = Field(default=64, description="Samples of symmetrical buffer for whitening and spike detection") + nfilt_factor: int = Field(default=4, description="Max number of clusters per good channel (even temporary ones) 4") + NT: int = Field(default=-1, description='Batch size (if -1 it is automatically computed)') + AUCsplit: float = Field(default=0.9, description="Threshold on the area under the curve (AUC) criterion for performing a split in the final step") + do_correction: bool = Field(default=True, description="If True drift registration is applied") + wave_length: float = Field(default=61, description="size of the waveform extracted around each detected peak, (Default 61, maximum 81)") + keep_good_only: bool = Field(default=True, description="If True only 'good' units are returned") + skip_kilosort_preprocessing: bool = Field(default=False, description="Can optionaly skip the internal kilosort preprocessing") + scaleproc: int = Field(default=-1, description="int16 scaling of whitened data, if -1 set to 200.") + + +class Kilosort3Model(BaseModel): + pass + + +class IronClustModel(BaseModel): + pass + + +class MountainSort5Model(BaseModel): + pass + + +class SortingParamsModel(BaseModel): + sorter_name: SorterName = Field(default="kilosort25", description="Name of the sorter to use.") + sorter_kwargs: Union[ + Kilosort25Model, + Kilosort3Model, + IronClustModel, + MountainSort5Model + ] = Field(default=Kilosort25Model(), description="Sorter specific kwargs.") \ No newline at end of file From 016151b835314fac3056552f04d3b431c60fedde Mon Sep 17 00:00:00 2001 From: luiz Date: Mon, 30 Oct 2023 15:06:42 +0100 Subject: [PATCH 07/16] sorting --- pyproject.toml | 2 +- src/spikeinterface_pipelines/pipeline.py | 52 ++++++++----------- .../preprocessing/preprocessing.py | 37 ++++++------- .../sorting/__init__.py | 3 +- .../sorting/models.py | 4 +- .../sorting/sorting.py | 38 ++++++++++++++ 6 files changed, 85 insertions(+), 51 deletions(-) create mode 100644 src/spikeinterface_pipelines/sorting/sorting.py diff --git a/pyproject.toml b/pyproject.toml index 46dafd1..9e7956b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ authors = [ { name = "Luiz Tauffer", email = "luiz.tauffer@catalystneuro.com" }, ] requires-python = ">=3.8" -dependencies = ["spikeinterface[full]", "neo>=0.12.0"] +dependencies = ["spikeinterface[full]", "neo>=0.12.0", "pydantic>=2.4.2"] keywords = [ "spikeinterface", "spike sorting", diff --git a/src/spikeinterface_pipelines/pipeline.py b/src/spikeinterface_pipelines/pipeline.py index 5fa3c0f..087762e 100644 --- a/src/spikeinterface_pipelines/pipeline.py +++ b/src/spikeinterface_pipelines/pipeline.py @@ -1,11 +1,9 @@ from pathlib import Path -import shutil import spikeinterface as si -import spikeinterface.sorters as ss from .logger import logger from .preprocessing import preprocessing, PreprocessingParamsModel -from .sorting import SortingParamsModel +from .sorting import sorting, SortingParamsModel # TODO - WIP @@ -14,33 +12,29 @@ def pipeline( results_path: Path = Path("./results/"), preprocessing_params: PreprocessingParamsModel = PreprocessingParamsModel(), sorting_params: SortingParamsModel = SortingParamsModel(), -): - # Preprocessing + run_preprocessing: bool = True, +) -> None: + # Paths results_path_preprocessing = results_path / "preprocessing" - recording_preprocessed = preprocessing( - recording=recording, - preprocessing_params=preprocessing_params, - results_path=results_path_preprocessing, - ) - if recording_preprocessed is None: - return None - - # Spike Sorting results_path_sorting = results_path / "sorting" - try: - sorting = ss.run_sorter( - recording=recording_preprocessed, - output_folder=str(results_path_sorting), - verbose=False, - delete_output_folder=True, - **sorting_params.model_dump(), + + # Preprocessing + if run_preprocessing: + logger.info("Preprocessing recording") + recording_preprocessed = preprocessing( + recording=recording, + preprocessing_params=preprocessing_params, + results_path=results_path_preprocessing, ) - # remove empty units - sorting = sorting.remove_empty_units() - except Exception as e: - # save log to results - results_path_sorting.mkdir() - shutil.copy(spikesorted_raw_output_folder / "spikeinterface_log.json", sorting_output_folder) - raise e + if recording_preprocessed is None: + raise Exception("Preprocessing failed") + else: + logger.info("Skipping preprocessing") + recording_preprocessed = recording - return recording + # Spike Sorting + sorter = sorting( + recording=recording_preprocessed, + sorting_params=sorting_params, + results_path=results_path_sorting, + ) diff --git a/src/spikeinterface_pipelines/preprocessing/preprocessing.py b/src/spikeinterface_pipelines/preprocessing/preprocessing.py index ac2aef0..0476e89 100644 --- a/src/spikeinterface_pipelines/preprocessing/preprocessing.py +++ b/src/spikeinterface_pipelines/preprocessing/preprocessing.py @@ -1,10 +1,10 @@ import warnings import numpy as np from pathlib import Path -from typing import Union import spikeinterface as si import spikeinterface.preprocessing as spre +from ..logger import logger from ..models import JobKwargs from .models import PreprocessingParamsModel @@ -15,7 +15,7 @@ def preprocessing( recording: si.BaseRecording, preprocessing_params: PreprocessingParamsModel = PreprocessingParamsModel(), - results_path: Path = Path("./results/"), + results_path: Path = Path("./results/preprocessing/"), job_kwargs: JobKwargs = JobKwargs(), ) -> si.BaseRecording | None: """ @@ -34,6 +34,7 @@ def preprocessing( duration_s: float Duration in seconds to use in the debug mode. """ + logger.info("[Preprocessing] \tRunning Preprocessing stage") si.set_global_job_kwargs(**job_kwargs.model_dump()) preprocessing_notes = "" @@ -43,18 +44,18 @@ def preprocessing( # preprocessing_output_folder = results_path / f"preprocessed_{recording_name}" # preprocessing_output_json = results_path / f"preprocessed_{recording_name}.json" - print("Preprocessing recording") - print(f"\tDuration: {np.round(recording.get_total_duration(), 2)} s") + logger.info(f"[Preprocessing] \tDuration: {np.round(recording.get_total_duration(), 2)} s") - # Phase shift correction - recording_ps_full = spre.phase_shift( - recording, - **preprocessing_params.phase_shift.model_dump() - ) + # # TODO - Phase shift correction + # recording = spre.phase_shift( + # recording, + # **preprocessing_params.phase_shift.model_dump() + # ) # Highpass filter + recording_hp_full = spre.highpass_filter( - recording_ps_full, + recording, **preprocessing_params.highpass_filter.model_dump() ) @@ -66,8 +67,8 @@ def preprocessing( dead_channel_mask = channel_labels == "dead" noise_channel_mask = channel_labels == "noise" out_channel_mask = channel_labels == "out" - print("\tBad channel detection:") - print(f"\t\t- dead channels - {np.sum(dead_channel_mask)}\n\t\t- noise channels - {np.sum(noise_channel_mask)}\n\t\t- out channels - {np.sum(out_channel_mask)}") + logger.info("[Preprocessing] \tBad channel detection:") + logger.info(f"[Preprocessing] \tdead channels - {np.sum(dead_channel_mask)}\n\t\t- noise channels - {np.sum(noise_channel_mask)}\n\t\t- out channels - {np.sum(out_channel_mask)}") dead_channel_ids = recording_hp_full.channel_ids[dead_channel_mask] noise_channel_ids = recording_hp_full.channel_ids[noise_channel_mask] out_channel_ids = recording_hp_full.channel_ids[out_channel_mask] @@ -75,13 +76,13 @@ def preprocessing( max_bad_channel_fraction_to_remove = preprocessing_params.max_bad_channel_fraction_to_remove if len(all_bad_channel_ids) >= int(max_bad_channel_fraction_to_remove * recording.get_num_channels()): - print(f"\tMore than {max_bad_channel_fraction_to_remove * 100}% bad channels ({len(all_bad_channel_ids)}). ") - print("Skipping further processing for this recording.") + logger.info(f"[Preprocessing] \tMore than {max_bad_channel_fraction_to_remove * 100}% bad channels ({len(all_bad_channel_ids)}). ") + logger.info("[Preprocessing] \tSkipping further processing for this recording.") preprocessing_notes += f"\n- Found {len(all_bad_channel_ids)} bad channels. Skipping further processing\n" return None if preprocessing_params.remove_out_channels: - print(f"\tRemoving {len(out_channel_ids)} out channels") + logger.info(f"[Preprocessing] \tRemoving {len(out_channel_ids)} out channels") recording_rm_out = recording_hp_full.remove_channels(out_channel_ids) preprocessing_notes += f"\n- Removed {len(out_channel_ids)} outside of the brain." else: @@ -103,14 +104,14 @@ def preprocessing( ) if preprocessing_params.remove_bad_channels: - print(f"\tRemoving {len(bad_channel_ids)} channels after {preprocessing_params.preprocessing_strategy} preprocessing") + logger.info(f"[Preprocessing] \tRemoving {len(bad_channel_ids)} channels after {preprocessing_params.preprocessing_strategy} preprocessing") recording_processed = recording_processed.remove_channels(bad_channel_ids) preprocessing_notes += f"\n- Removed {len(bad_channel_ids)} bad channels after preprocessing.\n" # Motion correction if preprocessing_params.motion_correction.compute: preset = preprocessing_params.motion_correction.preset - print(f"\tComputing motion correction with preset: {preset}") + logger.info(f"[Preprocessing] \tComputing motion correction with preset: {preset}") motion_folder = results_path / "motion_correction" recording_corrected = spre.correct_motion( recording_processed, preset=preset, @@ -118,7 +119,7 @@ def preprocessing( **job_kwargs.model_dump() ) if preprocessing_params.motion_correction.apply: - print("\tApplying motion correction") + logger.info("[Preprocessing] \tApplying motion correction") recording_processed = recording_corrected # recording_saved = recording_processed.save(folder=preprocessing_output_folder) diff --git a/src/spikeinterface_pipelines/sorting/__init__.py b/src/spikeinterface_pipelines/sorting/__init__.py index c4bfb2a..c9ffd2c 100644 --- a/src/spikeinterface_pipelines/sorting/__init__.py +++ b/src/spikeinterface_pipelines/sorting/__init__.py @@ -1 +1,2 @@ -from .models import SortingParamsModel \ No newline at end of file +from .models import SortingParamsModel +from .sorting import sorting \ No newline at end of file diff --git a/src/spikeinterface_pipelines/sorting/models.py b/src/spikeinterface_pipelines/sorting/models.py index eff36b9..fb0c21a 100644 --- a/src/spikeinterface_pipelines/sorting/models.py +++ b/src/spikeinterface_pipelines/sorting/models.py @@ -5,7 +5,7 @@ class SorterName(str, Enum): ironclust = "ironclust" - kilosort25 = "kilosort25" + kilosort25 = "kilosort2_5" kilosort3 = "kilosort3" mountainsort5 = "mountainsort5" @@ -46,7 +46,7 @@ class MountainSort5Model(BaseModel): class SortingParamsModel(BaseModel): - sorter_name: SorterName = Field(default="kilosort25", description="Name of the sorter to use.") + sorter_name: SorterName = Field(default="kilosort2_5", description="Name of the sorter to use.") sorter_kwargs: Union[ Kilosort25Model, Kilosort3Model, diff --git a/src/spikeinterface_pipelines/sorting/sorting.py b/src/spikeinterface_pipelines/sorting/sorting.py new file mode 100644 index 0000000..652da14 --- /dev/null +++ b/src/spikeinterface_pipelines/sorting/sorting.py @@ -0,0 +1,38 @@ +import spikeinterface.sorters as ss +import spikeinterface as si +from pathlib import Path +import shutil + +from ..logger import logger +from ..models import JobKwargs +from .models import SortingParamsModel + + +def sorting( + recording: si.BaseRecording, + sorting_params: SortingParamsModel = SortingParamsModel(), + results_path: Path = Path("./results/sorting/"), + job_kwargs: JobKwargs = JobKwargs(), +) -> si.BaseSorting | None: + try: + sorter = ss.run_sorter( + recording=recording, + output_folder=str(results_path / "tmp"), + verbose=False, + delete_output_folder=True, + **sorting_params.model_dump(), + ) + # remove empty units + sorter = sorter.remove_empty_units() + # save results + logger.info(f"\tSaving results to {results_path}") + sorter = sorter.save(folder=results_path) + return sorter + except Exception as e: + # save log to results + results_path.mkdir() + if (results_path / "tmp").exists(): + shutil.copy(results_path / "tmp/spikeinterface_log.json", results_path) + raise e + finally: + pass From 20fb4b9003ffb5f2056a70c02e09258432905d90 Mon Sep 17 00:00:00 2001 From: luiz Date: Mon, 30 Oct 2023 15:15:55 +0100 Subject: [PATCH 08/16] postprocessing --- src/spikeinterface_pipelines/pipeline.py | 5 + .../postprocessing/models.py | 142 +++++++++--------- .../postprocessing/postprocessing.py | 8 +- 3 files changed, 80 insertions(+), 75 deletions(-) diff --git a/src/spikeinterface_pipelines/pipeline.py b/src/spikeinterface_pipelines/pipeline.py index 087762e..4a992ca 100644 --- a/src/spikeinterface_pipelines/pipeline.py +++ b/src/spikeinterface_pipelines/pipeline.py @@ -4,6 +4,7 @@ from .logger import logger from .preprocessing import preprocessing, PreprocessingParamsModel from .sorting import sorting, SortingParamsModel +from .postprocessing import postprocessing, PostprocessingParamsModel # TODO - WIP @@ -12,6 +13,7 @@ def pipeline( results_path: Path = Path("./results/"), preprocessing_params: PreprocessingParamsModel = PreprocessingParamsModel(), sorting_params: SortingParamsModel = SortingParamsModel(), + postprocessing_params: PostprocessingParamsModel = PostprocessingParamsModel(), run_preprocessing: bool = True, ) -> None: # Paths @@ -38,3 +40,6 @@ def pipeline( sorting_params=sorting_params, results_path=results_path_sorting, ) + + # # TODO - Postprocessing + # postprocessing(postprocessing_params=postprocessing_params) diff --git a/src/spikeinterface_pipelines/postprocessing/models.py b/src/spikeinterface_pipelines/postprocessing/models.py index 1774c16..2c5cd8f 100644 --- a/src/spikeinterface_pipelines/postprocessing/models.py +++ b/src/spikeinterface_pipelines/postprocessing/models.py @@ -4,31 +4,31 @@ class PresenceRatio(BaseModel): - bin_duration_s: float = Field(60, description="Duration of the bin in seconds.") + bin_duration_s: float = Field(default=60, description="Duration of the bin in seconds.") class SNR(BaseModel): - peak_sign: str = Field("neg", description="Sign of the peak.") - peak_mode: str = Field("extremum", description="Mode of the peak.") - random_chunk_kwargs_dict: Optional[dict] = Field(None, description="Random chunk arguments.") + peak_sign: str = Field(default="neg", description="Sign of the peak.") + peak_mode: str = Field(default="extremum", description="Mode of the peak.") + random_chunk_kwargs_dict: Optional[dict] = Field(default=None, description="Random chunk arguments.") class ISIViolation(BaseModel): - isi_threshold_ms: float = Field(1.5, description="ISI threshold in milliseconds.") - min_isi_ms: float = Field(0., description="Minimum ISI in milliseconds.") + isi_threshold_ms: float = Field(default=1.5, description="ISI threshold in milliseconds.") + min_isi_ms: float = Field(default=0., description="Minimum ISI in milliseconds.") class RPViolation(BaseModel): - refractory_period_ms: float = Field(1., description="Refractory period in milliseconds.") - censored_period_ms: float = Field(0.0, description="Censored period in milliseconds.") + refractory_period_ms: float = Field(default=1., description="Refractory period in milliseconds.") + censored_period_ms: float = Field(default=0.0, description="Censored period in milliseconds.") class SlidingRPViolation(BaseModel): - bin_size_ms: float = Field(0.25, description="The size of binning for the autocorrelogram in ms, by default 0.25.") - window_size_s: float = Field(1, description="Window in seconds to compute correlogram, by default 1.") - exclude_ref_period_below_ms: float = Field(0.5, description="Refractory periods below this value are excluded, by default 0.5") - max_ref_period_ms: float = Field(10, description="Maximum refractory period to test in ms, by default 10 ms.") - contamination_values: Optional[list] = Field(None, description="The contamination values to test, by default np.arange(0.5, 35, 0.5) %") + bin_size_ms: float = Field(default=0.25, description="The size of binning for the autocorrelogram in ms, by default 0.25.") + window_size_s: float = Field(default=1, description="Window in seconds to compute correlogram, by default 1.") + exclude_ref_period_below_ms: float = Field(default=0.5, description="Refractory periods below this value are excluded, by default 0.5") + max_ref_period_ms: float = Field(default=10, description="Maximum refractory period to test in ms, by default 10 ms.") + contamination_values: Optional[list] = Field(default=None, description="The contamination values to test, by default np.arange(0.5, 35, 0.5) %") class PeakSign(str, Enum): @@ -38,106 +38,106 @@ class PeakSign(str, Enum): class AmplitudeCutoff(BaseModel): - peak_sign: PeakSign = Field("neg", description="The sign of the peaks.") - num_histogram_bins: int = Field(100, description="The number of bins to use to compute the amplitude histogram.") - histogram_smoothing_value: int = Field(3, description="Controls the smoothing applied to the amplitude histogram.") - amplitudes_bins_min_ratio: int = Field(5, description="The minimum ratio between number of amplitudes for a unit and the number of bins. If the ratio is less than this threshold, the amplitude_cutoff for the unit is set to NaN.") + peak_sign: PeakSign = Field(default="neg", description="The sign of the peaks.") + num_histogram_bins: int = Field(default=100, description="The number of bins to use to compute the amplitude histogram.") + histogram_smoothing_value: int = Field(default=3, description="Controls the smoothing applied to the amplitude histogram.") + amplitudes_bins_min_ratio: int = Field(default=5, description="The minimum ratio between number of amplitudes for a unit and the number of bins. If the ratio is less than this threshold, the amplitude_cutoff for the unit is set to NaN.") class AmplitudeMedian(BaseModel): - peak_sign: PeakSign = Field("neg", description="The sign of the peaks.") + peak_sign: PeakSign = Field(default="neg", description="The sign of the peaks.") class NearestNeighbor(BaseModel): - max_spikes: int = Field(10000, description="The number of spikes to use, per cluster. Note that the calculation can be very slow when this number is >20000.") - min_spikes: int = Field(10, description="Minimum number of spikes.") - n_neighbors: int = Field(4, description="The number of neighbors to use.") + max_spikes: int = Field(default=10000, description="The number of spikes to use, per cluster. Note that the calculation can be very slow when this number is >20000.") + min_spikes: int = Field(default=10, description="Minimum number of spikes.") + n_neighbors: int = Field(default=4, description="The number of neighbors to use.") class NNIsolation(NearestNeighbor): - n_components: int = Field(10, description="The number of PC components to use to project the snippets to.") - radius_um: int = Field(100, description="The radius, in um, that channels need to be within the peak channel to be included.") + n_components: int = Field(default=10, description="The number of PC components to use to project the snippets to.") + radius_um: int = Field(default=100, description="The radius, in um, that channels need to be within the peak channel to be included.") class QMParams(BaseModel): - presence_ratio: PresenceRatio - snr: SNR - isi_violation: ISIViolation - rp_violation: RPViolation - sliding_rp_violation: SlidingRPViolation - amplitude_cutoff: AmplitudeCutoff - amplitude_median: AmplitudeMedian - nearest_neighbor: NearestNeighbor - nn_isolation: NNIsolation - nn_noise_overlap: NNIsolation + presence_ratio: PresenceRatio = Field(default=PresenceRatio(), description="Presence ratio.") + snr: SNR = Field(default=SNR(), description="Signal to noise ratio.") + isi_violation: ISIViolation = Field(default=ISIViolation(), description="ISI violation.") + rp_violation: RPViolation = Field(default=RPViolation(), description="Refractory period violation.") + sliding_rp_violation: SlidingRPViolation = Field(default=SlidingRPViolation(), description="Sliding refractory period violation.") + amplitude_cutoff: AmplitudeCutoff = Field(default=AmplitudeCutoff(), description="Amplitude cutoff.") + amplitude_median: AmplitudeMedian = Field(default=AmplitudeMedian(), description="Amplitude median.") + nearest_neighbor: NearestNeighbor = Field(default=NearestNeighbor(), description="Nearest neighbor.") + nn_isolation: NNIsolation = Field(default=NNIsolation(), description="Nearest neighbor isolation.") + nn_noise_overlap: NNIsolation = Field(default=NNIsolation(), description="Nearest neighbor noise overlap.") class QualityMetrics(BaseModel): - qm_params: QMParams = Field(..., description="Quality metric parameters.") - metric_names: List[str] = Field(..., description="List of metric names to compute.") - n_jobs: int = Field(1, description="Number of jobs.") + qm_params: QMParams = Field(default=QMParams(), description="Quality metric parameters.") + metric_names: List[str] = Field(default=[], description="List of metric names to compute.") + n_jobs: int = Field(default=1, description="Number of jobs.") class Sparsity(BaseModel): - method: str = Field("radius", description="Method for determining sparsity.") - radius_um: int = Field(100, description="Radius in micrometers for sparsity.") + method: str = Field(default="radius", description="Method for determining sparsity.") + radius_um: int = Field(default=100, description="Radius in micrometers for sparsity.") class Waveforms(BaseModel): - ms_before: float = Field(3.0, description="Milliseconds before") - ms_after: float = Field(4.0, description="Milliseconds after") - max_spikes_per_unit: int = Field(500, description="Maximum spikes per unit") - return_scaled: bool = Field(True, description="Flag to determine if results should be scaled") - dtype: Optional[str] = Field(None, description="Data type for the waveforms") - precompute_template: Tuple[str, str] = Field(("average", "std"), description="Precomputation template method") - use_relative_path: bool = Field(True, description="Use relative paths") + ms_before: float = Field(default=3.0, description="Milliseconds before") + ms_after: float = Field(default=4.0, description="Milliseconds after") + max_spikes_per_unit: int = Field(default=500, description="Maximum spikes per unit") + return_scaled: bool = Field(default=True, description="Flag to determine if results should be scaled") + dtype: Optional[str] = Field(default=None, description="Data type for the waveforms") + precompute_template: Tuple[str, str] = Field(default=("average", "std"), description="Precomputation template method") + use_relative_path: bool = Field(default=True, description="Use relative paths") class SpikeAmplitudes(BaseModel): - peak_sign: str = Field("neg", description="Sign of the peak") - return_scaled: bool = Field(True, description="Flag to determine if amplitudes should be scaled") - outputs: str = Field("concatenated", description="Output format for the spike amplitudes") + peak_sign: str = Field(default="neg", description="Sign of the peak") + return_scaled: bool = Field(default=True, description="Flag to determine if amplitudes should be scaled") + outputs: str = Field(default="concatenated", description="Output format for the spike amplitudes") class Similarity(BaseModel): - method: str = Field("cosine_similarity", description="Method to compute similarity") + method: str = Field(default="cosine_similarity", description="Method to compute similarity") class Correlograms(BaseModel): - window_ms: float = Field(100.0, description="Size of the window in milliseconds") - bin_ms: float = Field(2.0, description="Size of the bin in milliseconds") + window_ms: float = Field(default=100.0, description="Size of the window in milliseconds") + bin_ms: float = Field(default=2.0, description="Size of the bin in milliseconds") class ISIS(BaseModel): - window_ms: float = Field(100.0, description="Size of the window in milliseconds") - bin_ms: float = Field(5.0, description="Size of the bin in milliseconds") + window_ms: float = Field(default=100.0, description="Size of the window in milliseconds") + bin_ms: float = Field(default=5.0, description="Size of the bin in milliseconds") class Locations(BaseModel): - method: str = Field("monopolar_triangulation", description="Method to determine locations") + method: str = Field(default="monopolar_triangulation", description="Method to determine locations") class TemplateMetrics(BaseModel): - upsampling_factor: int = Field(10, description="Upsampling factor") - sparsity: Optional[str] = Field(None, description="Sparsity method") + upsampling_factor: int = Field(default=10, description="Upsampling factor") + sparsity: Optional[str] = Field(default=None, description="Sparsity method") class PrincipalComponents(BaseModel): - n_components: int = Field(5, description="Number of principal components") - mode: str = Field("by_channel_local", description="Mode of principal component analysis") - whiten: bool = Field(True, description="Whiten the components") + n_components: int = Field(default=5, description="Number of principal components") + mode: str = Field(default="by_channel_local", description="Mode of principal component analysis") + whiten: bool = Field(default=True, description="Whiten the components") class PostprocessingParamsModel(BaseModel): - sparsity: Sparsity - waveforms_deduplicate: Waveforms - waveforms: Waveforms - spike_amplitudes: SpikeAmplitudes - similarity: Similarity - correlograms: Correlograms - isis: ISIS - locations: Locations - template_metrics: TemplateMetrics - principal_components: PrincipalComponents - quality_metrics: QualityMetrics - duplicate_threshold: float = Field(0.9, description="Duplicate threshold") + sparsity: Sparsity = Field(default=Sparsity(), description="Sparsity") + waveforms_deduplicate: Waveforms = Field(default=Waveforms(), description="Waveforms deduplicate") + waveforms: Waveforms = Field(default=Waveforms(), description="Waveforms") + spike_amplitudes: SpikeAmplitudes = Field(default=SpikeAmplitudes(), description="Spike amplitudes") + similarity: Similarity = Field(default=Similarity(), description="Similarity") + correlograms: Correlograms = Field(default=Correlograms(), description="Correlograms") + isis: ISIS = Field(default=ISIS(), description="ISIS") + locations: Locations = Field(default=Locations(), description="Locations") + template_metrics: TemplateMetrics = Field(default=TemplateMetrics(), description="Template metrics") + principal_components: PrincipalComponents = Field(default=PrincipalComponents(), description="Principal components") + quality_metrics: QualityMetrics = Field(default=QualityMetrics(), description="Quality metrics") + duplicate_threshold: float = Field(default=0.9, description="Duplicate threshold") diff --git a/src/spikeinterface_pipelines/postprocessing/postprocessing.py b/src/spikeinterface_pipelines/postprocessing/postprocessing.py index 1a096ed..82a8b09 100644 --- a/src/spikeinterface_pipelines/postprocessing/postprocessing.py +++ b/src/spikeinterface_pipelines/postprocessing/postprocessing.py @@ -18,11 +18,11 @@ warnings.filterwarnings("ignore") +# TODO - WIP def postprocessing( - job_kwargs: JobKwargs, - postprocessing_params: PostprocessingParamsModel, - data_folder: Path = Path("../data/"), - results_path: Path = Path("./results/"), + postprocessing_params: PostprocessingParamsModel = PostprocessingParamsModel(), + results_path: Path = Path("./results/postprocessing/"), + job_kwargs: JobKwargs = JobKwargs(), ) -> None: """ Postprocessing pipeline From de576c3023cad3663294bb54347069752a09fb68 Mon Sep 17 00:00:00 2001 From: luiz Date: Tue, 31 Oct 2023 11:43:53 +0100 Subject: [PATCH 09/16] fix args --- src/spikeinterface_pipelines/pipeline.py | 4 +++- src/spikeinterface_pipelines/sorting/models.py | 2 +- src/spikeinterface_pipelines/sorting/sorting.py | 5 +++-- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface_pipelines/pipeline.py b/src/spikeinterface_pipelines/pipeline.py index 4a992ca..8105365 100644 --- a/src/spikeinterface_pipelines/pipeline.py +++ b/src/spikeinterface_pipelines/pipeline.py @@ -15,7 +15,7 @@ def pipeline( sorting_params: SortingParamsModel = SortingParamsModel(), postprocessing_params: PostprocessingParamsModel = PostprocessingParamsModel(), run_preprocessing: bool = True, -) -> None: +) -> si.BaseSorting: # Paths results_path_preprocessing = results_path / "preprocessing" results_path_sorting = results_path / "sorting" @@ -41,5 +41,7 @@ def pipeline( results_path=results_path_sorting, ) + return sorter + # # TODO - Postprocessing # postprocessing(postprocessing_params=postprocessing_params) diff --git a/src/spikeinterface_pipelines/sorting/models.py b/src/spikeinterface_pipelines/sorting/models.py index fb0c21a..cbf169e 100644 --- a/src/spikeinterface_pipelines/sorting/models.py +++ b/src/spikeinterface_pipelines/sorting/models.py @@ -52,4 +52,4 @@ class SortingParamsModel(BaseModel): Kilosort3Model, IronClustModel, MountainSort5Model - ] = Field(default=Kilosort25Model(), description="Sorter specific kwargs.") \ No newline at end of file + ] = Field(default=Kilosort25Model(), description="Sorter specific kwargs.") diff --git a/src/spikeinterface_pipelines/sorting/sorting.py b/src/spikeinterface_pipelines/sorting/sorting.py index 652da14..9833e0b 100644 --- a/src/spikeinterface_pipelines/sorting/sorting.py +++ b/src/spikeinterface_pipelines/sorting/sorting.py @@ -17,10 +17,11 @@ def sorting( try: sorter = ss.run_sorter( recording=recording, + sorter_name=sorting_params.sorter_name, output_folder=str(results_path / "tmp"), verbose=False, delete_output_folder=True, - **sorting_params.model_dump(), + **sorting_params.sorter_kwargs.model_dump(), ) # remove empty units sorter = sorter.remove_empty_units() @@ -30,7 +31,7 @@ def sorting( return sorter except Exception as e: # save log to results - results_path.mkdir() + results_path.mkdir(exist_ok=True, parents=True) if (results_path / "tmp").exists(): shutil.copy(results_path / "tmp/spikeinterface_log.json", results_path) raise e From f50c5de250455a46929caefab676ac1e907a0aa3 Mon Sep 17 00:00:00 2001 From: luiz Date: Fri, 3 Nov 2023 12:06:55 +0100 Subject: [PATCH 10/16] version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 9e7956b..b377437 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "spikeinterface_pipelines" -version = "0.1.0" +version = "0.0.2" description = "Collection of standardized analysis pipelines based on SpikeInterfacee." readme = "README.md" authors = [ From a733006cd8d6f28ed6698fec2485366db75327c9 Mon Sep 17 00:00:00 2001 From: luiz Date: Mon, 6 Nov 2023 13:39:59 +0100 Subject: [PATCH 11/16] return objects --- src/spikeinterface_pipelines/pipeline.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface_pipelines/pipeline.py b/src/spikeinterface_pipelines/pipeline.py index 8105365..a19b726 100644 --- a/src/spikeinterface_pipelines/pipeline.py +++ b/src/spikeinterface_pipelines/pipeline.py @@ -1,4 +1,5 @@ from pathlib import Path +from typing import Tuple import spikeinterface as si from .logger import logger @@ -15,7 +16,7 @@ def pipeline( sorting_params: SortingParamsModel = SortingParamsModel(), postprocessing_params: PostprocessingParamsModel = PostprocessingParamsModel(), run_preprocessing: bool = True, -) -> si.BaseSorting: +) -> Tuple[si.BaseRecording, si.BaseSorting]: # Paths results_path_preprocessing = results_path / "preprocessing" results_path_sorting = results_path / "sorting" @@ -41,7 +42,7 @@ def pipeline( results_path=results_path_sorting, ) - return sorter - # # TODO - Postprocessing # postprocessing(postprocessing_params=postprocessing_params) + + return (recording_preprocessed, sorter) From 616acafbdbf848e94086bf2ac44e1721f5ac3173 Mon Sep 17 00:00:00 2001 From: luiz Date: Mon, 6 Nov 2023 13:44:31 +0100 Subject: [PATCH 12/16] comments --- src/spikeinterface_pipelines/pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface_pipelines/pipeline.py b/src/spikeinterface_pipelines/pipeline.py index a19b726..7d16911 100644 --- a/src/spikeinterface_pipelines/pipeline.py +++ b/src/spikeinterface_pipelines/pipeline.py @@ -5,7 +5,7 @@ from .logger import logger from .preprocessing import preprocessing, PreprocessingParamsModel from .sorting import sorting, SortingParamsModel -from .postprocessing import postprocessing, PostprocessingParamsModel +# from .postprocessing import postprocessing, PostprocessingParamsModel # TODO - WIP @@ -14,7 +14,7 @@ def pipeline( results_path: Path = Path("./results/"), preprocessing_params: PreprocessingParamsModel = PreprocessingParamsModel(), sorting_params: SortingParamsModel = SortingParamsModel(), - postprocessing_params: PostprocessingParamsModel = PostprocessingParamsModel(), + # postprocessing_params: PostprocessingParamsModel = PostprocessingParamsModel(), run_preprocessing: bool = True, ) -> Tuple[si.BaseRecording, si.BaseSorting]: # Paths From 4334de6cb1503cccf7960a3db415fedf079c1f95 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 15 Nov 2023 13:59:11 +0100 Subject: [PATCH 13/16] Update pipeline: renaming, postprocessing, tests, and CI --- in_progress/hd_shank_preprocessing.py | 151 ----------- pyproject.toml | 3 + src/spikeinterface_pipelines/__init__.py | 2 +- .../{models.py => global_params.py} | 2 +- src/spikeinterface_pipelines/logger.py | 3 +- src/spikeinterface_pipelines/pipeline.py | 66 +++-- .../postprocessing/__init__.py | 4 +- .../postprocessing/{models.py => params.py} | 15 +- .../postprocessing/postprocessing.py | 243 +++++++----------- .../preprocessing/__init__.py | 4 +- .../preprocessing/{models.py => params.py} | 2 +- .../preprocessing/preprocessing.py | 79 +++--- .../sorting/__init__.py | 2 - .../sorting/sorting.py | 39 --- .../spikesorting/__init__.py | 2 + .../models.py => spikesorting/params.py} | 4 +- .../spikesorting/spikesorting.py | 63 +++++ tests/test_pipeline.py | 132 ++++++++++ 18 files changed, 396 insertions(+), 420 deletions(-) delete mode 100644 in_progress/hd_shank_preprocessing.py rename src/spikeinterface_pipelines/{models.py => global_params.py} (72%) rename src/spikeinterface_pipelines/postprocessing/{models.py => params.py} (89%) rename src/spikeinterface_pipelines/preprocessing/{models.py => params.py} (98%) delete mode 100644 src/spikeinterface_pipelines/sorting/__init__.py delete mode 100644 src/spikeinterface_pipelines/sorting/sorting.py create mode 100644 src/spikeinterface_pipelines/spikesorting/__init__.py rename src/spikeinterface_pipelines/{sorting/models.py => spikesorting/params.py} (95%) create mode 100644 src/spikeinterface_pipelines/spikesorting/spikesorting.py create mode 100644 tests/test_pipeline.py diff --git a/in_progress/hd_shank_preprocessing.py b/in_progress/hd_shank_preprocessing.py deleted file mode 100644 index 45eabc3..0000000 --- a/in_progress/hd_shank_preprocessing.py +++ /dev/null @@ -1,151 +0,0 @@ -from pathlib import Path -import numpy as np -from pydantic import BaseModel, Field -import spikeinterface as si -import spikeinterface.preprocessing as spre - - -######################################################################################################################## -# Preprocessing parameters - -class PhaseShiftParameters(BaseModel): - margin_ms: float = Field(default=100.0, description="Margin in ms to use for phase shift") - -class HighpassFilterParameters(BaseModel): - freq_min: float = Field(default=300.0, description="Minimum frequency in Hz") - margin_ms: float = Field(default=5.0, description="Margin in ms to use for highpass filter") - -class DetectBadChannelsParameters(BaseModel): - method: str = Field(default="coherence+psd", description="Method to use for detecting bad channels: 'coherence+psd' or ...") - dead_channel_threshold: float = Field(default=-0.5, description="Threshold for dead channels") - noisy_channel_threshold: float = Field(default=1.0, description="Threshold for noisy channels") - outside_channel_threshold: float = Field(default=-0.3, description="Threshold for outside channels") - n_neighbors: int = Field(default=11, description="Number of neighbors to use for bad channel detection") - seed: int = Field(default=0, description="Seed for random number generator") - -class CommonReferenceParameters(BaseModel): - reference: str = Field(default="global", description="Reference to use for common reference: 'global' or ...") - operator: str = Field(default="median", description="Operator to use for common reference: 'median' or ...") - -class HighpassSpatialFilterParameters(BaseModel): - n_channel_pad: int = Field(default=60, description="Number of channels to pad") - n_channel_taper: int = Field(default=None, description="Number of channels to taper") - direction: str = Field(default="y", description="Direction to use for highpass spatial filter: 'y' or ...") - apply_agc: bool = Field(default=True, description="Whether to apply automatic gain control") - agc_window_length_s: float = Field(default=0.01, description="Window length in seconds for automatic gain control") - highpass_butter_order: int = Field(default=3, description="Butterworth order for highpass filter") - highpass_butter_wn: float = Field(default=0.01, description="Butterworth wn for highpass filter") - -class HDShankPreprocessingParameters(BaseModel): - preprocessing_strategy: str = Field(default="cmr", description="Preprocessing strategy to use: destripe or cmr") - highpass_filter: HighpassFilterParameters = Field(default_factory=HighpassFilterParameters, description="Highpass filter parameters") - phase_shift: PhaseShiftParameters = Field(default_factory=PhaseShiftParameters, description="Phase shift parameters") - detect_bad_channels: DetectBadChannelsParameters = Field(default_factory=DetectBadChannelsParameters, description="Detect bad channels parameters") - remove_out_channels: bool = Field(default=True, description="Whether to remove out channels") - remove_bad_channels: bool = Field(default=True, description="Whether to remove bad channels") - max_bad_channel_fraction_to_remove: float = Field(default=0.5, description="Maximum fraction of bad channels to remove") - common_reference: CommonReferenceParameters = Field(default_factory=CommonReferenceParameters, description="Common reference parameters") - highpass_spatial_filter: HighpassSpatialFilterParameters = Field(default_factory=HighpassSpatialFilterParameters, description="Highpass spatial filter parameters") - -######################################################################################################################## - -def hd_shank_preprocessing( - recording: si.BaseRecording, - params: HDShankPreprocessingParameters, - preprocessed_output_folder: Path, - verbose: bool = False -): - if "inter_sample_shift" in recording.get_property_keys(): - recording_ps_full = spre.phase_shift( - recording, - margin_ms=params.phase_shift.margin_ms - ) - else: - recording_ps_full = recording - - recording_hp_full = spre.highpass_filter( - recording_ps_full, - freq_min=params.highpass_filter.freq_min, - margin_ms=params.highpass_filter.margin_ms - ) - - # IBL bad channel detection - _, channel_labels = spre.detect_bad_channels( - recording_hp_full, - method=params.detect_bad_channels.method, - dead_channel_threshold=params.detect_bad_channels.dead_channel_threshold, - noisy_channel_threshold=params.detect_bad_channels.noisy_channel_threshold, - outside_channel_threshold=params.detect_bad_channels.outside_channel_threshold, - n_neighbors=params.detect_bad_channels.n_neighbors, - seed=params.detect_bad_channels.seed - ) - - dead_channel_mask = channel_labels == "dead" - noise_channel_mask = channel_labels == "noise" - out_channel_mask = channel_labels == "out" - - if verbose: - print("\tBad channel detection:") - print( - f"\t\t- dead channels - {np.sum(dead_channel_mask)}\n\t\t- noise channels - {np.sum(noise_channel_mask)}\n\t\t- out channels - {np.sum(out_channel_mask)}" - ) - dead_channel_ids = recording_hp_full.channel_ids[dead_channel_mask] - noise_channel_ids = recording_hp_full.channel_ids[noise_channel_mask] - out_channel_ids = recording_hp_full.channel_ids[out_channel_mask] - - all_bad_channel_ids = np.concatenate((dead_channel_ids, noise_channel_ids, out_channel_ids)) - - max_bad_channel_fraction_to_remove = params.max_bad_channel_fraction_to_remove - - # skip_processing = False - if len(all_bad_channel_ids) >= int( - max_bad_channel_fraction_to_remove * recording.get_num_channels() - ): - # always print this message even if verbose is False? - print( - f"\tMore than {max_bad_channel_fraction_to_remove * 100}% bad channels ({len(all_bad_channel_ids)}). " - f"Skipping further processing for this recording." - ) - # skip_processing = True - recording_ret = recording_hp_full - else: - if params.remove_out_channels: - if verbose: - print(f"\tRemoving {len(out_channel_ids)} out channels") - recording_rm_out = recording_hp_full.remove_channels(out_channel_ids) - else: - recording_rm_out = recording_hp_full - - recording_processed_cmr = spre.common_reference( - recording_rm_out, - reference=params.common_reference.reference, - operator=params.common_reference.operator - ) - - bad_channel_ids = np.concatenate((dead_channel_ids, noise_channel_ids)) - recording_interp = spre.interpolate_bad_channels(recording_rm_out, bad_channel_ids) - recording_hp_spatial = spre.highpass_spatial_filter( - recording_interp, - n_channel_pad=params.highpass_spatial_filter.n_channel_pad, - n_channel_taper=params.highpass_spatial_filter.n_channel_taper, - direction=params.highpass_spatial_filter.direction, - apply_agc=params.highpass_spatial_filter.apply_agc, - agc_window_length_s=params.highpass_spatial_filter.agc_window_length_s, - highpass_butter_order=params.highpass_spatial_filter.highpass_butter_order, - highpass_butter_wn=params.highpass_spatial_filter.highpass_butter_wn, - ) - - preproc_strategy = params.preprocessing_strategy - if preproc_strategy == "cmr": - recording_processed = recording_processed_cmr - else: - recording_processed = recording_hp_spatial - - if params.remove_bad_channels: - if verbose: - print(f"\tRemoving {len(bad_channel_ids)} channels after {preproc_strategy} preprocessing") - recording_processed = recording_processed.remove_channels(bad_channel_ids) - recording_name = 'recording' # not sure what this should be - recording_saved = recording_processed.save(folder=preprocessed_output_folder / recording_name) - recording_ret = recording_saved - return recording_ret diff --git a/pyproject.toml b/pyproject.toml index b377437..827aec2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,3 +31,6 @@ package-dir = { "" = "src" } [tool.setuptools.packages.find] where = ["src"] + +[tool.black] +line-length = 120 diff --git a/src/spikeinterface_pipelines/__init__.py b/src/spikeinterface_pipelines/__init__.py index b211130..2ae7cba 100644 --- a/src/spikeinterface_pipelines/__init__.py +++ b/src/spikeinterface_pipelines/__init__.py @@ -1 +1 @@ -from .pipeline import pipeline \ No newline at end of file +from .pipeline import run_pipeline \ No newline at end of file diff --git a/src/spikeinterface_pipelines/models.py b/src/spikeinterface_pipelines/global_params.py similarity index 72% rename from src/spikeinterface_pipelines/models.py rename to src/spikeinterface_pipelines/global_params.py index 1a64bee..73c79f6 100644 --- a/src/spikeinterface_pipelines/models.py +++ b/src/spikeinterface_pipelines/global_params.py @@ -4,4 +4,4 @@ class JobKwargs(BaseModel): n_jobs: int = Field(default=-1, description="The number of jobs to run in parallel.") chunk_duration: str = Field(default="1s", description="The duration of the chunks to process.") - progress_bar: bool = Field(default=True, description="Whether to display a progress bar.") + progress_bar: bool = Field(default=False, description="Whether to display a progress bar.") diff --git a/src/spikeinterface_pipelines/logger.py b/src/spikeinterface_pipelines/logger.py index d098728..0c8a7ea 100644 --- a/src/spikeinterface_pipelines/logger.py +++ b/src/spikeinterface_pipelines/logger.py @@ -1,5 +1,4 @@ import logging - logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) \ No newline at end of file +logger = logging.getLogger(__name__) diff --git a/src/spikeinterface_pipelines/pipeline.py b/src/spikeinterface_pipelines/pipeline.py index 7d16911..6eb869b 100644 --- a/src/spikeinterface_pipelines/pipeline.py +++ b/src/spikeinterface_pipelines/pipeline.py @@ -1,33 +1,48 @@ from pathlib import Path +import re from typing import Tuple + import spikeinterface as si from .logger import logger -from .preprocessing import preprocessing, PreprocessingParamsModel -from .sorting import sorting, SortingParamsModel -# from .postprocessing import postprocessing, PostprocessingParamsModel +from .global_params import JobKwargs +from .preprocessing import preprocess, PreprocessingParams +from .spikesorting import spikesort, SpikeSortingParams +from .postprocessing import postprocess, PostprocessingParams # TODO - WIP -def pipeline( +def run_pipeline( recording: si.BaseRecording, - results_path: Path = Path("./results/"), - preprocessing_params: PreprocessingParamsModel = PreprocessingParamsModel(), - sorting_params: SortingParamsModel = SortingParamsModel(), - # postprocessing_params: PostprocessingParamsModel = PostprocessingParamsModel(), + scratch_folder: Path = Path("./scratch/"), + results_folder: Path = Path("./results/"), + job_kwargs: JobKwargs = JobKwargs(), + preprocessing_params: PreprocessingParams = PreprocessingParams(), + spikesorting_params: SpikeSortingParams = SpikeSortingParams(), + postprocessing_params: PostprocessingParams = PostprocessingParams(), run_preprocessing: bool = True, -) -> Tuple[si.BaseRecording, si.BaseSorting]: +) -> Tuple[si.BaseRecording, si.BaseSorting, si.WaveformExtractor]: + + # Create folders + scratch_folder.mkdir(exist_ok=True, parents=True) + results_folder.mkdir(exist_ok=True, parents=True) + # Paths - results_path_preprocessing = results_path / "preprocessing" - results_path_sorting = results_path / "sorting" + results_folder_preprocessing = results_folder / "preprocessing" + results_folder_spikesorting = results_folder / "spikesorting" + results_folder_postprocessing = results_folder / "postprocessing" + + # set global job kwargs + si.set_global_job_kwargs(**job_kwargs.model_dump()) # Preprocessing if run_preprocessing: logger.info("Preprocessing recording") - recording_preprocessed = preprocessing( + recording_preprocessed = preprocess( recording=recording, preprocessing_params=preprocessing_params, - results_path=results_path_preprocessing, + scratch_folder=scratch_folder, + results_folder=results_folder_preprocessing, ) if recording_preprocessed is None: raise Exception("Preprocessing failed") @@ -36,13 +51,26 @@ def pipeline( recording_preprocessed = recording # Spike Sorting - sorter = sorting( + sorting = spikesort( recording=recording_preprocessed, - sorting_params=sorting_params, - results_path=results_path_sorting, + scratch_folder=scratch_folder, + spikesorting_params=spikesorting_params, + results_folder=results_folder_spikesorting, ) + if sorting is None: + raise Exception("Spike sorting failed") + + # Postprocessing + waveform_extractor = postprocess( + recording=recording_preprocessed, + sorting=sorting, + postprocessing_params=postprocessing_params, + scratch_folder=scratch_folder, + results_folder=results_folder_postprocessing, + ) + + # TODO: Curation - # # TODO - Postprocessing - # postprocessing(postprocessing_params=postprocessing_params) + # TODO: Visualization - return (recording_preprocessed, sorter) + return (recording_preprocessed, sorting, waveform_extractor) diff --git a/src/spikeinterface_pipelines/postprocessing/__init__.py b/src/spikeinterface_pipelines/postprocessing/__init__.py index 6e4b53b..e94f311 100644 --- a/src/spikeinterface_pipelines/postprocessing/__init__.py +++ b/src/spikeinterface_pipelines/postprocessing/__init__.py @@ -1,2 +1,2 @@ -from .postprocessing import postprocessing -from .models import PostprocessingParamsModel \ No newline at end of file +from .postprocessing import postprocess +from .params import PostprocessingParams \ No newline at end of file diff --git a/src/spikeinterface_pipelines/postprocessing/models.py b/src/spikeinterface_pipelines/postprocessing/params.py similarity index 89% rename from src/spikeinterface_pipelines/postprocessing/models.py rename to src/spikeinterface_pipelines/postprocessing/params.py index 2c5cd8f..fa013ad 100644 --- a/src/spikeinterface_pipelines/postprocessing/models.py +++ b/src/spikeinterface_pipelines/postprocessing/params.py @@ -74,7 +74,7 @@ class QMParams(BaseModel): class QualityMetrics(BaseModel): qm_params: QMParams = Field(default=QMParams(), description="Quality metric parameters.") - metric_names: List[str] = Field(default=[], description="List of metric names to compute.") + metric_names: List[str] = Field(default=None, description="List of metric names to compute.") n_jobs: int = Field(default=1, description="Number of jobs.") @@ -83,6 +83,15 @@ class Sparsity(BaseModel): radius_um: int = Field(default=100, description="Radius in micrometers for sparsity.") +class WaveformsRaw(BaseModel): + ms_before: float = Field(default=1.0, description="Milliseconds before") + ms_after: float = Field(default=2.0, description="Milliseconds after") + max_spikes_per_unit: int = Field(default=100, description="Maximum spikes per unit") + return_scaled: bool = Field(default=True, description="Flag to determine if results should be scaled") + dtype: Optional[str] = Field(default=None, description="Data type for the waveforms") + precompute_template: Tuple[str, str] = Field(default=("average", "std"), description="Precomputation template method") + use_relative_path: bool = Field(default=True, description="Use relative paths") + class Waveforms(BaseModel): ms_before: float = Field(default=3.0, description="Milliseconds before") ms_after: float = Field(default=4.0, description="Milliseconds after") @@ -128,9 +137,9 @@ class PrincipalComponents(BaseModel): whiten: bool = Field(default=True, description="Whiten the components") -class PostprocessingParamsModel(BaseModel): +class PostprocessingParams(BaseModel): sparsity: Sparsity = Field(default=Sparsity(), description="Sparsity") - waveforms_deduplicate: Waveforms = Field(default=Waveforms(), description="Waveforms deduplicate") + waveforms_raw: WaveformsRaw = Field(default=WaveformsRaw(), description="Waveforms raw") waveforms: Waveforms = Field(default=Waveforms(), description="Waveforms") spike_amplitudes: SpikeAmplitudes = Field(default=SpikeAmplitudes(), description="Spike amplitudes") similarity: Similarity = Field(default=Similarity(), description="Similarity") diff --git a/src/spikeinterface_pipelines/postprocessing/postprocessing.py b/src/spikeinterface_pipelines/postprocessing/postprocessing.py index 82a8b09..e2079eb 100644 --- a/src/spikeinterface_pipelines/postprocessing/postprocessing.py +++ b/src/spikeinterface_pipelines/postprocessing/postprocessing.py @@ -1,171 +1,112 @@ import warnings -import os -import numpy as np from pathlib import Path import shutil -import json -import time -from datetime import datetime + import spikeinterface as si import spikeinterface.postprocessing as spost import spikeinterface.qualitymetrics as sqm import spikeinterface.curation as sc -from ..models import JobKwargs -from .models import PostprocessingParamsModel +from .params import PostprocessingParams +from ..logger import logger warnings.filterwarnings("ignore") -# TODO - WIP -def postprocessing( - postprocessing_params: PostprocessingParamsModel = PostprocessingParamsModel(), - results_path: Path = Path("./results/postprocessing/"), - job_kwargs: JobKwargs = JobKwargs(), -) -> None: +def postprocess( + recording: si.BaseRecording, + sorting: si.BaseSorting, + postprocessing_params: PostprocessingParams = PostprocessingParams(), + scratch_folder: Path = Path("./scratch/"), + results_folder: Path = Path("./results/postprocessing/"), +) -> si.WaveformExtractor: """ - Postprocessing pipeline + Postprocess preprocessed and spike sorting output Parameters ---------- - data_folder: Path - Path to the data folder + recording: si.BaseRecording + The input recording + sorting: si.BaseSorting + The input sorting + postprocessing_params: PostprocessingParams + Postprocessing parameters results_folder: Path Path to the results folder - job_kwargs: JobKwargs - Job kwargs - postprocessing_params: PostprocessingParamsModel - Postprocessing parameters + + Returns + ------- + si.WaveformExtractor + The waveform extractor """ - si.set_global_job_kwargs(**job_kwargs.model_dump()) - - tmp_folder = results_path / "tmp" - tmp_folder.mkdir() - - data_process_prefix = "data_process_postprocessing" - print("\nPOSTPROCESSING") - t_postprocessing_start_all = time.perf_counter() - - # check if test - if (data_folder / "preprocessing_pipeline_output_test").is_dir(): - print("\n*******************\n**** TEST MODE ****\n*******************\n") - preprocessed_folder = data_folder / "preprocessing_pipeline_output_test" - spikesorted_folder = data_folder / "spikesorting_pipeline_output_test" - else: - preprocessed_folder = data_folder - spikesorted_folder = data_folder - - preprocessed_folders = [p for p in preprocessed_folder.iterdir() if p.is_dir() and "preprocessed_" in p.name] - - # load job json files - job_config_json_files = [p for p in data_folder.iterdir() if p.suffix == ".json" and "job" in p.name] - print(f"Found {len(job_config_json_files)} json configurations") - - if len(job_config_json_files) > 0: - recording_names = [] - for json_file in job_config_json_files: - with open(json_file, "r") as f: - config = json.load(f) - recording_name = config["recording_name"] - assert (preprocessed_folder / f"preprocessed_{recording_name}").is_dir(), f"Preprocessed folder for {recording_name} not found!" - recording_names.append(recording_name) - else: - recording_names = [("_").join(p.name.split("_")[1:]) for p in preprocessed_folders] - - for recording_name in recording_names: - datetime_start_postprocessing = datetime.now() - t_postprocessing_start = time.perf_counter() - postprocessing_notes = "" - - print(f"\tProcessing {recording_name}") - postprocessing_output_process_json = results_folder / f"{data_process_prefix}_{recording_name}.json" - postprocessing_output_folder = results_folder / f"postprocessed_{recording_name}" - postprocessing_sorting_output_folder = results_folder / f"postprocessed-sorting_{recording_name}" - - recording = si.load_extractor(preprocessed_folder / f"preprocessed_{recording_name}") - # make sure we have spikesorted output for the block-stream - sorted_folder = spikesorted_folder / f"spikesorted_{recording_name}" - if not sorted_folder.is_dir(): - raise FileNotFoundError(f"Spike sorted data for {recording_name} not found!") - - sorting = si.load_extractor(sorted_folder) - - # first extract some raw waveforms in memory to deduplicate based on peak alignment - wf_dedup_folder = tmp_folder / "postprocessed" / recording_name - we_raw = si.extract_waveforms( - recording, - sorting, - folder=wf_dedup_folder, - **postprocessing_params.waveforms_deduplicate.model_dump() - ) - - # de-duplication - sorting_deduplicated = sc.remove_redundant_units( - we_raw, - duplicate_threshold=postprocessing_params.duplicate_threshold - ) - print(f"\tNumber of original units: {len(we_raw.sorting.unit_ids)} -- Number of units after de-duplication: {len(sorting_deduplicated.unit_ids)}") - n_duplicated = int(len(sorting.unit_ids) - len(sorting_deduplicated.unit_ids)) - postprocessing_notes += f"\n- Removed {n_duplicated} duplicated units.\n" - deduplicated_unit_ids = sorting_deduplicated.unit_ids - - # use existing deduplicated waveforms to compute sparsity - sparsity_raw = si.compute_sparsity(we_raw, **postprocessing_params.sparsity.model_dump()) - sparsity_mask = sparsity_raw.mask[sorting.ids_to_indices(deduplicated_unit_ids), :] - sparsity = si.ChannelSparsity( - mask=sparsity_mask, - unit_ids=deduplicated_unit_ids, - channel_ids=recording.channel_ids - ) - shutil.rmtree(wf_dedup_folder) - del we_raw - - # this is a trick to make the postprocessed folder "self-contained - sorting_deduplicated = sorting_deduplicated.save(folder=postprocessing_sorting_output_folder) - - # now extract waveforms on de-duplicated units - print("\tSaving sparse de-duplicated waveform extractor folder") - we = si.extract_waveforms( - recording, - sorting_deduplicated, - folder=postprocessing_output_folder, - sparsity=sparsity, - sparse=True, - overwrite=True, - **postprocessing_params.waveforms.model_dump() - ) - - print("\tComputing spike amplitides") - amps = spost.compute_spike_amplitudes(we, **postprocessing_params.spike_amplitudes.model_dump()) - - print("\tComputing unit locations") - unit_locs = spost.compute_unit_locations(we, **postprocessing_params.locations.model_dump()) - - print("\tComputing spike locations") - spike_locs = spost.compute_spike_locations(we, **postprocessing_params.locations.model_dump()) - - print("\tComputing correlograms") - corr = spost.compute_correlograms(we, **postprocessing_params.correlograms.model_dump()) - - print("\tComputing ISI histograms") - tm = spost.compute_isi_histograms(we, **postprocessing_params.isis.model_dump()) - - print("\tComputing template similarity") - sim = spost.compute_template_similarity(we, **postprocessing_params.similarity.model_dump()) - - print("\tComputing template metrics") - tm = spost.compute_template_metrics(we, **postprocessing_params.template_metrics.model_dump()) - - print("\tComputing PCA") - pc = spost.compute_principal_components(we, **postprocessing_params.principal_components.model_dump()) - - print("\tComputing quality metrics") - qm = sqm.compute_quality_metrics(we, **postprocessing_params.quality_metrics.model_dump()) - - t_postprocessing_end = time.perf_counter() - elapsed_time_postprocessing = np.round(t_postprocessing_end - t_postprocessing_start, 2) - - t_postprocessing_end_all = time.perf_counter() - elapsed_time_postprocessing_all = np.round(t_postprocessing_end_all - t_postprocessing_start_all, 2) - print(f"POSTPROCESSING time: {elapsed_time_postprocessing_all}s") + + tmp_folder = scratch_folder / "tmp_postprocessing" + tmp_folder.mkdir(parents=True, exist_ok=True) + + # first extract some raw waveforms in memory to deduplicate based on peak alignment + wf_dedup_folder = tmp_folder / "waveforms_dense" + waveform_extractor_raw = si.extract_waveforms( + recording, + sorting, + folder=wf_dedup_folder, + sparse=False, + **postprocessing_params.waveforms_raw.model_dump() + ) + + # de-duplication + sorting_deduplicated = sc.remove_redundant_units( + waveform_extractor_raw, + duplicate_threshold=postprocessing_params.duplicate_threshold + ) + logger.info(f"[Postprocessing] \tNumber of original units: {len(waveform_extractor_raw.sorting.unit_ids)} -- Number of units after de-duplication: {len(sorting_deduplicated.unit_ids)}") + deduplicated_unit_ids = sorting_deduplicated.unit_ids + + # use existing deduplicated waveforms to compute sparsity + sparsity_raw = si.compute_sparsity(waveform_extractor_raw, **postprocessing_params.sparsity.model_dump()) + sparsity_mask = sparsity_raw.mask[sorting.ids_to_indices(deduplicated_unit_ids), :] + sparsity = si.ChannelSparsity( + mask=sparsity_mask, + unit_ids=deduplicated_unit_ids, + channel_ids=recording.channel_ids + ) + + # this is a trick to make the postprocessed folder "self-contained + sorting_folder = results_folder / "sorting" + sorting_deduplicated = sorting_deduplicated.save(folder=sorting_folder) + + # now extract waveforms on de-duplicated units + logger.info("[Postprocessing] \tSaving sparse de-duplicated waveform extractor folder") + waveform_extractor = si.extract_waveforms( + recording, + sorting_deduplicated, + folder=results_folder / "waveforms", + sparsity=sparsity, + sparse=True, + overwrite=True, + **postprocessing_params.waveforms.model_dump() + ) + + logger.info("[Postprocessing] \tComputing spike amplitides") + _ = spost.compute_spike_amplitudes(waveform_extractor, **postprocessing_params.spike_amplitudes.model_dump()) + logger.info("[Postprocessing] \tComputing unit locations") + _ = spost.compute_unit_locations(waveform_extractor, **postprocessing_params.locations.model_dump()) + logger.info("[Postprocessing] \tComputing spike locations") + _ = spost.compute_spike_locations(waveform_extractor, **postprocessing_params.locations.model_dump()) + logger.info("[Postprocessing] \tComputing correlograms") + _ = spost.compute_correlograms(waveform_extractor, **postprocessing_params.correlograms.model_dump()) + logger.info("[Postprocessing] \tComputing ISI histograms") + _ = spost.compute_isi_histograms(waveform_extractor, **postprocessing_params.isis.model_dump()) + logger.info("[Postprocessing] \tComputing template similarity") + _ = spost.compute_template_similarity(waveform_extractor, **postprocessing_params.similarity.model_dump()) + logger.info("[Postprocessing] \tComputing template metrics") + _ = spost.compute_template_metrics(waveform_extractor, **postprocessing_params.template_metrics.model_dump()) + logger.info("[Postprocessing] \tComputing PCA") + _ = spost.compute_principal_components(waveform_extractor, **postprocessing_params.principal_components.model_dump()) + logger.info("[Postprocessing] \tComputing quality metrics") + _ = sqm.compute_quality_metrics(waveform_extractor, **postprocessing_params.quality_metrics.model_dump()) + + # cleanup + shutil.rmtree(tmp_folder) + + return waveform_extractor diff --git a/src/spikeinterface_pipelines/preprocessing/__init__.py b/src/spikeinterface_pipelines/preprocessing/__init__.py index 445fde2..8dbd5e0 100644 --- a/src/spikeinterface_pipelines/preprocessing/__init__.py +++ b/src/spikeinterface_pipelines/preprocessing/__init__.py @@ -1,2 +1,2 @@ -from .preprocessing import preprocessing -from .models import PreprocessingParamsModel \ No newline at end of file +from .preprocessing import preprocess +from .params import PreprocessingParams \ No newline at end of file diff --git a/src/spikeinterface_pipelines/preprocessing/models.py b/src/spikeinterface_pipelines/preprocessing/params.py similarity index 98% rename from src/spikeinterface_pipelines/preprocessing/models.py rename to src/spikeinterface_pipelines/preprocessing/params.py index bd4dcdd..315f390 100644 --- a/src/spikeinterface_pipelines/preprocessing/models.py +++ b/src/spikeinterface_pipelines/preprocessing/params.py @@ -47,7 +47,7 @@ class MotionCorrection(BaseModel): preset: str = Field(default="nonrigid_accurate", description="Preset for motion correction") -class PreprocessingParamsModel(BaseModel): +class PreprocessingParams(BaseModel): preprocessing_strategy: PreprocessingStrategy = Field(default="cmr", description="Strategy for preprocessing") highpass_filter: HighpassFilter = Field(default=HighpassFilter(), description="Highpass filter") phase_shift: PhaseShift = Field(default=PhaseShift(), description="Phase shift") diff --git a/src/spikeinterface_pipelines/preprocessing/preprocessing.py b/src/spikeinterface_pipelines/preprocessing/preprocessing.py index 0476e89..d8b1578 100644 --- a/src/spikeinterface_pipelines/preprocessing/preprocessing.py +++ b/src/spikeinterface_pipelines/preprocessing/preprocessing.py @@ -1,65 +1,63 @@ import warnings import numpy as np from pathlib import Path + import spikeinterface as si import spikeinterface.preprocessing as spre from ..logger import logger -from ..models import JobKwargs -from .models import PreprocessingParamsModel +from .params import PreprocessingParams warnings.filterwarnings("ignore") -def preprocessing( +def preprocess( recording: si.BaseRecording, - preprocessing_params: PreprocessingParamsModel = PreprocessingParamsModel(), - results_path: Path = Path("./results/preprocessing/"), - job_kwargs: JobKwargs = JobKwargs(), -) -> si.BaseRecording | None: + preprocessing_params: PreprocessingParams = PreprocessingParams(), + scratch_folder: Path = Path("./scratch/"), + results_folder: Path = Path("./results/preprocessing/"), +) -> si.BaseRecording: """ - Preprocessing pipeline for ephys data. + Apply preprocessing to recording. Parameters ---------- recording: si.BaseRecording - Recording extractor. - preprocessing_params: PreprocessingParamsModel - Preprocessing parameters. - results_path: Path - Path to the results folder. - debug: bool - Flag to run in debug mode. - duration_s: float - Duration in seconds to use in the debug mode. + The input recording + preprocessing_params: PreprocessingParams + Preprocessing parameters + scratch_folder: Path + Path to the scratch folder + results_folder: Path + Path to the results folder + + Returns + ------- + si.BaseRecording | None + Preprocessed recording. If more than `max_bad_channel_fraction_to_remove` channels are detected as bad, + returns None. """ logger.info("[Preprocessing] \tRunning Preprocessing stage") - si.set_global_job_kwargs(**job_kwargs.model_dump()) - - preprocessing_notes = "" - - # recording_name = recording.name - # preprocessing_output_process_json = results_path / f"{data_process_prefix}_{recording_name}.json" - # preprocessing_output_folder = results_path / f"preprocessed_{recording_name}" - # preprocessing_output_json = results_path / f"preprocessed_{recording_name}.json" - logger.info(f"[Preprocessing] \tDuration: {np.round(recording.get_total_duration(), 2)} s") - # # TODO - Phase shift correction - # recording = spre.phase_shift( - # recording, - # **preprocessing_params.phase_shift.model_dump() - # ) + # Phase shift correction + if "inter_sample_shift" in recording.get_property_keys(): + logger.info(f"[Preprocessing] \tPhase shift") + recording = spre.phase_shift( + recording, + **preprocessing_params.phase_shift.model_dump() + ) + else: + logger.info(f"[Preprocessing] \tSkipping phase shift: 'inter_sample_shift' property not found") # Highpass filter - recording_hp_full = spre.highpass_filter( recording, **preprocessing_params.highpass_filter.model_dump() ) - # Detect bad channels + # Detect and remove bad channels _, channel_labels = spre.detect_bad_channels( recording_hp_full, **preprocessing_params.detect_bad_channels.model_dump() @@ -67,8 +65,7 @@ def preprocessing( dead_channel_mask = channel_labels == "dead" noise_channel_mask = channel_labels == "noise" out_channel_mask = channel_labels == "out" - logger.info("[Preprocessing] \tBad channel detection:") - logger.info(f"[Preprocessing] \tdead channels - {np.sum(dead_channel_mask)}\n\t\t- noise channels - {np.sum(noise_channel_mask)}\n\t\t- out channels - {np.sum(out_channel_mask)}") + logger.info(f"[Preprocessing] \tBad channel detection found: {np.sum(dead_channel_mask)} dead, {np.sum(noise_channel_mask)} noise, {np.sum(out_channel_mask)} out channels") dead_channel_ids = recording_hp_full.channel_ids[dead_channel_mask] noise_channel_ids = recording_hp_full.channel_ids[noise_channel_mask] out_channel_ids = recording_hp_full.channel_ids[out_channel_mask] @@ -78,19 +75,17 @@ def preprocessing( if len(all_bad_channel_ids) >= int(max_bad_channel_fraction_to_remove * recording.get_num_channels()): logger.info(f"[Preprocessing] \tMore than {max_bad_channel_fraction_to_remove * 100}% bad channels ({len(all_bad_channel_ids)}). ") logger.info("[Preprocessing] \tSkipping further processing for this recording.") - preprocessing_notes += f"\n- Found {len(all_bad_channel_ids)} bad channels. Skipping further processing\n" return None if preprocessing_params.remove_out_channels: logger.info(f"[Preprocessing] \tRemoving {len(out_channel_ids)} out channels") recording_rm_out = recording_hp_full.remove_channels(out_channel_ids) - preprocessing_notes += f"\n- Removed {len(out_channel_ids)} outside of the brain." else: recording_rm_out = recording_hp_full bad_channel_ids = np.concatenate((dead_channel_ids, noise_channel_ids)) - # Strategy: CMR or destripe + # Denoise: CMR or destripe if preprocessing_params.preprocessing_strategy == "cmr": recording_processed = spre.common_reference( recording_rm_out, @@ -106,23 +101,19 @@ def preprocessing( if preprocessing_params.remove_bad_channels: logger.info(f"[Preprocessing] \tRemoving {len(bad_channel_ids)} channels after {preprocessing_params.preprocessing_strategy} preprocessing") recording_processed = recording_processed.remove_channels(bad_channel_ids) - preprocessing_notes += f"\n- Removed {len(bad_channel_ids)} bad channels after preprocessing.\n" # Motion correction if preprocessing_params.motion_correction.compute: preset = preprocessing_params.motion_correction.preset logger.info(f"[Preprocessing] \tComputing motion correction with preset: {preset}") - motion_folder = results_path / "motion_correction" + motion_folder = results_folder / "motion_correction" recording_corrected = spre.correct_motion( recording_processed, preset=preset, folder=motion_folder, - **job_kwargs.model_dump() + verbose=False ) if preprocessing_params.motion_correction.apply: logger.info("[Preprocessing] \tApplying motion correction") recording_processed = recording_corrected - # recording_saved = recording_processed.save(folder=preprocessing_output_folder) - # recording_processed.dump_to_json(preprocessing_output_json, relative_to=data_folder) - return recording_processed diff --git a/src/spikeinterface_pipelines/sorting/__init__.py b/src/spikeinterface_pipelines/sorting/__init__.py deleted file mode 100644 index c9ffd2c..0000000 --- a/src/spikeinterface_pipelines/sorting/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .models import SortingParamsModel -from .sorting import sorting \ No newline at end of file diff --git a/src/spikeinterface_pipelines/sorting/sorting.py b/src/spikeinterface_pipelines/sorting/sorting.py deleted file mode 100644 index 9833e0b..0000000 --- a/src/spikeinterface_pipelines/sorting/sorting.py +++ /dev/null @@ -1,39 +0,0 @@ -import spikeinterface.sorters as ss -import spikeinterface as si -from pathlib import Path -import shutil - -from ..logger import logger -from ..models import JobKwargs -from .models import SortingParamsModel - - -def sorting( - recording: si.BaseRecording, - sorting_params: SortingParamsModel = SortingParamsModel(), - results_path: Path = Path("./results/sorting/"), - job_kwargs: JobKwargs = JobKwargs(), -) -> si.BaseSorting | None: - try: - sorter = ss.run_sorter( - recording=recording, - sorter_name=sorting_params.sorter_name, - output_folder=str(results_path / "tmp"), - verbose=False, - delete_output_folder=True, - **sorting_params.sorter_kwargs.model_dump(), - ) - # remove empty units - sorter = sorter.remove_empty_units() - # save results - logger.info(f"\tSaving results to {results_path}") - sorter = sorter.save(folder=results_path) - return sorter - except Exception as e: - # save log to results - results_path.mkdir(exist_ok=True, parents=True) - if (results_path / "tmp").exists(): - shutil.copy(results_path / "tmp/spikeinterface_log.json", results_path) - raise e - finally: - pass diff --git a/src/spikeinterface_pipelines/spikesorting/__init__.py b/src/spikeinterface_pipelines/spikesorting/__init__.py new file mode 100644 index 0000000..9bdd5fa --- /dev/null +++ b/src/spikeinterface_pipelines/spikesorting/__init__.py @@ -0,0 +1,2 @@ +from .spikesorting import spikesort +from .params import SpikeSortingParams diff --git a/src/spikeinterface_pipelines/sorting/models.py b/src/spikeinterface_pipelines/spikesorting/params.py similarity index 95% rename from src/spikeinterface_pipelines/sorting/models.py rename to src/spikeinterface_pipelines/spikesorting/params.py index cbf169e..a964514 100644 --- a/src/spikeinterface_pipelines/sorting/models.py +++ b/src/spikeinterface_pipelines/spikesorting/params.py @@ -24,7 +24,7 @@ class Kilosort25Model(BaseModel): nPCs: int = Field(default=3, description="Number of PCA dimensions") ntbuff: int = Field(default=64, description="Samples of symmetrical buffer for whitening and spike detection") nfilt_factor: int = Field(default=4, description="Max number of clusters per good channel (even temporary ones) 4") - NT: int = Field(default=-1, description='Batch size (if -1 it is automatically computed)') + NT: int = Field(default=None, description='Batch size (if None it is automatically computed)') AUCsplit: float = Field(default=0.9, description="Threshold on the area under the curve (AUC) criterion for performing a split in the final step") do_correction: bool = Field(default=True, description="If True drift registration is applied") wave_length: float = Field(default=61, description="size of the waveform extracted around each detected peak, (Default 61, maximum 81)") @@ -45,7 +45,7 @@ class MountainSort5Model(BaseModel): pass -class SortingParamsModel(BaseModel): +class SpikeSortingParams(BaseModel): sorter_name: SorterName = Field(default="kilosort2_5", description="Name of the sorter to use.") sorter_kwargs: Union[ Kilosort25Model, diff --git a/src/spikeinterface_pipelines/spikesorting/spikesorting.py b/src/spikeinterface_pipelines/spikesorting/spikesorting.py new file mode 100644 index 0000000..831e846 --- /dev/null +++ b/src/spikeinterface_pipelines/spikesorting/spikesorting.py @@ -0,0 +1,63 @@ +from pathlib import Path +import shutil + +import spikeinterface as si +import spikeinterface.sorters as ss +import spikeinterface.curation as sc + +from ..logger import logger +from .params import SpikeSortingParams + + +def spikesort( + recording: si.BaseRecording, + spikesorting_params: SpikeSortingParams = SpikeSortingParams(), + scratch_folder: Path = Path("./scratch/"), + results_folder: Path = Path("./results/spikesorting/"), +) -> si.BaseSorting | None: + """ + Apply spike sorting to recording + + Parameters + ---------- + recording: si.BaseRecording + The input recording + sorting_params: SpikeSortingParams + Spike sorting parameters + scratch_folder: Path + Path to the scratch folder + results_folder: Path + Path to the results folder + + Returns + ------- + si.BaseSorting | None + Spike sorted sorting. If spike sorting fails, None is returned + """ + output_folder = scratch_folder / "tmp_spikesorting" + + try: + logger.info(f"[Spikesorting] \tStarting {spikesorting_params.sorter_name} spike sorter") + sorting = ss.run_sorter( + recording=recording, + sorter_name=spikesorting_params.sorter_name, + output_folder=str(output_folder), + verbose=False, + delete_output_folder=True, + remove_existing_folder=True, + **spikesorting_params.sorter_kwargs.model_dump(), + ) + logger.info(f"[Spikesorting] \tFound {len(sorting.unit_ids)} raw units") + # remove spikes beyond num_Samples (if any) + sorting = sc.remove_excess_spikes(sorting=sorting, recording=recording) + # save results + logger.info(f"[Spikesorting]\tSaving results to {results_folder}") + return sorting + except Exception as e: + # save log to results + results_folder.mkdir(exist_ok=True, parents=True) + if (output_folder).is_dir(): + shutil.copy(output_folder / "spikeinterface_log.json", results_folder) + shutil.rmtree(output_folder) + logger.info(f"Spike sorting error:\n{e}") + return None diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py new file mode 100644 index 0000000..f87ddb7 --- /dev/null +++ b/tests/test_pipeline.py @@ -0,0 +1,132 @@ +import shutil +import pytest +import numpy as np +from pathlib import Path + +import spikeinterface as si +import spikeinterface.sorters as ss + +from spikeinterface_pipelines import pipeline + +from spikeinterface_pipelines.preprocessing import preprocess +from spikeinterface_pipelines.spikesorting import spikesort +from spikeinterface_pipelines.postprocessing import postprocess + +from spikeinterface_pipelines.preprocessing.params import PreprocessingParams +from spikeinterface_pipelines.spikesorting.params import Kilosort25Model, SpikeSortingParams +from spikeinterface_pipelines.postprocessing.params import PostprocessingParams + + +def _generate_gt_recording(): + recording, sorting = si.generate_ground_truth_recording(durations=[30], num_channels=64, seed=0) + # add inter sample shift (but fake) + inter_sample_shifts = np.zeros(recording.get_num_channels()) + recording.set_property("inter_sample_shift", inter_sample_shifts) + + return recording, sorting + + +@pytest.fixture +def generate_recording(): + return _generate_gt_recording() + + +def test_preprocessing(tmp_path, generate_recording): + recording, _ = generate_recording + + results_folder = Path(tmp_path) / "results_preprocessing" + scratch_folder = Path(tmp_path) / "scratch_prepocessing" + + recording_processed = preprocess( + recording=recording, + preprocessing_params=PreprocessingParams(), + results_folder=results_folder, + scratch_folder=scratch_folder + ) + + assert isinstance(recording_processed, si.BaseRecording) + + +@pytest.mark.skipif(not "kilosort2_5" in ss.installed_sorters(), reason="kilosort2_5 not installed") +def test_spikesorting(tmp_path, generate_recording): + recording, _ = generate_recording + if "inter_sample_shift" in recording.get_property_keys(): + recording.delete_property("inter_sample_shift") + + results_folder = Path(tmp_path) / "results_spikesorting" + scratch_folder = Path(tmp_path) / "scratch_spikesorting" + + sorting = spikesort( + recording=recording, + spikesorting_params=SpikeSortingParams(), + results_folder=results_folder, + scratch_folder=scratch_folder + ) + + assert isinstance(sorting, si.BaseSorting) + + +def test_postprocessing(tmp_path, generate_recording): + recording, sorting = generate_recording + if "inter_sample_shift" in recording.get_property_keys(): + recording.delete_property("inter_sample_shift") + + results_folder = Path(tmp_path) / "results_postprocessing" + scratch_folder = Path(tmp_path) / "scratch_postprocessing" + + waveform_extractor = postprocess( + recording=recording, + sorting=sorting, + postprocessing_params=PostprocessingParams(), + results_folder=results_folder, + scratch_folder=scratch_folder + ) + + assert isinstance(waveform_extractor, si.WaveformExtractor) + + +@pytest.mark.skipif(not "kilosort2_5" in ss.installed_sorters(), reason="kilosort2_5 not installed") +def test_pipeline(tmp_path, generate_recording): + recording, _ = generate_recording + if "inter_sample_shift" in recording.get_property_keys(): + recording.delete_property("inter_sample_shift") + + results_folder = Path(tmp_path) / "results" + scratch_folder = Path(tmp_path) / "scratch" + + ks25_params = Kilosort25Model(do_correction=False) + spikesorting_params = SpikeSortingParams( + sorter_name="kilosort2_5", + sorter_kwargs=ks25_params, + ) + + recording_processed, sorting, waveform_extractor = pipeline.run_pipeline( + recording=recording, + results_folder=results_folder, + scratch_folder=scratch_folder, + spikesorting_params=spikesorting_params + ) + + assert isinstance(recording_processed, si.BaseRecording) + assert isinstance(sorting, si.BaseSorting) + assert isinstance(waveform_extractor, si.WaveformExtractor) + + +if __name__ == "__main__": + tmp_folder = Path("./tmp_pipeline_output") + if tmp_folder.is_dir(): + shutil.rmtree(tmp_folder) + tmp_folder.mkdir() + + recording, sorting = _generate_gt_recording() + + print("TEST PREPROCESSING") + test_preprocessing(tmp_folder, (recording, sorting)) + print("TEST SPIKESORTING") + test_spikesorting(tmp_folder, (recording, sorting)) + print("TEST POSTPROCESSING") + test_postprocessing(tmp_folder, (recording, sorting)) + + print("TEST PIPELINE") + test_pipeline(tmp_folder, (recording, sorting)) + \ No newline at end of file From 487c6678502faa0caa16ac7de584192999578d9d Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 15 Nov 2023 13:59:52 +0100 Subject: [PATCH 14/16] Reformat --- src/spikeinterface_pipelines/__init__.py | 2 +- src/spikeinterface_pipelines/pipeline.py | 1 - .../postprocessing/__init__.py | 2 +- .../postprocessing/params.py | 55 ++++++++++++++----- .../postprocessing/postprocessing.py | 29 ++++------ .../preprocessing/__init__.py | 2 +- .../preprocessing/params.py | 8 ++- .../preprocessing/preprocessing.py | 35 +++++------- .../spikesorting/params.py | 37 ++++++++----- tests/test_pipeline.py | 9 ++- 10 files changed, 105 insertions(+), 75 deletions(-) diff --git a/src/spikeinterface_pipelines/__init__.py b/src/spikeinterface_pipelines/__init__.py index 2ae7cba..8d9a59d 100644 --- a/src/spikeinterface_pipelines/__init__.py +++ b/src/spikeinterface_pipelines/__init__.py @@ -1 +1 @@ -from .pipeline import run_pipeline \ No newline at end of file +from .pipeline import run_pipeline diff --git a/src/spikeinterface_pipelines/pipeline.py b/src/spikeinterface_pipelines/pipeline.py index 6eb869b..44d140e 100644 --- a/src/spikeinterface_pipelines/pipeline.py +++ b/src/spikeinterface_pipelines/pipeline.py @@ -22,7 +22,6 @@ def run_pipeline( postprocessing_params: PostprocessingParams = PostprocessingParams(), run_preprocessing: bool = True, ) -> Tuple[si.BaseRecording, si.BaseSorting, si.WaveformExtractor]: - # Create folders scratch_folder.mkdir(exist_ok=True, parents=True) results_folder.mkdir(exist_ok=True, parents=True) diff --git a/src/spikeinterface_pipelines/postprocessing/__init__.py b/src/spikeinterface_pipelines/postprocessing/__init__.py index e94f311..5359475 100644 --- a/src/spikeinterface_pipelines/postprocessing/__init__.py +++ b/src/spikeinterface_pipelines/postprocessing/__init__.py @@ -1,2 +1,2 @@ from .postprocessing import postprocess -from .params import PostprocessingParams \ No newline at end of file +from .params import PostprocessingParams diff --git a/src/spikeinterface_pipelines/postprocessing/params.py b/src/spikeinterface_pipelines/postprocessing/params.py index fa013ad..1f81a8f 100644 --- a/src/spikeinterface_pipelines/postprocessing/params.py +++ b/src/spikeinterface_pipelines/postprocessing/params.py @@ -15,20 +15,28 @@ class SNR(BaseModel): class ISIViolation(BaseModel): isi_threshold_ms: float = Field(default=1.5, description="ISI threshold in milliseconds.") - min_isi_ms: float = Field(default=0., description="Minimum ISI in milliseconds.") + min_isi_ms: float = Field(default=0.0, description="Minimum ISI in milliseconds.") class RPViolation(BaseModel): - refractory_period_ms: float = Field(default=1., description="Refractory period in milliseconds.") + refractory_period_ms: float = Field(default=1.0, description="Refractory period in milliseconds.") censored_period_ms: float = Field(default=0.0, description="Censored period in milliseconds.") class SlidingRPViolation(BaseModel): - bin_size_ms: float = Field(default=0.25, description="The size of binning for the autocorrelogram in ms, by default 0.25.") + bin_size_ms: float = Field( + default=0.25, description="The size of binning for the autocorrelogram in ms, by default 0.25." + ) window_size_s: float = Field(default=1, description="Window in seconds to compute correlogram, by default 1.") - exclude_ref_period_below_ms: float = Field(default=0.5, description="Refractory periods below this value are excluded, by default 0.5") - max_ref_period_ms: float = Field(default=10, description="Maximum refractory period to test in ms, by default 10 ms.") - contamination_values: Optional[list] = Field(default=None, description="The contamination values to test, by default np.arange(0.5, 35, 0.5) %") + exclude_ref_period_below_ms: float = Field( + default=0.5, description="Refractory periods below this value are excluded, by default 0.5" + ) + max_ref_period_ms: float = Field( + default=10, description="Maximum refractory period to test in ms, by default 10 ms." + ) + contamination_values: Optional[list] = Field( + default=None, description="The contamination values to test, by default np.arange(0.5, 35, 0.5) %" + ) class PeakSign(str, Enum): @@ -39,9 +47,16 @@ class PeakSign(str, Enum): class AmplitudeCutoff(BaseModel): peak_sign: PeakSign = Field(default="neg", description="The sign of the peaks.") - num_histogram_bins: int = Field(default=100, description="The number of bins to use to compute the amplitude histogram.") - histogram_smoothing_value: int = Field(default=3, description="Controls the smoothing applied to the amplitude histogram.") - amplitudes_bins_min_ratio: int = Field(default=5, description="The minimum ratio between number of amplitudes for a unit and the number of bins. If the ratio is less than this threshold, the amplitude_cutoff for the unit is set to NaN.") + num_histogram_bins: int = Field( + default=100, description="The number of bins to use to compute the amplitude histogram." + ) + histogram_smoothing_value: int = Field( + default=3, description="Controls the smoothing applied to the amplitude histogram." + ) + amplitudes_bins_min_ratio: int = Field( + default=5, + description="The minimum ratio between number of amplitudes for a unit and the number of bins. If the ratio is less than this threshold, the amplitude_cutoff for the unit is set to NaN.", + ) class AmplitudeMedian(BaseModel): @@ -49,14 +64,19 @@ class AmplitudeMedian(BaseModel): class NearestNeighbor(BaseModel): - max_spikes: int = Field(default=10000, description="The number of spikes to use, per cluster. Note that the calculation can be very slow when this number is >20000.") + max_spikes: int = Field( + default=10000, + description="The number of spikes to use, per cluster. Note that the calculation can be very slow when this number is >20000.", + ) min_spikes: int = Field(default=10, description="Minimum number of spikes.") n_neighbors: int = Field(default=4, description="The number of neighbors to use.") class NNIsolation(NearestNeighbor): n_components: int = Field(default=10, description="The number of PC components to use to project the snippets to.") - radius_um: int = Field(default=100, description="The radius, in um, that channels need to be within the peak channel to be included.") + radius_um: int = Field( + default=100, description="The radius, in um, that channels need to be within the peak channel to be included." + ) class QMParams(BaseModel): @@ -64,7 +84,9 @@ class QMParams(BaseModel): snr: SNR = Field(default=SNR(), description="Signal to noise ratio.") isi_violation: ISIViolation = Field(default=ISIViolation(), description="ISI violation.") rp_violation: RPViolation = Field(default=RPViolation(), description="Refractory period violation.") - sliding_rp_violation: SlidingRPViolation = Field(default=SlidingRPViolation(), description="Sliding refractory period violation.") + sliding_rp_violation: SlidingRPViolation = Field( + default=SlidingRPViolation(), description="Sliding refractory period violation." + ) amplitude_cutoff: AmplitudeCutoff = Field(default=AmplitudeCutoff(), description="Amplitude cutoff.") amplitude_median: AmplitudeMedian = Field(default=AmplitudeMedian(), description="Amplitude median.") nearest_neighbor: NearestNeighbor = Field(default=NearestNeighbor(), description="Nearest neighbor.") @@ -89,16 +111,21 @@ class WaveformsRaw(BaseModel): max_spikes_per_unit: int = Field(default=100, description="Maximum spikes per unit") return_scaled: bool = Field(default=True, description="Flag to determine if results should be scaled") dtype: Optional[str] = Field(default=None, description="Data type for the waveforms") - precompute_template: Tuple[str, str] = Field(default=("average", "std"), description="Precomputation template method") + precompute_template: Tuple[str, str] = Field( + default=("average", "std"), description="Precomputation template method" + ) use_relative_path: bool = Field(default=True, description="Use relative paths") + class Waveforms(BaseModel): ms_before: float = Field(default=3.0, description="Milliseconds before") ms_after: float = Field(default=4.0, description="Milliseconds after") max_spikes_per_unit: int = Field(default=500, description="Maximum spikes per unit") return_scaled: bool = Field(default=True, description="Flag to determine if results should be scaled") dtype: Optional[str] = Field(default=None, description="Data type for the waveforms") - precompute_template: Tuple[str, str] = Field(default=("average", "std"), description="Precomputation template method") + precompute_template: Tuple[str, str] = Field( + default=("average", "std"), description="Precomputation template method" + ) use_relative_path: bool = Field(default=True, description="Use relative paths") diff --git a/src/spikeinterface_pipelines/postprocessing/postprocessing.py b/src/spikeinterface_pipelines/postprocessing/postprocessing.py index e2079eb..60d21b5 100644 --- a/src/spikeinterface_pipelines/postprocessing/postprocessing.py +++ b/src/spikeinterface_pipelines/postprocessing/postprocessing.py @@ -47,32 +47,25 @@ def postprocess( # first extract some raw waveforms in memory to deduplicate based on peak alignment wf_dedup_folder = tmp_folder / "waveforms_dense" waveform_extractor_raw = si.extract_waveforms( - recording, - sorting, - folder=wf_dedup_folder, - sparse=False, - **postprocessing_params.waveforms_raw.model_dump() + recording, sorting, folder=wf_dedup_folder, sparse=False, **postprocessing_params.waveforms_raw.model_dump() ) # de-duplication sorting_deduplicated = sc.remove_redundant_units( - waveform_extractor_raw, - duplicate_threshold=postprocessing_params.duplicate_threshold + waveform_extractor_raw, duplicate_threshold=postprocessing_params.duplicate_threshold + ) + logger.info( + f"[Postprocessing] \tNumber of original units: {len(waveform_extractor_raw.sorting.unit_ids)} -- Number of units after de-duplication: {len(sorting_deduplicated.unit_ids)}" ) - logger.info(f"[Postprocessing] \tNumber of original units: {len(waveform_extractor_raw.sorting.unit_ids)} -- Number of units after de-duplication: {len(sorting_deduplicated.unit_ids)}") deduplicated_unit_ids = sorting_deduplicated.unit_ids - + # use existing deduplicated waveforms to compute sparsity sparsity_raw = si.compute_sparsity(waveform_extractor_raw, **postprocessing_params.sparsity.model_dump()) sparsity_mask = sparsity_raw.mask[sorting.ids_to_indices(deduplicated_unit_ids), :] - sparsity = si.ChannelSparsity( - mask=sparsity_mask, - unit_ids=deduplicated_unit_ids, - channel_ids=recording.channel_ids - ) + sparsity = si.ChannelSparsity(mask=sparsity_mask, unit_ids=deduplicated_unit_ids, channel_ids=recording.channel_ids) # this is a trick to make the postprocessed folder "self-contained - sorting_folder = results_folder / "sorting" + sorting_folder = results_folder / "sorting" sorting_deduplicated = sorting_deduplicated.save(folder=sorting_folder) # now extract waveforms on de-duplicated units @@ -84,7 +77,7 @@ def postprocess( sparsity=sparsity, sparse=True, overwrite=True, - **postprocessing_params.waveforms.model_dump() + **postprocessing_params.waveforms.model_dump(), ) logger.info("[Postprocessing] \tComputing spike amplitides") @@ -102,7 +95,9 @@ def postprocess( logger.info("[Postprocessing] \tComputing template metrics") _ = spost.compute_template_metrics(waveform_extractor, **postprocessing_params.template_metrics.model_dump()) logger.info("[Postprocessing] \tComputing PCA") - _ = spost.compute_principal_components(waveform_extractor, **postprocessing_params.principal_components.model_dump()) + _ = spost.compute_principal_components( + waveform_extractor, **postprocessing_params.principal_components.model_dump() + ) logger.info("[Postprocessing] \tComputing quality metrics") _ = sqm.compute_quality_metrics(waveform_extractor, **postprocessing_params.quality_metrics.model_dump()) diff --git a/src/spikeinterface_pipelines/preprocessing/__init__.py b/src/spikeinterface_pipelines/preprocessing/__init__.py index 8dbd5e0..ad522e5 100644 --- a/src/spikeinterface_pipelines/preprocessing/__init__.py +++ b/src/spikeinterface_pipelines/preprocessing/__init__.py @@ -1,2 +1,2 @@ from .preprocessing import preprocess -from .params import PreprocessingParams \ No newline at end of file +from .params import PreprocessingParams diff --git a/src/spikeinterface_pipelines/preprocessing/params.py b/src/spikeinterface_pipelines/preprocessing/params.py index 315f390..b5bc566 100644 --- a/src/spikeinterface_pipelines/preprocessing/params.py +++ b/src/spikeinterface_pipelines/preprocessing/params.py @@ -52,9 +52,13 @@ class PreprocessingParams(BaseModel): highpass_filter: HighpassFilter = Field(default=HighpassFilter(), description="Highpass filter") phase_shift: PhaseShift = Field(default=PhaseShift(), description="Phase shift") common_reference: CommonReference = Field(default=CommonReference(), description="Common reference") - highpass_spatial_filter: HighpassSpatialFilter = Field(default=HighpassSpatialFilter(), description="Highpass spatial filter") + highpass_spatial_filter: HighpassSpatialFilter = Field( + default=HighpassSpatialFilter(), description="Highpass spatial filter" + ) motion_correction: MotionCorrection = Field(default=MotionCorrection(), description="Motion correction") detect_bad_channels: DetectBadChannels = Field(default=DetectBadChannels(), description="Detect bad channels") remove_out_channels: bool = Field(default=True, description="Flag to remove out channels") remove_bad_channels: bool = Field(default=True, description="Flag to remove bad channels") - max_bad_channel_fraction_to_remove: float = Field(default=0.5, description="Maximum fraction of bad channels to remove") + max_bad_channel_fraction_to_remove: float = Field( + default=0.5, description="Maximum fraction of bad channels to remove" + ) diff --git a/src/spikeinterface_pipelines/preprocessing/preprocessing.py b/src/spikeinterface_pipelines/preprocessing/preprocessing.py index d8b1578..ba05529 100644 --- a/src/spikeinterface_pipelines/preprocessing/preprocessing.py +++ b/src/spikeinterface_pipelines/preprocessing/preprocessing.py @@ -44,28 +44,23 @@ def preprocess( # Phase shift correction if "inter_sample_shift" in recording.get_property_keys(): logger.info(f"[Preprocessing] \tPhase shift") - recording = spre.phase_shift( - recording, - **preprocessing_params.phase_shift.model_dump() - ) + recording = spre.phase_shift(recording, **preprocessing_params.phase_shift.model_dump()) else: logger.info(f"[Preprocessing] \tSkipping phase shift: 'inter_sample_shift' property not found") # Highpass filter - recording_hp_full = spre.highpass_filter( - recording, - **preprocessing_params.highpass_filter.model_dump() - ) + recording_hp_full = spre.highpass_filter(recording, **preprocessing_params.highpass_filter.model_dump()) # Detect and remove bad channels _, channel_labels = spre.detect_bad_channels( - recording_hp_full, - **preprocessing_params.detect_bad_channels.model_dump() + recording_hp_full, **preprocessing_params.detect_bad_channels.model_dump() ) dead_channel_mask = channel_labels == "dead" noise_channel_mask = channel_labels == "noise" out_channel_mask = channel_labels == "out" - logger.info(f"[Preprocessing] \tBad channel detection found: {np.sum(dead_channel_mask)} dead, {np.sum(noise_channel_mask)} noise, {np.sum(out_channel_mask)} out channels") + logger.info( + f"[Preprocessing] \tBad channel detection found: {np.sum(dead_channel_mask)} dead, {np.sum(noise_channel_mask)} noise, {np.sum(out_channel_mask)} out channels" + ) dead_channel_ids = recording_hp_full.channel_ids[dead_channel_mask] noise_channel_ids = recording_hp_full.channel_ids[noise_channel_mask] out_channel_ids = recording_hp_full.channel_ids[out_channel_mask] @@ -73,7 +68,9 @@ def preprocess( max_bad_channel_fraction_to_remove = preprocessing_params.max_bad_channel_fraction_to_remove if len(all_bad_channel_ids) >= int(max_bad_channel_fraction_to_remove * recording.get_num_channels()): - logger.info(f"[Preprocessing] \tMore than {max_bad_channel_fraction_to_remove * 100}% bad channels ({len(all_bad_channel_ids)}). ") + logger.info( + f"[Preprocessing] \tMore than {max_bad_channel_fraction_to_remove * 100}% bad channels ({len(all_bad_channel_ids)}). " + ) logger.info("[Preprocessing] \tSkipping further processing for this recording.") return None @@ -88,18 +85,18 @@ def preprocess( # Denoise: CMR or destripe if preprocessing_params.preprocessing_strategy == "cmr": recording_processed = spre.common_reference( - recording_rm_out, - **preprocessing_params.common_reference.model_dump() + recording_rm_out, **preprocessing_params.common_reference.model_dump() ) else: recording_interp = spre.interpolate_bad_channels(recording_rm_out, bad_channel_ids) recording_processed = spre.highpass_spatial_filter( - recording_interp, - **preprocessing_params.highpass_spatial_filter.model_dump() + recording_interp, **preprocessing_params.highpass_spatial_filter.model_dump() ) if preprocessing_params.remove_bad_channels: - logger.info(f"[Preprocessing] \tRemoving {len(bad_channel_ids)} channels after {preprocessing_params.preprocessing_strategy} preprocessing") + logger.info( + f"[Preprocessing] \tRemoving {len(bad_channel_ids)} channels after {preprocessing_params.preprocessing_strategy} preprocessing" + ) recording_processed = recording_processed.remove_channels(bad_channel_ids) # Motion correction @@ -108,9 +105,7 @@ def preprocess( logger.info(f"[Preprocessing] \tComputing motion correction with preset: {preset}") motion_folder = results_folder / "motion_correction" recording_corrected = spre.correct_motion( - recording_processed, preset=preset, - folder=motion_folder, - verbose=False + recording_processed, preset=preset, folder=motion_folder, verbose=False ) if preprocessing_params.motion_correction.apply: logger.info("[Preprocessing] \tApplying motion correction") diff --git a/src/spikeinterface_pipelines/spikesorting/params.py b/src/spikeinterface_pipelines/spikesorting/params.py index a964514..ab5fd57 100644 --- a/src/spikeinterface_pipelines/spikesorting/params.py +++ b/src/spikeinterface_pipelines/spikesorting/params.py @@ -13,23 +13,37 @@ class SorterName(str, Enum): class Kilosort25Model(BaseModel): detect_threshold: float = Field(default=6, description="Threshold for spike detection") projection_threshold: List[float] = Field(default=[10, 4], description="Threshold on projections") - preclust_threshold: float = Field(default=8, description="Threshold crossings for pre-clustering (in PCA projection space)") + preclust_threshold: float = Field( + default=8, description="Threshold crossings for pre-clustering (in PCA projection space)" + ) car: bool = Field(default=True, description="Enable or disable common reference") - minFR: float = Field(default=0.1, description="Minimum spike rate (Hz), if a cluster falls below this for too long it gets removed") + minFR: float = Field( + default=0.1, description="Minimum spike rate (Hz), if a cluster falls below this for too long it gets removed" + ) minfr_goodchannels: float = Field(default=0.1, description="Minimum firing rate on a 'good' channel") - nblocks: int = Field(default=5, description="blocks for registration. 0 turns it off, 1 does rigid registration. Replaces 'datashift' option.") + nblocks: int = Field( + default=5, + description="blocks for registration. 0 turns it off, 1 does rigid registration. Replaces 'datashift' option.", + ) sig: float = Field(default=20, description="spatial smoothness constant for registration") freq_min: float = Field(default=150, description="High-pass filter cutoff frequency") sigmaMask: float = Field(default=30, description="Spatial constant in um for computing residual variance of spike") nPCs: int = Field(default=3, description="Number of PCA dimensions") ntbuff: int = Field(default=64, description="Samples of symmetrical buffer for whitening and spike detection") nfilt_factor: int = Field(default=4, description="Max number of clusters per good channel (even temporary ones) 4") - NT: int = Field(default=None, description='Batch size (if None it is automatically computed)') - AUCsplit: float = Field(default=0.9, description="Threshold on the area under the curve (AUC) criterion for performing a split in the final step") + NT: int = Field(default=None, description="Batch size (if None it is automatically computed)") + AUCsplit: float = Field( + default=0.9, + description="Threshold on the area under the curve (AUC) criterion for performing a split in the final step", + ) do_correction: bool = Field(default=True, description="If True drift registration is applied") - wave_length: float = Field(default=61, description="size of the waveform extracted around each detected peak, (Default 61, maximum 81)") + wave_length: float = Field( + default=61, description="size of the waveform extracted around each detected peak, (Default 61, maximum 81)" + ) keep_good_only: bool = Field(default=True, description="If True only 'good' units are returned") - skip_kilosort_preprocessing: bool = Field(default=False, description="Can optionaly skip the internal kilosort preprocessing") + skip_kilosort_preprocessing: bool = Field( + default=False, description="Can optionaly skip the internal kilosort preprocessing" + ) scaleproc: int = Field(default=-1, description="int16 scaling of whitened data, if -1 set to 200.") @@ -47,9 +61,6 @@ class MountainSort5Model(BaseModel): class SpikeSortingParams(BaseModel): sorter_name: SorterName = Field(default="kilosort2_5", description="Name of the sorter to use.") - sorter_kwargs: Union[ - Kilosort25Model, - Kilosort3Model, - IronClustModel, - MountainSort5Model - ] = Field(default=Kilosort25Model(), description="Sorter specific kwargs.") + sorter_kwargs: Union[Kilosort25Model, Kilosort3Model, IronClustModel, MountainSort5Model] = Field( + default=Kilosort25Model(), description="Sorter specific kwargs." + ) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index f87ddb7..59428d0 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -41,7 +41,7 @@ def test_preprocessing(tmp_path, generate_recording): recording=recording, preprocessing_params=PreprocessingParams(), results_folder=results_folder, - scratch_folder=scratch_folder + scratch_folder=scratch_folder, ) assert isinstance(recording_processed, si.BaseRecording) @@ -60,7 +60,7 @@ def test_spikesorting(tmp_path, generate_recording): recording=recording, spikesorting_params=SpikeSortingParams(), results_folder=results_folder, - scratch_folder=scratch_folder + scratch_folder=scratch_folder, ) assert isinstance(sorting, si.BaseSorting) @@ -79,7 +79,7 @@ def test_postprocessing(tmp_path, generate_recording): sorting=sorting, postprocessing_params=PostprocessingParams(), results_folder=results_folder, - scratch_folder=scratch_folder + scratch_folder=scratch_folder, ) assert isinstance(waveform_extractor, si.WaveformExtractor) @@ -104,7 +104,7 @@ def test_pipeline(tmp_path, generate_recording): recording=recording, results_folder=results_folder, scratch_folder=scratch_folder, - spikesorting_params=spikesorting_params + spikesorting_params=spikesorting_params, ) assert isinstance(recording_processed, si.BaseRecording) @@ -129,4 +129,3 @@ def test_pipeline(tmp_path, generate_recording): print("TEST PIPELINE") test_pipeline(tmp_folder, (recording, sorting)) - \ No newline at end of file From 09dbd0e915c0c0fb2bc7bb937e5cf9297b4e6b15 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 15 Nov 2023 14:01:07 +0100 Subject: [PATCH 15/16] Add ci-test workflow --- .github/workflows/ci-test.yml | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 .github/workflows/ci-test.yml diff --git a/.github/workflows/ci-test.yml b/.github/workflows/ci-test.yml new file mode 100644 index 0000000..66dcf88 --- /dev/null +++ b/.github/workflows/ci-test.yml @@ -0,0 +1,33 @@ +name: Testing pipeline + +on: + pull_request: + types: [synchronize, opened, reopened] + branches: + - main + +concurrency: # Cancel previous workflows on the same pull request + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + build-and-test: + name: Test on ${{ matrix.os }} OS + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: ["ubuntu-latest", "macos-latest", "windows-latest"] + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: '3.10' + - name: Install dependencies + run: | + python -m pip install -U pip # Official recommended way + pip install -e . + - name: Test pipeline with pytest + run: | + pytest -v + shell: bash # Necessary for pipeline to work on windows From 7d7c26d51b51d93cf35a0107225e6892c4a6ad9d Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 15 Nov 2023 14:02:40 +0100 Subject: [PATCH 16/16] Install pytest --- .github/workflows/ci-test.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/ci-test.yml b/.github/workflows/ci-test.yml index 66dcf88..6524d3c 100644 --- a/.github/workflows/ci-test.yml +++ b/.github/workflows/ci-test.yml @@ -26,6 +26,7 @@ jobs: - name: Install dependencies run: | python -m pip install -U pip # Official recommended way + pip install pytest pip install -e . - name: Test pipeline with pytest run: |