From c99d57f4fa71e5ddc94ec5a8d0ba9bb742ea9fb3 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Fri, 26 Apr 2024 11:39:47 -0600 Subject: [PATCH 1/6] add template generation function --- src/spikeinterface/core/template.py | 5 +- src/spikeinterface/generation/__init__.py | 5 ++ .../generation/_template_database.py | 47 +++++++++++++++++++ 3 files changed, 55 insertions(+), 2 deletions(-) create mode 100644 src/spikeinterface/generation/_template_database.py diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index 51688709b2..b0afaef7ce 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -229,12 +229,13 @@ def add_templates_to_zarr_group(self, zarr_group: "zarr.Group") -> None: The `templates_array` dataset is saved with a chunk size that has a single unit per chunk to optimize read/write operations for individual units. """ + import numcodecs # Saves one chunk per unit arrays_chunk = (1, None, None) zarr_group.create_dataset("templates_array", data=self.templates_array, chunks=arrays_chunk) - zarr_group.create_dataset("channel_ids", data=self.channel_ids) - zarr_group.create_dataset("unit_ids", data=self.unit_ids) + zarr_group.create_dataset("channel_ids", data=self.channel_ids, object_codec=numcodecs.MsgPack()) + zarr_group.create_dataset("unit_ids", data=self.unit_ids, object_codec=numcodecs.MsgPack()) zarr_group.attrs["sampling_frequency"] = self.sampling_frequency zarr_group.attrs["nbefore"] = self.nbefore diff --git a/src/spikeinterface/generation/__init__.py b/src/spikeinterface/generation/__init__.py index 4015f3f75e..f82d2ddf20 100644 --- a/src/spikeinterface/generation/__init__.py +++ b/src/spikeinterface/generation/__init__.py @@ -11,3 +11,8 @@ generate_displacement_vector, generate_drifting_recording, ) + +from ._template_database import ( + fetch_templates_from_database, + generate_recording_from_template_database, +) diff --git a/src/spikeinterface/generation/_template_database.py b/src/spikeinterface/generation/_template_database.py new file mode 100644 index 0000000000..134ce1d81e --- /dev/null +++ b/src/spikeinterface/generation/_template_database.py @@ -0,0 +1,47 @@ +from spikeinterface.core.template import Templates +from spikeinterface.core import generate_sorting, InjectTemplatesRecording +import zarr + + +def fetch_templates_from_database(dataset="test_templates"): + + import s3fs + + s3 = s3fs.S3FileSystem(anon=False, client_kwargs={"region_name": "us-east-2"}) + + # Specify the S3 bucket and path where your Zarr dataset is stored + store = s3fs.S3Map(root=f"spikeinterface-template-database/{dataset}", s3=s3) + + # Load the Zarr group from S3 + zarr_group = zarr.open(store, mode="r") + + templates_object = Templates.from_zarr_group(zarr_group) + + return templates_object + + +def generate_recording_from_template_database(selected_unit_inidces=None, dataset="test_templates", durations=None): + + durations = durations or [10.0] + + templates_object = fetch_templates_from_database(dataset=dataset) + + if selected_unit_inidces: + selected_templates = templates_object.templates_array[selected_unit_inidces, :, :] + else: + selected_templates = templates_object.templates_array + + num_units = selected_templates.shape[0] + sampling_frequency = templates_object.sampling_frequency + sorting = generate_sorting(num_units=num_units, sampling_frequency=sampling_frequency, durations=durations) + + nbefore = templates_object.nbefore + num_samples = durations[0] * sampling_frequency + recording = InjectTemplatesRecording( + sorting=sorting, + templates=selected_templates, + nbefore=nbefore, + num_samples=[num_samples], + ) + + return recording From 800d8802e5a27bfd3661e68f3498c3fbee6a4246 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Fri, 26 Apr 2024 12:03:09 -0600 Subject: [PATCH 2/6] revert modifications to template --- src/spikeinterface/core/template.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index b0afaef7ce..51688709b2 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -229,13 +229,12 @@ def add_templates_to_zarr_group(self, zarr_group: "zarr.Group") -> None: The `templates_array` dataset is saved with a chunk size that has a single unit per chunk to optimize read/write operations for individual units. """ - import numcodecs # Saves one chunk per unit arrays_chunk = (1, None, None) zarr_group.create_dataset("templates_array", data=self.templates_array, chunks=arrays_chunk) - zarr_group.create_dataset("channel_ids", data=self.channel_ids, object_codec=numcodecs.MsgPack()) - zarr_group.create_dataset("unit_ids", data=self.unit_ids, object_codec=numcodecs.MsgPack()) + zarr_group.create_dataset("channel_ids", data=self.channel_ids) + zarr_group.create_dataset("unit_ids", data=self.unit_ids) zarr_group.attrs["sampling_frequency"] = self.sampling_frequency zarr_group.attrs["nbefore"] = self.nbefore From 0b2695000331767a7fbda7ce7a7af0dcd5ff4783 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 30 Apr 2024 09:28:13 -0600 Subject: [PATCH 3/6] use open consolidated --- src/spikeinterface/generation/__init__.py | 2 +- .../generation/_template_database.py | 47 ------------------- 2 files changed, 1 insertion(+), 48 deletions(-) delete mode 100644 src/spikeinterface/generation/_template_database.py diff --git a/src/spikeinterface/generation/__init__.py b/src/spikeinterface/generation/__init__.py index f82d2ddf20..91f9094f6b 100644 --- a/src/spikeinterface/generation/__init__.py +++ b/src/spikeinterface/generation/__init__.py @@ -12,7 +12,7 @@ generate_drifting_recording, ) -from ._template_database import ( +from .template_database import ( fetch_templates_from_database, generate_recording_from_template_database, ) diff --git a/src/spikeinterface/generation/_template_database.py b/src/spikeinterface/generation/_template_database.py deleted file mode 100644 index 134ce1d81e..0000000000 --- a/src/spikeinterface/generation/_template_database.py +++ /dev/null @@ -1,47 +0,0 @@ -from spikeinterface.core.template import Templates -from spikeinterface.core import generate_sorting, InjectTemplatesRecording -import zarr - - -def fetch_templates_from_database(dataset="test_templates"): - - import s3fs - - s3 = s3fs.S3FileSystem(anon=False, client_kwargs={"region_name": "us-east-2"}) - - # Specify the S3 bucket and path where your Zarr dataset is stored - store = s3fs.S3Map(root=f"spikeinterface-template-database/{dataset}", s3=s3) - - # Load the Zarr group from S3 - zarr_group = zarr.open(store, mode="r") - - templates_object = Templates.from_zarr_group(zarr_group) - - return templates_object - - -def generate_recording_from_template_database(selected_unit_inidces=None, dataset="test_templates", durations=None): - - durations = durations or [10.0] - - templates_object = fetch_templates_from_database(dataset=dataset) - - if selected_unit_inidces: - selected_templates = templates_object.templates_array[selected_unit_inidces, :, :] - else: - selected_templates = templates_object.templates_array - - num_units = selected_templates.shape[0] - sampling_frequency = templates_object.sampling_frequency - sorting = generate_sorting(num_units=num_units, sampling_frequency=sampling_frequency, durations=durations) - - nbefore = templates_object.nbefore - num_samples = durations[0] * sampling_frequency - recording = InjectTemplatesRecording( - sorting=sorting, - templates=selected_templates, - nbefore=nbefore, - num_samples=[num_samples], - ) - - return recording From 02b185c1704d702eb43af2968406da77cad580a6 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 30 Apr 2024 09:28:43 -0600 Subject: [PATCH 4/6] use open consolidated --- src/spikeinterface/generation/template_database.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 src/spikeinterface/generation/template_database.py diff --git a/src/spikeinterface/generation/template_database.py b/src/spikeinterface/generation/template_database.py new file mode 100644 index 0000000000..addb401ed8 --- /dev/null +++ b/src/spikeinterface/generation/template_database.py @@ -0,0 +1,13 @@ +from spikeinterface.core.template import Templates +from spikeinterface.core import generate_sorting, InjectTemplatesRecording +import zarr + + +def fetch_templates_from_database(dataset="test_templates"): + + s3_path = f"s3://spikeinterface-template-database/{dataset}/" + zarr_group = zarr.open_consolidated(s3_path, storage_options={"anon": False}) + + templates_object = Templates.from_zarr_group(zarr_group) + + return templates_object From 31d8a34f7899702b2cb180079fd1bc1a3badc619 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 30 Apr 2024 10:06:50 -0600 Subject: [PATCH 5/6] add typing and improve the test --- src/spikeinterface/generation/__init__.py | 1 - src/spikeinterface/generation/template_database.py | 3 +-- .../generation/tests/test_template_fetch.py | 13 +++++++++++++ 3 files changed, 14 insertions(+), 3 deletions(-) create mode 100644 src/spikeinterface/generation/tests/test_template_fetch.py diff --git a/src/spikeinterface/generation/__init__.py b/src/spikeinterface/generation/__init__.py index 91f9094f6b..d521f9dd9b 100644 --- a/src/spikeinterface/generation/__init__.py +++ b/src/spikeinterface/generation/__init__.py @@ -14,5 +14,4 @@ from .template_database import ( fetch_templates_from_database, - generate_recording_from_template_database, ) diff --git a/src/spikeinterface/generation/template_database.py b/src/spikeinterface/generation/template_database.py index addb401ed8..f455b7c57d 100644 --- a/src/spikeinterface/generation/template_database.py +++ b/src/spikeinterface/generation/template_database.py @@ -1,9 +1,8 @@ from spikeinterface.core.template import Templates -from spikeinterface.core import generate_sorting, InjectTemplatesRecording import zarr -def fetch_templates_from_database(dataset="test_templates"): +def fetch_templates_from_database(dataset="test_templates") -> Templates: s3_path = f"s3://spikeinterface-template-database/{dataset}/" zarr_group = zarr.open_consolidated(s3_path, storage_options={"anon": False}) diff --git a/src/spikeinterface/generation/tests/test_template_fetch.py b/src/spikeinterface/generation/tests/test_template_fetch.py new file mode 100644 index 0000000000..a7cc31af44 --- /dev/null +++ b/src/spikeinterface/generation/tests/test_template_fetch.py @@ -0,0 +1,13 @@ +import pytest +from spikeinterface.generation import fetch_templates_from_database +from spikeinterface.core.template import Templates + + +def test_basic_call(): + + templates = fetch_templates_from_database() + + assert isinstance(templates, Templates) + + assert templates.num_units == 100 + assert templates.num_channels == 384 From 0af5db1007d203d5fc677a25b120957b00d6ef5d Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 30 Apr 2024 19:38:33 +0200 Subject: [PATCH 6/6] Update src/spikeinterface/generation/template_database.py --- src/spikeinterface/generation/template_database.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/generation/template_database.py b/src/spikeinterface/generation/template_database.py index f455b7c57d..2bfcc32245 100644 --- a/src/spikeinterface/generation/template_database.py +++ b/src/spikeinterface/generation/template_database.py @@ -5,7 +5,7 @@ def fetch_templates_from_database(dataset="test_templates") -> Templates: s3_path = f"s3://spikeinterface-template-database/{dataset}/" - zarr_group = zarr.open_consolidated(s3_path, storage_options={"anon": False}) + zarr_group = zarr.open_consolidated(s3_path, storage_options={"anon": True}) templates_object = Templates.from_zarr_group(zarr_group)