From dec2e6ef4cc3c92760571eb0486cd830bcae2603 Mon Sep 17 00:00:00 2001 From: Jeremy Magland Date: Sat, 3 Aug 2024 09:15:01 -0400 Subject: [PATCH] do not rely on tarfile --- examples/benchmark1.py | 121 ++++++++++++++++ lindi/LindiH5pyFile/LindiH5pyFile.py | 1 + .../writers/LindiH5pyGroupWriter.py | 2 + .../create_zarr_dataset_from_h5_data.py | 16 +-- lindi/tar/LindiTarStore.py | 4 + lindi/tar/lindi_tar.py | 129 +++++++++++------- 6 files changed, 218 insertions(+), 55 deletions(-) create mode 100644 examples/benchmark1.py diff --git a/examples/benchmark1.py b/examples/benchmark1.py new file mode 100644 index 0000000..acba6e5 --- /dev/null +++ b/examples/benchmark1.py @@ -0,0 +1,121 @@ +import os +import h5py +import numpy as np +import time +import lindi +import gzip +import zarr +import numcodecs + + +def create_dataset(size): + return np.random.rand(size) + + +def benchmark_h5py(file_path, num_small_datasets, num_large_datasets, small_size, large_size, compression, mode): + start_time = time.time() + + if mode == 'dat': + with open(file_path, 'wb') as f: + # Write small datasets + print('Writing small datasets') + for i in range(num_small_datasets): + data = create_dataset(small_size) + f.write(data.tobytes()) + + # Write large datasets + print('Writing large datasets') + for i in range(num_large_datasets): + data = create_dataset(large_size) + if compression == 'gzip': + data_zipped = gzip.compress(data.tobytes(), compresslevel=4) + f.write(data_zipped) + elif compression is None: + f.write(data.tobytes()) + else: + raise ValueError(f"Unknown compressor: {compression}") + elif mode == 'zarr': + if os.path.exists(file_path): + import shutil + shutil.rmtree(file_path) + store = zarr.DirectoryStore(file_path) + root = zarr.group(store) + + if compression == 'gzip': + compressor = numcodecs.GZip(level=4) + else: + compressor = None + + # Write small datasets + print('Writing small datasets') + for i in range(num_small_datasets): + data = create_dataset(small_size) + root.create_dataset(f'small_dataset_{i}', data=data) + + # Write large datasets + print('Writing large datasets') + for i in range(num_large_datasets): + data = create_dataset(large_size) + root.create_dataset(f'large_dataset_{i}', data=data, chunks=(1000,), compressor=compressor) + else: + if mode == 'h5': + f = h5py.File(file_path, 'w') + else: + f = lindi.LindiH5pyFile.from_lindi_file(file_path, mode='w') + + # Write small datasets + print('Writing small datasets') + for i in range(num_small_datasets): + data = create_dataset(small_size) + ds = f.create_dataset(f'small_dataset_{i}', data=data) + ds.attrs['attr1'] = 1 + + # Write large datasets + print('Writing large datasets') + for i in range(num_large_datasets): + data = create_dataset(large_size) + ds = f.create_dataset(f'large_dataset_{i}', data=data, chunks=(1000,), compression=compression) + ds.attrs['attr1'] = 1 + + f.close() + + end_time = time.time() + total_time = end_time - start_time + + # Calculate total data size + total_size = (num_small_datasets * small_size + num_large_datasets * large_size) * 8 # 8 bytes per float64 + total_size_gb = total_size / (1024 ** 3) + + print("H5PY Benchmark Results:") + print(f"Total time: {total_time:.2f} seconds") + print(f"Total data size: {total_size_gb:.2f} GB") + print(f"Write speed: {total_size_gb / total_time:.2f} GB/s") + + h5py_file_size = os.path.getsize(file_path) / (1024 ** 3) + print(f"File size: {h5py_file_size:.2f} GB") + + return total_time, total_size_gb + + +if __name__ == "__main__": + file_path_h5 = "benchmark.h5" + file_path_lindi = "benchmark.lindi" + file_path_dat = "benchmark.dat" + file_path_zarr = "benchmark.zarr" + num_small_datasets = 0 + num_large_datasets = 5 + small_size = 1000 + large_size = 10000000 + compression = None # 'gzip' or None + + print('Zarr Benchmark') + lindi_time, total_size = benchmark_h5py(file_path_zarr, num_small_datasets, num_large_datasets, small_size, large_size, compression=compression, mode='zarr') + print('') + print('Lindi Benchmark') + lindi_time, total_size = benchmark_h5py(file_path_lindi, num_small_datasets, num_large_datasets, small_size, large_size, compression=compression, mode='lindi') + print('') + print('H5PY Benchmark') + h5py_time, total_size = benchmark_h5py(file_path_h5, num_small_datasets, num_large_datasets, small_size, large_size, compression=compression, mode='h5') + print('') + print('DAT Benchmark') + lindi_time, total_size = benchmark_h5py(file_path_dat, num_small_datasets, num_large_datasets, small_size, large_size, compression=compression, mode='dat') diff --git a/lindi/LindiH5pyFile/LindiH5pyFile.py b/lindi/LindiH5pyFile/LindiH5pyFile.py index 485d316..594a190 100644 --- a/lindi/LindiH5pyFile/LindiH5pyFile.py +++ b/lindi/LindiH5pyFile/LindiH5pyFile.py @@ -400,6 +400,7 @@ def flush(self): rfs = self.to_reference_file_system() if self._source_tar_file: self._source_tar_file.write_rfs(rfs) + self._source_tar_file._update_index() # very important else: _write_rfs_to_file(rfs=rfs, output_file_name=self._source_url_or_path) diff --git a/lindi/LindiH5pyFile/writers/LindiH5pyGroupWriter.py b/lindi/LindiH5pyFile/writers/LindiH5pyGroupWriter.py index cadbcab..0111603 100644 --- a/lindi/LindiH5pyFile/writers/LindiH5pyGroupWriter.py +++ b/lindi/LindiH5pyFile/writers/LindiH5pyGroupWriter.py @@ -83,6 +83,8 @@ def create_dataset( _zarr_compressor = numcodecs.GZip(level=level) else: raise Exception(f'Compression {compression} is not supported') + elif compression is None: + _zarr_compressor = None else: raise Exception(f'Unexpected type for compression: {type(compression)}') diff --git a/lindi/conversion/create_zarr_dataset_from_h5_data.py b/lindi/conversion/create_zarr_dataset_from_h5_data.py index 669c517..bed95c7 100644 --- a/lindi/conversion/create_zarr_dataset_from_h5_data.py +++ b/lindi/conversion/create_zarr_dataset_from_h5_data.py @@ -19,7 +19,7 @@ def create_zarr_dataset_from_h5_data( name: str, label: str, h5_chunks: Union[Tuple, None], - zarr_compressor: Union[Codec, Literal['default']] = 'default' + zarr_compressor: Union[Codec, Literal['default'], None] = 'default' ): """Create a zarr dataset from an h5py dataset. @@ -43,9 +43,9 @@ def create_zarr_dataset_from_h5_data( The name of the h5py dataset for error messages. h5_chunks : tuple The chunk shape of the h5py dataset. - zarr_compressor : numcodecs.abc.Codec + zarr_compressor : numcodecs.abc.Codec, 'default', or None The codec compressor to use when writing the dataset. If default, the - default compressor will be used. + default compressor will be used. When None, no compressor will be used. """ if h5_dtype is None: raise Exception(f'No dtype in h5_to_zarr_dataset_prep for dataset {label}') @@ -58,7 +58,7 @@ def create_zarr_dataset_from_h5_data( if h5_data is None: raise Exception(f'Data must be provided for scalar dataset {label}') - if zarr_compressor != 'default': + if zarr_compressor != 'default' and zarr_compressor is not None: raise Exception('zarr_compressor is not supported for scalar datasets') if _is_numeric_dtype(h5_dtype) or h5_dtype in [bool, np.bool_]: @@ -131,7 +131,7 @@ def create_zarr_dataset_from_h5_data( ) elif h5_dtype.kind == 'O': # For type object, we are going to use the JSON codec - if zarr_compressor != 'default': + if zarr_compressor != 'default' and zarr_compressor is not None: raise Exception('zarr_compressor is not supported for object datasets') if h5_data is not None: if isinstance(h5_data, h5py.Dataset): @@ -149,7 +149,7 @@ def create_zarr_dataset_from_h5_data( object_codec=object_codec ) elif h5_dtype.kind == 'S': # byte string - if zarr_compressor != 'default': + if zarr_compressor != 'default' and zarr_compressor is not None: raise Exception('zarr_compressor is not supported for byte string datasets') if h5_data is None: raise Exception(f'Data must be provided when converting dataset {label} with dtype {h5_dtype}') @@ -161,11 +161,11 @@ def create_zarr_dataset_from_h5_data( data=h5_data ) elif h5_dtype.kind == 'U': # unicode string - if zarr_compressor != 'default': + if zarr_compressor != 'default' and zarr_compressor is not None: raise Exception('zarr_compressor is not supported for unicode string datasets') raise Exception(f'Array of unicode strings not supported: dataset {label} with dtype {h5_dtype} and shape {h5_shape}') elif h5_dtype.kind == 'V' and h5_dtype.fields is not None: # compound dtype - if zarr_compressor != 'default': + if zarr_compressor != 'default' and zarr_compressor is not None: raise Exception('zarr_compressor is not supported for compound datasets') if h5_data is None: raise Exception(f'Data must be provided when converting compound dataset {label}') diff --git a/lindi/tar/LindiTarStore.py b/lindi/tar/LindiTarStore.py index 8bd0f87..2651893 100644 --- a/lindi/tar/LindiTarStore.py +++ b/lindi/tar/LindiTarStore.py @@ -1,4 +1,5 @@ import random +import numpy as np from zarr.storage import Store as ZarrStore from ..LindiH5pyFile.LindiReferenceFileSystemStore import LindiReferenceFileSystemStore from .lindi_tar import LindiTarFile @@ -19,7 +20,10 @@ def __setitem__(self, key: str, value: bytes): inline = True else: # presumably it is a chunk of an array + if isinstance(value, np.ndarray): + value = value.tobytes() if not isinstance(value, bytes): + print(f"key: {key}, value type: {type(value)}") raise ValueError("Value must be bytes") size = len(value) inline = size < 1000 # this should be a configurable threshold diff --git a/lindi/tar/lindi_tar.py b/lindi/tar/lindi_tar.py index 1192072..b5d92e1 100644 --- a/lindi/tar/lindi_tar.py +++ b/lindi/tar/lindi_tar.py @@ -6,8 +6,8 @@ TAR_ENTRY_JSON_SIZE = 1024 -INITIAL_TAR_INDEX_JSON_SIZE = 1024 * 256 -INITIAL_LINDI_JSON_SIZE = 1024 * 256 +INITIAL_TAR_INDEX_JSON_SIZE = 1024 * 8 +INITIAL_LINDI_JSON_SIZE = 1024 * 8 class LindiTarFile: @@ -23,6 +23,24 @@ def __init__(self, tar_path_or_url: str): # Load the index json index_json = _load_bytes_from_local_or_remote_file(self._tar_path_or_url, index_info['d'], index_info['d'] + index_info['s']) self._index = json.loads(index_json) + self._index_has_changed = False + + self._index_lookup = {} + for file in self._index['files']: + self._index_lookup[file['n']] = file + + self._file = open(self._tar_path_or_url, "r+b") if not self._is_remote else None + + def close(self): + self._update_index() + if self._file is not None: + self._file.close() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.close() def get_file_info(self, file_name: str): for file in self._index['files']: @@ -33,25 +51,27 @@ def get_file_info(self, file_name: str): def overwrite_file_content(self, file_name: str, data: bytes): if self._is_remote: raise ValueError("Cannot overwrite file content in a remote tar file") + if self._file is None: + raise ValueError("File is not open") info = self.get_file_info(file_name) if info is None: raise FileNotFoundError(f"File {file_name} not found") if info['s'] != len(data): raise ValueError("Unexpected problem in overwrite_file_content(): data size must match the size of the existing file") - with open(self._tar_path_or_url, "r+b") as f: - f.seek(info['d']) - f.write(data) + self._file.seek(info['d']) + self._file.write(data) def trash_file(self, file_name: str, do_write_index=True): if self._is_remote: raise ValueError("Cannot trash a file in a remote tar file") + if self._file is None: + raise ValueError("File is not open") info = self.get_file_info(file_name) if info is None: raise FileNotFoundError(f"File {file_name} not found") zeros = b"-" * info['s'] - with open(self._tar_path_or_url, "r+b") as f: - f.seek(info['d']) - f.write(zeros) + self._file.seek(info['d']) + self._file.write(zeros) self._change_name_of_file(file_name, f'.trash/{file_name}.{_create_random_string()}', do_write_index=do_write_index) def write_rfs(self, rfs: dict): @@ -87,50 +107,49 @@ def get_file_byte_range(self, file_name: str) -> tuple: def _change_name_of_file(self, file_name: str, new_file_name: str, do_write_index=True): if self._is_remote: raise ValueError("Cannot change the name of a file in a remote tar file") + if self._file is None: + raise ValueError("File is not open") info = self.get_file_info(file_name) if info is None: raise FileNotFoundError(f"File {file_name} not found") header_start_byte = info['o'] file_name_byte_range = (header_start_byte + 0, header_start_byte + 100) file_name_prefix_byte_range = (header_start_byte + 345, header_start_byte + 345 + 155) - with open(self._tar_path_or_url, "r+b") as f: - f.seek(file_name_byte_range[0]) - f.write(new_file_name.encode()) - # set the rest of the field to zeros - f.write(b"\0" * (file_name_byte_range[1] - file_name_byte_range[0] - len(new_file_name))) - - f.seek(file_name_prefix_byte_range[0]) - f.write(b"\0" * (file_name_prefix_byte_range[1] - file_name_prefix_byte_range[0])) - - _fix_checksum_in_header(f, header_start_byte) - try: - file_in_index = next(file for file in self._index['files'] if file['n'] == file_name) - except StopIteration: + self._file.seek(file_name_byte_range[0]) + self._file.write(new_file_name.encode()) + # set the rest of the field to zeros + self._file.write(b"\0" * (file_name_byte_range[1] - file_name_byte_range[0] - len(new_file_name))) + + self._file.seek(file_name_prefix_byte_range[0]) + self._file.write(b"\0" * (file_name_prefix_byte_range[1] - file_name_prefix_byte_range[0])) + + _fix_checksum_in_header(self._file, header_start_byte) + file_in_index = self._index_lookup.get(file_name, None) + if file_in_index is None: raise ValueError(f"File {file_name} not found in index") file_in_index['n'] = new_file_name - if do_write_index: - self._update_index() + self._index_has_changed = True def write_file(self, file_name: str, data: bytes): if self._is_remote: raise ValueError("Cannot write a file in a remote tar file") - with tarfile.open(self._tar_path_or_url, "a") as tar: - tarinfo = tarfile.TarInfo(name=file_name) - tarinfo.size = len(data) - fileobj = io.BytesIO(data) - tar.addfile(tarinfo, fileobj) - with tarfile.open(self._tar_path_or_url, "r") as tar: - # TODO: do not call getmember here, because it may be slow instead - # parse the header of the new file directly and get the offset from - # there - info = tar.getmember(file_name) - self._index['files'].append({ - 'n': file_name, - 'o': info.offset, - 'd': info.offset_data, - 's': info.size - }) - self._update_index() + if self._file is None: + raise ValueError("File is not open") + self._file.seek(0, 2) # seek to the end of the file + file_pos = self._file.tell() + # write a dummy header + self._file.write(b" " * 512) + # write the data + self._file.write(data) + x = { + 'n': file_name, + 'o': file_pos, + 'd': file_pos + 512, # we assume the header is 512 bytes + 's': len(data) + } + self._index['files'].append(x) + self._index_lookup[file_name] = x + self._index_has_changed = True def read_file(self, file_name: str) -> bytes: info = self.get_file_info(file_name) @@ -199,10 +218,15 @@ def create(fname: str, *, rfs: dict): # write the rfs file tf = LindiTarFile(fname) tf.write_rfs(rfs) + tf.close() def _update_index(self): + if not self._index_has_changed: + return if self._is_remote: raise ValueError("Cannot update the index in a remote tar file") + if self._file is None: + raise ValueError("File is not open") existing_index_json = self.read_file(".tar_index.json") new_index_json = json.dumps(self._index, indent=2, sort_keys=True) if len(new_index_json) <= len(existing_index_json): @@ -220,15 +244,24 @@ def _update_index(self): self.write_file(".tar_index.json", new_index_json) # now we need to update the entry file - tar_index_info = self.get_file_info(".tar_index.json") - if tar_index_info is None: - raise ValueError("tar_index_info is None") + # tar_index_info = self.get_file_info(".tar_index.json") + # if tar_index_info is None: + # raise ValueError("tar_index_info is None") + # new_entry_json = json.dumps({ + # 'index': { + # 'n': tar_index_info.name, + # 'o': tar_index_info.offset, + # 'd': tar_index_info.offset_data, + # 's': tar_index_info.size + # } + # }, indent=2, sort_keys=True) + tar_index_info = next(file for file in self._index['files'] if file['n'] == ".tar_index.json") new_entry_json = json.dumps({ 'index': { - 'n': tar_index_info.name, - 'o': tar_index_info.offset, - 'd': tar_index_info.offset_data, - 's': tar_index_info.size + 'n': tar_index_info['n'], + 'o': tar_index_info['o'], + 'd': tar_index_info['d'], + 's': tar_index_info['s'] } }, indent=2, sort_keys=True) new_entry_json = new_entry_json.encode() + b" " * (TAR_ENTRY_JSON_SIZE - len(new_entry_json)) @@ -237,6 +270,8 @@ def _update_index(self): # this is to avoid calling the potentially expensive getmember() method f.seek(512) f.write(new_entry_json) + self._file.flush() + self._index_has_changed = False def _download_file_byte_range(url: str, start: int, end: int) -> bytes: