diff --git a/src/spikeinterface_pipelines/preprocessing/preprocessing.py b/src/spikeinterface_pipelines/preprocessing/preprocessing.py index 8625ec8..4955f28 100644 --- a/src/spikeinterface_pipelines/preprocessing/preprocessing.py +++ b/src/spikeinterface_pipelines/preprocessing/preprocessing.py @@ -102,6 +102,13 @@ def preprocess( # Motion correction if preprocessing_params.motion_correction.compute: preset = preprocessing_params.motion_correction.preset + kwargs_model_name = preprocessing_params.motion_correction.motion_kwargs.__class__.__name__ + if preset == "nonrigid_accurate" and kwargs_model_name != "MCNonrigidAccurate": + raise ValueError(f"Motion correction preset {preset} requires motion_kwargs of type MCNonrigidAccurate, got {kwargs_model_name}") + elif preset == "rigid_fast" and kwargs_model_name != "MCRigidFast": + raise ValueError(f"Motion correction preset {preset} requires motion_kwargs of type MCRigidFast, got {kwargs_model_name}") + elif preset == "kilosort_like" and kwargs_model_name != "MCKilosortLike": + raise ValueError(f"Motion correction preset {preset} requires motion_kwargs of type MCKilosortLike, got {kwargs_model_name}") logger.info(f"[Preprocessing] \tComputing motion correction with preset: {preset}") motion_folder = results_folder / "motion_correction" recording_corrected = spre.correct_motion(