diff --git a/src/spikeinterface/preprocessing/deepinterpolation/tests/test_deepinterpolation.py b/src/spikeinterface/preprocessing/deepinterpolation/tests/test_deepinterpolation.py index 25a2317d10..7fdbe13753 100644 --- a/src/spikeinterface/preprocessing/deepinterpolation/tests/test_deepinterpolation.py +++ b/src/spikeinterface/preprocessing/deepinterpolation/tests/test_deepinterpolation.py @@ -43,10 +43,12 @@ def test_deepinterpolation_training(recording_and_shape_fixture): recording, model_folder=model_folder, model_name="training", - train_start_s=0, - train_end_s=0.1, - test_start_s=1, - test_end_s=1.005, + train_start_s=1, + train_end_s=10, + train_duration_s=0.1, + test_start_s=0, + test_end_s=1, + test_duration_s=0.05, pre_frame=20, post_frame=20, run_uid="si_test", @@ -69,10 +71,12 @@ def test_deepinterpolation_transfer(recording_and_shape_fixture, tmp_path): model_folder=model_folder, model_name="si_test_transfer", existing_model_path=existing_model_path, - train_start_s=0, - train_end_s=0.1, - test_start_s=1, - test_end_s=1.005, + train_start_s=1, + train_end_s=10, + train_duration_s=0.1, + test_start_s=0, + test_end_s=1, + test_duration_s=0.05, pre_frame=20, post_frame=20, pre_post_omission=1, diff --git a/src/spikeinterface/preprocessing/deepinterpolation/train.py b/src/spikeinterface/preprocessing/deepinterpolation/train.py index 005d2e7eee..5865153393 100644 --- a/src/spikeinterface/preprocessing/deepinterpolation/train.py +++ b/src/spikeinterface/preprocessing/deepinterpolation/train.py @@ -23,11 +23,11 @@ def train_deepinterpolation( model_name: str, train_start_s: float, train_end_s: float, - train_duration_s: float, test_start_s: float, test_end_s: float, - test_duration_s: float, desired_shape: tuple[int, int], + train_duration_s: Optional[float] = None, + test_duration_s: Optional[float] = None, pre_frame: int = 30, post_frame: int = 30, pre_post_omission: int = 1,