Skip to content

Commit

Permalink
Merge pull request #3314 from alejoe91/load-cloud-sorting-analyzer
Browse files Browse the repository at this point in the history
Enable cloud-loading for analyzer Zarr
  • Loading branch information
alejoe91 authored Aug 21, 2024
2 parents 73f6151 + 33e27b1 commit 3921806
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 18 deletions.
17 changes: 17 additions & 0 deletions src/spikeinterface/core/core_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,3 +684,20 @@ def measure_memory_allocation(measure_in_process: bool = True) -> float:
memory = mem_info.total - mem_info.available

return memory


def is_path_remote(path: str | Path) -> bool:
"""
Returns True if the path is a remote path (e.g., s3:// or gcs://).
Parameters
----------
path : str or Path
The path to check.
Returns
-------
bool
Whether the path is a remote path.
"""
return "s3://" in str(path) or "gcs://" in str(path)
41 changes: 23 additions & 18 deletions src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from .base import load_extractor
from .recording_tools import check_probe_do_not_overlap, get_rec_attributes, do_recording_attributes_match
from .core_tools import check_json, retrieve_importing_provenance
from .core_tools import check_json, retrieve_importing_provenance, is_path_remote
from .sorting_tools import generate_unit_ids_for_merge_group, _get_ids_after_merging
from .job_tools import split_job_kwargs
from .numpyextractors import NumpySorting
Expand Down Expand Up @@ -195,6 +195,7 @@ def __init__(
format=None,
sparsity=None,
return_scaled=True,
storage_options=None,
):
# very fast init because checks are done in load and create
self.sorting = sorting
Expand All @@ -204,6 +205,7 @@ def __init__(
self.format = format
self.sparsity = sparsity
self.return_scaled = return_scaled
self.storage_options = storage_options
# this is used to store temporary recording
self._temporary_recording = None

Expand Down Expand Up @@ -276,30 +278,34 @@ def create(
return sorting_analyzer

@classmethod
def load(cls, folder, recording=None, load_extensions=True, format="auto"):
def load(cls, folder, recording=None, load_extensions=True, format="auto", storage_options=None):
"""
Load folder or zarr.
The recording can be given if the recording location has changed.
Otherwise the recording is loaded when possible.
"""
folder = Path(folder)
assert folder.is_dir(), "Waveform folder does not exists"
if format == "auto":
# make better assumption and check for auto guess format
if folder.suffix == ".zarr":
if Path(folder).suffix == ".zarr":
format = "zarr"
else:
format = "binary_folder"

if format == "binary_folder":
sorting_analyzer = SortingAnalyzer.load_from_binary_folder(folder, recording=recording)
elif format == "zarr":
sorting_analyzer = SortingAnalyzer.load_from_zarr(folder, recording=recording)
sorting_analyzer = SortingAnalyzer.load_from_zarr(
folder, recording=recording, storage_options=storage_options
)

sorting_analyzer.folder = folder
if is_path_remote(str(folder)):
sorting_analyzer.folder = folder
# in this case we only load extensions when needed
else:
sorting_analyzer.folder = Path(folder)

if load_extensions:
sorting_analyzer.load_all_saved_extension()
if load_extensions:
sorting_analyzer.load_all_saved_extension()

return sorting_analyzer

Expand Down Expand Up @@ -470,7 +476,9 @@ def load_from_binary_folder(cls, folder, recording=None):
def _get_zarr_root(self, mode="r+"):
import zarr

zarr_root = zarr.open(self.folder, mode=mode)
if is_path_remote(str(self.folder)):
mode = "r"
zarr_root = zarr.open(self.folder, mode=mode, storage_options=self.storage_options)
return zarr_root

@classmethod
Expand Down Expand Up @@ -552,25 +560,22 @@ def create_zarr(cls, folder, sorting, recording, sparsity, return_scaled, rec_at
recording_info = zarr_root.create_group("extensions")

@classmethod
def load_from_zarr(cls, folder, recording=None):
def load_from_zarr(cls, folder, recording=None, storage_options=None):
import zarr

folder = Path(folder)
assert folder.is_dir(), f"This folder does not exist {folder}"

zarr_root = zarr.open(folder, mode="r")
zarr_root = zarr.open(str(folder), mode="r", storage_options=storage_options)

# load internal sorting in memory
# TODO propagate storage_options
sorting = NumpySorting.from_sorting(
ZarrSortingExtractor(folder, zarr_group="sorting"), with_metadata=True, copy_spike_vector=True
ZarrSortingExtractor(folder, zarr_group="sorting", storage_options=storage_options),
with_metadata=True,
copy_spike_vector=True,
)

# load recording if possible
if recording is None:
rec_dict = zarr_root["recording"][0]
try:

recording = load_extractor(rec_dict, base_folder=folder)
except:
recording = None
Expand Down

0 comments on commit 3921806

Please sign in to comment.