From 07617631213e81cf38da1e6840bc8b0b2f568392 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 25 Aug 2023 02:13:47 -0400 Subject: [PATCH] simplify: support model deviation of energy per atom (#1312) Signed-off-by: Jinzhe Zeng Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Han Wang <92130845+wanghan-iapcm@users.noreply.github.com> --- dpgen/simplify/arginfo.py | 16 ++++ dpgen/simplify/simplify.py | 30 +++++-- tests/simplify/test_post_model_devi.py | 116 +++++++++++++++++++++++++ 3 files changed, 155 insertions(+), 7 deletions(-) create mode 100644 tests/simplify/test_post_model_devi.py diff --git a/dpgen/simplify/arginfo.py b/dpgen/simplify/arginfo.py index 41ca5676f..9aa8f0234 100644 --- a/dpgen/simplify/arginfo.py +++ b/dpgen/simplify/arginfo.py @@ -32,6 +32,8 @@ def general_simplify_arginfo() -> Argument: doc_model_devi_f_trust_hi = ( "The higher bound of forces for the selection for the model deviation." ) + doc_model_devi_e_trust_lo = "The lower bound of energy per atom for the selection for the model deviation. Requires DeePMD-kit version >=2.2.2." + doc_model_devi_e_trust_hi = "The higher bound of energy per atom for the selection for the model deviation. Requires DeePMD-kit version >=2.2.2." return [ Argument("labeled", bool, optional=True, default=False, doc=doc_labeled), @@ -50,6 +52,20 @@ def general_simplify_arginfo() -> Argument: optional=False, doc=doc_model_devi_f_trust_hi, ), + Argument( + "model_devi_e_trust_lo", + float, + optional=True, + default=float("inf"), + doc=doc_model_devi_e_trust_lo, + ), + Argument( + "model_devi_e_trust_hi", + float, + optional=True, + default=float("inf"), + doc=doc_model_devi_e_trust_hi, + ), ] diff --git a/dpgen/simplify/simplify.py b/dpgen/simplify/simplify.py index 51ff364cb..e5dc24d7c 100644 --- a/dpgen/simplify/simplify.py +++ b/dpgen/simplify/simplify.py @@ -268,6 +268,8 @@ def post_model_devi(iter_index, jdata, mdata): f_trust_lo = jdata["model_devi_f_trust_lo"] f_trust_hi = jdata["model_devi_f_trust_hi"] + e_trust_lo = jdata["model_devi_e_trust_lo"] + e_trust_hi = jdata["model_devi_e_trust_hi"] type_map = jdata.get("type_map", []) sys_accurate = dpdata.MultiSystems(type_map=type_map) @@ -285,16 +287,30 @@ def post_model_devi(iter_index, jdata, mdata): if line.startswith("# data.rest.old"): name = (line.split()[1]).split("/")[-1] elif line.startswith("#"): - pass + columns = line.split()[1:] + cidx_step = columns.index("step") + cidx_max_devi_f = columns.index("max_devi_f") + try: + cidx_devi_e = columns.index("devi_e") + except ValueError: + # DeePMD-kit < 2.2.2 + cidx_devi_e = None else: - idx = int(line.split()[0]) - f_devi = float(line.split()[4]) + idx = int(line.split()[cidx_step]) + f_devi = float(line.split()[cidx_max_devi_f]) + if cidx_devi_e is not None: + e_devi = float(line.split()[cidx_devi_e]) + else: + e_devi = 0.0 subsys = sys_entire[name][idx] - if f_trust_lo <= f_devi < f_trust_hi: - sys_candinate.append(subsys) - elif f_devi >= f_trust_hi: + if f_devi >= f_trust_hi or e_devi >= e_trust_hi: sys_failed.append(subsys) - elif f_devi < f_trust_lo: + elif ( + f_trust_lo <= f_devi < f_trust_hi + or e_trust_lo <= e_devi < e_trust_hi + ): + sys_candinate.append(subsys) + elif f_devi < f_trust_lo and e_devi < e_trust_lo: sys_accurate.append(subsys) else: raise RuntimeError("reach a place that should NOT be reached...") diff --git a/tests/simplify/test_post_model_devi.py b/tests/simplify/test_post_model_devi.py new file mode 100644 index 000000000..0eeac7fc2 --- /dev/null +++ b/tests/simplify/test_post_model_devi.py @@ -0,0 +1,116 @@ +import os +import shutil +import sys +import unittest +from pathlib import Path + +import dpdata +import numpy as np + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +__package__ = "simplify" +from .context import dpgen + + +class TestSimplifyModelDevi(unittest.TestCase): + def setUp(self): + self.work_path = Path("iter.000001") / dpgen.simplify.simplify.model_devi_name + self.work_path.mkdir(exist_ok=True, parents=True) + self.system = dpdata.System( + data={ + "atom_names": ["H"], + "atom_numbs": [1], + "atom_types": np.zeros((1,), dtype=int), + "coords": np.zeros((1, 1, 3), dtype=np.float32), + "cells": np.zeros((1, 3, 3), dtype=np.float32), + "orig": np.zeros(3, dtype=np.float32), + "nopbc": True, + "energies": np.zeros((1,), dtype=np.float32), + "forces": np.zeros((1, 1, 3), dtype=np.float32), + } + ) + self.system.to_deepmd_npy( + self.work_path / "data.rest.old" / self.system.formula + ) + model_devi = np.array([[0, 0.2, 0.1, 0.15, 0.2, 0.1, 0.15, 0.2]]) + np.savetxt( + self.work_path / "details", + model_devi, + fmt=["%12d"] + ["%19.6e" for _ in range(7)], + header="data.rest.old/" + + self.system.formula + + "\n step max_devi_v min_devi_v avg_devi_v max_devi_f min_devi_f avg_devi_f devi_e", + ) + + def tearDown(self): + shutil.rmtree("iter.000001", ignore_errors=True) + + def test_post_model_devi_f_candidate(self): + dpgen.simplify.simplify.post_model_devi( + 1, + { + "model_devi_f_trust_lo": 0.15, + "model_devi_f_trust_hi": 0.25, + "model_devi_e_trust_lo": float("inf"), + "model_devi_e_trust_hi": float("inf"), + "iter_pick_number": 1, + }, + {}, + ) + assert (self.work_path / "data.picked" / self.system.formula).exists() + + def test_post_model_devi_e_candidate(self): + dpgen.simplify.simplify.post_model_devi( + 1, + { + "model_devi_e_trust_lo": 0.15, + "model_devi_e_trust_hi": 0.25, + "model_devi_f_trust_lo": float("inf"), + "model_devi_f_trust_hi": float("inf"), + "iter_pick_number": 1, + }, + {}, + ) + assert (self.work_path / "data.picked" / self.system.formula).exists() + + def test_post_model_devi_f_failed(self): + with self.assertRaises(RuntimeError): + dpgen.simplify.simplify.post_model_devi( + 1, + { + "model_devi_f_trust_lo": 0.0, + "model_devi_f_trust_hi": 0.0, + "model_devi_e_trust_lo": float("inf"), + "model_devi_e_trust_hi": float("inf"), + "iter_pick_number": 1, + }, + {}, + ) + + def test_post_model_devi_e_failed(self): + with self.assertRaises(RuntimeError): + dpgen.simplify.simplify.post_model_devi( + 1, + { + "model_devi_e_trust_lo": 0.0, + "model_devi_e_trust_hi": 0.0, + "model_devi_f_trust_lo": float("inf"), + "model_devi_f_trust_hi": float("inf"), + "iter_pick_number": 1, + }, + {}, + ) + + def test_post_model_devi_accurate(self): + dpgen.simplify.simplify.post_model_devi( + 1, + { + "model_devi_e_trust_lo": 0.3, + "model_devi_e_trust_hi": 0.4, + "model_devi_f_trust_lo": 0.3, + "model_devi_f_trust_hi": 0.4, + "iter_pick_number": 1, + }, + {}, + ) + assert (self.work_path / "data.accurate" / self.system.formula).exists()