Skip to content

Commit

Permalink
Merge pull request #2889 from alejoe91/0.100-fix-waveforms-save
Browse files Browse the repository at this point in the history
[0.100.x-bug-fixes] Fix waveforms save in recordingless mode
  • Loading branch information
alejoe91 authored Jun 3, 2024
2 parents f21a7b7 + 0c4be33 commit a75f1d7
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 7 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ dependencies = [
"numpy",
"threadpoolctl>=3.0.0",
"tqdm",
"zarr>=0.2.16",
"zarr>=2.16,<2.18",
"neo>=0.13.0",
"probeinterface>=0.2.21",
]
Expand Down
20 changes: 14 additions & 6 deletions src/spikeinterface/core/waveform_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -911,9 +911,11 @@ def save(
)
if self.recording.get_probegroup() is not None:
probegroup = self.recording.get_probegroup()
with_recording = True
else:
rec_attributes = deepcopy(self._rec_attributes)
probegroup = rec_attributes["probegroup"]
probegroup = rec_attributes.pop("probegroup")
with_recording = False

if self.is_sparse():
assert sparsity is None, "WaveformExtractor is already sparse!"
Expand Down Expand Up @@ -1026,7 +1028,7 @@ def save(
name=f"sampled_index_{unit_id}", data=sampled_indices, compressor=compressor
)

new_we = WaveformExtractor.load(folder)
new_we = WaveformExtractor.load(folder, with_recording=with_recording)

# save waveform extensions
for ext_name in self.get_available_extension_names():
Expand Down Expand Up @@ -1926,10 +1928,16 @@ def load(cls, folder, waveform_extractor):
params = cls.load_params_from_folder(folder)

if "sparsity" in params and params["sparsity"] is not None:
params["sparsity"] = ChannelSparsity.from_dict(params["sparsity"])

# if waveform_extractor is None:
# waveform_extractor = WaveformExtractor.load(folder)
sparsity_params = params["sparsity"]
# handle old sparsity version
if "unit_ids" not in params["sparsity"]:
sparsity_params = {}
sparsity_params["unit_ids"] = waveform_extractor.unit_ids
sparsity_params["channel_ids"] = waveform_extractor.channel_ids
sparsity_params["unit_id_to_channel_ids"] = params["sparsity"]
else:
sparsity_params = params["sparsity"]
params["sparsity"] = ChannelSparsity.from_dict(sparsity_params)

# make instance with params
ext = cls(waveform_extractor)
Expand Down

0 comments on commit a75f1d7

Please sign in to comment.