Skip to content

Commit

Permalink
omnixas for lightshow
Browse files Browse the repository at this point in the history
  • Loading branch information
shubharajkharel authored and matthewcarbone committed Nov 20, 2024
1 parent ad51db9 commit cb81dda
Show file tree
Hide file tree
Showing 2 changed files with 191 additions and 0 deletions.
27 changes: 27 additions & 0 deletions lightshow/omnixas/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Omnixas Integration in Lightshow

## Description

Self-contained modules to predcit site-XAS [`pymatgen`](https://pymatgen.org/) structures using Omnixas models.

## Models

- `v1.1.1`: Available in [Omnixas Repository](https://github.com/AI-multimodal/OmniXAS/tree/main/models)

## Input

- Input is `pymatgen` `Structure` object. Useful I/O methods for `pymatgen` structures are documented in [pyamtgen documentation](https://pymatgen.org/pymatgen.core.html#pymatgen.core.structure.IStructure) which includes `from_file`, `from_id`, `from_sites`, `from_spacegroup`, `from_str`.

## Supported File Formats

- use [`from_file`](https://pymatgen.org/pymatgen.core.html#pymatgen.core.structure.IMolecule.from_file) to load `CIF`, `POSCAR`, `CONTCAR`, `CHGCAR`, `LOCPOT`, `vasprun.xml`, `CSSR`, `Netcdf` and `pymatgen's JSON-serialized structures`.

## Usage

```python
import matplotlib.pyplot as plt
material_structure_file = "mp-1005792/POSCAR"
strucutre = PymatgenStructure.from_file(material_structure_file)
spectrum = XASModel(element="Cu", type="FEFF").predict(strucutre, 8)
plt.plot(spectrum)
```
164 changes: 164 additions & 0 deletions lightshow/omnixas/lightshow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# %%
from functools import cache
from typing import List

import numpy as np
import torch
import yaml
from lightning import LightningModule
from loguru import logger
from matgl import load_model
from matgl.ext.pymatgen import Structure2Graph
from matgl.graph.compute import (
compute_pair_vector_and_distance,
compute_theta_and_phi,
create_line_graph,
)
from matgl.utils.cutoff import polynomial_cutoff
from matplotlib import pyplot as plt
from pymatgen.core import Structure as PymatgenStructure
from torch import nn


class XASBlock(nn.Sequential):
DROPOUT = 0.5

def __init__(self, input_dim: int, hidden_dims: List[int], output_dim: int):
dims = [input_dim] + hidden_dims + [output_dim]
layers = []
for i, (w1, w2) in enumerate(zip(dims[:-1], dims[1:])):
layers.append(nn.Linear(w1, w2))
if i < len(dims) - 2: # not last layer
layers.append(nn.BatchNorm1d(w2))
layers.append(nn.SiLU())
layers.append(nn.Dropout(self.DROPOUT))
else:
layers.append(nn.Softplus()) # last layer
super().__init__(*layers)


class XASBlockModule(LightningModule):
def __init__(self, model: nn.Module):
super().__init__()
self.model = model

def forward(self, x):
return self.model(x)

@classmethod
def load(
cls,
element: str,
type: str,
pattern="models/xasblock/v1.1.1/{element}_{type}.ckpt", # TODO: hardcoded path
):
path = pattern.format(element=element, type=type)
logger.info(f"Loading XASBlock model from {path}")
model = XASBlock(
input_dim=64,
hidden_dims=[500, 500, 550], # TODO: hardcoded dims for version v1.1.1
output_dim=141,
)
module = cls.load_from_checkpoint(checkpoint_path=path, model=model)
module.eval()
module.freeze()
return module


class M3GNetFeaturizer:
def __init__(self, model=None, n_blocks=None):
self.model = model or M3GNetFeaturizer._load_m3gnet()
self.model.eval()
self.n_blocks = n_blocks or self.model.n_blocks

def featurize(
self,
structure: PymatgenStructure,
):
graph_converter = Structure2Graph(self.model.element_types, self.model.cutoff)
g, state_attr = graph_converter.get_graph(structure)

node_types = g.ndata["node_type"]
bond_vec, bond_dist = compute_pair_vector_and_distance(g)

g.edata["bond_vec"] = bond_vec.to(g.device)
g.edata["bond_dist"] = bond_dist.to(g.device)

with torch.no_grad():
expanded_dists = self.model.bond_expansion(g.edata["bond_dist"])

l_g = create_line_graph(g, self.model.threebody_cutoff)

l_g.apply_edges(compute_theta_and_phi)
g.edata["rbf"] = expanded_dists
three_body_basis = self.model.basis_expansion(l_g)
three_body_cutoff = polynomial_cutoff(
g.edata["bond_dist"], self.model.threebody_cutoff
)
node_feat, edge_feat, state_feat = self.model.embedding(
node_types, g.edata["rbf"], state_attr
)

for i in range(self.n_blocks):
edge_feat = self.model.three_body_interactions[i](
g,
l_g,
three_body_basis,
three_body_cutoff,
node_feat,
edge_feat,
)
edge_feat, node_feat, state_feat = self.model.graph_layers[i](
g, edge_feat, node_feat, state_feat
)
return np.array(node_feat.detach().numpy())

@cache
@staticmethod
def _load_m3gnet(path="models/M3GNet-MP-2021.2.8-PES"): # TODO: hardcoded path
logger.info(f"Loading m3gnet model from {path}")
model = load_model(path).model
model.eval()
return model


class M3GNetSiteFeaturizer(M3GNetFeaturizer):
def featurize(self, structure: PymatgenStructure, site_index: int):
return super().featurize(structure)[site_index]


class XASModel:
featurizer = M3GNetSiteFeaturizer()

def __init__(self, element: str, type: str):
self.element = element
self.type = type
self.model = XASBlockModule.load(element=element, type=type)

def _get_feature(
self,
structure: PymatgenStructure,
site_index: int,
):
return self.featurizer.featurize(structure, site_index)

def predict(
self,
structure: PymatgenStructure,
site_index: int,
):
feature = self._get_feature(structure, site_index)
spectrum = self.model(torch.tensor(feature).unsqueeze(0))
return spectrum.detach().numpy().squeeze()


if __name__ == "__main__":
material_structure_file = (
"examples/material/mp-1005792/POSCAR" # TODO: hardcoded path
)
strucutre = PymatgenStructure.from_file(material_structure_file)
spectrum = XASModel(element="Cu", type="FEFF").predict(strucutre, 8)
plt.plot(spectrum)
plt.show()

# %%

0 comments on commit cb81dda

Please sign in to comment.