Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introducing structured model outputs #316

Merged
merged 35 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
6dc2090
refactor: making output head construction quasi-optional
laserkelvin Oct 17, 2024
b1b1749
fix: using moduledict instead of regular dict to pass assertion
laserkelvin Oct 17, 2024
25d4321
refactor: making forward pass predictions if not embeddings
laserkelvin Oct 17, 2024
8cfc5e5
refactor: stashing embeddings if included in encoder outputs
laserkelvin Oct 17, 2024
1873f56
refactor: adding check to validate encoder outputs
laserkelvin Oct 17, 2024
f2dbeb1
refactor: return encoder predictions directly for force regression
laserkelvin Oct 17, 2024
8abd7a9
Merge branch 'main' into byo-outputs
laserkelvin Nov 8, 2024
9724f89
refactor: added private variable for abstract tasks to signal byo out…
laserkelvin Nov 8, 2024
64c9aa5
refactor: updating forward signature with non-embedding output
laserkelvin Nov 8, 2024
78446c3
Merge branch 'structured-model-output-type' into byo-outputs
laserkelvin Nov 8, 2024
09e10fe
refactor: introducing model output data structure
laserkelvin Nov 8, 2024
fc28949
Merge branch 'structured-model-output-type' into byo-outputs
laserkelvin Nov 8, 2024
9e70a70
refactor: updated output signatures to include ModelOutput
laserkelvin Nov 8, 2024
87e0c43
refactor: allowing high level forward call to skip output heads
laserkelvin Nov 8, 2024
87c7d92
refactor: mapping model outputs to process embeddings
laserkelvin Nov 8, 2024
c075a55
refactor: removing kwarg for using encoder_predictions in favor of de…
laserkelvin Nov 8, 2024
d9ba520
refactor: changing forward signature to emit ModelOutput or dict
laserkelvin Nov 8, 2024
a8db278
refactor: using getattr instead of dict lookup to support dict or Mod…
laserkelvin Nov 8, 2024
74a14e2
refactor: adding dict case in addition to getattr
laserkelvin Nov 8, 2024
d5d88b2
refactor: making MACE wrapper emit model output structure
laserkelvin Nov 8, 2024
69565a9
refactor: flipping default state for forces in MACE to be disabled
laserkelvin Nov 8, 2024
8d1a882
Merge branch 'structured-model-output-type' into byo-outputs
laserkelvin Nov 8, 2024
51d1f52
test: updating MACE wrapper test to work for model outputs
laserkelvin Nov 9, 2024
d155f1c
refactor: updating force regression to allow ModelOutput
laserkelvin Nov 9, 2024
408844b
fix: added missing macewrapper to pyg model namespace
laserkelvin Nov 9, 2024
2829cdd
fix: mapping unit shifts into MACE kwargs
laserkelvin Nov 9, 2024
af0d9f1
test: adding unit test for BYO outputs represented by MACE
laserkelvin Nov 9, 2024
60513a9
docs: added documentation about using modeloutput data structure
laserkelvin Nov 12, 2024
3a091c2
refactor: adding model output key check in ForceRegressionTask
laserkelvin Nov 12, 2024
52dfeff
refactor: skipping output kwargs if encoder doesn't use output heads
laserkelvin Nov 12, 2024
f2bbf30
docs: updated design intention with model outputs
laserkelvin Nov 12, 2024
2c0b6b8
Merge branch 'main' into byo-outputs
laserkelvin Nov 12, 2024
abaf72f
fix: correcting missing FAENet import in pyg namespace
laserkelvin Nov 12, 2024
2e4a5ba
fix: correcting model to encoder in expected attribute name
laserkelvin Nov 12, 2024
a1e98ba
Merge branch 'main' into byo-outputs
laserkelvin Nov 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,42 @@ The model interfaces we have depend on their original implementation:
currently, we support models from DGL, PyG, and point clouds (specifically
GALA).

Implementing interfaces
-----------------------

.. note::
This section is primarily a detail for developers. If you are a user
and aren't interested in implementation details, feel free to skip ahead.

As mentioned earlier, models in Open MatSciML Toolkit come in two flavors:
wrappers of upstream implementations, or self-contained implementations.
Ultimately, the output of models should be standardized in one of two ways:
every model `_forward` call should return either an ``Embeddings`` (at the minimum)
or a ``ModelOutput`` object. The latter is implemented with ``pydantic`` and
therefore takes advantage of data validation workflows, including standardizing
and checking tensor shapes, which is currently the **recommended** way for model
outputs. It also allows flexibility in wrapper models to produce their own
outputs with their own algorithm, but still be used seamlessly through the pipeline.
An example of this can be found in the ``MACEWrapper``. The ``ModelOutput`` class also
includes an ``embeddings`` field, which makes it compatible with the traditional
Open MatSciML Toolkit workflow of leveraging one or more output heads.


.. autoclass:: matsciml.common.types.Embeddings
:members:


.. autoclass:: matsciml.common.types.ModelOutput
:members:


.. important::
Training tasks and workflows should branch based on the prescence of either
objects, taking ``ModelOutput`` as the priority. For specific tasks, we can
check if properties are set (e.g. ``total_energy`` and ``forces``), and if
they aren't there, we should pass the ``embeddings`` to output heads.


PyG models
----------

Expand Down
125 changes: 102 additions & 23 deletions matsciml/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,13 @@

from matsciml.common import package_registry
from matsciml.common.registry import registry
from matsciml.common.types import AbstractGraph, BatchDict, DataDict, Embeddings
from matsciml.common.types import (
AbstractGraph,
BatchDict,
DataDict,
Embeddings,
ModelOutput,
)
from matsciml.models.common import OutputHead
from matsciml.modules.normalizer import Normalizer
from matsciml.models import losses as matsciml_losses
Expand Down Expand Up @@ -205,7 +211,8 @@ def num_params(self):


class AbstractTask(ABC, pl.LightningModule):
# TODO the intention is for this class to supersede AbstractEnergyModel for DGL
__skip_output_heads__ = False # allows wrapper models to bypass output heads

def __init__(
self,
atom_embedding_dim: int,
Expand Down Expand Up @@ -258,7 +265,7 @@ def read_batch(self, batch: BatchDict) -> DataDict:
def read_batch_size(self, batch: BatchDict) -> int | None: ...

@abstractmethod
def _forward(self, *args, **kwargs) -> Embeddings:
def _forward(self, *args, **kwargs) -> Embeddings | ModelOutput:
"""
Implements the actual logic of the architecture. Given a set
of input features, produce outputs/predictions from the model.
Expand All @@ -270,7 +277,7 @@ def _forward(self, *args, **kwargs) -> Embeddings:
"""
...

def forward(self, batch: BatchDict) -> Embeddings:
def forward(self, batch: BatchDict) -> Embeddings | ModelOutput:
"""
Given a batch structure, extract out data and pass it into the
neural network architecture. This implements the 'forward' method
Expand All @@ -285,16 +292,20 @@ def forward(self, batch: BatchDict) -> Embeddings:

Returns
-------
Embeddings
Data structure containing system/graph and point/node level embeddings.
Embeddings | ModelOutput
For models that do not have their own output mechanism, this
emits a data structure containing system/graph and point/node
level embeddings. If they provide their own outputs, it should
be packaged in the ``ModelOutput`` data structure.
"""
input_data = self.read_batch(batch)
outputs = self._forward(**input_data)
# raise an error to help spot models that have not yet been refactored
if not isinstance(outputs, Embeddings):
raise ValueError(
"Encoder did not return `Embeddings` data structure: please refactor your model!",
)
if not self.__skip_output_heads__:
# raise an error to help spot models that have not yet been refactored
if not isinstance(outputs, Embeddings):
raise ValueError(
"Encoder did not return `Embeddings` data structure: please refactor your model!",
)
return outputs


Expand Down Expand Up @@ -373,7 +384,7 @@ def _forward(
mask: torch.Tensor | None = None,
sizes: list[int] | None = None,
**kwargs,
) -> Embeddings:
) -> Embeddings | ModelOutput:
"""
Sets expected patterns for args for point cloud based modeling, whereby
the bare minimum expected data are 'pos' and 'pc_features' akin to graph
Expand Down Expand Up @@ -516,7 +527,7 @@ def _forward(
edge_feats: torch.Tensor | None = None,
graph_feats: torch.Tensor | None = None,
**kwargs,
) -> Embeddings:
) -> Embeddings | ModelOutput:
"""
Sets args/kwargs for the expected components of a graph-based
model. At the bare minimum, we expect some kind of abstract
Expand Down Expand Up @@ -728,9 +739,17 @@ def __init__(
# convert to a module dict for consistent API usage
loss_func = nn.ModuleDict({key: value for key, value in loss_func.items()})
self.loss_func = loss_func
default_heads = {"act_last": None, "hidden_dim": 128}
default_heads.update(output_kwargs)
self.output_kwargs = default_heads
# only add output kwargs if we are going to use them
if not encoder.__skip_output_heads__:
default_heads = {"act_last": None, "hidden_dim": 128}
default_heads.update(output_kwargs)
self.output_kwargs = default_heads
else:
# emit warning to user if the kwargs aren't being used
if output_kwargs:
logger.warning(
f"Specified encoder {encoder.__class__.__name__} skips output heads; ignoring output kwargs."
)
self.normalize_kwargs = normalize_kwargs
self.task_keys = task_keys
self._task_loss_scaling = kwargs.get("task_loss_scaling", {})
Expand Down Expand Up @@ -763,7 +782,14 @@ def task_keys(self, values: set | list[str] | None) -> None:
# if we're setting task keys we have enough to initialize
# the output heads
if not self.has_initialized:
self.output_heads = self._make_output_heads()
if not self.encoder.__skip_output_heads__:
self.output_heads = self._make_output_heads()
else:
# keeping the keys to allow functionality that doesn't need
# the actual weights
self.output_heads = nn.ModuleDict(
{key: None for key in self._task_keys}
)
self.normalizers = self._make_normalizers()
# homogenize it into a dictionary mapping
if isinstance(self.loss_func, nn.Module) and not isinstance(
Expand Down Expand Up @@ -882,10 +908,29 @@ def has_rnn(self) -> bool:
def forward(
self,
batch: dict[str, torch.Tensor | dgl.DGLGraph | dict[str, torch.Tensor]],
) -> dict[str, torch.Tensor]:
embeddings = self.encoder(batch)
batch["embeddings"] = embeddings
outputs = self.process_embedding(embeddings)
) -> dict[str, torch.Tensor] | ModelOutput:
encoder_outputs = self.encoder(batch)
# in the case that the model does not produce its own outputs,
# we will pass them through the output heads.
if not self.encoder.__skip_output_heads__:
if not isinstance(encoder_outputs, (Embeddings, dict, ModelOutput)):
raise RuntimeError(
f"Encoder model must emit a dict, `ModelOutput`, or `Embeddings` object. Got {encoder_outputs} instead."
)
if isinstance(encoder_outputs, Embeddings):
batch["embeddings"] = encoder_outputs
outputs = self.process_embedding(encoder_outputs)
elif isinstance(encoder_outputs, ModelOutput):
batch["embeddings"] = encoder_outputs.embeddings
outputs = self.process_embedding(encoder_outputs.embeddings)
else:
# here we assume that encoder model is predicting directly
outputs = encoder_outputs
# optionally if we still have embeddings, keep em
if "embeddings" in encoder_outputs:
batch["embeddings"] = encoder_outputs["embeddings"]
else:
outputs = encoder_outputs
return outputs

def process_embedding(self, embeddings: Embeddings) -> dict[str, torch.Tensor]:
Expand Down Expand Up @@ -1056,7 +1101,13 @@ def _compute_losses(
loss_func = self.loss_func[key]
# determine if we need additional arguments
loss_func_signature = signature(loss_func.forward).parameters
kwargs = {"input": predictions[key], "target": target_val}
# TODO refactor this once outputs are homogenized
if isinstance(predictions, dict):
kwargs = {"input": predictions[key], "target": target_val}
else:
kwargs = {"input": getattr(predictions, key), "target": target_val}
if not isinstance(kwargs["input"], torch.Tensor):
raise KeyError(f"Expected model to produce output with key {key}.")
# pack atoms per graph information too
if "atoms_per_graph" in loss_func_signature:
if graph := batch.get("graph", None):
Expand Down Expand Up @@ -1871,10 +1922,38 @@ def forward(
fa_pos.requires_grad_(True)
elif isinstance(fa_pos, list):
[f_p.requires_grad_(True) for f_p in fa_pos]
# check to see if embeddings were stashed away
if "embeddings" in batch:
embeddings = batch.get("embeddings")
else:
embeddings = self.encoder(batch)
encoder_outputs = self.encoder(batch)
if not isinstance(encoder_outputs, (Embeddings, dict, ModelOutput)):
raise RuntimeError(
f"Encoder model must emit a dict, `ModelOutput`, or `Embeddings` object. Got {encoder_outputs} instead."
)
# sets the embeddings variable
if isinstance(encoder_outputs, Embeddings):
embeddings = encoder_outputs
# for BYO output head cases
elif isinstance(encoder_outputs, ModelOutput):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A comment that may not be super important to consider right at this moment but should be noted - some workflows, such as serving models with OpenKIM's kusp to then run benchmarks, require that node energies are present. Some models (such as mace) can output these directly and they may be worth hanging onto as well.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think they're included in ModelOutput - in the wrapper output they're stashed as node_energies. That's what you mean right?

# this checks to make sure we have the expected attributes
for key in ["total_energy", "forces"]:
if getattr(encoder_outputs, key, None) is None:
raise RuntimeError(
f"Model {self.encoder.__class__.__name__} is not emitting {key} in model outputs."
)
# map the outputs as expected by the task
return {
"energy": encoder_outputs.total_energy,
"force": encoder_outputs.forces,
}
# in the alternative case we assume the encoder is emitting predictions
else:
for key in ["energy", "force"]:
assert (
key in encoder_outputs
), f"Expected {key} to be in encoder outputs."
return encoder_outputs

outputs = self.process_embedding(embeddings, batch, pos, fa_rot, fa_pos)
return outputs
Expand Down
6 changes: 3 additions & 3 deletions matsciml/models/pyg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
# load models if we have PyG installed
if _has_pyg:
from matsciml.models.pyg.egnn import EGNN
from matsciml.models.pyg.mace import MACE, ScaleShiftMACE, MACEWrapper
from matsciml.models.pyg.faenet import FAENet
from matsciml.models.pyg.mace import MACE, ScaleShiftMACE

__all__ = ["EGNN", "FAENet", "MACE", "ScaleShiftMACE"]
__all__ = ["CGCNN", "EGNN", "FAENet", "MACE", "ScaleShiftMACE", "MACEWrapper"]

# these packages need additional pyg dependencies
if package_registry["torch_sparse"] and package_registry["torch_scatter"]:
Expand All @@ -39,7 +39,7 @@
from matsciml.models.pyg.schnet import SchNetWrap # noqa: F401
from matsciml.models.pyg.cgcnn import CGCNN # noqa: F401

__all__.extend(["ForceNet", "SchNetWrap", "FAENet", "CGCNN"])
__all__.extend(["ForceNet", "SchNetWrap", "CGCNN"])
else:
logger.warning(
"Missing torch_scatter; ForceNet, SchNet, and FAENet models will not be available."
Expand Down
33 changes: 26 additions & 7 deletions matsciml/models/pyg/mace/wrapper/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,13 @@
from mendeleev import element

from matsciml.models.base import AbstractPyGModel
from matsciml.common.types import BatchDict, DataDict, AbstractGraph, Embeddings
from matsciml.common.types import (
BatchDict,
DataDict,
AbstractGraph,
ModelOutput,
Embeddings,
)
from matsciml.common.registry import registry
from matsciml.common.inspection import get_model_required_args, get_model_all_args

Expand Down Expand Up @@ -47,6 +53,8 @@ def free_ion_energy_table(num_elements: int = 100) -> torch.Tensor:

@registry.register_model("MACEWrapper")
class MACEWrapper(AbstractPyGModel):
__skip_output_heads__ = True # MACE is BYO output head due to expansion

def __init__(
self,
atom_embedding_dim: int,
Expand All @@ -55,6 +63,7 @@ def __init__(
embedding_kwargs: Any = None,
encoder_only: bool = True,
readout_method: str | Callable = "add",
disable_forces: bool = True,
**mace_kwargs,
) -> None:
if embedding_kwargs is not None:
Expand Down Expand Up @@ -164,6 +173,7 @@ def read_batch(self, batch: BatchDict) -> DataDict:
"node_feats": one_hot_atoms,
"cell": batch["cell"],
"shifts": batch["offsets"],
"unit_shifts": batch["unit_offsets"],
}
)
return data
Expand All @@ -174,7 +184,7 @@ def _forward(
node_feats: torch.Tensor,
pos: torch.Tensor,
**kwargs,
) -> Embeddings:
) -> ModelOutput:
"""
Takes arguments in the standardized format, and passes them into MACE
with some redundant mapping.
Expand All @@ -192,8 +202,8 @@ def _forward(

Returns
-------
Embeddings
MatSciML ``Embeddings`` structure
ModelOutput
MatSciML ``ModelOutput`` data structure.
"""
# repack data into MACE format
mace_data = {
Expand All @@ -202,17 +212,26 @@ def _forward(
"ptr": graph.ptr,
"cell": kwargs["cell"],
"shifts": kwargs["shifts"],
"unit_shifts": kwargs["unit_shifts"],
"batch": graph.batch,
"edge_index": graph.edge_index,
}
outputs = self.encoder(
mace_data,
training=self.training,
compute_force=False,
compute_force=not self.hparams["disable_forces"],
compute_virials=False,
compute_stress=False,
compute_displacement=False,
compute_displacement=not self.hparams["disable_forces"],
)
node_embeddings = outputs["node_feats"]
graph_embeddings = self.readout(node_embeddings, graph.batch)
return Embeddings(graph_embeddings, node_embeddings)
embeddings = Embeddings(graph_embeddings, node_embeddings)
output = ModelOutput(
batch_size=graph.batch_size,
forces=outputs["forces"],
total_energy=outputs["energy"],
node_energies=outputs["node_energy"],
embeddings=embeddings,
)
return output
10 changes: 9 additions & 1 deletion matsciml/models/pyg/mace/wrapper/tests/test_mace_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,18 @@ def test_model_forward_nograd(dset_class_name: str, mace_architecture: MACEWrapp
batch = next(iter(loader))
# run the model without gradient tracking
with torch.no_grad():
embeddings = mace_architecture(batch)
model_output = mace_architecture(batch)
embeddings = model_output.embeddings
# returns embeddings, and runs numerical checks
for z in [embeddings.system_embedding, embeddings.point_embedding]:
assert torch.isreal(z).all()
assert ~torch.isnan(z).all() # check there are no NaNs
assert torch.isfinite(z).all()
assert torch.all(torch.abs(z) <= 1000) # ensure reasonable values
# check energies as finite
graph_energies = model_output.total_energy
assert torch.isreal(graph_energies).all()
assert torch.isfinite(graph_energies).all()
node_energies = model_output.node_energies
assert torch.isreal(node_energies).all()
assert torch.isfinite(node_energies).all()
Loading
Loading