Skip to content

Commit

Permalink
expand tests
Browse files Browse the repository at this point in the history
  • Loading branch information
magland committed Sep 20, 2024
1 parent a446794 commit bdf10d6
Show file tree
Hide file tree
Showing 18 changed files with 700 additions and 768 deletions.
2 changes: 1 addition & 1 deletion examples/amend_remote_nwb_as_lindi_tar.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
rate=1.,
unit='s'
)
ts = nwbfile.processing['behavior'].add(timeseries_test) # type: ignore
nwbfile.processing['behavior'].add(timeseries_test) # type: ignore
io.write(nwbfile) # type: ignore

# Later on, you can read the file again
Expand Down
2 changes: 1 addition & 1 deletion lindi/LindiH5ZarrStore/LindiH5ZarrStore.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(self, h5_item, *, contiguous_dataset_max_chunk_size: Union[int, Non
nn = contiguous_dataset_max_chunk_size // size0
if nn == 0:
# The chunk size should not be zero
nn = 1
nn = 1 # pragma: no cover
self._split_chunk_shape = (nn,) + h5_item.shape[1:]
if h5_item.chunks is not None:
zero_chunk_coords = (0,) * h5_item.ndim
Expand Down
77 changes: 34 additions & 43 deletions lindi/LindiH5pyFile/LindiH5pyFile.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@


class LindiH5pyFile(h5py.File):
def __init__(self, _zarr_group: zarr.Group, *, _zarr_store: Union[ZarrStore, None] = None, _mode: LindiFileMode = "r", _local_cache: Union[LocalCache, None] = None, _source_url_or_path: Union[str, None] = None, _source_tar_file: Union[LindiTarFile, None] = None, _close_source_tar_file_on_close: bool = False):
def __init__(self, _zarr_group: zarr.Group, *, _zarr_store: ZarrStore, _mode: LindiFileMode = "r", _local_cache: Union[LocalCache, None] = None, _source_url_or_path: Union[str, None] = None, _source_tar_file: Union[LindiTarFile, None] = None, _close_source_tar_file_on_close: bool = False):
"""
Do not use this constructor directly. Instead, use: from_lindi_file,
from_h5py_file, from_reference_file_system, from_zarr_store, or
Expand Down Expand Up @@ -92,7 +92,7 @@ def from_hdf5_file(
"""
from ..LindiH5ZarrStore.LindiH5ZarrStore import LindiH5ZarrStore # avoid circular import
if mode != "r":
raise Exception("Opening hdf5 file in write mode is not supported")
raise ValueError("Opening hdf5 file in write mode is not supported")
zarr_store = LindiH5ZarrStore.from_file(url_or_path, local_cache=local_cache, opts=zarr_store_opts, url=url)
return LindiH5pyFile.from_zarr_store(
zarr_store=zarr_store,
Expand All @@ -101,7 +101,7 @@ def from_hdf5_file(
)

@staticmethod
def from_reference_file_system(rfs: Union[dict, str, None], *, mode: LindiFileMode = "r", local_cache: Union[LocalCache, None] = None, _source_url_or_path: Union[str, None] = None, _source_tar_file: Union[LindiTarFile, None] = None, _close_source_tar_file_on_close: bool = False):
def from_reference_file_system(rfs: Union[dict, str], *, mode: LindiFileMode = "r", local_cache: Union[LocalCache, None] = None, _source_url_or_path: Union[str, None] = None, _source_tar_file: Union[LindiTarFile, None] = None, _close_source_tar_file_on_close: bool = False):
"""
Create a LindiH5pyFile from a reference file system.
Expand All @@ -122,20 +122,11 @@ def from_reference_file_system(rfs: Union[dict, str, None], *, mode: LindiFileMo
_close_source_tar_file_on_close : bool, optional
Internal use only
"""
if rfs is None:
rfs = {
"refs": {
'.zgroup': {
'zarr_format': 2
}
},
}

if isinstance(rfs, str):
if _source_url_or_path is not None:
raise Exception("_source_file_path is not None even though rfs is a string")
raise Exception("_source_file_path is not None even though rfs is a string") # pragma: no cover
if _source_tar_file is not None:
raise Exception("_source_tar_file is not None even though rfs is a string")
raise Exception("_source_tar_file is not None even though rfs is a string") # pragma: no cover
rfs_is_url = rfs.startswith("http://") or rfs.startswith("https://")
if rfs_is_url:
data, tar_file = _load_rfs_from_url(rfs)
Expand All @@ -153,26 +144,28 @@ def from_reference_file_system(rfs: Union[dict, str, None], *, mode: LindiFileMo
if mode == "r":
# Readonly, file must exist (default)
if not os.path.exists(rfs):
raise Exception(f"File does not exist: {rfs}")
raise FileNotFoundError(f"File does not exist: {rfs}")
elif mode == "r+":
# Read/write, file must exist
if not os.path.exists(rfs):
raise Exception(f"File does not exist: {rfs}")
raise FileNotFoundError(f"File does not exist: {rfs}")
elif mode == "w":
# Create file, truncate if exists
need_to_create_empty_file = True

elif mode in ["w-", "x"]:
# Create file, fail if exists
if os.path.exists(rfs):
raise Exception(f"File already exists: {rfs}")
raise ValueError(f"File already exists: {rfs}")
need_to_create_empty_file = True
# Now that we have already checked for existence, let's just change mode to 'w'
mode = 'w'
elif mode == "a":
# Read/write if exists, create otherwise
if not os.path.exists(rfs):
need_to_create_empty_file = True
else:
raise Exception(f"Unhandled mode: {mode}")
raise Exception(f"Unhandled mode: {mode}") # pragma: no cover
if need_to_create_empty_file:
is_tar = rfs.endswith(".tar")
is_dir = rfs.endswith(".d")
Expand Down Expand Up @@ -207,7 +200,7 @@ def from_reference_file_system(rfs: Union[dict, str, None], *, mode: LindiFileMo
_close_source_tar_file_on_close=_close_source_tar_file_on_close
)
else:
raise Exception(f"Unhandled type for rfs: {type(rfs)}")
raise Exception(f"Unhandled type for rfs: {type(rfs)}") # pragma: no cover

@staticmethod
def from_zarr_store(zarr_store: ZarrStore, mode: LindiFileMode = "r", local_cache: Union[LocalCache, None] = None, _source_url_or_path: Union[str, None] = None, _source_tar_file: Union[LindiTarFile, None] = None, _close_source_tar_file_on_close: bool = False):
Expand All @@ -230,7 +223,7 @@ def from_zarr_store(zarr_store: ZarrStore, mode: LindiFileMode = "r", local_cach
return LindiH5pyFile.from_zarr_group(zarr_group, _zarr_store=zarr_store, mode=mode, local_cache=local_cache, _source_url_or_path=_source_url_or_path, _source_tar_file=_source_tar_file, _close_source_tar_file_on_close=_close_source_tar_file_on_close)

@staticmethod
def from_zarr_group(zarr_group: zarr.Group, *, mode: LindiFileMode = "r", _zarr_store: Union[ZarrStore, None] = None, local_cache: Union[LocalCache, None] = None, _source_url_or_path: Union[str, None] = None, _source_tar_file: Union[LindiTarFile, None] = None, _close_source_tar_file_on_close: bool = False):
def from_zarr_group(zarr_group: zarr.Group, *, mode: LindiFileMode = "r", _zarr_store: ZarrStore, local_cache: Union[LocalCache, None] = None, _source_url_or_path: Union[str, None] = None, _source_tar_file: Union[LindiTarFile, None] = None, _close_source_tar_file_on_close: bool = False):
"""
Create a LindiH5pyFile from a zarr group.
Expand All @@ -255,15 +248,13 @@ def to_reference_file_system(self):
Export the internal in-memory representation to a reference file system.
"""
from ..LindiH5ZarrStore.LindiH5ZarrStore import LindiH5ZarrStore # avoid circular import
if self._zarr_store is None:
raise Exception("Cannot convert to reference file system without zarr store")
zarr_store = self._zarr_store
if isinstance(zarr_store, LindiTarStore):
zarr_store = zarr_store._base_store
if isinstance(zarr_store, LindiH5ZarrStore):
return zarr_store.to_reference_file_system()
if not isinstance(zarr_store, LindiReferenceFileSystemStore):
raise Exception(f"Cannot create reference file system when zarr store has type {type(self._zarr_store)}")
raise Exception(f"Cannot create reference file system when zarr store has type {type(self._zarr_store)}") # pragma: no cover
rfs = zarr_store.rfs
rfs_copy = json.loads(json.dumps(rfs))
LindiReferenceFileSystemStore.replace_meta_file_contents_with_dicts_in_rfs(rfs_copy)
Expand All @@ -277,19 +268,19 @@ def write_lindi_file(self, filename: str, *, generation_metadata: Union[dict, No
Parameters
----------
filename : str
The filename to write to. It must end with '.lindi.json' or '.lindi.tar'.
The filename (or directory) to write to. It must end with '.lindi.json', '.lindi.tar', or '.lindi.d'.
generation_metadata : Union[dict, None], optional
The optional generation metadata to include in the reference file
system, by default None. This information dict is simply set to the
'generationMetadata' key in the reference file system.
"""
if not filename.endswith(".lindi.json") and not filename.endswith(".lindi.tar"):
raise Exception("Filename must end with '.lindi.json' or '.lindi.tar'")
if not filename.endswith(".lindi.json") and not filename.endswith(".lindi.tar") and not filename.endswith(".lindi.d"):
raise ValueError("Filename must end with '.lindi.json' or '.lindi.tar'")
rfs = self.to_reference_file_system()
if self._source_tar_file:
source_is_remote = self._source_url_or_path is not None and (self._source_url_or_path.startswith("http://") or self._source_url_or_path.startswith("https://"))
if not source_is_remote:
raise Exception("Cannot write to lindi file if the source is a local lindi tar file because it would not be able to resolve the local references within the tar file.")
raise ValueError("Cannot write to lindi file if the source is a local lindi tar file because it would not be able to resolve the local references within the tar file.")
assert self._source_url_or_path is not None
_update_internal_references_to_remote_tar_file(rfs, self._source_url_or_path, self._source_tar_file)
if generation_metadata is not None:
Expand All @@ -301,7 +292,7 @@ def write_lindi_file(self, filename: str, *, generation_metadata: Union[dict, No
elif filename.endswith(".d"):
LindiTarFile.create(filename, rfs=rfs, dir_representation=True)
else:
raise Exception("Unhandled file extension")
raise Exception("Unhandled file extension") # pragma: no cover

@property
def attrs(self): # type: ignore
Expand Down Expand Up @@ -336,20 +327,20 @@ def swmr_mode(self, value): # type: ignore

def close(self):
if not self._is_open:
print('Warning: LINDI file already closed.')
return
print('Warning: LINDI file already closed.') # pragma: no cover
return # pragma: no cover
self.flush()
if self._close_source_tar_file_on_close and self._source_tar_file:
self._source_tar_file.close()
self._is_open = False

def flush(self):
if not self._is_open:
return
return # pragma: no cover
if self._mode != 'r' and self._source_url_or_path is not None:
is_url = self._source_url_or_path.startswith("http://") or self._source_url_or_path.startswith("https://")
if is_url:
raise Exception("Cannot write to URL")
raise Exception("Cannot write to URL") # pragma: no cover
rfs = self.to_reference_file_system()
if self._source_tar_file:
self._source_tar_file.write_rfs(rfs)
Expand Down Expand Up @@ -394,7 +385,7 @@ def copy(self, source, dest, name=None,
raise Exception("name must be provided for copy")
src_item = self._get_item(source)
if not isinstance(src_item, (h5py.Group, h5py.Dataset)):
raise Exception(f"Unexpected type for source in copy: {type(src_item)}")
raise Exception(f"Unexpected type for source in copy: {type(src_item)}") # pragma: no cover
_recursive_copy(src_item, dest, name=name)

def __delitem__(self, name):
Expand All @@ -413,14 +404,14 @@ def _get_item(self, name, getlink=False, default=None):
raise Exception("Getting link is not allowed for references")
zarr_group = self._zarr_group
if name._source != '.':
raise Exception(f'For now, source of reference must be ".", got "{name._source}"')
raise Exception(f'For now, source of reference must be ".", got "{name._source}"') # pragma: no cover
if name._source_object_id is not None:
if name._source_object_id != zarr_group.attrs.get("object_id"):
raise Exception(f'Mismatch in source object_id: "{name._source_object_id}" and "{zarr_group.attrs.get("object_id")}"')
if name._source_object_id != zarr_group.attrs.get("object_id"): # pragma: no cover
raise Exception(f'Mismatch in source object_id: "{name._source_object_id}" and "{zarr_group.attrs.get("object_id")}"') # pragma: no cover
target = self[name._path]
if name._object_id is not None:
if name._object_id != target.attrs.get("object_id"):
raise Exception(f'Mismatch in object_id: "{name._object_id}" and "{target.attrs.get("object_id")}"')
if name._object_id != target.attrs.get("object_id"): # pragma: no cover
raise Exception(f'Mismatch in object_id: "{name._object_id}" and "{target.attrs.get("object_id")}"') # pragma: no cover
return target
# if it contains slashes, it's a path
if isinstance(name, str) and "/" in name:
Expand Down Expand Up @@ -477,24 +468,24 @@ def ref(self):
# write
def create_group(self, name, track_order=None):
if self._mode == 'r':
raise Exception("Cannot create group in read-only mode")
raise ValueError("Cannot create group in read-only mode")
if track_order is not None:
raise Exception("track_order is not supported (I don't know what it is)")
return self._the_group.create_group(name)

def require_group(self, name):
if self._mode == 'r':
raise Exception("Cannot require group in read-only mode")
raise ValueError("Cannot require group in read-only mode")
return self._the_group.require_group(name)

def create_dataset(self, name, shape=None, dtype=None, data=None, **kwds):
if self._mode == 'r':
raise Exception("Cannot create dataset in read-only mode")
raise ValueError("Cannot create dataset in read-only mode")
return self._the_group.create_dataset(name, shape=shape, dtype=dtype, data=data, **kwds)

def require_dataset(self, name, shape, dtype, exact=False, **kwds):
if self._mode == 'r':
raise Exception("Cannot require dataset in read-only mode")
raise ValueError("Cannot require dataset in read-only mode")
return self._the_group.require_dataset(name, shape, dtype, exact=exact, **kwds)


Expand Down Expand Up @@ -522,7 +513,7 @@ def _recursive_copy(src_item: Union[h5py.Group, h5py.Dataset], dest: h5py.File,
# data because we can copy the reference.
if isinstance(src_item.file, LindiH5pyFile) and isinstance(dest, LindiH5pyFile):
if src_item.name is None:
raise Exception("src_item.name is None")
raise Exception("src_item.name is None") # pragma: no cover
src_item_name = _without_initial_slash(src_item.name)
src_zarr_store = src_item.file._zarr_store
dst_zarr_store = dest._zarr_store
Expand Down
1 change: 1 addition & 0 deletions lindi/LindiH5pyFile/writers/LindiH5pyGroupWriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def create_group(self, name, track_order=None):
def require_group(self, name):
if name in self.p:
ret = self.p[name]
from ..LindiH5pyGroup import LindiH5pyGroup # avoid circular import
if not isinstance(ret, LindiH5pyGroup):
raise Exception(f'Expected a group at {name} but got {type(ret)}')
return ret
Expand Down
1 change: 1 addition & 0 deletions lindi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from .LindiH5pyFile import LindiH5pyFile, LindiH5pyGroup, LindiH5pyDataset, LindiH5pyHardLink, LindiH5pySoftLink # noqa: F401
from .LocalCache.LocalCache import LocalCache, ChunkTooLargeError # noqa: F401
from .LindiRemfile.additional_url_resolvers import add_additional_url_resolver # noqa: F401
from .LindiH5pyFile.LindiH5pyReference import LindiH5pyReference # noqa: F401
2 changes: 1 addition & 1 deletion lindi/conversion/create_zarr_dataset_from_h5_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def _get_default_chunks(shape: Tuple, dtype: Any) -> Tuple:
shape_prod_0 = np.prod(shape[1:])
optimal_chunk_size_bytes = 1024 * 1024 * 20 # 20 MB
optimal_chunk_size = optimal_chunk_size_bytes // (dtype_size * shape_prod_0)
if optimal_chunk_size <= shape[0]:
if optimal_chunk_size > shape[0]:
return shape
if optimal_chunk_size < 1:
return (1,) + shape[1:]
Expand Down
35 changes: 25 additions & 10 deletions lindi/tar/lindi_tar.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,18 @@
INITIAL_LINDI_JSON_SIZE = 1024 * 8


# for tests
def _test_set(
tar_entry_json_size: int,
initial_tar_index_json_size: int,
initial_lindi_json_size: int
):
global TAR_ENTRY_JSON_SIZE, INITIAL_TAR_INDEX_JSON_SIZE, INITIAL_LINDI_JSON_SIZE
TAR_ENTRY_JSON_SIZE = tar_entry_json_size
INITIAL_TAR_INDEX_JSON_SIZE = initial_tar_index_json_size
INITIAL_LINDI_JSON_SIZE = initial_lindi_json_size


class LindiTarFile:
def __init__(self, tar_path_or_url: str, dir_representation=False):
self._tar_path_or_url = tar_path_or_url
Expand Down Expand Up @@ -93,14 +105,18 @@ def overwrite_file_content(self, file_name: str, data: bytes):
self._file.seek(info['d'])
self._file.write(data)
else:
# for safety:
file_parts = file_name.split("/")
for part in file_parts[:-1]:
if part.startswith('..'):
raise ValueError(f"Invalid path: {file_name}")
fname = self._tar_path_or_url + "/" + file_name
with open(fname, "wb") as f:
f.write(data)
# Actually not ever used. The file is just replaced.
raise Exception('Overwriting file content in a directory representation is not supported') # pragma: no cover

# But if we did do it, it would look like this:
# # for safety:
# file_parts = file_name.split("/")
# for part in file_parts[:-1]:
# if part.startswith('..'):
# raise ValueError(f"Invalid path: {file_name}")
# fname = self._tar_path_or_url + "/" + file_name
# with open(fname, "wb") as f:
# f.write(data)

def trash_file(self, file_name: str):
if self._is_remote:
Expand Down Expand Up @@ -160,8 +176,7 @@ def write_rfs(self, rfs: dict):
rfs_json = _pad_bytes_to_leave_room_for_growth(rfs_json, INITIAL_LINDI_JSON_SIZE)
self.write_file("lindi.json", rfs_json)
else:
with open(self._tar_path_or_url + "/lindi.json", "wb") as f:
f.write(rfs_json.encode())
self.write_file("lindi.json", rfs_json.encode())

def get_file_byte_range(self, file_name: str) -> tuple:
if self._dir_representation:
Expand Down
Empty file added tests/__init__.py
Empty file.
Loading

0 comments on commit bdf10d6

Please sign in to comment.