Skip to content

Commit

Permalink
Merge pull request #229 from catalystneuro/support_OnePhotonSeries
Browse files Browse the repository at this point in the history
Refactor `NwbImagingExtractor` to support extracting data from `OnePhotonSeries`
  • Loading branch information
CodyCBakerPhD authored Jun 21, 2023
2 parents 1b2eb7d + 475bd02 commit 8ca7ef3
Showing 1 changed file with 20 additions and 19 deletions.
39 changes: 20 additions & 19 deletions src/roiextractors/extractors/nwbextractors/nwbextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

try:
from pynwb import NWBHDF5IO
from pynwb.ophys import TwoPhotonSeries
from pynwb.ophys import TwoPhotonSeries, OnePhotonSeries

HAVE_NWB = True
except ImportError:
Expand Down Expand Up @@ -55,8 +55,8 @@ def __init__(self, file_path: PathType, optical_series_name: Optional[str] = "Tw
----------
file_path: str
The location of the folder containing dataset.nwb file
optical_series_name: str (optional)
optical series to extract data from
optical_series_name: string, optional
The name of the optical series to extract data from.
"""
ImagingExtractor.__init__(self)
self._path = file_path
Expand All @@ -73,33 +73,34 @@ def __init__(self, file_path: PathType, optical_series_name: Optional[str] = "Tw
raise ValueError("No acquisitions found in the .nwb file.")
self._optical_series_name = a_names[0]

self.two_photon_series = self.nwbfile.acquisition[self._optical_series_name]
assert isinstance(
self.two_photon_series, TwoPhotonSeries
), "The optical series must be of type pynwb.TwoPhotonSeries"
self.photon_series = self.nwbfile.acquisition[self._optical_series_name]
valid_photon_series_types = [OnePhotonSeries, TwoPhotonSeries]
assert any(
[isinstance(self.photon_series, photon_series_type) for photon_series_type in valid_photon_series_types]
), "The optical series must be of type pynwb.ophys.OnePhotonSeries or pynwb.ophys.TwoPhotonSeries."

# TODO if external file --> return another proper extractor (e.g. TiffImagingExtractor)
assert self.two_photon_series.external_file is None, "Only 'raw' format is currently supported"
assert self.photon_series.external_file is None, "Only 'raw' format is currently supported"

# Load the two video structures that TwoPhotonSeries supports.
self._data_has_channels_axis = True
if len(self.two_photon_series.data.shape) == 3:
if len(self.photon_series.data.shape) == 3:
self._num_channels = 1
self._num_frames, self._columns, self._num_rows = self.two_photon_series.data.shape
self._num_frames, self._columns, self._num_rows = self.photon_series.data.shape
else:
raise_multi_channel_or_depth_not_implemented(extractor_name=self.extractor_name)

# Set channel names (This should disambiguate which optical channel)
self._channel_names = [i.name for i in self.two_photon_series.imaging_plane.optical_channel]
self._channel_names = [i.name for i in self.photon_series.imaging_plane.optical_channel]

# Set sampling frequency
if hasattr(self.two_photon_series, "timestamps") and self.two_photon_series.timestamps:
self._sampling_frequency = 1.0 / np.median(np.diff(self.two_photon_series.timestamps))
self._imaging_start_time = self.two_photon_series.timestamps[0]
self.set_times(np.array(self.two_photon_series.timestamps))
if hasattr(self.photon_series, "timestamps") and self.photon_series.timestamps:
self._sampling_frequency = 1.0 / np.median(np.diff(self.photon_series.timestamps))
self._imaging_start_time = self.photon_series.timestamps[0]
self.set_times(np.array(self.photon_series.timestamps))
else:
self._sampling_frequency = self.two_photon_series.rate
self._imaging_start_time = self.two_photon_series.fields.get("starting_time", 0.0)
self._sampling_frequency = self.photon_series.rate
self._imaging_start_time = self.photon_series.fields.get("starting_time", 0.0)

# Fill epochs dictionary
self._epochs = {}
Expand Down Expand Up @@ -158,7 +159,7 @@ def get_frames(self, frame_idxs: ArrayType, channel: Optional[int] = 0):
slice_start = 0
slice_stop = self.get_num_frames()

data = self.two_photon_series.data
data = self.photon_series.data
frames = data[slice_start:slice_stop, ...].transpose([0, 2, 1])

if isinstance(frame_idxs, int):
Expand All @@ -169,7 +170,7 @@ def get_video(self, start_frame=None, end_frame=None, channel: Optional[int] = 0
start_frame = start_frame if start_frame is not None else 0
end_frame = end_frame if end_frame is not None else self.get_num_frames()

video = self.two_photon_series.data
video = self.photon_series.data
video = video[start_frame:end_frame].transpose([0, 2, 1])
return video

Expand Down

0 comments on commit 8ca7ef3

Please sign in to comment.