diff --git a/src/spikeinterface/extractors/nwbextractors.py b/src/spikeinterface/extractors/nwbextractors.py index d213126f34..9aa8b1b907 100644 --- a/src/spikeinterface/extractors/nwbextractors.py +++ b/src/spikeinterface/extractors/nwbextractors.py @@ -401,7 +401,40 @@ def _retrieve_electrodes_indices_from_electrical_series_backend(open_file, elect return electrodes_indices -class NwbRecordingExtractor(BaseRecording): +class _BaseNWBExtractor: + "A class for common methods for NWB extractors." + + def _close_hdf5_file(self): + has_hdf5_backend = hasattr(self, "_file") + if has_hdf5_backend: + import h5py + + main_file_id = self._file.id + open_object_ids_main = h5py.h5f.get_obj_ids(main_file_id, types=h5py.h5f.OBJ_ALL) + for object_id in open_object_ids_main: + object_name = h5py.h5i.get_name(object_id).decode("utf-8") + try: + object_id.close() + except: + import warnings + + warnings.warn(f"Error closing object {object_name}") + + def __del__(self): + # backend mode + if hasattr(self, "_file"): + if hasattr(self._file, "store"): + self._file.store.close() + else: + self._close_hdf5_file() + # pynwb mode + elif hasattr(self, "_nwbfile"): + io = self._nwbfile.get_read_io() + if io is not None: + io.close() + + +class NwbRecordingExtractor(BaseRecording, _BaseNWBExtractor): """Load an NWBFile as a RecordingExtractor. Parameters @@ -626,35 +659,6 @@ def __init__( "file": file, } - def _close_hdf5_file(self): - has_hdf5_backend = hasattr(self, "_file") - if has_hdf5_backend: - import h5py - - main_file_id = self._file.id - open_object_ids_main = h5py.h5f.get_obj_ids(main_file_id, types=h5py.h5f.OBJ_ALL) - for object_id in open_object_ids_main: - object_name = h5py.h5i.get_name(object_id).decode("utf-8") - try: - object_id.close() - except: - import warnings - - warnings.warn(f"Error closing object {object_name}") - - def __del__(self): - # backend mode - if hasattr(self, "_file"): - if hasattr(self._file, "store"): - self._file.store.close() - else: - self._close_hdf5_file() - # pynwb mode - elif hasattr(self, "_nwbfile"): - io = self._nwbfile.get_read_io() - if io is not None: - io.close() - def _fetch_recording_segment_info_pynwb(self, file, cache, load_time_vector, samples_for_rate_estimation): self._nwbfile = read_nwbfile( backend=self.backend, @@ -968,7 +972,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): return traces -class NwbSortingExtractor(BaseSorting): +class NwbSortingExtractor(BaseSorting, _BaseNWBExtractor): """Load an NWBFile as a SortingExtractor. Parameters ---------- @@ -1127,41 +1131,6 @@ def __init__( "t_start": self.t_start, } - def _close_hdf5_file(self): - has_hdf5_backend = hasattr(self, "_file") - if has_hdf5_backend: - import h5py - - main_file_id = self._file.id - open_object_ids_main = h5py.h5f.get_obj_ids(main_file_id, types=h5py.h5f.OBJ_ALL) - for object_id in open_object_ids_main: - object_name = h5py.h5i.get_name(object_id).decode("utf-8") - try: - object_id.close() - except: - import warnings - - warnings.warn(f"Error closing object {object_name}") - - def __del__(self): - # backend mode - if hasattr(self, "_file"): - if hasattr(self._file, "store"): - self._file.store.close() - else: - self._close_hdf5_file() - # pynwb mode - elif hasattr(self, "_nwbfile"): - io = self._nwbfile.get_read_io() - if io is not None: - io.close() - - # pynwb mode - elif hasattr(self, "_nwbfile"): # hdf - io = self._nwbfile.get_read_io() - if io is not None: - io.close() - def _fetch_sorting_segment_info_pynwb( self, unit_table_path: str = None, samples_for_rate_estimation: int = 1000, cache: bool = False ):