Skip to content

Commit

Permalink
Add keys after the initial dataset has been constructed (#119)
Browse files Browse the repository at this point in the history
* tests for adding keys after the initial dataset has been constructed

* code cleanup and update tests

* further cleanup and tests updates

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* disable the tests for now

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* raise an error for now

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
PythonFZ and pre-commit-ci[bot] authored Aug 23, 2024
1 parent 9252639 commit 5340cc0
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 74 deletions.
4 changes: 4 additions & 0 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def test_datasets(tmp_path, dataset, request):
assert b.get_potential_energy() == a.get_potential_energy()
assert isinstance(a.get_potential_energy(), float)
assert isinstance(b.get_potential_energy(), float)
else:
assert b.calc is None

assert set(a.arrays) == set(b.arrays)
for key in a.arrays:
Expand Down Expand Up @@ -85,6 +87,8 @@ def test_datasets_extxyz(tmp_path, dataset, request):
assert b.get_potential_energy() == a.get_potential_energy()
assert isinstance(a.get_potential_energy(), float)
assert isinstance(b.get_potential_energy(), float)
else:
assert b.calc is None

assert set(a.arrays) == set(b.arrays)
for key in a.arrays:
Expand Down
32 changes: 32 additions & 0 deletions tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,38 @@ def test_extend_empty(tmp_path):
assert len(io) == 22


def test_add_new_keys_info(tmp_path):
io = znh5md.IO(tmp_path / "test.h5")
water = ase.build.molecule("H2O")

io.append(water)
water.info["key1"] = 1
with pytest.raises(ValueError):
io.append(water)

# assert len(io) == 2
# assert "key1" not in io[0].info
# assert "key1" in io[1].info

# assert io[1].info["key1"] == 1


def test_add_new_keys_arrays(tmp_path):
io = znh5md.IO(tmp_path / "test.h5")
water = ase.build.molecule("H2O")

io.append(water)
water.arrays["key1"] = np.zeros((len(water), 3))
with pytest.raises(ValueError):
io.append(water)

# assert len(io) == 2
# assert "key1" not in io[0].arrays
# assert "key1" in io[1].arrays

# assert np.allclose(io[1].arrays["key1"], np.zeros((len(water), 3)))


def test_extend_single(tmp_path):
vectors = np.random.rand(3, 3, 2, 3)

Expand Down
112 changes: 38 additions & 74 deletions znh5md/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,13 +254,27 @@ def _create_particle_group(self, f, data: fmt.ASEData):
step=data.step,
json=data.metadata.get(key, {}).get("type") == "json",
)
self._create_observables(
f,
data.observables,
data.metadata,
time=data.time,
step=data.step,
)
for key, value in tqdm(
data.observables.items(),
ncols=120,
desc="Creating observables",
disable=_disable_tqdm,
):
if "observables" not in f:
g_observables_grp = f.create_group("observables")
g_observables_grp = f["observables"].create_group(self.particle_group)
self._create_group(
g_observables_grp,
key,
value,
data.metadata.get(key, {}).get("unit"),
calc=data.metadata.get(key, {}).get("calc")
if self.use_ase_calc
else None,
time=data.time,
step=data.step,
json=data.metadata.get(key, {}).get("type") == "json",
)

def _create_group(
self,
Expand Down Expand Up @@ -338,46 +352,6 @@ def _add_time_and_step(self, grp, step: np.ndarray, time: np.ndarray):
ds_time[()] = self.timestep
ds_step[()] = 1

def _create_observables(
self,
f,
info_data,
metadata: dict,
time: np.ndarray | None = None,
step: np.ndarray | None = None,
):
if info_data:
g_observables = f.require_group("observables")
g_info = g_observables.require_group(self.particle_group)
for key, value in info_data.items():
g_observable = g_info.create_group(key)
ds_value = g_observable.create_dataset(
"value",
data=value,
dtype=utils.get_h5py_dtype(value),
chunks=True
if self.chunk_size is None
else tuple([self.chunk_size] + list(value.shape[1:])),
maxshape=([None] * value.ndim),
compression=self.compression,
compression_opts=self.compression_opts,
)
if metadata.get(key, {}).get("type") == "json":
ds_value.attrs["ZNH5MD_TYPE"] = "json"
if self.use_ase_calc and metadata.get(key, {}).get("calc") is not None:
ds_value.attrs["ASE_CALCULATOR_RESULT"] = metadata[key]["calc"]
if metadata.get(key, {}).get("unit") and self.save_units:
ds_value.attrs["unit"] = metadata[key]["unit"]
if time is None:
time = np.arange(len(value)) * self.timestep
elif self.store == "linear":
warnings.warn("time is ignored in 'linear' storage mode")
if step is None:
step = np.arange(len(value))
elif self.store == "linear":
warnings.warn("step is ignored in 'linear' storage mode")
self._add_time_and_step(g_observable, step, time)

def _extend_existing_data(self, f, data: fmt.ASEData):
g_particle_grp = f["particles"][self.particle_group]
self._extend_group(
Expand All @@ -397,7 +371,21 @@ def _extend_existing_data(self, f, data: fmt.ASEData):
self._extend_group(
g_particle_grp, key, value, step=data.step, time=data.time
)
self._extend_observables(f, data.observables, step=data.step, time=data.time)
for key, value in tqdm(
data.observables.items(),
ncols=120,
desc="Extending observables",
disable=_disable_tqdm,
):
if "observables" not in f:
g_observables_grp = f.create_group("observables")
if self.particle_group not in f["observables"]:
g_observables_grp = f["observables"].create_group(self.particle_group)
else:
g_observables_grp = f["observables"][self.particle_group]
self._extend_group(
g_observables_grp, key, value, step=data.step, time=data.time
)

def _extend_group(
self,
Expand All @@ -407,6 +395,8 @@ def _extend_group(
step: np.ndarray | None = None,
time: np.ndarray | None = None,
):
if name not in parent_grp:
raise ValueError(f"Group {name} not found in {parent_grp.name}")
if data is not None and name in parent_grp:
g_grp = parent_grp[name]
utils.fill_dataset(g_grp["value"], data)
Expand All @@ -423,32 +413,6 @@ def _extend_group(
)
utils.fill_dataset(g_grp["step"], step)

def _extend_observables(
self,
f,
info_data,
step: np.ndarray | None = None,
time: np.ndarray | None = None,
):
if f"observables/{self.particle_group}" in f:
g_observables = f[f"observables/{self.particle_group}"]
for key, value in info_data.items():
if key in g_observables:
g_val = g_observables[key]
utils.fill_dataset(g_val["value"], value)
if self.store == "time":
if time is None:
last_time = g_val["time"][-1]
time = np.arange(len(value)) * self.timestep + last_time
if step is None:
last_step = g_val["step"][-1]
step = np.arange(len(value)) + last_step
utils.fill_dataset(
g_val["time"],
time,
)
utils.fill_dataset(g_val["step"], step)

def append(self, atoms: ase.Atoms):
if not isinstance(atoms, ase.Atoms):
raise ValueError("atoms must be an ASE Atoms object")
Expand Down

0 comments on commit 5340cc0

Please sign in to comment.