Skip to content

Commit

Permalink
Merge pull request #196 from catalystneuro/remove_spikeextractors
Browse files Browse the repository at this point in the history
Remove spikeextractors dependency
  • Loading branch information
CodyCBakerPhD authored Aug 22, 2022
2 parents 304da82 + 00d666d commit d6610c2
Show file tree
Hide file tree
Showing 8 changed files with 51 additions and 37 deletions.
14 changes: 8 additions & 6 deletions .github/workflows/testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,14 @@ jobs:
git config --global user.email "[email protected]"
git config --global user.name "CI Almighty"
pip install wheel # needed for scanimage
- name: Test minimal installation
run: pip install .
- name: Test full installation
run: pip install .[full]
- name: Install testing requirements (-e needed for codecov report)
run: pip install -e .[test]
- name: Install roiextractors with minimal requirements
run: pip install .[test]
- name: Run minimal tests
run: pytest tests/test_internals -n auto --dist loadscope

- name: Test full installation (-e needed for codecov report)
run: pip install -e .[full]

- name: Get ophys_testing_data current head hash
id: ophys
Expand Down
1 change: 0 additions & 1 deletion requirements-minimal.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
h5py>=2.10.0
pynwb>=2.0.1
spikeextractors>=0.9.0
tqdm>=4.48.2
lazy_ops>=0.2.0
dill>=0.3.2
Expand Down
1 change: 1 addition & 0 deletions requirements-testing.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pytest
pytest-cov
parameterized==0.8.1
spikeextractors>=0.9.10
4 changes: 1 addition & 3 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from pathlib import Path
from setuptools import setup, find_packages
from copy import copy
from shutil import copy as copy_file


Expand All @@ -11,9 +10,8 @@
install_requires = f.readlines()
with open(root / "requirements-full.txt") as f:
full_dependencies = f.readlines()
testing_dependencies = copy(full_dependencies)
with open(root / "requirements-testing.txt") as f:
testing_dependencies.extend(f.readlines())
testing_dependencies = f.readlines()
extras_require = dict(full=full_dependencies, test=testing_dependencies)

# Create a local copy for the gin test configuration file based on the master file `base_gin_test_config.json`
Expand Down
3 changes: 2 additions & 1 deletion src/roiextractors/example_datasets/toy_example.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import numpy as np
import spikeextractors as se

from ..extractors.numpyextractors import (
NumpyImagingExtractor,
Expand Down Expand Up @@ -110,6 +109,8 @@ def toy_example(
The output segmentation extractor
"""
import spikeextractors as se

# generate ROIs
num_rois = int(num_rois)
roi_pixels, im, means = _generate_rois(
Expand Down
31 changes: 28 additions & 3 deletions src/roiextractors/extraction_tools.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from functools import wraps
from pathlib import Path
from typing import Union, Tuple
from dataclasses import dataclass, field
from dataclasses import dataclass

import lazy_ops
import scipy
import numpy as np
from numpy.typing import ArrayLike, DTypeLike
from tqdm import tqdm
from spikeextractors.extraction_tools import cast_start_end_frame

try:
import h5py
Expand All @@ -22,6 +21,13 @@
else:
from scipy.io.matlab.mio5_params import mat_struct

HAVE_Scipy = True
except AttributeError:
if hasattr(scipy, "io") and hasattr(scipy.io.matlab, "mat_struct"):
from scipy.io import mat_struct
else:
from scipy.io.matlab.mio5_params import mat_struct

HAVE_Scipy = True
except ImportError:
HAVE_Scipy = False
Expand Down Expand Up @@ -306,6 +312,25 @@ def corrected_args(imaging, frame_idxs, channel=0):
return corrected_args


def _cast_start_end_frame(start_frame, end_frame):
if isinstance(start_frame, float):
start_frame = int(start_frame)
elif isinstance(start_frame, (int, np.integer, type(None))):
start_frame = start_frame
else:
raise ValueError("start_frame must be an int, float (not infinity), or None")
if isinstance(end_frame, float) and np.isfinite(end_frame):
end_frame = int(end_frame)
elif isinstance(end_frame, (int, np.integer, type(None))):
end_frame = end_frame
# else end_frame is infinity (accepted for get_unit_spike_train)
if start_frame is not None:
start_frame = int(start_frame)
if end_frame is not None and np.isfinite(end_frame):
end_frame = int(end_frame)
return start_frame, end_frame


def check_get_videos_args(func):
@wraps(func)
def corrected_args(imaging, start_frame=None, end_frame=None, channel=0):
Expand All @@ -325,7 +350,7 @@ def corrected_args(imaging, start_frame=None, end_frame=None, channel=0):
end_frame = imaging.get_num_frames()
assert end_frame - start_frame > 0, "'start_frame' must be less than 'end_frame'!"

start_frame, end_frame = cast_start_end_frame(start_frame, end_frame)
start_frame, end_frame = _cast_start_end_frame(start_frame, end_frame)
channel = int(channel)
get_videos_correct_arg = func(imaging, start_frame=start_frame, end_frame=end_frame, channel=channel)

Expand Down
25 changes: 9 additions & 16 deletions src/roiextractors/imagingextractor.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,20 @@
"""Base class definitions for all ImagingExtractors."""
from abc import ABC, abstractmethod
from typing import Union, Optional, Tuple
import numpy as np
from copy import deepcopy

from spikeextractors.baseextractor import BaseExtractor
import numpy as np

from .extraction_tools import (
ArrayType,
PathType,
DtypeType,
FloatType,
check_get_videos_args,
)
from .extraction_tools import ArrayType, PathType, DtypeType, FloatType


class ImagingExtractor(ABC, BaseExtractor):
class ImagingExtractor(ABC):
"""Abstract class that contains all the meta-data and input data from the imaging data."""

def __init__(self) -> None:
BaseExtractor.__init__(self)
assert self.installed, self.installation_mesg
self._memmapped = False
def __init__(self, *args, **kwargs) -> None:
self._args = args
self._kwargs = kwargs
self._times = None

@abstractmethod
def get_image_size(self) -> Tuple[int, int]:
Expand Down Expand Up @@ -119,12 +112,12 @@ def set_times(self, times: ArrayType) -> None:
assert len(times) == self.get_num_frames(), "'times' should have the same length of the number of frames!"
self._times = np.array(times).astype("float64")

def copy_times(self, extractor: BaseExtractor) -> None:
def copy_times(self, extractor) -> None:
"""This function copies times from another extractor.
Parameters
----------
extractor: BaseExtractor
extractor
The extractor from which the epochs will be copied
"""
if extractor._times is not None:
Expand Down
9 changes: 2 additions & 7 deletions src/roiextractors/segmentationextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@
from typing import Union

import numpy as np
from spikeextractors.baseextractor import BaseExtractor

from .extraction_tools import ArrayType, IntType, FloatType
from .extraction_tools import _pixel_mask_extractor


class SegmentationExtractor(ABC, BaseExtractor):
class SegmentationExtractor(ABC):
"""
An abstract class that contains all the meta-data and output data from
the ROI segmentation operation when applied to the pre-processed data.
Expand All @@ -18,13 +17,9 @@ class SegmentationExtractor(ABC, BaseExtractor):
format specific classes that inherit from this.
"""

installed = True
installation_mesg = ""

def __init__(self):
assert self.installed, self.installation_mesg
BaseExtractor.__init__(self)
self._sampling_frequency = None
self._times = None
self._channel_names = ["OpticalChannel"]
self._num_planes = 1
self._roi_response_raw = None
Expand Down

0 comments on commit d6610c2

Please sign in to comment.