Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Dec 13, 2024
1 parent 5ed4f03 commit 21300a1
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 16 deletions.
41 changes: 28 additions & 13 deletions src/spikeinterface/core/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,24 +459,39 @@ def get_channel_locations(self) -> np.ndarray:

class SharedMemoryTemplates(Templates):

def __init__(self, shm_name, shape, dtype, sampling_frequency, nbefore, sparsity_mask,
channel_ids, unit_ids, probe, is_scaled, main_shm_owner=True):
def __init__(
self,
shm_name,
shape,
dtype,
sampling_frequency,
nbefore,
sparsity_mask,
channel_ids,
unit_ids,
probe,
is_scaled,
main_shm_owner=True,
):

assert len(shape) == 3
assert shape[0] > 0, "SharedMemoryTemplates only supported with no empty templates"

self.shm = SharedMemory(shm_name, create=False)
templates_array = np.ndarray(shape=shape, dtype=dtype, buffer=self.shm.buf)

Templates.__init__(self, templates_array=templates_array,
sampling_frequency=sampling_frequency,
nbefore=nbefore,
sparsity_mask=sparsity_mask,
channel_ids=channel_ids,
unit_ids=unit_ids,
probe=probe,
is_scaled=is_scaled)


Templates.__init__(
self,
templates_array=templates_array,
sampling_frequency=sampling_frequency,
nbefore=nbefore,
sparsity_mask=sparsity_mask,
channel_ids=channel_ids,
unit_ids=unit_ids,
probe=probe,
is_scaled=is_scaled,
)

# self._serializability["memory"] = True
# self._serializability["json"] = False
# self._serializability["pickle"] = False
Expand Down Expand Up @@ -524,4 +539,4 @@ def from_templates(templates):
main_shm_owner=True,
)
shm.close()
return shared_templates
return shared_templates
6 changes: 3 additions & 3 deletions src/spikeinterface/core/tests/test_template_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ def test_select_channels(template_type, is_scaled):
if template.sparsity_mask is not None:
assert np.array_equal(selected_template.sparsity_mask, template.sparsity_mask[:, selected_channel_ids_indices])


@pytest.mark.parametrize("is_scaled", [True, False])
@pytest.mark.parametrize("template_type", ["dense"])
def test_shm_templates(template_type, is_scaled):
Expand All @@ -179,13 +180,12 @@ def test_shm_templates(template_type, is_scaled):
# Verify that the channel ids match
assert np.array_equal(shm_templates.channel_ids, template.channel_ids)
# Verify that the templates data matches
assert np.array_equal(
shm_templates.templates_array, template.templates_array
)
assert np.array_equal(shm_templates.templates_array, template.templates_array)

if template.sparsity_mask is not None:
assert np.array_equal(shm_templates.sparsity_mask, template.sparsity_mask)


if __name__ == "__main__":
# test_json_serialization("sparse")
test_json_serialization("dense")

0 comments on commit 21300a1

Please sign in to comment.