Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add template generate_recording_from_template_database #2769

Merged
5 changes: 5 additions & 0 deletions src/spikeinterface/generation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,8 @@
generate_displacement_vector,
generate_drifting_recording,
)

from ._template_database import (
h-mayorquin marked this conversation as resolved.
Show resolved Hide resolved
fetch_templates_from_database,
generate_recording_from_template_database,
)
47 changes: 47 additions & 0 deletions src/spikeinterface/generation/_template_database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from spikeinterface.core.template import Templates
from spikeinterface.core import generate_sorting, InjectTemplatesRecording
import zarr


def fetch_templates_from_database(dataset="test_templates"):

import s3fs

s3 = s3fs.S3FileSystem(anon=False, client_kwargs={"region_name": "us-east-2"})

# Specify the S3 bucket and path where your Zarr dataset is stored
store = s3fs.S3Map(root=f"spikeinterface-template-database/{dataset}", s3=s3)

# Load the Zarr group from S3
zarr_group = zarr.open(store, mode="r")

templates_object = Templates.from_zarr_group(zarr_group)

return templates_object


def generate_recording_from_template_database(selected_unit_inidces=None, dataset="test_templates", durations=None):
h-mayorquin marked this conversation as resolved.
Show resolved Hide resolved

durations = durations or [10.0]

templates_object = fetch_templates_from_database(dataset=dataset)

if selected_unit_inidces:
selected_templates = templates_object.templates_array[selected_unit_inidces, :, :]
else:
selected_templates = templates_object.templates_array

num_units = selected_templates.shape[0]
sampling_frequency = templates_object.sampling_frequency
sorting = generate_sorting(num_units=num_units, sampling_frequency=sampling_frequency, durations=durations)

nbefore = templates_object.nbefore
num_samples = durations[0] * sampling_frequency
recording = InjectTemplatesRecording(
sorting=sorting,
templates=selected_templates,
nbefore=nbefore,
num_samples=[num_samples],
)

return recording
Loading