diff --git a/README.md b/README.md index 87457ef..672b816 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,7 @@ Traditional Zarr directory stores have some limitations. First, Zarr archives of HDF5 is not well-suited for cloud environments because accessing a remote HDF5 file often requires a large number of small requests to retrieve metadata before larger data chunks can be downloaded. LINDI addresses this by storing the entire group structure in a single JSON file, which can be downloaded in one request. Additionally, HDF5 lacks a built-in mechanism for referencing data chunks in external datasets. Furthermore, HDF5 does not support custom Python codecs, a feature available in both Zarr and LINDI. -**Is tar format really cloud-friendly** +**Is tar format really cloud-friendly?** With LINDI, yes. See [docs/tar.md](docs/tar.md) for details. @@ -101,7 +101,6 @@ with lindi.LindiH5pyFile.from_lindi_file('example.lindi.tar', mode='r') as f: With LINDI, it is easy to load an NWB file stored on DANDI. The following example demonstrates how to load an NWB file from DANDI, view it using the pynwb library, and save it as a relatively smaller .lindi.json file. The LINDI JSON file can then be read directly to access the NWB file. ```python -import json import pynwb import lindi diff --git a/examples/DANDI/nwbextractors.py b/examples/DANDI/nwbextractors.py deleted file mode 100644 index 37c5af2..0000000 --- a/examples/DANDI/nwbextractors.py +++ /dev/null @@ -1,1427 +0,0 @@ -from __future__ import annotations -from pathlib import Path -from typing import List, Optional, Literal, Dict, BinaryIO -import warnings - -import numpy as np -import h5py - -from spikeinterface import get_global_tmp_folder -from spikeinterface.core import BaseRecording, BaseRecordingSegment, BaseSorting, BaseSortingSegment -from spikeinterface.core.core_tools import define_function_from_class - - -def read_file_from_backend( - *, - file_path: str | Path | None, - file: BinaryIO | None = None, - h5py_file: h5py.File | None = None, - stream_mode: Literal["ffspec", "remfile"] | None = None, - cache: bool = False, - stream_cache_path: str | Path | None = None, - storage_options: dict | None = None, -): - """ - Reads a file from a hdf5 or zarr backend. - """ - if stream_mode == "fsspec": - import h5py - import fsspec - from fsspec.implementations.cached import CachingFileSystem - - assert file_path is not None, "file_path must be specified when using stream_mode='fsspec'" - - fsspec_file_system = fsspec.filesystem("http") - - if cache: - stream_cache_path = stream_cache_path if stream_cache_path is not None else str(get_global_tmp_folder()) - caching_file_system = CachingFileSystem( - fs=fsspec_file_system, - cache_storage=str(stream_cache_path), - ) - ffspec_file = caching_file_system.open(path=file_path, mode="rb") - else: - ffspec_file = fsspec_file_system.open(file_path, "rb") - - if _is_hdf5_file(ffspec_file): - open_file = h5py.File(ffspec_file, "r") - else: - raise RuntimeError(f"{file_path} is not a valid HDF5 file!") - - elif stream_mode == "ros3": - import h5py - - assert file_path is not None, "file_path must be specified when using stream_mode='ros3'" - - drivers = h5py.registered_drivers() - assertion_msg = "ROS3 support not enabled, use: install -c conda-forge h5py>=3.2 to enable streaming" - assert "ros3" in drivers, assertion_msg - open_file = h5py.File(name=file_path, mode="r", driver="ros3") - - elif stream_mode == "remfile": - import remfile - import h5py - - assert file_path is not None, "file_path must be specified when using stream_mode='remfile'" - rfile = remfile.File(file_path) - if _is_hdf5_file(rfile): - open_file = h5py.File(rfile, "r") - else: - raise RuntimeError(f"{file_path} is not a valid HDF5 file!") - - elif stream_mode == "zarr": - import zarr - - open_file = zarr.open(file_path, mode="r", storage_options=storage_options) - - elif file_path is not None: # local - file_path = str(Path(file_path).resolve()) - backend = _get_backend_from_local_file(file_path) - if backend == "zarr": - import zarr - - open_file = zarr.open(file_path, mode="r") - else: - import h5py - - open_file = h5py.File(name=file_path, mode="r") - elif file is not None: - import h5py - open_file = h5py.File(file, "r") - return open_file - elif h5py_file is not None: - return h5py_file - else: - raise ValueError("Provide either file_path or file or h5py_file") - - -def read_nwbfile( - *, - backend: Literal["hdf5", "zarr"], - file_path: str | Path | None, - file: BinaryIO | None = None, - h5py_file: h5py.File | None = None, - stream_mode: Literal["ffspec", "remfile", "zarr"] | None = None, - cache: bool = False, - stream_cache_path: str | Path | None = None, - storage_options: dict | None = None, -) -> "NWBFile": - """ - Read an NWB file and return the NWBFile object. - - Parameters - ---------- - file_path : Path, str or None - The path to the NWB file. Either provide this or file. - file : file-like object or None - The file-like object to read from. Either provide this or file_path. - stream_mode : "fsspec" | "remfile" | None, default: None - The streaming mode to use. If None it assumes the file is on the local disk. - cache : bool, default: False - If True, the file is cached in the file passed to stream_cache_path - if False, the file is not cached. - stream_cache_path : str or None, default: None - The path to the cache storage, when default to None it uses the a temporary - folder. - Returns - ------- - nwbfile : NWBFile - The NWBFile object. - - Notes - ----- - This function can stream data from the "fsspec", and "rem" protocols. - - - Examples - -------- - >>> nwbfile = read_nwbfile(file_path="data.nwb", backend="hdf5", stream_mode="fsspec") - """ - - if file_path is not None and file is not None: - raise ValueError("Provide either file_path or file, not both") - if file_path is None and h5py_file is not None: - raise ValueError("Provide either h5py_file or file_path, not both") - if file is not None and h5py_file is not None: - raise ValueError("Provide either h5py_file or file, not both") - if file_path is None and file is None and h5py_file is None: - raise ValueError("Provide either file_path or file or h5py_file") - - open_file = read_file_from_backend( - file_path=file_path, - file=file, - h5py_file=h5py_file, - stream_mode=stream_mode, - cache=cache, - stream_cache_path=stream_cache_path, - storage_options=storage_options, - ) - if backend == "hdf5": - from pynwb import NWBHDF5IO - - io = NWBHDF5IO(file=open_file, mode="r", load_namespaces=True) - else: - from hdmf_zarr import NWBZarrIO - - io = NWBZarrIO(path=open_file.store, mode="r", load_namespaces=True) - - nwbfile = io.read() - return nwbfile - - -def _retrieve_electrical_series_pynwb( - nwbfile: "NWBFile", electrical_series_path: Optional[str] = None -) -> "ElectricalSeries": - """ - Get an ElectricalSeries object from an NWBFile. - - Parameters - ---------- - nwbfile : NWBFile - The NWBFile object from which to extract the ElectricalSeries. - electrical_series_path : str, default: None - The name of the ElectricalSeries to extract. If not specified, it will return the first found ElectricalSeries - if there's only one; otherwise, it raises an error. - - Returns - ------- - ElectricalSeries - The requested ElectricalSeries object. - - Raises - ------ - ValueError - If no acquisitions are found in the NWBFile or if multiple acquisitions are found but no electrical_series_path - is provided. - AssertionError - If the specified electrical_series_path is not present in the NWBFile. - """ - from pynwb.ecephys import ElectricalSeries - - electrical_series_dict: Dict[str, ElectricalSeries] = {} - - for item in nwbfile.all_children(): - if isinstance(item, ElectricalSeries): - # remove data and skip first "/" - electrical_series_key = item.data.name.replace("/data", "")[1:] - electrical_series_dict[electrical_series_key] = item - - if electrical_series_path is not None: - if electrical_series_path not in electrical_series_dict: - raise ValueError(f"{electrical_series_path} not found in the NWBFile. ") - electrical_series = electrical_series_dict[electrical_series_path] - else: - electrical_series_list = list(electrical_series_dict.keys()) - if len(electrical_series_list) > 1: - raise ValueError( - f"More than one acquisition found! You must specify 'electrical_series_path'. \n" - f"Options in current file are: {[e for e in electrical_series_list]}" - ) - if len(electrical_series_list) == 0: - raise ValueError("No acquisitions found in the .nwb file.") - electrical_series = electrical_series_dict[electrical_series_list[0]] - - return electrical_series - - -def _retrieve_unit_table_pynwb(nwbfile: "NWBFile", unit_table_path: Optional[str] = None) -> "Units": - """ - Get an Units object from an NWBFile. - Units tables can be either the main unit table (nwbfile.units) or in the processing module. - - Parameters - ---------- - nwbfile : NWBFile - The NWBFile object from which to extract the Units. - unit_table_path : str, default: None - The path of the Units to extract. If not specified, it will return the first found Units - if there's only one; otherwise, it raises an error. - - Returns - ------- - Units - The requested Units object. - - Raises - ------ - ValueError - If no unit tables are found in the NWBFile or if multiple unit tables are found but no unit_table_path - is provided. - AssertionError - If the specified unit_table_path is not present in the NWBFile. - """ - from pynwb.misc import Units - - unit_table_dict: Dict[str:Units] = {} - - for item in nwbfile.all_children(): - if isinstance(item, Units): - # retrieve name of "id" column and skip first "/" - unit_table_key = item.id.data.name.replace("/id", "")[1:] - unit_table_dict[unit_table_key] = item - - if unit_table_path is not None: - if unit_table_path not in unit_table_dict: - raise ValueError(f"{unit_table_path} not found in the NWBFile. ") - unit_table = unit_table_dict[unit_table_path] - else: - unit_table_list: List[Units] = list(unit_table_dict.keys()) - - if len(unit_table_list) > 1: - raise ValueError( - f"More than one unit table found! You must specify 'unit_table_list_name'. \n" - f"Options in current file are: {[e for e in unit_table_list]}" - ) - if len(unit_table_list) == 0: - raise ValueError("No unit table found in the .nwb file.") - unit_table = unit_table_dict[unit_table_list[0]] - - return unit_table - - -def _is_hdf5_file(filename_or_file): - if isinstance(filename_or_file, (str, Path)): - import h5py - - filename = str(filename_or_file) - is_hdf5 = h5py.h5f.is_hdf5(filename.encode("utf-8")) - else: - file_signature = filename_or_file.read(8) - # Source of the magic number https://docs.hdfgroup.org/hdf5/develop/_f_m_t3.html - is_hdf5 = file_signature == b"\x89HDF\r\n\x1a\n" - - return is_hdf5 - - -def _get_backend_from_local_file(file_path: str | Path) -> str: - """ - Returns the file backend from a file path ("hdf5", "zarr") - - Parameters - ---------- - file_path : str or Path - The path to the file. - - Returns - ------- - backend : str - The file backend ("hdf5", "zarr") - """ - file_path = Path(file_path) - if file_path.is_file(): - if _is_hdf5_file(file_path): - backend = "hdf5" - else: - raise RuntimeError(f"{file_path} is not a valid HDF5 file!") - elif file_path.is_dir(): - try: - import zarr - - with zarr.open(file_path, "r") as f: - backend = "zarr" - except: - raise RuntimeError(f"{file_path} is not a valid Zarr folder!") - else: - raise RuntimeError(f"File {file_path} is not an existing file or folder!") - return backend - - -def _find_neurodata_type_from_backend(group, path="", result=None, neurodata_type="ElectricalSeries", backend="hdf5"): - """ - Recursively searches for groups with the specified neurodata_type hdf5 or zarr object, - and returns a list with their paths. - """ - if backend == "hdf5": - import h5py - - group_class = h5py.Group - else: - import zarr - - group_class = zarr.Group - - if result is None: - result = [] - - for neurodata_name, value in group.items(): - # Check if it's a group and if it has the neurodata_type - if isinstance(value, group_class): - current_path = f"{path}/{neurodata_name}" if path else neurodata_name - if value.attrs.get("neurodata_type") == neurodata_type: - result.append(current_path) - _find_neurodata_type_from_backend( - value, current_path, result, neurodata_type, backend - ) # Recursive call for sub-groups - return result - - -def _fetch_time_info_pynwb(electrical_series, samples_for_rate_estimation, load_time_vector=False): - """ - Extracts the sampling frequency and the time vector from an ElectricalSeries object. - """ - sampling_frequency = None - if hasattr(electrical_series, "rate"): - sampling_frequency = electrical_series.rate - - if hasattr(electrical_series, "starting_time"): - t_start = electrical_series.starting_time - else: - t_start = None - - timestamps = None - if hasattr(electrical_series, "timestamps"): - if electrical_series.timestamps is not None: - timestamps = electrical_series.timestamps - t_start = electrical_series.timestamps[0] - - # TimeSeries need to have either timestamps or rate - if sampling_frequency is None: - sampling_frequency = 1.0 / np.median(np.diff(timestamps[:samples_for_rate_estimation])) - - if load_time_vector and timestamps is not None: - times_kwargs = dict(time_vector=electrical_series.timestamps) - else: - times_kwargs = dict(sampling_frequency=sampling_frequency, t_start=t_start) - - return sampling_frequency, times_kwargs - - -def _retrieve_electrodes_indices_from_electrical_series_backend(open_file, electrical_series, backend="hdf5"): - """ - Retrieves the indices of the electrodes from the electrical series. - For the Zarr backend, the electrodes are stored in the electrical_series.attrs["zarr_link"]. - """ - if "electrodes" not in electrical_series: - if backend == "zarr": - import zarr - - # links must be resolved - zarr_links = electrical_series.attrs["zarr_link"] - electrodes_path = None - for zarr_link in zarr_links: - if zarr_link["name"] == "electrodes": - electrodes_path = zarr_link["path"] - assert electrodes_path is not None, "electrodes must be present in the electrical series" - electrodes_indices = open_file[electrodes_path][:] - else: - raise ValueError("electrodes must be present in the electrical series") - else: - electrodes_indices = electrical_series["electrodes"][:] - return electrodes_indices - - -class NwbRecordingExtractor(BaseRecording): - """Load an NWBFile as a RecordingExtractor. - - Parameters - ---------- - file_path : str, Path or None - Path to the NWB file or an s3 URL. Use this parameter to specify the file location - if not using the `file` or `h5py_file` parameter. - electrical_series_name : str or None, default: None - Deprecated, use `electrical_series_path` instead. - electrical_series_path : str or None, default: None - The name of the ElectricalSeries object within the NWB file. This parameter is crucial - when the NWB file contains multiple ElectricalSeries objects. It helps in identifying - which specific series to extract data from. If there is only one ElectricalSeries and - this parameter is not set, that unique series will be used by default. - If multiple ElectricalSeries are present and this parameter is not set, an error is raised. - The `electrical_series_path` corresponds to the path within the NWB file, e.g., - 'acquisition/MyElectricalSeries`. - load_time_vector : bool, default: False - If set to True, the time vector is also loaded into the recording object. Useful for - cases where precise timing information is required. - samples_for_rate_estimation : int, default: 1000 - The number of timestamp samples used for estimating the sampling rate. This is relevant - when the 'rate' attribute is not available in the ElectricalSeries. - stream_mode : "fsspec" | "remfile" | "zarr" | None, default: None - Determines the streaming mode for reading the file. Use this for optimized reading from - different sources, such as local disk or remote servers. - load_channel_properties : bool, default: True - If True, all the channel properties are loaded from the NWB file and stored as properties. - For streaming purposes, it can be useful to set this to False to speed up streaming. - file : file-like object or None, default: None - A file-like object representing the NWB file. Use this parameter if you have an in-memory - representation of the NWB file instead of a file path. - h5py_file : h5py.File or None, default: None - A h5py.File-like object representing the NWB file. (jfm) - cache : bool, default: False - Indicates whether to cache the file locally when using streaming. Caching can improve performance for - remote files. - stream_cache_path : str, Path, or None, default: None - Specifies the local path for caching the file. Relevant only if `cache` is True. - storage_options : dict | None = None, - These are the additional kwargs (e.g. AWS credentials) that are passed to the zarr.open convenience function. - This is only used on the "zarr" stream_mode. - use_pynwb : bool, default: False - Uses the pynwb library to read the NWB file. Setting this to False, the default, uses h5py - to read the file. Using h5py can improve performance by bypassing some of the PyNWB validations. - - Returns - ------- - recording : NwbRecordingExtractor - The recording extractor for the NWB file. - - Examples - -------- - Run on local file: - - >>> from spikeinterface.extractors.nwbextractors import NwbRecordingExtractor - >>> rec = NwbRecordingExtractor(filepath) - - Run on s3 URL from the DANDI Archive: - - >>> from spikeinterface.extractors.nwbextractors import NwbRecordingExtractor - >>> from dandi.dandiapi import DandiAPIClient - >>> - >>> # get s3 path - >>> dandiset_id, filepath = "101116", "sub-001/sub-001_ecephys.nwb" - >>> with DandiAPIClient("https://api-staging.dandiarchive.org/api") as client: - >>> asset = client.get_dandiset(dandiset_id, "draft").get_asset_by_path(filepath) - >>> s3_url = asset.get_content_url(follow_redirects=1, strip_query=True) - >>> - >>> rec = NwbRecordingExtractor(s3_url, stream_mode="fsspec", stream_cache_path="cache") - """ - - mode = "file" - name = "nwb" - installation_mesg = "To use the Nwb extractors, install pynwb: \n\n pip install pynwb\n\n" - - def __init__( - self, - file_path: str | Path | None = None, # provide either this or file - electrical_series_name: str | None = None, # deprecated - load_time_vector: bool = False, - samples_for_rate_estimation: int = 1_000, - stream_mode: Optional[Literal["fsspec", "remfile", "zarr"]] = None, - stream_cache_path: str | Path | None = None, - electrical_series_path: str | None = None, - load_channel_properties: bool = True, - *, - file: BinaryIO | None = None, # file-like - provide either this or file_path or h5py_file - h5py_file: h5py.File | None = None, # provide either this or file_path or file - cache: bool = False, - storage_options: dict | None = None, - use_pynwb: bool = False, - ): - - if stream_mode == "ros3": - warnings.warn( - "The 'ros3' stream_mode is deprecated and will be removed in version 0.103.0. " - "Use 'fsspec' stream_mode instead.", - DeprecationWarning, - ) - - if file_path is not None and file is not None: - raise ValueError("Provide either file_path or file, not both") - if file_path is not None and h5py_file is not None: - raise ValueError("Provide either h5py_file or file_path, not both") - if file is not None and h5py_file is not None: - raise ValueError("Provide either h5py_file or file, not both") - if file_path is None and file is None and h5py_file is None: - raise ValueError("Provide either file_path or file or h5py_file") - - if electrical_series_name is not None: - warning_msg = ( - "The `electrical_series_name` parameter is deprecated and will be removed in version 0.101.0.\n" - "Use `electrical_series_path` instead." - ) - if electrical_series_path is None: - warning_msg += f"\nSetting `electrical_series_path` to 'acquisition/{electrical_series_name}'." - electrical_series_path = f"acquisition/{electrical_series_name}" - else: - warning_msg += f"\nIgnoring `electrical_series_name` and using the provided `electrical_series_path`." - warnings.warn(warning_msg, DeprecationWarning, stacklevel=2) - - self.file_path = file_path - self.stream_mode = stream_mode - self.stream_cache_path = stream_cache_path - self.storage_options = storage_options - self.electrical_series_path = electrical_series_path - - if self.stream_mode is None and file is None and h5py_file is None: - self.backend = _get_backend_from_local_file(file_path) - else: - if self.stream_mode == "zarr": - self.backend = "zarr" - else: - self.backend = "hdf5" - - # extract info - if use_pynwb: - try: - import pynwb - except ImportError: - raise ImportError(self.installation_mesg) - - ( - channel_ids, - sampling_frequency, - dtype, - segment_data, - times_kwargs, - ) = self._fetch_recording_segment_info_pynwb(file, h5py_file, cache, load_time_vector, samples_for_rate_estimation) - else: - ( - channel_ids, - sampling_frequency, - dtype, - segment_data, - times_kwargs, - ) = self._fetch_recording_segment_info_backend(file, h5py_file, cache, load_time_vector, samples_for_rate_estimation) - BaseRecording.__init__(self, channel_ids=channel_ids, sampling_frequency=sampling_frequency, dtype=dtype) - recording_segment = NwbRecordingSegment( - electrical_series_data=segment_data, - times_kwargs=times_kwargs, - ) - self.add_recording_segment(recording_segment) - - # fetch and add main recording properties - if use_pynwb: - gains, offsets, locations, groups = self._fetch_main_properties_pynwb() - self.extra_requirements.append("pynwb") - else: - gains, offsets, locations, groups = self._fetch_main_properties_backend() - self.extra_requirements.append("h5py") - self.set_channel_gains(gains) - self.set_channel_offsets(offsets) - if locations is not None: - self.set_channel_locations(locations) - if groups is not None: - self.set_channel_groups(groups) - - # fetch and add additional recording properties - if load_channel_properties: - if use_pynwb: - electrodes_table = self._nwbfile.electrodes - electrodes_indices = self.electrical_series.electrodes.data[:] - columns = electrodes_table.colnames - else: - electrodes_table = self._file["/general/extracellular_ephys/electrodes"] - electrodes_indices = _retrieve_electrodes_indices_from_electrical_series_backend( - self._file, self.electrical_series, self.backend - ) - columns = electrodes_table.attrs["colnames"] - properties = self._fetch_other_properties(electrodes_table, electrodes_indices, columns) - - for property_name, property_values in properties.items(): - values = [x.decode("utf-8") if isinstance(x, bytes) else x for x in property_values] - self.set_property(property_name, values) - - if stream_mode is None and file_path is not None: - file_path = str(Path(file_path).resolve()) - - if stream_mode == "fsspec" and stream_cache_path is not None: - stream_cache_path = str(Path(self.stream_cache_path).absolute()) - - # set serializability bools - if file is not None: - # not json serializable if file arg is provided - self._serializability["json"] = False - if h5py_file is not None: - # not json serializable if h5py_file arg is provided - self._serializability["json"] = False - - if storage_options is not None and stream_mode == "zarr": - warnings.warn( - "The `storage_options` parameter will not be propagated to JSON or pickle files for security reasons, " - "so the extractor will not be JSON/pickle serializable. Only in-memory mode will be available." - ) - # not serializable if storage_options is provided - self._serializability["json"] = False - self._serializability["pickle"] = False - - self._kwargs = { - "file_path": file_path, - "electrical_series_path": self.electrical_series_path, - "load_time_vector": load_time_vector, - "samples_for_rate_estimation": samples_for_rate_estimation, - "stream_mode": stream_mode, - "load_channel_properties": load_channel_properties, - "storage_options": storage_options, - "cache": cache, - "stream_cache_path": stream_cache_path, - "file": file, - "h5py_file": h5py_file - } - - def __del__(self): - # backend mode - if hasattr(self, "_file"): - if hasattr(self._file, "store"): - self._file.store.close() - else: - self._file.close() - # pynwb mode - elif hasattr(self, "_nwbfile"): - io = self._nwbfile.get_read_io() - if io is not None: - io.close() - - def _fetch_recording_segment_info_pynwb(self, file, h5py_file, cache, load_time_vector, samples_for_rate_estimation): - self._nwbfile = read_nwbfile( - backend=self.backend, - file_path=self.file_path, - file=file, - h5py_file=h5py_file, - stream_mode=self.stream_mode, - cache=cache, - stream_cache_path=self.stream_cache_path, - ) - electrical_series = _retrieve_electrical_series_pynwb(self._nwbfile, self.electrical_series_path) - # The indices in the electrode table corresponding to this electrical series - electrodes_indices = electrical_series.electrodes.data[:] - # The table for all the electrodes in the nwbfile - electrodes_table = self._nwbfile.electrodes - - sampling_frequency, times_kwargs = _fetch_time_info_pynwb( - electrical_series=electrical_series, - samples_for_rate_estimation=samples_for_rate_estimation, - load_time_vector=load_time_vector, - ) - - # Fill channel properties dictionary from electrodes table - if "channel_name" in electrodes_table.colnames: - channel_ids = [ - electrical_series.electrodes["channel_name"][electrodes_index] - for electrodes_index in electrodes_indices - ] - else: - channel_ids = [electrical_series.electrodes.table.id[x] for x in electrodes_indices] - electrical_series_data = electrical_series.data - dtype = electrical_series_data.dtype - - # need this later - self.electrical_series = electrical_series - - return channel_ids, sampling_frequency, dtype, electrical_series_data, times_kwargs - - def _fetch_recording_segment_info_backend(self, file, h5py_file, cache, load_time_vector, samples_for_rate_estimation): - open_file = read_file_from_backend( - file_path=self.file_path, - file=file, - h5py_file=h5py_file, - stream_mode=self.stream_mode, - cache=cache, - stream_cache_path=self.stream_cache_path, - ) - - # If the electrical_series_path is not given, `_find_neurodata_type_from_backend` will be called - # And returns a list with the electrical_series_paths available in the file. - # If there is only one electrical series, the electrical_series_path is set to the name of the series, - # otherwise an error is raised. - if self.electrical_series_path is None: - available_electrical_series = _find_neurodata_type_from_backend( - open_file, neurodata_type="ElectricalSeries", backend=self.backend - ) - # if electrical_series_path is None: - if len(available_electrical_series) == 1: - self.electrical_series_path = available_electrical_series[0] - else: - raise ValueError( - "Multiple ElectricalSeries found in the file. " - "Please specify the 'electrical_series_path' argument:" - f"Available options are: {available_electrical_series}." - ) - - # Open the electrical series. In case of failure, raise an error with the available options. - try: - electrical_series = open_file[self.electrical_series_path] - except KeyError: - available_electrical_series = _find_neurodata_type_from_backend( - open_file, neurodata_type="ElectricalSeries", backend=self.backend - ) - raise ValueError( - f"{self.electrical_series_path} not found in the NWB file!" - f"Available options are: {available_electrical_series}." - ) - electrodes_indices = _retrieve_electrodes_indices_from_electrical_series_backend( - open_file, electrical_series, self.backend - ) - # The table for all the electrodes in the nwbfile - electrodes_table = open_file["/general/extracellular_ephys/electrodes"] - electrode_table_columns = electrodes_table.attrs["colnames"] - - # Get sampling frequency - if "starting_time" in electrical_series.keys(): - t_start = electrical_series["starting_time"][()] - sampling_frequency = electrical_series["starting_time"].attrs["rate"] - elif "timestamps" in electrical_series.keys(): - timestamps = electrical_series["timestamps"][:] - t_start = timestamps[0] - sampling_frequency = 1.0 / np.median(np.diff(timestamps[:samples_for_rate_estimation])) - - if load_time_vector and timestamps is not None: - times_kwargs = dict(time_vector=electrical_series["timestamps"]) - else: - times_kwargs = dict(sampling_frequency=sampling_frequency, t_start=t_start) - - # If channel names are present, use them as channel_ids instead of the electrode ids - if "channel_name" in electrode_table_columns: - channel_names = electrodes_table["channel_name"] - channel_ids = channel_names[electrodes_indices] - # Decode if bytes with utf-8 - channel_ids = [x.decode("utf-8") if isinstance(x, bytes) else x for x in channel_ids] - - else: - channel_ids = [electrodes_table["id"][x] for x in electrodes_indices] - - dtype = electrical_series["data"].dtype - electrical_series_data = electrical_series["data"] - - # need this for later - self.electrical_series = electrical_series - self._file = open_file - - return channel_ids, sampling_frequency, dtype, electrical_series_data, times_kwargs - - def _fetch_locations_and_groups(self, electrodes_table, electrodes_indices): - # Channel locations - locations = None - if "rel_x" in electrodes_table: - if "rel_y" in electrodes_table: - ndim = 3 if "rel_z" in electrodes_table else 2 - locations = np.zeros((self.get_num_channels(), ndim), dtype=float) - locations[:, 0] = electrodes_table["rel_x"][electrodes_indices] - locations[:, 1] = electrodes_table["rel_y"][electrodes_indices] - if "rel_z" in electrodes_table: - locations[:, 2] = electrodes_table["rel_z"][electrodes_indices] - - # allow x, y, z instead of rel_x, rel_y, rel_z - if locations is None: - if "x" in electrodes_table: - if "y" in electrodes_table: - ndim = 3 if "z" in electrodes_table else 2 - locations = np.zeros((self.get_num_channels(), ndim), dtype=float) - locations[:, 0] = electrodes_table["x"][electrodes_indices] - locations[:, 1] = electrodes_table["y"][electrodes_indices] - if "z" in electrodes_table: - locations[:, 2] = electrodes_table["z"][electrodes_indices] - - # Channel groups - groups = None - if "group_name" in electrodes_table: - groups = electrodes_table["group_name"][electrodes_indices][:] - if groups is not None: - groups = np.array([x.decode("utf-8") if isinstance(x, bytes) else x for x in groups]) - return locations, groups - - def _fetch_other_properties(self, electrodes_table, electrodes_indices, columns): - ######### - # Extract and re-name properties from nwbfile TODO: Should be a function - ######## - properties = dict() - properties_to_skip = [ - "id", - "rel_x", - "rel_y", - "rel_z", - "group", - "group_name", - "channel_name", - "offset", - ] - rename_properties = dict(location="brain_area") - - for column in columns: - if column in properties_to_skip: - continue - else: - column_name = rename_properties.get(column, column) - properties[column_name] = electrodes_table[column][electrodes_indices] - - return properties - - def _fetch_main_properties_pynwb(self): - """ - Fetches the main properties from the NWBFile and stores them in the RecordingExtractor, including: - - - gains - - offsets - - locations - - groups - """ - electrodes_indices = self.electrical_series.electrodes.data[:] - electrodes_table = self._nwbfile.electrodes - - # Channels gains - for RecordingExtractor, these are values to cast traces to uV - gains = self.electrical_series.conversion * 1e6 - if self.electrical_series.channel_conversion is not None: - gains = self.electrical_series.conversion * self.electrical_series.channel_conversion[:] * 1e6 - - # Channel offsets - offset = self.electrical_series.offset if hasattr(self.electrical_series, "offset") else 0 - if offset == 0 and "offset" in electrodes_table: - offset = electrodes_table["offset"].data[electrodes_indices] - offsets = offset * 1e6 - - locations, groups = self._fetch_locations_and_groups(electrodes_table, electrodes_indices) - - return gains, offsets, locations, groups - - def _fetch_main_properties_backend(self): - """ - Fetches the main properties from the NWBFile and stores them in the RecordingExtractor, including: - - - gains - - offsets - - locations - - groups - """ - electrodes_indices = _retrieve_electrodes_indices_from_electrical_series_backend( - self._file, self.electrical_series, self.backend - ) - electrodes_table = self._file["/general/extracellular_ephys/electrodes"] - - # Channels gains - for RecordingExtractor, these are values to cast traces to uV - data_attributes = self.electrical_series["data"].attrs - electrical_series_conversion = data_attributes["conversion"] - gains = electrical_series_conversion * 1e6 - channel_conversion = self.electrical_series.get("channel_conversion", None) - if channel_conversion: - gains *= self.electrical_series["channel_conversion"][:] - - # Channel offsets - offset = data_attributes["offset"] if "offset" in data_attributes else 0 - if offset == 0 and "offset" in electrodes_table: - offset = electrodes_table["offset"][electrodes_indices] - offsets = offset * 1e6 - - # Channel locations and groups - locations, groups = self._fetch_locations_and_groups(electrodes_table, electrodes_indices) - - return gains, offsets, locations, groups - - @staticmethod - def fetch_available_electrical_series_paths( - file_path: str | Path, - stream_mode: Optional[Literal["fsspec", "remfile", "zarr"]] = None, - storage_options: dict | None = None, - ) -> list[str]: - """ - Retrieves the paths to all ElectricalSeries objects within a neurodata file. - - Parameters - ---------- - file_path : str | Path - The path to the neurodata file to be analyzed. - stream_mode : "fsspec" | "remfile" | "zarr" | None, optional - Determines the streaming mode for reading the file. Use this for optimized reading from - different sources, such as local disk or remote servers. - storage_options : dict | None = None, - These are the additional kwargs (e.g. AWS credentials) that are passed to the zarr.open convenience function. - This is only used on the "zarr" stream_mode. - Returns - ------- - list of str - A list of paths to all ElectricalSeries objects found in the file. - - - Notes - ----- - The paths are returned as strings, and can be used to load the desired ElectricalSeries object. - Examples of paths are: - - "acquisition/ElectricalSeries1" - - "acquisition/ElectricalSeries2" - - "processing/ecephys/LFP/ElectricalSeries1" - - "processing/my_custom_module/MyContainer/ElectricalSeries2" - """ - - if stream_mode is None: - backend = _get_backend_from_local_file(file_path) - else: - if stream_mode == "zarr": - backend = "zarr" - else: - backend = "hdf5" - - file_handle = read_file_from_backend( - file_path=file_path, - stream_mode=stream_mode, - storage_options=storage_options, - ) - - electrical_series_paths = _find_neurodata_type_from_backend( - file_handle, - neurodata_type="ElectricalSeries", - backend=backend, - ) - return electrical_series_paths - - -class NwbRecordingSegment(BaseRecordingSegment): - def __init__(self, electrical_series_data, times_kwargs): - BaseRecordingSegment.__init__(self, **times_kwargs) - self.electrical_series_data = electrical_series_data - self._num_samples = self.electrical_series_data.shape[0] - - def get_num_samples(self): - """Returns the number of samples in this signal block - - Returns: - SampleIndex : Number of samples in the signal block - """ - return self._num_samples - - def get_traces(self, start_frame, end_frame, channel_indices): - electrical_series_data = self.electrical_series_data - if electrical_series_data.ndim == 1: - traces = electrical_series_data[start_frame:end_frame][:, np.newaxis] - elif isinstance(channel_indices, slice): - traces = electrical_series_data[start_frame:end_frame, channel_indices] - else: - # channel_indices is np.ndarray - if np.array(channel_indices).size > 1 and np.any(np.diff(channel_indices) < 0): - # get around h5py constraint that it does not allow datasets - # to be indexed out of order - sorted_channel_indices = np.sort(channel_indices) - resorted_indices = np.array([list(sorted_channel_indices).index(ch) for ch in channel_indices]) - recordings = electrical_series_data[start_frame:end_frame, sorted_channel_indices] - traces = recordings[:, resorted_indices] - else: - traces = electrical_series_data[start_frame:end_frame, channel_indices] - - return traces - - -class NwbSortingExtractor(BaseSorting): - """Load an NWBFile as a SortingExtractor. - Parameters - ---------- - file_path : str or Path - Path to NWB file. - electrical_series_path : str or None, default: None - The name of the ElectricalSeries (if multiple ElectricalSeries are present). - sampling_frequency : float or None, default: None - The sampling frequency in Hz (required if no ElectricalSeries is available). - unit_table_path : str or None, default: "units" - The path of the unit table in the NWB file. - samples_for_rate_estimation : int, default: 100000 - The number of timestamp samples to use to estimate the rate. - Used if "rate" is not specified in the ElectricalSeries. - stream_mode : "fsspec" | "remfile" | "zarr" | None, default: None - The streaming mode to use. If None it assumes the file is on the local disk. - stream_cache_path : str or Path or None, default: None - Local path for caching. If None it uses the system temporary directory. - load_unit_properties : bool, default: True - If True, all the unit properties are loaded from the NWB file and stored as properties. - t_start : float or None, default: None - This is the time at which the corresponding ElectricalSeries start. NWB stores its spikes as times - and the `t_start` is used to convert the times to seconds. Concrently, the returned frames are computed as: - - `frames = (times - t_start) * sampling_frequency`. - - As SpikeInterface always considers the first frame to be at the beginning of the recording independently - of the `t_start`. - - When a `t_start` is not provided it will be inferred from the corresponding ElectricalSeries with name equal - to `electrical_series_path`. The `t_start` then will be either the `ElectricalSeries.starting_time` or the - first timestamp in the `ElectricalSeries.timestamps`. - cache : bool, default: False - If True, the file is cached in the file passed to stream_cache_path - if False, the file is not cached. - storage_options : dict | None = None, - These are the additional kwargs (e.g. AWS credentials) that are passed to the zarr.open convenience function. - This is only used on the "zarr" stream_mode. - use_pynwb : bool, default: False - Uses the pynwb library to read the NWB file. Setting this to False, the default, uses h5py - to read the file. Using h5py can improve performance by bypassing some of the PyNWB validations. - - Returns - ------- - sorting : NwbSortingExtractor - The sorting extractor for the NWB file. - """ - - mode = "file" - installation_mesg = "To use the Nwb extractors, install pynwb: \n\n pip install pynwb\n\n" - name = "nwb" - - def __init__( - self, - file_path: str | Path, - electrical_series_path: str | None = None, - sampling_frequency: float | None = None, - samples_for_rate_estimation: int = 1_000, - stream_mode: str | None = None, - stream_cache_path: str | Path | None = None, - load_unit_properties: bool = True, - unit_table_path: str = "units", - *, - t_start: float | None = None, - cache: bool = False, - storage_options: dict | None = None, - use_pynwb: bool = False, - ): - - if stream_mode == "ros3": - warnings.warn( - "The 'ros3' stream_mode is deprecated and will be removed in version 0.103.0. " - "Use 'fsspec' stream_mode instead.", - DeprecationWarning, - ) - - self.stream_mode = stream_mode - self.stream_cache_path = stream_cache_path - self.electrical_series_path = electrical_series_path - self.file_path = file_path - self.t_start = t_start - self.provided_or_electrical_series_sampling_frequency = sampling_frequency - self.storage_options = storage_options - self.units_table = None - - if self.stream_mode is None: - self.backend = _get_backend_from_local_file(file_path) - else: - if self.stream_mode == "zarr": - self.backend = "zarr" - else: - self.backend = "hdf5" - - if use_pynwb: - try: - import pynwb - except ImportError: - raise ImportError(self.installation_mesg) - - unit_ids, spike_times_data, spike_times_index_data = self._fetch_sorting_segment_info_pynwb( - unit_table_path=unit_table_path, samples_for_rate_estimation=samples_for_rate_estimation, cache=cache - ) - else: - unit_ids, spike_times_data, spike_times_index_data = self._fetch_sorting_segment_info_backend( - unit_table_path=unit_table_path, samples_for_rate_estimation=samples_for_rate_estimation, cache=cache - ) - - BaseSorting.__init__( - self, sampling_frequency=self.provided_or_electrical_series_sampling_frequency, unit_ids=unit_ids - ) - - sorting_segment = NwbSortingSegment( - spike_times_data=spike_times_data, - spike_times_index_data=spike_times_index_data, - sampling_frequency=self.sampling_frequency, - t_start=self.t_start, - ) - self.add_sorting_segment(sorting_segment) - - # fetch and add sorting properties - if load_unit_properties: - if use_pynwb: - columns = [c.name for c in self.units_table.columns] - self.extra_requirements.append("pynwb") - else: - columns = list(self.units_table.keys()) - self.extra_requirements.append("h5py") - properties = self._fetch_properties(columns) - for property_name, property_values in properties.items(): - values = [x.decode("utf-8") if isinstance(x, bytes) else x for x in property_values] - self.set_property(property_name, values) - - if stream_mode is None and file_path is not None: - file_path = str(Path(file_path).resolve()) - - if storage_options is not None and stream_mode == "zarr": - warnings.warn( - "The `storage_options` parameter will not be propagated to JSON or pickle files for security reasons, " - "so the extractor will not be JSON/pickle serializable. Only in-memory mode will be available." - ) - # not serializable if storage_options is provided - self._serializability["json"] = False - self._serializability["pickle"] = False - - self._kwargs = { - "file_path": file_path, - "electrical_series_path": self.electrical_series_path, - "sampling_frequency": sampling_frequency, - "samples_for_rate_estimation": samples_for_rate_estimation, - "cache": cache, - "stream_mode": stream_mode, - "stream_cache_path": stream_cache_path, - "storage_options": storage_options, - "load_unit_properties": load_unit_properties, - "t_start": self.t_start, - } - - def __del__(self): - # backend mode - if hasattr(self, "_file"): - if hasattr(self._file, "store"): - self._file.store.close() - else: - self._file.close() - # pynwb mode - elif hasattr(self, "_nwbfile"): - io = self._nwbfile.get_read_io() - if io is not None: - io.close() - - def _fetch_sorting_segment_info_pynwb( - self, unit_table_path: str = None, samples_for_rate_estimation: int = 1000, cache: bool = False - ): - self._nwbfile = read_nwbfile( - backend=self.backend, - file_path=self.file_path, - stream_mode=self.stream_mode, - cache=cache, - stream_cache_path=self.stream_cache_path, - storage_options=self.storage_options, - ) - - timestamps = None - if self.provided_or_electrical_series_sampling_frequency is None: - # defines the electrical series from where the sorting came from - # important to know the sampling_frequency - self.electrical_series = _retrieve_electrical_series_pynwb(self._nwbfile, self.electrical_series_path) - # get rate - if self.electrical_series.rate is not None: - self.provided_or_electrical_series_sampling_frequency = self.electrical_series.rate - self.t_start = self.electrical_series.starting_time - else: - if hasattr(self.electrical_series, "timestamps"): - if self.electrical_series.timestamps is not None: - timestamps = self.electrical_series.timestamps - self.provided_or_electrical_series_sampling_frequency = 1 / np.median( - np.diff(timestamps[:samples_for_rate_estimation]) - ) - self.t_start = timestamps[0] - assert ( - self.provided_or_electrical_series_sampling_frequency is not None - ), "Couldn't load sampling frequency. Please provide it with the 'sampling_frequency' argument" - assert ( - self.t_start is not None - ), "Couldn't load a starting time for the sorting. Please provide it with the 't_start' argument" - if unit_table_path == "units": - units_table = self._nwbfile.units - else: - units_table = _retrieve_unit_table_pynwb(self._nwbfile, unit_table_path=unit_table_path) - - name_to_column_data = {c.name: c for c in units_table.columns} - spike_times_data = name_to_column_data.pop("spike_times").data - spike_times_index_data = name_to_column_data.pop("spike_times_index").data - - units_ids = name_to_column_data.pop("unit_name", None) - if units_ids is None: - units_ids = units_table["id"].data - - # need this for later - self.units_table = units_table - - return units_ids, spike_times_data, spike_times_index_data - - def _fetch_sorting_segment_info_backend( - self, unit_table_path: str = None, samples_for_rate_estimation: int = 1000, cache: bool = False - ): - open_file = read_file_from_backend( - file_path=self.file_path, - stream_mode=self.stream_mode, - cache=cache, - stream_cache_path=self.stream_cache_path, - storage_options=self.storage_options, - ) - - timestamps = None - - if self.provided_or_electrical_series_sampling_frequency is None or self.t_start is None: - # defines the electrical series from where the sorting came from - # important to know the sampling_frequency - available_electrical_series = _find_neurodata_type_from_backend( - open_file, neurodata_type="ElectricalSeries", backend=self.backend - ) - if self.electrical_series_path is None: - if len(available_electrical_series) == 1: - self.electrical_series_path = available_electrical_series[0] - else: - raise ValueError( - "Multiple ElectricalSeries found in the file. " - "Please specify the 'electrical_series_path' argument:" - f"Available options are: {available_electrical_series}." - ) - else: - if self.electrical_series_path not in available_electrical_series: - raise ValueError( - f"'{self.electrical_series_path}' not found in the file. " - f"Available options are: {available_electrical_series}" - ) - electrical_series = open_file[self.electrical_series_path] - - # Get sampling frequency - if "starting_time" in electrical_series.keys(): - self.t_start = electrical_series["starting_time"][()] - self.provided_or_electrical_series_sampling_frequency = electrical_series["starting_time"].attrs["rate"] - elif "timestamps" in electrical_series.keys(): - timestamps = electrical_series["timestamps"][:] - self.t_start = timestamps[0] - self.provided_or_electrical_series_sampling_frequency = 1.0 / np.median( - np.diff(timestamps[:samples_for_rate_estimation]) - ) - - assert ( - self.provided_or_electrical_series_sampling_frequency is not None - ), "Couldn't load sampling frequency. Please provide it with the 'sampling_frequency' argument" - assert ( - self.t_start is not None - ), "Couldn't load a starting time for the sorting. Please provide it with the 't_start' argument" - - if unit_table_path is None: - available_unit_table_paths = _find_neurodata_type_from_backend( - open_file, neurodata_type="Units", backend=self.backend - ) - if len(available_unit_table_paths) == 1: - unit_table_path = available_unit_table_paths[0] - else: - raise ValueError( - "Multiple Units tables found in the file. " - "Please specify the 'unit_table_path' argument:" - f"Available options are: {available_unit_table_paths}." - ) - # Try to open the unit table. If it fails, raise an error with the available options. - try: - units_table = open_file[unit_table_path] - except KeyError: - available_unit_table_paths = _find_neurodata_type_from_backend( - open_file, neurodata_type="Units", backend=self.backend - ) - raise ValueError( - f"{unit_table_path} not found in the NWB file!" f"Available options are: {available_unit_table_paths}." - ) - self.units_table_location = unit_table_path - units_table = open_file[self.units_table_location] - - spike_times_data = units_table["spike_times"] - spike_times_index_data = units_table["spike_times_index"] - - if "unit_name" in units_table: - unit_ids = units_table["unit_name"] - else: - unit_ids = units_table["id"] - - decode_to_string = lambda x: x.decode("utf-8") if isinstance(x, bytes) else x - unit_ids = [decode_to_string(id) for id in unit_ids] - - # need this for later - self.units_table = units_table - - return unit_ids, spike_times_data, spike_times_index_data - - def _fetch_properties(self, columns): - units_table = self.units_table - - properties_to_skip = ["spike_times", "spike_times_index", "unit_name", "id"] - index_columns = [name for name in columns if name.endswith("_index")] - nested_ragged_array_properties = [name for name in columns if f"{name}_index_index" in columns] - - # Filter those properties that are nested ragged arrays - skip_properties = properties_to_skip + index_columns + nested_ragged_array_properties - properties_to_add = [name for name in columns if name not in skip_properties] - - properties = dict() - for property_name in properties_to_add: - data = units_table[property_name][:] - corresponding_index_name = f"{property_name}_index" - not_ragged_array = corresponding_index_name not in columns - if not_ragged_array: - values = data[:] - else: # TODO if we want we could make this recursive to handle nested ragged arrays - data_index = units_table[corresponding_index_name] - if hasattr(data_index, "data"): - # for pynwb we need to get the data from the data attribute - data_index = data_index.data[:] - else: - data_index = data_index[:] - index_spacing = np.diff(data_index, prepend=0) - all_index_spacing_are_the_same = np.unique(index_spacing).size == 1 - if all_index_spacing_are_the_same: - if hasattr(units_table[corresponding_index_name], "data"): - # ragged array indexing is handled by pynwb - values = data - else: - # ravel array based on data_index - start_indices = [0] + list(data_index[:-1]) - end_indices = list(data_index) - values = [ - data[start_index:end_index] for start_index, end_index in zip(start_indices, end_indices) - ] - else: - warnings.warn(f"Skipping {property_name} because of unequal shapes across units") - continue - properties[property_name] = values - - return properties - - -class NwbSortingSegment(BaseSortingSegment): - def __init__(self, spike_times_data, spike_times_index_data, sampling_frequency: float, t_start: float): - BaseSortingSegment.__init__(self) - self.spike_times_data = spike_times_data - self.spike_times_index_data = spike_times_index_data - self.spike_times_data = spike_times_data - self.spike_times_index_data = spike_times_index_data - self._sampling_frequency = sampling_frequency - self._t_start = t_start - - def get_unit_spike_train( - self, - unit_id, - start_frame: Optional[int] = None, - end_frame: Optional[int] = None, - ) -> np.ndarray: - # Extract the spike times for the unit - unit_index = self.parent_extractor.id_to_index(unit_id) - if unit_index == 0: - start_index = 0 - else: - start_index = self.spike_times_index_data[unit_index - 1] - end_index = self.spike_times_index_data[unit_index] - spike_times = self.spike_times_data[start_index:end_index] - - # Transform spike times to frames and subset - frames = np.round((spike_times - self._t_start) * self._sampling_frequency) - - start_index = 0 - if start_frame is not None: - start_index = np.searchsorted(frames, start_frame, side="left") - - if end_frame is not None: - end_index = np.searchsorted(frames, end_frame, side="left") - else: - end_index = frames.size - - return frames[start_index:end_index].astype("int64", copy=False) - - -read_nwb_recording = define_function_from_class(source_class=NwbRecordingExtractor, name="read_nwb_recording") -read_nwb_sorting = define_function_from_class(source_class=NwbSortingExtractor, name="read_nwb_sorting") - - -def read_nwb(file_path, load_recording=True, load_sorting=False, electrical_series_path=None): - """Reads NWB file into SpikeInterface extractors. - - Parameters - ---------- - file_path : str or Path - Path to NWB file. - load_recording : bool, default: True - If True, the recording object is loaded. - load_sorting : bool, default: False - If True, the recording object is loaded. - electrical_series_path : str or None, default: None - The name of the ElectricalSeries (if multiple ElectricalSeries are present) - - Returns - ------- - extractors : extractor or tuple - Single RecordingExtractor/SortingExtractor or tuple with both - (depending on "load_recording"/"load_sorting") arguments. - """ - outputs = () - if load_recording: - rec = read_nwb_recording(file_path, electrical_series_path=electrical_series_path) - outputs = outputs + (rec,) - if load_sorting: - sorting = read_nwb_sorting(file_path, electrical_series_path=electrical_series_path) - outputs = outputs + (sorting,) - - if len(outputs) == 1: - outputs = outputs[0] - - return outputs diff --git a/examples/DANDI/preprocess_ephys.py b/examples/DANDI/preprocess_ephys.py deleted file mode 100644 index ff454a6..0000000 --- a/examples/DANDI/preprocess_ephys.py +++ /dev/null @@ -1,111 +0,0 @@ -import numpy as np -import lindi -import pynwb -from pynwb.ecephys import ElectricalSeries -import spikeinterface.preprocessing as spre -from nwbextractors import NwbRecordingExtractor -from qfc.codecs import QFCCodec -from qfc import qfc_estimate_quant_scale_factor - -QFCCodec.register_codec() - - -def preprocess_ephys(): - # https://neurosift.app/?p=/nwb&url=https://api.dandiarchive.org/api/assets/2e6b590a-a2a4-4455-bb9b-45cc3d7d7cc0/download/&dandisetId=000463&dandisetVersion=draft - url = "https://api.dandiarchive.org/api/assets/2e6b590a-a2a4-4455-bb9b-45cc3d7d7cc0/download/" - - print('Creating LINDI file') - with lindi.LindiH5pyFile.from_hdf5_file(url) as f: - f.write_lindi_file("example.nwb.lindi.tar") - - cache = lindi.LocalCache() - - print('Reading LINDI file') - with lindi.LindiH5pyFile.from_lindi_file("example.nwb.lindi.tar", mode="r", local_cache=cache) as f: - electrical_series_path = '/acquisition/ElectricalSeries' - - print("Loading recording") - recording = NwbRecordingExtractor( - h5py_file=f, electrical_series_path=electrical_series_path - ) - print(recording.get_channel_ids()) - - num_frames = recording.get_num_frames() - start_time_sec = 0 - # duration_sec = 300 - duration_sec = num_frames / recording.get_sampling_frequency() - start_frame = int(start_time_sec * recording.get_sampling_frequency()) - end_frame = int(np.minimum(num_frames, (start_time_sec + duration_sec) * recording.get_sampling_frequency())) - recording = recording.frame_slice( - start_frame=start_frame, - end_frame=end_frame - ) - - # bandpass filter - print("Filtering recording") - freq_min = 300 - freq_max = 6000 - recording_filtered = spre.bandpass_filter( - recording, freq_min=freq_min, freq_max=freq_max, dtype=np.float32 - ) # important to specify dtype here - f.close() - - traces0 = recording_filtered.get_traces(start_frame=0, end_frame=int(1 * recording_filtered.get_sampling_frequency())) - traces0 = traces0.astype(dtype=traces0.dtype, order='C') - - # noise_level = estimate_noise_level(traces0) - # print(f'Noise level: {noise_level}') - # scale_factor = qfc_estimate_quant_scale_factor(traces0, target_residual_stdev=noise_level * 0.2) - - compression_method = 'zlib' - zlib_level = 3 - zstd_level = 3 - - scale_factor = qfc_estimate_quant_scale_factor( - traces0, - target_compression_ratio=10, - compression_method=compression_method, - zlib_level=zlib_level, - zstd_level=zstd_level - ) - print(f'Quant. scale factor: {scale_factor}') - codec = QFCCodec( - quant_scale_factor=scale_factor, - dtype='float32', - segment_length=int(recording_filtered.get_sampling_frequency() * 1), - compression_method=compression_method, - zlib_level=zlib_level, - zstd_level=zstd_level - ) - traces0_compressed = codec.encode(traces0) - compression_ratio = traces0.size * 2 / len(traces0_compressed) - print(f'Compression ratio: {compression_ratio}') - - print("Writing filtered recording to LINDI file") - with lindi.LindiH5pyFile.from_lindi_file("example.nwb.lindi.tar", mode="a", local_cache=cache) as f: - with pynwb.NWBHDF5IO(file=f, mode='a') as io: - nwbfile = io.read() - - electrical_series = nwbfile.acquisition['ElectricalSeries'] - electrical_series_pre = ElectricalSeries( - name="ElectricalSeries_pre", - data=pynwb.H5DataIO( - recording_filtered.get_traces(), - chunks=(30000, recording.get_num_channels()), - compression=codec - ), - electrodes=electrical_series.electrodes, - starting_time=0.0, # timestamp of the first sample in seconds relative to the session start time - rate=recording_filtered.get_sampling_frequency(), - ) - nwbfile.add_acquisition(electrical_series_pre) # type: ignore - io.write(nwbfile) - - -def estimate_noise_level(traces): - noise_level = np.median(np.abs(traces - np.median(traces))) / 0.6745 - return noise_level - - -if __name__ == "__main__": - preprocess_ephys() \ No newline at end of file diff --git a/examples/amend_remote_nwb_as_lindi_tar.py b/examples/amend_remote_nwb_as_lindi_tar.py new file mode 100644 index 0000000..d4c7ebf --- /dev/null +++ b/examples/amend_remote_nwb_as_lindi_tar.py @@ -0,0 +1,32 @@ +import numpy as np +import pynwb +from pynwb.file import TimeSeries +import lindi + +# Load the remote NWB file from DANDI +h5_url = "https://api.dandiarchive.org/api/assets/11f512ba-5bcf-4230-a8cb-dc8d36db38cb/download/" +f = lindi.LindiH5pyFile.from_hdf5_file(h5_url) + +# Write to a local .lindi.tar file +f.write_lindi_file('example.nwb.lindi.tar') +f.close() + +# Open with pynwb and add new data +g = lindi.LindiH5pyFile.from_lindi_file('example.nwb.lindi.tar', mode='r+') +with pynwb.NWBHDF5IO(file=g, mode="a") as io: + nwbfile = io.read() + timeseries_test = TimeSeries( + name="test", + data=np.array([1, 2, 3, 4, 5, 4, 3, 2, 1]), + rate=1., + unit='s' + ) + ts = nwbfile.processing['behavior'].add(timeseries_test) # type: ignore + io.write(nwbfile) # type: ignore + +# Later on, you can read the file again +h = lindi.LindiH5pyFile.from_lindi_file('example.nwb.lindi.tar') +with pynwb.NWBHDF5IO(file=h, mode="r") as io: + nwbfile = io.read() + test_timeseries = nwbfile.processing['behavior']['test'] # type: ignore + print(test_timeseries) diff --git a/examples/benchmark1.py b/examples/benchmark1.py index 40f174f..36ce1a7 100644 --- a/examples/benchmark1.py +++ b/examples/benchmark1.py @@ -8,74 +8,94 @@ import numcodecs +# Benchmark writing a large number of small and large datasets to .lindi.tar, +# .zarr, .h5, and .dat files + + def create_dataset(size): return np.random.rand(size) -def benchmark_h5py(file_path, num_small_datasets, num_large_datasets, small_size, large_size, chunks, compression, mode): +def benchmark_h5py( + file_path, + num_small_datasets, + num_large_datasets, + small_size, + large_size, + chunks, + compression, + mode, +): start_time = time.time() - if mode == 'dat': - with open(file_path, 'wb') as f: + if mode == "dat": + with open(file_path, "wb") as f: # Write small datasets - print('Writing small datasets') + print("Writing small datasets") for i in range(num_small_datasets): data = create_dataset(small_size) f.write(data.tobytes()) # Write large datasets - print('Writing large datasets') + print("Writing large datasets") for i in range(num_large_datasets): data = create_dataset(large_size) - if compression == 'gzip': + if compression == "gzip": data_zipped = gzip.compress(data.tobytes(), compresslevel=4) f.write(data_zipped) elif compression is None: f.write(data.tobytes()) else: raise ValueError(f"Unknown compressor: {compression}") - elif mode == 'zarr': + elif mode == "zarr": if os.path.exists(file_path): import shutil + shutil.rmtree(file_path) store = zarr.DirectoryStore(file_path) root = zarr.group(store) - if compression == 'gzip': + if compression == "gzip": compressor = numcodecs.GZip(level=4) else: compressor = None # Write small datasets - print('Writing small datasets') + print("Writing small datasets") for i in range(num_small_datasets): data = create_dataset(small_size) - root.create_dataset(f'small_dataset_{i}', data=data) + root.create_dataset(f"small_dataset_{i}", data=data) # Write large datasets - print('Writing large datasets') + print("Writing large datasets") for i in range(num_large_datasets): data = create_dataset(large_size) - root.create_dataset(f'large_dataset_{i}', data=data, chunks=chunks, compressor=compressor) + root.create_dataset( + f"large_dataset_{i}", data=data, chunks=chunks, compressor=compressor + ) else: - if mode == 'h5': - f = h5py.File(file_path, 'w') + if mode == "h5": + f = h5py.File(file_path, "w") + elif mode == "lindi": + f = lindi.LindiH5pyFile.from_lindi_file(file_path, mode="w") else: - f = lindi.LindiH5pyFile.from_lindi_file(file_path, mode='w') + raise ValueError(f"Unknown mode: {mode}") # Write small datasets - print('Writing small datasets') + print("Writing small datasets") for i in range(num_small_datasets): data = create_dataset(small_size) - ds = f.create_dataset(f'small_dataset_{i}', data=data) - ds.attrs['attr1'] = 1 + ds = f.create_dataset(f"small_dataset_{i}", data=data) + ds.attrs["attr1"] = 1 # Write large datasets - print('Writing large datasets') + print("Writing large datasets") for i in range(num_large_datasets): data = create_dataset(large_size) - ds = f.create_dataset(f'large_dataset_{i}', data=data, chunks=chunks, compression=compression) - ds.attrs['attr1'] = 1 + ds = f.create_dataset( + f"large_dataset_{i}", data=data, chunks=chunks, compression=compression + ) + ds.attrs["attr1"] = 1 f.close() @@ -83,15 +103,17 @@ def benchmark_h5py(file_path, num_small_datasets, num_large_datasets, small_size total_time = end_time - start_time # Calculate total data size - total_size = (num_small_datasets * small_size + num_large_datasets * large_size) * 8 # 8 bytes per float64 - total_size_gb = total_size / (1024 ** 3) + total_size = ( + num_small_datasets * small_size + num_large_datasets * large_size + ) * 8 # 8 bytes per float64 + total_size_gb = total_size / (1024**3) print("Benchmark Results:") print(f"Total time: {total_time:.2f} seconds") print(f"Total data size: {total_size_gb:.2f} GB") print(f"Write speed: {total_size_gb / total_time:.2f} GB/s") - h5py_file_size = os.path.getsize(file_path) / (1024 ** 3) + h5py_file_size = os.path.getsize(file_path) / (1024**3) print(f"File size: {h5py_file_size:.2f} GB") return total_time, total_size_gb @@ -109,17 +131,54 @@ def benchmark_h5py(file_path, num_small_datasets, num_large_datasets, small_size compression = None # 'gzip' or None chunks = (large_size / 20,) - print('Lindi Benchmark') - lindi_time, total_size = benchmark_h5py(file_path_lindi, num_small_datasets, num_large_datasets, small_size, large_size, chunks=chunks, compression=compression, mode='lindi') - print('') - print('Zarr Benchmark') - lindi_time, total_size = benchmark_h5py(file_path_zarr, num_small_datasets, num_large_datasets, small_size, large_size, chunks=chunks, compression=compression, mode='zarr') - print('') - print('H5PY Benchmark') - h5py_time, total_size = benchmark_h5py(file_path_h5, num_small_datasets, num_large_datasets, small_size, large_size, chunks=chunks, compression=compression, mode='h5') - print('') - print('DAT Benchmark') - dat, total_size = benchmark_h5py(file_path_dat, num_small_datasets, num_large_datasets, small_size, large_size, chunks=chunks, compression=compression, mode='dat') + print("Lindi Benchmark") + lindi_time, total_size = benchmark_h5py( + file_path_lindi, + num_small_datasets, + num_large_datasets, + small_size, + large_size, + chunks=chunks, + compression=compression, + mode="lindi", + ) + print("") + print("Zarr Benchmark") + lindi_time, total_size = benchmark_h5py( + file_path_zarr, + num_small_datasets, + num_large_datasets, + small_size, + large_size, + chunks=chunks, + compression=compression, + mode="zarr", + ) + print("") + print("H5PY Benchmark") + h5py_time, total_size = benchmark_h5py( + file_path_h5, + num_small_datasets, + num_large_datasets, + small_size, + large_size, + chunks=chunks, + compression=compression, + mode="h5", + ) + print("") + print("DAT Benchmark") + dat, total_size = benchmark_h5py( + file_path_dat, + num_small_datasets, + num_large_datasets, + small_size, + large_size, + chunks=chunks, + compression=compression, + mode="dat", + ) import shutil - shutil.copyfile(file_path_lindi, file_path_lindi + '.tar') + + shutil.copyfile(file_path_lindi, file_path_lindi + ".tar") diff --git a/examples/example_a.py b/examples/create_and_read_lindi_json.py similarity index 93% rename from examples/example_a.py rename to examples/create_and_read_lindi_json.py index 3f6390c..89d228f 100644 --- a/examples/example_a.py +++ b/examples/create_and_read_lindi_json.py @@ -11,4 +11,4 @@ with lindi.LindiH5pyFile.from_lindi_file('example.lindi.json', mode='r') as f: print(f.attrs['attr1']) print(f.attrs['attr2']) - print(f['dataset1'][...]) \ No newline at end of file + print(f['dataset1'][...]) diff --git a/examples/example_b.py b/examples/create_and_read_lindi_tar.py similarity index 93% rename from examples/example_b.py rename to examples/create_and_read_lindi_tar.py index c982e58..1b2b6c9 100644 --- a/examples/example_b.py +++ b/examples/create_and_read_lindi_tar.py @@ -12,4 +12,4 @@ with lindi.LindiH5pyFile.from_lindi_file('example.lindi.tar', mode='r') as f: print(f.attrs['attr1']) print(f.attrs['attr2']) - print(f['dataset1'][...]) \ No newline at end of file + print(f['dataset1'][...]) diff --git a/examples/example1.py b/examples/example1.py deleted file mode 100644 index f90b136..0000000 --- a/examples/example1.py +++ /dev/null @@ -1,30 +0,0 @@ -import json -import pynwb -import lindi - -# Define the URL for a remote NWB file -h5_url = "https://api.dandiarchive.org/api/assets/11f512ba-5bcf-4230-a8cb-dc8d36db38cb/download/" - -# Create a read-only Zarr store as a wrapper for the h5 file -store = lindi.LindiH5ZarrStore.from_file(h5_url) - -# Generate a reference file system -rfs = store.to_reference_file_system() - -# Save it to a file for later use -with open("example.nwb.lindi.json", "w") as f: - json.dump(rfs, f, indent=2) - -# Create an h5py-like client from the reference file system -client = lindi.LindiH5pyFile.from_reference_file_system(rfs) - -# Open using pynwb -with pynwb.NWBHDF5IO(file=client, mode="r") as io: - nwbfile = io.read() - print(nwbfile) - - print('Electrode group at shank0:') - print(nwbfile.electrode_groups["shank0"]) # type: ignore - - print('Electrode group at index 0:') - print(nwbfile.electrodes.group[0]) # type: ignore diff --git a/examples/example2.py b/examples/example2.py deleted file mode 100644 index 10e7f61..0000000 --- a/examples/example2.py +++ /dev/null @@ -1,13 +0,0 @@ -import pynwb -import lindi - -# Define the URL for a remote .nwb.lindi.json file -url = 'https://lindi.neurosift.org/dandi/dandisets/000939/assets/56d875d6-a705-48d3-944c-53394a389c85/nwb.lindi.json' - -# Load the h5py-like client from the reference file system -client = lindi.LindiH5pyFile.from_lindi_file(url) - -# Open using pynwb -with pynwb.NWBHDF5IO(file=client, mode="r") as io: - nwbfile = io.read() - print(nwbfile) diff --git a/examples/example_ammend_remote_nwb.py b/examples/example_ammend_remote_nwb.py deleted file mode 100644 index 348ec31..0000000 --- a/examples/example_ammend_remote_nwb.py +++ /dev/null @@ -1,33 +0,0 @@ -import numpy as np -import lindi -import pynwb - - -def example_ammend_remote_nwb(): - url = 'https://api.dandiarchive.org/api/assets/2e6b590a-a2a4-4455-bb9b-45cc3d7d7cc0/download/' - with lindi.LindiH5pyFile.from_hdf5_file(url) as f: - f.write_lindi_file('example.nwb.lindi.tar') - with lindi.LindiH5pyFile.from_lindi_file('example.nwb.lindi.tar', mode='r+') as f: - - # Can't figure out how to modify something using pyNWB - # with pynwb.NWBHDF5IO(file=f, mode='r+') as io: - # nwbfile = io.read() - # print(nwbfile) - # nwbfile.session_description = 'Modified session description' - # io.write(nwbfile) - - f['session_description'][()] = 'new session description' - - # Create something that will become a new file in the tar - ds = f.create_dataset('new_dataset', data=np.random.rand(10000, 1000), chunks=(1000, 200)) - ds[20, 20] = 42 - - with lindi.LindiH5pyFile.from_lindi_file('example.nwb.lindi.tar', mode='r') as f: - with pynwb.NWBHDF5IO(file=f, mode='r') as io: - nwbfile = io.read() - print(nwbfile) - print(f['new_dataset'][20, 20]) - - -if __name__ == '__main__': - example_ammend_remote_nwb() diff --git a/examples/example_create_zarr_nwb.py b/examples/example_create_zarr_nwb.py deleted file mode 100644 index eb2dd1e..0000000 --- a/examples/example_create_zarr_nwb.py +++ /dev/null @@ -1,121 +0,0 @@ -from typing import Any -import shutil -import os -import zarr -import pynwb -import lindi - - -def example_create_zarr_nwb(): - zarr_dirname = 'example_nwb.zarr' - if os.path.exists(zarr_dirname): - shutil.rmtree(zarr_dirname) - - nwbfile = _create_sample_nwb_file() - - store = zarr.DirectoryStore(zarr_dirname) - zarr.group(store=store) # create a root group - with lindi.LindiH5pyFile.from_zarr_store(store, mode='r+') as client: - with pynwb.NWBHDF5IO(file=client, mode='r+') as io: - io.write(nwbfile) # type: ignore - - -def _create_sample_nwb_file(): - from datetime import datetime - from uuid import uuid4 - - import numpy as np - from dateutil.tz import tzlocal - - from pynwb import NWBFile - from pynwb.ecephys import LFP, ElectricalSeries - - nwbfile: Any = NWBFile( - session_description="my first synthetic recording", - identifier=str(uuid4()), - session_start_time=datetime.now(tzlocal()), - experimenter=[ - "Baggins, Bilbo", - ], - lab="Bag End Laboratory", - institution="University of Middle Earth at the Shire", - experiment_description="I went on an adventure to reclaim vast treasures.", - session_id="LONELYMTN001", - ) - - device = nwbfile.create_device( - name="array", description="the best array", manufacturer="Probe Company 9000" - ) - - nwbfile.add_electrode_column(name="label", description="label of electrode") - - nshanks = 4 - nchannels_per_shank = 3 - electrode_counter = 0 - - for ishank in range(nshanks): - # create an electrode group for this shank - electrode_group = nwbfile.create_electrode_group( - name="shank{}".format(ishank), - description="electrode group for shank {}".format(ishank), - device=device, - location="brain area", - ) - # add electrodes to the electrode table - for ielec in range(nchannels_per_shank): - nwbfile.add_electrode( - group=electrode_group, - label="shank{}elec{}".format(ishank, ielec), - location="brain area", - ) - electrode_counter += 1 - - all_table_region = nwbfile.create_electrode_table_region( - region=list(range(electrode_counter)), # reference row indices 0 to N-1 - description="all electrodes", - ) - - raw_data = np.random.randn(50, 12) - raw_electrical_series = ElectricalSeries( - name="ElectricalSeries", - data=raw_data, - electrodes=all_table_region, - starting_time=0.0, # timestamp of the first sample in seconds relative to the session start time - rate=20000.0, # in Hz - ) - - nwbfile.add_acquisition(raw_electrical_series) - - lfp_data = np.random.randn(50, 12) - lfp_electrical_series = ElectricalSeries( - name="ElectricalSeries", - data=lfp_data, - electrodes=all_table_region, - starting_time=0.0, - rate=200.0, - ) - - lfp = LFP(electrical_series=lfp_electrical_series) - - ecephys_module = nwbfile.create_processing_module( - name="ecephys", description="processed extracellular electrophysiology data" - ) - ecephys_module.add(lfp) - - nwbfile.add_unit_column(name="quality", description="sorting quality") - - firing_rate = 20 - n_units = 10 - res = 1000 - duration = 20 - for n_units_per_shank in range(n_units): - spike_times = ( - np.where(np.random.rand((res * duration)) < (firing_rate / res))[0] / res - ) - nwbfile.add_unit(spike_times=spike_times, quality="good") - - return nwbfile - - -if __name__ == '__main__': - example_create_zarr_nwb() diff --git a/examples/example_d.py b/examples/example_d.py deleted file mode 100644 index aca4749..0000000 --- a/examples/example_d.py +++ /dev/null @@ -1,15 +0,0 @@ -import numpy as np -import lindi - -# Create a new lindi binary file -with lindi.LindiH5pyFile.from_lindi_file('example.lindi.d', mode='w') as f: - f.attrs['attr1'] = 'value1' - f.attrs['attr2'] = 7 - ds = f.create_dataset('dataset1', shape=(1000, 1000), dtype='f') - ds[...] = np.random.rand(1000, 1000) - -# Later read the file -with lindi.LindiH5pyFile.from_lindi_file('example.lindi.d', mode='r') as f: - print(f.attrs['attr1']) - print(f.attrs['attr2']) - print(f['dataset1'][...]) \ No newline at end of file diff --git a/examples/example_edit_nwb.py b/examples/example_edit_nwb.py deleted file mode 100644 index 351eca5..0000000 --- a/examples/example_edit_nwb.py +++ /dev/null @@ -1,32 +0,0 @@ -import lindi -import h5py -import pynwb - - -# Define the URL for a remote .nwb.lindi.json file -url = 'https://lindi.neurosift.org/dandi/dandisets/000939/assets/56d875d6-a705-48d3-944c-53394a389c85/nwb.lindi.json' - -# Load the h5py-like client from the reference file system -client = lindi.LindiH5pyFile.from_lindi_file(url, mode='r+') - -# modify the age of the subject -subject = client['general']['subject'] # type: ignore -assert isinstance(subject, h5py.Group) -del subject['age'] # type: ignore -subject.create_dataset('age', data=b'3w') - -# Create a new reference file system -rfs_new = client.to_reference_file_system() - -# Optionally write to a file -# import json -# with open('new.nwb.lindi.json', 'w') as f: -# json.dump(rfs_new, f) - -# Load a new h5py-like client from the new reference file system -client_new = lindi.LindiH5pyFile.from_reference_file_system(rfs_new) - -# Open using pynwb and verify that the subject age has been updated -with pynwb.NWBHDF5IO(file=client, mode="r") as io: - nwbfile = io.read() - print(nwbfile) diff --git a/examples/example_tar_nwb.py b/examples/example_tar_nwb.py deleted file mode 100644 index cd03dd0..0000000 --- a/examples/example_tar_nwb.py +++ /dev/null @@ -1,147 +0,0 @@ -from typing import Any -import pynwb -import h5py -import lindi - - -nwb_lindi_fname = 'example.nwb.lindi.tar' -nwb_fname = 'example.nwb' - - -def test_write_lindi(): - print('test_write_lindi') - nwbfile = _create_sample_nwb_file() - with lindi.LindiH5pyFile.from_lindi_file(nwb_lindi_fname, mode='w') as client: - with pynwb.NWBHDF5IO(file=client, mode='w') as io: - io.write(nwbfile) # type: ignore - - -def test_read_lindi(): - print('test_read_lindi') - with lindi.LindiH5pyFile.from_lindi_file(nwb_lindi_fname, mode='r') as client: - with pynwb.NWBHDF5IO(file=client, mode='r') as io: - nwbfile = io.read() - print(nwbfile) - - -def test_write_h5(): - print('test_write_h5') - nwbfile = _create_sample_nwb_file() - with h5py.File(nwb_fname, 'w') as h5f: - with pynwb.NWBHDF5IO(file=h5f, mode='w') as io: - io.write(nwbfile) # type: ignore - - -def test_read_h5(): - print('test_read_h5') - with h5py.File(nwb_fname, 'r') as h5f: - with pynwb.NWBHDF5IO(file=h5f, mode='r') as io: - nwbfile = io.read() - print(nwbfile) - - -def _create_sample_nwb_file(): - from datetime import datetime - from uuid import uuid4 - - import numpy as np - from dateutil.tz import tzlocal - - from pynwb import NWBFile - from pynwb.ecephys import LFP, ElectricalSeries - - nwbfile: Any = NWBFile( - session_description="my first synthetic recording", - identifier=str(uuid4()), - session_start_time=datetime.now(tzlocal()), - experimenter=[ - "Baggins, Bilbo", - ], - lab="Bag End Laboratory", - institution="University of Middle Earth at the Shire", - experiment_description="I went on an adventure to reclaim vast treasures.", - session_id="LONELYMTN001", - ) - - device = nwbfile.create_device( - name="array", description="the best array", manufacturer="Probe Company 9000" - ) - - nwbfile.add_electrode_column(name="label", description="label of electrode") - - nshanks = 4 - nchannels_per_shank = 3 - electrode_counter = 0 - - for ishank in range(nshanks): - # create an electrode group for this shank - electrode_group = nwbfile.create_electrode_group( - name="shank{}".format(ishank), - description="electrode group for shank {}".format(ishank), - device=device, - location="brain area", - ) - # add electrodes to the electrode table - for ielec in range(nchannels_per_shank): - nwbfile.add_electrode( - group=electrode_group, - label="shank{}elec{}".format(ishank, ielec), - location="brain area", - ) - electrode_counter += 1 - - all_table_region = nwbfile.create_electrode_table_region( - region=list(range(electrode_counter)), # reference row indices 0 to N-1 - description="all electrodes", - ) - - raw_data = np.random.randn(50, 12) - raw_electrical_series = ElectricalSeries( - name="ElectricalSeries", - data=raw_data, - electrodes=all_table_region, - starting_time=0.0, # timestamp of the first sample in seconds relative to the session start time - rate=20000.0, # in Hz - ) - - nwbfile.add_acquisition(raw_electrical_series) - - lfp_data = np.random.randn(5000, 12) - lfp_electrical_series = ElectricalSeries( - name="ElectricalSeries", - data=lfp_data, - electrodes=all_table_region, - starting_time=0.0, - rate=200.0, - ) - - lfp = LFP(electrical_series=lfp_electrical_series) - - ecephys_module = nwbfile.create_processing_module( - name="ecephys", description="processed extracellular electrophysiology data" - ) - ecephys_module.add(lfp) - - nwbfile.add_unit_column(name="quality", description="sorting quality") - - firing_rate = 20 - n_units = 10 - res = 1000 - duration = 2000 - for n_units_per_shank in range(n_units): - spike_times = ( - np.where(np.random.rand((res * duration)) < (firing_rate / res))[0] / res - ) - nwbfile.add_unit(spike_times=spike_times, quality="good") - - return nwbfile - - -if __name__ == '__main__': - test_write_lindi() - test_read_lindi() - print('_________________________________') - print('') - - test_write_h5() - test_read_h5() diff --git a/examples/lindi_demo.ipynb b/examples/lindi_demo.ipynb deleted file mode 100644 index 5082386..0000000 --- a/examples/lindi_demo.ipynb +++ /dev/null @@ -1,208 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "52802664-85e7-433a-a8c3-f2645847423b", - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [ - "%pip install -q lindi\n", - "%pip install -q pynwb" - ] - }, - { - "cell_type": "markdown", - "id": "94dadefb-c04a-4aea-b906-b4cc3a263570", - "metadata": {}, - "source": [ - "### Lazy-load a remote NWB/HDF5 file for efficient access to metadata and data" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "65b50dec-5c6e-40cf-aa93-c18d220b74bb", - "metadata": {}, - "outputs": [], - "source": [ - "import pynwb\n", - "import lindi\n", - "\n", - "# URL of the remote NWB file\n", - "h5_url = \"https://api.dandiarchive.org/api/assets/11f512ba-5bcf-4230-a8cb-dc8d36db38cb/download/\"\n", - "\n", - "# Set up a local cache\n", - "local_cache = lindi.LocalCache(cache_dir='lindi_cache')\n", - "\n", - "# Create the h5py-like client\n", - "client = lindi.LindiH5pyFile.from_hdf5_file(h5_url, local_cache=local_cache)\n", - "\n", - "# Open using pynwb\n", - "with pynwb.NWBHDF5IO(file=client, mode=\"r\") as io:\n", - " nwbfile = io.read()\n", - " print(nwbfile)\n", - "\n", - "# The downloaded data will be cached locally, so subsequent reads will be faster" - ] - }, - { - "cell_type": "markdown", - "id": "79a2ce81-c3da-4c9c-a013-fd5fa34762f7", - "metadata": {}, - "source": [ - "### Represent a remote NWB/HDF5 file as a .nwb.lindi.json file" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d11db473-ab25-4d1e-ac15-7a16654c4bcf", - "metadata": {}, - "outputs": [], - "source": [ - "import json\n", - "import lindi\n", - "\n", - "# URL of the remote NWB file\n", - "h5_url = \"https://api.dandiarchive.org/api/assets/11f512ba-5bcf-4230-a8cb-dc8d36db38cb/download/\"\n", - "\n", - "# Create the h5py-like client\n", - "client = lindi.LindiH5pyFile.from_hdf5_file(h5_url)\n", - "\n", - "client.write_lindi_file('example.lindi.json')\n", - "\n", - "# See the next example for how to read this file" - ] - }, - { - "cell_type": "markdown", - "id": "2550661e-ee88-45d8-b6b2-8c6f8fe4dfaa", - "metadata": {}, - "source": [ - "### Read a local or remote .nwb.lindi.json file using pynwb or other tools" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9973c935-db2e-468f-ba9c-7f5e7153b319", - "metadata": {}, - "outputs": [], - "source": [ - "import pynwb\n", - "import lindi\n", - "\n", - "# URL of the remote .nwb.lindi.json file\n", - "url = 'https://lindi.neurosift.org/dandi/dandisets/000939/assets/56d875d6-a705-48d3-944c-53394a389c85/nwb.lindi.json'\n", - "\n", - "# Load the h5py-like client\n", - "client = lindi.LindiH5pyFile.from_lindi_file(url)\n", - "\n", - "# Open using pynwb\n", - "with pynwb.NWBHDF5IO(file=client, mode=\"r\") as io:\n", - " nwbfile = io.read()\n", - " print(nwbfile)" - ] - }, - { - "cell_type": "markdown", - "id": "ea03ad43-7d0b-4cfd-9992-bdb1ada6b13c", - "metadata": {}, - "source": [ - "### Edit a .nwb.lindi.json file using pynwb or other tools" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1c66931b-3ea9-4222-afc8-a93b25457e37", - "metadata": {}, - "outputs": [], - "source": [ - "import json\n", - "import lindi\n", - "\n", - "# URL of the remote .nwb.lindi.json file\n", - "url = 'https://lindi.neurosift.org/dandi/dandisets/000939/assets/56d875d6-a705-48d3-944c-53394a389c85/nwb.lindi.json'\n", - "\n", - "# Load the h5py-like client for the reference file system\n", - "# in read-write mode\n", - "client = lindi.LindiH5pyFile.from_lindi_file(url, mode=\"r+\")\n", - "\n", - "# Edit an attribute\n", - "client.attrs['new_attribute'] = 'new_value'\n", - "\n", - "# Save the changes to a new .nwb.lindi.json file\n", - "client.write_lindi_file('new.nwb.lindi.json')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "fe19e0f6-1c62-42e9-9af0-4a57c8a61364", - "metadata": {}, - "outputs": [], - "source": [ - "# Now load that file\n", - "client2 = lindi.LindiH5pyFile.from_lindi_file('new.nwb.lindi.json')\n", - "print(client2.attrs['new_attribute'])" - ] - }, - { - "cell_type": "markdown", - "id": "4a2addfc-58ed-4e79-9c64-b7ec95cb12f5", - "metadata": {}, - "source": [ - "### Add datasets to a .nwb.lindi.json file using a local staging area" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6e87640d-1927-43c1-89c1-c1274a11f185", - "metadata": {}, - "outputs": [], - "source": [ - "import lindi\n", - "\n", - "# URL of the remote .nwb.lindi.json file\n", - "url = 'https://lindi.neurosift.org/dandi/dandisets/000939/assets/56d875d6-a705-48d3-944c-53394a389c85/nwb.lindi.json'\n", - "\n", - "# Load the h5py-like client for the reference file system\n", - "# in read-write mode with a staging area\n", - "with lindi.StagingArea.create(base_dir='lindi_staging') as staging_area:\n", - " client = lindi.LindiH5pyFile.from_lindi_file(\n", - " url,\n", - " mode=\"r+\",\n", - " staging_area=staging_area\n", - " )\n", - " # add datasets to client using pynwb or other tools\n", - " # upload the changes to the remote .nwb.lindi.json file" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/examples/example_c.py b/examples/represent_remote_nwb_as_lindi_json.py similarity index 93% rename from examples/example_c.py rename to examples/represent_remote_nwb_as_lindi_json.py index 279dfc2..f77aa9b 100644 --- a/examples/example_c.py +++ b/examples/represent_remote_nwb_as_lindi_json.py @@ -1,4 +1,3 @@ -import json import pynwb import lindi @@ -20,6 +19,7 @@ # Save as LINDI JSON f.write_lindi_file('example.nwb.lindi.json') +f.close() # Later, read directly from the LINDI JSON file g = lindi.LindiH5pyFile.from_lindi_file('example.nwb.lindi.json') @@ -33,4 +33,4 @@ print(nwbfile.electrode_groups["shank0"]) # type: ignore print('Electrode group at index 0:') - print(nwbfile.electrodes.group[0]) # type: ignore \ No newline at end of file + print(nwbfile.electrodes.group[0]) # type: ignore diff --git a/examples/write_lindi_binary.py b/examples/write_lindi_binary.py deleted file mode 100644 index 8442321..0000000 --- a/examples/write_lindi_binary.py +++ /dev/null @@ -1,21 +0,0 @@ -import numpy as np -import lindi - - -def write_lindi_binary(): - with lindi.LindiH5pyFile.from_lindi_file('test.lindi.tar', mode='w') as f: - f.attrs['test'] = 42 - ds = f.create_dataset('data', shape=(1000, 1000), dtype='f4') - ds[...] = np.random.rand(1000, 1000) - - -def test_read(): - f = lindi.LindiH5pyFile.from_lindi_file('test.lindi.tar', mode='r') - print(f.attrs['test']) - print(f['data'][0, 0]) - f.close() - - -if __name__ == "__main__": - write_lindi_binary() - test_read() diff --git a/lindi/File/File.py b/lindi/File/File.py deleted file mode 100644 index 59e0f4c..0000000 --- a/lindi/File/File.py +++ /dev/null @@ -1,34 +0,0 @@ -from typing import Literal -import os -import h5py -from ..LindiH5pyFile.LindiH5pyFile import LindiH5pyFile -from ..LindiStagingStore.StagingArea import StagingArea -from ..LocalCache.LocalCache import LocalCache - - -class File(h5py.File): - """ - A drop-in replacement for h5py.File that is either a lindi.LindiH5pyFile or - h5py.File depending on whether the file name ends with .lindi.json or not. - """ - def __new__(cls, name, mode: Literal['r', 'r+', 'w', 'w-', 'x', 'a'] = 'r', **kwds): - if isinstance(name, str) and name.endswith('.lindi.json'): - # should we raise exceptions on select unsupported kwds? or just go with the flow? - if mode != 'r': - staging_area = StagingArea.create(dir=name + '.d') - else: - staging_area = None - local_cache_dir = os.environ.get('LINDI_LOCAL_CACHE_DIR', None) - if local_cache_dir is not None: - local_cache = LocalCache(cache_dir=local_cache_dir) - else: - local_cache = None - - return LindiH5pyFile.from_lindi_file( - name, - mode=mode, - staging_area=staging_area, - local_cache=local_cache - ) - else: - return h5py.File(name, mode=mode, **kwds) diff --git a/lindi/File/__init__.py b/lindi/File/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/lindi/LindiH5pyFile/LindiH5pyFile.py b/lindi/LindiH5pyFile/LindiH5pyFile.py index 37681c7..6d7ecc7 100644 --- a/lindi/LindiH5pyFile/LindiH5pyFile.py +++ b/lindi/LindiH5pyFile/LindiH5pyFile.py @@ -1,4 +1,4 @@ -from typing import Union, Literal, Callable +from typing import Union, Literal import os import json import tempfile @@ -12,8 +12,6 @@ from .LindiH5pyReference import LindiH5pyReference from .LindiReferenceFileSystemStore import LindiReferenceFileSystemStore -from ..LindiStagingStore.StagingArea import StagingArea -from ..LindiStagingStore.LindiStagingStore import LindiStagingStore, _apply_templates from ..LindiH5ZarrStore.LindiH5ZarrStoreOpts import LindiH5ZarrStoreOpts from ..LocalCache.LocalCache import LocalCache @@ -26,10 +24,6 @@ LindiFileMode = Literal["r", "r+", "w", "w-", "x", "a"] -# Accepts a string path to a file, uploads (or copies) it somewhere, and returns a string URL -# (or local path) -UploadFileFunc = Callable[[str], str] - class LindiH5pyFile(h5py.File): def __init__(self, _zarr_group: zarr.Group, *, _zarr_store: Union[ZarrStore, None] = None, _mode: LindiFileMode = "r", _local_cache: Union[LocalCache, None] = None, _source_url_or_path: Union[str, None] = None, _source_tar_file: Union[LindiTarFile, None] = None, _close_source_tar_file_on_close: bool = False): @@ -53,7 +47,7 @@ def __init__(self, _zarr_group: zarr.Group, *, _zarr_store: Union[ZarrStore, Non self._is_open = True @staticmethod - def from_lindi_file(url_or_path: str, *, mode: LindiFileMode = "r", staging_area: Union[StagingArea, None] = None, local_cache: Union[LocalCache, None] = None): + def from_lindi_file(url_or_path: str, *, mode: LindiFileMode = "r", local_cache: Union[LocalCache, None] = None): """ Create a LindiH5pyFile from a URL or path to a .lindi.json file. @@ -62,7 +56,6 @@ def from_lindi_file(url_or_path: str, *, mode: LindiFileMode = "r", staging_area return LindiH5pyFile.from_reference_file_system( url_or_path, mode=mode, - staging_area=staging_area, local_cache=local_cache ) @@ -108,7 +101,7 @@ def from_hdf5_file( ) @staticmethod - def from_reference_file_system(rfs: Union[dict, str, None], *, mode: LindiFileMode = "r", staging_area: Union[StagingArea, None] = None, local_cache: Union[LocalCache, None] = None, _source_url_or_path: Union[str, None] = None, _source_tar_file: Union[LindiTarFile, None] = None, _close_source_tar_file_on_close: bool = False): + def from_reference_file_system(rfs: Union[dict, str, None], *, mode: LindiFileMode = "r", local_cache: Union[LocalCache, None] = None, _source_url_or_path: Union[str, None] = None, _source_tar_file: Union[LindiTarFile, None] = None, _close_source_tar_file_on_close: bool = False): """ Create a LindiH5pyFile from a reference file system. @@ -120,9 +113,6 @@ def from_reference_file_system(rfs: Union[dict, str, None], *, mode: LindiFileMo be created. mode : Literal["r", "r+", "w", "w-", "x", "a"], optional The mode to open the file object in, by default "r". - staging_area : Union[StagingArea, None], optional - The staging area to use for writing data, preparing for upload. This - is only used in write mode, by default None. local_cache : Union[LocalCache, None], optional The local cache to use for caching data, by default None. _source_url_or_path : Union[str, None], optional @@ -152,7 +142,6 @@ def from_reference_file_system(rfs: Union[dict, str, None], *, mode: LindiFileMo return LindiH5pyFile.from_reference_file_system( data, mode=mode, - staging_area=staging_area, local_cache=local_cache, _source_tar_file=tar_file, _source_url_or_path=rfs, @@ -193,7 +182,6 @@ def from_reference_file_system(rfs: Union[dict, str, None], *, mode: LindiFileMo return LindiH5pyFile.from_reference_file_system( data, mode=mode, - staging_area=staging_area, local_cache=local_cache, _source_url_or_path=rfs, _source_tar_file=tar_file, @@ -208,11 +196,7 @@ def from_reference_file_system(rfs: Union[dict, str, None], *, mode: LindiFileMo _source_tar_file=_source_tar_file ) source_is_url = _source_url_or_path is not None and (_source_url_or_path.startswith("http://") or _source_url_or_path.startswith("https://")) - if staging_area: - if _source_tar_file and not source_is_url: - raise Exception("Cannot use staging area when source is a local tar file") - store = LindiStagingStore(base_store=store, staging_area=staging_area) - elif _source_url_or_path and _source_tar_file and not source_is_url: + if _source_url_or_path and _source_tar_file and not source_is_url: store = LindiTarStore(base_store=store, tar_file=_source_tar_file) return LindiH5pyFile.from_zarr_store( store, @@ -274,9 +258,6 @@ def to_reference_file_system(self): if self._zarr_store is None: raise Exception("Cannot convert to reference file system without zarr store") zarr_store = self._zarr_store - if isinstance(zarr_store, LindiStagingStore): - zarr_store.consolidate_chunks() - zarr_store = zarr_store._base_store if isinstance(zarr_store, LindiTarStore): zarr_store = zarr_store._base_store if isinstance(zarr_store, LindiH5ZarrStore): @@ -289,58 +270,6 @@ def to_reference_file_system(self): LindiReferenceFileSystemStore.use_templates_in_rfs(rfs_copy) return rfs_copy - def upload( - self, - *, - on_upload_blob: UploadFileFunc, - on_upload_main: UploadFileFunc - ): - """ - Consolidate the chunks in the staging area, upload them to a storage - system, updating the references in the base store, and then upload the - updated reference file system .json file. - - Parameters - ---------- - on_upload_blob : StoreFileFunc - A function that takes a string path to a blob file, uploads or copies it - somewhere, and returns a string URL (or local path). - on_upload_main : StoreFileFunc - A function that takes a string path to the main .json file, stores - it somewhere, and returns a string URL (or local path). - - Returns - ------- - str - The URL (or local path) of the uploaded reference file system .json - file. - """ - rfs = self.to_reference_file_system() - blobs_to_upload = set() - # Get the set of all local URLs in rfs['refs'] - for k, v in rfs['refs'].items(): - if isinstance(v, list) and len(v) == 3: - url = _apply_templates(v[0], rfs.get('templates', {})) - if not url.startswith("http://") and not url.startswith("https://"): - local_path = url - blobs_to_upload.add(local_path) - # Upload each of the local blobs using the given upload function and get a mapping from - # the original file paths to the URLs of the uploaded files - blob_mapping = _upload_blobs(blobs_to_upload, on_upload_blob=on_upload_blob) - # Replace the local URLs in rfs['refs'] with URLs of the uploaded files - for k, v in rfs['refs'].items(): - if isinstance(v, list) and len(v) == 3: - url1 = _apply_templates(v[0], rfs.get('templates', {})) - url2 = blob_mapping.get(url1, None) - if url2 is not None: - v[0] = url2 - # Write the updated LINDI file to a temp directory and upload it - with tempfile.TemporaryDirectory() as tmpdir: - rfs_fname = f"{tmpdir}/rfs.lindi.json" - LindiReferenceFileSystemStore.use_templates_in_rfs(rfs) - _write_rfs_to_file(rfs=rfs, output_file_name=rfs_fname) - return on_upload_main(rfs_fname) - def write_lindi_file(self, filename: str, *, generation_metadata: Union[dict, None] = None): """ Write the reference file system to a lindi or .lindi.json file. @@ -568,15 +497,6 @@ def require_dataset(self, name, shape, dtype, exact=False, **kwds): raise Exception("Cannot require dataset in read-only mode") return self._the_group.require_dataset(name, shape, dtype, exact=exact, **kwds) - ############################## - # staging store - @property - def staging_store(self): - store = self._zarr_store - if not isinstance(store, LindiStagingStore): - return None - return store - def _download_file(url: str, filename: str) -> None: headers = { @@ -650,35 +570,6 @@ def _deep_copy(obj): return obj -def _upload_blobs( - blobs: set, - *, - on_upload_blob: UploadFileFunc -) -> dict: - """ - Upload all the blobs in a set to a storage system and return a mapping from - the original file paths to the URLs of the uploaded files. - """ - blob_mapping = {} - for i, blob in enumerate(blobs): - size = os.path.getsize(blob) - print(f'Uploading blob {i + 1} of {len(blobs)} {blob} ({_format_size_bytes(size)})') - blob_url = on_upload_blob(blob) - blob_mapping[blob] = blob_url - return blob_mapping - - -def _format_size_bytes(size_bytes: int) -> str: - if size_bytes < 1024: - return f"{size_bytes} bytes" - elif size_bytes < 1024 * 1024: - return f"{size_bytes / 1024:.1f} KB" - elif size_bytes < 1024 * 1024 * 1024: - return f"{size_bytes / 1024 / 1024:.1f} MB" - else: - return f"{size_bytes / 1024 / 1024 / 1024:.1f} GB" - - def _load_rfs_from_url(url: str): file_size = _get_file_size_of_remote_file(url) if file_size < 1024 * 1024 * 2: @@ -832,3 +723,10 @@ def _update_internal_references_to_remote_tar_file(rfs: dict, remote_url: str, r raise Exception(f"Unexpected length for reference: {len(v)}") LindiReferenceFileSystemStore.use_templates_in_rfs(rfs) + + +def _apply_templates(x: str, templates: dict) -> str: + if '{{' in x and '}}' in x: + for key, val in templates.items(): + x = x.replace('{{' + key + '}}', val) + return x diff --git a/lindi/LindiStagingStore/LindiStagingStore.py b/lindi/LindiStagingStore/LindiStagingStore.py deleted file mode 100644 index 019d982..0000000 --- a/lindi/LindiStagingStore/LindiStagingStore.py +++ /dev/null @@ -1,218 +0,0 @@ -import os -from zarr.storage import Store as ZarrStore -from ..LindiH5pyFile.LindiReferenceFileSystemStore import LindiReferenceFileSystemStore -from .StagingArea import StagingArea, _random_str - - -class LindiStagingStore(ZarrStore): - """ - A Zarr store that allows supplementing a base LindiReferenceFileSystemStore - where the large data blobs are stored in a staging area. After writing new - data to the store, the data blobs can be consolidated into larger files and - then uploaded to a custom storage system, for example DANDI or a cloud - bucket. - """ - def __init__(self, *, base_store: LindiReferenceFileSystemStore, staging_area: StagingArea): - """ - Create a LindiStagingStore. - - Parameters - ---------- - base_store : LindiReferenceFileSystemStore - The base store that this store supplements. - staging_area : StagingArea - The staging area where large data blobs are stored. - """ - self._base_store = base_store - self._staging_area = staging_area - - def __getitem__(self, key: str): - return self._base_store.__getitem__(key) - - def __setitem__(self, key: str, value: bytes): - key_parts = key.split("/") - key_base_name = key_parts[-1] - if key_base_name.startswith('.') or key_base_name.endswith('.json'): # always inline .zattrs, .zgroup, .zarray, zarr.json - inline = True - else: - # presumably it is a chunk of an array - if not isinstance(value, bytes): - raise ValueError("Value must be bytes") - size = len(value) - inline = size < 1000 # this should be a configurable threshold - if inline: - # If inline, save in memory - return self._base_store.__setitem__(key, value) - else: - # If not inline, save it as a file in the staging directory - key_without_initial_slash = key if not key.startswith("/") else key[1:] - stored_file_path = self._staging_area.store_file(key_without_initial_slash, value) - - self._set_ref_reference(key_without_initial_slash, stored_file_path, 0, len(value)) - - def __delitem__(self, key: str): - # We don't delete the file from the staging directory, because that - # would be dangerous if the file was part of a consolidated file. - return self._base_store.__delitem__(key) - - def __iter__(self): - return self._base_store.__iter__() - - def __len__(self): - return self._base_store.__len__() - - # These methods are overridden from BaseStore - def is_readable(self): - return True - - def is_writeable(self): - return True - - def is_listable(self): - return True - - def is_erasable(self): - return False - - def _set_ref_reference(self, key: str, filename: str, offset: int, size: int): - rfs = self._base_store.rfs - if 'refs' not in rfs: - # this shouldn't happen, but we'll be defensive - rfs['refs'] = {} - rfs['refs'][key] = [ - filename, - offset, - size - ] - - def consolidate_chunks(self): - """ - Consolidate the chunks in the staging area. - """ - rfs = self._base_store.rfs - refs_keys_by_reference_parent_path = {} - for k, v in rfs['refs'].items(): - if isinstance(v, list) and len(v) == 3: - url = v[0] - if not url.startswith(self._staging_area.directory + '/'): - continue - parent_path = os.path.dirname(url) - if parent_path not in refs_keys_by_reference_parent_path: - refs_keys_by_reference_parent_path[parent_path] = [] - refs_keys_by_reference_parent_path[parent_path].append(k) - for root, dirs, files1 in os.walk(self._staging_area._directory): - files = [ - f for f in files1 - if not f.startswith('.') and not f.endswith('.json') and not f.startswith('consolidated.') - ] - if len(files) <= 1: - continue - refs_keys_for_this_dir = refs_keys_by_reference_parent_path.get(root, []) - if len(refs_keys_for_this_dir) <= 1: - continue - - # sort so that the files are in order 0.0.0, 0.0.1, 0.0.2, ... - files = _sort_by_chunk_key(files) - - print(f'Consolidating {len(files)} files in {root}') - - offset = 0 - offset_maps = {} - consolidated_id = _random_str(8) - consolidated_index = 0 - max_size_of_consolidated_file = 1024 * 1024 * 1024 # 1 GB, a good size for cloud bucket files - consolidated_fname = f"{root}/consolidated.{consolidated_id}.{consolidated_index}" - consolidated_f = open(consolidated_fname, "wb") - try: - for fname in files: - full_fname = f"{root}/{fname}" - with open(full_fname, "rb") as f2: - consolidated_f.write(f2.read()) - offset_maps[full_fname] = (consolidated_fname, offset) - offset += os.path.getsize(full_fname) - if offset > max_size_of_consolidated_file: - consolidated_f.close() - consolidated_index += 1 - consolidated_fname = f"{root}/consolidated.{consolidated_id}.{consolidated_index}" - consolidated_f = open(consolidated_fname, "wb") - offset = 0 - finally: - consolidated_f.close() - for key in refs_keys_for_this_dir: - filename, old_offset, old_size = rfs['refs'][key] - if filename not in offset_maps: - continue - consolidated_fname, new_offset = offset_maps[filename] - rfs['refs'][key] = [consolidated_fname, new_offset + old_offset, old_size] - # remove the old files - for fname in files: - os.remove(f"{root}/{fname}") - - def copy_chunks_to_staging_area(self, *, download_remote: bool): - """ - Copy the chunks in the base store to the staging area. This is done - in preparation for uploading to a storage system. - - Parameters - ---------- - download_remote : bool - If True, download the remote chunks to the staging area. If False, - just copy the local chunks. - """ - if download_remote: - raise NotImplementedError("Downloading remote chunks not yet implemented") - rfs = self._base_store.rfs - templates = rfs.get('templates', {}) - for k, v in rfs['refs'].items(): - if isinstance(v, list) and len(v) == 3: - url = _apply_templates(v[0], templates) - if url.startswith('http://') or url.startswith('https://'): - if download_remote: - raise NotImplementedError("Downloading remote chunks not yet implemented") - continue - elif url.startswith(self._staging_area.directory + '/'): - # already in the staging area - continue - else: - # copy the local file to the staging area - path0 = url - chunk_data = _read_chunk_data(path0, v[1], v[2]) - stored_file_path = self._staging_area.store_file(k, chunk_data) - self._set_ref_reference(k, stored_file_path, 0, v[2]) - - -def _apply_templates(x: str, templates: dict) -> str: - if '{{' in x and '}}' in x: - for key, val in templates.items(): - x = x.replace('{{' + key + '}}', val) - return x - - -def _sort_by_chunk_key(files: list) -> list: - # first verify that all the files have the same number of parts - num_parts = None - for fname in files: - parts = fname.split('.') - if num_parts is None: - num_parts = len(parts) - elif len(parts) != num_parts: - raise ValueError(f"Files have different numbers of parts: {files}") - # Verify that all the parts are integers - for fname in files: - parts = fname.split('.') - for p in parts: - try: - int(p) - except ValueError: - raise ValueError(f"File part is not an integer: {fname}") - - def _chunk_key(fname: str) -> tuple: - parts = fname.split('.') - return tuple(int(p) for p in parts) - return sorted(files, key=_chunk_key) - - -def _read_chunk_data(filename: str, offset: int, size: int) -> bytes: - with open(filename, "rb") as f: - f.seek(offset) - return f.read(size) diff --git a/lindi/LindiStagingStore/StagingArea.py b/lindi/LindiStagingStore/StagingArea.py deleted file mode 100644 index 460bc9b..0000000 --- a/lindi/LindiStagingStore/StagingArea.py +++ /dev/null @@ -1,109 +0,0 @@ -from typing import Union -import os -import random -import string -import datetime -import shutil - - -class StagingArea: - """ - A staging area where files can be stored temporarily before being - consolidated and uploaded to a storage system. - - This class is a context manager, so it can be used in a `with` statement to - ensure that the staging area is cleaned up when it is no longer needed. - """ - def __init__(self, *, _directory: str) -> None: - """ - Do not call this constructor directly. Instead, use the `create` method - to create a new staging area. - """ - self._directory = os.path.abspath(_directory) - - @staticmethod - def create(*, base_dir: Union[str, None] = None, dir: Union[str, None] = None) -> 'StagingArea': - """ - Create a new staging area. Provide either `base_dir` or `dir`, but not - both. - - Parameters - ---------- - base_dir : str or None - If provided, the base directory where the staging area will be - created. The staging directory will be a subdirectory of this - directory. - dir : str or None - If provided, the exact directory where the staging area will be - created. It is okay if this directory already exists. - """ - if base_dir is not None and dir is not None: - raise ValueError("Provide either base_dir or dir, but not both") - if base_dir is not None: - dir = os.path.join(base_dir, _create_random_id()) - if dir is None: - raise ValueError("Provide either base_dir or dir") - return StagingArea(_directory=dir) - - def cleanup(self) -> None: - """ - Clean up the staging area, deleting all files in it. This method is - called automatically when the staging area is used as a context manager - in a `with` statement. - """ - if os.path.exists(self._directory): - shutil.rmtree(self._directory) - - def __enter__(self) -> 'StagingArea': - return self - - def __exit__(self, exc_type, exc_value, traceback) -> None: - self.cleanup() - - @property - def directory(self) -> str: - """ - The directory where the files are stored. - """ - return self._directory - - def store_file(self, relpath: str, value: bytes) -> str: - """ - Store a file in the staging area. - - Parameters - ---------- - relpath : str - The relative path to the file, relative to the staging area root. - value : bytes - The contents of the file. - """ - path = os.path.join(self._directory, relpath) - os.makedirs(os.path.dirname(path), exist_ok=True) - with open(path, 'wb') as f: - f.write(value) - return path - - def get_full_path(self, relpath: str) -> str: - """ - Get the full path to a file in the staging area. - - Parameters - ---------- - relpath : str - The relative path to the file, relative to the staging area root. - """ - return os.path.join(self._directory, relpath) - - -def _create_random_id(): - # This is going to be a timestamp suitable for alphabetical chronological order plus a random string - return f"{_timestamp_str()}-{_random_str(8)}" - - -def _timestamp_str(): - return datetime.datetime.now().strftime("%Y%m%d%H%M%S") - - -def _random_str(n): - return ''.join(random.choices(string.ascii_lowercase + string.digits, k=n)) diff --git a/lindi/LindiStagingStore/__init__.py b/lindi/LindiStagingStore/__init__.py deleted file mode 100644 index 7ab3d49..0000000 --- a/lindi/LindiStagingStore/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .LindiStagingStore import LindiStagingStore, StagingArea # noqa: F401 diff --git a/lindi/__init__.py b/lindi/__init__.py index be60471..6e5e7f6 100644 --- a/lindi/__init__.py +++ b/lindi/__init__.py @@ -1,6 +1,4 @@ from .LindiH5ZarrStore import LindiH5ZarrStore, LindiH5ZarrStoreOpts # noqa: F401 from .LindiH5pyFile import LindiH5pyFile, LindiH5pyGroup, LindiH5pyDataset, LindiH5pyHardLink, LindiH5pySoftLink # noqa: F401 -from .LindiStagingStore import LindiStagingStore, StagingArea # noqa: F401 from .LocalCache.LocalCache import LocalCache, ChunkTooLargeError # noqa: F401 -from .File.File import File # noqa: F401 from .LindiRemfile.additional_url_resolvers import add_additional_url_resolver # noqa: F401 diff --git a/tests/test_lindi_file.py b/tests/test_lindi_file.py deleted file mode 100644 index 720579a..0000000 --- a/tests/test_lindi_file.py +++ /dev/null @@ -1,23 +0,0 @@ -import tempfile -import numpy as np -import h5py -import lindi - - -def test_lindi_file(): - with tempfile.TemporaryDirectory() as tmpdir: - fname = f'{tmpdir}/test.lindi.json' - with lindi.File(fname, 'w') as f: - f.create_dataset('data', data=np.arange(500000, dtype=np.uint32), chunks=(100000,)) - - with lindi.File(fname, 'r') as f: - ds = f['data'] - assert isinstance(ds, h5py.Dataset) - assert ds.shape == (500000,) - assert ds.chunks == (100000,) - assert ds.dtype == np.uint32 - assert np.all(ds[:] == np.arange(500000, dtype=np.uint32)) - - -if __name__ == '__main__': - test_lindi_file() diff --git a/tests/test_staging_area.py b/tests/test_staging_area.py deleted file mode 100644 index 763cb45..0000000 --- a/tests/test_staging_area.py +++ /dev/null @@ -1,66 +0,0 @@ -import tempfile -import os -import numpy as np -import lindi -import shutil - - -def test_staging_area(): - with tempfile.TemporaryDirectory() as tmpdir: - staging_area = lindi.StagingArea.create(base_dir=tmpdir + '/staging_area') - client = lindi.LindiH5pyFile.from_reference_file_system(None, mode='r+', staging_area=staging_area) - X = np.random.randn(1000, 1000).astype(np.float32) - client.create_dataset('large_array', data=X, chunks=(400, 400)) - total_size = _get_total_size_of_directory(tmpdir) - assert total_size >= X.nbytes * 0.5, f'{total_size} < {X.nbytes} * 0.5' # take into consideration compression - rfs = client.to_reference_file_system() - client2 = lindi.LindiH5pyFile.from_reference_file_system(rfs, mode='r') - assert isinstance(client2, lindi.LindiH5pyFile) - X1 = client['large_array'] - assert isinstance(X1, lindi.LindiH5pyDataset) - X2 = client2['large_array'] - assert isinstance(X2, lindi.LindiH5pyDataset) - assert np.allclose(X1[:], X2[:]) - - upload_dir = f'{tmpdir}/upload_dir' - os.makedirs(upload_dir, exist_ok=True) - output_fname = f'{tmpdir}/output.lindi.json' - - def on_upload_blob(fname: str): - random_fname = f'{upload_dir}/{_random_string(10)}' - shutil.copy(fname, random_fname) - return random_fname - - def on_upload_main(fname: str): - shutil.copy(fname, output_fname) - return output_fname - - assert client.staging_store - client.upload( - on_upload_blob=on_upload_blob, - on_upload_main=on_upload_main - ) - - client3 = lindi.LindiH5pyFile.from_lindi_file(output_fname, mode='r') - X3 = client3['large_array'] - assert isinstance(X3, lindi.LindiH5pyDataset) - assert np.allclose(X1[:], X3[:]) - - -def _get_total_size_of_directory(directory): - total_size = 0 - for dirpath, dirnames, filenames in os.walk(directory): - for f in filenames: - fp = os.path.join(dirpath, f) - total_size += os.path.getsize(fp) - return total_size - - -def _random_string(n): - import random - import string - return ''.join(random.choices(string.ascii_uppercase + string.digits, k=n)) - - -if __name__ == '__main__': - test_staging_area()