diff --git a/.gitignore b/.gitignore index 40483d5..e37d449 100644 --- a/.gitignore +++ b/.gitignore @@ -81,5 +81,11 @@ dmypy.json # VS Code .vscode +# Tex +*.aux +*.log +*.pdf + # Other molecules +*.csv diff --git a/README.md b/README.md index 118ec2b..94c74dc 100644 --- a/README.md +++ b/README.md @@ -30,4 +30,5 @@ pip install . ## Papers The [`papers`](papers) subdirectory contains training scripts and datasets for specific publications. Currently we have the following: - [Automated structure discovery in atomic force microscopy](papers/asd-afm) +- [Electrostatic Discovery Atomic Force Microscopy](papers/ed-afm) - [Structure discovery in Atomic Force Microscopy imaging of ice](papers/ice_structure_discovery) diff --git a/docs/source/reference/mlspm.models.rst b/docs/source/reference/mlspm.models.rst index de0705f..96af8e9 100644 --- a/docs/source/reference/mlspm.models.rst +++ b/docs/source/reference/mlspm.models.rst @@ -17,4 +17,12 @@ mlspm.models Alias of :class:`mlspm.image.models.ASDAFMNet` +.. class:: mlspm.models.AttentionUNet + + Alias of :class:`mlspm.image.models.AttentionUNet` + +.. class:: mlspm.models.EDAFMNet + + Alias of :class:`mlspm.image.models.EDAFMNet` + .. autofunction:: mlspm.models.download_weights diff --git a/mlspm/_weights.py b/mlspm/_weights.py index 2a21828..a00842d 100644 --- a/mlspm/_weights.py +++ b/mlspm/_weights.py @@ -11,6 +11,14 @@ "graph-ice-au111-bilayer": "https://zenodo.org/records/10054348/files/weights_ice-au111-bilayer.pth?download=1", "asdafm-light": "https://zenodo.org/records/10514470/files/weights_asdafm_light.pth?download=1", "asdafm-heavy": "https://zenodo.org/records/10514470/files/weights_asdafm_heavy.pth?download=1", + "edafm-base": "https://zenodo.org/records/10606273/files/base.pth?download=1", + "edafm-single-channel": "https://zenodo.org/records/10606273/files/single-channel.pth?download=1", + "edafm-CO-Cl": "https://zenodo.org/records/10606273/files/CO-Cl.pth?download=1", + "edafm-Xe-Cl": "https://zenodo.org/records/10606273/files/Xe-Cl.pth?download=1", + "edafm-constant-noise": "https://zenodo.org/records/10606273/files/constant-noise.pth?download=1", + "edafm-uniform-noise": "https://zenodo.org/records/10606273/files/uniform_noise.pth?download=1", + "edafm-no-gradient": "https://zenodo.org/records/10606273/files/no-gradient.pth?download=1", + "edafm-matched-tips": "https://zenodo.org/records/10606273/files/matched-tips.pth?download=1", } @@ -21,11 +29,30 @@ def download_weights(weights_name: str, target_path: Optional[PathLike] = None) The following weights are available: - ``'graph-ice-cu111'``: PosNet trained on ice clusters on Cu(111). (https://doi.org/10.5281/zenodo.10054348) - - ``'graph-ice-au111-monolayer'``: PosNet trained on monolayer ice clusters on Au(111). (https://doi.org/10.5281/zenodo.10054348) - - ``'graph-ice-au111-bilayer'``: PosNet trained on bilayer ice clusters on Au(111). (https://doi.org/10.5281/zenodo.10054348) - - ``'asdafm-light'``: ASDAFMNet trained on molecules containing the elements H, C, N, O, and F. (https://doi.org/10.5281/zenodo.10514470) - - ``'asdafm-heavy'``: ASDAFMNet trained on molecules additionally containing Si, P, S, Cl, and Br. (https://doi.org/10.5281/zenodo.10514470) - + - ``'graph-ice-au111-monolayer'``: PosNet trained on monolayer ice clusters on Au(111). + (https://doi.org/10.5281/zenodo.10054348) + - ``'graph-ice-au111-bilayer'``: PosNet trained on bilayer ice clusters on Au(111). + (https://doi.org/10.5281/zenodo.10054348) + - ``'asdafm-light'``: :class:`.ASDAFMNet` trained on molecules containing the elements H, C, N, O, and F. + (https://doi.org/10.5281/zenodo.10514470) + - ``'asdafm-heavy'``: :class:`.ASDAFMNet` trained on molecules additionally containing Si, P, S, Cl, and Br. + (https://doi.org/10.5281/zenodo.10514470) + - ``'edafm-base'``: :class:`.EDAFMNet` used for all predictions in the main ED-AFM paper and used for comparison in + the various tests in the supplementary information of the paper. (https://doi.org/10.5281/zenodo.10606273) + - ``'edafm-single-channel'``: :class:`.EDAFMNet` trained on only a single CO-tip AFM input. + (https://doi.org/10.5281/zenodo.10606273) + - ``'edafm-CO-Cl'``: :class:`.EDAFMNet` trained on alternative tip combination of CO and Cl. + (https://doi.org/10.5281/zenodo.10606273) + - ``'edafm-Xe-Cl'``: :class:`.EDAFMNet` trained on alternative tip combination of Xe and Cl. + (https://doi.org/10.5281/zenodo.10606273) + - ``'edafm-constant-noise'``: :class:`.EDAFMNet` trained using constant noise amplitude instead of normally distributed + amplitude. (https://doi.org/10.5281/zenodo.10606273) + - ``'edafm-uniform-noise'``: :class:`.EDAFMNet` trained using uniform random noise amplitude instead of normally + distributed amplitude. (https://doi.org/10.5281/zenodo.10606273) + - ``'edafm-no-gradient'``: :class:`.EDAFMNet` trained without background-gradient augmentation. + (https://doi.org/10.5281/zenodo.10606273) + - ``'edafm-matched-tips'``: :class:`.EDAFMNet` trained on data with matched tip distance between CO and Xe, + instead of independently randomized distances. (https://doi.org/10.5281/zenodo.10606273) Arguments: weights_name: Name of weights to download. diff --git a/mlspm/datasets.py b/mlspm/datasets.py index dedcc81..ad4b1d1 100644 --- a/mlspm/datasets.py +++ b/mlspm/datasets.py @@ -15,6 +15,8 @@ "AFM-ice-relaxed": "https://zenodo.org/records/10362511/files/relaxed_structures.tar.gz?download=1", "ASD-AFM-molecules": "https://zenodo.org/records/10562769/files/molecules.tar.gz?download=1", "AFM-camphor-exp": "https://zenodo.org/records/10562769/files/afm_camphor.tar.gz?download=1", + "ED-AFM-molecules": "https://zenodo.org/records/10609676/files/molecules_rebias.tar.gz?download=1", + "ED-AFM-data": "https://zenodo.org/records/10609676/files/edafm-data.tar.gz?download=1", } @@ -43,6 +45,8 @@ def download_dataset(name: str, target_dir: PathLike): - ``'AFM-ice-relaxed'``: https://doi.org/10.5281/zenodo.10362511 - ``'ASD-AFM-molecules'``: https://doi.org/10.5281/zenodo.10562769 - 'molecules.tar.gz' - ``'AFM-camphor-exp'``: https://doi.org/10.5281/zenodo.10562769 - 'afm_camphor.tar.gz' + - ``'ED-AFM-molecules'``: https://doi.org/10.5281/zenodo.10609676 - 'molecules_rebias.tar.gz' + - ``'ED-AFM-data'``: https://doi.org/10.5281/zenodo.10609676 - 'edafm-data.tar.gz' Arguments: name: Name of the dataset to download. diff --git a/mlspm/image/models.py b/mlspm/image/models.py index 70e07f4..b62aa38 100644 --- a/mlspm/image/models.py +++ b/mlspm/image/models.py @@ -1,17 +1,459 @@ -from turtle import forward -from typing import Literal, Optional, Tuple +from typing import List, Literal, Optional, Tuple import torch from torch import nn -from ..modules import _get_padding from .._weights import download_weights +from ..modules import Conv2dBlock, Conv3dBlock, UNetAttentionConv, _get_padding def _flatten_z_to_channels(x): return x.permute(0, 4, 1, 2, 3).reshape(x.size(0), -1, x.size(2), x.size(3)) +class AttentionUNet(nn.Module): + """ + Pytorch 3D-to-2D U-net model with attention. + + 3D conv -> concatenate -> 3D conv/pool/dropout -> 2D conv/dropout -> 2D upsampling/conv with skip connections + and attention. For multiple inputs, the inputs are first processed through separate 3D conv blocks before merging + by concatenating along channel axis. + + Arguments: + z_in: Size of input array in the z-dimension. + n_in: Number of input 3D images. + n_out: Number of output 2D maps. + in_channels: Number of channels in input array. + merge_block_channels: Number of channels in input merging 3D conv blocks. + merge_block_depth: Number of layers in each merge conv block. + conv3d_block_channels: Number channels in 3D conv blocks. + conv3d_block_depth: Number of layers in each 3D conv block. + conv3d_dropouts: Dropout rates after each conv3d block. + conv2d_block_channels: Number channels in 2D conv blocks. + conv2d_block_depth: Number of layers in each 2D conv block. + conv2d_dropouts: Dropout rates after each conv2d block. + attention_channels: Number of channels in conv layer within each attention block. + upscale2d_block_channels: Number of channels in each 2D conv block after upscale before skip connection. + upscale2d_block_depth: Number of layers in each 2D conv block after upscale before skip connection. + upscale2d_block_channels2: Number of channels in each 2D conv block after skip connection. + upscale2d_block_depth2: Number of layers in each 2D conv block after skip connection. + split_conv_block_channels: Number of channels in 2d conv blocks after splitting outputs. + split_conv_block_depth: Number of layers in each 2d conv block after splitting outputs. + res_connections: Whether to use residual connections in conv blocks. + out_convs_channels: Number of channels in splitted outputs. + out_relus: Whether to apply relu activation to the output 2D maps. + pool_type: Type of pooling to use. + pool_z_strides: Stride of pool layers in z direction. + padding_mode: Type of padding in each convolution layer. + activation: Activation to use after every layer except last one. + attention_activation: Type of activation to use for attention map. + device: Device to load model onto. + """ + + def __init__( + self, + z_in: int, + n_in: int = 1, + n_out: int = 3, + in_channels: int = 1, + merge_block_channels: List[int] = [8], + merge_block_depth: int = 2, + conv3d_block_channels: List[int] = [8, 16, 32], + conv3d_block_depth: int = 2, + conv3d_dropouts: List[float] = [0.0, 0.0, 0.0], + conv2d_block_channels: List[int] = [128], + conv2d_block_depth: int = 3, + conv2d_dropouts: List[float] = [0.1], + attention_channels: List[int] = [32, 32, 32], + upscale2d_block_channels: List[int] = [16, 16, 16], + upscale2d_block_depth: int = 1, + upscale2d_block_channels2: List[int] = [16, 16, 16], + upscale2d_block_depth2: int = 2, + split_conv_block_channels: List[int] = [16], + split_conv_block_depth: int = 3, + res_connections: bool = True, + out_convs_channels: int | List[int] = 1, + out_relus: bool | List[bool] = True, + pool_type: Literal["avg", "max"] = "avg", + pool_z_strides: List[int] = [2, 1, 2], + padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros", + activation: Literal["relu", "lrelu", "elu"] | nn.Module = "lrelu", + attention_activation: Literal["sigmoid", "softmax"] = "softmax", + device: str = "cuda", + ): + super().__init__() + + assert ( + len(conv3d_block_channels) + == len(conv3d_dropouts) + == len(upscale2d_block_channels) + == len(upscale2d_block_channels2) + == len(attention_channels) + ) + + if isinstance(activation, nn.Module): + self.act = activation + elif activation == "relu": + self.act = nn.ReLU() + elif activation == "lrelu": + self.act = nn.LeakyReLU() + elif activation == "elu": + self.act = nn.ELU() + else: + raise ValueError(f"Unknown activation function {activation}") + + if not isinstance(out_relus, list): + out_relus = [out_relus] * n_out + else: + assert len(out_relus) == n_out + + if not isinstance(out_convs_channels, list): + out_convs_channels = [out_convs_channels] * n_out + else: + assert len(out_convs_channels) == n_out + + self.out_relus = out_relus + self.relu_act = nn.ReLU() + + # Infer number of channels after 3D-to-2D flattening at each stage from the z_in size + z_size = z_in + attention_in_channels = [] + for pool_stride, conv3d_channels in zip(pool_z_strides, conv3d_block_channels): + attention_in_channels.append(conv3d_channels * z_size) + z_size = z_size // pool_stride + z_size -= max(0, 2 - pool_stride) + conv2d_in_channels = conv3d_block_channels[-1] * z_size + attention_in_channels = list(reversed(attention_in_channels)) + + # -- Input merge conv blocks -- + self.merge_convs = nn.ModuleList([None] * n_in) + for i in range(n_in): + self.merge_convs[i] = nn.ModuleList( + [ + Conv3dBlock( + in_channels, + merge_block_channels[0], + 3, + merge_block_depth, + padding_mode, + res_connections, + self.act, + False, + ) + ] + ) + for j in range(len(merge_block_channels) - 1): + self.merge_convs[i].append( + Conv3dBlock( + merge_block_channels[j], + merge_block_channels[j + 1], + 3, + merge_block_depth, + padding_mode, + res_connections, + self.act, + False, + ) + ) + + # -- Encoder conv blocks -- + self.conv3d_blocks = nn.ModuleList( + [ + Conv3dBlock( + n_in * merge_block_channels[-1], + conv3d_block_channels[0], + 3, + conv3d_block_depth, + padding_mode, + res_connections, + self.act, + False, + ) + ] + ) + self.conv3d_dropouts = nn.ModuleList([nn.Dropout(conv3d_dropouts[0])]) + for i in range(len(conv3d_block_channels) - 1): + self.conv3d_blocks.append( + Conv3dBlock( + conv3d_block_channels[i], + conv3d_block_channels[i + 1], + 3, + conv3d_block_depth, + padding_mode, + res_connections, + self.act, + False, + ) + ) + self.conv3d_dropouts.append(nn.Dropout(conv3d_dropouts[i + 1])) + + # -- Middle conv blocks -- + self.conv2d_blocks = nn.ModuleList( + [ + Conv2dBlock( + conv2d_in_channels, conv2d_block_channels[0], 3, conv2d_block_depth, padding_mode, res_connections, self.act, False + ) + ] + ) + self.conv2d_dropouts = nn.ModuleList([nn.Dropout(conv2d_dropouts[0])]) + for i in range(len(conv2d_block_channels) - 1): + self.conv2d_blocks.append( + Conv2dBlock( + conv2d_block_channels[i], + conv2d_block_channels[i + 1], + 3, + conv2d_block_depth, + padding_mode, + res_connections, + self.act, + False, + ) + ) + self.conv2d_dropouts.append(nn.Dropout(conv2d_dropouts[i + 1])) + + # -- Decoder conv blocks -- + self.attentions = nn.ModuleList([]) + for c_att, c_conv in zip(attention_channels, attention_in_channels): + self.attentions.append( + UNetAttentionConv( + c_conv, conv2d_block_channels[-1], c_att, 3, padding_mode, self.act, attention_activation, upsample_mode="bilinear" + ) + ) + + self.upscale2d_blocks = nn.ModuleList( + [ + Conv2dBlock( + conv2d_block_channels[-1], + upscale2d_block_channels[0], + 3, + upscale2d_block_depth, + padding_mode, + res_connections, + self.act, + False, + ) + ] + ) + for i in range(len(upscale2d_block_channels) - 1): + self.upscale2d_blocks.append( + Conv2dBlock( + upscale2d_block_channels2[i], + upscale2d_block_channels[i + 1], + 3, + upscale2d_block_depth, + padding_mode, + res_connections, + self.act, + False, + ) + ) + + self.upscale2d_blocks2 = nn.ModuleList([]) + for i in range(len(upscale2d_block_channels2)): + self.upscale2d_blocks2.append( + Conv2dBlock( + upscale2d_block_channels[i] + attention_in_channels[i], + upscale2d_block_channels2[i], + 3, + upscale2d_block_depth2, + padding_mode, + res_connections, + self.act, + False, + ) + ) + + # -- Output split conv blocks -- + padding = _get_padding(3, 2) + self.out_convs = nn.ModuleList([]) + self.split_convs = nn.ModuleList([None] * n_out) + for i_out in range(n_out): + self.split_convs[i_out] = nn.ModuleList( + [ + Conv2dBlock( + upscale2d_block_channels2[-1], + split_conv_block_channels[0], + 3, + split_conv_block_depth, + padding_mode, + res_connections, + self.act, + False, + ) + ] + ) + for i in range(len(split_conv_block_channels) - 1): + self.split_convs.append( + Conv2dBlock( + split_conv_block_channels[i], + split_conv_block_channels[i + 1], + 3, + split_conv_block_depth, + padding_mode, + res_connections, + self.act, + False, + ) + ) + + self.out_convs.append( + nn.Conv2d( + split_conv_block_channels[-1], out_convs_channels[i_out], kernel_size=3, padding=padding, padding_mode=padding_mode + ) + ) + + if pool_type == "avg": + pool = nn.AvgPool3d + elif pool_type == "max": + pool = nn.MaxPool3d + self.pools = nn.ModuleList([pool(2, stride=(2, 2, pz)) for pz in pool_z_strides]) + + self.upsample2d = nn.Upsample(scale_factor=2, mode="nearest") + self.device = device + self.n_out = n_out + self.n_in = n_in + + self.to(device) + + def _flatten(self, x): + return x.permute(0, 1, 4, 2, 3).reshape(x.size(0), -1, x.size(2), x.size(3)) + + def forward(self, x: List[torch.Tensor]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + Do forward computation. + + Arguments: + x: Input AFM images of shape (batch, channels, x, y, z). + + Returns: + Tuple (**outputs**, **attention_maps**), where + + - **outputs** - Output arrays of shape ``(batch, out_convs_channels, x, y)`` or ``(batch, x, y)`` if + out_convs_channels == 1. + - **attention_maps** - Attention maps at each stage of skip-connections. + """ + assert len(x) == self.n_in + + # Do 3D convolutions for each input + in_branches = [] + for xi, convs in zip(x, self.merge_convs): + for conv in convs: + xi = self.act(conv(xi)) + in_branches.append(xi) + + # Merge input branches + x = torch.cat(in_branches, dim=1) + + # Encode + x_3ds = [] + for conv, dropout, pool in zip(self.conv3d_blocks, self.conv3d_dropouts, self.pools): + x = self.act(conv(x)) + x = dropout(x) + x_3ds.append(x) + x = pool(x) + + # Middle 2d convs + x = self._flatten(x) + for conv, dropout in zip(self.conv2d_blocks, self.conv2d_dropouts): + x = self.act(conv(x)) + x = dropout(x) + + # Compute attention maps + attention_maps = [] + x_gated = [] + for attention, x_3d in zip(self.attentions, reversed(x_3ds)): + g, a = attention(self._flatten(x_3d), x) + x_gated.append(g) + attention_maps.append(a) + + # Decode + for i, (conv1, conv2, xg) in enumerate(zip(self.upscale2d_blocks, self.upscale2d_blocks2, x_gated)): + x = self.upsample2d(x) + x = self.act(conv1(x)) + x = torch.cat([x, xg], dim=1) # Attention-gated skip connection + x = self.act(conv2(x)) + + # Split into different outputs + outputs = [] + for i, (split_convs, out_conv) in enumerate(zip(self.split_convs, self.out_convs)): + h = x + for conv in split_convs: + h = self.act(conv(h)) + h = out_conv(h) + + if self.out_relus[i]: + h = self.relu_act(h) + outputs.append(h.squeeze(1)) + + return outputs, attention_maps + + +class EDAFMNet(AttentionUNet): + """ + ED-AFM Attention U-net. + + This is the model used in the ED-AFM paper for task of predicting electrostatics from AFM images. + It is a subclass of the :class:`AttentionUNet` class with specific hyperparameters. + + The following pretrained weights are available: + + - ``'base'``: The base model used for all predictions in the main ED-AFM paper and used for comparison in the various test in the supplementary information of the paper. + - ``'single-channel'``: Model trained on only a single CO-tip AFM input. + - ``'CO-Cl'``: Model trained on alternative tip combination of CO and Cl. + - ``'Xe-Cl'``: Model trained on alternative tip combination of Xe and Cl. + - ``'constant-noise'``: Model trained using constant noise amplitude instead of normally distributed amplitude. + - ``'uniform-noise'``: Model trained using uniform random noise amplitude instead of normally distributed amplitude. + - ``'no-gradient'``: Model trained without background-gradient augmentation. + - ``'matched-tips'``: Model trained on data with matched tip distance between CO and Xe, instead of independently randomized distances. + + Arguments: + device: Device to load model onto. + trained_weights: If not None, load the specified pretrained weights to the model. + """ + + def __init__( + self, + device: str = "cuda", + pretrained_weights: Optional[ + Literal["base", "single-channel", "CO-Cl", "Xe-Cl", "constant-noise", "uniform-noise", "no-gradient", "matched-tips"] + ] = None, + ): + if pretrained_weights == "single-channel": + n_in = 1 + else: + n_in = 2 + + super().__init__( + z_in=6, + n_in=n_in, + n_out=1, + in_channels=1, + merge_block_channels=[32], + merge_block_depth=2, + conv3d_dropouts=[0.0, 0.0, 0.0], + conv3d_block_channels=[48, 96, 192], + conv3d_block_depth=3, + conv2d_block_channels=[512], + conv2d_block_depth=3, + conv2d_dropouts=[0.0], + upscale2d_block_channels=[256, 128, 64], + upscale2d_block_depth=2, + upscale2d_block_channels2=[256, 128, 64], + upscale2d_block_depth2=2, + split_conv_block_channels=[64], + split_conv_block_depth=3, + out_relus=[False], + pool_z_strides=[2, 1, 2], + activation=nn.LeakyReLU(negative_slope=0.1, inplace=True), + padding_mode="replicate", + device=device, + ) + + if pretrained_weights: + weights_path = download_weights(f"edafm-{pretrained_weights}") + self.load_state_dict(torch.load(weights_path)) + + class ASDAFMNet(nn.Module): """ The model used in the paper "Automated structure discovery in atomic force microscopy": https://doi.org/10.1126/sciadv.aay6913. diff --git a/mlspm/models.py b/mlspm/models.py index b0c6b40..b6254e2 100644 --- a/mlspm/models.py +++ b/mlspm/models.py @@ -1,3 +1,3 @@ -from .graph.models import PosNet, GraphImgNet, GraphImgNetIce -from .image.models import ASDAFMNet from ._weights import download_weights +from .graph.models import GraphImgNet, GraphImgNetIce, PosNet +from .image.models import ASDAFMNet, AttentionUNet, EDAFMNet diff --git a/mlspm/preprocessing.py b/mlspm/preprocessing.py index 23a16fa..5c5be09 100644 --- a/mlspm/preprocessing.py +++ b/mlspm/preprocessing.py @@ -1,5 +1,5 @@ import random -from typing import List, Tuple +from typing import List, Literal, Optional, Tuple import numpy as np import scipy.ndimage as nimg @@ -213,12 +213,136 @@ def interpolate_and_crop( def minimum_to_zero(Ys: List[np.ndarray]): - ''' + """ Shift values in arrays such that minimum is at zero. In-place operation. Arguments: Ys: Arrays of shape (batch_size, ...). - ''' + """ for Y in Ys: for j in range(Y.shape[0]): Y[j] -= Y[j].min() + + +def add_rotation_reflection( + X: List[np.ndarray], + Y: List[np.ndarray], + reflections: bool = True, + multiple: int = 2, + crop: Optional[Tuple[int]] = None, + per_batch_item: bool = False, +) -> Tuple[np.ndarray, np.ndarray]: + """ + Augment batch with random rotations and reflections. + + Arguments: + X: AFM images to augment. Each array should be of shape ``(batch_size, x, y, z)``. + Y: Reference image descriptors to augment. Each array should be of shape ``(batch, x, y)``. + reflections: Whether to augment with reflections. If True, each rotation is randomly reflected with 50% probability. + multiple: Multiplier for how many rotations to generate for every sample. + crop: If not None, then output batch is cropped to specified size ``(x_size, y_size)`` in the middle of the image. + per_batch_item: If True, rotation is randomized per batch item, otherwise same rotation for all. + + Returns: + Tuple (**X**, **Y**), where + + - **X** - Batch of rotation-augmented AFM images of shape ``(batch*multiple, x_new, y_new, z)``. + - **Y** - Batch of rotation-augmented reference image descriptors of shape ``(batch*multiple, x_new, y_new)`` + """ + + X_aug = [[] for _ in range(len(X))] + Y_aug = [[] for _ in range(len(Y))] + + for _ in range(multiple): + if per_batch_item: + rotations = 360 * np.random.rand(len(X[0])) + else: + rotations = [360 * np.random.rand()] * len(X[0]) + if reflections: + flip = np.random.randint(2) + for k, x in enumerate(X): + x = x.copy() + for i in range(x.shape[0]): + for j in range(x.shape[-1]): + x[i, :, :, j] = np.array(Image.fromarray(x[i, :, :, j]).rotate(rotations[i], resample=Image.BICUBIC)) + if flip: + x = x[:, :, ::-1] + X_aug[k].append(x) + for k, y in enumerate(Y): + y = y.copy() + for i in range(y.shape[0]): + y[i, :, :] = np.array(Image.fromarray(y[i, :, :]).rotate(rotations[i], resample=Image.BICUBIC)) + if flip: + y = y[:, :, ::-1] + Y_aug[k].append(y) + + X = [np.concatenate(x, axis=0) for x in X_aug] + Y = [np.concatenate(y, axis=0) for y in Y_aug] + + if crop is not None: + x_start = (X[0].shape[1] - crop[0]) // 2 + y_start = (X[0].shape[2] - crop[1]) // 2 + X = [x[:, x_start : x_start + crop[0], y_start : y_start + crop[1]] for x in X] + Y = [y[:, x_start : x_start + crop[0], y_start : y_start + crop[1]] for y in Y] + + return X, Y + + +def random_crop( + X: List[np.ndarray], + Y: List[np.ndarray], + min_crop: float = 0.5, + max_aspect: float = 2.0, + multiple: int = 8, + distribution: Literal["flat", "exp-log"] = "flat", +) -> Tuple[np.ndarray, np.ndarray]: + """ + Randomly crop images in a batch to a different size and aspect ratio. + + Arguments: + X: AFM images to crop. Each array should be of shape ``(batch_size, x, y, z)``. + Y: Reference image descriptors to crop. Each array should be of shape ``(batch, x, y)``. + min_crop: Minimum crop size as a fraction of the original size. + max_aspect: Maximum aspect ratio for crop. Cannot be more than 1/min_crop. + multiple: The crop size is rounded down to the specified integer multiple. + distribution: 'flat' or 'exp-log'. How aspect ratios are distributed. If 'flat', then distribution is random uniform + between (1, max_aspect) and half of time is flipped. If 'exp-log', then distribution is exp of log of uniform + distribution over (1/max_aspect, max_aspect). 'exp-log' is more biased towards square aspect ratios. + + Returns: + Tuple (**X**, **Y**), where + + - **X** - Batch of cropped AFM images of shape ``(batch, x_new, y_new, z)``. + - **Y** - Batch of cropped reference image descriptors of shape ``(batch, x_new, y_new)``. + """ + assert 0 < min_crop <= 1.0 + assert max_aspect >= 1.0 + assert 1 / min_crop >= max_aspect + + if distribution == "flat": + aspect = np.random.uniform(1, max_aspect) + if np.random.rand() > 0.5: + aspect = 1 / aspect + elif distribution == "exp-log": + aspect = np.exp(np.random.uniform(np.log(1 / max_aspect), np.log(max_aspect))) + else: + raise ValueError(f"Unrecognized aspect ratio distribution {distribution}") + + x_size, y_size = X[0].shape[1], X[0].shape[2] + if aspect > 1.0: + height = int(np.random.uniform(int(min_crop * y_size), int(y_size / aspect))) + width = int(height * aspect) + else: + width = int(np.random.uniform(int(min_crop * x_size), int(x_size * aspect))) + height = int(width / aspect) + + width = width - (width % multiple) + height = height - (height % multiple) + + start_x = int(np.random.uniform(0, x_size - width - 1e-6)) + start_y = int(np.random.uniform(0, y_size - height - 1e-6)) + + X = [x[:, start_x : start_x + width, start_y : start_y + height] for x in X] + Y = [y[:, start_x : start_x + width, start_y : start_y + height] for y in Y] + + return X, Y diff --git a/papers/asd-afm/generate_data.py b/papers/asd-afm/generate_data.py index 2e76310..c5dafde 100644 --- a/papers/asd-afm/generate_data.py +++ b/papers/asd-afm/generate_data.py @@ -31,7 +31,7 @@ def on_sample_start(self): # Define simulator and image descriptor parameters scan_window = ((0, 0, 6.0), (15.875, 15.875, 7.9)) scan_dim = (128, 128, 19) - afmulator = AFMulator(pixPerAngstrome=5, scan_dim=scan_dim, scan_window=scan_window) + afmulator = AFMulator(pixPerAngstrome=5, scan_dim=scan_dim, scan_window=scan_window, tipR0=[0.0, 0.0, 4.0]) aux_maps = [ AtomicDisks(scan_dim=scan_dim, scan_window=scan_window, zmin=-1.2, zmax_s=-1.2, diskMode="sphere"), vdwSpheres(scan_dim=scan_dim, scan_window=scan_window, zmin=-1.5), @@ -41,7 +41,7 @@ def on_sample_start(self): "afmulator": afmulator, "aux_maps": aux_maps, "batch_size": 1, - "distAbove": 4.3, + "distAbove": 5.25, "iZPPs": [8], "Qs": [[-0.1, 0, 0, 0]], "QZs": [[0, 0, 0, 0]], diff --git a/papers/ed-afm/README.md b/papers/ed-afm/README.md new file mode 100644 index 0000000..b33527a --- /dev/null +++ b/papers/ed-afm/README.md @@ -0,0 +1,68 @@ +# Electrostatic Discovery Atomic Force Microscopy + +The scrpts in the original repository at https://github.com/SINGROUP/ED-AFM are reproduced here for convenience. + +Paper: [*N. Oinonen et al. Electrostatic Discovery Atomic Force Microscopy, ACS Nano 2022*](https://pubs.acs.org/doi/10.1021/acsnano.1c06840) + +Abstract: +_While offering unprecedented resolution of atomic and electronic structure, Scanning Probe Microscopy techniques have found greater challenges in providing reliable electrostatic characterization at the same scale. In this work, we introduce Electrostatic Discovery Atomic Force Microscopy, a machine learning based method which provides immediate quantitative maps of the electrostatic potential directly from Atomic Force Microscopy images with functionalized tips. We apply this to characterize the electrostatic properties of a variety of molecular systems and compare directly to reference simulations, demonstrating good agreement. This approach opens the door to reliable atomic scale electrostatic maps on any system with minimal computational overhead._ + +![Method schematic](https://github.com/SINGROUP/ED-AFM/blob/master/figures/method_schem.png) + +## ML model + +We use a U-net type convolutional neural network with attention gates in the skip connections. Similar model was used previously by Oktay et al. for segmenting medical images (https://arxiv.org/abs/1804.03999v2). + +![Model schematic](https://github.com/SINGROUP/ED-AFM/blob/master/figures/model_schem.png) +![AG schematic](https://github.com/SINGROUP/ED-AFM/blob/master/figures/AG_schem.png) + +The model implementation can be found in [`mlspm/image/models.py`](https://github.com/SINGROUP/ml-spm/blob/main/mlspm/image/models.py), where two modules can be found: [`AttentionUNet`](https://ml-spm.readthedocs.io/en/latest/reference/mlspm.image.html#mlspm.image.models.AttentionUNet), which is the generic version of the model, and [`EDAFMNet`](https://ml-spm.readthedocs.io/en/latest/reference/mlspm.image.html#mlspm.image.models.EDAFMNet), which is a subclass of the former specifying the exact hyperparameters for the model that we used. + +In `EDAFMNet` one can also specify pretrained weights of several types to download to the model using the `pretrained_weights` argument: + + - `'base'`: The base model used for all predictions in the main ED-AFM paper and used for comparison in the various test in the supplementary information of the paper. + - `'single-channel'`: Model trained on only a single CO-tip AFM input. + - `'CO-Cl'`: Model trained on alternative tip combination of CO and Cl. + - `'Xe-Cl`': Model trained on alternative tip combination of Xe and Cl. + - `'constant-noise'`: Model trained using constant noise amplitude instead of normally distributed amplitude. + - `'uniform-noise'`: Model trained using uniform random noise amplitude instead of normally distributed amplitude. + - `'no-gradient'`: Model trained without background-gradient augmentation. + - `'matched-tips'`: Model trained on data with matched tip distance between CO and Xe, instead of independently randomized distances. + +The model weights can also be downloaded directly from https://doi.org/10.5281/zenodo.10606273. The weights are saved in the state_dict format of PyTorch. + +## Data and model training + +We provide a database of molecular geometries that can be used to generate the full dataset using [`ppafm`](https://github.com/Probe-Particle/ppafm). The provided script `generate_data.py` does the data generation and will download the molecule database automatically. Alternatively, the molecule database can be downloaded directly from https://doi.org/10.5281/zenodo.10609676. + +After generating the dataset, the model training can be done using the provided script `run_train.sh`, which calls the actual training script `train.py` with the appropriate parameters. Note that performing the training using all the same settings as we used requires a significant amount of time and also a significant amount VRAM on the GPU, likely more than can be found on a single GPU. In our case the model training took ~5 days using 4 x Nvidia Tesla V100 (32GB) GPUs. However, inference on the trained model can be done even on a single lower-end GPU or on CPU. + +All the data used for the predictions in the paper can be found at https://doi.org/10.5281/zenodo.10609676. + +## Figures + +The scripts used to generate most of the figures in the paper are provided under the directory `figures`. Running the scripts will automatically download the required data. + +The scripts correspond to the figures as follows: + + - Fig. 1: `sims.py` + - Fig. 2: `ptcda.py` + - Fig. 3: `bcb.py` + - Fig. 4: `water.py` + - Fig. 5: `surface_sims_bcb_water.py` + - Fig. S1: `model_schem.tex` + - Fig. S3: `stats.py`\* + - Fig. S4: `esmap_sample.py` and then `esmap_schem.tex` + - Fig. S5: `stats_spring_constants.py`\* + - Fig. S6: `afm_stacks.py` and `afm_stacks2.py` + - Fig. S7: `sims_hartree.py` + - Fig. S8: `ptcda_surface_sim.py` + - Fig. S9: `single_tip.py` + - Fig. S10: `sims_Cl.py` + - Fig. S11: `height_dependence.py` + - Fig. S12: `extra_electron.py` + - Fig. S13: `background_gradient.py` + +\* Precalculated MSE values used by the plotting script are provided under `figures/stats`. The scripts used to calculate these values are also under `figures/stats`. + +You can also use `run_all.sh`-script to run all of the scripts in one go. Note that compiling the .tex files additionally requires a working LaTex installation on your system. diff --git a/papers/ed-afm/figures/afm_stacks.py b/papers/ed-afm/figures/afm_stacks.py new file mode 100644 index 0000000..d9d02ac --- /dev/null +++ b/papers/ed-afm/figures/afm_stacks.py @@ -0,0 +1,63 @@ +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np + +from mlspm.datasets import download_dataset + +# # Set matplotlib font rendering to use LaTex +# plt.rcParams.update({ +# "text.usetex": True, +# "font.family": "serif", +# "font.serif": ["Computer Modern Roman"] +# }) + +if __name__ == "__main__": + data_dir = Path("./edafm-data") + fig_width = 160 + fontsize = 8 + dpi = 300 + + # Download data if not already there + download_dataset("ED-AFM-data", data_dir) + + # Load data + bcb_CO = np.load(data_dir / "BCB" / "data_CO_exp.npz") + bcb_Xe = np.load(data_dir / "BCB" / "data_Xe_exp.npz") + ptcda_CO = np.load(data_dir / "PTCDA" / "data_CO_exp.npz") + ptcda_Xe = np.load(data_dir / "PTCDA" / "data_Xe_exp.npz") + + fig_width = 0.1 / 2.54 * fig_width + height_ratios = [2, 2, 2.45, 2.45] + fig = plt.figure(figsize=(fig_width, 0.85 * sum(height_ratios))) + fig_grid = fig.add_gridspec(4, 1, wspace=0, hspace=0.1, height_ratios=height_ratios) + + # BCB plots + for i, (sample, label) in enumerate(zip([bcb_CO, bcb_Xe], ["A", "B"])): + d = sample["data"] + l = sample["lengthX"] + axes = fig_grid[i, 0].subgridspec(2, 8, wspace=0.02, hspace=0.02).subplots().flatten() + for j, ax in enumerate(axes): + if j < d.shape[-1]: + ax.imshow(d[:, :, j].T, origin="lower", cmap="afmhot") + ax.axis("off") + axes[0].text( + -0.3, 0.8, label, horizontalalignment="center", verticalalignment="center", transform=axes[0].transAxes, fontsize=fontsize + ) + axes[0].plot([50, 50 + 5 / l * d.shape[0]], [470, 470], color="k") + + # PTCDA plots + for i, (sample, label) in enumerate(zip([ptcda_CO, ptcda_Xe], ["C", "D"])): + d = sample["data"] + l = sample["lengthX"] + axes = fig_grid[i + 2, 0].subgridspec(3, 6, wspace=0.02, hspace=0.02).subplots().flatten() + for j, ax in enumerate(axes): + if j < d.shape[-1]: + ax.imshow(d[:, :, j].T, origin="lower", cmap="afmhot") + ax.axis("off") + axes[0].text( + -0.22, 0.7, label, horizontalalignment="center", verticalalignment="center", transform=axes[0].transAxes, fontsize=fontsize + ) + axes[0].plot([20, 20 + 5 / l * d.shape[0]], [135, 135], color="k") + + plt.savefig("afm_stacks.pdf", bbox_inches="tight", dpi=dpi) diff --git a/papers/ed-afm/figures/afm_stacks2.py b/papers/ed-afm/figures/afm_stacks2.py new file mode 100644 index 0000000..bc507dc --- /dev/null +++ b/papers/ed-afm/figures/afm_stacks2.py @@ -0,0 +1,45 @@ + +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np + +from mlspm.datasets import download_dataset + +# # Set matplotlib font rendering to use LaTex +# plt.rcParams.update({ +# "text.usetex": True, +# "font.family": "serif", +# "font.serif": ["Computer Modern Roman"] +# }) + +if __name__ == "__main__": + data_dir = Path("./edafm-data") + fig_width = 160 + fontsize = 8 + dpi = 300 + + # Download data if not already there + download_dataset("ED-AFM-data", data_dir) + + # Load data + water_CO = np.load(data_dir / 'Water' / 'data_CO_exp.npz') + water_Xe = np.load(data_dir / 'Water' / 'data_Xe_exp.npz') + + fig = plt.figure(figsize=(0.1/2.54*fig_width, 5.0)) + fig_grid = fig.add_gridspec(2, 1, wspace=0, hspace=0.1) + + # Water plots + for i, (sample, label) in enumerate(zip([water_CO, water_Xe], ['E', 'F'])): + d = sample['data'] + l = sample['lengthX'] + axes = fig_grid[i, 0].subgridspec(3, 8, wspace=0.02, hspace=0.02).subplots().flatten() + for j, ax in enumerate(axes): + if j < d.shape[-1]: + ax.imshow(d[:,:,j].T, origin='lower', cmap='afmhot') + ax.axis('off') + axes[0].text(-0.3, 0.8, label, horizontalalignment='center', + verticalalignment='center', transform=axes[0].transAxes, fontsize=fontsize) + axes[0].plot([50, 50+5/l*d.shape[0]], [470, 470], color='k') + + plt.savefig('afm_stacks2.pdf', bbox_inches='tight', dpi=dpi) diff --git a/papers/ed-afm/figures/background_gradient.py b/papers/ed-afm/figures/background_gradient.py new file mode 100644 index 0000000..3eab663 --- /dev/null +++ b/papers/ed-afm/figures/background_gradient.py @@ -0,0 +1,236 @@ + +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import ppafm.ml.AuxMap as aux +import ppafm.ocl.field as FFcl +import ppafm.ocl.oclUtils as oclu +import ppafm.ocl.relax as oclr +import torch +from matplotlib import cm +from ppafm.ml.Generator import InverseAFMtrainer +from ppafm.ocl.AFMulator import AFMulator + +import mlspm.preprocessing as pp +from mlspm.datasets import download_dataset +from mlspm.models import EDAFMNet + +# # Set matplotlib font rendering to use LaTex +# plt.rcParams.update({ +# "text.usetex": True, +# "font.family": "serif", +# "font.serif": ["Computer Modern Roman"] +# }) + +def apply_preprocessing_sim(batch): + + X, Y, xyzs = batch + + print(X[0].shape) + + X = [x[..., 2:8] for x in X] + + pp.add_norm(X) + np.random.seed(0) + pp.add_noise(X, c=0.08) + + # Add background gradient + c = 0.3 + angle = -np.pi / 2 + x, y = np.meshgrid(np.arange(0, X[0].shape[1]), np.arange(0, X[0].shape[2]), indexing="ij") + n = [np.cos(angle), np.sin(angle), 1] + z = -(n[0]*x + n[1]*y) + z -= z.mean() + z /= np.ptp(z) + for x in X: + x += z[None, :, :, None]*c*np.ptp(x) + + return X, Y, xyzs + +def apply_preprocessing_exp(X, real_dim): + + # Pick slices + x0_start, x1_start = 2, 0 + X[0] = X[0][..., x0_start:x0_start+6] # CO + X[1] = X[1][..., x1_start:x1_start+6] # Xe + + X = pp.interpolate_and_crop(X, real_dim) + pp.add_norm(X) + X = [x[:,:,6:78] for x in X] + + return X + +if __name__ == "__main__": + + data_dir = Path("./edafm-data") # Path to data + X_slices = [0, 3, 5] # Which AFM slices to plot + tip_names = ["CO", "Xe"] # AFM tip types + device = "cuda" # Device to run inference on + fig_width = 140 # Figure width in mm + fontsize = 8 + dpi = 300 + + # Download data if not already there + download_dataset("ED-AFM-data", data_dir) + + # Initialize OpenCL environment on GPU + env = oclu.OCLEnvironment( i_platform = 0 ) + FFcl.init(env) + oclr.init(env) + + afmulator_args = { + "pixPerAngstrome" : 20, + "scan_dim" : (176, 144, 19), + "scan_window" : ((2.0, 2.0, 7.0), (24, 20, 8.9)), + "df_steps" : 10, + "tipR0" : [0.0, 0.0, 4.0] + } + + generator_kwargs = { + "batch_size" : 1, + "distAbove" : 5.25, + "iZPPs" : [8, 54], + "Qs" : [[ -10, 20, -10, 0 ], [ 30, -60, 30, 0 ]], + "QZs" : [[ 0.1, 0, -0.1, 0 ], [ 0.1, 0, -0.1, 0 ]] + } + + # Paths to molecule xyz files + molecules = [data_dir / "PTCDA" / "mol.xyz"] + + # Define AFMulator + afmulator = AFMulator(**afmulator_args) + afmulator.npbc = (0,0,0) + + # Define AuxMaps + aux_maps = [ + aux.ESMapConstant( + scan_dim = afmulator.scan_dim[:2], + scan_window = [afmulator.scan_window[0][:2], afmulator.scan_window[1][:2]], + height = 4.0, + vdW_cutoff = -2.0, + Rpp = 1.0 + ) + ] + + # Define generator + trainer = InverseAFMtrainer(afmulator, aux_maps, molecules, **generator_kwargs) + + # Get simulation data + sim_data = next(iter(trainer)) + X_sim, ref, xyzs = apply_preprocessing_sim(sim_data) + X_sim_cuda = [torch.from_numpy(x).unsqueeze(1).to(device) for x in X_sim] + + # Load experimental data and preprocess + data1 = np.load(data_dir / "PTCDA" / "data_CO_exp.npz") + X1 = data1["data"] + afm_dim1 = (data1["lengthX"], data1["lengthY"]) + + data2 = np.load(data_dir / "PTCDA" / "data_Xe_exp.npz") + X2 = data2["data"] + afm_dim2 = (data2["lengthX"], data2["lengthY"]) + + assert afm_dim1 == afm_dim2 + afm_dim = afm_dim1 + X_exp = apply_preprocessing_exp([X1[None], X2[None]], afm_dim) + X_exp_cuda = [torch.from_numpy(x.astype(np.float32)).unsqueeze(1).to(device) for x in X_exp] + + # Load model with gradient augmentation + model_grad = EDAFMNet(device=device, pretrained_weights="base") + + # Load model without gradient augmentation + model_no_grad = EDAFMNet(device=device, pretrained_weights="no-gradient") + + with torch.no_grad(): + pred_sim_grad, attentions_sim_grad = model_grad(X_sim_cuda) + pred_sim_no_grad, attentions_sim_no_grad = model_no_grad(X_sim_cuda) + pred_exp, attentions_exp = model_no_grad(X_exp_cuda) + pred_sim_grad = [p.cpu().numpy() for p in pred_sim_grad] + pred_sim_no_grad = [p.cpu().numpy() for p in pred_sim_no_grad] + pred_exp = [p.cpu().numpy() for p in pred_exp] + attentions_sim_grad = [a.cpu().numpy() for a in attentions_sim_grad] + attentions_sim_no_grad = [a.cpu().numpy() for a in attentions_sim_no_grad] + attentions_exp = [a.cpu().numpy() for a in attentions_exp] + + # Create figure grid + fig_width = 0.1/2.54*fig_width + width_ratios = [6, 4.4] + fig = plt.figure(figsize=(fig_width, 6*fig_width/sum(width_ratios))) + fig_grid = fig.add_gridspec(1, 2, wspace=0.3, hspace=0, width_ratios=width_ratios) + left_grid = fig_grid[0, 0].subgridspec(2, 1, wspace=0, hspace=0.1) + + pred_sim_grid = fig_grid[0, 1].subgridspec(2, 1, wspace=0, hspace=0.1) + pred_sim_no_grad_ax, cbar_sim_no_grad_ax = pred_sim_grid[0, 0].subgridspec(1, 2, wspace=0.05, + hspace=0, width_ratios=[1, 0.08]).subplots() + pred_sim_grad_ax, cbar_sim_grad_ax = pred_sim_grid[1, 0].subgridspec(1, 2, wspace=0.05, + hspace=0, width_ratios=[1, 0.08]).subplots() + pred_exp_ax, cbar_exp_ax = left_grid[0, 0].subgridspec(1, 2, wspace=0.05, width_ratios=[1, 0.05]).subplots() + afm_axes = left_grid[1, 0].subgridspec(len(X_sim), len(X_slices), wspace=0.01, hspace=0.01).subplots(squeeze=False) + + # Plot AFM + for i, x in enumerate(X_sim): + for j, s in enumerate(X_slices): + + # Plot AFM slice + im = afm_axes[i, j].imshow(x[0,:,:,s].T, origin="lower", cmap="afmhot") + afm_axes[i, j].set_axis_off() + + # Put tip names to the left of the AFM image rows + afm_axes[i, 0].text(-0.1, 0.5, tip_names[i], horizontalalignment="center", + verticalalignment="center", transform=afm_axes[i, 0].transAxes, + rotation="vertical", fontsize=fontsize) + + # Figure out ES data limits + vmax_sim_no_grad = max(abs(pred_sim_no_grad[0].min()), abs(pred_sim_no_grad[0].max())) + vmax_sim_grad = max(abs(pred_sim_grad[0].min()), abs(pred_sim_grad[0].max())) + vmax_exp = max(abs(pred_exp[0].min()), abs(pred_exp[0].max())) + vmin_sim_no_grad = -vmax_sim_no_grad + vmin_sim_grad = -vmax_sim_grad + vmin_exp = -vmax_exp + + # Plot ES predictions + pred_sim_no_grad_ax.imshow(pred_sim_no_grad[0][0].T, origin="lower", cmap="coolwarm", + vmin=vmin_sim_no_grad, vmax=vmax_sim_no_grad) + pred_sim_grad_ax.imshow(pred_sim_grad[0][0].T, origin="lower", cmap="coolwarm", + vmin=vmin_sim_grad, vmax=vmax_sim_grad) + pred_exp_ax.imshow(pred_exp[0][0].T, origin="lower", cmap="coolwarm", vmin=vmin_exp, vmax=vmax_exp) + + pred_sim_no_grad_ax.set_axis_off() + pred_sim_grad_ax.set_axis_off() + pred_exp_ax.set_axis_off() + + # Plot ES Map colorbar for no grad prediction + m_es = cm.ScalarMappable(cmap=cm.coolwarm) + m_es.set_array((vmin_sim_no_grad, vmax_sim_no_grad)) + cbar = plt.colorbar(m_es, cax=cbar_sim_no_grad_ax) + cbar.set_ticks([-0.1, 0.0, 0.1]) + cbar_sim_no_grad_ax.tick_params(labelsize=fontsize-1) + cbar.set_label("V/Å", fontsize=fontsize) + + # Plot ES Map colorbar for grad prediction + m_es = cm.ScalarMappable(cmap=cm.coolwarm) + m_es.set_array((vmin_sim_grad, vmax_sim_grad)) + cbar = plt.colorbar(m_es, cax=cbar_sim_grad_ax) + cbar.set_ticks([-0.1, 0.0, 0.1]) + cbar_sim_grad_ax.tick_params(labelsize=fontsize-1) + cbar.set_label("V/Å", fontsize=fontsize) + + # Plot ES Map colorbar for experimental prediction + m_es = cm.ScalarMappable(cmap=cm.coolwarm) + m_es.set_array((vmin_exp, vmax_exp)) + cbar = plt.colorbar(m_es, cax=cbar_exp_ax) + cbar.set_ticks([-0.04, 0.0, 0.04]) + cbar_exp_ax.tick_params(labelsize=fontsize-1) + cbar.set_label("V/Å", fontsize=fontsize) + + # Set labels + pred_exp_ax.text(-0.06, 0.98, "A", horizontalalignment="center", + verticalalignment="center", transform=pred_exp_ax.transAxes, fontsize=fontsize) + afm_axes[0, 0].text(-0.2, 1.0, "B", horizontalalignment="center", + verticalalignment="center", transform=afm_axes[0, 0].transAxes, fontsize=fontsize) + pred_sim_no_grad_ax.text(-0.08, 0.98, "C", horizontalalignment="center", + verticalalignment="center", transform=pred_sim_no_grad_ax.transAxes, fontsize=fontsize) + pred_sim_grad_ax.text(-0.08, 0.98, "D", horizontalalignment="center", + verticalalignment="center", transform=pred_sim_grad_ax.transAxes, fontsize=fontsize) + + plt.savefig("background_gradient.pdf", bbox_inches="tight", dpi=dpi) \ No newline at end of file diff --git a/papers/ed-afm/figures/bcb.py b/papers/ed-afm/figures/bcb.py new file mode 100644 index 0000000..34c4a5c --- /dev/null +++ b/papers/ed-afm/figures/bcb.py @@ -0,0 +1,214 @@ + +from pathlib import Path + +import imageio.v2 as imageio +import matplotlib.pyplot as plt +import numpy as np +import ppafm.ml.AuxMap as aux +import ppafm.ocl.field as FFcl +import ppafm.ocl.oclUtils as oclu +import ppafm.ocl.relax as oclr +import torch +from matplotlib import cm +from ppafm.ml.Generator import InverseAFMtrainer +from ppafm.ocl.AFMulator import AFMulator +from scipy.ndimage import rotate, shift + +import mlspm.preprocessing as pp +from mlspm.datasets import download_dataset +from mlspm.models import EDAFMNet + +# # Set matplotlib font rendering to use LaTex +# plt.rcParams.update({ +# "text.usetex": True, +# "font.family": "serif", +# "font.serif": ["Computer Modern Roman"] +# }) + +def apply_preprocessing_sim(batch): + + X, Y, xyzs = batch + + X = [x[..., 2:8] for x in X] + + pp.add_norm(X) + np.random.seed(0) + pp.add_noise(X, c=0.08) + + return X, Y, xyzs + +def apply_preprocessing_exp(X, real_dim): + + # Pick slices + x0_start, x1_start = 4, 9 + X[0] = X[0][..., x0_start:x0_start+6] # CO + X[1] = X[1][..., x1_start:x1_start+6] # Xe + + X = pp.interpolate_and_crop(X, real_dim) + pp.add_norm(X) + + # Flip, rotate and shift Xe data + X[1] = X[1][:,::-1] + X[1] = rotate(X[1], angle=-12, axes=(2,1), reshape=False, mode="reflect") + X[1] = shift(X[1], shift=(0,-5,1,0), mode="reflect") + X = [x[:,0:96] for x in X] + + print(X[0].shape) + + return X + +if __name__ == "__main__": + + data_dir = Path("./edafm-data") # Path to data + X_slices = [0, 3, 5] # Which AFM slices to plot + tip_names = ["CO", "Xe"] # AFM tip types + device = "cuda" # Device to run inference on + fig_width = 160 # Figure width in mm + fontsize = 8 + dpi = 300 + + # Download data if not already there + download_dataset("ED-AFM-data", data_dir) + + # Initialize OpenCL environment on GPU + env = oclu.OCLEnvironment( i_platform = 0 ) + FFcl.init(env) + oclr.init(env) + + afmulator_args = { + "pixPerAngstrome" : 20, + "scan_dim" : (128, 128, 19), + "scan_window" : ((2.0, 2.0, 7.0), (18.0, 18.0, 8.9)), + "df_steps" : 10, + "tipR0" : [0.0, 0.0, 4.0] + } + + generator_kwargs = { + "batch_size" : 1, + "distAbove" : 4.95, + "iZPPs" : [8, 54], + "Qs" : [[ -10, 20, -10, 0 ], [ 30, -60, 30, 0 ]], + "QZs" : [[ 0.1, 0, -0.1, 0 ], [ 0.1, 0, -0.1, 0 ]] + } + + # Paths to molecule xyz files + molecules = [data_dir / "BCB" / "mol.xyz"] + + # Define AFMulator + afmulator = AFMulator(**afmulator_args) + afmulator.npbc = (0,0,0) + + # Define AuxMaps + aux_maps = [ + aux.ESMapConstant( + scan_dim = afmulator.scan_dim[:2], + scan_window = [afmulator.scan_window[0][:2], afmulator.scan_window[1][:2]], + height = 4.0, + vdW_cutoff = -2.0, + Rpp = 1.0 + ) + ] + + # Define generator + trainer = InverseAFMtrainer(afmulator, aux_maps, molecules, **generator_kwargs) + + # Get simulation data + X_sim, ref, xyzs = apply_preprocessing_sim(next(iter(trainer))) + X_sim_cuda = [torch.from_numpy(x).unsqueeze(1).to(device) for x in X_sim] + + # Load experimental data and preprocess + data1 = np.load(data_dir / "BCB" / "data_CO_exp.npz") + X1 = data1["data"] + afm_dim1 = (data1["lengthX"], data1["lengthY"]) + + data2 = np.load(data_dir / "BCB" / "data_Xe_exp.npz") + X2 = data2["data"] + afm_dim2 = (data2["lengthX"], data2["lengthY"]) + + assert afm_dim1 == afm_dim2 + afm_dim = afm_dim1 + X_exp = apply_preprocessing_exp([X1[None], X2[None]], afm_dim) + X_exp_cuda = [torch.from_numpy(x.astype(np.float32)).unsqueeze(1).to(device) for x in X_exp] + + # Load model for sim + model = EDAFMNet(device=device, pretrained_weights="base") + + # Get predictions + with torch.no_grad(): + pred_sim, attentions_sim = model(X_sim_cuda) + pred_exp, attentions_exp = model(X_exp_cuda) + pred_sim = [p.cpu().numpy() for p in pred_sim] + pred_exp = [p.cpu().numpy() for p in pred_exp] + attentions_sim = [a.cpu().numpy() for a in attentions_sim] + attentions_exp = [a.cpu().numpy() for a in attentions_exp] + + # Create figure grid + fig_width = 0.1/2.54*fig_width + width_ratios = [6, 8, 0.4] + height_ratios = [1, 1.08] + gap = 0.15 + fig = plt.figure(figsize=(fig_width, 8.85*fig_width/sum(width_ratios))) + fig_grid = fig.add_gridspec(1, len(width_ratios), wspace=0.02, hspace=0, width_ratios=width_ratios) + afm_grid = fig_grid[0, 0].subgridspec(2, 1, wspace=0, hspace=gap, height_ratios=height_ratios) + pred_grid = fig_grid[0, 1].subgridspec(2, 2, wspace=0.02, hspace=gap, height_ratios=height_ratios) + cbar_grid = fig_grid[0, 2].subgridspec(1, 1, wspace=0, hspace=0) + + # Get axes from grid + afm_sim_axes = afm_grid[0, 0].subgridspec(len(X_sim), len(X_slices), wspace=0.01, hspace=0.01).subplots(squeeze=False) + afm_exp_axes = afm_grid[1, 0].subgridspec(len(X_exp), len(X_slices), wspace=0.01, hspace=0.01).subplots(squeeze=False) + pred_sim_ax, ref_pc_ax, pred_exp_ax, geom_ax = pred_grid.subplots(squeeze=True).flatten() + cbar_ax = cbar_grid.subplots(squeeze=True) + + # Plot AFM + for k, (axes, X) in enumerate(zip([afm_sim_axes, afm_exp_axes], [X_sim, X_exp])): + for i, x in enumerate(X): + for j, s in enumerate(X_slices): + + # Plot AFM slice + im = axes[i, j].imshow(x[0,:,:,s].T, origin="lower", cmap="afmhot") + axes[i, j].set_axis_off() + + # Put tip names to the left of the AFM image rows + axes[i, 0].text(-0.1, 0.5, tip_names[i], horizontalalignment="center", + verticalalignment="center", transform=axes[i, 0].transAxes, + rotation="vertical", fontsize=fontsize) + + # Figure out data limits + vmax = max( + abs(pred_sim[0].min()), abs(pred_sim[0].max()), + abs(pred_exp[0].min()), abs(pred_exp[0].max()), + abs(ref[0].min()), abs(ref[0].max()) + ) + vmin = -vmax + + # Plot predictions and references + pred_sim_ax.imshow(pred_sim[0][0].T, origin="lower", cmap="coolwarm", vmin=vmin, vmax=vmax) + pred_exp_ax.imshow(pred_exp[0][0].T, origin="lower", cmap="coolwarm", vmin=vmin, vmax=vmax) + ref_pc_ax.imshow(ref[0][0].T, origin="lower", cmap="coolwarm", vmin=vmin, vmax=vmax) + + # Plot molecule geometry + xyz_img = np.flipud(imageio.imread(data_dir / "BCB" / "mol.png")) + geom_ax.imshow(xyz_img, origin="lower") + + # Plot ES Map colorbar + m_es = cm.ScalarMappable(cmap=cm.coolwarm) + m_es.set_array((vmin, vmax)) + cbar = plt.colorbar(m_es, cax=cbar_ax) + cbar.set_ticks([-0.02, -0.01, 0.0, 0.01, 0.02]) + cbar_ax.tick_params(labelsize=fontsize-1) + cbar.set_label("V/Å", fontsize=fontsize) + + # Turn off axes ticks + pred_sim_ax.set_axis_off() + pred_exp_ax.set_axis_off() + ref_pc_ax.set_axis_off() + geom_ax.set_axis_off() + + # Set titles + afm_sim_axes[0, len(X_slices)//2].set_title("AFM simulation", fontsize=fontsize, y=0.94) + afm_exp_axes[0, len(X_slices)//2].set_title("AFM experiment", fontsize=fontsize, y=0.94) + pred_sim_ax.set_title("Sim. prediction", fontsize=fontsize, y=0.97) + pred_exp_ax.set_title("Exp. prediction", fontsize=fontsize, y=0.97) + ref_pc_ax.set_title("Reference", fontsize=fontsize, y=0.97) + + plt.savefig("bcb.pdf", bbox_inches="tight", dpi=dpi) diff --git a/papers/ed-afm/figures/esmap_sample.py b/papers/ed-afm/figures/esmap_sample.py new file mode 100644 index 0000000..b5f2c62 --- /dev/null +++ b/papers/ed-afm/figures/esmap_sample.py @@ -0,0 +1,89 @@ + +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import ppafm.ml.AuxMap as aux +import ppafm.ocl.field as FFcl +import ppafm.ocl.oclUtils as oclu + +from mlspm.datasets import download_dataset +from mlspm.utils import read_xyzs + +if __name__ == "__main__": + + data_dir = Path("./edafm-data") # Path to data + save_dir = Path("./images/") # Where images are saved + + scan_window = ((-8, -8), (8, 8)) + scan_dim = (128, 128) + height = 4 + zmin = -2.0 + Rpp = 1.0 + + # Download data if not already there + download_dataset("ED-AFM-data", data_dir) + + # Initialize OpenCL environment on GPU + env = oclu.OCLEnvironment(i_platform=0) + FFcl.init(env) + + # Paths to molecule xyz file + xyz_path = data_dir / "BCB" / "mol.xyz" + + # Define AuxMaps + es_map = aux.ESMapConstant(scan_dim=scan_dim, scan_window=scan_window, height=height) + vdw = aux.vdwSpheres(scan_dim=scan_dim, scan_window=scan_window, zmin=zmin, Rpp=Rpp) + + # Make sure save directory exists + save_dir.mkdir(exist_ok=True, parents=True) + + # Load molecule + mol = read_xyzs([xyz_path])[0] + xyzqs = mol[:, :4] + xyzqs[:, :3] -= xyzqs[:, :3].mean(axis=0) + Zs = mol[:, 4].astype(int) + + # Compute decriptors + Y_es = es_map(xyzqs, Zs) + Y_vdw = vdw(xyzqs, Zs) + Y_vdw_mask = Y_vdw.copy() + Y_vdw_mask -= Y_vdw_mask.min() + Y_vdw_mask[Y_vdw_mask > 0.0] = 1.0 + Y_combined = Y_vdw_mask * Y_es + + # Plot ES field + plt.figure(figsize=tuple(0.01 * np.array(Y_es.shape)), dpi=100) + vmax = max(abs(Y_es.max()), abs(Y_es.min())) + vmin = -vmax + plt.imshow(Y_es.T, origin="lower", cmap="coolwarm", vmin=vmin, vmax=vmax) + plt.axis("off") + plt.tight_layout(pad=0) + plt.savefig(save_dir / "es.png") + plt.close() + + # Plot vdw spheres + plt.figure(figsize=tuple(0.01 * np.array(Y_es.shape)), dpi=100) + plt.imshow(Y_vdw.T, origin="lower", cmap="viridis") + plt.axis("off") + plt.tight_layout(pad=0) + plt.savefig(save_dir / "vdw.png") + plt.close() + + # Plot vdw mask + plt.figure(figsize=tuple(0.01 * np.array(Y_es.shape)), dpi=100) + plt.imshow(Y_vdw_mask.T, origin="lower", cmap="viridis") + plt.axis("off") + plt.tight_layout(pad=0) + plt.savefig(save_dir / "vdw_mask.png") + plt.close() + + # Plot combined + plt.figure(figsize=tuple(0.01 * np.array(Y_es.shape)), dpi=100) + vmax = max(abs(Y_combined.max()), abs(Y_combined.min())) + vmin = -vmax + plt.imshow(Y_combined.T, origin="lower", cmap="coolwarm", vmin=vmin, vmax=vmax) + plt.axis("off") + plt.tight_layout(pad=0) + plt.savefig(save_dir / "es_cut.png") + plt.close() diff --git a/papers/ed-afm/figures/esmap_schem.tex b/papers/ed-afm/figures/esmap_schem.tex new file mode 100644 index 0000000..14b0e0b --- /dev/null +++ b/papers/ed-afm/figures/esmap_schem.tex @@ -0,0 +1,47 @@ + + +\documentclass[tikz]{standalone} + +\usepackage[english]{babel} +\usepackage[utf8]{inputenc} +\usepackage{amsfonts,amssymb,amsbsy} +\usepackage{xcolor} + +\usepackage{tikz} +\usetikzlibrary{arrows.meta, positioning, quotes, calc, intersections, decorations.pathreplacing} + +\begin{document} + + \begin{tikzpicture} + + \tikzset{ + myline/.style={draw=black!50!white, line width=0.4mm, rounded corners}, + myarrow/.style={myline, -{Latex[width=0.2cm, length=0.3cm]}}, + fontstyle/.style={font=\footnotesize}, + labelstyle/.style={fontstyle, yshift=-12mm} + } + + % Set images and labels + \path node[outer sep=-4mm] (0,0) (mol) {\includegraphics[width=20mm]{./edafm-data/BCB/mol.png}} node[labelstyle, align=center, yshift=2mm] {Molecule\\geometry} + ++(30mm,13mm) node (es) {\includegraphics[width=20mm]{./images/es.png}} node[labelstyle] {ES field} + ++(0,-26mm) node (vdw) {\includegraphics[width=20mm]{./images/vdw.png}} node[labelstyle] {vdW Spheres} + + ++(35mm, 0) node (mask) {\includegraphics[width=20mm]{./images/vdw_mask.png}} node[labelstyle] {Mask} + ++(0, 26mm) node[draw, circle, myline, minimum size=5mm, inner sep=0, outer sep=2mm] (times) {} + node[text=black!50!white, xshift=0.05mm, yshift=-0.05mm] {$\boldsymbol{\times}$} + + ++(35mm, 0) node (esmap) {\includegraphics[width=20mm]{./images/es_cut.png}} node[labelstyle] {ES Map}; + + % Draw arrows + \draw[myarrow] (mol) -- (es); + \draw[myarrow] (mol) -- (vdw); + \draw[myarrow] (vdw) -- (mask); + \draw[myarrow] (mask) -- (times); + \draw[myarrow] (es) -- (times); + \draw[myarrow] (times) -- (esmap); + + +ra \end{tikzpicture} + + +\end{document} \ No newline at end of file diff --git a/papers/ed-afm/figures/extra_electron.py b/papers/ed-afm/figures/extra_electron.py new file mode 100644 index 0000000..49339bf --- /dev/null +++ b/papers/ed-afm/figures/extra_electron.py @@ -0,0 +1,114 @@ + +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import ppafm.ml.AuxMap as aux +import ppafm.ocl.field as FFcl +import ppafm.ocl.oclUtils as oclu +import ppafm.ocl.relax as oclr +from ppafm.ml.Generator import InverseAFMtrainer +from ppafm.ocl.AFMulator import AFMulator + +import mlspm.preprocessing as pp +from mlspm.datasets import download_dataset + +# # Set matplotlib font rendering to use LaTex +# plt.rcParams.update({ +# "text.usetex": True, +# "font.family": "serif", +# "font.serif": ["Computer Modern Roman"] +# }) + +class Trainer(InverseAFMtrainer): + + # Override this method to set the Xe tip closer + def handle_distance(self): + if self.afmulator.iZPP == 54: + self.distAboveActive -= 0.4 + super().handle_distance() + + # Override position handling to center on the non-Cu atoms + def handle_positions(self): + sw = self.afmulator.scan_window + scan_center = np.array([sw[1][0] + sw[0][0], sw[1][1] + sw[0][1]]) / 2 + self.xyzs[:,:2] += scan_center - self.xyzs[self.Zs != 29,:2].mean(axis=0) + +def apply_preprocessing(batch): + + X, Y, xyzs = batch + + X = [x[..., 2:8] for x in X] + + pp.add_norm(X) + np.random.seed(0) + pp.add_noise(X, c=0.05) + + return X, Y, xyzs + + +if __name__ == "__main__": + + data_dir = Path("./edafm-data") # Path to data + X_slices = [0, 3, 5] # Which AFM slices to plot + tip_names = ["CO", "Xe"] # AFM tip types + fig_width = 100 # Figure width in mm + fontsize = 8 + dpi = 300 + + # Download data if not already there + download_dataset("ED-AFM-data", data_dir) + + # Initialize OpenCL environment on GPU + env = oclu.OCLEnvironment( i_platform = 0 ) + FFcl.init(env) + oclr.init(env) + + afmulator_args = { + "pixPerAngstrome" : 20, + "scan_dim" : (144, 104, 19), + "scan_window" : ((2.0, 2.0, 7.0), (20.0, 15.0, 8.9)), + "df_steps" : 10, + "tipR0" : [0.0, 0.0, 4.0] + } + + generator_kwargs = { + "batch_size" : 1, + "distAbove" : 4.75, + "iZPPs" : [8, 54], + "Qs" : [[ -10, 20, -10, 0 ], [ 30, -60, 30, 0 ]], + "QZs" : [[ 0.1, 0, -0.1, 0 ], [ 0.1, 0, -0.1, 0 ]] + } + + # Paths to molecule xyz files + molecules = [data_dir / "PTCDA" / "mol-1.xyz"] + + # Define AFMulator + afmulator = AFMulator(**afmulator_args) + afmulator.npbc = (0,0,0) + + # Define generator + trainer = Trainer(afmulator, [], molecules, **generator_kwargs) + + # Get simulation data + X, _, _ = apply_preprocessing(next(iter(trainer))) + + # Create figure grid + fig_width = 0.1/2.54*fig_width + fig = plt.figure(figsize=(fig_width, 0.49*fig_width)) + axes = fig.add_gridspec(len(X), len(X_slices), wspace=0.02, hspace=0.02).subplots(squeeze=False) + + # Plot AFM + for i, x in enumerate(X): + for j, s in enumerate(X_slices): + + # Plot AFM slice + im = axes[i, j].imshow(x[0,:,:,s].T, origin="lower", cmap="afmhot") + axes[i, j].set_axis_off() + + # Put tip names to the left of the AFM image rows + axes[i, 0].text(-0.1, 0.5, tip_names[i], horizontalalignment="center", + verticalalignment="center", transform=axes[i, 0].transAxes, + rotation="vertical", fontsize=fontsize) + + plt.savefig("extra_electron.pdf", bbox_inches="tight", dpi=dpi) diff --git a/papers/ed-afm/figures/height_dependence.py b/papers/ed-afm/figures/height_dependence.py new file mode 100644 index 0000000..dcdfebe --- /dev/null +++ b/papers/ed-afm/figures/height_dependence.py @@ -0,0 +1,170 @@ + +import string +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import torch +from matplotlib import cm +from scipy.ndimage import rotate, shift + +import mlspm.preprocessing as pp +from mlspm.datasets import download_dataset +from mlspm.models import EDAFMNet + +# # Set matplotlib font rendering to use LaTex +# plt.rcParams.update({ +# "text.usetex": True, +# "font.family": "serif", +# "font.serif": ["Computer Modern Roman"] +# }) + +def apply_preprocessing_bcb(X, real_dim): + + # Pick slices + X[0] = np.concatenate([X[0][..., i:i+6] for i in [5, 4, 3]], axis=0) + X[1] = np.concatenate([X[1][..., i:i+6] for i in [10, 9, 8, 4]], axis=0) + + X = pp.interpolate_and_crop(X, real_dim) + pp.add_norm(X) + + # Flip, rotate and shift Xe data + X[1] = X[1][:,::-1] + X[1] = rotate(X[1], angle=-12, axes=(2,1), reshape=False, mode="reflect") + X[1] = shift(X[1], shift=(0,-5,1,0), mode="reflect") + X = [x[:,0:96] for x in X] + + return X + +def apply_preprocessing_ptcda(X, real_dim): + + # Pick slices + X[0] = np.concatenate([X[0][..., i:i+6] for i in [3, 2, 1]], axis=0) + X[1] = np.concatenate([X[1][..., i:i+6] for i in [6, 2, 1, 0]], axis=0) + + X = pp.interpolate_and_crop(X, real_dim) + pp.add_norm(X) + X = [x[:,:,6:78] for x in X] + + return X + +if __name__ == "__main__": + + # Options + data_dir = Path("./edafm-data") # Path to data + device = "cuda" # Device to run inference on + fig_width = 150 # Figure width in mm + fontsize = 8 + dpi = 300 + + # Download data if not already there + download_dataset("ED-AFM-data", data_dir) + + # Load model + model = EDAFMNet(device=device, pretrained_weights="base") + + # Load BCB data and preprocess + data_bcb = np.load(data_dir / "BCB" / "data_CO_exp.npz") + afm_dim_bcb = (data_bcb["lengthX"], data_bcb["lengthY"]) + X_bcb_CO = data_bcb["data"] + X_bcb_Xe = np.load(data_dir / "BCB" / "data_Xe_exp.npz")["data"] + X_bcb = apply_preprocessing_bcb([X_bcb_CO[None], X_bcb_Xe[None]], afm_dim_bcb) + + # Load PTCDA data and preprocess + data_ptcda = np.load(data_dir / "PTCDA" / "data_CO_exp.npz") + afm_dim_ptcda = (data_ptcda["lengthX"], data_ptcda["lengthY"]) + X_ptcda_CO = data_ptcda["data"] + X_ptcda_Xe = np.load(data_dir / "PTCDA" / "data_Xe_exp.npz")["data"] + X_ptcda = apply_preprocessing_ptcda([X_ptcda_CO[None], X_ptcda_Xe[None]], afm_dim_ptcda) + + # Create figure grid + fig_width = 0.1/2.54*fig_width + height_ratios = [1, 0.525] + width_ratios = [1, 0.03] + fig = plt.figure(figsize=(fig_width, 0.86*sum(height_ratios)*fig_width/sum(width_ratios))) + fig_grid = fig.add_gridspec(2, 2, wspace=0.05, hspace=0.15, height_ratios=height_ratios, width_ratios=width_ratios) + + ticks = [ + [-0.03, -0.02, -0.01, 0.00, 0.01, 0.02, 0.03], + [-0.08, -0.04, 0.00, 0.04, 0.08] + ] + + offsets_labels = [ + [ + ["-0.1Å", "+0.0Å", "+0.1Å", "+0.5Å"], + ["-0.1Å", "+0.0Å", "+0.1Å"] + ], + [ + ["-0.6Å", "-0.2Å", "-0.1Å", "+0.0Å"], + ["-0.1Å", "+0.0Å", "+0.1Å"] + ] + ] + + # Do for both BCB and PTCDA + for k, X in enumerate([X_bcb, X_ptcda]): + + # Create subgrid for predictions and colorbar + pred_grid = fig_grid[k, 0].subgridspec(3, 4, wspace=0.01, hspace=0) + pred_axes = pred_grid.subplots(squeeze=False) + cbar_ax = fig_grid[k, 1].subgridspec(1, 1, wspace=0, hspace=0).subplots(squeeze=True) + + preds = np.zeros([3, 4, X[0].shape[1], X[0].shape[2]]) + for i in range(3): + for j in range(4): + + # Pick a subset of slices + X_ = [x.copy() for x in X] + X_[0] = X_[0][i:i+1] + X_[1] = X_[1][j:j+1] + X_cuda = [torch.from_numpy(x.astype(np.float32)).unsqueeze(1).to(device) for x in X_] + + # Make prediction + with torch.no_grad(): + pred = model(X_cuda) + preds[i, j] = pred[0][0].cpu().numpy() + + # Figure out data limits + vmax = max(abs(preds.min()), abs(preds.max())) + vmin = -vmax + + # Plot predictions + for i in range(3): + for j in range(4): + pred_axes[i, j].imshow(preds[i, j].T, origin="lower", cmap="coolwarm", vmin=vmin, vmax=vmax) + pred_axes[i, j].set_axis_off() + if i == 0: + pred_axes[i, j].text(0.5, 1.06+k*0.03, offsets_labels[k][0][j], horizontalalignment="center", + verticalalignment="center", transform=pred_axes[i, j].transAxes, + fontsize=fontsize-2) + if j == 0: + pred_axes[i, j].text(-0.06, 0.5, offsets_labels[k][1][i], horizontalalignment="center", + verticalalignment="center", transform=pred_axes[i, j].transAxes, + rotation="vertical", fontsize=fontsize-2) + + # Plot ES Map colorbar + m_es = cm.ScalarMappable(cmap=cm.coolwarm) + m_es.set_array((vmin, vmax)) + cbar = plt.colorbar(m_es, cax=cbar_ax) + cbar.set_ticks(ticks[k]) + cbar_ax.tick_params(labelsize=fontsize-1) + + # Set Xe-shift title + ((x0, _), ( _, y)) = pred_axes[0, 0].get_position().get_points() + (( _, _), (x1, _)) = pred_axes[0, -1].get_position().get_points() + plt.text((x0 + x1)/2, y+0.03, "Xe-shift", fontsize=fontsize, + transform=fig.transFigure, horizontalalignment="center", verticalalignment="center") + + # Set CO-shift title + (( x, _), (_, y1)) = pred_axes[ 0, 0].get_position().get_points() + (( _, y0), (_, _)) = pred_axes[-1, 0].get_position().get_points() + plt.text(x0-0.04, (y0 + y1)/2, "CO-shift", fontsize=fontsize, + transform=fig.transFigure, horizontalalignment="center", verticalalignment="center", + rotation="vertical") + + # Set subfigure reference letters + grid_pos = pred_grid.get_grid_positions(fig) + x, y = grid_pos[2][0]-0.03, grid_pos[1][0]+0.01 + fig.text(x, y, string.ascii_uppercase[k], fontsize=fontsize, + horizontalalignment="center", verticalalignment="center") + + plt.savefig("height_dependence.pdf", bbox_inches="tight", dpi=dpi) diff --git a/papers/ed-afm/figures/model_schem.tex b/papers/ed-afm/figures/model_schem.tex new file mode 100644 index 0000000..7f53032 --- /dev/null +++ b/papers/ed-afm/figures/model_schem.tex @@ -0,0 +1,153 @@ + + +%\documentclass[12pt,a4paper]{article} +\documentclass[tikz]{standalone} + +\usepackage[english]{babel} +\usepackage[utf8]{inputenc} +\usepackage{amsfonts,amssymb,amsbsy} +\usepackage{xcolor} + +\usepackage{tikz} +\usetikzlibrary{arrows.meta, positioning, quotes, calc, intersections, decorations.pathreplacing} + +\begin{document} + + \begin{tikzpicture}[scale=0.92] + + \tikzset{ + cuboid/.pic={ + \tikzset{% + every edge quotes/.append style={midway, auto}, + /cuboid/.cd, + #1 + } + \draw [every edge/.append style={pic actions, densely dashed, opacity=.5}, pic actions] + (0,0,0) coordinate (sl) ++(0,\cubescale*\cubey,\cubescale*\cubez/2) coordinate (o) + -- ++(-\cubescale*\cubex,0,0) coordinate (a) -- ++(0,-\cubescale*\cubey,0) coordinate (b) -- node[midway] (bm) {} ++(\cubescale*\cubex,0,0) coordinate (c) -- node[midway] (fm) {} cycle + (o) -- node[midway] (su1) {} ++(0,0,-\cubescale*\cubez) coordinate (d) -- ++(0,-\cubescale*\cubey,0) coordinate (e) -- (c) -- cycle + (o) -- (a) -- node[midway] (su2) {} ++(0,0,-\cubescale*\cubez) coordinate (f) -- (d) -- cycle + ($(su1)!0.5!(sl)$) node (sc) {} + ($(su1)!0.5!(su2)$) node (uc) {}; + \path (uc) ++($(b)-(a)$) coordinate(bc); + \draw [opacity=0.3] (f) -- ++(0,-\cubescale*\cubey,0) coordinate(g) (g) -- (e) (g) -- (b); + }, + pics/attsymbol/.style args={#1, #2}{ + code={ + \def\r{0.6*#1} + \draw[black, line width=#2] circle (#1) coordinate (c); + \draw[black, line width=#2] (-\r, -\r) .. controls (\r, -\r, 0) and (-\r, \r) .. (\r, \r); + } + }, + conv3d/.pic={\pic [fill=green!50!white, opacity=0.8] {cuboid={#1}};}, + conv2d/.pic={\pic [fill=magenta!60!white, opacity=0.8] {cuboid={#1}};}, + conv2dlr/.pic={\pic [fill=cyan!60!white, opacity=0.8] {cuboid={#1}};}, + conv2dr/.pic={\pic [fill=blue!50!white, opacity=0.8] {cuboid={#1}};}, + upsample/.pic={\pic [fill=yellow!80!white, opacity=0.8] {cuboid={#1}};}, + maxpool/.pic={\pic [fill=blue!80!white, opacity=0.8] {cuboid={#1}};}, + avgpool/.pic={\pic [fill=teal!80!white, opacity=0.8] {cuboid={#1}};}, + myline/.style={draw=black!50!white, line width=0.4mm, rounded corners}, + myarrow/.style={myline, -{Latex[width=0.15cm, length=0.2cm]}}, + myarrow2/.style={myline, draw=violet!60!white, -{Latex[width=0.15cm, length=0.2cm]}}, + layerparam/.style={rotate=90, anchor=east, font=\scriptsize}, + /cuboid/.search also={/tikz}, + /cuboid/.cd, + width/.store in=\cubex, + height/.store in=\cubey, + depth/.store in=\cubez, + units/.store in=\cubeunits, + scale/.store in=\cubescale, + width=1, + height=1, + depth=1, + units=cm, + scale=1 + } + + % Input branches + \path (0,0,0) coordinate (model) pic {conv3d={width=0.4, height=2.5, depth=2.5}} (sc) node (c1) {} (sl) + ++(0,-3.9,0) pic {conv3d={width=0.4, height=2.5, depth=2.5}} (sc) node (c2) {} + (bm) node[layerparam] {2x(32@128x128x6)} (sl); + + % Input branch arrows + \draw[myarrow] (c2) -- ++(0.95, 0); + \draw[myline] (c1) -- ++(0.45,0) -- ++($(c2)-(c1)$); + + % Encoder + \path ($(sl)+(1.8,0,0)$) pic {conv3d={width=0.4, height=2.5, depth=2.5}} (uc) node (c3) {} (bc) node (c3b) {} (sl) + (bm) node[layerparam] {3x(48@128x128x6)} (sl) + ++(0.6,0,0) pic {maxpool={width=0.3, height=2.0, depth=2.0}} + (bm) node[layerparam] {48@64x64x3} (sl) + % + ++(1.2,0,0) pic {conv3d={width=0.3, height=2.0, depth=2.0}} (uc) node (c4) {} (bc) node (c4b) {} (sl) + (bm) node[layerparam] {3x(96@64x64x3)} (sl) + ++(0.4,0,0) pic {maxpool={width=0.2, height=1.5, depth=1.5}} + (bm) node[layerparam] {96@32x32x2} (sl) + % + ++(0.9,0,0) pic {conv3d={width=0.2, height=1.5, depth=1.5}} (uc) node (c5) {} (bc) node (c5b) {} (sl) + (bm) node[layerparam] {3x(192@32x32x2)} (sl) + ++(0.3,0,0) pic {maxpool={width=0.1, height=1.0, depth=1.0}} + (bm) node[layerparam] {192@16x16x1} (sl) + + % Middle + ++(0.8,0,0) pic {conv2dlr={width=0, height=1.0, depth=1.0}} (uc) node (m1) {} (sl) node (m1b) {} + (bm) node[layerparam] {3x(512@16x16)} (sl) + + % Decoder + ++(0.9,0,0) pic {upsample={width=0, height=1.5, depth=1.5}} + (bm) node[layerparam] {512@32x32} (sl) + ++(0.3,0,0) pic {conv2dlr={width=0, height=1.5, depth=1.5}} + (bm) node[layerparam] {2x(256@32x32)} (sl) + ++(0.3,0,0) pic {conv2dlr={width=0, height=1.5, depth=1.5}} (uc) node (c6) {} (sl) + (bm) node[layerparam] {2x(256@32x32)} (sl) + % + ++(0.8,0,0) pic {upsample={width=0, height=2.0, depth=2.0}} + (bm) node[layerparam] {256@64x64} (sl) + ++(0.3,0,0) pic {conv2dlr={width=0, height=2.0, depth=2.0}} + (bm) node[layerparam] {2x(128@64x64)} (sl) + ++(0.3,0,0) pic {conv2dlr={width=0, height=2.0, depth=2.0}} (uc) node (c7) {} (sl) + (bm) node[layerparam] {2x(128@64x64)} (sl) + % + ++(1.1,0,0) pic {upsample={width=0, height=2.5, depth=2.5}} + (bm) node[layerparam] {128@128x128} (sl) + ++(0.3,0,0) pic {conv2dlr={width=0, height=2.5, depth=2.5}} + (bm) node[layerparam] {2x(64@128x128)} (sl) + ++(0.3,0,0) pic {conv2dlr={width=0, height=2.5, depth=2.5}} (uc) node (c8) {} (sc) node (c9) {} + (bm) node[layerparam] {2x(64@128x128)} (sl); + + % Output branches + \path ($(sl)+(1.3,0,0)$) pic {conv2dlr={width=0, height=2.5, depth=2.5}} + (bm) node[layerparam] {3x(64@128x128)} (sl) + ++(0.3,0,0) pic {conv2d={width=0, height=2.5, depth=2.5}} + (bm) node[layerparam] {1@128x128} (sl); + + % Skip-connections + \def\dy{0.1} + \draw[myarrow] (c5) -- ++(0.0,0.8,0) coordinate (s1) -- ++(2, 0) coordinate (a1); + \path ($(a1)+(0.3, -0.1)$) pic {attsymbol={0.22, 0.8}} ++(0.3, 0) coordinate (a2); + \draw[myarrow2] (m1) -- ($(s1)+(m1b)-(c5b)-(0.0,2*\dy)$) -- ($(a1) - (0, 2*\dy)$); + \draw[myarrow] (a2) -- ++($(c6)-(a2)+(s1)-(c5)-(0,\dy)$) -- (c6); + % + \draw[myarrow] (c4) -- ++(0.0,0.8,0) coordinate (s1) -- ++(4.7, 0) coordinate (a1); + \path ($(a1)+(0.3, -0.1)$) pic {attsymbol={0.22, 0.8}} ++(0.3, 0) coordinate (a2); + \draw[myarrow2] (m1) -- ($(s1)+(m1b)-(c4b)-(0.0,2*\dy)$) -- ($(a1) - (0, 2*\dy)$); + \draw[myarrow] (a2) -- ++($(c7)-(a2)+(s1)-(c4)-(0,\dy)$) -- (c7); + % + \draw[myarrow] (c3) -- ++(0.0,0.8,0) coordinate (s1) -- ++(8.2, 0) coordinate (a1); + \path ($(a1)+(0.3, -0.1)$) pic {attsymbol={0.22, 0.8}} ++(0.3, 0) coordinate (a2); + \draw[myarrow2] (m1) -- ($(s1)+(m1b)-(c3b)-(0.0,2*\dy)$) -- ($(a1) - (0, 2*\dy)$); + \draw[myarrow] (a2) -- ++($(c8)-(a2)+(s1)-(c3)-(0,\dy)$) -- (c8); + + % Legend + \path (model) ++(3.3,2.3,0) coordinate (start) + pic {conv3d={width=0.4, height=0.4, depth=0}} ($(fm)+(-0.05,-0.03)$) node [label=east:{\footnotesize 3D Conv Block (LeakyReLU)}] {} (sl) + ++(0,-0.6,0) pic {conv2dlr={width=0.4, height=0.4, depth=0}} ($(fm)+(-0.05,-0.03)$) node [label=east:{\footnotesize 2D Conv Block (LeakyReLU)}] {} (sl) + ++(0,-0.6,0) pic {conv2d={width=0.4, height=0.4, depth=0}} ($(fm)+(-0.05,-0.01)$) node [label=east:{\footnotesize 2D Conv (No activation)}] {} (bm) + (start) ++(5.5,0,0) pic {maxpool={width=0.4, height=0.4, depth=0}} ($(fm)+(-0.05,-0.01)$) node [label=east:{\footnotesize AvgPool}] {} (sl) + ++(0,-0.6,0) pic {upsample={width=0.4, height=0.4, depth=0}} ($(fm)+(-0.05,-0.03)$) node [label=east:{\footnotesize NN-upsample}] {} (bm) + ++(0,-0.45,0) pic {attsymbol={0.22, 0.8}} ($(c)+(0.15, 0)$) node [label=east:{\footnotesize Attention Gate}] {}; + + \end{tikzpicture} + + +\end{document} \ No newline at end of file diff --git a/papers/ed-afm/figures/ptcda.py b/papers/ed-afm/figures/ptcda.py new file mode 100644 index 0000000..bfaa82d --- /dev/null +++ b/papers/ed-afm/figures/ptcda.py @@ -0,0 +1,214 @@ + +from pathlib import Path + +import imageio.v2 as imageio +import matplotlib.pyplot as plt +import numpy as np +import ppafm.ml.AuxMap as aux +import ppafm.ocl.field as FFcl +import ppafm.ocl.oclUtils as oclu +import ppafm.ocl.relax as oclr +import torch +from matplotlib import cm +from ppafm.ml.Generator import InverseAFMtrainer +from ppafm.ocl.AFMulator import AFMulator + +import mlspm.preprocessing as pp +from mlspm.datasets import download_dataset +from mlspm.models import EDAFMNet + +# # Set matplotlib font rendering to use LaTex +# plt.rcParams.update({ +# "text.usetex": True, +# "font.family": "serif", +# "font.serif": ["Computer Modern Roman"] +# }) + +def apply_preprocessing_sim(batch): + + X, Y, xyzs = batch + + X = [x[..., 2:8] for x in X] + + pp.add_norm(X) + np.random.seed(0) + pp.add_noise(X, c=0.08) + + return X, Y, xyzs + +def apply_preprocessing_exp(X, real_dim): + + # Pick slices + x0_start, x1_start = 2, 0 + X[0] = X[0][..., x0_start:x0_start+6] # CO + X[1] = X[1][..., x1_start:x1_start+6] # Xe + + X = pp.interpolate_and_crop(X, real_dim) + pp.add_norm(X) + X = [x[:,:,6:78] for x in X] + + print(X[0].shape) + + return X + +if __name__ == "__main__": + + data_dir = Path("./edafm-data") # Path to data + X_slices = [0, 3, 5] # Which AFM slices to plot + tip_names = ["CO", "Xe"] # AFM tip types + device = "cuda" # Device to run inference on + fig_width = 160 # Figure width in mm + fontsize = 8 + dpi = 300 + + # Download data if not already there + download_dataset("ED-AFM-data", data_dir) + + # Initialize OpenCL environment on GPU + env = oclu.OCLEnvironment( i_platform = 0 ) + FFcl.init(env) + oclr.init(env) + + afmulator_args = { + "pixPerAngstrome" : 20, + "scan_dim" : (144, 104, 19), + "scan_window" : ((2.0, 2.0, 7.0), (20.0, 15.0, 8.9)), + "df_steps" : 10, + "tipR0" : [0.0, 0.0, 4.0] + } + + generator_kwargs = { + "batch_size" : 1, + "distAbove" : 5.45, + "iZPPs" : [8, 54], + "Qs" : [[ -10, 20, -10, 0 ], [ 30, -60, 30, 0 ]], + "QZs" : [[ 0.1, 0, -0.1, 0 ], [ 0.1, 0, -0.1, 0 ]] + } + + # Paths to molecule xyz files + molecules = [data_dir / "PTCDA" / "mol.xyz"] + + # Define AFMulator + afmulator = AFMulator(**afmulator_args) + afmulator.npbc = (0,0,0) + + # Define AuxMaps + aux_maps = [ + aux.ESMapConstant( + scan_dim = afmulator.scan_dim[:2], + scan_window = [afmulator.scan_window[0][:2], afmulator.scan_window[1][:2]], + height = 4.0, + vdW_cutoff = -2.0, + Rpp = 1.0 + ) + ] + + # Define generator + trainer = InverseAFMtrainer(afmulator, aux_maps, molecules, **generator_kwargs) + + # Get simulation data + X_sim, ref, xyzs = apply_preprocessing_sim(next(iter(trainer))) + X_sim_cuda = [torch.from_numpy(x).unsqueeze(1).to(device) for x in X_sim] + + # Load experimental data and preprocess + data1 = np.load(data_dir / "PTCDA" / "data_CO_exp.npz") + X1 = data1["data"] + afm_dim1 = (data1["lengthX"], data1["lengthY"]) + + data2 = np.load(data_dir / "PTCDA" / "data_Xe_exp.npz") + X2 = data2["data"] + afm_dim2 = (data2["lengthX"], data2["lengthY"]) + + assert afm_dim1 == afm_dim2 + afm_dim = afm_dim1 + X_exp = apply_preprocessing_exp([X1[None], X2[None]], afm_dim) + X_exp_cuda = [torch.from_numpy(x.astype(np.float32)).unsqueeze(1).to(device) for x in X_exp] + + # Load model + model = EDAFMNet(device=device, pretrained_weights="base") + + # Get predictions + with torch.no_grad(): + pred_sim, attentions_sim = model(X_sim_cuda) + pred_exp, attentions_exp = model(X_exp_cuda) + pred_sim = [p.cpu().numpy() for p in pred_sim] + pred_exp = [p.cpu().numpy() for p in pred_exp] + attentions_sim = [a.cpu().numpy() for a in attentions_sim] + attentions_exp = [a.cpu().numpy() for a in attentions_exp] + + # Load Hartree reference + Y_hartree = np.load(data_dir / "PTCDA" / "ESMapHartree.npy") + + # Create figure grid + fig_width = 0.1/2.54*fig_width + width_ratios = [6, 8, 0.3] + height_ratios = [1, 0.778] + gap = 0.25 + fig = plt.figure(figsize=(fig_width, 5.75*fig_width/sum(width_ratios))) + fig_grid = fig.add_gridspec(1, len(width_ratios), wspace=0.02, hspace=0, width_ratios=width_ratios) + afm_grid = fig_grid[0, 0].subgridspec(2, 1, wspace=0, hspace=gap, height_ratios=height_ratios) + pred_grid = fig_grid[0, 1].subgridspec(2, 2, wspace=0.02, hspace=gap, height_ratios=height_ratios) + cbar_grid = fig_grid[0, 2].subgridspec(1, 1, wspace=0, hspace=0) + + # Get axes from grid + afm_sim_axes = afm_grid[0, 0].subgridspec(len(X_sim), len(X_slices), wspace=0.01, hspace=0.01).subplots(squeeze=False) + afm_exp_axes = afm_grid[1, 0].subgridspec(len(X_exp), len(X_slices), wspace=0.01, hspace=0.01).subplots(squeeze=False) + pred_sim_ax, ref_pc_ax, pred_exp_ax, geom_ax = pred_grid.subplots(squeeze=True).flatten() + cbar_ax = cbar_grid.subplots(squeeze=True) + + # Plot AFM + for k, (axes, X) in enumerate(zip([afm_sim_axes, afm_exp_axes], [X_sim, X_exp])): + for i, x in enumerate(X): + for j, s in enumerate(X_slices): + + # Plot AFM slice + im = axes[i, j].imshow(x[0,:,:,s].T, origin="lower", cmap="afmhot") + axes[i, j].set_axis_off() + + # Put tip names to the left of the AFM image rows + axes[i, 0].text(-0.1, 0.5, tip_names[i], horizontalalignment="center", + verticalalignment="center", transform=axes[i, 0].transAxes, + rotation="vertical", fontsize=fontsize) + + + # Figure out data limits + vmax = max( + abs(pred_sim[0].min()), abs(pred_sim[0].max()), + abs(pred_exp[0].min()), abs(pred_exp[0].max()), + abs(ref[0].min()), abs(ref[0].max()), + abs(Y_hartree.min()), abs(Y_hartree.max()) + ) + vmin = -vmax + print(vmin, vmax) + + # Plot predictions and references + pred_sim_ax.imshow(pred_sim[0][0].T, origin="lower", cmap="coolwarm", vmin=vmin, vmax=vmax) + pred_exp_ax.imshow(pred_exp[0][0].T, origin="lower", cmap="coolwarm", vmin=vmin, vmax=vmax) + ref_pc_ax.imshow(ref[0][0].T, origin="lower", cmap="coolwarm", vmin=vmin, vmax=vmax) + + # Plot molecule geometry + xyz_img = np.flipud(imageio.imread(data_dir / "PTCDA" / "mol.png")) + geom_ax.imshow(xyz_img, origin="lower") + + # Plot ES Map colorbar + m_es = cm.ScalarMappable(cmap=cm.coolwarm) + m_es.set_array((vmin, vmax)) + cbar = plt.colorbar(m_es, cax=cbar_ax) + cbar.set_ticks([-0.1, 0.0, 0.1]) + cbar_ax.tick_params(labelsize=fontsize-1) + cbar.set_label("V/Å", fontsize=fontsize) + + # Turn off axes ticks + pred_sim_ax.set_axis_off() + pred_exp_ax.set_axis_off() + ref_pc_ax.set_axis_off() + geom_ax.set_axis_off() + + # Set titles + afm_sim_axes[0, len(X_slices)//2].set_title("AFM simulation", fontsize=fontsize, y=0.90) + afm_exp_axes[0, len(X_slices)//2].set_title("AFM experiment", fontsize=fontsize, y=0.90) + pred_sim_ax.set_title("Sim. prediction", fontsize=fontsize, y=0.95) + pred_exp_ax.set_title("Exp. prediction", fontsize=fontsize, y=0.94) + ref_pc_ax.set_title("Reference", fontsize=fontsize, y=0.95) + + plt.savefig("ptcda.pdf", bbox_inches="tight", dpi=dpi) diff --git a/papers/ed-afm/figures/ptcda_surface_sim.py b/papers/ed-afm/figures/ptcda_surface_sim.py new file mode 100644 index 0000000..e5ef3af --- /dev/null +++ b/papers/ed-afm/figures/ptcda_surface_sim.py @@ -0,0 +1,113 @@ + +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import torch +from matplotlib import cm + +import mlspm.preprocessing as pp +from mlspm.datasets import download_dataset +from mlspm.models import EDAFMNet + +# # Set matplotlib font rendering to use LaTex +# plt.rcParams.update({ +# "text.usetex": True, +# "font.family": "serif", +# "font.serif": ["Computer Modern Roman"] +# }) + +def apply_preprocessing(batch): + + X, Y, xyzs = batch + + pp.add_norm(X) + np.random.seed(0) + pp.add_noise(X, c=0.08) + + return X, Y, xyzs + +if __name__ == "__main__": + + data_dir = Path("./edafm-data") # Path to data + X_slices = [0, 3, 5] # Which AFM slices to plot + tip_names = ["CO", "Xe"] # AFM tip types + device = "cuda" # Device to run inference on + fig_width = 150 # Figure width in mm + fontsize = 8 + dpi = 300 + + # Download data if not already there + download_dataset("ED-AFM-data", data_dir) + + # Load model + model = EDAFMNet(device=device, pretrained_weights="base") + + # Loop over molecules and plot + fig_width = 0.1/2.54*fig_width + width_ratios = [6, 8, 0.3] + fig = plt.figure(figsize=(fig_width, 2.88*fig_width/sum(width_ratios))) + + # Define ticks for colorbars + ticks = [-0.08, -0.04, 0.00, 0.04, 0.08] + + # Load data + X1 = np.load(data_dir / "PTCDA" / "data_CO_sim.npz")["data"].astype(np.float32) + X2 = np.load(data_dir / "PTCDA" / "data_Xe_sim.npz")["data"].astype(np.float32) + X = [X1[None], X2[None]] + Y_hartree = [np.load(data_dir / "PTCDA" / "ESMapHartree.npy")[None]] + + X, Y_hartree, _ = apply_preprocessing((X, Y_hartree, None)) + with torch.no_grad(): + X = [torch.from_numpy(x).unsqueeze(1).to(device) for x in X] + pred, attentions = model(X) + pred = [p.cpu().numpy() for p in pred] + attentions = [a.cpu().numpy() for a in attentions] + X = [x.squeeze(1).cpu().numpy() for x in X] + + # Create plot grid + sample_grid = fig.add_gridspec(1, 3, wspace=0.01, hspace=0, width_ratios=width_ratios) + input_axes = sample_grid[0, 0].subgridspec(len(X), len(X_slices), wspace=0.01, hspace=0.02).subplots(squeeze=False) + pred_ax, ref_ax = sample_grid[0, 1].subgridspec(1, 2, wspace=0.01, hspace=0).subplots(squeeze=True) + cbar_ax = sample_grid[0, 2].subgridspec(1, 1, wspace=0, hspace=0).subplots(squeeze=True) + + # Plot AFM inputs + ims = [] + for i, x in enumerate(X): + for j, s in enumerate(X_slices): + ims.append(input_axes[i, j].imshow(x[0,:,:,s].T, origin="lower", cmap="afmhot")) + input_axes[i, j].set_axis_off() + input_axes[i, 0].text(-0.1, 0.5, tip_names[i], horizontalalignment="center", + verticalalignment="center", transform=input_axes[i, 0].transAxes, + rotation="vertical", fontsize=fontsize) + + # Figure out data limits + vmax = max( + abs(pred[0].min()), abs(pred[0].max()), + abs(Y_hartree[0][0].min()), abs(Y_hartree[0][0].max()) + ) + vmin = -vmax + + # Plot prediction and references + pred_ax.imshow(pred[0][0].T, origin="lower", cmap="coolwarm", vmin=vmin, vmax=vmax) + ref_ax.imshow(Y_hartree[0][0].T, origin="lower", cmap="coolwarm", vmin=vmin, vmax=vmax) + + # Plot ES Map colorbar + plt.rcParams["font.serif"] = "cmr10" + m_es = cm.ScalarMappable(cmap=cm.coolwarm) + m_es.set_array((vmin, vmax)) + cbar = plt.colorbar(m_es, cax=cbar_ax) + cbar.set_ticks(ticks) + cbar_ax.tick_params(labelsize=fontsize-1) + cbar.set_label("V/Å", fontsize=fontsize) + + # Turn off axes ticks + pred_ax.set_axis_off() + ref_ax.set_axis_off() + + # Set titles + input_axes[0, len(X_slices)//2].set_title("AFM simulation (Hartree)", fontsize=fontsize, y=0.93) + pred_ax.set_title("Prediction", fontsize=fontsize, y=0.97) + ref_ax.set_title("Reference (Hartree)", fontsize=fontsize, y=0.97) + + plt.savefig("surface_sims_ptcda.pdf", bbox_inches="tight", dpi=dpi) diff --git a/papers/ed-afm/figures/run_all.sh b/papers/ed-afm/figures/run_all.sh new file mode 100755 index 0000000..a12b182 --- /dev/null +++ b/papers/ed-afm/figures/run_all.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +python sims.py +python ptcda.py +python bcb.py +python water.py +python surface_sims_bcb_water.py +pdflatex model_schem.tex +python stats.py +python stats_spring_constants.py +python esmap_sample.py +pdflatex esmap_schem.tex +python afm_stacks.py +python afm_stacks2.py +python sims_hartree.py +python ptcda_surface_sim.py +python single_tip.py +python sims_Cl.py +python height_dependence.py +python extra_electron.py +python background_gradient.py \ No newline at end of file diff --git a/papers/ed-afm/figures/sims.py b/papers/ed-afm/figures/sims.py new file mode 100644 index 0000000..3962d44 --- /dev/null +++ b/papers/ed-afm/figures/sims.py @@ -0,0 +1,187 @@ + +import string +from pathlib import Path + +import imageio.v2 as imageio +import matplotlib.pyplot as plt +import numpy as np +import ppafm.ml.AuxMap as aux +import ppafm.ocl.field as FFcl +import ppafm.ocl.oclUtils as oclu +import ppafm.ocl.relax as oclr +import torch +from matplotlib import cm +from ppafm.ml.Generator import InverseAFMtrainer +from ppafm.ocl.AFMulator import AFMulator + +import mlspm.preprocessing as pp +from mlspm.datasets import download_dataset +from mlspm.models import EDAFMNet + +# # Set matplotlib font rendering to use LaTex +# plt.rcParams.update({ +# "text.usetex": True, +# "font.family": "serif", +# "font.serif": ["Computer Modern Roman"] +# }) + +def apply_preprocessing(batch): + + X, Y, xyzs = batch + + X = [x[..., 2:8] for x in X] + + pp.add_norm(X) + np.random.seed(0) + pp.add_noise(X, c=0.08) + + return X, Y, xyzs + +if __name__ == "__main__": + + data_dir = Path("./edafm-data") # Path to data + X_slices = [0, 3, 5] # Which AFM slices to plot + tip_names = ["CO", "Xe"] # AFM tip types + device = "cuda" # Device to run inference on + molecules = ["NCM", "PTH", "TTF-TDZ"] # Molecules to run + fig_width = 140 # Figure width in mm + fontsize = 8 + dpi = 300 + + # Download data if not already there + download_dataset("ED-AFM-data", data_dir) + + # Initialize OpenCL environment on GPU + env = oclu.OCLEnvironment(i_platform = 0) + FFcl.init(env) + oclr.init(env) + + afmulator_args = { + "pixPerAngstrome" : 20, + "scan_dim" : (144, 144, 19), + "scan_window" : ((2.0, 2.0, 7.0), (20.0, 20.0, 8.9)), + "df_steps" : 10, + "tipR0" : [0.0, 0.0, 4.0] + } + + generator_kwargs = { + "batch_size" : 1, + "distAbove" : 5.25, + "iZPPs" : [8, 54], + "Qs" : [[ -10, 20, -10, 0 ], [ 30, -60, 30, 0 ]], + "QZs" : [[ 0.1, 0, -0.1, 0 ], [ 0.1, 0, -0.1, 0 ]] + } + + # Paths to molecule xyz files + molecules = [data_dir / m for m in molecules] + xyz_paths = [m / "mol.xyz" for m in molecules] + + # Define AFMulator + afmulator = AFMulator(**afmulator_args) + afmulator.npbc = (0,0,0) + + # Define AuxMaps + aux_maps = [ + aux.ESMapConstant( + scan_dim = afmulator.scan_dim[:2], + scan_window = [afmulator.scan_window[0][:2], afmulator.scan_window[1][:2]], + height = 4.0, + vdW_cutoff = -2.0, + Rpp = 1.0 + ) + ] + + # Define generator + trainer = InverseAFMtrainer(afmulator, aux_maps, xyz_paths, **generator_kwargs) + + # Load model + model = EDAFMNet(device=device, pretrained_weights="base") + + # Set figure + fig_width = 0.1/2.54*fig_width + width_ratios = [3, 6, 8, 0.3] + fig = plt.figure(figsize=(fig_width, 4.05*len(molecules)*fig_width/sum(width_ratios))) + fig_grid = fig.add_gridspec(len(molecules), 1, wspace=0, hspace=0.03) + + # Define ticks for colorbars + ticks = [ + [-0.10, -0.05, 0.00, 0.05, 0.10], + [-0.04, -0.02, 0.00, 0.02, 0.04], + [-0.08, -0.04, 0.00, 0.04, 0.08] + ] + + # Loop over molecules and plot + for ib, batch in enumerate(trainer): + + # Get batch and predict + X, Y_pc, xyzs = apply_preprocessing(batch) + with torch.no_grad(): + X = [torch.from_numpy(x).unsqueeze(1).to(device) for x in X] + pred, attentions = model(X) + pred = [p.cpu().numpy() for p in pred] + attentions = [a.cpu().numpy() for a in attentions] + X = [x.squeeze(1).cpu().numpy() for x in X] + + # Load Hartree reference + Y_hartree = np.load(molecules[ib] / "ESMapHartree.npy") + + # Create plot grid + sample_grid = fig_grid[ib, 0].subgridspec(1, len(width_ratios), wspace=0.02, hspace=0, width_ratios=width_ratios) + xyz_ax = sample_grid[0, 0].subgridspec(1, 1, wspace=0, hspace=0).subplots(squeeze=True) + input_axes = sample_grid[0, 1].subgridspec(len(X), len(X_slices), wspace=0.01, hspace=0.02).subplots(squeeze=False) + pred_ax, ref_ax = sample_grid[0, 2].subgridspec(1, 2, wspace=0.01, hspace=0).subplots(squeeze=True) + cbar_ax = sample_grid[0, 3].subgridspec(1, 1, wspace=0, hspace=0).subplots(squeeze=True) + + # Set subfigure reference letters + grid_pos = sample_grid.get_grid_positions(fig) + x, y = grid_pos[2][0], (grid_pos[1][0] + grid_pos[0][0]) / 2 + 0.3/len(molecules) + fig.text(x, y, string.ascii_uppercase[ib], fontsize=fontsize) + + # Plot molecule geometry + xyz_ax.imshow(imageio.imread(molecules[ib] / "mol.png")) + + # Plot AFM inputs + ims = [] + for i, x in enumerate(X): + for j, s in enumerate(X_slices): + ims.append(input_axes[i, j].imshow(x[0,:,:,s].T, origin="lower", cmap="afmhot")) + input_axes[i, j].set_axis_off() + input_axes[i, 0].text(-0.1, 0.5, tip_names[i], horizontalalignment="center", + verticalalignment="center", transform=input_axes[i, 0].transAxes, + rotation="vertical", fontsize=fontsize) + + # Figure out data limits + vmax = max( + abs(pred[0].min()), abs(pred[0].max()), + abs(Y_pc[0].min()), abs(Y_pc[0].max()), + ) + vmin = -vmax + + # Plot prediction and references + pred_ax.imshow(pred[0][0].T, origin="lower", cmap="coolwarm", vmin=vmin, vmax=vmax) + ref_ax.imshow(Y_pc[0][0].T, origin="lower", cmap="coolwarm", vmin=vmin, vmax=vmax) + + # Plot ES Map colorbar + m_es = cm.ScalarMappable(cmap=cm.coolwarm) + m_es.set_array((vmin, vmax)) + cbar = plt.colorbar(m_es, cax=cbar_ax) + cbar.set_ticks(ticks[ib]) + cbar_ax.tick_params(labelsize=fontsize-1) + cbar.set_label("V/Å", fontsize=fontsize) + + # Turn off axes ticks + xyz_ax.set_axis_off() + pred_ax.set_axis_off() + ref_ax.set_axis_off() + + # Set titles for first row of images + if ib == 0: + input_axes[0, len(X_slices)//2].set_title("AFM simulation", fontsize=fontsize, y=0.90) + pred_ax.set_title("Prediction", fontsize=fontsize, y=0.95) + ref_ax.set_title("Reference", fontsize=fontsize, y=0.95) + + # Calculate relative error metric + rel_abs_err_es = np.mean(np.abs(pred[0] - Y_pc[0])) / np.ptp(Y_pc[0]) + print(f"Relative error: {rel_abs_err_es*100:.2f}%") + + plt.savefig("sims.pdf", bbox_inches="tight", dpi=dpi) diff --git a/papers/ed-afm/figures/sims_Cl.py b/papers/ed-afm/figures/sims_Cl.py new file mode 100644 index 0000000..e261c9a --- /dev/null +++ b/papers/ed-afm/figures/sims_Cl.py @@ -0,0 +1,191 @@ + + +import string +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import ppafm.ml.AuxMap as aux +import ppafm.ocl.field as FFcl +import ppafm.ocl.oclUtils as oclu +import ppafm.ocl.relax as oclr +import torch +from matplotlib import cm +from ppafm.ml.Generator import InverseAFMtrainer +from ppafm.ocl.AFMulator import AFMulator + +import mlspm.preprocessing as pp +from mlspm.datasets import download_dataset +from mlspm.models import EDAFMNet + +# # Set matplotlib font rendering to use LaTex +# plt.rcParams.update({ +# "text.usetex": True, +# "font.family": "serif", +# "font.serif": ["Computer Modern Roman"] +# }) + +class Trainer(InverseAFMtrainer): + + def on_afm_start(self): + if self.afmulator.iZPP in [8, 54]: + afmulator.scanner.stiffness = np.array([0.25, 0.25, 0.0, 30.0], dtype=np.float32) / -16.0217662 + elif self.afmulator.iZPP == 17: + afmulator.scanner.stiffness = np.array([0.50, 0.50, 0.0, 30.0], dtype=np.float32) / -16.0217662 + else: + raise RuntimeError(f"Unknown tip {self.afmulator.iZPP}") + +def apply_preprocessing(batch): + + X, Y, xyzs = batch + + X = [x[..., 2:8] for x in X] + + pp.add_norm(X) + np.random.seed(0) + pp.add_noise(X, c=0.08) + + return X, Y, xyzs + +if __name__ == "__main__": + + data_dir = Path("./edafm-data") # Path to data + X_slices = [0, 3, 5] # Which AFM slices to plot + tip_names = ["CO", "Xe", "Cl"] # AFM tip types + device = "cuda" # Device to run inference on + molecules = ["NCM", "PTH", "TTF-TDZ"] # Molecules to run + fig_width = 160 # Figure width in mm + fontsize = 8 + dpi = 300 + + # Download data if not already there + download_dataset("ED-AFM-data", data_dir) + + # Initialize OpenCL environment on GPU + env = oclu.OCLEnvironment( i_platform = 0 ) + FFcl.init(env) + oclr.init(env) + + afmulator_args = { + "pixPerAngstrome" : 20, + "scan_dim" : (144, 144, 19), + "scan_window" : ((2.0, 2.0, 7.0), (20.0, 20.0, 8.9)), + "df_steps" : 10, + "tipR0" : [0.0, 0.0, 4.0] + } + + generator_kwargs = { + "batch_size" : 1, + "distAbove" : 5.25, + "iZPPs" : [8, 54, 17], # CO, Xe, Cl + "Qs" : [[ -10, 20, -10, 0 ], [ 30, -60, 30, 0 ], [ -0.3, 0, 0, 0 ]], + "QZs" : [[ 0.1, 0, -0.1, 0 ], [ 0.1, 0, -0.1, 0 ], [ 0, 0, 0, 0 ]] + } + + # Paths to molecule xyz files + molecules = [data_dir / m / "mol.xyz" for m in molecules] + + # Define AFMulator + afmulator = AFMulator(**afmulator_args) + afmulator.npbc = (0,0,0) + + # Define AuxMaps + aux_maps = [ + aux.ESMapConstant( + scan_dim = afmulator.scan_dim[:2], + scan_window = [afmulator.scan_window[0][:2], afmulator.scan_window[1][:2]], + height = 4.0, + vdW_cutoff = -2.0, + Rpp = 1.0 + ) + ] + + # Define generator + trainer = Trainer(afmulator, aux_maps, molecules, **generator_kwargs) + + # Load models + model_CO_Cl = EDAFMNet(device=device, pretrained_weights="CO-Cl") + model_Xe_Cl = EDAFMNet(device=device, pretrained_weights="Xe-Cl") + + # Set figure + fig_width = 0.1/2.54*fig_width + width_ratios = [6, 12, 0.3] + fig = plt.figure(figsize=(fig_width, 6*len(molecules)*fig_width/sum(width_ratios))) + fig_grid = fig.add_gridspec(len(molecules), 1, wspace=0, hspace=0.03) + + # Define ticks for colorbars + ticks = [ + [-0.10, -0.05, 0.00, 0.05, 0.10], + [-0.04, -0.02, 0.00, 0.02, 0.04], + [-0.06, -0.03, 0.00, 0.03, 0.06] + ] + + # Loop over molecules and plot + for ib, batch in enumerate(trainer): + + # Get batch and predict + X, Y, _ = apply_preprocessing(batch) + X_CO_Cl = [X[0], X[2]] + X_Xe_Cl = [X[1], X[2]] + with torch.no_grad(): + X_CO_Cl_cuda = [torch.from_numpy(x).unsqueeze(1).to(device) for x in X_CO_Cl] + X_Xe_Cl_cuda = [torch.from_numpy(x).unsqueeze(1).to(device) for x in X_Xe_Cl] + pred_CO_Cl, attentions_CO_Cl = model_CO_Cl(X_CO_Cl_cuda) + pred_Xe_Cl, attentions_Xe_Cl = model_Xe_Cl(X_Xe_Cl_cuda) + pred_CO_Cl = [p.cpu().numpy() for p in pred_CO_Cl] + pred_Xe_Cl = [p.cpu().numpy() for p in pred_Xe_Cl] + attentions_CO_Cl = [a.cpu().numpy() for a in attentions_CO_Cl] + attentions_Xe_Cl = [a.cpu().numpy() for a in attentions_Xe_Cl] + + # Create plot grid + sample_grid = fig_grid[ib, 0].subgridspec(1, len(width_ratios), wspace=0.02, hspace=0, width_ratios=width_ratios) + input_axes = sample_grid[0, 0].subgridspec(len(X), len(X_slices), wspace=0.01, hspace=0.02).subplots(squeeze=False) + pred_CO_Cl_ax, pred_Xe_Cl_ax = sample_grid[0, 1].subgridspec(1, 2, wspace=0.01, hspace=0).subplots(squeeze=True) + cbar_ax = sample_grid[0, 2].subgridspec(1, 1, wspace=0, hspace=0).subplots(squeeze=True) + + # Set subfigure reference letters + grid_pos = sample_grid.get_grid_positions(fig) + x, y = grid_pos[2][0]-0.04, (grid_pos[1][0] + grid_pos[0][0]) / 2 + 0.3/len(molecules) + 0.01 + fig.text(x, y, string.ascii_uppercase[ib], fontsize=fontsize) + + # Plot AFM inputs + ims = [] + for i, x in enumerate(X): + for j, s in enumerate(X_slices): + ims.append(input_axes[i, j].imshow(x[0,:,:,s].T, origin="lower", cmap="afmhot")) + input_axes[i, j].set_axis_off() + input_axes[i, 0].text(-0.1, 0.5, tip_names[i], horizontalalignment="center", + verticalalignment="center", transform=input_axes[i, 0].transAxes, + rotation="vertical", fontsize=fontsize) + + # Figure out data limits + vmax = max( + abs(pred_CO_Cl[0].min()), abs(pred_CO_Cl[0].max()), + abs(pred_Xe_Cl[0].min()), abs(pred_Xe_Cl[0].max()) + ) + vmin = -vmax + + # Plot predictions + pred_CO_Cl_ax.imshow(pred_CO_Cl[0][0].T, origin="lower", cmap="coolwarm", vmin=vmin, vmax=vmax) + pred_Xe_Cl_ax.imshow(pred_Xe_Cl[0][0].T, origin="lower", cmap="coolwarm", vmin=vmin, vmax=vmax) + + # Plot ES Map colorbar + plt.rcParams["font.serif"] = "cmr10" + m_es = cm.ScalarMappable(cmap=cm.coolwarm) + m_es.set_array((vmin, vmax)) + cbar = plt.colorbar(m_es, cax=cbar_ax) + cbar.set_ticks(ticks[ib]) + cbar_ax.tick_params(labelsize=fontsize-1) + cbar.set_label("V/Å", fontsize=fontsize) + + # Turn off axes ticks + pred_CO_Cl_ax.set_axis_off() + pred_Xe_Cl_ax.set_axis_off() + + # Set titles for first row of images + if ib == 0: + input_axes[0, len(X_slices)//2].set_title("AFM simulation", fontsize=fontsize, y=0.91) + pred_CO_Cl_ax.set_title("Prediction (Cl-CO)", fontsize=fontsize, y=0.97) + pred_Xe_Cl_ax.set_title("Prediction (Cl-Xe)", fontsize=fontsize, y=0.97) + + plt.savefig("sims_Cl.pdf", bbox_inches="tight", dpi=dpi) diff --git a/papers/ed-afm/figures/sims_hartree.py b/papers/ed-afm/figures/sims_hartree.py new file mode 100644 index 0000000..1e5ee6f --- /dev/null +++ b/papers/ed-afm/figures/sims_hartree.py @@ -0,0 +1,133 @@ + +import string +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import torch +from matplotlib import cm + +import mlspm.preprocessing as pp +from mlspm.datasets import download_dataset +from mlspm.models import EDAFMNet + +# # Set matplotlib font rendering to use LaTex +# plt.rcParams.update({ +# "text.usetex": True, +# "font.family": "serif", +# "font.serif": ["Computer Modern Roman"] +# }) + +def apply_preprocessing(batch): + + X, Y, xyzs = batch + + pp.add_norm(X) + np.random.seed(0) + pp.add_noise(X, c=0.08) + + return X, Y, xyzs + +if __name__ == "__main__": + + data_dir = Path("./edafm-data") # Path to data + X_slices = [0, 3, 5] # Which AFM slices to plot + tip_names = ["CO", "Xe"] # AFM tip types + device = "cuda" # Device to run inference on + molecules = ["NCM", "PTH", "TTF-TDZ"] # Molecules to run + fig_width = 150 # Figure width in mm + fontsize = 8 + dpi = 300 + + molecules = [data_dir / m for m in molecules] + + # Download data if not already there + download_dataset("ED-AFM-data", data_dir) + + # Load model + model = EDAFMNet(device=device, pretrained_weights="base") + + # Loop over molecules and plot + fig_width = 0.1/2.54*fig_width + width_ratios = [6, 8, 0.3] + fig = plt.figure(figsize=(fig_width, 4.05*len(molecules)*fig_width/sum(width_ratios))) + fig_grid = fig.add_gridspec(len(molecules), 1, wspace=0, hspace=0.03) + + # Define ticks for colorbars + ticks = [ + [-0.15, -0.10, -0.05, 0.00, 0.05, 0.10, 0.15], + [-0.08, -0.04, 0.00, 0.04, 0.08], + [-0.04, 0.00, 0.04] + ] + + for ib in range(len(molecules)): + + # Load data + X1 = np.load(molecules[ib] / "data_CO.npz")["data"].astype(np.float32) + X2 = np.load(molecules[ib] / "data_Xe.npz")["data"].astype(np.float32) + X = [X1[None], X2[None]] + Y_hartree = [np.load(molecules[ib] / "ESMapHartree.npy")[None]] + + X, Y_hartree, _ = apply_preprocessing((X, Y_hartree, None)) + with torch.no_grad(): + X = [torch.from_numpy(x).unsqueeze(1).to(device) for x in X] + pred, attentions = model(X) + pred = [p.cpu().numpy() for p in pred] + attentions = [a.cpu().numpy() for a in attentions] + X = [x.squeeze(1).cpu().numpy() for x in X] + + # Create plot grid + sample_grid = fig_grid[ib, 0].subgridspec(1, 3, wspace=0.01, hspace=0, width_ratios=width_ratios) + input_axes = sample_grid[0, 0].subgridspec(len(X), len(X_slices), wspace=0.01, hspace=0.02).subplots(squeeze=False) + pred_ax, ref_ax = sample_grid[0, 1].subgridspec(1, 2, wspace=0.01, hspace=0).subplots(squeeze=True) + cbar_ax = sample_grid[0, 2].subgridspec(1, 1, wspace=0, hspace=0).subplots(squeeze=True) + + # Set subfigure reference letters + grid_pos = sample_grid.get_grid_positions(fig) + x, y = grid_pos[2][0] - 0.04, (grid_pos[1][0] + grid_pos[0][0]) / 2 + 0.3/len(molecules) + fig.text(x, y, string.ascii_uppercase[ib], fontsize=fontsize) + + # Plot AFM inputs + ims = [] + for i, x in enumerate(X): + for j, s in enumerate(X_slices): + ims.append(input_axes[i, j].imshow(x[0,:,:,s].T, origin="lower", cmap="afmhot")) + input_axes[i, j].set_axis_off() + input_axes[i, 0].text(-0.1, 0.5, tip_names[i], horizontalalignment="center", + verticalalignment="center", transform=input_axes[i, 0].transAxes, + rotation="vertical", fontsize=fontsize) + + # Figure out data limits + vmax = max( + abs(pred[0].min()), abs(pred[0].max()), + abs(Y_hartree[0][0].min()), abs(Y_hartree[0][0].max()) + ) + vmin = -vmax + + # Plot prediction and references + pred_ax.imshow(pred[0][0].T, origin="lower", cmap="coolwarm", vmin=vmin, vmax=vmax) + ref_ax.imshow(Y_hartree[0][0].T, origin="lower", cmap="coolwarm", vmin=vmin, vmax=vmax) + + # Plot ES Map colorbar + m_es = cm.ScalarMappable(cmap=cm.coolwarm) + m_es.set_array((vmin, vmax)) + cbar = plt.colorbar(m_es, cax=cbar_ax) + cbar.set_ticks(ticks[ib]) + cbar_ax.tick_params(labelsize=fontsize-1) + cbar.set_label("V/Å", fontsize=fontsize) + + # Turn off axes ticks + pred_ax.set_axis_off() + ref_ax.set_axis_off() + + # Set titles for first row of images + if ib == 0: + input_axes[0, len(X_slices)//2].set_title("AFM simulation (Hartree)", fontsize=fontsize, y=0.93) + pred_ax.set_title("Prediction", fontsize=fontsize, y=0.97) + ref_ax.set_title("Reference (Hartree)", fontsize=fontsize, y=0.97) + + # Calculate relative error metric + rel_abs_err_es = np.mean(np.abs(pred[0] - Y_hartree[0])) / np.ptp(Y_hartree[0]) + print(f"Relative error: {rel_abs_err_es*100:.2f}%") + + plt.savefig("sims_hartree.pdf", bbox_inches="tight", dpi=dpi) diff --git a/papers/ed-afm/figures/single_tip.py b/papers/ed-afm/figures/single_tip.py new file mode 100644 index 0000000..3f5a57b --- /dev/null +++ b/papers/ed-afm/figures/single_tip.py @@ -0,0 +1,171 @@ + +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import ppafm.ml.AuxMap as aux +import ppafm.ocl.field as FFcl +import ppafm.ocl.oclUtils as oclu +import ppafm.ocl.relax as oclr +import torch +from matplotlib import cm +from ppafm.ml.Generator import InverseAFMtrainer +from ppafm.ocl.AFMulator import AFMulator + +import mlspm.preprocessing as pp +from mlspm.datasets import download_dataset +from mlspm.models import EDAFMNet + +# # Set matplotlib font rendering to use LaTex +# plt.rcParams.update({ +# "text.usetex": True, +# "font.family": "serif", +# "font.serif": ["Computer Modern Roman"] +# }) + +def apply_preprocessing_sim(batch): + + X, Y, xyzs = batch + + X = [x[..., 2:8] for x in X] + + pp.add_norm(X) + np.random.seed(0) + pp.add_noise(X, c=0.08) + + return X, Y, xyzs + +def apply_preprocessing_bcb(X, real_dim): + x0_start = 4 + X[0] = X[0][..., x0_start:x0_start+6] # CO + X = pp.interpolate_and_crop(X, real_dim) + pp.add_norm(X) + return X + +def apply_preprocessing_ptcda(X, real_dim): + x0_start = 2 + X[0] = X[0][..., x0_start:x0_start+6] # CO + X = pp.interpolate_and_crop(X, real_dim) + pp.add_norm(X) + X = [x[:,:,6:78] for x in X] + return X + +if __name__ == "__main__": + + data_dir = Path("./edafm-data") # Path to data + device = "cuda" # Device to run inference on + fig_width = 160 # Figure width in mm + fontsize = 8 + dpi = 300 + + # Download data if not already there + download_dataset("ED-AFM-data", data_dir) + + # Initialize OpenCL environment on GPU + env = oclu.OCLEnvironment( i_platform = 0 ) + FFcl.init(env) + oclr.init(env) + + afmulator_args = { + "pixPerAngstrome" : 20, + "scan_dim" : (128, 128, 19), + "scan_window" : ((2.0, 2.0, 7.0), (20.0, 20.0, 8.9)), + "df_steps" : 10, + "tipR0" : [0.0, 0.0, 4.0] + } + + generator_kwargs = { + "batch_size" : 1, + "distAbove" : 5.25, + "iZPPs" : [8], + "Qs" : [[ -10, 20, -10, 0 ]], + "QZs" : [[ 0.1, 0, -0.1, 0 ]] + } + + # Paths to molecule xyz files + molecules = [data_dir / "TTF-TDZ" / "mol.xyz"] + + # Define AFMulator + afmulator = AFMulator(**afmulator_args) + afmulator.npbc = (0,0,0) + + # Define AuxMaps + aux_maps = [ + aux.ESMapConstant( + scan_dim = afmulator.scan_dim[:2], + scan_window = [afmulator.scan_window[0][:2], afmulator.scan_window[1][:2]], + height = 4.0, + vdW_cutoff = -2.0, + Rpp = 1.0 + ) + ] + + # Define generator + trainer = InverseAFMtrainer(afmulator, aux_maps, molecules, **generator_kwargs) + + # Get simulation data + batch = next(iter(trainer)) + X_sim, Y_sim, xyz = apply_preprocessing_sim(batch) + X_sim_cuda = [torch.from_numpy(x).unsqueeze(1).to(device) for x in X_sim] + + # Load BCB data and preprocess + data_bcb = np.load(data_dir / "BCB" / "data_CO_exp.npz") + X_bcb = data_bcb["data"] + afm_dim_bcb = (data_bcb["lengthX"], data_bcb["lengthY"]) + X_bcb = apply_preprocessing_bcb([X_bcb[None]], afm_dim_bcb) + X_bcb_cuda = [torch.from_numpy(x.astype(np.float32)).unsqueeze(1).to(device) for x in X_bcb] + + # Load PTCDA data and preprocess + data_ptcda = np.load(data_dir / "PTCDA" / "data_CO_exp.npz") + X_ptcda = data_ptcda["data"] + afm_dim_ptcda = (data_ptcda["lengthX"], data_ptcda["lengthY"]) + X_ptcda = apply_preprocessing_ptcda([X_ptcda[None]], afm_dim_ptcda) + X_ptcda_cuda = [torch.from_numpy(x.astype(np.float32)).unsqueeze(1).to(device) for x in X_ptcda] + + # Load model + model = EDAFMNet(device=device, pretrained_weights="single-channel") + + # Make predictions + with torch.no_grad(): + pred_sim, attentions_sim = model(X_sim_cuda) + pred_bcb, attentions_bcb = model(X_bcb_cuda) + pred_ptcda, attentions_ptcda = model(X_ptcda_cuda) + pred_sim = [p.cpu().numpy() for p in pred_sim] + pred_bcb = [p.cpu().numpy() for p in pred_bcb] + pred_ptcda = [p.cpu().numpy() for p in pred_ptcda] + attentions_sim = [a.cpu().numpy() for a in attentions_sim] + attentions_bcb = [a.cpu().numpy() for a in attentions_bcb] + attentions_ptcda = [a.cpu().numpy() for a in attentions_ptcda] + + # Make figure + fig_width = 0.1/2.54*fig_width + width_ratios = [4, 4, 6.9] + fig, axes = plt.subplots(1, 3, figsize=(fig_width, 2.6*fig_width/sum(width_ratios)), + gridspec_kw={"width_ratios": width_ratios, "wspace": 0.3}) + + tick_arrays = [ + [-0.03, 0.0, 0.03], + [-0.05, 0.0, 0.05], + [-0.1, 0.0, 0.1] + ] + + # Plot all predictions + for ax, ticks, pred, label in zip(axes, tick_arrays, [pred_sim, pred_bcb, pred_ptcda], ["A", "B", "C"]): + vmax = max(abs(pred[0][0].max()), abs(pred[0][0].min())); vmin = -vmax + ax.imshow(pred[0][0].T, vmin=vmin, vmax=vmax, cmap="coolwarm", origin="lower") + plt.axes(ax) + m = cm.ScalarMappable(cmap=cm.coolwarm) + m.set_array([vmin, vmax]) + cbar = plt.colorbar(m, ax=ax) + cbar.set_ticks(ticks) + cbar.ax.tick_params(labelsize=fontsize-1) + cbar.set_label("V/Å", fontsize=fontsize) + ax.text(-0.1, 0.95, label, horizontalalignment="center", + verticalalignment="center", transform=ax.transAxes, fontsize=fontsize) + ax.set_axis_off() + + # Calculate relative error metric for simulation + rel_abs_err_es = np.mean(np.abs(pred_sim[0] - Y_sim[0])) / np.ptp(Y_sim[0]) + print(f"Relative error: {rel_abs_err_es*100:.2f}%") + + plt.savefig("single_tip_predictions.pdf", bbox_inches="tight", dpi=dpi) diff --git a/papers/ed-afm/figures/stats.py b/papers/ed-afm/figures/stats.py new file mode 100644 index 0000000..2c81975 --- /dev/null +++ b/papers/ed-afm/figures/stats.py @@ -0,0 +1,122 @@ + +from pathlib import Path +import tarfile +import numpy as np +import matplotlib.pyplot as plt + +# # Set matplotlib font rendering to use LaTex +# plt.rcParams.update({ +# "text.usetex": True, +# "font.family": "serif", +# "font.serif": ["Computer Modern Roman"] +# }) + +if __name__ == "__main__": + + stats_dir = Path("./stats") # Path to directory where csv files for losses are stored + fig_width = 160 # Figure width in mm + fontsize = 8 + dpi = 300 + linewidth = 1 + + with tarfile.open(stats_dir / "stats.tar.gz") as f: + f.extractall(stats_dir) + + data_independent = np.loadtxt(stats_dir / "mse_independent.csv", delimiter=",") + data_matched = np.loadtxt(stats_dir / "mse_matched.csv", delimiter=",") + data_independent_Xe = np.loadtxt(stats_dir / "mse_independent_Xe.csv", delimiter=",") + data_matched_Xe = np.loadtxt(stats_dir / "mse_matched_Xe.csv", delimiter=",") + + data_constant = np.loadtxt(stats_dir / "mse_constant.csv", delimiter=",") + data_uniform = np.loadtxt(stats_dir / "mse_uniform.csv", delimiter=",") + data_normal = np.loadtxt(stats_dir / "mse_normal.csv", delimiter=",") + + h = data_independent[:, 0] + data_independent = data_independent[:, 1:] + data_matched = data_matched[:, 1:] + data_independent_Xe = data_independent_Xe[:, 1:] + data_matched_Xe = data_matched_Xe[:, 1:] + + amp = data_constant[:, 0] + data_constant = data_constant[:, 1:] + data_uniform = data_uniform[:, 1:] + data_normal = data_normal[:, 1:] + + dh = h-5.3 + m1 = data_independent.mean(axis=1) + m2 = data_matched.mean(axis=1) + m3 = data_independent_Xe.mean(axis=1) + m4 = data_matched_Xe.mean(axis=1) + + m5 = data_constant.mean(axis=1) + m6 = data_uniform.mean(axis=1) + m7 = data_normal.mean(axis=1) + + p1_05, p1_95 = np.percentile(data_independent, [5, 95], axis=1) + p2_05, p2_95 = np.percentile(data_matched, [5, 95], axis=1) + p3_05, p3_95 = np.percentile(data_independent_Xe, [5, 95], axis=1) + p4_05, p4_95 = np.percentile(data_matched_Xe, [5, 95], axis=1) + + p5_05, p5_95 = np.percentile(data_constant, [5, 95], axis=1) + p6_05, p6_95 = np.percentile(data_uniform, [5, 95], axis=1) + p7_05, p7_95 = np.percentile(data_normal, [5, 95], axis=1) + + r = 0.35 + w = 0.25 + h = w/r + b = 0.18 + fig_width = 0.1/2.54*fig_width + fig = plt.figure(figsize=(fig_width, r*fig_width), dpi=dpi) + ax1 = plt.axes((0.070, b, w, h)) + ax2 = plt.axes((0.415, b, w, h)) + ax3 = plt.axes((0.732, b, w, h)) + + ax1.plot(dh, m1, "b", linewidth=linewidth) + ax1.plot(dh, m2, "r", linewidth=linewidth) + ax1.plot(dh, p1_05, "--b", linewidth=linewidth) + ax1.plot(dh, p1_95, "--b", linewidth=linewidth) + ax1.plot(dh, p2_05, "--r", linewidth=linewidth) + ax1.plot(dh, p2_95, "--r", linewidth=linewidth) + ax1.set_xlabel("$dh$(\AA)", fontsize=fontsize) + ax1.set_ylabel("MSE", fontsize=fontsize) + ax1.ticklabel_format(axis="y", style="sci", scilimits=(-3, 3)) + ax1.set_title("Both tips shifted", fontsize=fontsize+1) + ax1.legend(["Independent tips", "Matched tips"], fontsize=fontsize-1) + ax1.text(-0.26, 1.08, "A", transform=ax1.transAxes) + + ax2.semilogy(dh, m3, "b", linewidth=linewidth) + ax2.plot(dh, m4, "r", linewidth=linewidth) + ax2.plot(dh, p3_05, "--b", linewidth=linewidth) + ax2.plot(dh, p3_95, "--b", linewidth=linewidth) + ax2.plot(dh, p4_05, "--r", linewidth=linewidth) + ax2.plot(dh, p4_95, "--r", linewidth=linewidth) + ax2.set_xlabel("$dh$(\AA)", fontsize=fontsize) + ax2.set_ylabel("MSE", fontsize=fontsize) + ax2.set_title("Only Xe shifted", fontsize=fontsize+1) + ax2.legend(["Independent tips", "Matched tips"], fontsize=fontsize-1) + ax2.set_ylim((1e-6, 1e-2)) + ax2.text(-0.32, 1.08, "B", transform=ax2.transAxes) + + ax3.plot(amp, m5, "b", linewidth=linewidth) + ax3.plot(amp, m6, "r", linewidth=linewidth) + ax3.plot(amp, m7, "k", linewidth=linewidth) + ax3.plot(amp, p5_05, "--b", linewidth=linewidth) + ax3.plot(amp, p5_95, "--b", linewidth=linewidth) + ax3.plot(amp, p6_05, "--r", linewidth=linewidth) + ax3.plot(amp, p6_95, "--r", linewidth=linewidth) + ax3.plot(amp, p7_05, "--k", linewidth=linewidth) + ax3.plot(amp, p7_95, "--k", linewidth=linewidth) + ax3.set_xlabel("Noise amplitude", fontsize=fontsize) + ax3.set_ylabel("MSE", fontsize=fontsize) + ax3.set_title("Noise amp. distributions", fontsize=fontsize+1) + ax3.legend(["$\mathcal{C}(0.08)$", "$\mathcal{U}(0.16)$", "$\mathcal{N}(0.1)$"], + fontsize=fontsize-1) + ax3.text(-0.22, 1.08, "C", transform=ax3.transAxes) + + for ax in [ax1, ax2, ax3]: + ax.tick_params(axis="both", labelsize=fontsize-1) + tx = ax.yaxis.get_offset_text() + tx.set_fontsize(fontsize-1) + tx.set_position((-0.15, 0)) + + plt.savefig("stats.pdf") diff --git a/papers/ed-afm/figures/stats/stats.tar.gz b/papers/ed-afm/figures/stats/stats.tar.gz new file mode 100644 index 0000000..1f4ebed Binary files /dev/null and b/papers/ed-afm/figures/stats/stats.tar.gz differ diff --git a/papers/ed-afm/figures/stats/stats_distance.py b/papers/ed-afm/figures/stats/stats_distance.py new file mode 100644 index 0000000..47f3f2f --- /dev/null +++ b/papers/ed-afm/figures/stats/stats_distance.py @@ -0,0 +1,130 @@ + +import random +import time +from pathlib import Path + +import numpy as np +import ppafm.ml.AuxMap as aux +import ppafm.ocl.field as FFcl +import ppafm.ocl.oclUtils as oclu +import ppafm.ocl.relax as oclr +import torch +from ppafm.ml.Generator import InverseAFMtrainer +from ppafm.ocl.AFMulator import AFMulator + +import mlspm.preprocessing as pp +from mlspm.datasets import download_dataset +from mlspm.models import EDAFMNet + +def apply_preprocessing(batch): + Xs, Ys, _ = batch + Xs = [x[...,2:8] for x in Xs] + pp.add_norm(Xs) + pp.add_noise(Xs, c=0.08, randomize_amplitude=False) + return Xs, Ys + +if __name__ == "__main__": + + # Indepenent tips model + model_type = "base" # Type of pretrained weights to use + save_file = Path("mse_independent.csv") # File to save MSE values into + + # # Matched tips model + # model_type = "matched-tips" # Type of pretrained weights to use + # save_file = Path("mse_matched.csv") # File to save MSE values into + + device = "cuda" # Device to run inference on + molecules_dir = Path("../../molecules") # Path to molecule database + test_heights = np.linspace(4.9, 5.7, 21) # Test heights to run + n_samples = 3000 # Number of samples to run + + if save_file.exists(): + raise RuntimeError("Save file already exists") + + afmulator_args = { + "pixPerAngstrome" : 20, + "scan_dim" : (128, 128, 19), + "scan_window" : ((2.0, 2.0, 6.0), (18.0, 18.0, 7.9)), + "df_steps" : 10, + "tipR0" : [0.0, 0.0, 4.0] + } + + generator_kwargs = { + "batch_size" : 30, + "distAbove" : 5.3, + "iZPPs" : [8, 54], + "Qs" : [[ -10, 20, -10, 0 ], [ 30, -60, 30, 0 ]], + "QZs" : [[ 0.1, 0, -0.1, 0 ], [ 0.1, 0, -0.1, 0 ]] + } + + # Set random seed for reproducibility + random.seed(0) + + # Initialize OpenCL environment on GPU + env = oclu.OCLEnvironment( i_platform = 0 ) + FFcl.init(env) + oclr.init(env) + + # Define AFMulator + afmulator = AFMulator(**afmulator_args) + afmulator.npbc = (0,0,0) + + # Define AuxMaps + aux_maps = [ + aux.ESMapConstant( + scan_dim = afmulator.scan_dim[:2], + scan_window = [afmulator.scan_window[0][:2], afmulator.scan_window[1][:2]], + height = 4.0, + vdW_cutoff = -2.0, + Rpp = 1.0 + ) + ] + + # Download molecules if not already there + download_dataset("ED-AFM-molecules", molecules_dir) + + # Define generator + xyz_paths = (molecules_dir / "test").glob("*.xyz") + trainer = InverseAFMtrainer(afmulator, aux_maps, xyz_paths, **generator_kwargs) + + # Pick samples + random.shuffle(trainer.molecules) + trainer.molecules = trainer.molecules[:n_samples] + + # Make model + model = EDAFMNet(device=device, pretrained_weights=model_type) + + # Initialize save file + with open(save_file, "w") as f: + pass + + # Calculate MSE at every height for every batch + start_time = time.time() + total_len = len(test_heights)*len(trainer) + for ih, height in enumerate(test_heights): + + print(f"Height = {height:.2f}") + trainer.distAboveActive = height + + mses = [] + for ib, batch in enumerate(trainer): + + X, ref = apply_preprocessing(batch) + X = [torch.from_numpy(x).unsqueeze(1).to(device) for x in X] + ref = [torch.from_numpy(r).to(device) for r in ref] + + with torch.no_grad(): + pred, _ = model(X) + pred = pred[0] + + diff = pred - ref[0] + for d in diff: + mses.append((d**2).mean().cpu().numpy()) + + eta = (time.time() - start_time) * (total_len / (ih*len(trainer)+ib+1) - 1) + print(f"Batch {ib+1}/{len(trainer)} - ETA: {eta:.1f}s") + + with open(save_file, "a") as f: + f.write(f"{height:.2f},") + f.write(",".join([str(v) for v in mses])) + f.write("\n") diff --git a/papers/ed-afm/figures/stats/stats_distance_Xe.py b/papers/ed-afm/figures/stats/stats_distance_Xe.py new file mode 100644 index 0000000..cf1381b --- /dev/null +++ b/papers/ed-afm/figures/stats/stats_distance_Xe.py @@ -0,0 +1,141 @@ + +import random +import time +from pathlib import Path + +import numpy as np +import ppafm.ml.AuxMap as aux +import ppafm.ocl.field as FFcl +import ppafm.ocl.oclUtils as oclu +import ppafm.ocl.relax as oclr +import torch +from ppafm.ml.Generator import InverseAFMtrainer +from ppafm.ocl.AFMulator import AFMulator + +import mlspm.preprocessing as pp +from mlspm.datasets import download_dataset +from mlspm.models import EDAFMNet + + +class Trainer(InverseAFMtrainer): + + # Override this method to set the Xe tip at a different height + def handle_distance(self): + if self.afmulator.iZPP == 54: + self.distAboveActive = self.distAboveXe + super().handle_distance() + if self.afmulator.iZPP == 54: + self.distAboveActive = self.distAbove + +def apply_preprocessing(batch): + Xs, Ys, _ = batch + Xs = [x[...,2:8] for x in Xs] + pp.add_norm(Xs) + pp.add_noise(Xs, c=0.08, randomize_amplitude=False) + return Xs, Ys + +if __name__ == "__main__": + + # # Independent tips model + # model_type = "base" # Type of pretrained weights to use + # save_file = Path("mse_independent_Xe.csv") # File to save MSE values into + + # Matched tips model + model_type = "matched-tips" # Type of pretrained weights to use + save_file = Path("./mse_matched_Xe.csv") # File to save MSE values into + + device = "cuda" # Device to run inference on + molecules_dir = Path("../../molecules") # Path to molecule database + test_heights = np.linspace(4.9, 5.7, 21) # Test heights to run + n_samples = 3000 # Number of samples to run + + if save_file.exists(): + raise RuntimeError("Save file already exists") + + afmulator_args = { + "pixPerAngstrome" : 20, + "scan_dim" : (128, 128, 19), + "scan_window" : ((2.0, 2.0, 6.0), (18.0, 18.0, 7.9)), + "df_steps" : 10, + "tipR0" : [0.0, 0.0, 4.0] + } + + generator_kwargs = { + "batch_size" : 30, + "distAbove" : 5.3, + "iZPPs" : [8, 54], + "Qs" : [[ -10, 20, -10, 0 ], [ 30, -60, 30, 0 ]], + "QZs" : [[ 0.1, 0, -0.1, 0 ], [ 0.1, 0, -0.1, 0 ]] + } + + # Set random seed for reproducibility + random.seed(0) + + # Initialize OpenCL environment on GPU + env = oclu.OCLEnvironment( i_platform = 0 ) + FFcl.init(env) + oclr.init(env) + + # Define AFMulator + afmulator = AFMulator(**afmulator_args) + afmulator.npbc = (0,0,0) + + # Define AuxMaps + aux_maps = [ + aux.ESMapConstant( + scan_dim = afmulator.scan_dim[:2], + scan_window = [afmulator.scan_window[0][:2], afmulator.scan_window[1][:2]], + height = 4.0, + vdW_cutoff = -2.0, + Rpp = 1.0 + ) + ] + + # Download molecules if not already there + download_dataset("ED-AFM-molecules", molecules_dir) + + # Define generator + xyz_paths = (molecules_dir / "test").glob("*.xyz") + trainer = Trainer(afmulator, aux_maps, xyz_paths, **generator_kwargs) + + # Pick samples + random.shuffle(trainer.molecules) + trainer.molecules = trainer.molecules[:n_samples] + + # Make model + model = EDAFMNet(device=device, pretrained_weights=model_type) + + # Initialize save file + with open(save_file, "w") as f: + pass + + # Calculate MSE at every height for every batch + start_time = time.time() + total_len = len(test_heights)*len(trainer) + for ih, height in enumerate(test_heights): + + print(f"Height = {height:.2f}") + trainer.distAboveXe = height + + mses = [] + for ib, batch in enumerate(trainer): + + X, ref = apply_preprocessing(batch) + X = [torch.from_numpy(x).unsqueeze(1).to(device) for x in X] + ref = [torch.from_numpy(r).to(device) for r in ref] + + with torch.no_grad(): + pred, _ = model(X) + pred = pred[0] + + diff = pred - ref[0] + for d in diff: + mses.append((d**2).mean().cpu().numpy()) + + eta = (time.time() - start_time) * (total_len / (ih*len(trainer)+ib+1) - 1) + print(f"Batch {ib+1}/{len(trainer)} - ETA: {eta:.1f}s") + + with open(save_file, "a") as f: + f.write(f"{height:.2f},") + f.write(",".join([str(v) for v in mses])) + f.write("\n") diff --git a/papers/ed-afm/figures/stats/stats_noise.py b/papers/ed-afm/figures/stats/stats_noise.py new file mode 100644 index 0000000..3ce1afb --- /dev/null +++ b/papers/ed-afm/figures/stats/stats_noise.py @@ -0,0 +1,126 @@ + +from pathlib import Path +import random +import time + +import numpy as np +import ppafm.ml.AuxMap as aux +import ppafm.ocl.field as FFcl +import ppafm.ocl.oclUtils as oclu +import ppafm.ocl.relax as oclr +import torch +from ppafm.ml.Generator import InverseAFMtrainer +from ppafm.ocl.AFMulator import AFMulator + +import mlspm.preprocessing as pp +from mlspm.datasets import download_dataset +from mlspm.models import EDAFMNet + +if __name__ == "__main__": + + # # Normal noise model + # model_type = 'base' # Type of pretrained weights to use + # save_file = Path('mse_normal.csv') # File to save MSE values into + + # # Uniform noise model + # model_type = 'uniform-noise' # Type of pretrained weights to use + # save_file = Path('mse_uniform.csv') # File to save MSE values into + + # Constant noise model + model_type = 'constant-noise' # Type of pretrained weights to use + save_file = Path('mse_constant.csv') # File to save MSE values into + + device = 'cuda' # Device to run inference on + molecules_dir = Path('../../molecules') # Path to molecule database + test_amplitudes = np.linspace(0, 0.2, 21) # Test amplitudes to run + n_samples = 3000 # Number of samples to run + + if save_file.exists(): + raise RuntimeError("Save file already exists") + + afmulator_args = { + 'pixPerAngstrome' : 20, + 'scan_dim' : (128, 128, 20), + 'scan_window' : ((2.0, 2.0, 6.0), (18.0, 18.0, 8.0)), + 'df_steps' : 10, + 'tipR0' : [0.0, 0.0, 4.0] + } + + generator_kwargs = { + 'batch_size' : 30, + 'distAbove' : 5.3, + 'iZPPs' : [8, 54], + 'Qs' : [[ -10, 20, -10, 0 ], [ 30, -60, 30, 0 ]], + 'QZs' : [[ 0.1, 0, -0.1, 0 ], [ 0.1, 0, -0.1, 0 ]] + } + + # Set random seed for reproducibility + random.seed(0) + + # Initialize OpenCL environment on GPU + env = oclu.OCLEnvironment( i_platform = 0 ) + FFcl.init(env) + oclr.init(env) + + # Define AFMulator + afmulator = AFMulator(**afmulator_args) + afmulator.npbc = (0,0,0) + + # Define AuxMaps + aux_maps = [ + aux.ESMapConstant( + scan_dim = afmulator.scan_dim[:2], + scan_window = [afmulator.scan_window[0][:2], afmulator.scan_window[1][:2]], + height = 4.0, + vdW_cutoff = -2.0, + Rpp = 1.0 + ) + ] + + # Download molecules if not already there + download_dataset("ED-AFM-molecules", molecules_dir) + + # Define generator + xyz_paths = (molecules_dir / "test").glob("*.xyz") + trainer = InverseAFMtrainer(afmulator, aux_maps, xyz_paths, **generator_kwargs) + + # Pick samples + random.shuffle(trainer.molecules) + trainer.molecules = trainer.molecules[:n_samples] + + # Make model + model = EDAFMNet(device=device, pretrained_weights=model_type) + + # Calculate MSE at every height for every batch + start_time = time.time() + total_len = len(test_amplitudes)*len(trainer) + mses = [[] for _ in range(len(test_amplitudes))] + for ib, batch in enumerate(trainer): + + X, ref, _ = batch + X = [x[...,2:8] for x in X] + pp.add_norm(X) + ref = [torch.from_numpy(r).to(device) for r in ref] + + for ia, noise_amp in enumerate(test_amplitudes): + + X_ = [x.copy() for x in X] + pp.add_noise(X_, c=noise_amp, randomize_amplitude=False) + X_ = [torch.from_numpy(x).unsqueeze(1).to(device) for x in X_] + + with torch.no_grad(): + pred, _ = model(X_) + pred = pred[0] + + diff = pred - ref[0] + for d in diff: + mses[ia].append((d**2).mean().cpu().numpy()) + + eta = (time.time() - start_time) * (len(trainer) / (ib+1) - 1) + print(f'Batch {ib+1}/{len(trainer)} - ETA: {eta:.1f}s') + + with open(save_file, 'w') as f: + for noise_amp, mse_amp in zip(test_amplitudes, mses): + f.write(f'{noise_amp:.2f},') + f.write(','.join([str(v) for v in mse_amp])) + f.write('\n') diff --git a/papers/ed-afm/figures/stats/stats_spring_constant_lat.py b/papers/ed-afm/figures/stats/stats_spring_constant_lat.py new file mode 100644 index 0000000..fa3e86e --- /dev/null +++ b/papers/ed-afm/figures/stats/stats_spring_constant_lat.py @@ -0,0 +1,135 @@ + +import random +import time +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import ppafm.ml.AuxMap as aux +import ppafm.ocl.field as FFcl +import ppafm.ocl.oclUtils as oclu +import ppafm.ocl.relax as oclr +import torch +from ppafm.ml.Generator import InverseAFMtrainer +from ppafm.ocl.AFMulator import AFMulator + +import mlspm.preprocessing as pp +from mlspm.datasets import download_dataset +from mlspm.models import EDAFMNet +from mlspm.visualization import plot_input + + +def apply_preprocessing(batch): + Xs, Ys, _ = batch + Xs = [x[...,2:8] for x in Xs] + pp.add_norm(Xs) + pp.add_noise(Xs, c=0.08, randomize_amplitude=False) + return Xs, Ys + +if __name__ == "__main__": + + model_type = "base" # Type of pretrained weights to use + save_file = Path("mse_spring_constants_lat.csv") # File to save MSE values into + device = "cuda" # Device to run inference on + molecules_dir = Path("../../molecules") # Path to molecule database + sim_save_dir = Path("./test_sim_lat") # Directory where to save test simulations for inspection + test_constants = np.linspace(0.2, 0.3, 21) # Test lateral spring constants to run + n_samples = 3000 # Number of samples to run + + if save_file.exists(): + raise RuntimeError("Save file already exists") + + afmulator_args = { + "pixPerAngstrome" : 20, + "scan_dim" : (128, 128, 19), + "scan_window" : ((2.0, 2.0, 6.0), (18.0, 18.0, 7.9)), + "df_steps" : 10, + "tipR0" : [0.0, 0.0, 4.0] + } + + generator_kwargs = { + "batch_size" : 30, + "distAbove" : 5.3, + "iZPPs" : [8, 54], + "Qs" : [[ -10, 20, -10, 0 ], [ 30, -60, 30, 0 ]], + "QZs" : [[ 0.1, 0, -0.1, 0 ], [ 0.1, 0, -0.1, 0 ]] + } + # Set random seed for reproducibility + random.seed(0) + + # Initialize OpenCL environment on GPU + env = oclu.OCLEnvironment( i_platform = 0 ) + FFcl.init(env) + oclr.init(env) + + # Define AFMulator + afmulator = AFMulator(**afmulator_args) + afmulator.npbc = (0,0,0) + + # Define AuxMaps + aux_maps = [ + aux.ESMapConstant( + scan_dim = afmulator.scan_dim[:2], + scan_window = [afmulator.scan_window[0][:2], afmulator.scan_window[1][:2]], + height = 4.0, + vdW_cutoff = -2.0, + Rpp = 1.0 + ) + ] + + # Download molecules if not already there + download_dataset("ED-AFM-molecules", molecules_dir) + + # Define generator + xyz_paths = (molecules_dir / "test").glob("*.xyz") + trainer = InverseAFMtrainer(afmulator, aux_maps, xyz_paths, **generator_kwargs) + + # Pick samples + random.shuffle(trainer.molecules) + trainer.molecules = trainer.molecules[:n_samples] + + # Make model + model = EDAFMNet(device=device, pretrained_weights=model_type) + + # Initialize save file + with open(save_file, "w") as f: + pass + sim_save_dir.mkdir(exist_ok=True, parents=True) + + # Calculate MSE at every height for every batch + start_time = time.time() + total_len = len(test_constants)*len(trainer) + for ih, k_lat in enumerate(test_constants): + + print(f"Lateral spring constant = {k_lat:.3f}") + afmulator.scanner.stiffness = np.array([k_lat, k_lat, 0.0, 30.0], dtype=np.float32) / -16.0217662 + + mses = [] + for ib, batch in enumerate(trainer): + + if ib == 0: + print("Saving example simulations...") + for s in range(10): + fig = plot_input(batch[0][0][s]) + plt.savefig(sim_save_dir / f"sim{s}_klat_{k_lat:.3f}.png") + plt.close() + + X, ref = apply_preprocessing(batch) + X = [torch.from_numpy(x).unsqueeze(1).to(device) for x in X] + ref = [torch.from_numpy(r).to(device) for r in ref] + + with torch.no_grad(): + pred, _ = model(X) + pred = pred[0] + + diff = pred - ref[0] + for d in diff: + mses.append((d**2).mean().cpu().numpy()) + + eta = (time.time() - start_time) * (total_len / (ih*len(trainer)+ib+1) - 1) + print(f"Batch {ib+1}/{len(trainer)} - ETA: {eta:.1f}s") + + with open(save_file, "a") as f: + f.write(f"{k_lat:.3f},") + f.write(",".join([str(v) for v in mses])) + f.write("\n") diff --git a/papers/ed-afm/figures/stats/stats_spring_constant_rad.py b/papers/ed-afm/figures/stats/stats_spring_constant_rad.py new file mode 100644 index 0000000..67071c6 --- /dev/null +++ b/papers/ed-afm/figures/stats/stats_spring_constant_rad.py @@ -0,0 +1,135 @@ + +import random +import time +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import ppafm.ml.AuxMap as aux +import ppafm.ocl.field as FFcl +import ppafm.ocl.oclUtils as oclu +import ppafm.ocl.relax as oclr +import torch +from ppafm.ml.Generator import InverseAFMtrainer +from ppafm.ocl.AFMulator import AFMulator + +import mlspm.preprocessing as pp +from mlspm.datasets import download_dataset +from mlspm.models import EDAFMNet +from mlspm.visualization import plot_input + + +def apply_preprocessing(batch): + Xs, Ys, _ = batch + Xs = [x[...,2:8] for x in Xs] + pp.add_norm(Xs) + pp.add_noise(Xs, c=0.08, randomize_amplitude=False) + return Xs, Ys + +if __name__ == "__main__": + + model_type = "base" # Type of pretrained weights to use + save_file = Path("mse_spring_constants_rad.csv") # File to save MSE values into + device = "cuda" # Device to run inference on + molecules_dir = Path("../../molecules") # Path to molecule database + sim_save_dir = Path("./test_sim_rad") # Directory where to save test simulations for inspection + test_constants = np.linspace(20, 40, 21) # Test radial spring constants to run + n_samples = 3000 # Number of samples to run + + if save_file.exists(): + raise RuntimeError("Save file already exists") + + afmulator_args = { + "pixPerAngstrome" : 20, + "scan_dim" : (128, 128, 19), + "scan_window" : ((2.0, 2.0, 6.0), (18.0, 18.0, 7.9)), + "df_steps" : 10, + "tipR0" : [0.0, 0.0, 4.0] + } + + generator_kwargs = { + "batch_size" : 30, + "distAbove" : 5.3, + "iZPPs" : [8, 54], + "Qs" : [[ -10, 20, -10, 0 ], [ 30, -60, 30, 0 ]], + "QZs" : [[ 0.1, 0, -0.1, 0 ], [ 0.1, 0, -0.1, 0 ]] + } + # Set random seed for reproducibility + random.seed(0) + + # Initialize OpenCL environment on GPU + env = oclu.OCLEnvironment( i_platform = 0 ) + FFcl.init(env) + oclr.init(env) + + # Define AFMulator + afmulator = AFMulator(**afmulator_args) + afmulator.npbc = (0,0,0) + + # Define AuxMaps + aux_maps = [ + aux.ESMapConstant( + scan_dim = afmulator.scan_dim[:2], + scan_window = [afmulator.scan_window[0][:2], afmulator.scan_window[1][:2]], + height = 4.0, + vdW_cutoff = -2.0, + Rpp = 1.0 + ) + ] + + # Download molecules if not already there + download_dataset("ED-AFM-molecules", molecules_dir) + + # Define generator + xyz_paths = (molecules_dir / "test").glob("*.xyz") + trainer = InverseAFMtrainer(afmulator, aux_maps, xyz_paths, **generator_kwargs) + + # Pick samples + random.shuffle(trainer.molecules) + trainer.molecules = trainer.molecules[:n_samples] + + # Make model + model = EDAFMNet(device=device, pretrained_weights=model_type) + + # Initialize save file + with open(save_file, "w") as f: + pass + sim_save_dir.mkdir(exist_ok=True, parents=True) + + # Calculate MSE at every height for every batch + start_time = time.time() + total_len = len(test_constants)*len(trainer) + for ih, k_rad in enumerate(test_constants): + + print(f"Radial spring constant = {k_rad:.1f}") + afmulator.scanner.stiffness = np.array([0.25, 0.25, 0.0, k_rad], dtype=np.float32) / -16.0217662 + + mses = [] + for ib, batch in enumerate(trainer): + + if ib == 0: + print("Saving example simulations...") + for s in range(10): + fig = plot_input(batch[0][0][s]) + plt.savefig(sim_save_dir / f"sim{s}_krad_{k_rad:.3f}.png") + plt.close() + + X, ref = apply_preprocessing(batch) + X = [torch.from_numpy(x).unsqueeze(1).to(device) for x in X] + ref = [torch.from_numpy(r).to(device) for r in ref] + + with torch.no_grad(): + pred, _ = model(X) + pred = pred[0] + + diff = pred - ref[0] + for d in diff: + mses.append((d**2).mean().cpu().numpy()) + + eta = (time.time() - start_time) * (total_len / (ih*len(trainer)+ib+1) - 1) + print(f"Batch {ib+1}/{len(trainer)} - ETA: {eta:.1f}s") + + with open(save_file, "a") as f: + f.write(f"{k_rad:.1f},") + f.write(",".join([str(v) for v in mses])) + f.write("\n") diff --git a/papers/ed-afm/figures/stats_spring_constants.py b/papers/ed-afm/figures/stats_spring_constants.py new file mode 100644 index 0000000..3c0b947 --- /dev/null +++ b/papers/ed-afm/figures/stats_spring_constants.py @@ -0,0 +1,75 @@ + +from pathlib import Path +import tarfile + +import matplotlib.pyplot as plt +import numpy as np + +# # Set matplotlib font rendering to use LaTex +# plt.rcParams.update({ +# "text.usetex": True, +# "font.family": "serif", +# "font.serif": ["Computer Modern Roman"] +# }) + +if __name__ == "__main__": + + stats_dir = Path("./stats") # Path to directory where csv files for losses are stored + fig_width = 130 # Figure width in mm + fontsize = 8 + dpi = 300 + linewidth = 1 + + with tarfile.open(stats_dir / "stats.tar.gz") as f: + f.extractall(stats_dir) + + data_lat = np.loadtxt(stats_dir / "mse_spring_constants_lat.csv", delimiter=",") + data_rad = np.loadtxt(stats_dir / "mse_spring_constants_rad.csv", delimiter=",") + + k_lat = data_lat[:, 0] + data_lat = data_lat[:, 1:] + + k_rad = data_rad[:, 0] + data_rad = data_rad[:, 1:] + + m1 = data_lat.mean(axis=1) + m2 = data_rad.mean(axis=1) + + print(m1.max()/m1.min()) + + p1_05, p1_95 = np.percentile(data_lat, [5, 95], axis=1) + p2_05, p2_95 = np.percentile(data_rad, [5, 95], axis=1) + + r = 0.55 + w = 0.38 + h = w/r + b = 0.16 + fig_width = 0.1/2.54*fig_width + fig = plt.figure(figsize=(fig_width, r*fig_width), dpi=dpi) + ax1 = plt.axes((0.100, b, w, h)) + ax2 = plt.axes((0.600, b, w, h)) + + ax1.plot(k_lat, m1, "b", linewidth=linewidth) + ax1.plot(k_lat, p1_05, "--b", linewidth=linewidth) + ax1.plot(k_lat, p1_95, "--b", linewidth=linewidth) + ax1.set_xlabel("$k_{\mathrm{lat}}$(N/m)", fontsize=fontsize) + ax1.set_ylabel("MSE", fontsize=fontsize) + ax1.ticklabel_format(axis="y", style="sci", scilimits=(-3, 3)) + ax1.set_title("Lateral spring constant", fontsize=fontsize+1) + ax1.text(-0.23, 1.08, "A", transform=ax1.transAxes) + + ax2.plot(k_rad, m2, "b", linewidth=linewidth) + ax2.plot(k_rad, p2_05, "--b", linewidth=linewidth) + ax2.plot(k_rad, p2_95, "--b", linewidth=linewidth) + ax2.set_xlabel("$k_{\mathrm{rad}}$(N/m)", fontsize=fontsize) + ax2.set_ylabel("MSE", fontsize=fontsize) + ax2.set_title("Radial spring constant", fontsize=fontsize+1) + ax2.text(-0.23, 1.08, "B", transform=ax2.transAxes) + + for ax in [ax1, ax2]: + ax.tick_params(axis="both", labelsize=fontsize-1) + tx = ax.yaxis.get_offset_text() + tx.set_fontsize(fontsize-1) + tx.set_position((-0.15, 0)) + + plt.savefig("stats_spring_constants.pdf") diff --git a/papers/ed-afm/figures/surface_sims_bcb_water.py b/papers/ed-afm/figures/surface_sims_bcb_water.py new file mode 100644 index 0000000..9346631 --- /dev/null +++ b/papers/ed-afm/figures/surface_sims_bcb_water.py @@ -0,0 +1,131 @@ + + +import string +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import torch +from matplotlib import cm + +import mlspm.preprocessing as pp +from mlspm.datasets import download_dataset +from mlspm.models import EDAFMNet + +# # Set matplotlib font rendering to use LaTex +# plt.rcParams.update({ +# "text.usetex": True, +# "font.family": "serif", +# "font.serif": ["Computer Modern Roman"] +# }) + +def apply_preprocessing(batch): + + X, Y, xyzs = batch + + pp.add_norm(X) + np.random.seed(0) + pp.add_noise(X, c=0.08) + + return X, Y, xyzs + +if __name__ == "__main__": + + data_dir = Path("./edafm-data") # Path to data + X_slices = [0, 3, 5] # Which AFM slices to plot + tip_names = ["CO", "Xe"] # AFM tip types + device = "cuda" # Device to run inference on + molecules = ["BCB", "Water"] # Molecules to run + fig_width = 150 # Figure width in mm + fontsize = 8 + dpi = 300 + + molecules = [data_dir / m for m in molecules] + + # Download data if not already there + download_dataset("ED-AFM-data", data_dir) + + # Load model + model = EDAFMNet(device=device, pretrained_weights="base") + + # Set figure + fig_width = 0.1/2.54*fig_width + width_ratios = [6, 8, 0.3] + fig = plt.figure(figsize=(fig_width, 4.05*len(molecules)*fig_width/sum(width_ratios))) + fig_grid = fig.add_gridspec(len(molecules), 1, wspace=0, hspace=0.03) + + # Define ticks for colorbars + ticks = [ + [-0.04, -0.02, 0.00, 0.02, 0.04], + [-0.10, -0.05, 0.0, 0.05, 0.10] + ] + + # Loop over molecules and plot + for ib in range(len(molecules)): + + # Load data + X1 = np.load(molecules[ib] / "data_CO_sim.npz")["data"].astype(np.float32) + X2 = np.load(molecules[ib] / "data_Xe_sim.npz")["data"].astype(np.float32) + X = [X1[None], X2[None]] + Y_hartree = [np.load(molecules[ib] / "ESMapHartree.npy")[None]] + + X, Y_hartree, _ = apply_preprocessing((X, Y_hartree, None)) + with torch.no_grad(): + X = [torch.from_numpy(x).unsqueeze(1).to(device) for x in X] + pred, attentions = model(X) + pred = [p.cpu().numpy() for p in pred] + attentions = [a.cpu().numpy() for a in attentions] + X = [x.squeeze(1).cpu().numpy() for x in X] + + # Create plot grid + sample_grid = fig_grid[ib, 0].subgridspec(1, 3, wspace=0.01, hspace=0, width_ratios=width_ratios) + input_axes = sample_grid[0, 0].subgridspec(len(X), len(X_slices), wspace=0.01, hspace=0.02).subplots(squeeze=False) + pred_ax, ref_ax = sample_grid[0, 1].subgridspec(1, 2, wspace=0.01, hspace=0).subplots(squeeze=True) + cbar_ax = sample_grid[0, 2].subgridspec(1, 1, wspace=0, hspace=0).subplots(squeeze=True) + + # Set subfigure reference letters + grid_pos = sample_grid.get_grid_positions(fig) + x, y = grid_pos[2][0] - 0.04, (grid_pos[1][0] + grid_pos[0][0]) / 2 + 0.3/len(molecules) + fig.text(x, y, string.ascii_uppercase[ib], fontsize=fontsize) + + # Plot AFM inputs + ims = [] + for i, x in enumerate(X): + for j, s in enumerate(X_slices): + ims.append(input_axes[i, j].imshow(x[0,:,:,s].T, origin="lower", cmap="afmhot")) + input_axes[i, j].set_axis_off() + input_axes[i, 0].text(-0.1, 0.5, tip_names[i], horizontalalignment="center", + verticalalignment="center", transform=input_axes[i, 0].transAxes, + rotation="vertical", fontsize=fontsize) + + # Figure out data limits + vmax = max( + abs(pred[0].min()), abs(pred[0].max()), + abs(Y_hartree[0][0].min()), abs(Y_hartree[0][0].max()) + ) + vmin = -vmax + print(vmin, vmax) + + # Plot prediction and references + pred_ax.imshow(pred[0][0].T, origin="lower", cmap="coolwarm", vmin=vmin, vmax=vmax) + ref_ax.imshow(Y_hartree[0][0].T, origin="lower", cmap="coolwarm", vmin=vmin, vmax=vmax) + + # Plot ES Map colorbar + m_es = cm.ScalarMappable(cmap=cm.coolwarm) + m_es.set_array((vmin, vmax)) + cbar = plt.colorbar(m_es, cax=cbar_ax) + cbar.set_ticks(ticks[ib]) + cbar_ax.tick_params(labelsize=fontsize-1) + cbar.set_label("V/Å", fontsize=fontsize) + + # Turn off axes ticks + pred_ax.set_axis_off() + ref_ax.set_axis_off() + + # Set titles for first row of images + if ib == 0: + input_axes[0, len(X_slices)//2].set_title("AFM simulation (Hartree)", fontsize=fontsize, y=0.93) + pred_ax.set_title("Prediction", fontsize=fontsize, y=0.97) + ref_ax.set_title("Reference (Hartree)", fontsize=fontsize, y=0.97) + + plt.savefig("surface_sims_bcb_water.pdf", bbox_inches="tight", dpi=dpi) diff --git a/papers/ed-afm/figures/water.py b/papers/ed-afm/figures/water.py new file mode 100644 index 0000000..ee78301 --- /dev/null +++ b/papers/ed-afm/figures/water.py @@ -0,0 +1,229 @@ + +from pathlib import Path + +import imageio.v2 as imageio +import matplotlib.pyplot as plt +import numpy as np +import ppafm.ml.AuxMap as aux +import ppafm.ocl.field as FFcl +import ppafm.ocl.oclUtils as oclu +import ppafm.ocl.relax as oclr +import torch +from matplotlib import cm +from ppafm.ml.Generator import InverseAFMtrainer +from ppafm.ocl.AFMulator import AFMulator + +import mlspm.preprocessing as pp +from mlspm.datasets import download_dataset +from mlspm.models import EDAFMNet + +# # Set matplotlib font rendering to use LaTex +# plt.rcParams.update({ +# "text.usetex": True, +# "font.family": "serif", +# "font.serif": ["Computer Modern Roman"] +# }) + +np.random.seed(0) + +class Trainer(InverseAFMtrainer): + + # Override this method to set the Xe tip further + def handle_distance(self): + if self.afmulator.iZPP == 54: + self.distAboveActive += 1.0 + super().handle_distance() + + # Override position handling to center on the non-Cu atoms + def handle_positions(self): + sw = self.afmulator.scan_window + scan_center = np.array([sw[1][0] + sw[0][0], sw[1][1] + sw[0][1]]) / 2 + self.xyzs[:,:2] += scan_center - self.xyzs[self.Zs != 29,:2].mean(axis=0) + +def apply_preprocessing_sim(batch): + + X, Y, xyzs = batch + + X = [x[..., 2:8] for x in X] + + pp.add_norm(X) + np.random.seed(0) + pp.add_noise(X, c=0.08) + + return X, Y, xyzs + +def apply_preprocessing_exp(X, real_dim): + + # Pick slices + x0_start, x1_start = 12, 12 + X[0] = X[0][..., x0_start:x0_start+6] # CO + X[1] = X[1][..., x1_start:x1_start+6] # Xe + + X = pp.interpolate_and_crop(X, real_dim) + pp.add_norm(X) + + # Crop + X = [x[:, 36:-20, 26:-30] for x in X] + + print(X[0].shape) + + return X + +if __name__ == "__main__": + + data_dir = Path("./edafm-data") # Path to data + X_slices = [0, 3, 5] # Which AFM slices to plot + tip_names = ['CO', 'Xe'] # AFM tip types + device = 'cuda' # Device to run inference on + fig_width = 160 # Figure width in mm + fontsize = 8 + dpi = 300 + + # Download data if not already there + download_dataset("ED-AFM-data", data_dir) + + # Initialize OpenCL environment on GPU + env = oclu.OCLEnvironment( i_platform = 0 ) + FFcl.init(env) + oclr.init(env) + + afmulator_args = { + 'pixPerAngstrome' : 20, + 'scan_dim' : (144, 144, 19), + 'scan_window' : ((2.0, 2.0, 7.0), (20.0, 20.0, 8.9)), + 'df_steps' : 10, + "tipR0" : [0.0, 0.0, 4.0] + } + + generator_kwargs = { + 'batch_size' : 1, + 'distAbove' : 5.3, + 'iZPPs' : [8, 54], + 'Qs' : [[ -10, 20, -10, 0 ], [ 30, -60, 30, 0 ]], + 'QZs' : [[ 0.1, 0, -0.1, 0 ], [ 0.1, 0, -0.1, 0 ]] + } + + # Paths to molecule xyz files + molecules = [data_dir / 'Water' / 'mol.xyz'] + + # Define AFMulator + afmulator = AFMulator(**afmulator_args) + afmulator.npbc = (0,0,0) + + # Define AuxMaps + aux_maps = [ + aux.ESMapConstant( + scan_dim = afmulator.scan_dim[:2], + scan_window = [afmulator.scan_window[0][:2], afmulator.scan_window[1][:2]], + height = 4.0, + vdW_cutoff = -2.0, + Rpp = 1.0 + ) + ] + + # Define generator + trainer = Trainer(afmulator, aux_maps, molecules, **generator_kwargs) + + # Get simulation data + X_sim, ref, xyzs = apply_preprocessing_sim(next(iter(trainer))) + X_sim_cuda = [torch.from_numpy(x).unsqueeze(1).to(device) for x in X_sim] + + # Load experimental data and preprocess + data1 = np.load(data_dir / 'Water' / 'data_CO_exp.npz') + X1 = data1['data'] + afm_dim1 = (data1['lengthX'], data1['lengthY']) + + data2 = np.load(data_dir / 'Water' / 'data_Xe_exp.npz') + X2 = data2['data'] + afm_dim2 = (data2['lengthX'], data2['lengthY']) + + print(X1.shape, X2.shape) + assert afm_dim1 == afm_dim2 + afm_dim = afm_dim1 + X_exp = apply_preprocessing_exp([X1[None], X2[None]], afm_dim) + X_exp_cuda = [torch.from_numpy(x.astype(np.float32)).unsqueeze(1).to(device) for x in X_exp] + + # Load model + model = EDAFMNet(device=device, pretrained_weights='base') + + # Get predictions + with torch.no_grad(): + pred_sim, attentions_sim = model(X_sim_cuda) + pred_exp, attentions_exp = model(X_exp_cuda) + pred_sim = [p.cpu().numpy() for p in pred_sim] + pred_exp = [p.cpu().numpy() for p in pred_exp] + attentions_sim = [a.cpu().numpy() for a in attentions_sim] + attentions_exp = [a.cpu().numpy() for a in attentions_exp] + + # Create figure grid + fig_width = 0.1/2.54*fig_width + width_ratios = [6, 8, 0.3] + height_ratios = [1, 1] + gap = 0.20 + fig = plt.figure(figsize=(fig_width, 8.7*fig_width/sum(width_ratios))) + fig_grid = fig.add_gridspec(1, len(width_ratios), wspace=0.02, hspace=0, width_ratios=width_ratios) + afm_grid = fig_grid[0, 0].subgridspec(2, 1, wspace=0, hspace=gap, height_ratios=height_ratios) + pred_grid = fig_grid[0, 1].subgridspec(2, 2, wspace=0.02, hspace=gap, height_ratios=height_ratios) + cbar_grid = fig_grid[0, 2].subgridspec(1, 1, wspace=0, hspace=0) + + # Get axes from grid + afm_sim_axes = afm_grid[0, 0].subgridspec(len(X_sim), len(X_slices), wspace=0.01, hspace=0.01).subplots(squeeze=False) + afm_exp_axes = afm_grid[1, 0].subgridspec(len(X_exp), len(X_slices), wspace=0.01, hspace=0.01).subplots(squeeze=False) + pred_sim_ax, ref_pc_ax, pred_exp_ax, geom_ax = pred_grid.subplots(squeeze=True).flatten() + cbar_ax = cbar_grid.subplots(squeeze=True) + + # Plot AFM + for k, (axes, X) in enumerate(zip([afm_sim_axes, afm_exp_axes], [X_sim, X_exp])): + for i, x in enumerate(X): + for j, s in enumerate(X_slices): + + # Plot AFM slice + im = axes[i, j].imshow(x[0,:,:,s].T, origin='lower', cmap='afmhot') + axes[i, j].set_axis_off() + + # Put tip names to the left of the AFM image rows + axes[i, 0].text(-0.1, 0.5, tip_names[i], horizontalalignment='center', + verticalalignment='center', transform=axes[i, 0].transAxes, + rotation='vertical', fontsize=fontsize) + + + # Figure out data limits + vmax = max( + abs(pred_sim[0].min()), abs(pred_sim[0].max()), + abs(pred_exp[0].min()), abs(pred_exp[0].max()), + abs(ref[0].min()), abs(ref[0].max()) + ) + vmin = -vmax + print(ref[0].min()) + + # Plot predictions and references + pred_sim_ax.imshow(pred_sim[0][0].T, origin='lower', cmap='coolwarm', vmin=vmin, vmax=vmax) + pred_exp_ax.imshow(pred_exp[0][0].T, origin='lower', cmap='coolwarm', vmin=vmin, vmax=vmax) + ref_pc_ax.imshow(ref[0][0].T, origin='lower', cmap='coolwarm', vmin=vmin, vmax=vmax) + + # Plot molecule geometry + xyz_img = np.flipud(imageio.imread(data_dir / 'Water' / 'mol.png')) + geom_ax.imshow(xyz_img, origin='lower') + + # Plot ES Map colorbar + m_es = cm.ScalarMappable(cmap=cm.coolwarm) + m_es.set_array((vmin, vmax)) + cbar = plt.colorbar(m_es, cax=cbar_ax) + cbar.set_ticks([-0.1, -0.05, 0.0, 0.05, 0.1]) + cbar_ax.tick_params(labelsize=fontsize-1) + cbar.set_label('V/Å', fontsize=fontsize) + + # Turn off axes ticks + pred_sim_ax.set_axis_off() + pred_exp_ax.set_axis_off() + ref_pc_ax.set_axis_off() + geom_ax.set_axis_off() + + # Set titles + afm_sim_axes[0, len(X_slices)//2].set_title('AFM simulation', fontsize=fontsize, y=0.91) + afm_exp_axes[0, len(X_slices)//2].set_title('AFM experiment', fontsize=fontsize, y=0.91) + pred_sim_ax.set_title('Sim. prediction', fontsize=fontsize, y=0.96) + pred_exp_ax.set_title('Exp. prediction', fontsize=fontsize, y=0.96) + ref_pc_ax.set_title('Reference', fontsize=fontsize, y=0.96) + + plt.savefig('water.pdf', bbox_inches='tight', dpi=dpi) diff --git a/papers/ed-afm/generate_data.py b/papers/ed-afm/generate_data.py new file mode 100644 index 0000000..801ba5d --- /dev/null +++ b/papers/ed-afm/generate_data.py @@ -0,0 +1,117 @@ +import time +from math import ceil +from pathlib import Path + +import numpy as np +from ppafm.common import eVA_Nm +from ppafm.ml.AuxMap import ESMapConstant +from ppafm.ml.Generator import InverseAFMtrainer +from ppafm.ocl.AFMulator import AFMulator + +from mlspm.data_generation import TarWriter +from mlspm.datasets import download_dataset + + +class Trainer(InverseAFMtrainer): + + def on_afm_start(self): + # Use different lateral stiffness for Cl than CO and Xe + if self.afmulator.iZPP in [8, 54]: + afmulator.scanner.stiffness = np.array([0.25, 0.25, 0.0, 30.0], dtype=np.float32) / -eVA_Nm + elif self.afmulator.iZPP == 17: + afmulator.scanner.stiffness = np.array([0.5, 0.5, 0.0, 30.0], dtype=np.float32) / -eVA_Nm + else: + raise RuntimeError(f"Unknown tip {self.afmulator.iZPP}") + + # Override to randomize tip distance and probe tilt + def handle_distance(self): + self.randomize_distance(delta=0.25) + self.randomize_tip(max_tilt=0.5) + super().handle_distance() + + +if __name__ == "__main__": + + # Path where molecule geometry files are saved + mol_dir = Path("./molecules") + + # Directory where to save data + data_dir = Path(f"./data/") + + # Define simulator and image descriptor parameters + scan_window = ((0, 0, 6.0), (23.875, 23.875, 7.9)) + scan_dim = (192, 192, 19) + afmulator = AFMulator(pixPerAngstrome=10, scan_dim=scan_dim, scan_window=scan_window, tipR0=[0.0, 0.0, 4.0]) + aux_maps = [ + ESMapConstant( + scan_dim=afmulator.scan_dim[:2], + scan_window=[afmulator.scan_window[0][:2], afmulator.scan_window[1][:2]], + height=4.0, + vdW_cutoff=-2.0, + Rpp=1.0, + ) + ] + generator_arguments = { + "afmulator": afmulator, + "aux_maps": aux_maps, + "batch_size": 1, + "distAbove": 5.25, + "iZPPs": [8, 54, 17], # CO, Xe, Cl + "Qs": [[-10, 20, -10, 0], [30, -60, 30, 0], [-0.3, 0, 0, 0]], + "QZs": [[0.1, 0, -0.1, 0], [0.1, 0, -0.1, 0], [0, 0, 0, 0]], + } + + # Number of tar file shards for each set + target_shard_count = 8 + + # Make sure the save directory exists + data_dir.mkdir(exist_ok=True, parents=True) + + # Download the dataset. The extraction may take a while since there are ~235k files. + download_dataset("ED-AFM-molecules", mol_dir) + + # Paths to molecule xyz files + train_paths = list((mol_dir / "train").glob("*.xyz")) + val_paths = list((mol_dir / "validation").glob("*.xyz")) + test_paths = list((mol_dir / "test").glob("*.xyz")) + + # Generate dataset + start_time = time.perf_counter() + counter = 1 + for mode, molecules in zip(["train", "val", "test"], [train_paths, val_paths, test_paths]): + + # Construct generator + generator = Trainer(paths=molecules, **generator_arguments) + + # Generate data + max_count = ceil(len(generator) / target_shard_count) + start_gen = time.perf_counter() + with TarWriter(data_dir, f"{data_dir.name}-K-0_{mode}", max_count=max_count) as tar_writer: + for i, (X, Y, xyz) in enumerate(generator): + + # Get rid of the batch dimension + X = [x[0] for x in X] + Y = [y[0] for y in Y] + xyz = xyz[0] + + # Save information of the simulation parameters into the xyz comment line + amp = generator.afmulator.amplitude + R0 = generator.afmulator.tipR0 + kxy = generator.afmulator.scanner.stiffness[0] + sw = generator.afmulator.scan_window + comment_str = f"Scan window: [{sw[0]}, {sw[1]}], Amplitude: {amp}, tip R0: {R0}, kxy: {kxy}" + + # Write the sample to a tar file + tar_writer.add_sample(X, xyz, Y=Y, comment_str=comment_str) + + if i % 100 == 0: + elapsed = time.perf_counter() - start_gen + eta = elapsed / (i + 1) * (len(generator) - i) + print( + f"{mode} sample {i}/{len(generator)}, writing to `{tar_writer.ft.name}`, " + f"Elapsed: {elapsed:.2f}s, ETA: {eta:.2f}s" + ) + + print(f"Done with {mode} - Elapsed time: {time.perf_counter() - start_gen:.1f}") + + print("Total time taken: %d" % (time.perf_counter() - start_time)) diff --git a/papers/ed-afm/run_train.sh b/papers/ed-afm/run_train.sh new file mode 100755 index 0000000..d621dd0 --- /dev/null +++ b/papers/ed-afm/run_train.sh @@ -0,0 +1,35 @@ +#!/bin/bash + +# Number of GPUs and the number of samples per batch per GPU (total batch size = N_GPU x BATCH_SIZE). +N_GPU=1 +BATCH_SIZE=30 + +# Number of parallel workers per GPU for loading data from disk. +N_WORKERS=8 + +export OMP_NUM_THREADS=1 + +torchrun \ + --standalone \ + --nnodes 1 \ + --nproc_per_node $N_GPU \ + --max_restarts 0 \ + train.py \ + --run_dir ./train \ + --data_dir ./data \ + --urls_train "data-K-0_train_{0..7}.tar" \ + --urls_val "data-K-0_val_{0..7}.tar" \ + --urls_test "data-K-0_test_{0..7}.tar" \ + --random_seed 0 \ + --train True \ + --test True \ + --predict True \ + --epochs 50 \ + --num_workers $N_WORKERS \ + --batch_size $BATCH_SIZE \ + --avg_best_epochs 3 \ + --pred_batches 3 \ + --lr 1e-4 \ + --loss_labels "ES" \ + --loss_weights 1.0 \ + --timings diff --git a/papers/ed-afm/train.py b/papers/ed-afm/train.py new file mode 100644 index 0000000..9f14dbe --- /dev/null +++ b/papers/ed-afm/train.py @@ -0,0 +1,357 @@ +import os +import random +import time +from pathlib import Path + +import numpy as np +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import webdataset as wds +import yaml +from torch import optim +from torch.distributed.algorithms.join import Join +from torch.nn.parallel import DistributedDataParallel + +import mlspm.data_loading as dl +import mlspm.preprocessing as pp +import mlspm.visualization as vis +from mlspm import utils +from mlspm.cli import parse_args +from mlspm.image.models import EDAFMNet +from mlspm.logging import LossLogPlot, SyncedLoss +from mlspm.losses import WeightedMSELoss + + +def make_model(device, cfg): + model = EDAFMNet(device=device) + criterion = WeightedMSELoss(cfg["loss_weights"]) + optimizer = optim.Adam(model.parameters(), lr=cfg["lr"]) + lr_decay_rate = 1e-5 + lr_decay = optim.lr_scheduler.LambdaLR(optimizer, lambda b: 1.0 / (1.0 + lr_decay_rate * b)) + return model, criterion, optimizer, lr_decay + + +def apply_preprocessing(batch): + + X = batch["X"] + Y = batch["Y"] + xyz = batch["xyz"] + + X = [X[0], X[1]] # Pick CO and Xe + X = [x[:, :, :, 2:8] for x in X] + pp.rand_shift_xy_trend(X, max_layer_shift=0.02, max_total_shift=0.04) + + X, Y = pp.add_rotation_reflection(X, Y, reflections=True, multiple=3, crop=(128, 128)) + X, Y = pp.random_crop(X, Y, min_crop=0.75, max_aspect=1.25) + + pp.add_norm(X, per_layer=True) + pp.add_gradient(X, c=0.3) + pp.add_noise(X, c=0.1, randomize_amplitude=True, normal_amplitude=True) + pp.add_cutout(X, n_holes=5) + + return X, Y, xyz + + +def make_webDataloader(cfg, mode="train"): + assert mode in ["train", "val", "test"], mode + + shard_list = dl.ShardList( + cfg[f"urls_{mode}"], + base_path=cfg["data_dir"], + world_size=cfg["world_size"], + rank=cfg["global_rank"], + substitute_param=(mode == "train"), + log=Path(cfg["run_dir"]) / "shards.log", + ) + + dataset = wds.WebDataset(shard_list) + dataset.pipeline.pop() + if mode == "train": + dataset.append(wds.shuffle(10)) # Shuffle order of shards + dataset.append(wds.tariterators.tarfile_to_samples()) # Gather files inside tar files into samples + dataset.append(wds.split_by_worker) # Use a different subset of samples in shards in different workers + if mode == "train": + dataset.append(wds.shuffle(100)) # Shuffle samples within a worker process + dataset.append(wds.decode("pill", wds.autodecode.basichandlers, dl.decode_xyz)) # Decode image and xyz files + dataset.append(dl.rotate_and_stack()) # Combine separate images into a stack + dataset.append(dl.batched(cfg["batch_size"])) # Gather samples into batches + dataset = dataset.map(apply_preprocessing) # Preprocess batch + + dataloader = wds.WebLoader( + dataset, + num_workers=cfg["num_workers"], + batch_size=None, # Batching is done in the WebDataset + pin_memory=True, + collate_fn=dl.default_collate, + persistent_workers=False, + ) + + return dataset, dataloader + + +def batch_to_device(batch, device): + X, ref, xyz = batch + X = [x.to(device) for x in X] + ref = [y.to(device) for y in ref] + return X, ref, xyz + + +def run(cfg): + print(f'Starting on global rank {cfg["global_rank"]}, local rank {cfg["local_rank"]}\n', flush=True) + + # Initialize the distributed environment. + dist.init_process_group(cfg["comm_backend"]) + + start_time = time.perf_counter() + + if cfg["global_rank"] == 0: + # Create run directory + if not os.path.exists(cfg["run_dir"]): + os.makedirs(cfg["run_dir"]) + dist.barrier() + + # Define model, optimizer, and loss + model, criterion, optimizer, lr_decay = make_model(cfg["local_rank"], cfg) + + print(f'World size = {cfg["world_size"]}') + print(f"Trainable parameters: {utils.count_parameters(model)}") + + # Setup checkpointing and load a checkpoint if available + checkpointer = utils.Checkpointer( + model, + optimizer, + additional_data={"lr_params": lr_decay}, + checkpoint_dir=os.path.join(cfg["run_dir"], "Checkpoints/"), + keep_last_epoch=True, + ) + init_epoch = checkpointer.epoch + + # Setup logging + log_file = open(os.path.join(cfg["run_dir"], "batches.log"), "a") + loss_logger = LossLogPlot( + log_path=os.path.join(cfg["run_dir"], "loss_log.csv"), + plot_path=os.path.join(cfg["run_dir"], "loss_history.png"), + loss_labels=cfg["loss_labels"], + loss_weights=cfg["loss_weights"], + print_interval=cfg["print_interval"], + init_epoch=init_epoch, + stream=log_file, + ) + + # Wrap model in DistributedDataParallel. + model = DistributedDataParallel(model, device_ids=[cfg["local_rank"]], find_unused_parameters=False) + + if cfg["train"]: + # Create datasets and dataloaders + _, train_loader = make_webDataloader(cfg, "train") + _, val_loader = make_webDataloader(cfg, "val") + + if cfg["global_rank"] == 0: + if init_epoch <= cfg["epochs"]: + print(f"\n ========= Starting training from epoch {init_epoch}") + else: + print("Model already trained") + + for epoch in range(init_epoch, cfg["epochs"] + 1): + + if cfg["global_rank"] == 0: + print(f"\n === Epoch {epoch}") + + # Train + if cfg["timings"] and cfg["global_rank"] == 0: + t0 = time.perf_counter() + + model.train() + with Join([model, loss_logger.get_joinable("train")]): + for ib, batch in enumerate(train_loader): + + # Transfer batch to device + X, ref, _ = batch_to_device(batch, cfg["local_rank"]) + + if cfg["timings"] and cfg["global_rank"] == 0: + torch.cuda.synchronize() + t1 = time.perf_counter() + + # Forward + pred, _ = model(X) + losses = criterion(pred, ref) + loss = losses[0] + + if cfg["timings"] and cfg["global_rank"] == 0: + torch.cuda.synchronize() + t2 = time.perf_counter() + + # Backward + optimizer.zero_grad() + loss.backward() + optimizer.step() + lr_decay.step() + + # Log losses + loss_logger.add_train_loss(loss) + + if cfg["timings"] and cfg["global_rank"] == 0: + torch.cuda.synchronize() + t3 = time.perf_counter() + print(f"(Train) Load Batch/Forward/Backward: {t1-t0:6f}/{t2-t1:6f}/{t3-t2:6f}") + t0 = t3 + + # Validate + if cfg["global_rank"] == 0: + val_start = time.perf_counter() + if cfg["timings"]: + t0 = val_start + + model.eval() + with Join([loss_logger.get_joinable("val")]): + with torch.no_grad(): + for ib, batch in enumerate(val_loader): + + # Transfer batch to device + X, ref, _ = batch_to_device(batch, cfg["local_rank"]) + + if cfg["timings"] and cfg["global_rank"] == 0: + torch.cuda.synchronize() + t1 = time.perf_counter() + + # Forward + pred, _ = model(X) + losses = criterion(pred, ref) + + loss_logger.add_val_loss(losses[0]) + + if cfg["timings"] and cfg["global_rank"] == 0: + torch.cuda.synchronize() + t2 = time.perf_counter() + print(f"(Val) Load Batch/Forward: {t1-t0:6f}/{t2-t1:6f}") + t0 = t2 + + # Write average losses to log and report to terminal + loss_logger.next_epoch() + + # Save checkpoint + checkpointer.next_epoch(loss_logger.val_losses[-1][0]) + + # Return to best epoch + checkpointer.revert_to_best_epoch() + + # Return to best epoch, and save model weights + dist.barrier() + checkpointer.revert_to_best_epoch() + if cfg["global_rank"] == 0: + torch.save(model.module.state_dict(), save_path := os.path.join(cfg["run_dir"], "best_model.pth")) + print(f"\nModel saved to {save_path}") + print(f"Best validation loss on epoch {checkpointer.best_epoch}: {checkpointer.best_loss}") + print( + f'Average of best {cfg["avg_best_epochs"]} validation losses: ' + f'{np.sort(loss_logger.val_losses[:, 0])[:cfg["avg_best_epochs"]].mean()}' + ) + + if cfg["test"] or cfg["predict"]: + _, test_loader = make_webDataloader(cfg, "test") + + if cfg["test"]: + + if cfg["global_rank"] == 0: + print(f"\n ========= Testing with model from epoch {checkpointer.best_epoch}") + + eval_losses = SyncedLoss(num_losses=len(loss_logger.loss_labels)) + eval_start = time.perf_counter() + if cfg["timings"] and cfg["global_rank"] == 0: + t0 = eval_start + + model.eval() + with Join([eval_losses]): + with torch.no_grad(): + for ib, batch in enumerate(test_loader): + + # Transfer batch to device + X, ref, xyz = batch_to_device(batch, cfg["local_rank"]) + + if cfg["timings"] and cfg["global_rank"] == 0: + torch.cuda.synchronize() + t1 = time.perf_counter() + + # Forward + pred, _ = model(X) + losses = criterion(pred, ref) + eval_losses.append(losses[0]) + + if (ib + 1) % cfg["print_interval"] == 0: + print(f"Test Batch {ib+1}") + + if cfg["timings"] and cfg["global_rank"] == 0: + torch.cuda.synchronize() + t2 = time.perf_counter() + print(f"(Test) t0/Load Batch/Forward: {t1-t0:6f}/{t2-t1:6f}/") + t0 = t2 + + # Average losses and print + eval_loss = eval_losses.mean() + print(f"Test set loss: {loss_logger.loss_str(eval_loss)}") + + # Save test set loss to file + with open(os.path.join(cfg["run_dir"], "test_loss.txt"), "w") as f: + f.write(";".join([str(l) for l in eval_loss])) + + if cfg["predict"] and cfg["global_rank"] == 0: + + # Make predictions + print(f'\n ========= Predict on {cfg["pred_batches"]} batches from the test set') + counter = 0 + pred_dir = os.path.join(cfg["run_dir"], "predictions/") + + with torch.no_grad(): + for ib, batch in enumerate(test_loader): + + if ib >= cfg["pred_batches"]: + break + + # Transfer batch to device + X, ref, xyz = batch_to_device(batch, cfg["local_rank"]) + + # Forward + pred, _ = model(X) + + # Data back to host + X = [x.squeeze(1).cpu().numpy() for x in X] + pred = [p.cpu().numpy() for p in pred] + ref = [r.cpu().numpy() for r in ref] + + # Save xyzs + utils.batch_write_xyzs(xyz, outdir=pred_dir, start_ind=counter) + + # Visualize input AFM images and predictions + vis.make_input_plots(X, outdir=pred_dir, start_ind=counter) + vis.make_prediction_plots(pred, ref, descriptors=cfg["loss_labels"], outdir=pred_dir, start_ind=counter) + + counter += len(X[0]) + + print(f'Done at rank {cfg["global_rank"]}. Total time: {time.perf_counter() - start_time:.0f}s') + + log_file.close() + dist.barrier() + dist.destroy_process_group() + + +if __name__ == "__main__": + + # Get config + cfg = parse_args() + run_dir = Path(cfg["run_dir"]) + run_dir.mkdir(exist_ok=True, parents=True) + with open(run_dir / "config.yaml", "w") as f: + # Remember settings + yaml.safe_dump(cfg, f) + + # Set random seeds + torch.manual_seed(cfg["random_seed"]) + random.seed(cfg["random_seed"]) + np.random.seed(cfg["random_seed"]) + + # Start run + cfg["world_size"] = int(os.environ["WORLD_SIZE"]) + cfg["global_rank"] = int(os.environ["RANK"]) + cfg["local_rank"] = int(os.environ["LOCAL_RANK"]) + run(cfg) diff --git a/tests/test_models.py b/tests/test_models.py index 70f1dac..af7acc8 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -256,6 +256,7 @@ def test_GraphImgNet(): # fmt:on + def test_ASDAFMNet(): import torch @@ -272,3 +273,39 @@ def test_ASDAFMNet(): assert ys[0].shape == ys[1].shape == ys[2].shape == (5, 128, 128) assert ys[1].min() >= 0.0 assert ys[2].min() >= 0.0 + + +def test_AttentionUnet(): + + import torch + from mlspm.image.models import AttentionUNet + + torch.manual_seed(0) + + device = "cpu" + + for z_in in range(6, 15): + + model = AttentionUNet( + z_in=z_in, + n_in=2, + n_out=3, + in_channels=1, + merge_block_channels=[2], + conv3d_block_channels=[2, 4, 8], + conv2d_block_channels=[32], + attention_channels= [4, 8, 6], + pool_z_strides=[2, 1, 2], + device=device + ) + + x = [torch.rand((5, 1, 128, 128, z_in)).to(device), torch.rand((5, 1, 128, 128, z_in)).to(device)] + ys, att = model(x) + + assert len(ys) == 3 + assert ys[0].shape == ys[1].shape == ys[2].shape == (5, 128, 128) + + assert len(att) == 3 + assert att[0].shape == (5, 32, 32) + assert att[1].shape == (5, 64, 64) + assert att[2].shape == (5, 128, 128)