diff --git a/src/roiextractors/extractors/nwbextractors/nwbextractors.py b/src/roiextractors/extractors/nwbextractors/nwbextractors.py index 20fb3be6..00c0d5f1 100644 --- a/src/roiextractors/extractors/nwbextractors/nwbextractors.py +++ b/src/roiextractors/extractors/nwbextractors/nwbextractors.py @@ -6,7 +6,7 @@ try: from pynwb import NWBHDF5IO - from pynwb.ophys import TwoPhotonSeries + from pynwb.ophys import TwoPhotonSeries, OnePhotonSeries HAVE_NWB = True except ImportError: @@ -55,8 +55,8 @@ def __init__(self, file_path: PathType, optical_series_name: Optional[str] = "Tw ---------- file_path: str The location of the folder containing dataset.nwb file - optical_series_name: str (optional) - optical series to extract data from + optical_series_name: string, optional + The name of the optical series to extract data from. """ ImagingExtractor.__init__(self) self._path = file_path @@ -73,33 +73,34 @@ def __init__(self, file_path: PathType, optical_series_name: Optional[str] = "Tw raise ValueError("No acquisitions found in the .nwb file.") self._optical_series_name = a_names[0] - self.two_photon_series = self.nwbfile.acquisition[self._optical_series_name] - assert isinstance( - self.two_photon_series, TwoPhotonSeries - ), "The optical series must be of type pynwb.TwoPhotonSeries" + self.photon_series = self.nwbfile.acquisition[self._optical_series_name] + valid_photon_series_types = [OnePhotonSeries, TwoPhotonSeries] + assert any( + [isinstance(self.photon_series, photon_series_type) for photon_series_type in valid_photon_series_types] + ), "The optical series must be of type pynwb.ophys.OnePhotonSeries or pynwb.ophys.TwoPhotonSeries." # TODO if external file --> return another proper extractor (e.g. TiffImagingExtractor) - assert self.two_photon_series.external_file is None, "Only 'raw' format is currently supported" + assert self.photon_series.external_file is None, "Only 'raw' format is currently supported" # Load the two video structures that TwoPhotonSeries supports. self._data_has_channels_axis = True - if len(self.two_photon_series.data.shape) == 3: + if len(self.photon_series.data.shape) == 3: self._num_channels = 1 - self._num_frames, self._columns, self._num_rows = self.two_photon_series.data.shape + self._num_frames, self._columns, self._num_rows = self.photon_series.data.shape else: raise_multi_channel_or_depth_not_implemented(extractor_name=self.extractor_name) # Set channel names (This should disambiguate which optical channel) - self._channel_names = [i.name for i in self.two_photon_series.imaging_plane.optical_channel] + self._channel_names = [i.name for i in self.photon_series.imaging_plane.optical_channel] # Set sampling frequency - if hasattr(self.two_photon_series, "timestamps") and self.two_photon_series.timestamps: - self._sampling_frequency = 1.0 / np.median(np.diff(self.two_photon_series.timestamps)) - self._imaging_start_time = self.two_photon_series.timestamps[0] - self.set_times(np.array(self.two_photon_series.timestamps)) + if hasattr(self.photon_series, "timestamps") and self.photon_series.timestamps: + self._sampling_frequency = 1.0 / np.median(np.diff(self.photon_series.timestamps)) + self._imaging_start_time = self.photon_series.timestamps[0] + self.set_times(np.array(self.photon_series.timestamps)) else: - self._sampling_frequency = self.two_photon_series.rate - self._imaging_start_time = self.two_photon_series.fields.get("starting_time", 0.0) + self._sampling_frequency = self.photon_series.rate + self._imaging_start_time = self.photon_series.fields.get("starting_time", 0.0) # Fill epochs dictionary self._epochs = {} @@ -158,7 +159,7 @@ def get_frames(self, frame_idxs: ArrayType, channel: Optional[int] = 0): slice_start = 0 slice_stop = self.get_num_frames() - data = self.two_photon_series.data + data = self.photon_series.data frames = data[slice_start:slice_stop, ...].transpose([0, 2, 1]) if isinstance(frame_idxs, int): @@ -169,7 +170,7 @@ def get_video(self, start_frame=None, end_frame=None, channel: Optional[int] = 0 start_frame = start_frame if start_frame is not None else 0 end_frame = end_frame if end_frame is not None else self.get_num_frames() - video = self.two_photon_series.data + video = self.photon_series.data video = video[start_frame:end_frame].transpose([0, 2, 1]) return video