diff --git a/src/spikeinterface/core/core_tools.py b/src/spikeinterface/core/core_tools.py index aad7613d01..b38222391c 100644 --- a/src/spikeinterface/core/core_tools.py +++ b/src/spikeinterface/core/core_tools.py @@ -684,3 +684,20 @@ def measure_memory_allocation(measure_in_process: bool = True) -> float: memory = mem_info.total - mem_info.available return memory + + +def is_path_remote(path: str | Path) -> bool: + """ + Returns True if the path is a remote path (e.g., s3:// or gcs://). + + Parameters + ---------- + path : str or Path + The path to check. + + Returns + ------- + bool + Whether the path is a remote path. + """ + return "s3://" in str(path) or "gcs://" in str(path) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index ac142405ab..45f1f881b4 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -23,7 +23,7 @@ from .base import load_extractor from .recording_tools import check_probe_do_not_overlap, get_rec_attributes, do_recording_attributes_match -from .core_tools import check_json, retrieve_importing_provenance +from .core_tools import check_json, retrieve_importing_provenance, is_path_remote from .sorting_tools import generate_unit_ids_for_merge_group, _get_ids_after_merging from .job_tools import split_job_kwargs from .numpyextractors import NumpySorting @@ -195,6 +195,7 @@ def __init__( format=None, sparsity=None, return_scaled=True, + storage_options=None, ): # very fast init because checks are done in load and create self.sorting = sorting @@ -204,6 +205,7 @@ def __init__( self.format = format self.sparsity = sparsity self.return_scaled = return_scaled + self.storage_options = storage_options # this is used to store temporary recording self._temporary_recording = None @@ -276,17 +278,15 @@ def create( return sorting_analyzer @classmethod - def load(cls, folder, recording=None, load_extensions=True, format="auto"): + def load(cls, folder, recording=None, load_extensions=True, format="auto", storage_options=None): """ Load folder or zarr. The recording can be given if the recording location has changed. Otherwise the recording is loaded when possible. """ - folder = Path(folder) - assert folder.is_dir(), "Waveform folder does not exists" if format == "auto": # make better assumption and check for auto guess format - if folder.suffix == ".zarr": + if Path(folder).suffix == ".zarr": format = "zarr" else: format = "binary_folder" @@ -294,12 +294,18 @@ def load(cls, folder, recording=None, load_extensions=True, format="auto"): if format == "binary_folder": sorting_analyzer = SortingAnalyzer.load_from_binary_folder(folder, recording=recording) elif format == "zarr": - sorting_analyzer = SortingAnalyzer.load_from_zarr(folder, recording=recording) + sorting_analyzer = SortingAnalyzer.load_from_zarr( + folder, recording=recording, storage_options=storage_options + ) - sorting_analyzer.folder = folder + if is_path_remote(str(folder)): + sorting_analyzer.folder = folder + # in this case we only load extensions when needed + else: + sorting_analyzer.folder = Path(folder) - if load_extensions: - sorting_analyzer.load_all_saved_extension() + if load_extensions: + sorting_analyzer.load_all_saved_extension() return sorting_analyzer @@ -470,7 +476,9 @@ def load_from_binary_folder(cls, folder, recording=None): def _get_zarr_root(self, mode="r+"): import zarr - zarr_root = zarr.open(self.folder, mode=mode) + if is_path_remote(str(self.folder)): + mode = "r" + zarr_root = zarr.open(self.folder, mode=mode, storage_options=self.storage_options) return zarr_root @classmethod @@ -552,25 +560,22 @@ def create_zarr(cls, folder, sorting, recording, sparsity, return_scaled, rec_at recording_info = zarr_root.create_group("extensions") @classmethod - def load_from_zarr(cls, folder, recording=None): + def load_from_zarr(cls, folder, recording=None, storage_options=None): import zarr - folder = Path(folder) - assert folder.is_dir(), f"This folder does not exist {folder}" - - zarr_root = zarr.open(folder, mode="r") + zarr_root = zarr.open(str(folder), mode="r", storage_options=storage_options) # load internal sorting in memory - # TODO propagate storage_options sorting = NumpySorting.from_sorting( - ZarrSortingExtractor(folder, zarr_group="sorting"), with_metadata=True, copy_spike_vector=True + ZarrSortingExtractor(folder, zarr_group="sorting", storage_options=storage_options), + with_metadata=True, + copy_spike_vector=True, ) # load recording if possible if recording is None: rec_dict = zarr_root["recording"][0] try: - recording = load_extractor(rec_dict, base_folder=folder) except: recording = None