diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 415ca42548..cefd7bd950 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -118,6 +118,14 @@ def __repr__(self): txt = f"ChannelSparsity - units: {self.num_units} - channels: {self.num_channels} - density, P(x=1): {density:0.2f}" return txt + def __eq__(self, other): + return ( + isinstance(other, ChannelSparsity) + and np.array_equal(self.channel_ids, other.channel_ids) + and np.array_equal(self.unit_ids, other.unit_ids) + and np.array_equal(self.mask, other.mask) + ) + @property def unit_id_to_channel_ids(self): if self._unit_id_to_channel_ids is None: diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index faed5161c6..13e01c32da 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -155,10 +155,15 @@ def _check_sorting_analyzers(sorting_analyzer, original_sorting, cache_folder): data = sorting_analyzer2.get_extension("dummy").data assert "result_one" in data + assert isinstance(data["result_one"], str) + assert isinstance(data["result_two"], np.ndarray) assert data["result_two"].size == original_sorting.to_spike_vector().size + assert np.array_equal(data["result_two"], sorting_analyzer.get_extension("dummy").data["result_two"]) assert sorting_analyzer2.return_scaled == sorting_analyzer.return_scaled + assert sorting_analyzer2.sparsity == sorting_analyzer.sparsity + # select unit_ids to several format for format in ("memory", "binary_folder", "zarr"): if format != "memory":