diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index 7aa4bb5b38..e248d43aea 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -459,24 +459,39 @@ def get_channel_locations(self) -> np.ndarray: class SharedMemoryTemplates(Templates): - def __init__(self, shm_name, shape, dtype, sampling_frequency, nbefore, sparsity_mask, - channel_ids, unit_ids, probe, is_scaled, main_shm_owner=True): + def __init__( + self, + shm_name, + shape, + dtype, + sampling_frequency, + nbefore, + sparsity_mask, + channel_ids, + unit_ids, + probe, + is_scaled, + main_shm_owner=True, + ): assert len(shape) == 3 assert shape[0] > 0, "SharedMemoryTemplates only supported with no empty templates" self.shm = SharedMemory(shm_name, create=False) templates_array = np.ndarray(shape=shape, dtype=dtype, buffer=self.shm.buf) - - Templates.__init__(self, templates_array=templates_array, - sampling_frequency=sampling_frequency, - nbefore=nbefore, - sparsity_mask=sparsity_mask, - channel_ids=channel_ids, - unit_ids=unit_ids, - probe=probe, - is_scaled=is_scaled) - + + Templates.__init__( + self, + templates_array=templates_array, + sampling_frequency=sampling_frequency, + nbefore=nbefore, + sparsity_mask=sparsity_mask, + channel_ids=channel_ids, + unit_ids=unit_ids, + probe=probe, + is_scaled=is_scaled, + ) + # self._serializability["memory"] = True # self._serializability["json"] = False # self._serializability["pickle"] = False @@ -524,4 +539,4 @@ def from_templates(templates): main_shm_owner=True, ) shm.close() - return shared_templates \ No newline at end of file + return shared_templates diff --git a/src/spikeinterface/core/tests/test_template_class.py b/src/spikeinterface/core/tests/test_template_class.py index ee694892a1..b38d3cc2ca 100644 --- a/src/spikeinterface/core/tests/test_template_class.py +++ b/src/spikeinterface/core/tests/test_template_class.py @@ -170,6 +170,7 @@ def test_select_channels(template_type, is_scaled): if template.sparsity_mask is not None: assert np.array_equal(selected_template.sparsity_mask, template.sparsity_mask[:, selected_channel_ids_indices]) + @pytest.mark.parametrize("is_scaled", [True, False]) @pytest.mark.parametrize("template_type", ["dense"]) def test_shm_templates(template_type, is_scaled): @@ -179,13 +180,12 @@ def test_shm_templates(template_type, is_scaled): # Verify that the channel ids match assert np.array_equal(shm_templates.channel_ids, template.channel_ids) # Verify that the templates data matches - assert np.array_equal( - shm_templates.templates_array, template.templates_array - ) + assert np.array_equal(shm_templates.templates_array, template.templates_array) if template.sparsity_mask is not None: assert np.array_equal(shm_templates.sparsity_mask, template.sparsity_mask) + if __name__ == "__main__": # test_json_serialization("sparse") test_json_serialization("dense")