From b6cdeb96b000e829957ce78ce908dfa1fdc1b5d6 Mon Sep 17 00:00:00 2001 From: luiz Date: Fri, 12 Jan 2024 14:18:43 +0100 Subject: [PATCH 01/10] models --- .../postprocessing/params.py | 3 +- .../preprocessing/params.py | 117 +++++++++++++++++- .../preprocessing/preprocessing.py | 10 +- 3 files changed, 126 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface_pipelines/postprocessing/params.py b/src/spikeinterface_pipelines/postprocessing/params.py index 1f81a8f..ecb56ff 100644 --- a/src/spikeinterface_pipelines/postprocessing/params.py +++ b/src/spikeinterface_pipelines/postprocessing/params.py @@ -94,9 +94,10 @@ class QMParams(BaseModel): nn_noise_overlap: NNIsolation = Field(default=NNIsolation(), description="Nearest neighbor noise overlap.") +# TODO - fill in metric_names defauklt value with all the metric names class QualityMetrics(BaseModel): qm_params: QMParams = Field(default=QMParams(), description="Quality metric parameters.") - metric_names: List[str] = Field(default=None, description="List of metric names to compute.") + metric_names: List[str] = Field(default=[], description="List of metric names to compute.") n_jobs: int = Field(default=1, description="Number of jobs.") diff --git a/src/spikeinterface_pipelines/preprocessing/params.py b/src/spikeinterface_pipelines/preprocessing/params.py index b5bc566..c6d126b 100644 --- a/src/spikeinterface_pipelines/preprocessing/params.py +++ b/src/spikeinterface_pipelines/preprocessing/params.py @@ -1,5 +1,5 @@ from pydantic import BaseModel, Field -from typing import Optional +from typing import Optional, Union, List from enum import Enum @@ -41,12 +41,125 @@ class HighpassSpatialFilter(BaseModel): highpass_butter_wn: float = Field(default=0.01, description="Natural frequency for the Butterworth filter") +# Motion correction --------------------------------------------------------------- +class MCDetectKwargs(BaseModel): + method: str = Field(default="locally_exclusive", description="") + peak_sign: str = Field(default="neg", description="") + detect_threshold: float = Field(default=8.0, description="") + exclude_sweep_ms: float = Field(default=0.1, description="") + radius_um: float = Field(default=50.0, description="") + + +class MCLocalizeCenterOfMass(BaseModel): + radius_um: float = Field(default=75.0, description="Radius in um for channel sparsity.") + feature: str = Field(default="ptp", description="'ptp', 'mean', 'energy' or 'peak_voltage'. Feature to consider for computation") + + +class MCLocalizeMonopolarTriangulation(BaseModel): + radius_um: float = Field(default=75.0, description="For channel sparsity.") + max_distance_um: float = Field(default=150.0, description="Boundary for distance estimation.") + optimizer: str = Field(default="minimize_with_log_penality", description="") + enforce_decrease: bool = Field(default=True, description="Enforce spatial decreasingness for PTP vectors") + feature: str = Field(default="ptp", description="'ptp', 'energy' or 'peak_voltage'. The available features to consider for estimating the position via monopolar triangulation are peak-to-peak amplitudes (ptp, default), energy ('energy', as L2 norm) or voltages at the center of the waveform (peak_voltage)") + + +class MCLocalizeGridConvolution(BaseModel): + radius_um: float = Field(default=40.0, description="Radius in um for channel sparsity.") + upsampling_um: float = Field(default=5.0, description="Upsampling resolution for the grid of templates.") + sigma_um: List[float] = Field(default=[5.0, 25.0, 5], description="Spatial decays of the fake templates.") + sigma_ms: float = Field(default=0.25, description="The temporal decay of the fake templates.") + margin_um: float = Field(default=30.0, description="The margin for the grid of fake templates.") + percentile: float = Field(default=10.0, description="The percentage in [0, 100] of the best scalar products kept to estimate the position.") + sparsity_threshold: float = Field(default=0.01, description="The sparsity threshold (in [0, 1]) below which weights should be considered as 0.") + + +class MCEstimateMotionDecentralized(BaseModel): + method: str = Field(default="decentralized", description="") + direction: str = Field(default="y", description="") + bin_duration_s: float = Field(default=2.0, description="") + rigid: bool = Field(default=False, description="") + bin_um: float = Field(default=5.0, description="") + margin_um: float = Field(default=0.0, description="") + win_shape: str = Field(default="gaussian", description="") + win_step_um: float = Field(default=100.0, description="") + win_sigma_um: float = Field(default=200.0, description="") + histogram_depth_smooth_um: float = Field(default=5.0, description="") + histogram_time_smooth_s: Optional[float] = Field(default=None, description="") + pairwise_displacement_method: str = Field(default="conv", description="") + max_displacement_um: float = Field(default=100.0, description="") + weight_scale: str = Field(default="linear", description="") + error_sigma: float = Field(default=0.2, description="") + conv_engine: Optional[str] = Field(default=None, description="") + torch_device: Optional[str] = Field(default=None, description="") + batch_size: int = Field(default=1, description="") + corr_threshold: float = Field(default=0.0, description="") + time_horizon_s: Optional[float] = Field(default=None, description="") + convergence_method: str = Field(default="lsmr", description="") + soft_weights: bool = Field(default=False, description="") + normalized_xcorr: bool = Field(default=True, description="") + centered_xcorr: bool = Field(default=True, description="") + temporal_prior: bool = Field(default=True, description="") + spatial_prior: bool = Field(default=False, description="") + force_spatial_median_continuity: bool = Field(default=False, description="") + reference_displacement: str = Field(default="median", description="") + reference_displacement_time_s: float = Field(default=0, description="") + robust_regression_sigma: int = Field(default=2, description="") + weight_with_amplitude: bool = Field(default=False, description="") + + +class MCEstimateMotionIterativeTemplate(BaseModel): + bin_duration_s: float = Field(default=2.0, description="") + rigid: bool = Field(default=False, description="") + win_step_um: float = Field(default=50.0, description="") + win_sigma_um: float = Field(default=150.0, description="") + margin_um: float = Field(default=0.0, description="") + win_shape: str = Field(default="rect", description="") + + +class MCInterpolateMotionKwargs(BaseModel): + direction: int = Field(default=1, description="0 | 1 | 2. Dimension along which channel_locations are shifted (0 - x, 1 - y, 2 - z).") + border_mode: str = Field(default="remove_channels", description="'remove_channels' | 'force_extrapolate' | 'force_zeros'. Control how channels are handled on border.") + spatial_interpolation_method: str = Field(default="idw", description="The spatial interpolation method used to interpolate the channel locations.") + sigma_um: float = Field(default=20.0, description="Used in the 'kriging' formula") + p: int = Field(default=1, description="Used in the 'kriging' formula") + num_closest: int = Field(default=3, description="Number of closest channels used by 'idw' method for interpolation.") + + +class MCNonrigidAccurate(BaseModel): + detect_kwargs: MCDetectKwargs = Field(default=MCDetectKwargs(), description="") + localize_peaks_kwargs: MCLocalizeMonopolarTriangulation = Field(default=MCLocalizeMonopolarTriangulation(), description="") + estimate_motion_kwargs: MCEstimateMotionDecentralized = Field(default=MCEstimateMotionDecentralized(), description="") + interpolate_motion_kwargs: MCInterpolateMotionKwargs = Field(default=MCInterpolateMotionKwargs(), description="") + + +class MCRigidFast(BaseModel): + detect_kwargs: MCDetectKwargs = Field(default=MCDetectKwargs(), description="") + localize_peaks_kwargs: MCLocalizeCenterOfMass = Field(default=MCLocalizeCenterOfMass(), description="") + estimate_motion_kwargs: MCEstimateMotionDecentralized = Field(default=MCEstimateMotionDecentralized(bin_duration_s=10.0, rigid=True), description="") + interpolate_motion_kwargs: MCInterpolateMotionKwargs = Field(default=MCInterpolateMotionKwargs(), description="") + + +class MCKilosortLike(BaseModel): + detect_kwargs: MCDetectKwargs = Field(default=MCDetectKwargs(), description="") + localize_peaks_kwargs: MCLocalizeGridConvolution = Field(default=MCLocalizeGridConvolution(), description="") + estimate_motion_kwargs: MCEstimateMotionIterativeTemplate = Field(default=MCEstimateMotionIterativeTemplate(), description="") + interpolate_motion_kwargs: MCInterpolateMotionKwargs = Field(default=MCInterpolateMotionKwargs(border_mode="force_extrapolate", spatial_interpolation_method="kriging"), description="") + + +class MCPreset(str, Enum): + nonrigid_accurate = "nonrigid_accurate" + rigid_fast = "rigid_fast" + kilosort_like = "kilosort_like" + + class MotionCorrection(BaseModel): compute: bool = Field(default=True, description="Whether to compute motion correction") apply: bool = Field(default=False, description="Whether to apply motion correction") - preset: str = Field(default="nonrigid_accurate", description="Preset for motion correction") + preset: MCPreset = Field(default=MCPreset.nonrigid_accurate.value, description="Preset for motion correction") + motion_kwargs: Union[MCNonrigidAccurate, MCRigidFast, MCKilosortLike] = Field(default=MCNonrigidAccurate(), description="Motion correction parameters") +# Preprocessing params --------------------------------------------------------------- class PreprocessingParams(BaseModel): preprocessing_strategy: PreprocessingStrategy = Field(default="cmr", description="Strategy for preprocessing") highpass_filter: HighpassFilter = Field(default=HighpassFilter(), description="Highpass filter") diff --git a/src/spikeinterface_pipelines/preprocessing/preprocessing.py b/src/spikeinterface_pipelines/preprocessing/preprocessing.py index e65a808..ed07929 100644 --- a/src/spikeinterface_pipelines/preprocessing/preprocessing.py +++ b/src/spikeinterface_pipelines/preprocessing/preprocessing.py @@ -105,7 +105,15 @@ def preprocess( logger.info(f"[Preprocessing] \tComputing motion correction with preset: {preset}") motion_folder = results_folder / "motion_correction" recording_corrected = spre.correct_motion( - recording_processed, preset=preset, folder=motion_folder, verbose=False + recording_processed, + preset=preset, + folder=motion_folder, + verbose=False, + detect_kwargs=preprocessing_params.motion_correction.motion_kwargs.detect_kwargs.model_dump(), + select_kwargs=dict(), + localize_peaks_kwargs=preprocessing_params.motion_correction.motion_kwargs.localize_peaks_kwargs.model_dump(), + estimate_motion_kwargs=preprocessing_params.motion_correction.motion_kwargs.estimate_motion_kwargs.model_dump(), + interpolate_motion_kwargs=preprocessing_params.motion_correction.motion_kwargs.interpolate_motion_kwargs.model_dump(), ) if preprocessing_params.motion_correction.apply: logger.info("[Preprocessing] \tApplying motion correction") From aab0d429b43cfce79e5132b794e0b4c593b34de4 Mon Sep 17 00:00:00 2001 From: Luiz Tauffer Date: Fri, 12 Jan 2024 15:18:21 +0100 Subject: [PATCH 02/10] Update src/spikeinterface_pipelines/preprocessing/preprocessing.py Co-authored-by: Alessio Buccino --- src/spikeinterface_pipelines/preprocessing/preprocessing.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface_pipelines/preprocessing/preprocessing.py b/src/spikeinterface_pipelines/preprocessing/preprocessing.py index ed07929..8625ec8 100644 --- a/src/spikeinterface_pipelines/preprocessing/preprocessing.py +++ b/src/spikeinterface_pipelines/preprocessing/preprocessing.py @@ -110,7 +110,6 @@ def preprocess( folder=motion_folder, verbose=False, detect_kwargs=preprocessing_params.motion_correction.motion_kwargs.detect_kwargs.model_dump(), - select_kwargs=dict(), localize_peaks_kwargs=preprocessing_params.motion_correction.motion_kwargs.localize_peaks_kwargs.model_dump(), estimate_motion_kwargs=preprocessing_params.motion_correction.motion_kwargs.estimate_motion_kwargs.model_dump(), interpolate_motion_kwargs=preprocessing_params.motion_correction.motion_kwargs.interpolate_motion_kwargs.model_dump(), From 775d6283b7d52bb4cbcce21e2c79af85ed2a2b8b Mon Sep 17 00:00:00 2001 From: Luiz Tauffer Date: Fri, 12 Jan 2024 15:19:07 +0100 Subject: [PATCH 03/10] Update src/spikeinterface_pipelines/preprocessing/params.py Co-authored-by: Alessio Buccino --- src/spikeinterface_pipelines/preprocessing/params.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface_pipelines/preprocessing/params.py b/src/spikeinterface_pipelines/preprocessing/params.py index c6d126b..5824483 100644 --- a/src/spikeinterface_pipelines/preprocessing/params.py +++ b/src/spikeinterface_pipelines/preprocessing/params.py @@ -43,11 +43,11 @@ class HighpassSpatialFilter(BaseModel): # Motion correction --------------------------------------------------------------- class MCDetectKwargs(BaseModel): - method: str = Field(default="locally_exclusive", description="") - peak_sign: str = Field(default="neg", description="") - detect_threshold: float = Field(default=8.0, description="") - exclude_sweep_ms: float = Field(default=0.1, description="") - radius_um: float = Field(default=50.0, description="") + method: str = Field(default="locally_exclusive", description="The method for peak detection.") + peak_sign: Literal["pos", "neg", "both] = Field(default="neg", description="The peak sign to detect peaks.") + detect_threshold: float = Field(default=8.0, description="The detection threshold in MAD units.") + exclude_sweep_ms: float = Field(default=0.1, description="The time sweep to exclude for time de-duplication.") + radius_um: float = Field(default=50.0, description="The radius in um for channel de-duplication.") class MCLocalizeCenterOfMass(BaseModel): From 2e6ae9b44558f9e55ee9f111fbabab4f1aa88dac Mon Sep 17 00:00:00 2001 From: Luiz Tauffer Date: Fri, 12 Jan 2024 15:19:17 +0100 Subject: [PATCH 04/10] Update src/spikeinterface_pipelines/preprocessing/params.py Co-authored-by: Alessio Buccino --- src/spikeinterface_pipelines/preprocessing/params.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface_pipelines/preprocessing/params.py b/src/spikeinterface_pipelines/preprocessing/params.py index 5824483..34ef33e 100644 --- a/src/spikeinterface_pipelines/preprocessing/params.py +++ b/src/spikeinterface_pipelines/preprocessing/params.py @@ -56,7 +56,7 @@ class MCLocalizeCenterOfMass(BaseModel): class MCLocalizeMonopolarTriangulation(BaseModel): - radius_um: float = Field(default=75.0, description="For channel sparsity.") + radius_um: float = Field(default=75.0, description="Radius in um for channel sparsity.") max_distance_um: float = Field(default=150.0, description="Boundary for distance estimation.") optimizer: str = Field(default="minimize_with_log_penality", description="") enforce_decrease: bool = Field(default=True, description="Enforce spatial decreasingness for PTP vectors") From 426c9c399a57c6fd163b344e4df386ed90f4d9e9 Mon Sep 17 00:00:00 2001 From: Luiz Tauffer Date: Fri, 12 Jan 2024 15:19:42 +0100 Subject: [PATCH 05/10] Update src/spikeinterface_pipelines/postprocessing/params.py Co-authored-by: Alessio Buccino --- src/spikeinterface_pipelines/postprocessing/params.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface_pipelines/postprocessing/params.py b/src/spikeinterface_pipelines/postprocessing/params.py index ecb56ff..cafba32 100644 --- a/src/spikeinterface_pipelines/postprocessing/params.py +++ b/src/spikeinterface_pipelines/postprocessing/params.py @@ -97,7 +97,7 @@ class QMParams(BaseModel): # TODO - fill in metric_names defauklt value with all the metric names class QualityMetrics(BaseModel): qm_params: QMParams = Field(default=QMParams(), description="Quality metric parameters.") - metric_names: List[str] = Field(default=[], description="List of metric names to compute.") + metric_names: Optional[List[str]] = Field(default=None, description="List of metric names to compute. If None, all available metrics are computed.") n_jobs: int = Field(default=1, description="Number of jobs.") From 975693de0f52c522c05332135d73ba31ab2d900e Mon Sep 17 00:00:00 2001 From: luiz Date: Fri, 12 Jan 2024 15:29:45 +0100 Subject: [PATCH 06/10] fix --- src/spikeinterface_pipelines/preprocessing/params.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface_pipelines/preprocessing/params.py b/src/spikeinterface_pipelines/preprocessing/params.py index 34ef33e..3458f68 100644 --- a/src/spikeinterface_pipelines/preprocessing/params.py +++ b/src/spikeinterface_pipelines/preprocessing/params.py @@ -1,5 +1,5 @@ from pydantic import BaseModel, Field -from typing import Optional, Union, List +from typing import Optional, Union, List, Literal from enum import Enum @@ -44,7 +44,7 @@ class HighpassSpatialFilter(BaseModel): # Motion correction --------------------------------------------------------------- class MCDetectKwargs(BaseModel): method: str = Field(default="locally_exclusive", description="The method for peak detection.") - peak_sign: Literal["pos", "neg", "both] = Field(default="neg", description="The peak sign to detect peaks.") + peak_sign: Literal["pos", "neg", "both"] = Field(default="neg", description="The peak sign to detect peaks.") detect_threshold: float = Field(default=8.0, description="The detection threshold in MAD units.") exclude_sweep_ms: float = Field(default=0.1, description="The time sweep to exclude for time de-duplication.") radius_um: float = Field(default=50.0, description="The radius in um for channel de-duplication.") From 78dbf9f4864a31d1494a995b60fe6a6a58e933e6 Mon Sep 17 00:00:00 2001 From: luiz Date: Fri, 12 Jan 2024 16:10:34 +0100 Subject: [PATCH 07/10] check model for presets --- .../preprocessing/preprocessing.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/spikeinterface_pipelines/preprocessing/preprocessing.py b/src/spikeinterface_pipelines/preprocessing/preprocessing.py index 8625ec8..4955f28 100644 --- a/src/spikeinterface_pipelines/preprocessing/preprocessing.py +++ b/src/spikeinterface_pipelines/preprocessing/preprocessing.py @@ -102,6 +102,13 @@ def preprocess( # Motion correction if preprocessing_params.motion_correction.compute: preset = preprocessing_params.motion_correction.preset + kwargs_model_name = preprocessing_params.motion_correction.motion_kwargs.__class__.__name__ + if preset == "nonrigid_accurate" and kwargs_model_name != "MCNonrigidAccurate": + raise ValueError(f"Motion correction preset {preset} requires motion_kwargs of type MCNonrigidAccurate, got {kwargs_model_name}") + elif preset == "rigid_fast" and kwargs_model_name != "MCRigidFast": + raise ValueError(f"Motion correction preset {preset} requires motion_kwargs of type MCRigidFast, got {kwargs_model_name}") + elif preset == "kilosort_like" and kwargs_model_name != "MCKilosortLike": + raise ValueError(f"Motion correction preset {preset} requires motion_kwargs of type MCKilosortLike, got {kwargs_model_name}") logger.info(f"[Preprocessing] \tComputing motion correction with preset: {preset}") motion_folder = results_folder / "motion_correction" recording_corrected = spre.correct_motion( From d494fc7d5c34db4ec67b39436df4e526fa419d93 Mon Sep 17 00:00:00 2001 From: luiz Date: Fri, 12 Jan 2024 16:35:58 +0100 Subject: [PATCH 08/10] fix validation --- .../preprocessing/preprocessing.py | 23 +++++++++---------- 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/src/spikeinterface_pipelines/preprocessing/preprocessing.py b/src/spikeinterface_pipelines/preprocessing/preprocessing.py index 4955f28..a917fac 100644 --- a/src/spikeinterface_pipelines/preprocessing/preprocessing.py +++ b/src/spikeinterface_pipelines/preprocessing/preprocessing.py @@ -6,7 +6,7 @@ import spikeinterface.preprocessing as spre from ..logger import logger -from .params import PreprocessingParams +from .params import PreprocessingParams, MCNonrigidAccurate, MCRigidFast, MCKilosortLike warnings.filterwarnings("ignore") @@ -102,13 +102,12 @@ def preprocess( # Motion correction if preprocessing_params.motion_correction.compute: preset = preprocessing_params.motion_correction.preset - kwargs_model_name = preprocessing_params.motion_correction.motion_kwargs.__class__.__name__ - if preset == "nonrigid_accurate" and kwargs_model_name != "MCNonrigidAccurate": - raise ValueError(f"Motion correction preset {preset} requires motion_kwargs of type MCNonrigidAccurate, got {kwargs_model_name}") - elif preset == "rigid_fast" and kwargs_model_name != "MCRigidFast": - raise ValueError(f"Motion correction preset {preset} requires motion_kwargs of type MCRigidFast, got {kwargs_model_name}") - elif preset == "kilosort_like" and kwargs_model_name != "MCKilosortLike": - raise ValueError(f"Motion correction preset {preset} requires motion_kwargs of type MCKilosortLike, got {kwargs_model_name}") + if preset == "nonrigid_accurate": + motion_correction_kwargs = MCNonrigidAccurate(**preprocessing_params.motion_correction.motion_kwargs.model_dump()) + elif preset == "rigid_fast": + motion_correction_kwargs = MCRigidFast(**preprocessing_params.motion_correction.motion_kwargs.model_dump()) + elif preset == "kilosort_like": + motion_correction_kwargs = MCKilosortLike(**preprocessing_params.motion_correction.motion_kwargs.model_dump()) logger.info(f"[Preprocessing] \tComputing motion correction with preset: {preset}") motion_folder = results_folder / "motion_correction" recording_corrected = spre.correct_motion( @@ -116,10 +115,10 @@ def preprocess( preset=preset, folder=motion_folder, verbose=False, - detect_kwargs=preprocessing_params.motion_correction.motion_kwargs.detect_kwargs.model_dump(), - localize_peaks_kwargs=preprocessing_params.motion_correction.motion_kwargs.localize_peaks_kwargs.model_dump(), - estimate_motion_kwargs=preprocessing_params.motion_correction.motion_kwargs.estimate_motion_kwargs.model_dump(), - interpolate_motion_kwargs=preprocessing_params.motion_correction.motion_kwargs.interpolate_motion_kwargs.model_dump(), + detect_kwargs=motion_correction_kwargs.detect_kwargs.model_dump(), + localize_peaks_kwargs=motion_correction_kwargs.localize_peaks_kwargs.model_dump(), + estimate_motion_kwargs=motion_correction_kwargs.estimate_motion_kwargs.model_dump(), + interpolate_motion_kwargs=motion_correction_kwargs.interpolate_motion_kwargs.model_dump(), ) if preprocessing_params.motion_correction.apply: logger.info("[Preprocessing] \tApplying motion correction") From c8afccb28f139cb9673b450d0b98880ed225c8eb Mon Sep 17 00:00:00 2001 From: luiz Date: Fri, 12 Jan 2024 16:46:36 +0100 Subject: [PATCH 09/10] strategy --- src/spikeinterface_pipelines/preprocessing/params.py | 3 +-- src/spikeinterface_pipelines/preprocessing/preprocessing.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface_pipelines/preprocessing/params.py b/src/spikeinterface_pipelines/preprocessing/params.py index 3458f68..f7f4876 100644 --- a/src/spikeinterface_pipelines/preprocessing/params.py +++ b/src/spikeinterface_pipelines/preprocessing/params.py @@ -153,8 +153,7 @@ class MCPreset(str, Enum): class MotionCorrection(BaseModel): - compute: bool = Field(default=True, description="Whether to compute motion correction") - apply: bool = Field(default=False, description="Whether to apply motion correction") + strategy: Literal["skip", "compute", "apply"] = Field(default="compute", description="What strategy to use for motion correction") preset: MCPreset = Field(default=MCPreset.nonrigid_accurate.value, description="Preset for motion correction") motion_kwargs: Union[MCNonrigidAccurate, MCRigidFast, MCKilosortLike] = Field(default=MCNonrigidAccurate(), description="Motion correction parameters") diff --git a/src/spikeinterface_pipelines/preprocessing/preprocessing.py b/src/spikeinterface_pipelines/preprocessing/preprocessing.py index a917fac..e11bb9c 100644 --- a/src/spikeinterface_pipelines/preprocessing/preprocessing.py +++ b/src/spikeinterface_pipelines/preprocessing/preprocessing.py @@ -100,7 +100,7 @@ def preprocess( recording_processed = recording_processed.remove_channels(bad_channel_ids) # Motion correction - if preprocessing_params.motion_correction.compute: + if preprocessing_params.motion_correction.strategy != "skip": preset = preprocessing_params.motion_correction.preset if preset == "nonrigid_accurate": motion_correction_kwargs = MCNonrigidAccurate(**preprocessing_params.motion_correction.motion_kwargs.model_dump()) @@ -120,7 +120,7 @@ def preprocess( estimate_motion_kwargs=motion_correction_kwargs.estimate_motion_kwargs.model_dump(), interpolate_motion_kwargs=motion_correction_kwargs.interpolate_motion_kwargs.model_dump(), ) - if preprocessing_params.motion_correction.apply: + if preprocessing_params.motion_correction.strategy == "apply": logger.info("[Preprocessing] \tApplying motion correction") recording_processed = recording_corrected From 2f674a0d119f8fe710110337ab392b55e92bc9e6 Mon Sep 17 00:00:00 2001 From: luiz Date: Fri, 12 Jan 2024 16:52:22 +0100 Subject: [PATCH 10/10] version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 375176e..989df2b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "spikeinterface_pipelines" -version = "0.0.3" +version = "0.0.4" description = "Collection of standardized analysis pipelines based on SpikeInterfacee." readme = "README.md" authors = [