Skip to content

Commit

Permalink
fix validation
Browse files Browse the repository at this point in the history
  • Loading branch information
luiztauffer committed Jan 12, 2024
1 parent 78dbf9f commit d494fc7
Showing 1 changed file with 11 additions and 12 deletions.
23 changes: 11 additions & 12 deletions src/spikeinterface_pipelines/preprocessing/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import spikeinterface.preprocessing as spre

from ..logger import logger
from .params import PreprocessingParams
from .params import PreprocessingParams, MCNonrigidAccurate, MCRigidFast, MCKilosortLike


warnings.filterwarnings("ignore")
Expand Down Expand Up @@ -102,24 +102,23 @@ def preprocess(
# Motion correction
if preprocessing_params.motion_correction.compute:
preset = preprocessing_params.motion_correction.preset
kwargs_model_name = preprocessing_params.motion_correction.motion_kwargs.__class__.__name__
if preset == "nonrigid_accurate" and kwargs_model_name != "MCNonrigidAccurate":
raise ValueError(f"Motion correction preset {preset} requires motion_kwargs of type MCNonrigidAccurate, got {kwargs_model_name}")
elif preset == "rigid_fast" and kwargs_model_name != "MCRigidFast":
raise ValueError(f"Motion correction preset {preset} requires motion_kwargs of type MCRigidFast, got {kwargs_model_name}")
elif preset == "kilosort_like" and kwargs_model_name != "MCKilosortLike":
raise ValueError(f"Motion correction preset {preset} requires motion_kwargs of type MCKilosortLike, got {kwargs_model_name}")
if preset == "nonrigid_accurate":
motion_correction_kwargs = MCNonrigidAccurate(**preprocessing_params.motion_correction.motion_kwargs.model_dump())
elif preset == "rigid_fast":
motion_correction_kwargs = MCRigidFast(**preprocessing_params.motion_correction.motion_kwargs.model_dump())
elif preset == "kilosort_like":
motion_correction_kwargs = MCKilosortLike(**preprocessing_params.motion_correction.motion_kwargs.model_dump())
logger.info(f"[Preprocessing] \tComputing motion correction with preset: {preset}")
motion_folder = results_folder / "motion_correction"
recording_corrected = spre.correct_motion(
recording_processed,
preset=preset,
folder=motion_folder,
verbose=False,
detect_kwargs=preprocessing_params.motion_correction.motion_kwargs.detect_kwargs.model_dump(),
localize_peaks_kwargs=preprocessing_params.motion_correction.motion_kwargs.localize_peaks_kwargs.model_dump(),
estimate_motion_kwargs=preprocessing_params.motion_correction.motion_kwargs.estimate_motion_kwargs.model_dump(),
interpolate_motion_kwargs=preprocessing_params.motion_correction.motion_kwargs.interpolate_motion_kwargs.model_dump(),
detect_kwargs=motion_correction_kwargs.detect_kwargs.model_dump(),
localize_peaks_kwargs=motion_correction_kwargs.localize_peaks_kwargs.model_dump(),
estimate_motion_kwargs=motion_correction_kwargs.estimate_motion_kwargs.model_dump(),
interpolate_motion_kwargs=motion_correction_kwargs.interpolate_motion_kwargs.model_dump(),
)
if preprocessing_params.motion_correction.apply:
logger.info("[Preprocessing] \tApplying motion correction")
Expand Down

0 comments on commit d494fc7

Please sign in to comment.