Skip to content

Commit

Permalink
Merge pull request #2971 from alejoe91/fix-deepinterpolation-tests
Browse files Browse the repository at this point in the history
Fix deepinterpolation tests
  • Loading branch information
samuelgarcia authored Jun 5, 2024
2 parents c71d454 + a72e8b4 commit 7f1f424
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@

import numpy as np
from typing import Optional
from packaging.version import parse

from .tf_utils import has_tf, import_tf
from ...core import BaseRecording
from ...core.core_tools import define_function_from_class
from ..basepreprocessor import BasePreprocessor, BasePreprocessorSegment
from ..zero_channel_pad import ZeroChannelPaddedRecording
from spikeinterface.core import get_random_data_chunks


class DeepInterpolatedRecording(BasePreprocessor):
Expand Down Expand Up @@ -66,6 +64,11 @@ def __init__(
disable_tf_logger: bool = True,
memory_gpu: Optional[int] = None,
):
import deepinterpolation

if parse(deepinterpolation.__version__) < parse("0.2.0"):
raise ImportError("DeepInterpolation version must be at least 0.2.0")

assert has_tf(
use_gpu, disable_tf_logger, memory_gpu
), "To use DeepInterpolation, you first need to install `tensorflow`."
Expand Down
34 changes: 7 additions & 27 deletions src/spikeinterface/preprocessing/deepinterpolation/generators.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,8 @@
from __future__ import annotations
import tempfile
import json
from typing import Optional
import numpy as np
import os

from ...core import load_extractor, concatenate_recordings, BaseRecording, BaseRecordingSegment

from .tf_utils import has_tf, import_tf

from ...core import load_extractor
from ...core import concatenate_recordings, BaseRecording, BaseRecordingSegment

from deepinterpolation.generator_collection import SequentialGenerator

Expand Down Expand Up @@ -41,16 +34,16 @@ def __init__(
r.get_num_channels() == recordings[0].get_num_channels() for r in recordings[1:]
), "All recordings must have the same number of channels"

total_num_samples = np.sum(r.get_total_samples() for r in recordings)
total_num_samples = sum([r.get_total_samples() for r in recordings])
# In case of multiple recordings and/or multiple segments, we calculate the frame periods to be excluded (borders)
exclude_intervals = []
pre_extended = pre_frame + pre_post_omission
post_extended = post_frame + pre_post_omission
for i, recording in enumerate(recordings):
total_samples_pre = np.sum(r.get_total_samples() for r in recordings[:i])
total_samples_pre = sum([r.get_total_samples() for r in recordings[:i]])
for segment_index in range(recording.get_num_segments()):
# exclude samples at the border of the recordings
num_samples_segment_pre = np.sum(recording.get_num_samples(s) for s in np.arange(segment_index))
num_samples_segment_pre = sum([recording.get_num_samples(s) for s in np.arange(segment_index)])
if num_samples_segment_pre > 0:
exclude_intervals.append(
(
Expand All @@ -63,7 +56,7 @@ def __init__(
exclude_intervals.append((total_samples_pre - pre_extended - 1, total_samples_pre + post_extended))

total_valid_samples = (
total_num_samples - np.sum([end - start for start, end in exclude_intervals]) - pre_extended - post_extended
total_num_samples - sum([end - start for start, end in exclude_intervals]) - pre_extended - post_extended
)
self.total_samples = int(total_valid_samples) if total_samples == -1 else total_samples
assert len(desired_shape) == 2, "desired_shape should be 2D"
Expand All @@ -86,12 +79,7 @@ def __init__(
sequential_generator_params["total_samples"] = self.total_samples
sequential_generator_params["pre_post_omission"] = pre_post_omission

with tempfile.NamedTemporaryFile(suffix=".json", mode="w", delete=False, dir="/tmp") as f:
json.dump(sequential_generator_params, f)
f.flush()
json_path = f.name

super().__init__(json_path)
super().__init__(sequential_generator_params)

self._update_end_frame(total_num_samples)

Expand Down Expand Up @@ -206,9 +194,6 @@ def reshape_output(self, output):
return output.squeeze().reshape(-1, self.recording.get_num_channels())


from deepinterpolation.generator_collection import SequentialGenerator


class SpikeInterfaceRecordingSegmentGenerator(SequentialGenerator):
"""This generator is used when dealing with a SpikeInterface recording.
The desired shape controls the reshaping of the input data before convolutions."""
Expand Down Expand Up @@ -246,12 +231,7 @@ def __init__(
sequential_generator_params["total_samples"] = self.total_samples
sequential_generator_params["pre_post_omission"] = pre_post_omission

with tempfile.NamedTemporaryFile(suffix=".json", mode="w", delete=False, dir="/tmp") as f:
json.dump(sequential_generator_params, f)
f.flush()
json_path = f.name

super().__init__(json_path)
super().__init__(sequential_generator_params)

self._update_end_frame(num_segment_samples)
# IMPORTANT: this is used for inference, so we don't want to shuffle
Expand Down
Original file line number Diff line number Diff line change
@@ -1,21 +1,26 @@
import pytest
import numpy as np
from pathlib import Path
from packaging.version import parse
from warnings import warn

import probeinterface
from spikeinterface import download_dataset, generate_recording, append_recordings, concatenate_recordings
from spikeinterface.extractors import read_mearec, read_spikeglx, read_openephys
from spikeinterface import generate_recording, append_recordings
from spikeinterface.preprocessing import depth_order, zscore

from spikeinterface.preprocessing.deepinterpolation import train_deepinterpolation, deepinterpolate
from spikeinterface.preprocessing.deepinterpolation import train_deepinterpolation, deepinterpolate
from spikeinterface.preprocessing.deepinterpolation.train import train_deepinterpolation_process


try:
import tensorflow
import deepinterpolation

HAVE_DEEPINTERPOLATION = True
if parse(deepinterpolation.__version__) >= parse("0.2.0"):
HAVE_DEEPINTERPOLATION = True
else:
warn("DeepInterpolation version >=0.2.0 is required for the tests. Skipping...")
HAVE_DEEPINTERPOLATION = False
except ImportError:
HAVE_DEEPINTERPOLATION = False

Expand Down Expand Up @@ -67,6 +72,7 @@ def test_deepinterpolation_generator_borders(recording_and_shape_fixture):


@pytest.mark.skipif(not HAVE_DEEPINTERPOLATION, reason="requires deepinterpolation")
@pytest.mark.dependency()
def test_deepinterpolation_training(recording_and_shape_fixture):
recording, desired_shape = recording_and_shape_fixture

Expand All @@ -87,6 +93,7 @@ def test_deepinterpolation_training(recording_and_shape_fixture):
run_uid="si_test",
pre_post_omission=1,
desired_shape=desired_shape,
nb_workers=1,
)
print(model_path)

Expand Down Expand Up @@ -173,6 +180,6 @@ def test_deepinterpolation_inference_multi_job(recording_and_shape_fixture):
recording_shape = recording_and_shape()
test_deepinterpolation_training(recording_shape)
# test_deepinterpolation_transfer()
test_deepinterpolation_inference(recording_shape)
# test_deepinterpolation_inference(recording_shape)
# test_deepinterpolation_inference_multi_job()
# test_deepinterpolation_generator_borders(recording_shape)
41 changes: 15 additions & 26 deletions src/spikeinterface/preprocessing/deepinterpolation/train.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
from __future__ import annotations
import json
import os
import warnings
from pathlib import Path
from typing import Optional
from typing import Callable, Optional

from concurrent.futures import ProcessPoolExecutor
import multiprocessing as mp

from .tf_utils import import_tf

# from .generators import define_recording_generator_class

from ...core import BaseRecording


Expand Down Expand Up @@ -45,7 +41,7 @@ def train_deepinterpolation(
nb_workers: int = -1,
caching_validation: bool = False,
run_uid: str = "si",
network_name: str = "unet_single_ephys_1024",
network: Callable | None = None,
use_gpu: bool = True,
disable_tf_logger: bool = True,
memory_gpu: Optional[int] = None,
Expand Down Expand Up @@ -106,8 +102,10 @@ def train_deepinterpolation(
Whether to cache the validation data
run_uid : str, default: "si"
Unique identifier for the training
network_name : str, default: "unet_single_ephys_1024"
Name of the network to be used
network : Callable or None, default: None
Name deepinterpolation network to use. If None, the "unet_single_ephys_1024" network is used.
The network should be a callable that takes a dictionary as input and returns a deepinterpolation network.
See deepinterpolation.network_collection for examples.
use_gpu : bool, default: True
Whether to use GPU
disable_tf_logger : bool, default: True
Expand Down Expand Up @@ -151,7 +149,7 @@ def train_deepinterpolation(
nb_workers,
caching_validation,
run_uid,
network_name,
network,
use_gpu,
disable_tf_logger,
memory_gpu,
Expand Down Expand Up @@ -191,13 +189,12 @@ def train_deepinterpolation_process(
nb_workers: int = -1,
caching_validation: bool = False,
run_uid: str = "training",
network_name: str = "unet_single_ephys_1024",
network: Callable | None = None,
use_gpu: bool = True,
disable_tf_logger: bool = True,
memory_gpu: Optional[int] = None,
):
from deepinterpolation.trainor_collection import core_trainer
from deepinterpolation.generic import ClassLoader
from .generators import SpikeInterfaceRecordingGenerator

# initialize TF
Expand Down Expand Up @@ -232,11 +229,7 @@ def train_deepinterpolation_process(
):
warnings.warn("Training and testing overlap. This is not recommended.")

# Those are parameters used for the network topology
network_params = dict()
network_params["type"] = "network"
# Name of network topology in the collection
network_params["name"] = network_name if network_name is not None else "unet_single_ephys_1024"
# # Those are parameters used for the network topology
training_params = dict()
training_params["output_dir"] = str(trained_model_folder)
# We pass on the uid
Expand Down Expand Up @@ -280,18 +273,14 @@ def train_deepinterpolation_process(
total_samples=total_samples_testing,
)

network_json_path = trained_model_folder / "network_params.json"
with open(network_json_path, "w") as f:
json.dump(network_params, f)
if network is None:
from deepinterpolation.network_collection import unet_single_ephys_1024

network_obj = ClassLoader(network_json_path)
data_network = network_obj.find_and_build()(network_json_path)

training_json_path = trained_model_folder / "training_params.json"
with open(training_json_path, "w") as f:
json.dump(training_params, f)
network_obj = unet_single_ephys_1024({})
else:
network_obj = network({})

training_class = core_trainer(training_data_generator, test_data_generator, data_network, training_json_path)
training_class = core_trainer(training_data_generator, test_data_generator, network_obj, training_params)

if verbose:
print("Created objects for training. Running training job")
Expand Down

0 comments on commit 7f1f424

Please sign in to comment.