Skip to content

Commit

Permalink
encode nan, inf, -inf attr as strings
Browse files Browse the repository at this point in the history
  • Loading branch information
magland committed Mar 15, 2024
1 parent 1e78461 commit 35121e1
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 6 deletions.
11 changes: 10 additions & 1 deletion .vscode/tasks.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,22 @@
"version": "2.0.0",
"tasks": [
{
"label": "Test",
"label": "Run tests",
"type": "shell",
"command": "bash -ic .vscode/tasks/test.sh",
"presentation": {
"clear": true
},
"detail": "Run tests"
},
{
"label": "Run quck tests",
"type": "shell",
"command": "bash -ic .vscode/tasks/quick_test.sh",
"presentation": {
"clear": true
},
"detail": "Run quick tests"
}
]
}
7 changes: 7 additions & 0 deletions .vscode/tasks/quick_test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#!/bin/bash
set -ex

# black --check .
flake8 .
# pyright
pytest --cov=lindi --cov-report=xml --cov-report=term -m "not slow" tests/
51 changes: 47 additions & 4 deletions lindi/LindiH5Store/LindiH5Store.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,10 @@ def close(self):

@staticmethod
def from_file(
hdf5_file_name_or_url: str, *,
hdf5_file_name_or_url: str,
*,
opts: LindiH5StoreOpts = LindiH5StoreOpts(),
url: Union[str, None] = None
url: Union[str, None] = None,
):
"""
Create a LindiH5Store from a file or url.
Expand Down Expand Up @@ -348,7 +349,9 @@ def _get_external_array_link(self, parent_key: str, h5_item: h5py.Dataset):
"name": parent_key,
}
else:
print(f'WARNING when creating external array link for {parent_key}: url is not set, so external array link will not work')
print(
f"WARNING when creating external array link for {parent_key}: url is not set, so external array link will not work"
)
return self._external_array_links[parent_key]

def listdir(self, path: str = "") -> List[str]:
Expand Down Expand Up @@ -484,4 +487,44 @@ def _get_chunk_names_for_dataset(chunk_coords_shape: List[int]) -> List[str]:
def _reformat_json(x: Union[bytes, None]) -> Union[bytes, None]:
if x is None:
return None
return json.dumps(json.loads(x.decode("utf-8"))).encode("utf-8")
a = json.loads(x.decode("utf-8"))
return json.dumps(a, cls=FloatJSONEncoder).encode("utf-8")


# From https://github.com/rly/h5tojson/blob/b162ff7f61160a48f1dc0026acb09adafdb422fa/h5tojson/h5tojson.py#L121-L156
class FloatJSONEncoder(json.JSONEncoder):
"""JSON encoder that converts NaN, Inf, and -Inf to strings."""

def encode(self, obj, *args, **kwargs): # type: ignore
"""Convert NaN, Inf, and -Inf to strings."""
obj = FloatJSONEncoder._convert_nan(obj)
return super().encode(obj, *args, **kwargs)

def iterencode(self, obj, *args, **kwargs): # type: ignore
"""Convert NaN, Inf, and -Inf to strings."""
obj = FloatJSONEncoder._convert_nan(obj)
return super().iterencode(obj, *args, **kwargs)

@staticmethod
def _convert_nan(obj):
"""Convert NaN, Inf, and -Inf from a JSON object to strings."""
if isinstance(obj, dict):
return {k: FloatJSONEncoder._convert_nan(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [FloatJSONEncoder._convert_nan(v) for v in obj]
elif isinstance(obj, float):
return FloatJSONEncoder._nan_to_string(obj)
return obj

@staticmethod
def _nan_to_string(obj: float):
"""Convert NaN, Inf, and -Inf from a float to a string."""
if np.isnan(obj):
return "NaN"
elif np.isinf(obj):
if obj > 0:
return "Infinity"
else:
return "-Infinity"
else:
return float(obj)
5 changes: 4 additions & 1 deletion pytest.ini
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
[pytest]
addopts = --verbose
log_cli = true
log_cli_level = INFO
log_cli_level = INFO
markers =
slow: marks tests as slow (deselect with '-m "not slow"')
network: marks tests as network (deselect with '-m "not network"')
29 changes: 29 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,30 @@ def test_attributes():
raise ValueError("Attribute mismatch")


def test_nan_inf_attr():
print("Testing NaN, Inf, and -Inf attributes")
with tempfile.TemporaryDirectory() as tmpdir:
filename = f"{tmpdir}/test.h5"
with h5py.File(filename, "w") as f:
f.create_dataset("X", data=[1, 2, 3])
f["X"].attrs["nan"] = np.nan
f["X"].attrs["inf"] = np.inf
f["X"].attrs["ninf"] = -np.inf
h5f = h5py.File(filename, "r")
with LindiH5Store.from_file(filename, url=filename) as store:
rfs = store.to_reference_file_system()
client = LindiClient.from_reference_file_system(rfs)

X1 = h5f["X"]
assert isinstance(X1, h5py.Dataset)
X2 = client["X"]
assert isinstance(X2, LindiDataset)

assert X2.attrs["nan"] == 'NaN'
assert X2.attrs["inf"] == 'Infinity'
assert X2.attrs["ninf"] == '-Infinity'


def _check_equal(a, b):
# allow comparison of bytes and strings
if isinstance(a, str):
Expand Down Expand Up @@ -184,6 +208,11 @@ def _check_equal(a, b):
assert isinstance(b, np.ndarray)
return _check_arrays_equal(a, b)

# test for NaNs (we need to use np.isnan because NaN != NaN in python)
if isinstance(a, float) and isinstance(b, float):
if np.isnan(a) and np.isnan(b):
return True

return a == b


Expand Down
3 changes: 3 additions & 0 deletions tests/test_with_real_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import remfile
from lindi import LindiH5Store, LindiClient
import lindi
import pytest

examples = []

Expand Down Expand Up @@ -272,6 +273,8 @@ def _hdf5_visit_items(item, callback):
return


@pytest.mark.network
@pytest.mark.slow
def test_with_real_data():
example_num = 0
example = examples[example_num]
Expand Down

0 comments on commit 35121e1

Please sign in to comment.