diff --git a/lightshow/omnixas/README.md b/lightshow/omnixas/README.md new file mode 100644 index 00000000..862f8c06 --- /dev/null +++ b/lightshow/omnixas/README.md @@ -0,0 +1,27 @@ +# Omnixas Integration in Lightshow + +## Description + +Self-contained modules to predcit site-XAS [`pymatgen`](https://pymatgen.org/) structures using Omnixas models. + +## Models + +- `v1.1.1`: Available in [Omnixas Repository](https://github.com/AI-multimodal/OmniXAS/tree/main/models) + +## Input + +- Input is `pymatgen` `Structure` object. Useful I/O methods for `pymatgen` structures are documented in [pyamtgen documentation](https://pymatgen.org/pymatgen.core.html#pymatgen.core.structure.IStructure) which includes `from_file`, `from_id`, `from_sites`, `from_spacegroup`, `from_str`. + +## Supported File Formats + +- use [`from_file`](https://pymatgen.org/pymatgen.core.html#pymatgen.core.structure.IMolecule.from_file) to load `CIF`, `POSCAR`, `CONTCAR`, `CHGCAR`, `LOCPOT`, `vasprun.xml`, `CSSR`, `Netcdf` and `pymatgen's JSON-serialized structures`. + +## Usage + +```python +import matplotlib.pyplot as plt +material_structure_file = "mp-1005792/POSCAR" +strucutre = PymatgenStructure.from_file(material_structure_file) +spectrum = XASModel(element="Cu", type="FEFF").predict(strucutre, 8) +plt.plot(spectrum) +``` diff --git a/lightshow/omnixas/lightshow.py b/lightshow/omnixas/lightshow.py new file mode 100644 index 00000000..1915b37d --- /dev/null +++ b/lightshow/omnixas/lightshow.py @@ -0,0 +1,164 @@ +# %% +from functools import cache +from typing import List + +import numpy as np +import torch +import yaml +from lightning import LightningModule +from loguru import logger +from matgl import load_model +from matgl.ext.pymatgen import Structure2Graph +from matgl.graph.compute import ( + compute_pair_vector_and_distance, + compute_theta_and_phi, + create_line_graph, +) +from matgl.utils.cutoff import polynomial_cutoff +from matplotlib import pyplot as plt +from pymatgen.core import Structure as PymatgenStructure +from torch import nn + + +class XASBlock(nn.Sequential): + DROPOUT = 0.5 + + def __init__(self, input_dim: int, hidden_dims: List[int], output_dim: int): + dims = [input_dim] + hidden_dims + [output_dim] + layers = [] + for i, (w1, w2) in enumerate(zip(dims[:-1], dims[1:])): + layers.append(nn.Linear(w1, w2)) + if i < len(dims) - 2: # not last layer + layers.append(nn.BatchNorm1d(w2)) + layers.append(nn.SiLU()) + layers.append(nn.Dropout(self.DROPOUT)) + else: + layers.append(nn.Softplus()) # last layer + super().__init__(*layers) + + +class XASBlockModule(LightningModule): + def __init__(self, model: nn.Module): + super().__init__() + self.model = model + + def forward(self, x): + return self.model(x) + + @classmethod + def load( + cls, + element: str, + type: str, + pattern="models/xasblock/v1.1.1/{element}_{type}.ckpt", # TODO: hardcoded path + ): + path = pattern.format(element=element, type=type) + logger.info(f"Loading XASBlock model from {path}") + model = XASBlock( + input_dim=64, + hidden_dims=[500, 500, 550], # TODO: hardcoded dims for version v1.1.1 + output_dim=141, + ) + module = cls.load_from_checkpoint(checkpoint_path=path, model=model) + module.eval() + module.freeze() + return module + + +class M3GNetFeaturizer: + def __init__(self, model=None, n_blocks=None): + self.model = model or M3GNetFeaturizer._load_m3gnet() + self.model.eval() + self.n_blocks = n_blocks or self.model.n_blocks + + def featurize( + self, + structure: PymatgenStructure, + ): + graph_converter = Structure2Graph(self.model.element_types, self.model.cutoff) + g, state_attr = graph_converter.get_graph(structure) + + node_types = g.ndata["node_type"] + bond_vec, bond_dist = compute_pair_vector_and_distance(g) + + g.edata["bond_vec"] = bond_vec.to(g.device) + g.edata["bond_dist"] = bond_dist.to(g.device) + + with torch.no_grad(): + expanded_dists = self.model.bond_expansion(g.edata["bond_dist"]) + + l_g = create_line_graph(g, self.model.threebody_cutoff) + + l_g.apply_edges(compute_theta_and_phi) + g.edata["rbf"] = expanded_dists + three_body_basis = self.model.basis_expansion(l_g) + three_body_cutoff = polynomial_cutoff( + g.edata["bond_dist"], self.model.threebody_cutoff + ) + node_feat, edge_feat, state_feat = self.model.embedding( + node_types, g.edata["rbf"], state_attr + ) + + for i in range(self.n_blocks): + edge_feat = self.model.three_body_interactions[i]( + g, + l_g, + three_body_basis, + three_body_cutoff, + node_feat, + edge_feat, + ) + edge_feat, node_feat, state_feat = self.model.graph_layers[i]( + g, edge_feat, node_feat, state_feat + ) + return np.array(node_feat.detach().numpy()) + + @cache + @staticmethod + def _load_m3gnet(path="models/M3GNet-MP-2021.2.8-PES"): # TODO: hardcoded path + logger.info(f"Loading m3gnet model from {path}") + model = load_model(path).model + model.eval() + return model + + +class M3GNetSiteFeaturizer(M3GNetFeaturizer): + def featurize(self, structure: PymatgenStructure, site_index: int): + return super().featurize(structure)[site_index] + + +class XASModel: + featurizer = M3GNetSiteFeaturizer() + + def __init__(self, element: str, type: str): + self.element = element + self.type = type + self.model = XASBlockModule.load(element=element, type=type) + + def _get_feature( + self, + structure: PymatgenStructure, + site_index: int, + ): + return self.featurizer.featurize(structure, site_index) + + def predict( + self, + structure: PymatgenStructure, + site_index: int, + ): + feature = self._get_feature(structure, site_index) + spectrum = self.model(torch.tensor(feature).unsqueeze(0)) + return spectrum.detach().numpy().squeeze() + + +if __name__ == "__main__": + material_structure_file = ( + "examples/material/mp-1005792/POSCAR" # TODO: hardcoded path + ) + strucutre = PymatgenStructure.from_file(material_structure_file) + spectrum = XASModel(element="Cu", type="FEFF").predict(strucutre, 8) + plt.plot(spectrum) + plt.show() + +# %%