Skip to content

Commit

Permalink
Merge pull request #2896 from samuelgarcia/some_fix
Browse files Browse the repository at this point in the history
Important fixes
  • Loading branch information
samuelgarcia authored May 23, 2024
2 parents 29ad02b + 586fab6 commit 74e9dcf
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 12 deletions.
6 changes: 3 additions & 3 deletions src/spikeinterface/core/analyzer_extension_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,9 +488,9 @@ def get_templates(self, unit_ids=None, operator="average", percentile=None, save
self.params["operators"] += [(operator, percentile)]
templates_array = self.data[key]

if save:
if not self.sorting_analyzer.is_read_only():
self.save()
if save:
if not self.sorting_analyzer.is_read_only():
self.save()

if unit_ids is not None:
unit_indices = self.sorting_analyzer.sorting.ids_to_indices(unit_ids)
Expand Down
14 changes: 6 additions & 8 deletions src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,9 +585,10 @@ def load_from_zarr(cls, folder, recording=None):
rec_attributes["probegroup"] = None

# sparsity
if "sparsity_mask" in zarr_root.attrs:
# sparsity = zarr_root.attrs["sparsity"]
sparsity = ChannelSparsity(zarr_root["sparsity_mask"], cls.unit_ids, rec_attributes["channel_ids"])
if "sparsity_mask" in zarr_root:
sparsity = ChannelSparsity(
np.array(zarr_root["sparsity_mask"]), sorting.unit_ids, rec_attributes["channel_ids"]
)
else:
sparsity = None

Expand Down Expand Up @@ -1596,10 +1597,6 @@ def load_data(self):
self.data[ext_data_name] = ext_data

elif self.format == "zarr":
# Alessio
# TODO: we need decide if we make a copy to memory or keep the lazy loading. For binary_folder it used to be lazy with memmap
# but this make the garbage complicated when a data is hold by a plot but the o SortingAnalyzer is delete
# lets talk
extension_group = self._get_zarr_extension_group(mode="r")
for ext_data_name in extension_group.keys():
ext_data_ = extension_group[ext_data_name]
Expand All @@ -1615,7 +1612,8 @@ def load_data(self):
elif "object" in ext_data_.attrs:
ext_data = ext_data_[0]
else:
ext_data = ext_data_
# this load in memmory
ext_data = np.array(ext_data_)
self.data[ext_data_name] = ext_data

def copy(self, new_sorting_analyzer, unit_ids=None):
Expand Down
8 changes: 8 additions & 0 deletions src/spikeinterface/core/sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,14 @@ def __repr__(self):
txt = f"ChannelSparsity - units: {self.num_units} - channels: {self.num_channels} - density, P(x=1): {density:0.2f}"
return txt

def __eq__(self, other):
return (
isinstance(other, ChannelSparsity)
and np.array_equal(self.channel_ids, other.channel_ids)
and np.array_equal(self.unit_ids, other.unit_ids)
and np.array_equal(self.mask, other.mask)
)

@property
def unit_id_to_channel_ids(self):
if self._unit_id_to_channel_ids is None:
Expand Down
5 changes: 5 additions & 0 deletions src/spikeinterface/core/tests/test_sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,15 @@ def _check_sorting_analyzers(sorting_analyzer, original_sorting, cache_folder):

data = sorting_analyzer2.get_extension("dummy").data
assert "result_one" in data
assert isinstance(data["result_one"], str)
assert isinstance(data["result_two"], np.ndarray)
assert data["result_two"].size == original_sorting.to_spike_vector().size
assert np.array_equal(data["result_two"], sorting_analyzer.get_extension("dummy").data["result_two"])

assert sorting_analyzer2.return_scaled == sorting_analyzer.return_scaled

assert sorting_analyzer2.sparsity == sorting_analyzer.sparsity

# select unit_ids to several format
for format in ("memory", "binary_folder", "zarr"):
if format != "memory":
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/widgets/utils_matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def make_mpl_figure(figure=None, ax=None, axes=None, ncols=None, num_axes=None,
if "ipympl" not in matplotlib.get_backend():
ax = figure.add_subplot(111)
else:
ax = figure.add_subplot(111, layout="constrained")
ax = figure.add_subplot(111)
axes = np.array([[ax]])
else:
assert ncols is not None
Expand Down

0 comments on commit 74e9dcf

Please sign in to comment.