Skip to content

Commit

Permalink
merge commit with FAENet updates
Browse files Browse the repository at this point in the history
  • Loading branch information
laserkelvin committed Apr 22, 2024
2 parents 92da14e + 6f0a9e1 commit 01bc048
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 82 deletions.
58 changes: 9 additions & 49 deletions examples/model_demos/faenet_pyg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,36 +4,34 @@

from matsciml.datasets.transforms import (
FrameAveraging,
GraphToGraphTransform,
PointCloudToGraphTransform,
UnitCellCalculator,
)
from matsciml.lightning.data_utils import MatSciMLDataModule
from matsciml.models.base import ScalarRegressionTask
from matsciml.models.base import ForceRegressionTask
from matsciml.models.pyg import FAENet

"""
This example script runs through a fast development run of the IS2RE devset
in combination with a PyG implementation of FAENet.
"""

# construct IS2RE relaxed energy regression with PyG implementation of FAENet
task = ScalarRegressionTask(
# construct S2EF force regression task with PyG implementation of FAENet
task = ForceRegressionTask(
encoder_class=FAENet,
encoder_kwargs={
"average_frame_embeddings": True,
"average_frame_embeddings": False, # set to false for use with FA transform
"pred_as_dict": False,
"hidden_dim": 128,
"out_dim": 64,
"out_dim": 128,
"tag_hidden_channels": 0,
},
output_kwargs={"lazy": False, "input_dim": 64, "hidden_dim": 64},
task_keys=["energy_relaxed"],
output_kwargs={"lazy": False, "input_dim": 128, "hidden_dim": 128},
task_keys=["force"],
)

# ### matsciml devset for OCP are serialized with DGL - this transform goes between the two frameworks
dm = MatSciMLDataModule.from_devset(
"IS2REDataset",
"S2EFDataset",
dset_kwargs={
"transforms": [
PointCloudToGraphTransform(
Expand All @@ -49,42 +47,4 @@

# run a quick training loop
trainer = pl.Trainer(fast_dev_run=10)
trainer.fit(task, datamodule=dm)


########################################################################################
########################################################################################


# construct Materials Project band gap regression with PyG implementation of FAENet
task = ScalarRegressionTask(
encoder_class=FAENet,
encoder_kwargs={
"pred_as_dict": False,
"hidden_dim": 128,
"out_dim": 64,
"tag_hidden_channels": 0,
"input_dim": 128,
},
output_kwargs={"lazy": False, "input_dim": 64, "hidden_dim": 64},
task_keys=["band_gap"],
)

dm = MatSciMLDataModule.from_devset(
"MaterialsProjectDataset",
dset_kwargs={
"transforms": [
UnitCellCalculator(),
PointCloudToGraphTransform(
"pyg",
cutoff_dist=20.0,
node_keys=["pos", "atomic_numbers"],
),
FrameAveraging(frame_averaging="3D", fa_method="stochastic"),
],
},
)

# run a quick training loop
trainer = pl.Trainer(fast_dev_run=10)
trainer.fit(task, datamodule=dm)
trainer.fit(task, datamodule=dm)
69 changes: 46 additions & 23 deletions matsciml/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1282,7 +1282,6 @@ def _compute_losses(
losses[key] = self.loss_func(predictions[key], target_val) * (
coefficient / predictions[key].numel()
)

total_loss: torch.Tensor = sum(losses.values())
return {"loss": total_loss, "log": losses}

Expand Down Expand Up @@ -1536,15 +1535,19 @@ def forward(
pos: torch.Tensor = graph.ndata.get("pos")
# for frame averaging
fa_rot = graph.ndata.get("fa_rot", None)
fa_pos = graph.ndata.get("fa_pos", None)
else:
# otherwise assume it's PyG
pos: torch.Tensor = graph.pos
# for frame averaging
fa_rot = getattr(graph, "fa_rot", None)
fa_pos = getattr(graph, "fa_pos", None)
else:
# assume point cloud otherwise
pos: torch.Tensor = batch.get("pos")
# no frame averaging architecture yet for point clouds
fa_rot = None
fa_pos = None
if pos is None:
raise ValueError(
"No atomic positions were found in batch - neither as standalone tensor nor graph.",
Expand All @@ -1557,53 +1560,73 @@ def forward(
raise ValueError(
f"'pos' data is required for force calculation, but isn't a tensor or a list of tensors: {type(pos)}.",
)
if isinstance(fa_pos, torch.Tensor):
fa_pos.requires_grad_(True)
elif isinstance(fa_pos, list):
[f_p.requires_grad_(True) for f_p in fa_pos]
if "embeddings" in batch:
embeddings = batch.get("embeddings")
else:
embeddings = self.encoder(batch)
outputs = self.process_embedding(embeddings, pos, fa_rot)
natoms = batch.get("natoms", None)
outputs = self.process_embedding(embeddings, pos, fa_rot, fa_pos, natoms)
return outputs

def process_embedding(
self,
embeddings: Embeddings,
pos: torch.Tensor,
fa_rot: None | torch.Tensor = None,
fa_pos: None | torch.Tensor = None,
natoms: None | torch.Tensor = None,
) -> dict[str, torch.Tensor]:
outputs = {}
energy = self.output_heads["energy"](embeddings.system_embedding)
# now use autograd for force calculation
force = (
-1
* torch.autograd.grad(
energy,
pos,
grad_outputs=torch.ones_like(energy),
create_graph=True,
)[0]
)

def energy_and_force(
pos: torch.Tensor, system_embedding: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
energy = self.output_heads["energy"](system_embedding)
# now use autograd for force calculation
force = (
-1
* torch.autograd.grad(
energy,
pos,
grad_outputs=torch.ones_like(energy),
create_graph=True,
)[0]
)
return energy, force

if fa_pos is None:
energy, force = energy_and_force(pos, embeddings.system_embedding)
else:
energy = []
force = []
for idx, pos in enumerate(fa_pos):
frame_embedding = embeddings.system_embedding[:, idx, :]
frame_energy, frame_force = energy_and_force(pos, frame_embedding)
force.append(frame_force)
energy.append(frame_energy)

# check to see if we are frame averaging
if isinstance(fa_rot, torch.Tensor):
natoms = pos.size(0)
if fa_rot is not None:
all_forces = []
# loop over each frame prediction, and transform to guarantee
# equivariance of frame averaging method
for frame_idx, frame_rot in fa_rot:
natoms = natoms.squeeze(-1).to(int)
for frame_idx, frame_rot in enumerate(fa_rot):
repeat_rot = torch.repeat_interleave(
frame_rot,
natoms,
dim=0,
).to(self.device)
rotated_forces = (
force[:, frame_idx, :]
.view(-1, 1, 3)
.bmm(
repeat_rot.transpose(1, 2),
)
force[frame_idx].view(-1, 1, 3).bmm(repeat_rot.transpose(1, 2))
)
all_forces.append(rotated_forces.view(natoms, 3))
all_forces.append(rotated_forces)
# combine all the force data into a single tensor
force = torch.stack(all_forces, dim=1)
force = torch.cat(all_forces, dim=1)
# reduce outputs to what are expected shapes
outputs["force"] = reduce(
force,
Expand Down
12 changes: 2 additions & 10 deletions matsciml/models/pyg/faenet/faenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
FAENet: Frame Averaging Equivariant graph neural Network
Simple, scalable and expressive model for property prediction on 3D atomic systems.
"""

from __future__ import annotations

from copy import deepcopy
Expand Down Expand Up @@ -354,17 +355,8 @@ def first_forward(
Returns:
(dict): predicted energy, forces and final atomic hidden states
"""
if self.training:
mode = "train"
else:
mode = "inference"
preproc = True
preproc = False
data = graph

# energy gradient w.r.t. positions will be computed
if mode == "train" or self.regress_forces == "from_energy":
data.pos.requires_grad_(True)

# produce final embeddings after going through model
embeddings = self.energy_forward(data, preproc)
return embeddings
Expand Down

0 comments on commit 01bc048

Please sign in to comment.