Skip to content

Commit

Permalink
Add extra_requirements after init
Browse files Browse the repository at this point in the history
  • Loading branch information
alejoe91 committed Mar 28, 2024
1 parent 09b7b70 commit ac48bc7
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/spikeinterface/extractors/nwbextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,6 @@ def __init__(
segment_data,
times_kwargs,
) = self._fetch_recording_segment_info_pynwb(file, cache, load_time_vector, samples_for_rate_estimation)
self.extra_requirements.append("pynwb")
else:
(
channel_ids,
Expand All @@ -549,7 +548,6 @@ def __init__(
segment_data,
times_kwargs,
) = self._fetch_recording_segment_info_backend(file, cache, load_time_vector, samples_for_rate_estimation)
self.extra_requirements.append("h5py")
BaseRecording.__init__(self, channel_ids=channel_ids, sampling_frequency=sampling_frequency, dtype=dtype)
recording_segment = NwbRecordingSegment(
electrical_series_data=segment_data,
Expand All @@ -560,8 +558,10 @@ def __init__(
# fetch and add main recording properties
if use_pynwb:
gains, offsets, locations, groups = self._fetch_main_properties_pynwb()
self.extra_requirements.append("pynwb")
else:
gains, offsets, locations, groups = self._fetch_main_properties_backend()
self.extra_requirements.append("h5py")
self.set_channel_gains(gains)
self.set_channel_offsets(offsets)
if locations is not None:
Expand Down Expand Up @@ -990,12 +990,10 @@ def __init__(
unit_ids, spike_times_data, spike_times_index_data = self._fetch_sorting_segment_info_pynwb(
unit_table_path=unit_table_path, samples_for_rate_estimation=samples_for_rate_estimation, cache=cache
)
self.extra_requirements.append("pynwb")
else:
unit_ids, spike_times_data, spike_times_index_data = self._fetch_sorting_segment_info_backend(
unit_table_path=unit_table_path, samples_for_rate_estimation=samples_for_rate_estimation, cache=cache
)
self.extra_requirements.append("h5py")

BaseSorting.__init__(
self, sampling_frequency=self.provided_or_electrical_series_sampling_frequency, unit_ids=unit_ids
Expand All @@ -1013,8 +1011,10 @@ def __init__(
if load_unit_properties:
if use_pynwb:
columns = [c.name for c in self.units_table.columns]
self.extra_requirements.append("pynwb")
else:
columns = list(self.units_table.keys())
self.extra_requirements.append("h5py")
properties = self._fetch_properties(columns)
for property_name, property_values in properties.items():
values = [x.decode("utf-8") if isinstance(x, bytes) else x for x in property_values]
Expand Down

0 comments on commit ac48bc7

Please sign in to comment.