Skip to content

Commit

Permalink
Relaxed ice structures and plotting scripts.
Browse files Browse the repository at this point in the history
  • Loading branch information
NikoOinonen committed Dec 12, 2023
1 parent 3fe7020 commit a8b9347
Show file tree
Hide file tree
Showing 7 changed files with 440 additions and 6 deletions.
12 changes: 11 additions & 1 deletion mlspm/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"AFM-ice-Au111-monolayer": "https://zenodo.org/records/10049832/files/AFM-ice-Au111-monolayer.tar.gz?download=1",
"AFM-ice-Au111-bilayer": "https://zenodo.org/records/10049856/files/AFM-ice-Au111-bilayer.tar.gz?download=1",
"AFM-ice-exp": "https://zenodo.org/records/10054847/files/exp_data_ice.tar.gz?download=1",
"AFM-ice-relaxed": "https://zenodo.org/records/10362511/files/relaxed_structures.tar.gz?download=1",
}


Expand All @@ -29,6 +30,14 @@ def _safe_extract(tar, path=".", members=None, *, numeric_owner=False):
raise Exception("Attempted Path Traversal in Tar File")
tar.extractall(path, members, numeric_owner=numeric_owner)

def _common_parent(paths):
path_parts = [list(Path(p).parts) for p in paths]
common_part = Path()
for parts in zip(*path_parts):
p = parts[0]
if all(part == p for part in parts):
common_part /= p
return common_part

def download_dataset(name: str, target_dir: PathLike):
"""
Expand All @@ -40,6 +49,7 @@ def download_dataset(name: str, target_dir: PathLike):
- ``'AFM-ice-Au111-monolayer'``: https://doi.org/10.5281/zenodo.10049832
- ``'AFM-ice-Au111-bilayer'``: https://doi.org/10.5281/zenodo.10049856
- ``'AFM-ice-exp'``: https://doi.org/10.5281/zenodo.10054847
- ``'AFM-ice-relaxed'``: https://doi.org/10.5281/zenodo.10362511
Arguments:
name: Name of dataset to download.
Expand All @@ -64,7 +74,7 @@ def download_dataset(name: str, target_dir: PathLike):
with tarfile.open(temp_file, "r") as ft:
print("Reading archive files...")
members = []
base_dir = os.path.commonprefix(ft.getnames())
base_dir = _common_parent(ft.getnames())
for m in ft.getmembers():
if m.isfile():
# relative_to(base_dir) here gets rid of a common parent directory within the archive (if any),
Expand Down
4 changes: 3 additions & 1 deletion papers/ice_structure_discovery/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ This folder contains the source code and links to the datasets that were used fo

The subdirectories contain various scripts for training and running predictions with the models:
- `training`: Scripts for training the atom position and graph construction models, and evaluating the trained models.
- `prediction`: Scripts for reproducing the result in Fig. 2 of the paper using the pretrained models.
- `predictions`: Scripts for reproducing the results figures of the paper using the pretrained models.

## Data

Expand All @@ -25,4 +25,6 @@ Training datasets:

Experimental data: https://doi.org/10.5281/zenodo.10054847

Final relaxed geometries: https://doi.org/10.5281/zenodo.10362511

Pretrained weights for the models: https://doi.org/10.5281/zenodo.10054348
6 changes: 4 additions & 2 deletions papers/ice_structure_discovery/predictions/README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
The scripts here can be used to reproduce the results in Fig. 2 of the paper.
The scripts here can be used to reproduce the results figures in the paper.
- `predict_experiments.py`: Runs the prediction for all of the experimental AFM images of ice on Cu(111) and Au(111) using the three models pretrained on the Cu(111), Au(111)-monolayer, and Au(111)-bilayer datasets, and saves them on disk.
- `plot_predictions.py`: Picks the appropriate predictions for each experiment and plots them to a figure.
- `plot_predictions.py`: Picks the appropriate predictions for each experiment and plots them to a figure as in Fig. 2 of the paper.
- `plot_relaxed_structures.py`: Plots the on-surface structures relaxed with a neural network potential and DFT as well as the corresponding simulations and experimental images as in Fig. 3 of the paper.
- `plot_prediction_extra.py`: Plots the prediction and the relaxed structure with corresponding simulations and experimental images for the one extra ice cluster not in the main results figure.
132 changes: 132 additions & 0 deletions papers/ice_structure_discovery/predictions/plot_prediction_extra.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
#!/usr/bin/env python3

from pathlib import Path

import matplotlib.pyplot as plt
from torch import scatter
from ppafm.ocl.oclUtils import init_env

from plot_predictions import get_data as get_data_prediction, MM_TO_INCH
from plot_predictions import plot_graph as plot_graph_prediction
from plot_relaxed_structures import get_data as get_data_relaxed
from plot_relaxed_structures import plot_graph as plot_graph_relaxed

# Set matplotlib font rendering to use LaTex
plt.rcParams.update({"text.usetex": True, "font.family": "serif", "font.serif": ["Computer Modern Roman"]})


def init_fig(width=140, left_margin=4, top_margin=4, row_gap=6, gap=0.5):
ax_size = (width - left_margin - 5 * gap) / 5

left_margin *= MM_TO_INCH
top_margin *= MM_TO_INCH
row_gap *= MM_TO_INCH
gap *= MM_TO_INCH
ax_size *= MM_TO_INCH
width *= MM_TO_INCH
height = top_margin + 2 * (ax_size + gap) + row_gap
fig = plt.figure(figsize=(width, height))

axes = []

y = height - top_margin - ax_size
x = left_margin
axes_ = []
for _ in range(5):
rect = [x / width, y / height, ax_size / width, ax_size / height]
ax = fig.add_axes(rect)
ax.set_xticks([])
ax.set_yticks([])
for axis in ["top", "bottom", "left", "right"]:
ax.spines[axis].set_linewidth(0.5)
axes_.append(ax)
x += ax_size + gap
axes.append(axes_)

y = height - top_margin - 2 * ax_size - row_gap
x = left_margin + 2 * (ax_size + gap)
axes_ = []
for _ in range(3):
rect = [x / width, y / height, ax_size / width, ax_size / height]
ax = fig.add_axes(rect)
ax.set_xticks([])
ax.set_yticks([])
for axis in ["top", "bottom", "left", "right"]:
ax.spines[axis].set_linewidth(0.5)
axes_.append(ax)
x += ax_size + gap
axes.append(axes_)

return fig, axes


if __name__ == "__main__":
init_env(i_platform=1)

exp_data_dir = Path("./exp_data")
sim_data_dir = Path("./relaxed_structures/")
scatter_size = 5
zmin = -5.0
zmax = 0.5
classes = [[1], [8], [29, 79]]
class_colors = ["w", "r"]
fontsize = 7

params = {
"pred_dir": "predictions_au111-bilayer",
"sim_name": "hartree_I",
"exp_name": "Ying_Jiang_4",
"label": "I",
"dist": 4.8,
"rot_angle": -25.000,
"amp": 2.0,
"nz": 7,
"offset": (0.0, 0.0),
}

exp_data, pred_mol, sim_pred = get_data_prediction(params, exp_data_dir, classes)
opt_mol, sim_opt, _, sw_opt = get_data_relaxed(params, exp_data_dir, sim_data_dir, classes)

fig, axes = init_fig()

# Plot data
axes[0][0].imshow(exp_data['data'][:, :, 0].T, origin="lower", cmap="gray")
axes[0][1].imshow(exp_data['data'][:, :, -1].T, origin="lower", cmap="gray")
plot_graph_prediction(
axes[0][2],
pred_mol,
box_borders=[[0, 0, zmin], [exp_data["lengthX"], exp_data["lengthY"], zmax]],
zmin=zmin,
zmax=zmax,
scatter_size=scatter_size,
class_colors=class_colors,
)
axes[0][3].imshow(sim_pred[:, :, 0].T, origin="lower", cmap="gray")
axes[0][4].imshow(sim_pred[:, :, -1].T, origin="lower", cmap="gray")
plot_graph_relaxed(
axes[1][0],
opt_mol,
box_borders=[[sw_opt[0][0], sw_opt[0][1], zmin], [sw_opt[1][0], sw_opt[1][1], zmax]],
zmin=zmin,
zmax=zmax,
scatter_size=scatter_size,
class_colors=class_colors,
)
axes[1][1].imshow(sim_opt[:, :, 0].T, origin="lower", cmap="gray")
axes[1][2].imshow(sim_opt[:, :, -1].T, origin="lower", cmap="gray")

# Set labels
y = 1.08
axes[0][0].text(
-0.08, 0.5, params["label"], transform=axes[0][0].transAxes, fontsize=fontsize, va="center", ha="center", rotation="vertical"
)
axes[0][0].text(0.5, y, "Exp.\ AFM (far)", transform=axes[0][0].transAxes, fontsize=fontsize, va="center", ha="center")
axes[0][1].text(0.5, y, "Exp.\ AFM (close)", transform=axes[0][1].transAxes, fontsize=fontsize, va="center", ha="center")
axes[0][2].text(0.5, y, "Pred.\ geom.", transform=axes[0][2].transAxes, fontsize=fontsize, va="center", ha="center")
axes[0][3].text(0.5, y, "Sim.\ AFM (far)", transform=axes[0][3].transAxes, fontsize=fontsize, va="center", ha="center")
axes[0][4].text(0.5, y, "Sim.\ AFM (close)", transform=axes[0][4].transAxes, fontsize=fontsize, va="center", ha="center")
axes[1][0].text(0.5, y, "Opt.\ geom.", transform=axes[1][0].transAxes, fontsize=fontsize, va="center", ha="center")
axes[1][1].text(0.5, y, "Sim.\ AFM (far)", transform=axes[1][1].transAxes, fontsize=fontsize, va="center", ha="center")
axes[1][2].text(0.5, y, "Sim.\ AFM (close)", transform=axes[1][2].transAxes, fontsize=fontsize, va="center", ha="center")

plt.savefig(f"sims_extra.png", dpi=400)
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,8 @@ def plot_graph(ax, mol, box_borders, class_colors, scatter_size, zmin, zmax):
{'pred_dir': 'predictions_au111-bilayer' , 'exp_name': 'Ying_Jiang_2_2', 'label': 'E', 'dist': 4.9, 'offset': ( 0.0, 0.0)},
{'pred_dir': 'predictions_au111-bilayer' , 'exp_name': 'Ying_Jiang_3' , 'label': 'F', 'dist': 4.8, 'offset': ( 0.0, -2.0)},
{'pred_dir': 'predictions_au111-bilayer' , 'exp_name': 'Ying_Jiang_5' , 'label': 'G', 'dist': 5.0, 'offset': ( 2.0, 0.0)},
{'pred_dir': 'predictions_au111-bilayer' , 'exp_name': 'Ying_Jiang_6' , 'label': 'H', 'dist': 4.8, 'offset': ( 1.5, 2.0)}
{'pred_dir': 'predictions_au111-bilayer' , 'exp_name': 'Ying_Jiang_6' , 'label': 'H', 'dist': 4.8, 'offset': ( 1.5, 2.0)},
# {'pred_dir': 'predictions_au111-bilayer' , 'exp_name': 'Ying_Jiang_4' , 'label': 'I', 'dist': 4.8, 'offset': ( 0.0, 0.0)}
]

data = [get_data(p, exp_data_dir, classes) for p in params]
Expand Down
Loading

0 comments on commit a8b9347

Please sign in to comment.