Skip to content

Commit

Permalink
fix: distance protect and shape of model devi inconsistence by xiaoya…
Browse files Browse the repository at this point in the history
…ng wang
  • Loading branch information
wangzyphysics committed May 31, 2024
1 parent 036f3a2 commit 0f0ab27
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 1 deletion.
12 changes: 12 additions & 0 deletions dpgen2/exploration/render/traj_render_lammps.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,20 @@ def get_model_devi(

return model_devi

def _load_one_model_devi_deprecated(self, fname, model_devi):
dd = np.loadtxt(fname)
model_devi.add(DeviManager.MAX_DEVI_V, dd[:, 1])
model_devi.add(DeviManager.MIN_DEVI_V, dd[:, 2])
model_devi.add(DeviManager.AVG_DEVI_V, dd[:, 3])
model_devi.add(DeviManager.MAX_DEVI_F, dd[:, 4])
model_devi.add(DeviManager.MIN_DEVI_F, dd[:, 5])
model_devi.add(DeviManager.AVG_DEVI_F, dd[:, 6])

def _load_one_model_devi(self, fname, model_devi):
dd = np.loadtxt(fname)
if len(np.shape(dd)) == 1: # In case model-devi.out is 1-dimensional
dd = dd.reshape((1, len(dd)))

model_devi.add(DeviManager.MAX_DEVI_V, dd[:, 1])
model_devi.add(DeviManager.MIN_DEVI_V, dd[:, 2])
model_devi.add(DeviManager.AVG_DEVI_V, dd[:, 3])
Expand Down
101 changes: 100 additions & 1 deletion dpgen2/op/run_caly_model_devi.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def atoms2lmpdump(atoms, struc_idx, type_map, ignore=False):
return dump_str


def parse_traj(traj_file):
def parse_traj_deprecated(traj_file):
from ase import ( # type: ignore
Atoms,
)
Expand Down Expand Up @@ -272,6 +272,105 @@ def parse_traj(traj_file):
return selected_traj


def parse_traj(traj_file):
from ase import ( # type: ignore
Atoms,
)
from ase.build import ( # type: ignore
make_supercell,
)
from ase.io import ( # type: ignore
read,
)

safe_dist_dict = {
"He": 0.0,
"Li": 1.5,
"Na": 1.45,
"K": 2.3,
"Rb": 2.5,
"Mg": 1.7,
"Ca": 2.3,
"Sr": 2.5,
"Al": 1.7,
"Sc": 2.0,
"Y": 2.1,
"La": 2.5,
"Ti": 2.0,
"Zr": 2.1,
"Hf": 2.4,
"Mo": 2.1,
"W": 2.3,
"B": 1.1,
"C": 1.1,
"Si": 1.6,
"P": 1.5,
"As": 2.0,
"S": 1.5,
"Se": 2.1,
"Te": 2.0,
"Br": 2.3,
"H": 0.813,
}

trajs: List[Atoms] = read(traj_file, index=":", format="traj") # type: ignore
dthresh = 0.72
numb_traj = len(trajs)
assert numb_traj >= 1, "traj file is broken."

# 1st Filter, initial configuration
origin = trajs[0]
origin = make_supercell(origin, [[2, 0, 0], [0, 2, 0], [0, 0, 2]])
dis_mtx = origin.get_all_distances(mic=True)
row, col = np.diag_indices_from(dis_mtx)
dis_mtx[row, col] = np.nan
is_reasonable = np.nanmin(dis_mtx) > dthresh

selected_traj: Union[List[Atoms], None] = None
if is_reasonable:
if len(trajs) >= 20:
selected_traj = [trajs[iii] for iii in [4, 9, -10, -5, -1]]
elif 5 <= len(trajs) < 20:
selected_traj = [
trajs[np.random.randint(3, len(trajs) - 1)] for _ in range(4)
]
selected_traj.append(trajs[-1])
elif 3 <= len(trajs) < 5:
selected_traj = [trajs[round((len(trajs) - 1) / 2)]]
selected_traj.append(trajs[-1])
elif len(trajs) == 2:
selected_traj = [trajs[0], trajs[-1]]
else:
selected_traj = [trajs[0]]

# 2nd filter for selected traj. It filters out all FRAMES that are to close.
i_keep = []
for t in selected_traj:
t2 = make_supercell(t, [[2, 0, 0], [0, 2, 0], [0, 0, 2]])

frame_is_reasonable = True
dist_dict = t2.get_all_distances(mic=True)
atype = t2.get_chemical_symbols()
for a in range(len(atype)):
for b in range(a + 1, len(atype)):
dd = dist_dict[a][b]
dr = (
(safe_dist_dict[atype[a]] + safe_dist_dict[atype[b]])
* 0.529
/ 1.2
)
if dd < dr:
frame_is_reasonable = False

if frame_is_reasonable:
i_keep.append(selected_traj.index(t))
selected_traj = [selected_traj[iii] for iii in i_keep]
else:
selected_traj = None

return selected_traj


def write_model_devi_out(devi: np.ndarray, fname: Union[str, Path], header: str = ""):
assert devi.shape[1] == 8
header = "%s\n%10s" % (header, "step")
Expand Down

0 comments on commit 0f0ab27

Please sign in to comment.