diff --git a/.github/workflows/linter_checks.yml b/.github/workflows/linter_checks.yml index acf2726..87ffa72 100644 --- a/.github/workflows/linter_checks.yml +++ b/.github/workflows/linter_checks.yml @@ -19,4 +19,4 @@ jobs: - name: Run flake8 run: cd lindi && flake8 --config ../.flake8 - name: Run pyright - run: cd lindi && pyright + run: cd lindi && pyright . diff --git a/.gitignore b/.gitignore index c55f5c4..9b0872b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ +benchmark.* +*.lindi *.lindi.json* *.nwb diff --git a/.vscode/tasks/test.sh b/.vscode/tasks/test.sh index c87184e..3d01bfa 100755 --- a/.vscode/tasks/test.sh +++ b/.vscode/tasks/test.sh @@ -5,7 +5,7 @@ set -ex cd lindi flake8 . -pyright +pyright . cd .. pytest --cov=lindi --cov-report=xml --cov-report=term tests/ diff --git a/README.md b/README.md index 2952edb..9f984d0 100644 --- a/README.md +++ b/README.md @@ -6,19 +6,39 @@ :warning: Please note, LINDI is currently under development and should not yet be used in practice. -**HDF5 as Zarr as JSON for NWB** +LINDI is a cloud-friendly file format and Python library designed for managing scientific data, especially Neurodata Without Borders (NWB) datasets. It offers an alternative to [HDF5](https://docs.hdfgroup.org/hdf5/v1_14/_intro_h_d_f5.html) and [Zarr](https://zarr.dev/), maintaining compatibility with both, while providing features tailored for linking to remote datasets stored in the cloud, such as those on the [DANDI Archive](https://www.dandiarchive.org/). LINDI's unique structure and capabilities make it particularly well-suited for efficient data access and management in cloud environments. -LINDI provides a JSON representation of NWB (Neurodata Without Borders) data where the large data chunks are stored separately from the main metadata. This enables efficient storage, composition, and sharing of NWB files on cloud systems such as [DANDI](https://www.dandiarchive.org/) without duplicating the large data blobs. +**What is a LINDI file?** -LINDI provides: +A LINDI file is a cloud-friendly format for storing scientific data, designed to be compatible with HDF5 and Zarr while offering unique advantages. It comes in two types: JSON/text format (.lindi.json) and binary format (.lindi.tar). -- A specification for representing arbitrary HDF5 files as Zarr stores. This handles scalar datasets, references, soft links, and compound data types for datasets. -- A Zarr wrapper for remote or local HDF5 files (LindiH5ZarrStore). -- A mechanism for creating .lindi.json (or .nwb.lindi.json) files that reference data chunks in external files, inspired by [kerchunk](https://github.com/fsspec/kerchunk). -- An h5py-like interface for reading from and writing to these data sources that can be used with [pynwb](https://pynwb.readthedocs.io/en/stable/). -- A mechanism for uploading and downloading these data sources to and from cloud storage, including DANDI. +In the JSON format, the hierarchical group structure, attributes, and small datasets are stored in a JSON structure, with references to larger data chunks stored in external files (inspired by [kerchunk](https://github.com/fsspec/kerchunk)). This format is human-readable and easily inspected and edited. On the other hand, the binary format is a .tar file that contains the JSON file along with optional internal data chunks referenced by the JSON file, in addition to external chunks. This format allows for efficient cloud storage and random access. -This project was inspired by [kerchunk](https://github.com/fsspec/kerchunk) and [hdmf-zarr](https://hdmf-zarr.readthedocs.io/en/latest/index.html) and depends on [zarr](https://zarr.readthedocs.io/en/stable/), [h5py](https://www.h5py.org/) and [numcodecs](https://numcodecs.readthedocs.io/en/stable/). +The main advantage of the JSON LINDI format is its readability and ease of modification, while the binary LINDI format offers the ability to include internal data chunks, providing flexibility in data storage and retrieval. Both formats are optimized for cloud use, enabling efficient downloading and access from cloud storage. + +**What are the main use cases?** + +LINDI files are particularly useful in the following scenarios: + +**Efficient NWB File Representation on DANDI**: A LINDI JSON file can represent an NWB file stored on the DANDI Archive (or other remote system). By downloading a condensed JSON file, the entire group structure can be retrieved in a single request, facilitating efficient loading of NWB files. For instance, [Neurosift](https://github.com/flatironinstitute/neurosift) utilizes pre-generated LINDI JSON files to streamline the loading process of NWB files from DANDI. + +**Creating Amended NWB Files**: LINDI allows for the creation of amended NWB files that add new data objects to existing NWB files without duplicating the entire file. This is achieved by generating a binary LINDI file that references the original NWB file and includes additional data objects stored as internal data chunks. This approach saves storage space and reduces redundancy. + +**Why not use Zarr?** + +While Zarr is a cloud-friendly alternative to HDF5, it has notable limitations. Zarr archives often consist of thousands of individual files, making them cumbersome to manage. In contrast, LINDI files adopt a single file approach similar to HDF5, enhancing manageability while retaining cloud-friendliness. Another limitation of Zarr is the lack of a mechanism to reference data chunks in external datasets as LINDI has. Additionally, Zarr does not support certain features utilized by PyNWB, such as compound data types and references, which are supported by both HDF5 and LINDI. + +**Why not use HDF5?** + +HDF5 is not well-suited for cloud environments because accessing a remote HDF5 file often requires a large number of small requests to retrieve metadata before larger data chunks can be downloaded. LINDI addresses this by storing the entire group structure in a single JSON file, which can be downloaded in one request. Additionally, HDF5 lacks a built-in mechanism for referencing data chunks in external datasets. Furthermore, HDF5 does not support custom Python codecs, a feature available in both Zarr and LINDI. These advantages make LINDI a more efficient and versatile option for cloud-based data storage and access. + +**Does LINDI use Zarr?** + +Yes, LINDI leverages the Zarr format to store data, including attributes and group hierarchies. However, instead of using directories and files like Zarr, LINDI stores all data within a single JSON structure. This structure includes references to large data chunks, which can reside in remote files (e.g., an HDF5 NWB file on DANDI) or within internal data chunks in the binary LINDI file. Although NWB relies on certain HDF5 features not supported by Zarr, LINDI provides mechanisms to represent these features in Zarr, ensuring compatibility and extending functionality. + +**Is tar format really cloud-friendly** + +With LINDI, yes. See [docs/tar.md](docs/tar.md) for details. ## Installation @@ -33,116 +53,101 @@ cd lindi pip install -e . ``` -## Use cases +## Usage -* Lazy-load a remote NWB/HDF5 file for efficient access to metadata and data. -* Represent a remote NWB/HDF5 file as a .nwb.lindi.json file. -* Read a local or remote .nwb.lindi.json file using pynwb or other tools. -* Edit a .nwb.lindi.json file using pynwb or other tools. -* Add datasets to a .nwb.lindi.json file using a local staging area. -* Upload a .nwb.lindi.json file with staged datasets to a cloud storage service such as DANDI. +**Creating and reading a LINDI file** -### Lazy-load a remote NWB/HDF5 file for efficient access to metadata and data +The simplest way to start is to use it like HDF5. ```python -import pynwb import lindi -# URL of the remote NWB file -h5_url = "https://api.dandiarchive.org/api/assets/11f512ba-5bcf-4230-a8cb-dc8d36db38cb/download/" - -# Set up a local cache -local_cache = lindi.LocalCache(cache_dir='lindi_cache') - -# Create the h5py-like client -client = lindi.LindiH5pyFile.from_hdf5_file(h5_url, local_cache=local_cache) - -# Open using pynwb -with pynwb.NWBHDF5IO(file=client, mode="r") as io: - nwbfile = io.read() - print(nwbfile) - -# The downloaded data will be cached locally, so subsequent reads will be faster +# Create a new lindi.json file +with lindi.LindiH5pyFile.from_lindi_file('example.lindi.json', mode='w') as f: + f.attrs['attr1'] = 'value1' + f.attrs['attr2'] = 7 + ds = f.create_dataset('dataset1', shape=(10,), dtype='f') + ds[...] = 12 + +# Later read the file +with lindi.LindiH5pyFile.from_lindi_file('example.lindi.json', mode='r') as f: + print(f.attrs['attr1']) + print(f.attrs['attr2']) + print(f['dataset1'][...]) ``` -### Represent a remote NWB/HDF5 file as a .nwb.lindi.json file +You can inspect the example.lindi.json file to get an idea of how the data are stored. If you are familiar with the internal Zarr format, you will recognize the .group and .zarray files and the layout of the chunks. + +Because the above dataset is very small, it can all fit reasonably inside the JSON file. For storing larger arrays (the usual case) it is better to use the binary format. Just leave off the .json extension. ```python -import json +import numpy as np import lindi -# URL of the remote NWB file -h5_url = "https://api.dandiarchive.org/api/assets/11f512ba-5bcf-4230-a8cb-dc8d36db38cb/download/" - -# Create the h5py-like client -client = lindi.LindiH5pyFile.from_hdf5_file(h5_url) - -client.write_lindi_file('example.lindi.json') - -# See the next example for how to read this file +# Create a new lindi binary file +with lindi.LindiH5pyFile.from_lindi_file('example.lindi.tar', mode='w') as f: + f.attrs['attr1'] = 'value1' + f.attrs['attr2'] = 7 + ds = f.create_dataset('dataset1', shape=(1000, 1000), dtype='f') + ds[...] = np.random.rand(1000, 1000) + +# Later read the file +with lindi.LindiH5pyFile.from_lindi_file('example.lindi.tar', mode='r') as f: + print(f.attrs['attr1']) + print(f.attrs['attr2']) + print(f['dataset1'][...]) ``` -### Read a local or remote .nwb.lindi.json file using pynwb or other tools +**Loading a remote NWB file from DANDI** ```python +import json import pynwb import lindi -# URL of the remote .nwb.lindi.json file -url = 'https://lindi.neurosift.org/dandi/dandisets/000939/assets/56d875d6-a705-48d3-944c-53394a389c85/nwb.lindi.json' - -# Load the h5py-like client -client = lindi.LindiH5pyFile.from_lindi_file(url) +# Define the URL for a remote NWB file +h5_url = "https://api.dandiarchive.org/api/assets/11f512ba-5bcf-4230-a8cb-dc8d36db38cb/download/" -# Open using pynwb -with pynwb.NWBHDF5IO(file=client, mode="r") as io: +# Load as LINDI and view using pynwb +f = lindi.LindiH5pyFile.from_hdf5_file(h5_url) +with pynwb.NWBHDF5IO(file=f, mode="r") as io: nwbfile = io.read() + print('NWB via LINDI') print(nwbfile) -``` -### Edit a .nwb.lindi.json file using pynwb or other tools + print('Electrode group at shank0:') + print(nwbfile.electrode_groups["shank0"]) # type: ignore -```python -import json -import lindi + print('Electrode group at index 0:') + print(nwbfile.electrodes.group[0]) # type: ignore -# URL of the remote .nwb.lindi.json file -url = 'https://lindi.neurosift.org/dandi/dandisets/000939/assets/56d875d6-a705-48d3-944c-53394a389c85/nwb.lindi.json' +# Save as LINDI JSON +f.write_lindi_file('example.nwb.lindi.json') -# Load the h5py-like client for the reference file system -# in read-write mode -client = lindi.LindiH5pyFile.from_lindi_file(url, mode="r+") +# Later, read directly from the LINDI JSON file +g = lindi.LindiH5pyFile.from_lindi_file('example.nwb.lindi.json') +with pynwb.NWBHDF5IO(file=g, mode="r") as io: + nwbfile = io.read() + print('') + print('NWB from LINDI JSON:') + print(nwbfile) -# Edit an attribute -client.attrs['new_attribute'] = 'new_value' + print('Electrode group at shank0:') + print(nwbfile.electrode_groups["shank0"]) # type: ignore -# Save the changes to a new .nwb.lindi.json file -client.write_lindi_file('new.nwb.lindi.json') + print('Electrode group at index 0:') + print(nwbfile.electrodes.group[0]) # type: ignore ``` -### Add datasets to a .nwb.lindi.json file using a local staging area +## Amending an NWB file -```python -import lindi +Basically you save the remote NWB as a local binary LINDI file, and then add additional data objects to it. -# URL of the remote .nwb.lindi.json file -url = 'https://lindi.neurosift.org/dandi/dandisets/000939/assets/56d875d6-a705-48d3-944c-53394a389c85/nwb.lindi.json' - -# Load the h5py-like client for the reference file system -# in read-write mode with a staging area -with lindi.StagingArea.create(base_dir='lindi_staging') as staging_area: - client = lindi.LindiH5pyFile.from_lindi_file( - url, - mode="r+", - staging_area=staging_area - ) - # add datasets to client using pynwb or other tools - # upload the changes to the remote .nwb.lindi.json file -``` +TODO: finish this section -### Upload a .nwb.lindi.json file with staged datasets to a cloud storage service such as DANDI +## Notes -See [this example](https://github.com/magland/lindi-dandi/blob/main/devel/lindi_test_2.py). +This project was inspired by [kerchunk](https://github.com/fsspec/kerchunk) and [hdmf-zarr](https://hdmf-zarr.readthedocs.io/en/latest/index.html) and depends on [zarr](https://zarr.readthedocs.io/en/stable/), [h5py](https://www.h5py.org/) and [numcodecs](https://numcodecs.readthedocs.io/en/stable/). ## For developers diff --git a/devel/write_tar_header.py b/devel/write_tar_header.py new file mode 100644 index 0000000..aa666dc --- /dev/null +++ b/devel/write_tar_header.py @@ -0,0 +1,256 @@ +import tempfile +import tarfile +import io + + +test_file_name = "text/file.txt" +test_file_size = 8100 + + +def create_tar_header_using_tarfile(file_name: str, file_size: int) -> bytes: + with tempfile.TemporaryDirectory() as tmpdirname: + tmp_fname = f"{tmpdirname}/test.tar" + with tarfile.open(tmp_fname, "w") as tar: + tarinfo = tarfile.TarInfo(file_name) + tarinfo.size = file_size + fobj = io.BytesIO(b"0" * file_size) + tar.addfile(tarinfo, fileobj=fobj) + tar.close() + with open(tmp_fname, "rb") as f: + header = f.read(512) + return header + + +def practice_form_header(file_name: str, file_size: int, header: bytes): + h = b'' + # We use USTAR format only + + # file name + a = header[0:100] + b = file_name.encode() + b"\x00" * (100 - len(file_name)) + assert a == b + h += b + + # file mode + a = header[100:108] + b = b"0000644\x00" # 644 is the default permission - you can read and write, but others can only read + assert a == b + h += b + + # uid + a = header[108:116] + b = b"0000000\x00" # 0 is the default user id + assert a == b + h += b + + # gid + a = header[116:124] + b = b"0000000\x00" # 0 is the default group id + assert a == b + h += b + + # size + a = header[124:136] + # we need an octal representation of the size + b = f"{file_size:011o}".encode() + b"\x00" # 11 octal digits + assert a == b + h += b + + # mtime + a = header[136:148] + b = b"00000000000\x00" # 0 is the default modification time + assert a == b + h += b + + # chksum + # We'll determine the checksum after creating the full header + h += b" " * 8 # 8 spaces for now + + # typeflag + a = header[156:157] + b = b"0" # default typeflag is 0 representing a regular file + assert a == b + h += b + + # linkname + a = header[157:257] + b = b"\x00" * 100 # no link name + assert a == b + h += b + + # magic + a = header[257:263] + b = b"ustar\x00" # specifies the ustar format + assert a == b + h += b + + # version + a = header[263:265] + b = b"00" # ustar version + assert a == b + h += b + + # uname + a = header[265:297] + b = b"\x00" * 32 # no user name + assert a == b + h += b + + # gname + a = header[297:329] + b = b"\x00" * 32 # no group name + assert a == b + h += b + + # devmajor + a = header[329:337] + b = b"\x00" * 8 # no device major number + assert a == b + h += b + + # devminor + a = header[337:345] + b = b"\x00" * 8 # no device minor number + assert a == b + h += b + + # prefix + a = header[345:500] + b = b"\x00" * 155 # no prefix + assert a == b + h += b + + # padding + a = header[500:] + b = b"\x00" * 12 # padding + assert a == b + h += b + + # Now we calculate the checksum + chksum = _compute_checksum_for_header(h) + h = h[:148] + chksum + h[156:] + + assert h == header + + +def create_tar_header(file_name: str, file_size: int) -> bytes: + # We use USTAR format only + h = b'' + + # file name + a = file_name.encode() + b"\x00" * (100 - len(file_name)) + h += a + + # file mode + a = b"0000644\x00" # 644 is the default permission - you can read and write, but others can only read + h += a + + # uid + a = b"0000000\x00" # 0 is the default user id + h += a + + # gid + a = b"0000000\x00" # 0 is the default group id + h += a + + # size + # we need an octal representation of the size + a = f"{file_size:011o}".encode() + b"\x00" # 11 octal digits + h += a + + # mtime + a = b"00000000000\x00" # 0 is the default modification time + h += a + + # chksum + # We'll determine the checksum after creating the full header + a = b" " * 8 # 8 spaces for now + h += a + + # typeflag + a = b"0" # default typeflag is 0 representing a regular file + h += a + + # linkname + a = b"\x00" * 100 # no link name + h += a + + # magic + a = b"ustar\x00" # specifies the ustar format + h += a + + # version + a = b"00" # ustar version + h += a + + # uname + a = b"\x00" * 32 # no user name + h += a + + # gname + a = b"\x00" * 32 # no group name + h += a + + # devmajor + a = b"\x00" * 8 # no device major number + h += a + + # devminor + a = b"\x00" * 8 # no device minor number + h += a + + # prefix + a = b"\x00" * 155 # no prefix + h += a + + # padding + a = b"\x00" * 12 # padding + h += a + + # Now we calculate the checksum + chksum = _compute_checksum_for_header(h) + h = h[:148] + chksum + h[156:] + + assert len(h) == 512 + + return h + + +def _compute_checksum_for_header(header: bytes) -> bytes: + # From https://en.wikipedia.org/wiki/Tar_(computing) + # The checksum is calculated by taking the sum of the unsigned byte values + # of the header record with the eight checksum bytes taken to be ASCII + # spaces (decimal value 32). It is stored as a six digit octal number with + # leading zeroes followed by a NUL and then a space. Various implementations + # do not adhere to this format. In addition, some historic tar + # implementations treated bytes as signed. Implementations typically + # calculate the checksum both ways, and treat it as good if either the + # signed or unsigned sum matches the included checksum. + + header_byte_list = [] + for byte in header: + header_byte_list.append(byte) + for i in range(148, 156): + header_byte_list[i] = 32 + sum = 0 + for byte in header_byte_list: + sum += byte + checksum = oct(sum).encode()[2:] + while len(checksum) < 6: + checksum = b"0" + checksum + checksum += b"\0 " + return checksum + + +def main(): + header1 = create_tar_header_using_tarfile(test_file_name, test_file_size) + practice_form_header(test_file_name, test_file_size, header1) + header2 = create_tar_header(test_file_name, test_file_size) + + assert header1 == header2 + + print("Success!") + + +if __name__ == "__main__": + main() diff --git a/docs/tar.md b/docs/tar.md new file mode 100644 index 0000000..cc11b04 --- /dev/null +++ b/docs/tar.md @@ -0,0 +1,15 @@ +# LINDI binary (tar) format + +In addition to a JSON/text format, LINDI offers a binary format packaged as a tar archive, which includes a specialized lindi.json file in the standard JSON format as well as other files including binary chunks. The `lindi.json` file can reference a mix of external references and internal binary chunks. + +**General structure of a tar archive**: Tar is a simple and widely-used format that houses binary files sequentially, with each file record beginning with a 512-byte header that describes the file (name, size, etc.), followed by the content rounded up to 512-byte blocks. The archive is terminated by two 512-byte blocks filled with zeros. + +**Cloud Optimization**: Tar archives are typically not optimized for cloud storage due to their sequential file arrangement which necessitates reading all headers for index construction. To address this, LINDI introduces two crucial files within each archive: + +`.tar_entry.json`: This must always be the first file in the archive, fixed at 1024 bytes (padded with whitespace if necessary). It specifies the byte range for the `.tar_index.json` file, allowing it to be quickly located and read. + +`.tar_index.json`: Contains names and byte ranges of all other files in the archive, enabling efficient random access after the initial two requests (one for `.tar_entry.json` and one for `.tar_index.json`). + +**Handling Updates and Data Growth**: Traditional tar clients do not allow for file resizing or deletion, posing a challenge when updating files like `lindi.json` that might grow as data is added. LINDI circumvents these issues by padding `lindi.json` and `.tar_index.json` with extra whitespace, allowing for in-place expansion up to a predetermined limit without modifying the tar structure. If expansion beyond this limit is necessary, the original file is renamed to a placeholder (e.g., `./trash/xxxxx`), effectively removing it from use, and a new version of the file is appended to the end of the archive. + +**Efficient Cloud Interaction**: With the special structure of `.tar_entry.json` and `.tar_index.json`, clients can download the index with minimal requests, reducing the overhead typical of cloud interactions with large tar archives. \ No newline at end of file diff --git a/examples/DANDI/nwbextractors.py b/examples/DANDI/nwbextractors.py new file mode 100644 index 0000000..37c5af2 --- /dev/null +++ b/examples/DANDI/nwbextractors.py @@ -0,0 +1,1427 @@ +from __future__ import annotations +from pathlib import Path +from typing import List, Optional, Literal, Dict, BinaryIO +import warnings + +import numpy as np +import h5py + +from spikeinterface import get_global_tmp_folder +from spikeinterface.core import BaseRecording, BaseRecordingSegment, BaseSorting, BaseSortingSegment +from spikeinterface.core.core_tools import define_function_from_class + + +def read_file_from_backend( + *, + file_path: str | Path | None, + file: BinaryIO | None = None, + h5py_file: h5py.File | None = None, + stream_mode: Literal["ffspec", "remfile"] | None = None, + cache: bool = False, + stream_cache_path: str | Path | None = None, + storage_options: dict | None = None, +): + """ + Reads a file from a hdf5 or zarr backend. + """ + if stream_mode == "fsspec": + import h5py + import fsspec + from fsspec.implementations.cached import CachingFileSystem + + assert file_path is not None, "file_path must be specified when using stream_mode='fsspec'" + + fsspec_file_system = fsspec.filesystem("http") + + if cache: + stream_cache_path = stream_cache_path if stream_cache_path is not None else str(get_global_tmp_folder()) + caching_file_system = CachingFileSystem( + fs=fsspec_file_system, + cache_storage=str(stream_cache_path), + ) + ffspec_file = caching_file_system.open(path=file_path, mode="rb") + else: + ffspec_file = fsspec_file_system.open(file_path, "rb") + + if _is_hdf5_file(ffspec_file): + open_file = h5py.File(ffspec_file, "r") + else: + raise RuntimeError(f"{file_path} is not a valid HDF5 file!") + + elif stream_mode == "ros3": + import h5py + + assert file_path is not None, "file_path must be specified when using stream_mode='ros3'" + + drivers = h5py.registered_drivers() + assertion_msg = "ROS3 support not enabled, use: install -c conda-forge h5py>=3.2 to enable streaming" + assert "ros3" in drivers, assertion_msg + open_file = h5py.File(name=file_path, mode="r", driver="ros3") + + elif stream_mode == "remfile": + import remfile + import h5py + + assert file_path is not None, "file_path must be specified when using stream_mode='remfile'" + rfile = remfile.File(file_path) + if _is_hdf5_file(rfile): + open_file = h5py.File(rfile, "r") + else: + raise RuntimeError(f"{file_path} is not a valid HDF5 file!") + + elif stream_mode == "zarr": + import zarr + + open_file = zarr.open(file_path, mode="r", storage_options=storage_options) + + elif file_path is not None: # local + file_path = str(Path(file_path).resolve()) + backend = _get_backend_from_local_file(file_path) + if backend == "zarr": + import zarr + + open_file = zarr.open(file_path, mode="r") + else: + import h5py + + open_file = h5py.File(name=file_path, mode="r") + elif file is not None: + import h5py + open_file = h5py.File(file, "r") + return open_file + elif h5py_file is not None: + return h5py_file + else: + raise ValueError("Provide either file_path or file or h5py_file") + + +def read_nwbfile( + *, + backend: Literal["hdf5", "zarr"], + file_path: str | Path | None, + file: BinaryIO | None = None, + h5py_file: h5py.File | None = None, + stream_mode: Literal["ffspec", "remfile", "zarr"] | None = None, + cache: bool = False, + stream_cache_path: str | Path | None = None, + storage_options: dict | None = None, +) -> "NWBFile": + """ + Read an NWB file and return the NWBFile object. + + Parameters + ---------- + file_path : Path, str or None + The path to the NWB file. Either provide this or file. + file : file-like object or None + The file-like object to read from. Either provide this or file_path. + stream_mode : "fsspec" | "remfile" | None, default: None + The streaming mode to use. If None it assumes the file is on the local disk. + cache : bool, default: False + If True, the file is cached in the file passed to stream_cache_path + if False, the file is not cached. + stream_cache_path : str or None, default: None + The path to the cache storage, when default to None it uses the a temporary + folder. + Returns + ------- + nwbfile : NWBFile + The NWBFile object. + + Notes + ----- + This function can stream data from the "fsspec", and "rem" protocols. + + + Examples + -------- + >>> nwbfile = read_nwbfile(file_path="data.nwb", backend="hdf5", stream_mode="fsspec") + """ + + if file_path is not None and file is not None: + raise ValueError("Provide either file_path or file, not both") + if file_path is None and h5py_file is not None: + raise ValueError("Provide either h5py_file or file_path, not both") + if file is not None and h5py_file is not None: + raise ValueError("Provide either h5py_file or file, not both") + if file_path is None and file is None and h5py_file is None: + raise ValueError("Provide either file_path or file or h5py_file") + + open_file = read_file_from_backend( + file_path=file_path, + file=file, + h5py_file=h5py_file, + stream_mode=stream_mode, + cache=cache, + stream_cache_path=stream_cache_path, + storage_options=storage_options, + ) + if backend == "hdf5": + from pynwb import NWBHDF5IO + + io = NWBHDF5IO(file=open_file, mode="r", load_namespaces=True) + else: + from hdmf_zarr import NWBZarrIO + + io = NWBZarrIO(path=open_file.store, mode="r", load_namespaces=True) + + nwbfile = io.read() + return nwbfile + + +def _retrieve_electrical_series_pynwb( + nwbfile: "NWBFile", electrical_series_path: Optional[str] = None +) -> "ElectricalSeries": + """ + Get an ElectricalSeries object from an NWBFile. + + Parameters + ---------- + nwbfile : NWBFile + The NWBFile object from which to extract the ElectricalSeries. + electrical_series_path : str, default: None + The name of the ElectricalSeries to extract. If not specified, it will return the first found ElectricalSeries + if there's only one; otherwise, it raises an error. + + Returns + ------- + ElectricalSeries + The requested ElectricalSeries object. + + Raises + ------ + ValueError + If no acquisitions are found in the NWBFile or if multiple acquisitions are found but no electrical_series_path + is provided. + AssertionError + If the specified electrical_series_path is not present in the NWBFile. + """ + from pynwb.ecephys import ElectricalSeries + + electrical_series_dict: Dict[str, ElectricalSeries] = {} + + for item in nwbfile.all_children(): + if isinstance(item, ElectricalSeries): + # remove data and skip first "/" + electrical_series_key = item.data.name.replace("/data", "")[1:] + electrical_series_dict[electrical_series_key] = item + + if electrical_series_path is not None: + if electrical_series_path not in electrical_series_dict: + raise ValueError(f"{electrical_series_path} not found in the NWBFile. ") + electrical_series = electrical_series_dict[electrical_series_path] + else: + electrical_series_list = list(electrical_series_dict.keys()) + if len(electrical_series_list) > 1: + raise ValueError( + f"More than one acquisition found! You must specify 'electrical_series_path'. \n" + f"Options in current file are: {[e for e in electrical_series_list]}" + ) + if len(electrical_series_list) == 0: + raise ValueError("No acquisitions found in the .nwb file.") + electrical_series = electrical_series_dict[electrical_series_list[0]] + + return electrical_series + + +def _retrieve_unit_table_pynwb(nwbfile: "NWBFile", unit_table_path: Optional[str] = None) -> "Units": + """ + Get an Units object from an NWBFile. + Units tables can be either the main unit table (nwbfile.units) or in the processing module. + + Parameters + ---------- + nwbfile : NWBFile + The NWBFile object from which to extract the Units. + unit_table_path : str, default: None + The path of the Units to extract. If not specified, it will return the first found Units + if there's only one; otherwise, it raises an error. + + Returns + ------- + Units + The requested Units object. + + Raises + ------ + ValueError + If no unit tables are found in the NWBFile or if multiple unit tables are found but no unit_table_path + is provided. + AssertionError + If the specified unit_table_path is not present in the NWBFile. + """ + from pynwb.misc import Units + + unit_table_dict: Dict[str:Units] = {} + + for item in nwbfile.all_children(): + if isinstance(item, Units): + # retrieve name of "id" column and skip first "/" + unit_table_key = item.id.data.name.replace("/id", "")[1:] + unit_table_dict[unit_table_key] = item + + if unit_table_path is not None: + if unit_table_path not in unit_table_dict: + raise ValueError(f"{unit_table_path} not found in the NWBFile. ") + unit_table = unit_table_dict[unit_table_path] + else: + unit_table_list: List[Units] = list(unit_table_dict.keys()) + + if len(unit_table_list) > 1: + raise ValueError( + f"More than one unit table found! You must specify 'unit_table_list_name'. \n" + f"Options in current file are: {[e for e in unit_table_list]}" + ) + if len(unit_table_list) == 0: + raise ValueError("No unit table found in the .nwb file.") + unit_table = unit_table_dict[unit_table_list[0]] + + return unit_table + + +def _is_hdf5_file(filename_or_file): + if isinstance(filename_or_file, (str, Path)): + import h5py + + filename = str(filename_or_file) + is_hdf5 = h5py.h5f.is_hdf5(filename.encode("utf-8")) + else: + file_signature = filename_or_file.read(8) + # Source of the magic number https://docs.hdfgroup.org/hdf5/develop/_f_m_t3.html + is_hdf5 = file_signature == b"\x89HDF\r\n\x1a\n" + + return is_hdf5 + + +def _get_backend_from_local_file(file_path: str | Path) -> str: + """ + Returns the file backend from a file path ("hdf5", "zarr") + + Parameters + ---------- + file_path : str or Path + The path to the file. + + Returns + ------- + backend : str + The file backend ("hdf5", "zarr") + """ + file_path = Path(file_path) + if file_path.is_file(): + if _is_hdf5_file(file_path): + backend = "hdf5" + else: + raise RuntimeError(f"{file_path} is not a valid HDF5 file!") + elif file_path.is_dir(): + try: + import zarr + + with zarr.open(file_path, "r") as f: + backend = "zarr" + except: + raise RuntimeError(f"{file_path} is not a valid Zarr folder!") + else: + raise RuntimeError(f"File {file_path} is not an existing file or folder!") + return backend + + +def _find_neurodata_type_from_backend(group, path="", result=None, neurodata_type="ElectricalSeries", backend="hdf5"): + """ + Recursively searches for groups with the specified neurodata_type hdf5 or zarr object, + and returns a list with their paths. + """ + if backend == "hdf5": + import h5py + + group_class = h5py.Group + else: + import zarr + + group_class = zarr.Group + + if result is None: + result = [] + + for neurodata_name, value in group.items(): + # Check if it's a group and if it has the neurodata_type + if isinstance(value, group_class): + current_path = f"{path}/{neurodata_name}" if path else neurodata_name + if value.attrs.get("neurodata_type") == neurodata_type: + result.append(current_path) + _find_neurodata_type_from_backend( + value, current_path, result, neurodata_type, backend + ) # Recursive call for sub-groups + return result + + +def _fetch_time_info_pynwb(electrical_series, samples_for_rate_estimation, load_time_vector=False): + """ + Extracts the sampling frequency and the time vector from an ElectricalSeries object. + """ + sampling_frequency = None + if hasattr(electrical_series, "rate"): + sampling_frequency = electrical_series.rate + + if hasattr(electrical_series, "starting_time"): + t_start = electrical_series.starting_time + else: + t_start = None + + timestamps = None + if hasattr(electrical_series, "timestamps"): + if electrical_series.timestamps is not None: + timestamps = electrical_series.timestamps + t_start = electrical_series.timestamps[0] + + # TimeSeries need to have either timestamps or rate + if sampling_frequency is None: + sampling_frequency = 1.0 / np.median(np.diff(timestamps[:samples_for_rate_estimation])) + + if load_time_vector and timestamps is not None: + times_kwargs = dict(time_vector=electrical_series.timestamps) + else: + times_kwargs = dict(sampling_frequency=sampling_frequency, t_start=t_start) + + return sampling_frequency, times_kwargs + + +def _retrieve_electrodes_indices_from_electrical_series_backend(open_file, electrical_series, backend="hdf5"): + """ + Retrieves the indices of the electrodes from the electrical series. + For the Zarr backend, the electrodes are stored in the electrical_series.attrs["zarr_link"]. + """ + if "electrodes" not in electrical_series: + if backend == "zarr": + import zarr + + # links must be resolved + zarr_links = electrical_series.attrs["zarr_link"] + electrodes_path = None + for zarr_link in zarr_links: + if zarr_link["name"] == "electrodes": + electrodes_path = zarr_link["path"] + assert electrodes_path is not None, "electrodes must be present in the electrical series" + electrodes_indices = open_file[electrodes_path][:] + else: + raise ValueError("electrodes must be present in the electrical series") + else: + electrodes_indices = electrical_series["electrodes"][:] + return electrodes_indices + + +class NwbRecordingExtractor(BaseRecording): + """Load an NWBFile as a RecordingExtractor. + + Parameters + ---------- + file_path : str, Path or None + Path to the NWB file or an s3 URL. Use this parameter to specify the file location + if not using the `file` or `h5py_file` parameter. + electrical_series_name : str or None, default: None + Deprecated, use `electrical_series_path` instead. + electrical_series_path : str or None, default: None + The name of the ElectricalSeries object within the NWB file. This parameter is crucial + when the NWB file contains multiple ElectricalSeries objects. It helps in identifying + which specific series to extract data from. If there is only one ElectricalSeries and + this parameter is not set, that unique series will be used by default. + If multiple ElectricalSeries are present and this parameter is not set, an error is raised. + The `electrical_series_path` corresponds to the path within the NWB file, e.g., + 'acquisition/MyElectricalSeries`. + load_time_vector : bool, default: False + If set to True, the time vector is also loaded into the recording object. Useful for + cases where precise timing information is required. + samples_for_rate_estimation : int, default: 1000 + The number of timestamp samples used for estimating the sampling rate. This is relevant + when the 'rate' attribute is not available in the ElectricalSeries. + stream_mode : "fsspec" | "remfile" | "zarr" | None, default: None + Determines the streaming mode for reading the file. Use this for optimized reading from + different sources, such as local disk or remote servers. + load_channel_properties : bool, default: True + If True, all the channel properties are loaded from the NWB file and stored as properties. + For streaming purposes, it can be useful to set this to False to speed up streaming. + file : file-like object or None, default: None + A file-like object representing the NWB file. Use this parameter if you have an in-memory + representation of the NWB file instead of a file path. + h5py_file : h5py.File or None, default: None + A h5py.File-like object representing the NWB file. (jfm) + cache : bool, default: False + Indicates whether to cache the file locally when using streaming. Caching can improve performance for + remote files. + stream_cache_path : str, Path, or None, default: None + Specifies the local path for caching the file. Relevant only if `cache` is True. + storage_options : dict | None = None, + These are the additional kwargs (e.g. AWS credentials) that are passed to the zarr.open convenience function. + This is only used on the "zarr" stream_mode. + use_pynwb : bool, default: False + Uses the pynwb library to read the NWB file. Setting this to False, the default, uses h5py + to read the file. Using h5py can improve performance by bypassing some of the PyNWB validations. + + Returns + ------- + recording : NwbRecordingExtractor + The recording extractor for the NWB file. + + Examples + -------- + Run on local file: + + >>> from spikeinterface.extractors.nwbextractors import NwbRecordingExtractor + >>> rec = NwbRecordingExtractor(filepath) + + Run on s3 URL from the DANDI Archive: + + >>> from spikeinterface.extractors.nwbextractors import NwbRecordingExtractor + >>> from dandi.dandiapi import DandiAPIClient + >>> + >>> # get s3 path + >>> dandiset_id, filepath = "101116", "sub-001/sub-001_ecephys.nwb" + >>> with DandiAPIClient("https://api-staging.dandiarchive.org/api") as client: + >>> asset = client.get_dandiset(dandiset_id, "draft").get_asset_by_path(filepath) + >>> s3_url = asset.get_content_url(follow_redirects=1, strip_query=True) + >>> + >>> rec = NwbRecordingExtractor(s3_url, stream_mode="fsspec", stream_cache_path="cache") + """ + + mode = "file" + name = "nwb" + installation_mesg = "To use the Nwb extractors, install pynwb: \n\n pip install pynwb\n\n" + + def __init__( + self, + file_path: str | Path | None = None, # provide either this or file + electrical_series_name: str | None = None, # deprecated + load_time_vector: bool = False, + samples_for_rate_estimation: int = 1_000, + stream_mode: Optional[Literal["fsspec", "remfile", "zarr"]] = None, + stream_cache_path: str | Path | None = None, + electrical_series_path: str | None = None, + load_channel_properties: bool = True, + *, + file: BinaryIO | None = None, # file-like - provide either this or file_path or h5py_file + h5py_file: h5py.File | None = None, # provide either this or file_path or file + cache: bool = False, + storage_options: dict | None = None, + use_pynwb: bool = False, + ): + + if stream_mode == "ros3": + warnings.warn( + "The 'ros3' stream_mode is deprecated and will be removed in version 0.103.0. " + "Use 'fsspec' stream_mode instead.", + DeprecationWarning, + ) + + if file_path is not None and file is not None: + raise ValueError("Provide either file_path or file, not both") + if file_path is not None and h5py_file is not None: + raise ValueError("Provide either h5py_file or file_path, not both") + if file is not None and h5py_file is not None: + raise ValueError("Provide either h5py_file or file, not both") + if file_path is None and file is None and h5py_file is None: + raise ValueError("Provide either file_path or file or h5py_file") + + if electrical_series_name is not None: + warning_msg = ( + "The `electrical_series_name` parameter is deprecated and will be removed in version 0.101.0.\n" + "Use `electrical_series_path` instead." + ) + if electrical_series_path is None: + warning_msg += f"\nSetting `electrical_series_path` to 'acquisition/{electrical_series_name}'." + electrical_series_path = f"acquisition/{electrical_series_name}" + else: + warning_msg += f"\nIgnoring `electrical_series_name` and using the provided `electrical_series_path`." + warnings.warn(warning_msg, DeprecationWarning, stacklevel=2) + + self.file_path = file_path + self.stream_mode = stream_mode + self.stream_cache_path = stream_cache_path + self.storage_options = storage_options + self.electrical_series_path = electrical_series_path + + if self.stream_mode is None and file is None and h5py_file is None: + self.backend = _get_backend_from_local_file(file_path) + else: + if self.stream_mode == "zarr": + self.backend = "zarr" + else: + self.backend = "hdf5" + + # extract info + if use_pynwb: + try: + import pynwb + except ImportError: + raise ImportError(self.installation_mesg) + + ( + channel_ids, + sampling_frequency, + dtype, + segment_data, + times_kwargs, + ) = self._fetch_recording_segment_info_pynwb(file, h5py_file, cache, load_time_vector, samples_for_rate_estimation) + else: + ( + channel_ids, + sampling_frequency, + dtype, + segment_data, + times_kwargs, + ) = self._fetch_recording_segment_info_backend(file, h5py_file, cache, load_time_vector, samples_for_rate_estimation) + BaseRecording.__init__(self, channel_ids=channel_ids, sampling_frequency=sampling_frequency, dtype=dtype) + recording_segment = NwbRecordingSegment( + electrical_series_data=segment_data, + times_kwargs=times_kwargs, + ) + self.add_recording_segment(recording_segment) + + # fetch and add main recording properties + if use_pynwb: + gains, offsets, locations, groups = self._fetch_main_properties_pynwb() + self.extra_requirements.append("pynwb") + else: + gains, offsets, locations, groups = self._fetch_main_properties_backend() + self.extra_requirements.append("h5py") + self.set_channel_gains(gains) + self.set_channel_offsets(offsets) + if locations is not None: + self.set_channel_locations(locations) + if groups is not None: + self.set_channel_groups(groups) + + # fetch and add additional recording properties + if load_channel_properties: + if use_pynwb: + electrodes_table = self._nwbfile.electrodes + electrodes_indices = self.electrical_series.electrodes.data[:] + columns = electrodes_table.colnames + else: + electrodes_table = self._file["/general/extracellular_ephys/electrodes"] + electrodes_indices = _retrieve_electrodes_indices_from_electrical_series_backend( + self._file, self.electrical_series, self.backend + ) + columns = electrodes_table.attrs["colnames"] + properties = self._fetch_other_properties(electrodes_table, electrodes_indices, columns) + + for property_name, property_values in properties.items(): + values = [x.decode("utf-8") if isinstance(x, bytes) else x for x in property_values] + self.set_property(property_name, values) + + if stream_mode is None and file_path is not None: + file_path = str(Path(file_path).resolve()) + + if stream_mode == "fsspec" and stream_cache_path is not None: + stream_cache_path = str(Path(self.stream_cache_path).absolute()) + + # set serializability bools + if file is not None: + # not json serializable if file arg is provided + self._serializability["json"] = False + if h5py_file is not None: + # not json serializable if h5py_file arg is provided + self._serializability["json"] = False + + if storage_options is not None and stream_mode == "zarr": + warnings.warn( + "The `storage_options` parameter will not be propagated to JSON or pickle files for security reasons, " + "so the extractor will not be JSON/pickle serializable. Only in-memory mode will be available." + ) + # not serializable if storage_options is provided + self._serializability["json"] = False + self._serializability["pickle"] = False + + self._kwargs = { + "file_path": file_path, + "electrical_series_path": self.electrical_series_path, + "load_time_vector": load_time_vector, + "samples_for_rate_estimation": samples_for_rate_estimation, + "stream_mode": stream_mode, + "load_channel_properties": load_channel_properties, + "storage_options": storage_options, + "cache": cache, + "stream_cache_path": stream_cache_path, + "file": file, + "h5py_file": h5py_file + } + + def __del__(self): + # backend mode + if hasattr(self, "_file"): + if hasattr(self._file, "store"): + self._file.store.close() + else: + self._file.close() + # pynwb mode + elif hasattr(self, "_nwbfile"): + io = self._nwbfile.get_read_io() + if io is not None: + io.close() + + def _fetch_recording_segment_info_pynwb(self, file, h5py_file, cache, load_time_vector, samples_for_rate_estimation): + self._nwbfile = read_nwbfile( + backend=self.backend, + file_path=self.file_path, + file=file, + h5py_file=h5py_file, + stream_mode=self.stream_mode, + cache=cache, + stream_cache_path=self.stream_cache_path, + ) + electrical_series = _retrieve_electrical_series_pynwb(self._nwbfile, self.electrical_series_path) + # The indices in the electrode table corresponding to this electrical series + electrodes_indices = electrical_series.electrodes.data[:] + # The table for all the electrodes in the nwbfile + electrodes_table = self._nwbfile.electrodes + + sampling_frequency, times_kwargs = _fetch_time_info_pynwb( + electrical_series=electrical_series, + samples_for_rate_estimation=samples_for_rate_estimation, + load_time_vector=load_time_vector, + ) + + # Fill channel properties dictionary from electrodes table + if "channel_name" in electrodes_table.colnames: + channel_ids = [ + electrical_series.electrodes["channel_name"][electrodes_index] + for electrodes_index in electrodes_indices + ] + else: + channel_ids = [electrical_series.electrodes.table.id[x] for x in electrodes_indices] + electrical_series_data = electrical_series.data + dtype = electrical_series_data.dtype + + # need this later + self.electrical_series = electrical_series + + return channel_ids, sampling_frequency, dtype, electrical_series_data, times_kwargs + + def _fetch_recording_segment_info_backend(self, file, h5py_file, cache, load_time_vector, samples_for_rate_estimation): + open_file = read_file_from_backend( + file_path=self.file_path, + file=file, + h5py_file=h5py_file, + stream_mode=self.stream_mode, + cache=cache, + stream_cache_path=self.stream_cache_path, + ) + + # If the electrical_series_path is not given, `_find_neurodata_type_from_backend` will be called + # And returns a list with the electrical_series_paths available in the file. + # If there is only one electrical series, the electrical_series_path is set to the name of the series, + # otherwise an error is raised. + if self.electrical_series_path is None: + available_electrical_series = _find_neurodata_type_from_backend( + open_file, neurodata_type="ElectricalSeries", backend=self.backend + ) + # if electrical_series_path is None: + if len(available_electrical_series) == 1: + self.electrical_series_path = available_electrical_series[0] + else: + raise ValueError( + "Multiple ElectricalSeries found in the file. " + "Please specify the 'electrical_series_path' argument:" + f"Available options are: {available_electrical_series}." + ) + + # Open the electrical series. In case of failure, raise an error with the available options. + try: + electrical_series = open_file[self.electrical_series_path] + except KeyError: + available_electrical_series = _find_neurodata_type_from_backend( + open_file, neurodata_type="ElectricalSeries", backend=self.backend + ) + raise ValueError( + f"{self.electrical_series_path} not found in the NWB file!" + f"Available options are: {available_electrical_series}." + ) + electrodes_indices = _retrieve_electrodes_indices_from_electrical_series_backend( + open_file, electrical_series, self.backend + ) + # The table for all the electrodes in the nwbfile + electrodes_table = open_file["/general/extracellular_ephys/electrodes"] + electrode_table_columns = electrodes_table.attrs["colnames"] + + # Get sampling frequency + if "starting_time" in electrical_series.keys(): + t_start = electrical_series["starting_time"][()] + sampling_frequency = electrical_series["starting_time"].attrs["rate"] + elif "timestamps" in electrical_series.keys(): + timestamps = electrical_series["timestamps"][:] + t_start = timestamps[0] + sampling_frequency = 1.0 / np.median(np.diff(timestamps[:samples_for_rate_estimation])) + + if load_time_vector and timestamps is not None: + times_kwargs = dict(time_vector=electrical_series["timestamps"]) + else: + times_kwargs = dict(sampling_frequency=sampling_frequency, t_start=t_start) + + # If channel names are present, use them as channel_ids instead of the electrode ids + if "channel_name" in electrode_table_columns: + channel_names = electrodes_table["channel_name"] + channel_ids = channel_names[electrodes_indices] + # Decode if bytes with utf-8 + channel_ids = [x.decode("utf-8") if isinstance(x, bytes) else x for x in channel_ids] + + else: + channel_ids = [electrodes_table["id"][x] for x in electrodes_indices] + + dtype = electrical_series["data"].dtype + electrical_series_data = electrical_series["data"] + + # need this for later + self.electrical_series = electrical_series + self._file = open_file + + return channel_ids, sampling_frequency, dtype, electrical_series_data, times_kwargs + + def _fetch_locations_and_groups(self, electrodes_table, electrodes_indices): + # Channel locations + locations = None + if "rel_x" in electrodes_table: + if "rel_y" in electrodes_table: + ndim = 3 if "rel_z" in electrodes_table else 2 + locations = np.zeros((self.get_num_channels(), ndim), dtype=float) + locations[:, 0] = electrodes_table["rel_x"][electrodes_indices] + locations[:, 1] = electrodes_table["rel_y"][electrodes_indices] + if "rel_z" in electrodes_table: + locations[:, 2] = electrodes_table["rel_z"][electrodes_indices] + + # allow x, y, z instead of rel_x, rel_y, rel_z + if locations is None: + if "x" in electrodes_table: + if "y" in electrodes_table: + ndim = 3 if "z" in electrodes_table else 2 + locations = np.zeros((self.get_num_channels(), ndim), dtype=float) + locations[:, 0] = electrodes_table["x"][electrodes_indices] + locations[:, 1] = electrodes_table["y"][electrodes_indices] + if "z" in electrodes_table: + locations[:, 2] = electrodes_table["z"][electrodes_indices] + + # Channel groups + groups = None + if "group_name" in electrodes_table: + groups = electrodes_table["group_name"][electrodes_indices][:] + if groups is not None: + groups = np.array([x.decode("utf-8") if isinstance(x, bytes) else x for x in groups]) + return locations, groups + + def _fetch_other_properties(self, electrodes_table, electrodes_indices, columns): + ######### + # Extract and re-name properties from nwbfile TODO: Should be a function + ######## + properties = dict() + properties_to_skip = [ + "id", + "rel_x", + "rel_y", + "rel_z", + "group", + "group_name", + "channel_name", + "offset", + ] + rename_properties = dict(location="brain_area") + + for column in columns: + if column in properties_to_skip: + continue + else: + column_name = rename_properties.get(column, column) + properties[column_name] = electrodes_table[column][electrodes_indices] + + return properties + + def _fetch_main_properties_pynwb(self): + """ + Fetches the main properties from the NWBFile and stores them in the RecordingExtractor, including: + + - gains + - offsets + - locations + - groups + """ + electrodes_indices = self.electrical_series.electrodes.data[:] + electrodes_table = self._nwbfile.electrodes + + # Channels gains - for RecordingExtractor, these are values to cast traces to uV + gains = self.electrical_series.conversion * 1e6 + if self.electrical_series.channel_conversion is not None: + gains = self.electrical_series.conversion * self.electrical_series.channel_conversion[:] * 1e6 + + # Channel offsets + offset = self.electrical_series.offset if hasattr(self.electrical_series, "offset") else 0 + if offset == 0 and "offset" in electrodes_table: + offset = electrodes_table["offset"].data[electrodes_indices] + offsets = offset * 1e6 + + locations, groups = self._fetch_locations_and_groups(electrodes_table, electrodes_indices) + + return gains, offsets, locations, groups + + def _fetch_main_properties_backend(self): + """ + Fetches the main properties from the NWBFile and stores them in the RecordingExtractor, including: + + - gains + - offsets + - locations + - groups + """ + electrodes_indices = _retrieve_electrodes_indices_from_electrical_series_backend( + self._file, self.electrical_series, self.backend + ) + electrodes_table = self._file["/general/extracellular_ephys/electrodes"] + + # Channels gains - for RecordingExtractor, these are values to cast traces to uV + data_attributes = self.electrical_series["data"].attrs + electrical_series_conversion = data_attributes["conversion"] + gains = electrical_series_conversion * 1e6 + channel_conversion = self.electrical_series.get("channel_conversion", None) + if channel_conversion: + gains *= self.electrical_series["channel_conversion"][:] + + # Channel offsets + offset = data_attributes["offset"] if "offset" in data_attributes else 0 + if offset == 0 and "offset" in electrodes_table: + offset = electrodes_table["offset"][electrodes_indices] + offsets = offset * 1e6 + + # Channel locations and groups + locations, groups = self._fetch_locations_and_groups(electrodes_table, electrodes_indices) + + return gains, offsets, locations, groups + + @staticmethod + def fetch_available_electrical_series_paths( + file_path: str | Path, + stream_mode: Optional[Literal["fsspec", "remfile", "zarr"]] = None, + storage_options: dict | None = None, + ) -> list[str]: + """ + Retrieves the paths to all ElectricalSeries objects within a neurodata file. + + Parameters + ---------- + file_path : str | Path + The path to the neurodata file to be analyzed. + stream_mode : "fsspec" | "remfile" | "zarr" | None, optional + Determines the streaming mode for reading the file. Use this for optimized reading from + different sources, such as local disk or remote servers. + storage_options : dict | None = None, + These are the additional kwargs (e.g. AWS credentials) that are passed to the zarr.open convenience function. + This is only used on the "zarr" stream_mode. + Returns + ------- + list of str + A list of paths to all ElectricalSeries objects found in the file. + + + Notes + ----- + The paths are returned as strings, and can be used to load the desired ElectricalSeries object. + Examples of paths are: + - "acquisition/ElectricalSeries1" + - "acquisition/ElectricalSeries2" + - "processing/ecephys/LFP/ElectricalSeries1" + - "processing/my_custom_module/MyContainer/ElectricalSeries2" + """ + + if stream_mode is None: + backend = _get_backend_from_local_file(file_path) + else: + if stream_mode == "zarr": + backend = "zarr" + else: + backend = "hdf5" + + file_handle = read_file_from_backend( + file_path=file_path, + stream_mode=stream_mode, + storage_options=storage_options, + ) + + electrical_series_paths = _find_neurodata_type_from_backend( + file_handle, + neurodata_type="ElectricalSeries", + backend=backend, + ) + return electrical_series_paths + + +class NwbRecordingSegment(BaseRecordingSegment): + def __init__(self, electrical_series_data, times_kwargs): + BaseRecordingSegment.__init__(self, **times_kwargs) + self.electrical_series_data = electrical_series_data + self._num_samples = self.electrical_series_data.shape[0] + + def get_num_samples(self): + """Returns the number of samples in this signal block + + Returns: + SampleIndex : Number of samples in the signal block + """ + return self._num_samples + + def get_traces(self, start_frame, end_frame, channel_indices): + electrical_series_data = self.electrical_series_data + if electrical_series_data.ndim == 1: + traces = electrical_series_data[start_frame:end_frame][:, np.newaxis] + elif isinstance(channel_indices, slice): + traces = electrical_series_data[start_frame:end_frame, channel_indices] + else: + # channel_indices is np.ndarray + if np.array(channel_indices).size > 1 and np.any(np.diff(channel_indices) < 0): + # get around h5py constraint that it does not allow datasets + # to be indexed out of order + sorted_channel_indices = np.sort(channel_indices) + resorted_indices = np.array([list(sorted_channel_indices).index(ch) for ch in channel_indices]) + recordings = electrical_series_data[start_frame:end_frame, sorted_channel_indices] + traces = recordings[:, resorted_indices] + else: + traces = electrical_series_data[start_frame:end_frame, channel_indices] + + return traces + + +class NwbSortingExtractor(BaseSorting): + """Load an NWBFile as a SortingExtractor. + Parameters + ---------- + file_path : str or Path + Path to NWB file. + electrical_series_path : str or None, default: None + The name of the ElectricalSeries (if multiple ElectricalSeries are present). + sampling_frequency : float or None, default: None + The sampling frequency in Hz (required if no ElectricalSeries is available). + unit_table_path : str or None, default: "units" + The path of the unit table in the NWB file. + samples_for_rate_estimation : int, default: 100000 + The number of timestamp samples to use to estimate the rate. + Used if "rate" is not specified in the ElectricalSeries. + stream_mode : "fsspec" | "remfile" | "zarr" | None, default: None + The streaming mode to use. If None it assumes the file is on the local disk. + stream_cache_path : str or Path or None, default: None + Local path for caching. If None it uses the system temporary directory. + load_unit_properties : bool, default: True + If True, all the unit properties are loaded from the NWB file and stored as properties. + t_start : float or None, default: None + This is the time at which the corresponding ElectricalSeries start. NWB stores its spikes as times + and the `t_start` is used to convert the times to seconds. Concrently, the returned frames are computed as: + + `frames = (times - t_start) * sampling_frequency`. + + As SpikeInterface always considers the first frame to be at the beginning of the recording independently + of the `t_start`. + + When a `t_start` is not provided it will be inferred from the corresponding ElectricalSeries with name equal + to `electrical_series_path`. The `t_start` then will be either the `ElectricalSeries.starting_time` or the + first timestamp in the `ElectricalSeries.timestamps`. + cache : bool, default: False + If True, the file is cached in the file passed to stream_cache_path + if False, the file is not cached. + storage_options : dict | None = None, + These are the additional kwargs (e.g. AWS credentials) that are passed to the zarr.open convenience function. + This is only used on the "zarr" stream_mode. + use_pynwb : bool, default: False + Uses the pynwb library to read the NWB file. Setting this to False, the default, uses h5py + to read the file. Using h5py can improve performance by bypassing some of the PyNWB validations. + + Returns + ------- + sorting : NwbSortingExtractor + The sorting extractor for the NWB file. + """ + + mode = "file" + installation_mesg = "To use the Nwb extractors, install pynwb: \n\n pip install pynwb\n\n" + name = "nwb" + + def __init__( + self, + file_path: str | Path, + electrical_series_path: str | None = None, + sampling_frequency: float | None = None, + samples_for_rate_estimation: int = 1_000, + stream_mode: str | None = None, + stream_cache_path: str | Path | None = None, + load_unit_properties: bool = True, + unit_table_path: str = "units", + *, + t_start: float | None = None, + cache: bool = False, + storage_options: dict | None = None, + use_pynwb: bool = False, + ): + + if stream_mode == "ros3": + warnings.warn( + "The 'ros3' stream_mode is deprecated and will be removed in version 0.103.0. " + "Use 'fsspec' stream_mode instead.", + DeprecationWarning, + ) + + self.stream_mode = stream_mode + self.stream_cache_path = stream_cache_path + self.electrical_series_path = electrical_series_path + self.file_path = file_path + self.t_start = t_start + self.provided_or_electrical_series_sampling_frequency = sampling_frequency + self.storage_options = storage_options + self.units_table = None + + if self.stream_mode is None: + self.backend = _get_backend_from_local_file(file_path) + else: + if self.stream_mode == "zarr": + self.backend = "zarr" + else: + self.backend = "hdf5" + + if use_pynwb: + try: + import pynwb + except ImportError: + raise ImportError(self.installation_mesg) + + unit_ids, spike_times_data, spike_times_index_data = self._fetch_sorting_segment_info_pynwb( + unit_table_path=unit_table_path, samples_for_rate_estimation=samples_for_rate_estimation, cache=cache + ) + else: + unit_ids, spike_times_data, spike_times_index_data = self._fetch_sorting_segment_info_backend( + unit_table_path=unit_table_path, samples_for_rate_estimation=samples_for_rate_estimation, cache=cache + ) + + BaseSorting.__init__( + self, sampling_frequency=self.provided_or_electrical_series_sampling_frequency, unit_ids=unit_ids + ) + + sorting_segment = NwbSortingSegment( + spike_times_data=spike_times_data, + spike_times_index_data=spike_times_index_data, + sampling_frequency=self.sampling_frequency, + t_start=self.t_start, + ) + self.add_sorting_segment(sorting_segment) + + # fetch and add sorting properties + if load_unit_properties: + if use_pynwb: + columns = [c.name for c in self.units_table.columns] + self.extra_requirements.append("pynwb") + else: + columns = list(self.units_table.keys()) + self.extra_requirements.append("h5py") + properties = self._fetch_properties(columns) + for property_name, property_values in properties.items(): + values = [x.decode("utf-8") if isinstance(x, bytes) else x for x in property_values] + self.set_property(property_name, values) + + if stream_mode is None and file_path is not None: + file_path = str(Path(file_path).resolve()) + + if storage_options is not None and stream_mode == "zarr": + warnings.warn( + "The `storage_options` parameter will not be propagated to JSON or pickle files for security reasons, " + "so the extractor will not be JSON/pickle serializable. Only in-memory mode will be available." + ) + # not serializable if storage_options is provided + self._serializability["json"] = False + self._serializability["pickle"] = False + + self._kwargs = { + "file_path": file_path, + "electrical_series_path": self.electrical_series_path, + "sampling_frequency": sampling_frequency, + "samples_for_rate_estimation": samples_for_rate_estimation, + "cache": cache, + "stream_mode": stream_mode, + "stream_cache_path": stream_cache_path, + "storage_options": storage_options, + "load_unit_properties": load_unit_properties, + "t_start": self.t_start, + } + + def __del__(self): + # backend mode + if hasattr(self, "_file"): + if hasattr(self._file, "store"): + self._file.store.close() + else: + self._file.close() + # pynwb mode + elif hasattr(self, "_nwbfile"): + io = self._nwbfile.get_read_io() + if io is not None: + io.close() + + def _fetch_sorting_segment_info_pynwb( + self, unit_table_path: str = None, samples_for_rate_estimation: int = 1000, cache: bool = False + ): + self._nwbfile = read_nwbfile( + backend=self.backend, + file_path=self.file_path, + stream_mode=self.stream_mode, + cache=cache, + stream_cache_path=self.stream_cache_path, + storage_options=self.storage_options, + ) + + timestamps = None + if self.provided_or_electrical_series_sampling_frequency is None: + # defines the electrical series from where the sorting came from + # important to know the sampling_frequency + self.electrical_series = _retrieve_electrical_series_pynwb(self._nwbfile, self.electrical_series_path) + # get rate + if self.electrical_series.rate is not None: + self.provided_or_electrical_series_sampling_frequency = self.electrical_series.rate + self.t_start = self.electrical_series.starting_time + else: + if hasattr(self.electrical_series, "timestamps"): + if self.electrical_series.timestamps is not None: + timestamps = self.electrical_series.timestamps + self.provided_or_electrical_series_sampling_frequency = 1 / np.median( + np.diff(timestamps[:samples_for_rate_estimation]) + ) + self.t_start = timestamps[0] + assert ( + self.provided_or_electrical_series_sampling_frequency is not None + ), "Couldn't load sampling frequency. Please provide it with the 'sampling_frequency' argument" + assert ( + self.t_start is not None + ), "Couldn't load a starting time for the sorting. Please provide it with the 't_start' argument" + if unit_table_path == "units": + units_table = self._nwbfile.units + else: + units_table = _retrieve_unit_table_pynwb(self._nwbfile, unit_table_path=unit_table_path) + + name_to_column_data = {c.name: c for c in units_table.columns} + spike_times_data = name_to_column_data.pop("spike_times").data + spike_times_index_data = name_to_column_data.pop("spike_times_index").data + + units_ids = name_to_column_data.pop("unit_name", None) + if units_ids is None: + units_ids = units_table["id"].data + + # need this for later + self.units_table = units_table + + return units_ids, spike_times_data, spike_times_index_data + + def _fetch_sorting_segment_info_backend( + self, unit_table_path: str = None, samples_for_rate_estimation: int = 1000, cache: bool = False + ): + open_file = read_file_from_backend( + file_path=self.file_path, + stream_mode=self.stream_mode, + cache=cache, + stream_cache_path=self.stream_cache_path, + storage_options=self.storage_options, + ) + + timestamps = None + + if self.provided_or_electrical_series_sampling_frequency is None or self.t_start is None: + # defines the electrical series from where the sorting came from + # important to know the sampling_frequency + available_electrical_series = _find_neurodata_type_from_backend( + open_file, neurodata_type="ElectricalSeries", backend=self.backend + ) + if self.electrical_series_path is None: + if len(available_electrical_series) == 1: + self.electrical_series_path = available_electrical_series[0] + else: + raise ValueError( + "Multiple ElectricalSeries found in the file. " + "Please specify the 'electrical_series_path' argument:" + f"Available options are: {available_electrical_series}." + ) + else: + if self.electrical_series_path not in available_electrical_series: + raise ValueError( + f"'{self.electrical_series_path}' not found in the file. " + f"Available options are: {available_electrical_series}" + ) + electrical_series = open_file[self.electrical_series_path] + + # Get sampling frequency + if "starting_time" in electrical_series.keys(): + self.t_start = electrical_series["starting_time"][()] + self.provided_or_electrical_series_sampling_frequency = electrical_series["starting_time"].attrs["rate"] + elif "timestamps" in electrical_series.keys(): + timestamps = electrical_series["timestamps"][:] + self.t_start = timestamps[0] + self.provided_or_electrical_series_sampling_frequency = 1.0 / np.median( + np.diff(timestamps[:samples_for_rate_estimation]) + ) + + assert ( + self.provided_or_electrical_series_sampling_frequency is not None + ), "Couldn't load sampling frequency. Please provide it with the 'sampling_frequency' argument" + assert ( + self.t_start is not None + ), "Couldn't load a starting time for the sorting. Please provide it with the 't_start' argument" + + if unit_table_path is None: + available_unit_table_paths = _find_neurodata_type_from_backend( + open_file, neurodata_type="Units", backend=self.backend + ) + if len(available_unit_table_paths) == 1: + unit_table_path = available_unit_table_paths[0] + else: + raise ValueError( + "Multiple Units tables found in the file. " + "Please specify the 'unit_table_path' argument:" + f"Available options are: {available_unit_table_paths}." + ) + # Try to open the unit table. If it fails, raise an error with the available options. + try: + units_table = open_file[unit_table_path] + except KeyError: + available_unit_table_paths = _find_neurodata_type_from_backend( + open_file, neurodata_type="Units", backend=self.backend + ) + raise ValueError( + f"{unit_table_path} not found in the NWB file!" f"Available options are: {available_unit_table_paths}." + ) + self.units_table_location = unit_table_path + units_table = open_file[self.units_table_location] + + spike_times_data = units_table["spike_times"] + spike_times_index_data = units_table["spike_times_index"] + + if "unit_name" in units_table: + unit_ids = units_table["unit_name"] + else: + unit_ids = units_table["id"] + + decode_to_string = lambda x: x.decode("utf-8") if isinstance(x, bytes) else x + unit_ids = [decode_to_string(id) for id in unit_ids] + + # need this for later + self.units_table = units_table + + return unit_ids, spike_times_data, spike_times_index_data + + def _fetch_properties(self, columns): + units_table = self.units_table + + properties_to_skip = ["spike_times", "spike_times_index", "unit_name", "id"] + index_columns = [name for name in columns if name.endswith("_index")] + nested_ragged_array_properties = [name for name in columns if f"{name}_index_index" in columns] + + # Filter those properties that are nested ragged arrays + skip_properties = properties_to_skip + index_columns + nested_ragged_array_properties + properties_to_add = [name for name in columns if name not in skip_properties] + + properties = dict() + for property_name in properties_to_add: + data = units_table[property_name][:] + corresponding_index_name = f"{property_name}_index" + not_ragged_array = corresponding_index_name not in columns + if not_ragged_array: + values = data[:] + else: # TODO if we want we could make this recursive to handle nested ragged arrays + data_index = units_table[corresponding_index_name] + if hasattr(data_index, "data"): + # for pynwb we need to get the data from the data attribute + data_index = data_index.data[:] + else: + data_index = data_index[:] + index_spacing = np.diff(data_index, prepend=0) + all_index_spacing_are_the_same = np.unique(index_spacing).size == 1 + if all_index_spacing_are_the_same: + if hasattr(units_table[corresponding_index_name], "data"): + # ragged array indexing is handled by pynwb + values = data + else: + # ravel array based on data_index + start_indices = [0] + list(data_index[:-1]) + end_indices = list(data_index) + values = [ + data[start_index:end_index] for start_index, end_index in zip(start_indices, end_indices) + ] + else: + warnings.warn(f"Skipping {property_name} because of unequal shapes across units") + continue + properties[property_name] = values + + return properties + + +class NwbSortingSegment(BaseSortingSegment): + def __init__(self, spike_times_data, spike_times_index_data, sampling_frequency: float, t_start: float): + BaseSortingSegment.__init__(self) + self.spike_times_data = spike_times_data + self.spike_times_index_data = spike_times_index_data + self.spike_times_data = spike_times_data + self.spike_times_index_data = spike_times_index_data + self._sampling_frequency = sampling_frequency + self._t_start = t_start + + def get_unit_spike_train( + self, + unit_id, + start_frame: Optional[int] = None, + end_frame: Optional[int] = None, + ) -> np.ndarray: + # Extract the spike times for the unit + unit_index = self.parent_extractor.id_to_index(unit_id) + if unit_index == 0: + start_index = 0 + else: + start_index = self.spike_times_index_data[unit_index - 1] + end_index = self.spike_times_index_data[unit_index] + spike_times = self.spike_times_data[start_index:end_index] + + # Transform spike times to frames and subset + frames = np.round((spike_times - self._t_start) * self._sampling_frequency) + + start_index = 0 + if start_frame is not None: + start_index = np.searchsorted(frames, start_frame, side="left") + + if end_frame is not None: + end_index = np.searchsorted(frames, end_frame, side="left") + else: + end_index = frames.size + + return frames[start_index:end_index].astype("int64", copy=False) + + +read_nwb_recording = define_function_from_class(source_class=NwbRecordingExtractor, name="read_nwb_recording") +read_nwb_sorting = define_function_from_class(source_class=NwbSortingExtractor, name="read_nwb_sorting") + + +def read_nwb(file_path, load_recording=True, load_sorting=False, electrical_series_path=None): + """Reads NWB file into SpikeInterface extractors. + + Parameters + ---------- + file_path : str or Path + Path to NWB file. + load_recording : bool, default: True + If True, the recording object is loaded. + load_sorting : bool, default: False + If True, the recording object is loaded. + electrical_series_path : str or None, default: None + The name of the ElectricalSeries (if multiple ElectricalSeries are present) + + Returns + ------- + extractors : extractor or tuple + Single RecordingExtractor/SortingExtractor or tuple with both + (depending on "load_recording"/"load_sorting") arguments. + """ + outputs = () + if load_recording: + rec = read_nwb_recording(file_path, electrical_series_path=electrical_series_path) + outputs = outputs + (rec,) + if load_sorting: + sorting = read_nwb_sorting(file_path, electrical_series_path=electrical_series_path) + outputs = outputs + (sorting,) + + if len(outputs) == 1: + outputs = outputs[0] + + return outputs diff --git a/examples/DANDI/preprocess_ephys.py b/examples/DANDI/preprocess_ephys.py new file mode 100644 index 0000000..ff454a6 --- /dev/null +++ b/examples/DANDI/preprocess_ephys.py @@ -0,0 +1,111 @@ +import numpy as np +import lindi +import pynwb +from pynwb.ecephys import ElectricalSeries +import spikeinterface.preprocessing as spre +from nwbextractors import NwbRecordingExtractor +from qfc.codecs import QFCCodec +from qfc import qfc_estimate_quant_scale_factor + +QFCCodec.register_codec() + + +def preprocess_ephys(): + # https://neurosift.app/?p=/nwb&url=https://api.dandiarchive.org/api/assets/2e6b590a-a2a4-4455-bb9b-45cc3d7d7cc0/download/&dandisetId=000463&dandisetVersion=draft + url = "https://api.dandiarchive.org/api/assets/2e6b590a-a2a4-4455-bb9b-45cc3d7d7cc0/download/" + + print('Creating LINDI file') + with lindi.LindiH5pyFile.from_hdf5_file(url) as f: + f.write_lindi_file("example.nwb.lindi.tar") + + cache = lindi.LocalCache() + + print('Reading LINDI file') + with lindi.LindiH5pyFile.from_lindi_file("example.nwb.lindi.tar", mode="r", local_cache=cache) as f: + electrical_series_path = '/acquisition/ElectricalSeries' + + print("Loading recording") + recording = NwbRecordingExtractor( + h5py_file=f, electrical_series_path=electrical_series_path + ) + print(recording.get_channel_ids()) + + num_frames = recording.get_num_frames() + start_time_sec = 0 + # duration_sec = 300 + duration_sec = num_frames / recording.get_sampling_frequency() + start_frame = int(start_time_sec * recording.get_sampling_frequency()) + end_frame = int(np.minimum(num_frames, (start_time_sec + duration_sec) * recording.get_sampling_frequency())) + recording = recording.frame_slice( + start_frame=start_frame, + end_frame=end_frame + ) + + # bandpass filter + print("Filtering recording") + freq_min = 300 + freq_max = 6000 + recording_filtered = spre.bandpass_filter( + recording, freq_min=freq_min, freq_max=freq_max, dtype=np.float32 + ) # important to specify dtype here + f.close() + + traces0 = recording_filtered.get_traces(start_frame=0, end_frame=int(1 * recording_filtered.get_sampling_frequency())) + traces0 = traces0.astype(dtype=traces0.dtype, order='C') + + # noise_level = estimate_noise_level(traces0) + # print(f'Noise level: {noise_level}') + # scale_factor = qfc_estimate_quant_scale_factor(traces0, target_residual_stdev=noise_level * 0.2) + + compression_method = 'zlib' + zlib_level = 3 + zstd_level = 3 + + scale_factor = qfc_estimate_quant_scale_factor( + traces0, + target_compression_ratio=10, + compression_method=compression_method, + zlib_level=zlib_level, + zstd_level=zstd_level + ) + print(f'Quant. scale factor: {scale_factor}') + codec = QFCCodec( + quant_scale_factor=scale_factor, + dtype='float32', + segment_length=int(recording_filtered.get_sampling_frequency() * 1), + compression_method=compression_method, + zlib_level=zlib_level, + zstd_level=zstd_level + ) + traces0_compressed = codec.encode(traces0) + compression_ratio = traces0.size * 2 / len(traces0_compressed) + print(f'Compression ratio: {compression_ratio}') + + print("Writing filtered recording to LINDI file") + with lindi.LindiH5pyFile.from_lindi_file("example.nwb.lindi.tar", mode="a", local_cache=cache) as f: + with pynwb.NWBHDF5IO(file=f, mode='a') as io: + nwbfile = io.read() + + electrical_series = nwbfile.acquisition['ElectricalSeries'] + electrical_series_pre = ElectricalSeries( + name="ElectricalSeries_pre", + data=pynwb.H5DataIO( + recording_filtered.get_traces(), + chunks=(30000, recording.get_num_channels()), + compression=codec + ), + electrodes=electrical_series.electrodes, + starting_time=0.0, # timestamp of the first sample in seconds relative to the session start time + rate=recording_filtered.get_sampling_frequency(), + ) + nwbfile.add_acquisition(electrical_series_pre) # type: ignore + io.write(nwbfile) + + +def estimate_noise_level(traces): + noise_level = np.median(np.abs(traces - np.median(traces))) / 0.6745 + return noise_level + + +if __name__ == "__main__": + preprocess_ephys() \ No newline at end of file diff --git a/examples/benchmark1.py b/examples/benchmark1.py new file mode 100644 index 0000000..40f174f --- /dev/null +++ b/examples/benchmark1.py @@ -0,0 +1,125 @@ +import os +import h5py +import numpy as np +import time +import lindi +import gzip +import zarr +import numcodecs + + +def create_dataset(size): + return np.random.rand(size) + + +def benchmark_h5py(file_path, num_small_datasets, num_large_datasets, small_size, large_size, chunks, compression, mode): + start_time = time.time() + + if mode == 'dat': + with open(file_path, 'wb') as f: + # Write small datasets + print('Writing small datasets') + for i in range(num_small_datasets): + data = create_dataset(small_size) + f.write(data.tobytes()) + + # Write large datasets + print('Writing large datasets') + for i in range(num_large_datasets): + data = create_dataset(large_size) + if compression == 'gzip': + data_zipped = gzip.compress(data.tobytes(), compresslevel=4) + f.write(data_zipped) + elif compression is None: + f.write(data.tobytes()) + else: + raise ValueError(f"Unknown compressor: {compression}") + elif mode == 'zarr': + if os.path.exists(file_path): + import shutil + shutil.rmtree(file_path) + store = zarr.DirectoryStore(file_path) + root = zarr.group(store) + + if compression == 'gzip': + compressor = numcodecs.GZip(level=4) + else: + compressor = None + + # Write small datasets + print('Writing small datasets') + for i in range(num_small_datasets): + data = create_dataset(small_size) + root.create_dataset(f'small_dataset_{i}', data=data) + + # Write large datasets + print('Writing large datasets') + for i in range(num_large_datasets): + data = create_dataset(large_size) + root.create_dataset(f'large_dataset_{i}', data=data, chunks=chunks, compressor=compressor) + else: + if mode == 'h5': + f = h5py.File(file_path, 'w') + else: + f = lindi.LindiH5pyFile.from_lindi_file(file_path, mode='w') + + # Write small datasets + print('Writing small datasets') + for i in range(num_small_datasets): + data = create_dataset(small_size) + ds = f.create_dataset(f'small_dataset_{i}', data=data) + ds.attrs['attr1'] = 1 + + # Write large datasets + print('Writing large datasets') + for i in range(num_large_datasets): + data = create_dataset(large_size) + ds = f.create_dataset(f'large_dataset_{i}', data=data, chunks=chunks, compression=compression) + ds.attrs['attr1'] = 1 + + f.close() + + end_time = time.time() + total_time = end_time - start_time + + # Calculate total data size + total_size = (num_small_datasets * small_size + num_large_datasets * large_size) * 8 # 8 bytes per float64 + total_size_gb = total_size / (1024 ** 3) + + print("Benchmark Results:") + print(f"Total time: {total_time:.2f} seconds") + print(f"Total data size: {total_size_gb:.2f} GB") + print(f"Write speed: {total_size_gb / total_time:.2f} GB/s") + + h5py_file_size = os.path.getsize(file_path) / (1024 ** 3) + print(f"File size: {h5py_file_size:.2f} GB") + + return total_time, total_size_gb + + +if __name__ == "__main__": + file_path_h5 = "benchmark.h5" + file_path_lindi = "benchmark.lindi.tar" + file_path_dat = "benchmark.dat" + file_path_zarr = "benchmark.zarr" + num_small_datasets = 0 + num_large_datasets = 5 + small_size = 1000 + large_size = 100000000 + compression = None # 'gzip' or None + chunks = (large_size / 20,) + + print('Lindi Benchmark') + lindi_time, total_size = benchmark_h5py(file_path_lindi, num_small_datasets, num_large_datasets, small_size, large_size, chunks=chunks, compression=compression, mode='lindi') + print('') + print('Zarr Benchmark') + lindi_time, total_size = benchmark_h5py(file_path_zarr, num_small_datasets, num_large_datasets, small_size, large_size, chunks=chunks, compression=compression, mode='zarr') + print('') + print('H5PY Benchmark') + h5py_time, total_size = benchmark_h5py(file_path_h5, num_small_datasets, num_large_datasets, small_size, large_size, chunks=chunks, compression=compression, mode='h5') + print('') + print('DAT Benchmark') + dat, total_size = benchmark_h5py(file_path_dat, num_small_datasets, num_large_datasets, small_size, large_size, chunks=chunks, compression=compression, mode='dat') + + import shutil + shutil.copyfile(file_path_lindi, file_path_lindi + '.tar') diff --git a/examples/example_a.py b/examples/example_a.py new file mode 100644 index 0000000..3f6390c --- /dev/null +++ b/examples/example_a.py @@ -0,0 +1,14 @@ +import lindi + +# Create a new lindi.json file +with lindi.LindiH5pyFile.from_lindi_file('example.lindi.json', mode='w') as f: + f.attrs['attr1'] = 'value1' + f.attrs['attr2'] = 7 + ds = f.create_dataset('dataset1', shape=(10,), dtype='f') + ds[...] = 12 + +# Later read the file +with lindi.LindiH5pyFile.from_lindi_file('example.lindi.json', mode='r') as f: + print(f.attrs['attr1']) + print(f.attrs['attr2']) + print(f['dataset1'][...]) \ No newline at end of file diff --git a/examples/example_ammend_remote_nwb.py b/examples/example_ammend_remote_nwb.py new file mode 100644 index 0000000..348ec31 --- /dev/null +++ b/examples/example_ammend_remote_nwb.py @@ -0,0 +1,33 @@ +import numpy as np +import lindi +import pynwb + + +def example_ammend_remote_nwb(): + url = 'https://api.dandiarchive.org/api/assets/2e6b590a-a2a4-4455-bb9b-45cc3d7d7cc0/download/' + with lindi.LindiH5pyFile.from_hdf5_file(url) as f: + f.write_lindi_file('example.nwb.lindi.tar') + with lindi.LindiH5pyFile.from_lindi_file('example.nwb.lindi.tar', mode='r+') as f: + + # Can't figure out how to modify something using pyNWB + # with pynwb.NWBHDF5IO(file=f, mode='r+') as io: + # nwbfile = io.read() + # print(nwbfile) + # nwbfile.session_description = 'Modified session description' + # io.write(nwbfile) + + f['session_description'][()] = 'new session description' + + # Create something that will become a new file in the tar + ds = f.create_dataset('new_dataset', data=np.random.rand(10000, 1000), chunks=(1000, 200)) + ds[20, 20] = 42 + + with lindi.LindiH5pyFile.from_lindi_file('example.nwb.lindi.tar', mode='r') as f: + with pynwb.NWBHDF5IO(file=f, mode='r') as io: + nwbfile = io.read() + print(nwbfile) + print(f['new_dataset'][20, 20]) + + +if __name__ == '__main__': + example_ammend_remote_nwb() diff --git a/examples/example_b.py b/examples/example_b.py new file mode 100644 index 0000000..c982e58 --- /dev/null +++ b/examples/example_b.py @@ -0,0 +1,15 @@ +import numpy as np +import lindi + +# Create a new lindi binary file +with lindi.LindiH5pyFile.from_lindi_file('example.lindi.tar', mode='w') as f: + f.attrs['attr1'] = 'value1' + f.attrs['attr2'] = 7 + ds = f.create_dataset('dataset1', shape=(1000, 1000), dtype='f') + ds[...] = np.random.rand(1000, 1000) + +# Later read the file +with lindi.LindiH5pyFile.from_lindi_file('example.lindi.tar', mode='r') as f: + print(f.attrs['attr1']) + print(f.attrs['attr2']) + print(f['dataset1'][...]) \ No newline at end of file diff --git a/examples/example_c.py b/examples/example_c.py new file mode 100644 index 0000000..279dfc2 --- /dev/null +++ b/examples/example_c.py @@ -0,0 +1,36 @@ +import json +import pynwb +import lindi + +# Define the URL for a remote NWB file +h5_url = "https://api.dandiarchive.org/api/assets/11f512ba-5bcf-4230-a8cb-dc8d36db38cb/download/" + +# Load as LINDI and view using pynwb +f = lindi.LindiH5pyFile.from_hdf5_file(h5_url) +with pynwb.NWBHDF5IO(file=f, mode="r") as io: + nwbfile = io.read() + print('NWB via LINDI') + print(nwbfile) + + print('Electrode group at shank0:') + print(nwbfile.electrode_groups["shank0"]) # type: ignore + + print('Electrode group at index 0:') + print(nwbfile.electrodes.group[0]) # type: ignore + +# Save as LINDI JSON +f.write_lindi_file('example.nwb.lindi.json') + +# Later, read directly from the LINDI JSON file +g = lindi.LindiH5pyFile.from_lindi_file('example.nwb.lindi.json') +with pynwb.NWBHDF5IO(file=g, mode="r") as io: + nwbfile = io.read() + print('') + print('NWB from LINDI JSON:') + print(nwbfile) + + print('Electrode group at shank0:') + print(nwbfile.electrode_groups["shank0"]) # type: ignore + + print('Electrode group at index 0:') + print(nwbfile.electrodes.group[0]) # type: ignore \ No newline at end of file diff --git a/examples/example_d.py b/examples/example_d.py new file mode 100644 index 0000000..aca4749 --- /dev/null +++ b/examples/example_d.py @@ -0,0 +1,15 @@ +import numpy as np +import lindi + +# Create a new lindi binary file +with lindi.LindiH5pyFile.from_lindi_file('example.lindi.d', mode='w') as f: + f.attrs['attr1'] = 'value1' + f.attrs['attr2'] = 7 + ds = f.create_dataset('dataset1', shape=(1000, 1000), dtype='f') + ds[...] = np.random.rand(1000, 1000) + +# Later read the file +with lindi.LindiH5pyFile.from_lindi_file('example.lindi.d', mode='r') as f: + print(f.attrs['attr1']) + print(f.attrs['attr2']) + print(f['dataset1'][...]) \ No newline at end of file diff --git a/examples/example_tar_nwb.py b/examples/example_tar_nwb.py new file mode 100644 index 0000000..cd03dd0 --- /dev/null +++ b/examples/example_tar_nwb.py @@ -0,0 +1,147 @@ +from typing import Any +import pynwb +import h5py +import lindi + + +nwb_lindi_fname = 'example.nwb.lindi.tar' +nwb_fname = 'example.nwb' + + +def test_write_lindi(): + print('test_write_lindi') + nwbfile = _create_sample_nwb_file() + with lindi.LindiH5pyFile.from_lindi_file(nwb_lindi_fname, mode='w') as client: + with pynwb.NWBHDF5IO(file=client, mode='w') as io: + io.write(nwbfile) # type: ignore + + +def test_read_lindi(): + print('test_read_lindi') + with lindi.LindiH5pyFile.from_lindi_file(nwb_lindi_fname, mode='r') as client: + with pynwb.NWBHDF5IO(file=client, mode='r') as io: + nwbfile = io.read() + print(nwbfile) + + +def test_write_h5(): + print('test_write_h5') + nwbfile = _create_sample_nwb_file() + with h5py.File(nwb_fname, 'w') as h5f: + with pynwb.NWBHDF5IO(file=h5f, mode='w') as io: + io.write(nwbfile) # type: ignore + + +def test_read_h5(): + print('test_read_h5') + with h5py.File(nwb_fname, 'r') as h5f: + with pynwb.NWBHDF5IO(file=h5f, mode='r') as io: + nwbfile = io.read() + print(nwbfile) + + +def _create_sample_nwb_file(): + from datetime import datetime + from uuid import uuid4 + + import numpy as np + from dateutil.tz import tzlocal + + from pynwb import NWBFile + from pynwb.ecephys import LFP, ElectricalSeries + + nwbfile: Any = NWBFile( + session_description="my first synthetic recording", + identifier=str(uuid4()), + session_start_time=datetime.now(tzlocal()), + experimenter=[ + "Baggins, Bilbo", + ], + lab="Bag End Laboratory", + institution="University of Middle Earth at the Shire", + experiment_description="I went on an adventure to reclaim vast treasures.", + session_id="LONELYMTN001", + ) + + device = nwbfile.create_device( + name="array", description="the best array", manufacturer="Probe Company 9000" + ) + + nwbfile.add_electrode_column(name="label", description="label of electrode") + + nshanks = 4 + nchannels_per_shank = 3 + electrode_counter = 0 + + for ishank in range(nshanks): + # create an electrode group for this shank + electrode_group = nwbfile.create_electrode_group( + name="shank{}".format(ishank), + description="electrode group for shank {}".format(ishank), + device=device, + location="brain area", + ) + # add electrodes to the electrode table + for ielec in range(nchannels_per_shank): + nwbfile.add_electrode( + group=electrode_group, + label="shank{}elec{}".format(ishank, ielec), + location="brain area", + ) + electrode_counter += 1 + + all_table_region = nwbfile.create_electrode_table_region( + region=list(range(electrode_counter)), # reference row indices 0 to N-1 + description="all electrodes", + ) + + raw_data = np.random.randn(50, 12) + raw_electrical_series = ElectricalSeries( + name="ElectricalSeries", + data=raw_data, + electrodes=all_table_region, + starting_time=0.0, # timestamp of the first sample in seconds relative to the session start time + rate=20000.0, # in Hz + ) + + nwbfile.add_acquisition(raw_electrical_series) + + lfp_data = np.random.randn(5000, 12) + lfp_electrical_series = ElectricalSeries( + name="ElectricalSeries", + data=lfp_data, + electrodes=all_table_region, + starting_time=0.0, + rate=200.0, + ) + + lfp = LFP(electrical_series=lfp_electrical_series) + + ecephys_module = nwbfile.create_processing_module( + name="ecephys", description="processed extracellular electrophysiology data" + ) + ecephys_module.add(lfp) + + nwbfile.add_unit_column(name="quality", description="sorting quality") + + firing_rate = 20 + n_units = 10 + res = 1000 + duration = 2000 + for n_units_per_shank in range(n_units): + spike_times = ( + np.where(np.random.rand((res * duration)) < (firing_rate / res))[0] / res + ) + nwbfile.add_unit(spike_times=spike_times, quality="good") + + return nwbfile + + +if __name__ == '__main__': + test_write_lindi() + test_read_lindi() + print('_________________________________') + print('') + + test_write_h5() + test_read_h5() diff --git a/examples/write_lindi_binary.py b/examples/write_lindi_binary.py new file mode 100644 index 0000000..8442321 --- /dev/null +++ b/examples/write_lindi_binary.py @@ -0,0 +1,21 @@ +import numpy as np +import lindi + + +def write_lindi_binary(): + with lindi.LindiH5pyFile.from_lindi_file('test.lindi.tar', mode='w') as f: + f.attrs['test'] = 42 + ds = f.create_dataset('data', shape=(1000, 1000), dtype='f4') + ds[...] = np.random.rand(1000, 1000) + + +def test_read(): + f = lindi.LindiH5pyFile.from_lindi_file('test.lindi.tar', mode='r') + print(f.attrs['test']) + print(f['data'][0, 0]) + f.close() + + +if __name__ == "__main__": + write_lindi_binary() + test_read() diff --git a/lindi/LindiH5ZarrStore/LindiH5ZarrStore.py b/lindi/LindiH5ZarrStore/LindiH5ZarrStore.py index bafbcaf..6ff58a2 100644 --- a/lindi/LindiH5ZarrStore/LindiH5ZarrStore.py +++ b/lindi/LindiH5ZarrStore/LindiH5ZarrStore.py @@ -23,6 +23,109 @@ from ..LocalCache.LocalCache import ChunkTooLargeError, LocalCache from ..LindiRemfile.LindiRemfile import LindiRemfile from .LindiH5ZarrStoreOpts import LindiH5ZarrStoreOpts +from ..LindiH5pyFile.LindiReferenceFileSystemStore import _get_padded_size, _pad_chunk + + +class SplitDatasetH5Item: + """ + Represents a dataset that is a single contiguous chunk in the hdf5 file, but + is split into multiple chunks for efficient slicing in the zarr store. + """ + def __init__(self, h5_item, *, contiguous_dataset_max_chunk_size: Union[int, None]): + self._h5_item = h5_item + self._contiguous_dataset_max_chunk_size = contiguous_dataset_max_chunk_size + should_split = False + if contiguous_dataset_max_chunk_size is not None: + codecs = h5_filters_to_codecs(h5_item) + if codecs is None or len(codecs) == 0: # does not have compression + if h5_item.chunks is None or h5_item.chunks == h5_item.shape: # only one chunk + if h5_item.dtype.kind in ['i', 'u', 'f']: # integer or float + size_bytes = int(np.prod(h5_item.shape)) * h5_item.dtype.itemsize + if size_bytes > contiguous_dataset_max_chunk_size: # large enough to split + should_split = True + self._do_split = should_split + if should_split: + size0 = int(np.prod(h5_item.shape[1:])) * h5_item.dtype.itemsize + # We want each chunk to be of size around + # contiguous_dataset_max_chunk_size. So if nn is the size of a chunk + # in the first dimension, then nn * size0 should be approximately + # contiguous_dataset_max_chunk_size. So nn should be approximately + # contiguous_dataset_max_chunk_size // size0 + nn = contiguous_dataset_max_chunk_size // size0 + if nn == 0: + # The chunk size should not be zero + nn = 1 + self._split_chunk_shape = (nn,) + h5_item.shape[1:] + if h5_item.chunks is not None: + zero_chunk_coords = (0,) * h5_item.ndim + try: + byte_offset, byte_count = _get_chunk_byte_range(h5_item, zero_chunk_coords) + except Exception as e: + raise Exception( + f"Error getting byte range for chunk when trying to split contiguous dataset {h5_item.name}: {e}" + ) + else: + # Get the byte range in the file for the contiguous dataset + byte_offset, byte_count = _get_byte_range_for_contiguous_dataset(h5_item) + self._split_chunk_byte_offset = byte_offset + self._split_chunk_byte_count = byte_count + self._num_chunks = int(np.prod(h5_item.shape[0:]) + np.prod(self._split_chunk_shape) - 1) // int(np.prod(self._split_chunk_shape)) + else: + self._split_chunk_shape = None + self._split_chunk_byte_offset = None + self._split_chunk_byte_count = None + self._num_chunks = None + + def get_chunk_byte_range(self, chunk_coords: Tuple[int, ...]): + if len(chunk_coords) != self.ndim: + raise Exception(f"SplitDatasetH5Item: Chunk coordinates {chunk_coords} do not match dataset dimensions") + for i in range(1, len(chunk_coords)): + if chunk_coords[i] != 0: + raise Exception(f"SplitDatasetH5Item: Unexpected non-zero chunk coordinate {chunk_coords[i]}") + if self._split_chunk_byte_offset is None: + raise Exception("SplitDatasetH5Item: Unexpected _split_chunk_byte_offset is None") + if self._split_chunk_shape is None: + raise Exception("SplitDatasetH5Item: Unexpected _split_chunk_shape is None") + chunk_index = chunk_coords[0] + byte_offset = self._split_chunk_byte_offset + chunk_index * int(np.prod(self._split_chunk_shape)) * self.dtype.itemsize + byte_count = int(np.prod(self._split_chunk_shape)) * self.dtype.itemsize + if byte_offset + byte_count > self._split_chunk_byte_offset + self._split_chunk_byte_count: + byte_count = self._split_chunk_byte_offset + self._split_chunk_byte_count - byte_offset + return byte_offset, byte_count + + @property + def shape(self): + return self._h5_item.shape + + @property + def dtype(self): + return self._h5_item.dtype + + @property + def name(self): + return self._h5_item.name + + @property + def chunks(self): + if self._do_split: + return self._split_chunk_shape + return self._h5_item.chunks + + @property + def ndim(self): + return self._h5_item.ndim + + @property + def fillvalue(self): + return self._h5_item.fillvalue + + @property + def attrs(self): + return self._h5_item.attrs + + @property + def size(self): + return self._h5_item.size class LindiH5ZarrStore(Store): @@ -65,6 +168,9 @@ def __init__( # it when the chunk is requested. self._inline_arrays: Dict[str, InlineArray] = {} + # For large contiguous arrays, we want to split them into smaller chunks. + self._split_datasets: Dict[str, SplitDatasetH5Item] = {} + self._external_array_links: Dict[str, Union[dict, None]] = {} @staticmethod @@ -118,6 +224,16 @@ def close(self): self._file = None def __getitem__(self, key): + val = self._get_helper(key) + + if val is not None: + padded_size = _get_padded_size(self, key, val) + if padded_size is not None: + val = _pad_chunk(val, padded_size) + + return val + + def _get_helper(self, key: str): """Get an item from the store (required by base class).""" parts = [part for part in key.split("/") if part] if len(parts) == 0: @@ -180,6 +296,8 @@ def __contains__(self, key): return False if not isinstance(h5_item, h5py.Dataset): return False + if self._split_datasets.get(key_parent, None) is not None: + h5_item = self._split_datasets[key_parent] external_array_link = self._get_external_array_link(key_parent, h5_item) if external_array_link is not None: # The chunk files do not exist for external array links @@ -278,7 +396,7 @@ def _get_zgroup_bytes(self, parent_key: str): zarr.group(store=memory_store) return reformat_json(memory_store.get(".zgroup")) - def _get_inline_array(self, key: str, h5_dataset: h5py.Dataset): + def _get_inline_array(self, key: str, h5_dataset: Union[h5py.Dataset, SplitDatasetH5Item]): if key in self._inline_arrays: return self._inline_arrays[key] self._inline_arrays[key] = InlineArray(h5_dataset) @@ -299,6 +417,11 @@ def _get_zarray_bytes(self, parent_key: str): filters = h5_filters_to_codecs(h5_item) + split_dataset = SplitDatasetH5Item(h5_item, contiguous_dataset_max_chunk_size=self._opts.contiguous_dataset_max_chunk_size) + if split_dataset._do_split: + self._split_datasets[parent_key] = split_dataset + h5_item = split_dataset + # We create a dummy zarr dataset with the appropriate shape, chunks, # dtype, and filters and then copy the .zarray JSON text from it memory_store = MemoryStore() @@ -370,6 +493,9 @@ def _get_chunk_file_bytes_data(self, key_parent: str, key_name: str): if not isinstance(h5_item, h5py.Dataset): raise Exception(f"Item {key_parent} is not a dataset") + if self._split_datasets.get(key_parent, None) is not None: + h5_item = self._split_datasets[key_parent] + external_array_link = self._get_external_array_link(key_parent, h5_item) if external_array_link is not None: raise Exception( @@ -418,7 +544,10 @@ def _get_chunk_file_bytes_data(self, key_parent: str, key_name: str): if h5_item.chunks is not None: # Get the byte range in the file for the chunk. try: - byte_offset, byte_count = _get_chunk_byte_range(h5_item, chunk_coords) + if isinstance(h5_item, SplitDatasetH5Item): + byte_offset, byte_count = h5_item.get_chunk_byte_range(chunk_coords) + else: + byte_offset, byte_count = _get_chunk_byte_range(h5_item, chunk_coords) except Exception as e: raise Exception( f"Error getting byte range for chunk {key_parent}/{key_name}. Shape: {h5_item.shape}, Chunks: {h5_item.chunks}, Chunk coords: {chunk_coords}: {e}" @@ -430,6 +559,8 @@ def _get_chunk_file_bytes_data(self, key_parent: str, key_name: str): raise Exception( f"Chunk coordinates {chunk_coords} are not (0, 0, 0, ...) for contiguous dataset {key_parent} with dtype {h5_item.dtype} and shape {h5_item.shape}" ) + if isinstance(h5_item, SplitDatasetH5Item): + raise Exception(f'Unexpected SplitDatasetH5Item for contiguous dataset {key_parent}') # Get the byte range in the file for the contiguous dataset byte_offset, byte_count = _get_byte_range_for_contiguous_dataset(h5_item) return byte_offset, byte_count, None @@ -440,6 +571,9 @@ def _add_chunk_info_to_refs(self, key_parent: str, add_ref: Callable, add_ref_ch h5_item = self._h5f.get('/' + key_parent, None) assert isinstance(h5_item, h5py.Dataset) + if self._split_datasets.get(key_parent, None) is not None: + h5_item = self._split_datasets[key_parent] + # If the shape is (0,), (0, 0), (0, 0, 0), etc., then do not add any chunk references if np.prod(h5_item.shape) == 0: return @@ -467,7 +601,7 @@ def _add_chunk_info_to_refs(self, key_parent: str, add_ref: Callable, add_ref_ch # does not provide a way to hook in a progress bar # We use max number of chunks instead of actual number of chunks because get_num_chunks is slow # for remote datasets. - num_chunks = _get_max_num_chunks(h5_item) # NOTE: unallocated chunks are counted + num_chunks = _get_max_num_chunks(shape=h5_item.shape, chunk_size=h5_item.chunks) # NOTE: unallocated chunks are counted pbar = tqdm( total=num_chunks, desc=f"Writing chunk info for {key_parent}", @@ -477,24 +611,35 @@ def _add_chunk_info_to_refs(self, key_parent: str, add_ref: Callable, add_ref_ch chunk_size = h5_item.chunks - def store_chunk_info(chunk_info): - # Get the byte range in the file for each chunk. - chunk_offset: Tuple[int, ...] = chunk_info.chunk_offset - byte_offset = chunk_info.byte_offset - byte_count = chunk_info.size - key_name = ".".join([str(a // b) for a, b in zip(chunk_offset, chunk_size)]) - add_ref_chunk(f"{key_parent}/{key_name}", (self._url, byte_offset, byte_count)) - pbar.update() + if isinstance(h5_item, SplitDatasetH5Item): + assert h5_item._num_chunks is not None, "Unexpected: _num_chunks is None" + for i in range(h5_item._num_chunks): + chunk_coords = (i,) + (0,) * (h5_item.ndim - 1) + byte_offset, byte_count = h5_item.get_chunk_byte_range(chunk_coords) + key_name = ".".join([str(x) for x in chunk_coords]) + add_ref_chunk(f"{key_parent}/{key_name}", (self._url, byte_offset, byte_count)) + pbar.update() + else: + def store_chunk_info(chunk_info): + # Get the byte range in the file for each chunk. + chunk_offset: Tuple[int, ...] = chunk_info.chunk_offset + byte_offset = chunk_info.byte_offset + byte_count = chunk_info.size + key_name = ".".join([str(a // b) for a, b in zip(chunk_offset, chunk_size)]) + add_ref_chunk(f"{key_parent}/{key_name}", (self._url, byte_offset, byte_count)) + pbar.update() + + _apply_to_all_chunk_info(h5_item, store_chunk_info) - _apply_to_all_chunk_info(h5_item, store_chunk_info) pbar.close() else: # Get the byte range in the file for the contiguous dataset + assert not isinstance(h5_item, SplitDatasetH5Item), "Unexpected SplitDatasetH5Item for contiguous dataset" byte_offset, byte_count = _get_byte_range_for_contiguous_dataset(h5_item) key_name = ".".join("0" for _ in range(h5_item.ndim)) add_ref_chunk(f"{key_parent}/{key_name}", (self._url, byte_offset, byte_count)) - def _get_external_array_link(self, parent_key: str, h5_item: h5py.Dataset): + def _get_external_array_link(self, parent_key: str, h5_item: Union[h5py.Dataset, SplitDatasetH5Item]): # First check the memory cache if parent_key in self._external_array_links: return self._external_array_links[parent_key] @@ -510,7 +655,7 @@ def _get_external_array_link(self, parent_key: str, h5_item: h5py.Dataset): (shape[i] + chunks[i] - 1) // chunks[i] if chunks[i] != 0 else 0 for i in range(len(shape)) ] - num_chunks = np.prod(chunk_coords_shape) + num_chunks = int(np.prod(chunk_coords_shape)) if num_chunks > self._opts.num_dataset_chunks_threshold: if self._url is not None: self._external_array_links[parent_key] = { @@ -663,7 +808,7 @@ def _process_dataset(key, item: h5py.Dataset): class InlineArray: - def __init__(self, h5_dataset: h5py.Dataset): + def __init__(self, h5_dataset: Union[h5py.Dataset, SplitDatasetH5Item]): self._additional_zarr_attributes = {} if h5_dataset.shape == (): self._additional_zarr_attributes["_SCALAR"] = True @@ -686,9 +831,15 @@ def __init__(self, h5_dataset: h5py.Dataset): # For example: [['x', 'uint32'], ['y', 'uint32'], ['weight', 'float32']] self._additional_zarr_attributes["_COMPOUND_DTYPE"] = compound_dtype if self._is_inline: + if isinstance(h5_dataset, SplitDatasetH5Item): + raise Exception('SplitDatasetH5Item should not be an inline array') memory_store = MemoryStore() dummy_group = zarr.group(store=memory_store) size_is_zero = np.prod(h5_dataset.shape) == 0 + if isinstance(h5_dataset, SplitDatasetH5Item): + h5_item = h5_dataset._h5_item + else: + h5_item = h5_dataset create_zarr_dataset_from_h5_data( zarr_parent_group=dummy_group, name='X', @@ -700,8 +851,8 @@ def __init__(self, h5_dataset: h5py.Dataset): label=f'{h5_dataset.name}', h5_shape=h5_dataset.shape, h5_dtype=h5_dataset.dtype, - h5f=h5_dataset.file, - h5_data=h5_dataset[...] + h5f=h5_item.file, + h5_data=h5_item[...] ) self._zarray_bytes = reformat_json(memory_store['X/.zarray']) if not size_is_zero: diff --git a/lindi/LindiH5ZarrStore/LindiH5ZarrStoreOpts.py b/lindi/LindiH5ZarrStore/LindiH5ZarrStoreOpts.py index 40fe998..d8ea82e 100644 --- a/lindi/LindiH5ZarrStore/LindiH5ZarrStoreOpts.py +++ b/lindi/LindiH5ZarrStore/LindiH5ZarrStoreOpts.py @@ -13,5 +13,12 @@ class LindiH5ZarrStoreOpts: the dataset will be represented as an external array link. If None, then no datasets will be represented as external array links (equivalent to a threshold of 0). Default is 1000. + + contiguous_dataset_max_chunk_size (Union[int, None]): For large + contiguous arrays in the hdf5 file that are not chunked, this option + specifies the maximum size in bytes of the zarr chunks that will be + created. If None, then the entire array will be represented as a single + chunk. Default is 1000 * 1000 * 20 """ num_dataset_chunks_threshold: Union[int, None] = 1000 + contiguous_dataset_max_chunk_size: Union[int, None] = 1000 * 1000 * 20 diff --git a/lindi/LindiH5ZarrStore/_util.py b/lindi/LindiH5ZarrStore/_util.py index 0badbae..681164a 100644 --- a/lindi/LindiH5ZarrStore/_util.py +++ b/lindi/LindiH5ZarrStore/_util.py @@ -12,15 +12,16 @@ def _read_bytes(file: IO, offset: int, count: int): return file.read(count) -def _get_max_num_chunks(h5_dataset: h5py.Dataset): +def _get_max_num_chunks(*, shape, chunk_size): """Get the maximum number of chunks in an h5py dataset. This is similar to h5_dataset.id.get_num_chunks() but significantly faster. It does not account for whether some chunks are allocated. """ - chunk_size = h5_dataset.chunks assert chunk_size is not None - return math.prod([math.ceil(a / b) for a, b in zip(h5_dataset.shape, chunk_size)]) + if np.prod(chunk_size) == 0: + return 0 + return math.prod([math.ceil(a / b) for a, b in zip(shape, chunk_size)]) def _apply_to_all_chunk_info(h5_dataset: h5py.Dataset, callback: Callable): diff --git a/lindi/LindiH5pyFile/LindiH5pyFile.py b/lindi/LindiH5pyFile/LindiH5pyFile.py index 5968a34..37681c7 100644 --- a/lindi/LindiH5pyFile/LindiH5pyFile.py +++ b/lindi/LindiH5pyFile/LindiH5pyFile.py @@ -20,6 +20,9 @@ from ..LindiH5ZarrStore._util import _write_rfs_to_file +from ..tar.lindi_tar import LindiTarFile +from ..tar.LindiTarStore import LindiTarStore + LindiFileMode = Literal["r", "r+", "w", "w-", "x", "a"] @@ -29,7 +32,7 @@ class LindiH5pyFile(h5py.File): - def __init__(self, _zarr_group: zarr.Group, *, _zarr_store: Union[ZarrStore, None] = None, _mode: LindiFileMode = "r", _local_cache: Union[LocalCache, None] = None, _local_file_path: Union[str, None] = None): + def __init__(self, _zarr_group: zarr.Group, *, _zarr_store: Union[ZarrStore, None] = None, _mode: LindiFileMode = "r", _local_cache: Union[LocalCache, None] = None, _source_url_or_path: Union[str, None] = None, _source_tar_file: Union[LindiTarFile, None] = None, _close_source_tar_file_on_close: bool = False): """ Do not use this constructor directly. Instead, use: from_lindi_file, from_h5py_file, from_reference_file_system, from_zarr_store, or @@ -40,22 +43,28 @@ def __init__(self, _zarr_group: zarr.Group, *, _zarr_store: Union[ZarrStore, Non self._mode: LindiFileMode = _mode self._the_group = LindiH5pyGroup(_zarr_group, self) self._local_cache = _local_cache - self._local_file_path = _local_file_path + self._source_url_or_path = _source_url_or_path + self._source_tar_file = _source_tar_file + self._close_source_tar_file_on_close = _close_source_tar_file_on_close # see comment in LindiH5pyGroup self._id = f'{id(self._zarr_group)}/' + self._is_open = True + @staticmethod - def from_lindi_file(url_or_path: str, *, mode: LindiFileMode = "r", staging_area: Union[StagingArea, None] = None, local_cache: Union[LocalCache, None] = None, local_file_path: Union[str, None] = None): + def from_lindi_file(url_or_path: str, *, mode: LindiFileMode = "r", staging_area: Union[StagingArea, None] = None, local_cache: Union[LocalCache, None] = None): """ Create a LindiH5pyFile from a URL or path to a .lindi.json file. For a description of parameters, see from_reference_file_system(). """ - if local_file_path is None: - if not url_or_path.startswith("http://") and not url_or_path.startswith("https://"): - local_file_path = url_or_path - return LindiH5pyFile.from_reference_file_system(url_or_path, mode=mode, staging_area=staging_area, local_cache=local_cache, local_file_path=local_file_path) + return LindiH5pyFile.from_reference_file_system( + url_or_path, + mode=mode, + staging_area=staging_area, + local_cache=local_cache + ) @staticmethod def from_hdf5_file( @@ -99,7 +108,7 @@ def from_hdf5_file( ) @staticmethod - def from_reference_file_system(rfs: Union[dict, str, None], *, mode: LindiFileMode = "r", staging_area: Union[StagingArea, None] = None, local_cache: Union[LocalCache, None] = None, local_file_path: Union[str, None] = None): + def from_reference_file_system(rfs: Union[dict, str, None], *, mode: LindiFileMode = "r", staging_area: Union[StagingArea, None] = None, local_cache: Union[LocalCache, None] = None, _source_url_or_path: Union[str, None] = None, _source_tar_file: Union[LindiTarFile, None] = None, _close_source_tar_file_on_close: bool = False): """ Create a LindiH5pyFile from a reference file system. @@ -116,11 +125,12 @@ def from_reference_file_system(rfs: Union[dict, str, None], *, mode: LindiFileMo is only used in write mode, by default None. local_cache : Union[LocalCache, None], optional The local cache to use for caching data, by default None. - local_file_path : Union[str, None], optional - If rfs is not a string or is a remote url, this is the path to the - local file for the purpose of writing to it. It is required in this - case if mode is not "r". If rfs is a string and not a remote url, it - must be equal to local_file_path if provided. + _source_url_or_path : Union[str, None], optional + Internal use only + _source_tar_file : Union[LindiTarFile, None], optional + Internal use only + _close_source_tar_file_on_close : bool, optional + Internal use only """ if rfs is None: rfs = { @@ -132,25 +142,25 @@ def from_reference_file_system(rfs: Union[dict, str, None], *, mode: LindiFileMo } if isinstance(rfs, str): + if _source_url_or_path is not None: + raise Exception("_source_file_path is not None even though rfs is a string") + if _source_tar_file is not None: + raise Exception("_source_tar_file is not None even though rfs is a string") rfs_is_url = rfs.startswith("http://") or rfs.startswith("https://") - if local_file_path is not None and not rfs_is_url and rfs != local_file_path: - raise Exception(f"rfs is not a remote url, so local_file_path must be the same as rfs, but got: {rfs} and {local_file_path}") if rfs_is_url: - with tempfile.TemporaryDirectory() as tmpdir: - filename = f"{tmpdir}/temp.lindi.json" - _download_file(rfs, filename) - with open(filename, "r") as f: - data = json.load(f) - assert isinstance(data, dict) # prevent infinite recursion - return LindiH5pyFile.from_reference_file_system(data, mode=mode, staging_area=staging_area, local_cache=local_cache, local_file_path=local_file_path) + data, tar_file = _load_rfs_from_url(rfs) + return LindiH5pyFile.from_reference_file_system( + data, + mode=mode, + staging_area=staging_area, + local_cache=local_cache, + _source_tar_file=tar_file, + _source_url_or_path=rfs, + _close_source_tar_file_on_close=_close_source_tar_file_on_close + ) else: - empty_rfs = { - "refs": { - '.zgroup': { - 'zarr_format': 2 - } - }, - } + # local file (or directory) + need_to_create_empty_file = False if mode == "r": # Readonly, file must exist (default) if not os.path.exists(rfs): @@ -161,36 +171,62 @@ def from_reference_file_system(rfs: Union[dict, str, None], *, mode: LindiFileMo raise Exception(f"File does not exist: {rfs}") elif mode == "w": # Create file, truncate if exists - with open(rfs, "w") as f: - json.dump(empty_rfs, f) + need_to_create_empty_file = True + elif mode in ["w-", "x"]: # Create file, fail if exists if os.path.exists(rfs): raise Exception(f"File already exists: {rfs}") - with open(rfs, "w") as f: - json.dump(empty_rfs, f) + need_to_create_empty_file = True elif mode == "a": # Read/write if exists, create otherwise - if os.path.exists(rfs): - with open(rfs, "r") as f: - data = json.load(f) + if not os.path.exists(rfs): + need_to_create_empty_file = True else: raise Exception(f"Unhandled mode: {mode}") - with open(rfs, "r") as f: - data = json.load(f) + if need_to_create_empty_file: + is_tar = rfs.endswith(".tar") + is_dir = rfs.endswith(".d") + _create_empty_lindi_file(rfs, is_tar=is_tar, is_dir=is_dir) + data, tar_file = _load_rfs_from_local_file_or_dir(rfs) assert isinstance(data, dict) # prevent infinite recursion - return LindiH5pyFile.from_reference_file_system(data, mode=mode, staging_area=staging_area, local_cache=local_cache, local_file_path=local_file_path) + return LindiH5pyFile.from_reference_file_system( + data, + mode=mode, + staging_area=staging_area, + local_cache=local_cache, + _source_url_or_path=rfs, + _source_tar_file=tar_file, + _close_source_tar_file_on_close=True + ) elif isinstance(rfs, dict): # This store does not need to be closed - store = LindiReferenceFileSystemStore(rfs, local_cache=local_cache) + store = LindiReferenceFileSystemStore( + rfs, + local_cache=local_cache, + _source_url_or_path=_source_url_or_path, + _source_tar_file=_source_tar_file + ) + source_is_url = _source_url_or_path is not None and (_source_url_or_path.startswith("http://") or _source_url_or_path.startswith("https://")) if staging_area: + if _source_tar_file and not source_is_url: + raise Exception("Cannot use staging area when source is a local tar file") store = LindiStagingStore(base_store=store, staging_area=staging_area) - return LindiH5pyFile.from_zarr_store(store, mode=mode, local_file_path=local_file_path, local_cache=local_cache) + elif _source_url_or_path and _source_tar_file and not source_is_url: + store = LindiTarStore(base_store=store, tar_file=_source_tar_file) + return LindiH5pyFile.from_zarr_store( + store, + mode=mode, + local_cache=local_cache, + _source_url_or_path=_source_url_or_path, + _source_tar_file=_source_tar_file, + _close_source_tar_file_on_close=_close_source_tar_file_on_close + ) else: raise Exception(f"Unhandled type for rfs: {type(rfs)}") @staticmethod - def from_zarr_store(zarr_store: ZarrStore, mode: LindiFileMode = "r", local_cache: Union[LocalCache, None] = None, local_file_path: Union[str, None] = None): + def from_zarr_store(zarr_store: ZarrStore, mode: LindiFileMode = "r", local_cache: Union[LocalCache, None] = None, _source_url_or_path: Union[str, None] = None, _source_tar_file: Union[LindiTarFile, None] = None, _close_source_tar_file_on_close: bool = False): """ Create a LindiH5pyFile from a zarr store. @@ -207,10 +243,10 @@ def from_zarr_store(zarr_store: ZarrStore, mode: LindiFileMode = "r", local_cach # does not need to be closed zarr_group = zarr.open(store=zarr_store, mode=mode) assert isinstance(zarr_group, zarr.Group) - return LindiH5pyFile.from_zarr_group(zarr_group, _zarr_store=zarr_store, mode=mode, local_cache=local_cache, local_file_path=local_file_path) + return LindiH5pyFile.from_zarr_group(zarr_group, _zarr_store=zarr_store, mode=mode, local_cache=local_cache, _source_url_or_path=_source_url_or_path, _source_tar_file=_source_tar_file, _close_source_tar_file_on_close=_close_source_tar_file_on_close) @staticmethod - def from_zarr_group(zarr_group: zarr.Group, *, mode: LindiFileMode = "r", _zarr_store: Union[ZarrStore, None] = None, local_cache: Union[LocalCache, None] = None, local_file_path: Union[str, None] = None): + def from_zarr_group(zarr_group: zarr.Group, *, mode: LindiFileMode = "r", _zarr_store: Union[ZarrStore, None] = None, local_cache: Union[LocalCache, None] = None, _source_url_or_path: Union[str, None] = None, _source_tar_file: Union[LindiTarFile, None] = None, _close_source_tar_file_on_close: bool = False): """ Create a LindiH5pyFile from a zarr group. @@ -228,7 +264,7 @@ def from_zarr_group(zarr_group: zarr.Group, *, mode: LindiFileMode = "r", _zarr_ See from_zarr_store(). """ - return LindiH5pyFile(zarr_group, _zarr_store=_zarr_store, _mode=mode, _local_cache=local_cache, _local_file_path=local_file_path) + return LindiH5pyFile(zarr_group, _zarr_store=_zarr_store, _mode=mode, _local_cache=local_cache, _source_url_or_path=_source_url_or_path, _source_tar_file=_source_tar_file, _close_source_tar_file_on_close=_close_source_tar_file_on_close) def to_reference_file_system(self): """ @@ -241,6 +277,8 @@ def to_reference_file_system(self): if isinstance(zarr_store, LindiStagingStore): zarr_store.consolidate_chunks() zarr_store = zarr_store._base_store + if isinstance(zarr_store, LindiTarStore): + zarr_store = zarr_store._base_store if isinstance(zarr_store, LindiH5ZarrStore): return zarr_store.to_reference_file_system() if not isinstance(zarr_store, LindiReferenceFileSystemStore): @@ -305,23 +343,36 @@ def upload( def write_lindi_file(self, filename: str, *, generation_metadata: Union[dict, None] = None): """ - Write the reference file system to a .lindi.json file. + Write the reference file system to a lindi or .lindi.json file. Parameters ---------- filename : str - The filename to write to. It must end with '.lindi.json'. + The filename to write to. It must end with '.lindi.json' or '.lindi.tar'. generation_metadata : Union[dict, None], optional The optional generation metadata to include in the reference file system, by default None. This information dict is simply set to the 'generationMetadata' key in the reference file system. """ - if not filename.endswith(".lindi.json"): - raise Exception("Filename must end with '.lindi.json'") + if not filename.endswith(".lindi.json") and not filename.endswith(".lindi.tar"): + raise Exception("Filename must end with '.lindi.json' or '.lindi.tar'") rfs = self.to_reference_file_system() + if self._source_tar_file: + source_is_remote = self._source_url_or_path is not None and (self._source_url_or_path.startswith("http://") or self._source_url_or_path.startswith("https://")) + if not source_is_remote: + raise Exception("Cannot write to lindi file if the source is a local lindi tar file because it would not be able to resolve the local references within the tar file.") + assert self._source_url_or_path is not None + _update_internal_references_to_remote_tar_file(rfs, self._source_url_or_path, self._source_tar_file) if generation_metadata is not None: rfs['generationMetadata'] = generation_metadata - _write_rfs_to_file(rfs=rfs, output_file_name=filename) + if filename.endswith(".lindi.json"): + _write_rfs_to_file(rfs=rfs, output_file_name=filename) + elif filename.endswith(".lindi.tar"): + LindiTarFile.create(filename, rfs=rfs) + elif filename.endswith(".d"): + LindiTarFile.create(filename, rfs=rfs, dir_representation=True) + else: + raise Exception("Unhandled file extension") @property def attrs(self): # type: ignore @@ -355,12 +406,27 @@ def swmr_mode(self, value): # type: ignore raise Exception("Getting swmr_mode is not allowed") def close(self): + if not self._is_open: + print('Warning: LINDI file already closed.') + return self.flush() + if self._close_source_tar_file_on_close and self._source_tar_file: + self._source_tar_file.close() + self._is_open = False def flush(self): - if self._mode != 'r' and self._local_file_path is not None: + if not self._is_open: + return + if self._mode != 'r' and self._source_url_or_path is not None: + is_url = self._source_url_or_path.startswith("http://") or self._source_url_or_path.startswith("https://") + if is_url: + raise Exception("Cannot write to URL") rfs = self.to_reference_file_system() - _write_rfs_to_file(rfs=rfs, output_file_name=self._local_file_path) + if self._source_tar_file: + self._source_tar_file.write_rfs(rfs) + self._source_tar_file._update_index_in_file() # very important + else: + _write_rfs_to_file(rfs=rfs, output_file_name=self._source_url_or_path) def __enter__(self): # type: ignore return self @@ -446,6 +512,12 @@ def get(self, name, default=None, getclass=False, getlink=False): raise Exception("Getting class is not allowed") return self._get_item(name, getlink=getlink, default=default) + def keys(self): # type: ignore + return self._the_group.keys() + + def items(self): + return self._the_group.items() + def __iter__(self): return self._the_group.__iter__() @@ -605,3 +677,158 @@ def _format_size_bytes(size_bytes: int) -> str: return f"{size_bytes / 1024 / 1024:.1f} MB" else: return f"{size_bytes / 1024 / 1024 / 1024:.1f} GB" + + +def _load_rfs_from_url(url: str): + file_size = _get_file_size_of_remote_file(url) + if file_size < 1024 * 1024 * 2: + # if it's a small file, we'll just download the whole thing + with tempfile.TemporaryDirectory() as tmpdir: + tmp_fname = f"{tmpdir}/temp.lindi.json" + _download_file(url, tmp_fname) + data, tar_file = _load_rfs_from_local_file_or_dir(tmp_fname) + return data, tar_file + else: + # if it's a large file, we start by downloading the entry file and then the index file + tar_entry_buf = _download_file_byte_range(url, 0, 512) + is_tar = _check_is_tar_header(tar_entry_buf[:512]) + if is_tar: + tar_file = LindiTarFile(url) + rfs_json = tar_file.read_file("lindi.json") + rfs = json.loads(rfs_json) + return rfs, tar_file + else: + # In this case, it must be a regular json file + with tempfile.TemporaryDirectory() as tmpdir: + tmp_fname = f"{tmpdir}/temp.lindi.json" + _download_file(url, tmp_fname) + with open(tmp_fname, "r") as f: + return json.load(f), None + + +def _load_rfs_from_local_file_or_dir(fname: str): + if os.path.isdir(fname): + dir_file = LindiTarFile(fname, dir_representation=True) + rfs_json = dir_file.read_file("lindi.json") + rfs = json.loads(rfs_json) + return rfs, dir_file + file_size = os.path.getsize(fname) + if file_size >= 512: + # Read first bytes to check if it's a tar file + with open(fname, "rb") as f: + tar_entry_buf = f.read(512) + is_tar = _check_is_tar_header(tar_entry_buf) + if is_tar: + tar_file = LindiTarFile(fname) + rfs_json = tar_file.read_file("lindi.json") + rfs = json.loads(rfs_json) + return rfs, tar_file + + # Must be a regular json file + with open(fname, "r") as f: + return json.load(f), None + + +def _check_is_tar_header(header_buf: bytes) -> bool: + if len(header_buf) < 512: + return False + + # We're only going to support ustar format + # get the ustar indicator at bytes 257-262 + if header_buf[257:262] == b"ustar" and header_buf[262] == 0: + # Note that it's unlikely but possible that a json file could have the + # string "ustar" at these bytes, but it would not have a null byte at + # byte 262 + return True + + # Check for any 0 bytes in the header + if b"\0" in header_buf: + print(header_buf[257:262]) + raise Exception("Problem with lindi file: 0 byte found in header, but not ustar tar format") + + return False + + +def _get_file_size_of_remote_file(url: str) -> int: + headers = { + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3" + } + req = urllib.request.Request(url, headers=headers) + with urllib.request.urlopen(req) as response: + return int(response.headers['Content-Length']) + + +def _download_file_byte_range(url: str, start: int, end: int) -> bytes: + headers = { + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3", + "Range": f"bytes={start}-{end - 1}" + } + req = urllib.request.Request(url, headers=headers) + with urllib.request.urlopen(req) as response: + return response.read() + + +empty_rfs = { + "refs": { + ".zgroup": { + "zarr_format": 2 + } + } +} + + +def _create_empty_lindi_file(fname: str, *, is_tar: bool = False, is_dir: bool = False): + if is_tar: + if is_dir: + raise Exception("Cannot be both tar and dir") + LindiTarFile.create(fname, rfs=empty_rfs) + elif is_dir: + LindiTarFile.create(fname, rfs=empty_rfs, dir_representation=True) + else: + with open(fname, "w") as f: + json.dump(empty_rfs, f) + + +def _update_internal_references_to_remote_tar_file(rfs: dict, remote_url: str, remote_tar_file: LindiTarFile): + # This is tricky. This happens when the source is a remote tar file and we + # are trying to write the lindi file locally, but we need to update the + # internal references to point to the remote tar file. Yikes. + + # First we remove all templates to simplify the process. We will restore them below. + LindiReferenceFileSystemStore.remove_templates_in_rfs(rfs) + + for k, v in rfs['refs'].items(): + if isinstance(v, list): + if len(v) == 3: + url = v[0] + if url.startswith('./'): + internal_path = url[2:] + if not remote_tar_file._dir_representation: + info = remote_tar_file.get_file_info(internal_path) + start_byte = info['d'] + num_bytes = info['s'] + v[0] = remote_url + v[1] = start_byte + v[1] + if v[1] + v[2] > start_byte + num_bytes: + raise Exception(f"Reference goes beyond end of file: {v[1] + v[2]} > {num_bytes}") + # v[2] stays the same, it is the size + else: + v[0] = remote_url + '/' + internal_path + elif len(v) == 1: + # This is a reference to the full file + url = v[0] + if url.startswith('./'): + internal_path = url[2:] + if not remote_tar_file._dir_representation: + info = remote_tar_file.get_file_info(internal_path) + start_byte = info['d'] + num_bytes = info['s'] + v[0] = remote_url + v.append(start_byte) + v.append(num_bytes) + else: + v[0] = remote_url + '/' + internal_path + else: + raise Exception(f"Unexpected length for reference: {len(v)}") + + LindiReferenceFileSystemStore.use_templates_in_rfs(rfs) diff --git a/lindi/LindiH5pyFile/LindiReferenceFileSystemStore.py b/lindi/LindiH5pyFile/LindiReferenceFileSystemStore.py index 7b51fcc..e0069b2 100644 --- a/lindi/LindiH5pyFile/LindiReferenceFileSystemStore.py +++ b/lindi/LindiH5pyFile/LindiReferenceFileSystemStore.py @@ -1,10 +1,14 @@ from typing import Literal, Dict, Union +import os import json +import time import base64 +import numpy as np import requests from zarr.storage import Store as ZarrStore from ..LocalCache.LocalCache import ChunkTooLargeError, LocalCache +from ..tar.lindi_tar import LindiTarFile class LindiReferenceFileSystemStore(ZarrStore): @@ -68,7 +72,15 @@ class LindiReferenceFileSystemStore(ZarrStore): It is okay for rfs to be modified outside of this class, and the changes will be reflected immediately in the store. """ - def __init__(self, rfs: dict, *, mode: Literal["r", "r+"] = "r+", local_cache: Union[LocalCache, None] = None): + def __init__( + self, + rfs: dict, + *, + mode: Literal["r", "r+"] = "r+", + local_cache: Union[LocalCache, None] = None, + _source_url_or_path: Union[str, None] = None, + _source_tar_file: Union[LindiTarFile, None] = None + ): """ Create a LindiReferenceFileSystemStore. @@ -113,6 +125,8 @@ def __init__(self, rfs: dict, *, mode: Literal["r", "r+"] = "r+", local_cache: U self.rfs = rfs self.mode = mode self.local_cache = local_cache + self._source_url_or_path = _source_url_or_path + self._source_tar_file = _source_tar_file # These methods are overridden from MutableMapping def __contains__(self, key: object): @@ -121,6 +135,16 @@ def __contains__(self, key: object): return key in self.rfs["refs"] def __getitem__(self, key: str): + val = self._get_helper(key) + + if val is not None: + padded_size = _get_padded_size(self, key, val) + if padded_size is not None: + val = _pad_chunk(val, padded_size) + + return val + + def _get_helper(self, key: str): if key not in self.rfs["refs"]: raise KeyError(key) x = self.rfs["refs"][key] @@ -134,22 +158,52 @@ def __getitem__(self, key: str): elif isinstance(x, list): if len(x) != 3: raise Exception("list must have 3 elements") # pragma: no cover - url = x[0] + url_or_path = x[0] offset = x[1] length = x[2] - if '{{' in url and '}}' in url and 'templates' in self.rfs: + if '{{' in url_or_path and '}}' in url_or_path and 'templates' in self.rfs: for k, v in self.rfs["templates"].items(): - url = url.replace("{{" + k + "}}", v) - if self.local_cache is not None: - x = self.local_cache.get_remote_chunk(url=url, offset=offset, size=length) - if x is not None: - return x - val = _read_bytes_from_url_or_path(url, offset, length) - if self.local_cache is not None: - try: - self.local_cache.put_remote_chunk(url=url, offset=offset, size=length, data=val) - except ChunkTooLargeError: - print(f'Warning: unable to cache chunk of size {length} on LocalCache (key: {key})') + url_or_path = url_or_path.replace("{{" + k + "}}", v) + is_url = url_or_path.startswith('http://') or url_or_path.startswith('https://') + if url_or_path.startswith('./'): + if self._source_url_or_path is None: + raise Exception(f"Cannot resolve relative path {url_or_path} without source file path") + if self._source_tar_file is None: + raise Exception(f"Cannot resolve relative path {url_or_path} without source file type") + if self._source_tar_file and (not self._source_tar_file._dir_representation): + start_byte, end_byte = self._source_tar_file.get_file_byte_range(file_name=url_or_path[2:]) + if start_byte + offset + length > end_byte: + raise Exception(f"Chunk {key} is out of bounds in tar file {url_or_path}") + url_or_path = self._source_url_or_path + offset = offset + start_byte + elif self._source_tar_file and self._source_tar_file._dir_representation: + fname = self._source_tar_file._tar_path_or_url + '/' + url_or_path[2:] + if not os.path.exists(fname): + raise Exception(f"File does not exist: {fname}") + file_size = os.path.getsize(fname) + if offset + length > file_size: + raise Exception(f"Chunk {key} is out of bounds in tar file {url_or_path}: {fname}") + url_or_path = fname + else: + if is_url: + raise Exception(f"Cannot resolve relative path {url_or_path} for URL that is not a tar") + else: + source_file_parent_dir = '/'.join(self._source_url_or_path.split('/')[:-1]) + abs_path = source_file_parent_dir + '/' + url_or_path[2:] + url_or_path = abs_path + if is_url: + if self.local_cache is not None: + x = self.local_cache.get_remote_chunk(url=url_or_path, offset=offset, size=length) + if x is not None: + return x + val = _read_bytes_from_url_or_path(url_or_path, offset, length) + if self.local_cache is not None: + try: + self.local_cache.put_remote_chunk(url=url_or_path, offset=offset, size=length, data=val) + except ChunkTooLargeError: + print(f'Warning: unable to cache chunk of size {length} on LocalCache (key: {key})') + else: + val = _read_bytes_from_url_or_path(url_or_path, offset, length) return val else: # should not happen given checks in __init__, but self.rfs is mutable @@ -241,6 +295,22 @@ def use_templates_in_rfs(rfs: dict) -> None: if url in template_names_for_urls: v[0] = '{{' + template_names_for_urls[url] + '}}' + @staticmethod + def remove_templates_in_rfs(rfs: dict) -> None: + """ + Utility for removing templates from an rfs. This is the opposite of + use_templates_in_rfs. + """ + templates0 = rfs.get('templates', {}) + for k, v in rfs['refs'].items(): + if isinstance(v, list): + url = v[0] + if '{{' in url and '}}' in url: + template_name = url[2:-2].strip() + if template_name in templates0: + v[0] = templates0[template_name] + rfs['templates'] = {} + def _read_bytes_from_url_or_path(url_or_path: str, offset: int, length: int): """ @@ -248,18 +318,74 @@ def _read_bytes_from_url_or_path(url_or_path: str, offset: int, length: int): """ from ..LindiRemfile.LindiRemfile import _resolve_url if url_or_path.startswith('http://') or url_or_path.startswith('https://'): - url_resolved = _resolve_url(url_or_path) # handle DANDI auth - range_start = offset - range_end = offset + length - 1 - range_header = f"bytes={range_start}-{range_end}" - headers = { - "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3", - "Range": range_header - } - response = requests.get(url_resolved, headers=headers) - response.raise_for_status() - return response.content + num_retries = 8 + for try_num in range(num_retries): + try: + url_resolved = _resolve_url(url_or_path) # handle DANDI auth + range_start = offset + range_end = offset + length - 1 + range_header = f"bytes={range_start}-{range_end}" + headers = { + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3", + "Range": range_header + } + response = requests.get(url_resolved, headers=headers) + response.raise_for_status() + return response.content + except Exception as e: + if try_num == num_retries - 1: + raise e + else: + delay = 0.1 * 2 ** try_num + print(f'Retry load data from {url_or_path} in {delay} seconds') + time.sleep(delay) + raise Exception(f"Failed to load data from {url_or_path}") else: with open(url_or_path, 'rb') as f: f.seek(offset) return f.read(length) + + +def _is_chunk_base_key(base_key: str) -> bool: + a = base_key.split('.') + if len(a) == 0: + return False + for x in a: + # check if integer + try: + int(x) + except ValueError: + return False + return True + + +def _get_itemsize(dtype: str) -> int: + d = np.dtype(dtype) + return d.itemsize + + +def _pad_chunk(data: bytes, expected_chunk_size: int) -> bytes: + return data + b'\0' * (expected_chunk_size - len(data)) + + +def _get_padded_size(store, key: str, val: bytes): + # If the key is a chunk and it's smaller than the expected size, then we + # need to pad it with zeros. This can happen if this is the final chunk + # in a contiguous hdf5 dataset. See + # https://github.com/NeurodataWithoutBorders/lindi/pull/84 + base_key = key.split('/')[-1] + if val and _is_chunk_base_key(base_key): + parent_key = key.split('/')[:-1] + zarray_key = '/'.join(parent_key) + '/.zarray' + if zarray_key in store: + zarray_json = store.__getitem__(zarray_key) + assert isinstance(zarray_json, bytes) + zarray = json.loads(zarray_json) + chunk_shape = zarray['chunks'] + dtype = zarray['dtype'] + if np.dtype(dtype).kind in ['i', 'u', 'f']: + expected_chunk_size = int(np.prod(chunk_shape)) * _get_itemsize(dtype) + if len(val) < expected_chunk_size: + return expected_chunk_size + + return None diff --git a/lindi/LindiH5pyFile/writers/LindiH5pyGroupWriter.py b/lindi/LindiH5pyFile/writers/LindiH5pyGroupWriter.py index cadbcab..0111603 100644 --- a/lindi/LindiH5pyFile/writers/LindiH5pyGroupWriter.py +++ b/lindi/LindiH5pyFile/writers/LindiH5pyGroupWriter.py @@ -83,6 +83,8 @@ def create_dataset( _zarr_compressor = numcodecs.GZip(level=level) else: raise Exception(f'Compression {compression} is not supported') + elif compression is None: + _zarr_compressor = None else: raise Exception(f'Unexpected type for compression: {type(compression)}') diff --git a/lindi/conversion/create_zarr_dataset_from_h5_data.py b/lindi/conversion/create_zarr_dataset_from_h5_data.py index fa2fe45..851897a 100644 --- a/lindi/conversion/create_zarr_dataset_from_h5_data.py +++ b/lindi/conversion/create_zarr_dataset_from_h5_data.py @@ -19,7 +19,7 @@ def create_zarr_dataset_from_h5_data( name: str, label: str, h5_chunks: Union[Tuple, None], - zarr_compressor: Union[Codec, Literal['default']] = 'default' + zarr_compressor: Union[Codec, Literal['default'], None] = 'default' ): """Create a zarr dataset from an h5py dataset. @@ -43,9 +43,9 @@ def create_zarr_dataset_from_h5_data( The name of the h5py dataset for error messages. h5_chunks : tuple The chunk shape of the h5py dataset. - zarr_compressor : numcodecs.abc.Codec + zarr_compressor : numcodecs.abc.Codec, 'default', or None The codec compressor to use when writing the dataset. If default, the - default compressor will be used. + default compressor will be used. When None, no compressor will be used. """ if h5_dtype is None: raise Exception(f'No dtype in h5_to_zarr_dataset_prep for dataset {label}') @@ -58,7 +58,7 @@ def create_zarr_dataset_from_h5_data( if h5_data is None: raise Exception(f'Data must be provided for scalar dataset {label}') - if zarr_compressor != 'default': + if zarr_compressor != 'default' and zarr_compressor is not None: raise Exception('zarr_compressor is not supported for scalar datasets') if _is_numeric_dtype(h5_dtype) or h5_dtype in [bool, np.bool_]: @@ -113,13 +113,14 @@ def create_zarr_dataset_from_h5_data( if _is_numeric_dtype(h5_dtype) or h5_dtype in [bool, np.bool_]: # integer, unsigned integer, float, bool # This is the normal case of a chunked dataset with a numeric (or boolean) dtype if h5_chunks is None: - # We require that chunks be specified when writing a dataset with more - # than 1 million elements. This is because zarr may default to - # suboptimal chunking. Note that the default for h5py is to use the - # entire dataset as a single chunk. - total_size = np.prod(h5_shape) if len(h5_shape) > 0 else 1 - if total_size > 1000 * 1000: - raise Exception(f'Chunks must be specified explicitly when writing dataset of shape {h5_shape}') + # # We require that chunks be specified when writing a dataset with more + # # than 1 million elements. This is because zarr may default to + # # suboptimal chunking. Note that the default for h5py is to use the + # # entire dataset as a single chunk. + # total_size = int(np.prod(h5_shape)) if len(h5_shape) > 0 else 1 + # if total_size > 1000 * 1000: + # raise Exception(f'Chunks must be specified explicitly when writing dataset of shape {h5_shape}') + h5_chunks = _get_default_chunks(h5_shape, h5_dtype) # Note that we are not using the same filters as in the h5py dataset return zarr_parent_group.create_dataset( name, @@ -131,7 +132,7 @@ def create_zarr_dataset_from_h5_data( ) elif h5_dtype.kind == 'O': # For type object, we are going to use the JSON codec - if zarr_compressor != 'default': + if zarr_compressor != 'default' and zarr_compressor is not None: raise Exception('zarr_compressor is not supported for object datasets') if h5_data is not None: if isinstance(h5_data, h5py.Dataset): @@ -149,7 +150,7 @@ def create_zarr_dataset_from_h5_data( object_codec=object_codec ) elif h5_dtype.kind == 'S': # byte string - if zarr_compressor != 'default': + if zarr_compressor != 'default' and zarr_compressor is not None: raise Exception('zarr_compressor is not supported for byte string datasets') if h5_data is None: raise Exception(f'Data must be provided when converting dataset {label} with dtype {h5_dtype}') @@ -161,11 +162,11 @@ def create_zarr_dataset_from_h5_data( data=h5_data ) elif h5_dtype.kind == 'U': # unicode string - if zarr_compressor != 'default': + if zarr_compressor != 'default' and zarr_compressor is not None: raise Exception('zarr_compressor is not supported for unicode string datasets') raise Exception(f'Array of unicode strings not supported: dataset {label} with dtype {h5_dtype} and shape {h5_shape}') elif h5_dtype.kind == 'V' and h5_dtype.fields is not None: # compound dtype - if zarr_compressor != 'default': + if zarr_compressor != 'default' and zarr_compressor is not None: raise Exception('zarr_compressor is not supported for compound datasets') if h5_data is None: raise Exception(f'Data must be provided when converting compound dataset {label}') @@ -252,3 +253,15 @@ def h5_object_data_to_zarr_data(h5_data: Union[np.ndarray, list], *, h5f: Union[ else: raise Exception(f'Cannot handle value of type {type(val)} in dataset {label} with dtype {h5_data.dtype} and shape {h5_data.shape}') return zarr_data + + +def _get_default_chunks(shape: Tuple, dtype: Any) -> Tuple: + dtype_size = np.dtype(dtype).itemsize + shape_prod_0 = np.prod(shape[1:]) + optimal_chunk_size_bytes = 1024 * 1024 * 20 # 20 MB + optimal_chunk_size = optimal_chunk_size_bytes // (dtype_size * shape_prod_0) + if optimal_chunk_size <= shape[0]: + return shape + if optimal_chunk_size < 1: + return (1,) + shape[1:] + return (optimal_chunk_size,) + shape[1:] diff --git a/lindi/tar/LindiTarStore.py b/lindi/tar/LindiTarStore.py new file mode 100644 index 0000000..d64b4d2 --- /dev/null +++ b/lindi/tar/LindiTarStore.py @@ -0,0 +1,87 @@ +import numpy as np +from zarr.storage import Store as ZarrStore +from ..LindiH5pyFile.LindiReferenceFileSystemStore import LindiReferenceFileSystemStore +from .lindi_tar import LindiTarFile + + +class LindiTarStore(ZarrStore): + def __init__(self, *, base_store: LindiReferenceFileSystemStore, tar_file: LindiTarFile): + self._base_store = base_store + self._tar_file = tar_file + + def __getitem__(self, key: str): + return self._base_store.__getitem__(key) + + def __setitem__(self, key: str, value: bytes): + self.setitems({key: value}) + + def setitems(self, items_dict: dict): + for key, value in items_dict.items(): + key_parts = key.split("/") + key_base_name = key_parts[-1] + + files_to_write_to_tar = {} + + if key_base_name.startswith('.') or key_base_name.endswith('.json'): # always inline .zattrs, .zgroup, .zarray, zarr.json + inline = True + else: + # presumably it is a chunk of an array + if isinstance(value, np.ndarray): + value = value.tobytes() + if not isinstance(value, bytes): + print(f"key: {key}, value type: {type(value)}") + raise ValueError("Value must be bytes") + size = len(value) + inline = size < 1000 # this should be a configurable threshold + if inline: + # If inline, save in memory + return self._base_store.__setitem__(key, value) + else: + # If not inline, save it as a new file in the tar file + key_without_initial_slash = key if not key.startswith("/") else key[1:] + fname_in_tar = f'blobs/{key_without_initial_slash}' + if self._tar_file.has_file_with_name(fname_in_tar): + v = 2 + while self._tar_file.has_file_with_name(f'{fname_in_tar}.v{v}'): + v += 1 + fname_in_tar = f'{fname_in_tar}.v{v}' + files_to_write_to_tar[fname_in_tar] = value + + self._set_ref_reference(key_without_initial_slash, f'./{fname_in_tar}', 0, len(value)) + + self._tar_file.write_files(files_to_write_to_tar) + + def __delitem__(self, key: str): + # We don't actually delete the file from the tar, but maybe it would be + # smart to put it in .trash in the future + return self._base_store.__delitem__(key) + + def __iter__(self): + return self._base_store.__iter__() + + def __len__(self): + return self._base_store.__len__() + + # These methods are overridden from BaseStore + def is_readable(self): + return True + + def is_writeable(self): + return True + + def is_listable(self): + return True + + def is_erasable(self): + return False + + def _set_ref_reference(self, key: str, filename: str, offset: int, size: int): + rfs = self._base_store.rfs + if 'refs' not in rfs: + # this shouldn't happen, but we'll be defensive + rfs['refs'] = {} + rfs['refs'][key] = [ + filename, + offset, + size + ] diff --git a/lindi/tar/create_tar_header.py b/lindi/tar/create_tar_header.py new file mode 100644 index 0000000..7e045aa --- /dev/null +++ b/lindi/tar/create_tar_header.py @@ -0,0 +1,107 @@ +def create_tar_header(file_name: str, file_size: int) -> bytes: + # We use USTAR format only + h = b'' + + # file name + a = file_name.encode() + b"\x00" * (100 - len(file_name)) + h += a + + # file mode + a = b"0000644\x00" # 644 is the default permission - you can read and write, but others can only read + h += a + + # uid + a = b"0000000\x00" # 0 is the default user id + h += a + + # gid + a = b"0000000\x00" # 0 is the default group id + h += a + + # size + # we need an octal representation of the size + a = f"{file_size:011o}".encode() + b"\x00" # 11 octal digits + h += a + + # mtime + a = b"00000000000\x00" # 0 is the default modification time + h += a + + # chksum + # We'll determine the checksum after creating the full header + a = b" " * 8 # 8 spaces for now + h += a + + # typeflag + a = b"0" # default typeflag is 0 representing a regular file + h += a + + # linkname + a = b"\x00" * 100 # no link name + h += a + + # magic + a = b"ustar\x00" # specifies the ustar format + h += a + + # version + a = b"00" # ustar version + h += a + + # uname + a = b"\x00" * 32 # no user name + h += a + + # gname + a = b"\x00" * 32 # no group name + h += a + + # devmajor + a = b"\x00" * 8 # no device major number + h += a + + # devminor + a = b"\x00" * 8 # no device minor number + h += a + + # prefix + a = b"\x00" * 155 # no prefix + h += a + + # padding + a = b"\x00" * 12 # padding + h += a + + # Now we calculate the checksum + chksum = _compute_checksum_for_header(h) + h = h[:148] + chksum + h[156:] + + assert len(h) == 512 + + return h + + +def _compute_checksum_for_header(header: bytes) -> bytes: + # From https://en.wikipedia.org/wiki/Tar_(computing) + # The checksum is calculated by taking the sum of the unsigned byte values + # of the header record with the eight checksum bytes taken to be ASCII + # spaces (decimal value 32). It is stored as a six digit octal number with + # leading zeroes followed by a NUL and then a space. Various implementations + # do not adhere to this format. In addition, some historic tar + # implementations treated bytes as signed. Implementations typically + # calculate the checksum both ways, and treat it as good if either the + # signed or unsigned sum matches the included checksum. + + header_byte_list = [] + for byte in header: + header_byte_list.append(byte) + for i in range(148, 156): + header_byte_list[i] = 32 + sum = 0 + for byte in header_byte_list: + sum += byte + checksum = oct(sum).encode()[2:] + while len(checksum) < 6: + checksum = b"0" + checksum + checksum += b"\0 " + return checksum diff --git a/lindi/tar/lindi_tar.py b/lindi/tar/lindi_tar.py new file mode 100644 index 0000000..2398786 --- /dev/null +++ b/lindi/tar/lindi_tar.py @@ -0,0 +1,457 @@ +import os +import json +import random +import urllib.request +from .create_tar_header import create_tar_header + + +TAR_ENTRY_JSON_SIZE = 1024 +INITIAL_TAR_INDEX_JSON_SIZE = 1024 * 8 +INITIAL_LINDI_JSON_SIZE = 1024 * 8 + + +class LindiTarFile: + def __init__(self, tar_path_or_url: str, dir_representation=False): + self._tar_path_or_url = tar_path_or_url + self._dir_representation = dir_representation + self._is_remote = tar_path_or_url.startswith("http://") or tar_path_or_url.startswith("https://") + + if not dir_representation: + # Load the entry json + entry_json = _load_bytes_from_local_or_remote_file(self._tar_path_or_url, 512, 512 + TAR_ENTRY_JSON_SIZE) + entry = json.loads(entry_json) + index_info = entry['index'] + + # Load the index json + index_json = _load_bytes_from_local_or_remote_file(self._tar_path_or_url, index_info['d'], index_info['d'] + index_info['s']) + self._index = json.loads(index_json) + self._index_has_changed = False + + self._index_lookup = {} + for file in self._index['files']: + self._index_lookup[file['n']] = file + + # Verify that the index file correctly has a reference to itself + index_info_2 = self._index_lookup.get(".tar_index.json", None) + if index_info_2 is None: + raise ValueError("File .tar_index.json not found in index") + for k in ['n', 'o', 'd', 's']: + if k not in index_info_2: + raise ValueError(f"File .tar_index.json does not have key {k}") + if k not in index_info: + raise ValueError(f"File .tar_index.json does not have key {k}") + if index_info_2[k] != index_info[k]: + raise ValueError(f"File .tar_index.json has unexpected value for key {k}") + # Verify that the index file correctly as a reference to the entry file + entry_info = self._index_lookup.get(".tar_entry.json", None) + if entry_info is None: + raise ValueError("File .tar_entry.json not found in index") + if entry_info['n'] != ".tar_entry.json": + raise ValueError("File .tar_entry.json has unexpected name") + if entry_info['o'] != 0: + raise ValueError("File .tar_entry.json has unexpected offset") + if entry_info['d'] != 512: + raise ValueError("File .tar_entry.json has unexpected data offset") + if entry_info['s'] != TAR_ENTRY_JSON_SIZE: + raise ValueError("File .tar_entry.json has unexpected size") + self._file = open(self._tar_path_or_url, "r+b") if not self._is_remote else None + else: + self._index = None + self._index_has_changed = False + self._index_lookup = None + self._file = None + + def close(self): + if not self._dir_representation: + self._update_index_in_file() + if self._file is not None: + self._file.close() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.close() + + def get_file_info(self, file_name: str): + if self._dir_representation: + raise ValueError("Cannot get file info in a directory representation") + assert self._index_lookup + return self._index_lookup.get(file_name, None) + + def overwrite_file_content(self, file_name: str, data: bytes): + if self._is_remote: + raise ValueError("Cannot overwrite file content in a remote tar file") + if not self._dir_representation: + if self._file is None: + raise ValueError("File is not open") + info = self.get_file_info(file_name) + if info is None: + raise FileNotFoundError(f"File {file_name} not found") + if info['s'] != len(data): + raise ValueError("Unexpected problem in overwrite_file_content(): data size must match the size of the existing file") + self._file.seek(info['d']) + self._file.write(data) + else: + # for safety: + file_parts = file_name.split("/") + for part in file_parts[:-1]: + if part.startswith('..'): + raise ValueError(f"Invalid path: {file_name}") + fname = self._tar_path_or_url + "/" + file_name + with open(fname, "wb") as f: + f.write(data) + + def trash_file(self, file_name: str): + if self._is_remote: + raise ValueError("Cannot trash a file in a remote tar file") + if not self._dir_representation: + if self._file is None: + raise ValueError("File is not open") + assert self._index + assert self._index_lookup + info = self.get_file_info(file_name) + if info is None: + raise FileNotFoundError(f"File {file_name} not found") + zeros = b"-" * info['s'] + self._file.seek(info['d']) + self._file.write(zeros) + self._change_name_of_file( + file_name, + f'.trash/{file_name}.{_create_random_string()}' + ) + self._index['files'] = [file for file in self._index['files'] if file['n'] != file_name] + del self._index_lookup[file_name] + self._index_has_changed = True + else: + # for safety: + file_parts = file_name.split("/") + for part in file_parts[:-1]: + if part.startswith('..'): + raise ValueError(f"Invalid path: {file_name}") + fname = self._tar_path_or_url + "/" + file_name + os.remove(fname) + + def write_rfs(self, rfs: dict): + if self._is_remote: + raise ValueError("Cannot write a file in a remote tar file") + + rfs_json = json.dumps(rfs, indent=2, sort_keys=True) + + if not self._dir_representation: + existing_lindi_json_info = self.get_file_info("lindi.json") + if existing_lindi_json_info is not None: + file_size = existing_lindi_json_info['s'] + if file_size >= len(rfs_json): + # We are going to overwrite the existing lindi.json with the new + # one. But first we pad it with spaces to the same size as the + # existing one. + padding = b" " * (file_size - len(rfs_json)) + rfs_json = rfs_json.encode() + padding + self.overwrite_file_content("lindi.json", rfs_json) + else: + # In this case we need to trash the existing file and write a new one + # at the end of the tar file. + self.trash_file("lindi.json") + rfs_json = _pad_bytes_to_leave_room_for_growth(rfs_json, INITIAL_LINDI_JSON_SIZE) + self.write_file("lindi.json", rfs_json) + else: + # We are writing a new lindi.json. + rfs_json = _pad_bytes_to_leave_room_for_growth(rfs_json, INITIAL_LINDI_JSON_SIZE) + self.write_file("lindi.json", rfs_json) + else: + with open(self._tar_path_or_url + "/lindi.json", "wb") as f: + f.write(rfs_json.encode()) + + def get_file_byte_range(self, file_name: str) -> tuple: + if self._dir_representation: + raise ValueError("Cannot get file byte range in a directory representation") + info = self.get_file_info(file_name) + if info is None: + raise FileNotFoundError(f"File {file_name} not found in tar file") + return info['d'], info['d'] + info['s'] + + def has_file_with_name(self, file_name: str) -> bool: + if not self._dir_representation: + return self.get_file_info(file_name) is not None + else: + return os.path.exists(self._tar_path_or_url + "/" + file_name) + + def _change_name_of_file(self, file_name: str, new_file_name: str): + if self._dir_representation: + raise ValueError("Cannot change the name of a file in a directory representation") + if self._is_remote: + raise ValueError("Cannot change the name of a file in a remote tar file") + if self._file is None: + raise ValueError("File is not open") + info = self.get_file_info(file_name) + if info is None: + raise FileNotFoundError(f"File {file_name} not found") + header_start_byte = info['o'] + file_name_byte_range = (header_start_byte + 0, header_start_byte + 100) + file_name_prefix_byte_range = (header_start_byte + 345, header_start_byte + 345 + 155) + self._file.seek(file_name_byte_range[0]) + self._file.write(new_file_name.encode()) + # set the rest of the field to zeros + self._file.write(b"\x00" * (file_name_byte_range[1] - file_name_byte_range[0] - len(new_file_name))) + + self._file.seek(file_name_prefix_byte_range[0]) + self._file.write(b"\x00" * (file_name_prefix_byte_range[1] - file_name_prefix_byte_range[0])) + + _fix_checksum_in_header(self._file, header_start_byte) + + def write_file(self, file_name: str, data: bytes): + self.write_files({file_name: data}) + + def write_files(self, files: dict): + if self._is_remote: + raise ValueError("Cannot write a file in a remote tar file") + if not self._dir_representation: + assert self._index + assert self._index_lookup + if self._file is None: + raise ValueError("File is not open") + self._file.seek(-1024, 2) + hh = self._file.read(1024) + if hh != b"\x00" * 1024: + raise ValueError("The tar file does not end with 1024 bytes of zeros") + self._file.seek(-1024, 2) + + file_pos = self._file.tell() + + for file_name, data in files.items(): + x = { + 'n': file_name, + 'o': file_pos, + 'd': file_pos + 512, # we assume the header is 512 bytes + 's': len(data) + } + + # write the tar header + tar_header = create_tar_header(file_name, len(data)) + + # pad up to blocks of 512 + if len(data) % 512 != 0: + padding_len = 512 - len(data) % 512 + else: + padding_len = 0 + + self._file.write(tar_header) + self._file.write(data) + self._file.write(b"\x00" * padding_len) + file_pos += 512 + len(data) + padding_len + + self._index['files'].append(x) + self._index_lookup[file_name] = x + self._index_has_changed = True + + # write the 1024 bytes marking the end of the file + self._file.write(b"\x00" * 1024) + else: + for file_name, data in files.items(): + # for safety: + file_parts = file_name.split("/") + for part in file_parts[:-1]: + if part.startswith('..'): + raise ValueError(f"Invalid path: {file_name}") + fname = self._tar_path_or_url + "/" + file_name + parent_dir = os.path.dirname(fname) + if not os.path.exists(parent_dir): + os.makedirs(parent_dir) + with open(fname, "wb") as f: + f.write(data) + + def read_file(self, file_name: str) -> bytes: + if not self._dir_representation: + info = self.get_file_info(file_name) + if info is None: + raise FileNotFoundError(f"File {file_name} not found for {self._tar_path_or_url}") + start_byte = info['d'] + size = info['s'] + return _load_bytes_from_local_or_remote_file(self._tar_path_or_url, start_byte, start_byte + size) + else: + return _load_all_bytes_from_local_or_remote_file(self._tar_path_or_url + "/" + file_name) + + @staticmethod + def create(fname: str, *, rfs: dict, dir_representation=False): + if not dir_representation: + with open(fname, "wb") as f: + # Define the sizes and names of the entry and index files + tar_entry_json_name = ".tar_entry.json" + tar_entry_json_size = TAR_ENTRY_JSON_SIZE + tar_index_json_size = INITIAL_TAR_INDEX_JSON_SIZE + tar_index_json_name = ".tar_index.json" + tar_index_json_offset = 512 + TAR_ENTRY_JSON_SIZE + tar_index_json_offset_data = tar_index_json_offset + 512 + + # Define the content of .tar_entry.json + initial_entry_json = json.dumps({ + 'index': { + 'n': tar_index_json_name, + 'o': tar_index_json_offset, + 'd': tar_index_json_offset_data, + 's': tar_index_json_size + } + }, indent=2, sort_keys=True) + initial_entry_json = initial_entry_json.encode() + b" " * (tar_entry_json_size - len(initial_entry_json)) + + # Define the content of .tar_index.json + initial_index_json = json.dumps({ + 'files': [ + { + 'n': tar_entry_json_name, + 'o': 0, + 'd': 512, + 's': tar_entry_json_size + }, + { + 'n': tar_index_json_name, + 'o': tar_index_json_offset, + 'd': tar_index_json_offset_data, + 's': tar_index_json_size + } + ] + }, indent=2, sort_keys=True) + initial_index_json = initial_index_json.encode() + b" " * (tar_index_json_size - len(initial_index_json)) + + # Write the initial entry file (.tar_entry.json). This will always + # be the first file in the tar file, and has a fixed size. + header = create_tar_header(tar_entry_json_name, tar_entry_json_size) + f.write(header) + f.write(initial_entry_json) + + # Write the initial index file (.tar_index.json) this will start as + # the second file in the tar file but as it grows outside the + # initial bounds, a new index file will be appended to the end of + # the tar, and then entry file will be updated accordingly to point + # to the new index file. + header = create_tar_header(tar_index_json_name, tar_index_json_size) + f.write(header) + f.write(initial_index_json) + + f.write(b"\x00" * 1024) + else: + if os.path.exists(fname): + raise ValueError(f"Directory {fname} already exists") + os.makedirs(fname) + + # write the rfs file + tf = LindiTarFile(fname, dir_representation=dir_representation) + tf.write_rfs(rfs) + tf.close() + + def _update_index_in_file(self): + if self._dir_representation: + return + assert self._index + if not self._index_has_changed: + return + if self._is_remote: + raise ValueError("Cannot update the index in a remote tar file") + if self._file is None: + raise ValueError("File is not open") + existing_index_json = self.read_file(".tar_index.json") + new_index_json = json.dumps(self._index, indent=2, sort_keys=True) + if len(new_index_json) <= len(existing_index_json): + # we can overwrite the existing index file + new_index_json = new_index_json.encode() + b" " * (len(existing_index_json) - len(new_index_json)) + self.overwrite_file_content(".tar_index.json", new_index_json) + else: + # we must create a new index file + self.trash_file(".tar_index.json") + + # after we trash the file, the index has changed once again + new_index_json = json.dumps(self._index, indent=2, sort_keys=True) + new_index_json = _pad_bytes_to_leave_room_for_growth(new_index_json, INITIAL_TAR_INDEX_JSON_SIZE) + new_index_json_size = len(new_index_json) + self.write_file(".tar_index.json", new_index_json) + + # now the index has changed once again, but we assume it doesn't exceed the size + new_index_json = json.dumps(self._index, indent=2, sort_keys=True) + new_index_json = new_index_json.encode() + b" " * (new_index_json_size - len(new_index_json)) + self.overwrite_file_content(".tar_index.json", new_index_json) + + tar_index_info = next(file for file in self._index['files'] if file['n'] == ".tar_index.json") + new_entry_json = json.dumps({ + 'index': { + 'n': tar_index_info['n'], + 'o': tar_index_info['o'], + 'd': tar_index_info['d'], + 's': tar_index_info['s'] + } + }, indent=2, sort_keys=True) + new_entry_json = new_entry_json.encode() + b" " * (TAR_ENTRY_JSON_SIZE - len(new_entry_json)) + self._file.seek(512) + self._file.write(new_entry_json) + self._file.flush() + self._index_has_changed = False + + +def _download_file_byte_range(url: str, start: int, end: int) -> bytes: + headers = { + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3", + "Range": f"bytes={start}-{end - 1}" + } + req = urllib.request.Request(url, headers=headers) + with urllib.request.urlopen(req) as response: + return response.read() + + +def _load_all_bytes_from_local_or_remote_file(path_or_url: str) -> bytes: + if path_or_url.startswith("http://") or path_or_url.startswith("https://"): + with urllib.request.urlopen(path_or_url) as response: + return response.read() + else: + with open(path_or_url, "rb") as f: + return f.read() + + +def _load_bytes_from_local_or_remote_file(path_or_url: str, start: int, end: int) -> bytes: + if path_or_url.startswith("http://") or path_or_url.startswith("https://"): + return _download_file_byte_range(path_or_url, start, end) + else: + with open(path_or_url, "rb") as f: + f.seek(start) + return f.read(end - start) + + +def _pad_bytes_to_leave_room_for_growth(x: str, initial_size: int) -> bytes: + total_size = initial_size + while total_size < len(x) * 4: + total_size *= 2 + padding = b" " * (total_size - len(x)) + return x.encode() + padding + + +def _fix_checksum_in_header(f, header_start_byte): + f.seek(header_start_byte) + header = f.read(512) + + # From https://en.wikipedia.org/wiki/Tar_(computing) + # The checksum is calculated by taking the sum of the unsigned byte values + # of the header record with the eight checksum bytes taken to be ASCII + # spaces (decimal value 32). It is stored as a six digit octal number with + # leading zeroes followed by a NUL and then a space. Various implementations + # do not adhere to this format. In addition, some historic tar + # implementations treated bytes as signed. Implementations typically + # calculate the checksum both ways, and treat it as good if either the + # signed or unsigned sum matches the included checksum. + + header_byte_list = [] + for byte in header: + header_byte_list.append(byte) + for i in range(148, 156): + header_byte_list[i] = 32 + sum = 0 + for byte in header_byte_list: + sum += byte + checksum = oct(sum).encode()[2:] + while len(checksum) < 6: + checksum = b"0" + checksum + checksum += b"\x00 " + f.seek(header_start_byte + 148) + f.write(checksum) + + +def _create_random_string(): + return "".join(random.choices("abcdefghijklmnopqrstuvwxyz", k=10)) diff --git a/pyproject.toml b/pyproject.toml index 43c0fad..a30295e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "lindi" -version = "0.3.10" +version = "0.4.0a1" description = "" authors = [ "Jeremy Magland ", diff --git a/tests/test_local_cache.py b/tests/test_local_cache.py index 7fdb857..4d180ef 100644 --- a/tests/test_local_cache.py +++ b/tests/test_local_cache.py @@ -42,7 +42,7 @@ def test_remote_data_1(): elapsed_0 = elapsed if passnum == 1: elapsed_1 = elapsed - assert elapsed_1 < elapsed_0 * 0.3 # type: ignore + assert elapsed_1 < elapsed_0 * 0.6 # type: ignore def test_put_local_cache(): diff --git a/tests/test_split_contiguous_dataset.py b/tests/test_split_contiguous_dataset.py new file mode 100644 index 0000000..569fec3 --- /dev/null +++ b/tests/test_split_contiguous_dataset.py @@ -0,0 +1,29 @@ +import pytest +import lindi +import h5py + + +@pytest.mark.network +def test_split_contiguous_dataset(): + # https://neurosift.app/?p=/nwb&dandisetId=000935&dandisetVersion=draft&url=https://api.dandiarchive.org/api/assets/e18e787a-544a-438e-8396-f396efb3bd3d/download/ + h5_url = "https://api.dandiarchive.org/api/assets/e18e787a-544a-438e-8396-f396efb3bd3d/download/" + + opts = lindi.LindiH5ZarrStoreOpts( + contiguous_dataset_max_chunk_size=1000 * 1000 * 17 + ) + x = lindi.LindiH5pyFile.from_hdf5_file(h5_url, zarr_store_opts=opts) + d = x['acquisition/ElectricalSeries/data'] + assert isinstance(d, h5py.Dataset) + print(d.shape) + assert d[0][0] == 6.736724784228119e-06 + assert d[10 * 1000 * 1000][0] == -1.0145925267155008e-06 + rfs = x.to_reference_file_system() + zarray = rfs['refs']['acquisition/ElectricalSeries/data/.zarray'] + assert zarray['chunks'] == [66406, 32] + aa = rfs['refs']['acquisition/ElectricalSeries/data/5.0'] + assert aa[1] == 2415072880 + assert aa[2] == 16999936 + + +if __name__ == "__main__": + test_split_contiguous_dataset()