diff --git a/src/spikeinterface/generation/__init__.py b/src/spikeinterface/generation/__init__.py index 4015f3f75e..d521f9dd9b 100644 --- a/src/spikeinterface/generation/__init__.py +++ b/src/spikeinterface/generation/__init__.py @@ -11,3 +11,7 @@ generate_displacement_vector, generate_drifting_recording, ) + +from .template_database import ( + fetch_templates_from_database, +) diff --git a/src/spikeinterface/generation/template_database.py b/src/spikeinterface/generation/template_database.py new file mode 100644 index 0000000000..2bfcc32245 --- /dev/null +++ b/src/spikeinterface/generation/template_database.py @@ -0,0 +1,12 @@ +from spikeinterface.core.template import Templates +import zarr + + +def fetch_templates_from_database(dataset="test_templates") -> Templates: + + s3_path = f"s3://spikeinterface-template-database/{dataset}/" + zarr_group = zarr.open_consolidated(s3_path, storage_options={"anon": True}) + + templates_object = Templates.from_zarr_group(zarr_group) + + return templates_object diff --git a/src/spikeinterface/generation/tests/test_template_fetch.py b/src/spikeinterface/generation/tests/test_template_fetch.py new file mode 100644 index 0000000000..a7cc31af44 --- /dev/null +++ b/src/spikeinterface/generation/tests/test_template_fetch.py @@ -0,0 +1,13 @@ +import pytest +from spikeinterface.generation import fetch_templates_from_database +from spikeinterface.core.template import Templates + + +def test_basic_call(): + + templates = fetch_templates_from_database() + + assert isinstance(templates, Templates) + + assert templates.num_units == 100 + assert templates.num_channels == 384