diff --git a/docs/source/models.rst b/docs/source/models.rst index dd7dac5b..417f1a8b 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -19,6 +19,42 @@ The model interfaces we have depend on their original implementation: currently, we support models from DGL, PyG, and point clouds (specifically GALA). +Implementing interfaces +----------------------- + +.. note:: + This section is primarily a detail for developers. If you are a user + and aren't interested in implementation details, feel free to skip ahead. + +As mentioned earlier, models in Open MatSciML Toolkit come in two flavors: +wrappers of upstream implementations, or self-contained implementations. +Ultimately, the output of models should be standardized in one of two ways: +every model `_forward` call should return either an ``Embeddings`` (at the minimum) +or a ``ModelOutput`` object. The latter is implemented with ``pydantic`` and +therefore takes advantage of data validation workflows, including standardizing +and checking tensor shapes, which is currently the **recommended** way for model +outputs. It also allows flexibility in wrapper models to produce their own +outputs with their own algorithm, but still be used seamlessly through the pipeline. +An example of this can be found in the ``MACEWrapper``. The ``ModelOutput`` class also +includes an ``embeddings`` field, which makes it compatible with the traditional +Open MatSciML Toolkit workflow of leveraging one or more output heads. + + +.. autoclass:: matsciml.common.types.Embeddings + :members: + + +.. autoclass:: matsciml.common.types.ModelOutput + :members: + + +.. important:: + Training tasks and workflows should branch based on the prescence of either + objects, taking ``ModelOutput`` as the priority. For specific tasks, we can + check if properties are set (e.g. ``total_energy`` and ``forces``), and if + they aren't there, we should pass the ``embeddings`` to output heads. + + PyG models ---------- diff --git a/matsciml/models/base.py b/matsciml/models/base.py index 68bd2464..27d10e54 100644 --- a/matsciml/models/base.py +++ b/matsciml/models/base.py @@ -21,7 +21,13 @@ from matsciml.common import package_registry from matsciml.common.registry import registry -from matsciml.common.types import AbstractGraph, BatchDict, DataDict, Embeddings +from matsciml.common.types import ( + AbstractGraph, + BatchDict, + DataDict, + Embeddings, + ModelOutput, +) from matsciml.models.common import OutputHead from matsciml.modules.normalizer import Normalizer from matsciml.models import losses as matsciml_losses @@ -205,7 +211,8 @@ def num_params(self): class AbstractTask(ABC, pl.LightningModule): - # TODO the intention is for this class to supersede AbstractEnergyModel for DGL + __skip_output_heads__ = False # allows wrapper models to bypass output heads + def __init__( self, atom_embedding_dim: int, @@ -258,7 +265,7 @@ def read_batch(self, batch: BatchDict) -> DataDict: def read_batch_size(self, batch: BatchDict) -> int | None: ... @abstractmethod - def _forward(self, *args, **kwargs) -> Embeddings: + def _forward(self, *args, **kwargs) -> Embeddings | ModelOutput: """ Implements the actual logic of the architecture. Given a set of input features, produce outputs/predictions from the model. @@ -270,7 +277,7 @@ def _forward(self, *args, **kwargs) -> Embeddings: """ ... - def forward(self, batch: BatchDict) -> Embeddings: + def forward(self, batch: BatchDict) -> Embeddings | ModelOutput: """ Given a batch structure, extract out data and pass it into the neural network architecture. This implements the 'forward' method @@ -285,16 +292,20 @@ def forward(self, batch: BatchDict) -> Embeddings: Returns ------- - Embeddings - Data structure containing system/graph and point/node level embeddings. + Embeddings | ModelOutput + For models that do not have their own output mechanism, this + emits a data structure containing system/graph and point/node + level embeddings. If they provide their own outputs, it should + be packaged in the ``ModelOutput`` data structure. """ input_data = self.read_batch(batch) outputs = self._forward(**input_data) - # raise an error to help spot models that have not yet been refactored - if not isinstance(outputs, Embeddings): - raise ValueError( - "Encoder did not return `Embeddings` data structure: please refactor your model!", - ) + if not self.__skip_output_heads__: + # raise an error to help spot models that have not yet been refactored + if not isinstance(outputs, Embeddings): + raise ValueError( + "Encoder did not return `Embeddings` data structure: please refactor your model!", + ) return outputs @@ -373,7 +384,7 @@ def _forward( mask: torch.Tensor | None = None, sizes: list[int] | None = None, **kwargs, - ) -> Embeddings: + ) -> Embeddings | ModelOutput: """ Sets expected patterns for args for point cloud based modeling, whereby the bare minimum expected data are 'pos' and 'pc_features' akin to graph @@ -516,7 +527,7 @@ def _forward( edge_feats: torch.Tensor | None = None, graph_feats: torch.Tensor | None = None, **kwargs, - ) -> Embeddings: + ) -> Embeddings | ModelOutput: """ Sets args/kwargs for the expected components of a graph-based model. At the bare minimum, we expect some kind of abstract @@ -728,9 +739,17 @@ def __init__( # convert to a module dict for consistent API usage loss_func = nn.ModuleDict({key: value for key, value in loss_func.items()}) self.loss_func = loss_func - default_heads = {"act_last": None, "hidden_dim": 128} - default_heads.update(output_kwargs) - self.output_kwargs = default_heads + # only add output kwargs if we are going to use them + if not encoder.__skip_output_heads__: + default_heads = {"act_last": None, "hidden_dim": 128} + default_heads.update(output_kwargs) + self.output_kwargs = default_heads + else: + # emit warning to user if the kwargs aren't being used + if output_kwargs: + logger.warning( + f"Specified encoder {encoder.__class__.__name__} skips output heads; ignoring output kwargs." + ) self.normalize_kwargs = normalize_kwargs self.task_keys = task_keys self._task_loss_scaling = kwargs.get("task_loss_scaling", {}) @@ -763,7 +782,14 @@ def task_keys(self, values: set | list[str] | None) -> None: # if we're setting task keys we have enough to initialize # the output heads if not self.has_initialized: - self.output_heads = self._make_output_heads() + if not self.encoder.__skip_output_heads__: + self.output_heads = self._make_output_heads() + else: + # keeping the keys to allow functionality that doesn't need + # the actual weights + self.output_heads = nn.ModuleDict( + {key: None for key in self._task_keys} + ) self.normalizers = self._make_normalizers() # homogenize it into a dictionary mapping if isinstance(self.loss_func, nn.Module) and not isinstance( @@ -882,10 +908,29 @@ def has_rnn(self) -> bool: def forward( self, batch: dict[str, torch.Tensor | dgl.DGLGraph | dict[str, torch.Tensor]], - ) -> dict[str, torch.Tensor]: - embeddings = self.encoder(batch) - batch["embeddings"] = embeddings - outputs = self.process_embedding(embeddings) + ) -> dict[str, torch.Tensor] | ModelOutput: + encoder_outputs = self.encoder(batch) + # in the case that the model does not produce its own outputs, + # we will pass them through the output heads. + if not self.encoder.__skip_output_heads__: + if not isinstance(encoder_outputs, (Embeddings, dict, ModelOutput)): + raise RuntimeError( + f"Encoder model must emit a dict, `ModelOutput`, or `Embeddings` object. Got {encoder_outputs} instead." + ) + if isinstance(encoder_outputs, Embeddings): + batch["embeddings"] = encoder_outputs + outputs = self.process_embedding(encoder_outputs) + elif isinstance(encoder_outputs, ModelOutput): + batch["embeddings"] = encoder_outputs.embeddings + outputs = self.process_embedding(encoder_outputs.embeddings) + else: + # here we assume that encoder model is predicting directly + outputs = encoder_outputs + # optionally if we still have embeddings, keep em + if "embeddings" in encoder_outputs: + batch["embeddings"] = encoder_outputs["embeddings"] + else: + outputs = encoder_outputs return outputs def process_embedding(self, embeddings: Embeddings) -> dict[str, torch.Tensor]: @@ -1056,7 +1101,13 @@ def _compute_losses( loss_func = self.loss_func[key] # determine if we need additional arguments loss_func_signature = signature(loss_func.forward).parameters - kwargs = {"input": predictions[key], "target": target_val} + # TODO refactor this once outputs are homogenized + if isinstance(predictions, dict): + kwargs = {"input": predictions[key], "target": target_val} + else: + kwargs = {"input": getattr(predictions, key), "target": target_val} + if not isinstance(kwargs["input"], torch.Tensor): + raise KeyError(f"Expected model to produce output with key {key}.") # pack atoms per graph information too if "atoms_per_graph" in loss_func_signature: if graph := batch.get("graph", None): @@ -1871,10 +1922,38 @@ def forward( fa_pos.requires_grad_(True) elif isinstance(fa_pos, list): [f_p.requires_grad_(True) for f_p in fa_pos] + # check to see if embeddings were stashed away if "embeddings" in batch: embeddings = batch.get("embeddings") else: - embeddings = self.encoder(batch) + encoder_outputs = self.encoder(batch) + if not isinstance(encoder_outputs, (Embeddings, dict, ModelOutput)): + raise RuntimeError( + f"Encoder model must emit a dict, `ModelOutput`, or `Embeddings` object. Got {encoder_outputs} instead." + ) + # sets the embeddings variable + if isinstance(encoder_outputs, Embeddings): + embeddings = encoder_outputs + # for BYO output head cases + elif isinstance(encoder_outputs, ModelOutput): + # this checks to make sure we have the expected attributes + for key in ["total_energy", "forces"]: + if getattr(encoder_outputs, key, None) is None: + raise RuntimeError( + f"Model {self.encoder.__class__.__name__} is not emitting {key} in model outputs." + ) + # map the outputs as expected by the task + return { + "energy": encoder_outputs.total_energy, + "force": encoder_outputs.forces, + } + # in the alternative case we assume the encoder is emitting predictions + else: + for key in ["energy", "force"]: + assert ( + key in encoder_outputs + ), f"Expected {key} to be in encoder outputs." + return encoder_outputs outputs = self.process_embedding(embeddings, batch, pos, fa_rot, fa_pos) return outputs diff --git a/matsciml/models/pyg/__init__.py b/matsciml/models/pyg/__init__.py index a6b334a4..e6ed251b 100644 --- a/matsciml/models/pyg/__init__.py +++ b/matsciml/models/pyg/__init__.py @@ -19,10 +19,10 @@ # load models if we have PyG installed if _has_pyg: from matsciml.models.pyg.egnn import EGNN + from matsciml.models.pyg.mace import MACE, ScaleShiftMACE, MACEWrapper from matsciml.models.pyg.faenet import FAENet - from matsciml.models.pyg.mace import MACE, ScaleShiftMACE - __all__ = ["EGNN", "FAENet", "MACE", "ScaleShiftMACE"] + __all__ = ["CGCNN", "EGNN", "FAENet", "MACE", "ScaleShiftMACE", "MACEWrapper"] # these packages need additional pyg dependencies if package_registry["torch_sparse"] and package_registry["torch_scatter"]: @@ -39,7 +39,7 @@ from matsciml.models.pyg.schnet import SchNetWrap # noqa: F401 from matsciml.models.pyg.cgcnn import CGCNN # noqa: F401 - __all__.extend(["ForceNet", "SchNetWrap", "FAENet", "CGCNN"]) + __all__.extend(["ForceNet", "SchNetWrap", "CGCNN"]) else: logger.warning( "Missing torch_scatter; ForceNet, SchNet, and FAENet models will not be available." diff --git a/matsciml/models/pyg/mace/wrapper/model.py b/matsciml/models/pyg/mace/wrapper/model.py index e47c4184..022891e8 100644 --- a/matsciml/models/pyg/mace/wrapper/model.py +++ b/matsciml/models/pyg/mace/wrapper/model.py @@ -11,7 +11,13 @@ from mendeleev import element from matsciml.models.base import AbstractPyGModel -from matsciml.common.types import BatchDict, DataDict, AbstractGraph, Embeddings +from matsciml.common.types import ( + BatchDict, + DataDict, + AbstractGraph, + ModelOutput, + Embeddings, +) from matsciml.common.registry import registry from matsciml.common.inspection import get_model_required_args, get_model_all_args @@ -47,6 +53,8 @@ def free_ion_energy_table(num_elements: int = 100) -> torch.Tensor: @registry.register_model("MACEWrapper") class MACEWrapper(AbstractPyGModel): + __skip_output_heads__ = True # MACE is BYO output head due to expansion + def __init__( self, atom_embedding_dim: int, @@ -55,6 +63,7 @@ def __init__( embedding_kwargs: Any = None, encoder_only: bool = True, readout_method: str | Callable = "add", + disable_forces: bool = True, **mace_kwargs, ) -> None: if embedding_kwargs is not None: @@ -164,6 +173,7 @@ def read_batch(self, batch: BatchDict) -> DataDict: "node_feats": one_hot_atoms, "cell": batch["cell"], "shifts": batch["offsets"], + "unit_shifts": batch["unit_offsets"], } ) return data @@ -174,7 +184,7 @@ def _forward( node_feats: torch.Tensor, pos: torch.Tensor, **kwargs, - ) -> Embeddings: + ) -> ModelOutput: """ Takes arguments in the standardized format, and passes them into MACE with some redundant mapping. @@ -192,8 +202,8 @@ def _forward( Returns ------- - Embeddings - MatSciML ``Embeddings`` structure + ModelOutput + MatSciML ``ModelOutput`` data structure. """ # repack data into MACE format mace_data = { @@ -202,17 +212,26 @@ def _forward( "ptr": graph.ptr, "cell": kwargs["cell"], "shifts": kwargs["shifts"], + "unit_shifts": kwargs["unit_shifts"], "batch": graph.batch, "edge_index": graph.edge_index, } outputs = self.encoder( mace_data, training=self.training, - compute_force=False, + compute_force=not self.hparams["disable_forces"], compute_virials=False, compute_stress=False, - compute_displacement=False, + compute_displacement=not self.hparams["disable_forces"], ) node_embeddings = outputs["node_feats"] graph_embeddings = self.readout(node_embeddings, graph.batch) - return Embeddings(graph_embeddings, node_embeddings) + embeddings = Embeddings(graph_embeddings, node_embeddings) + output = ModelOutput( + batch_size=graph.batch_size, + forces=outputs["forces"], + total_energy=outputs["energy"], + node_energies=outputs["node_energy"], + embeddings=embeddings, + ) + return output diff --git a/matsciml/models/pyg/mace/wrapper/tests/test_mace_wrapper.py b/matsciml/models/pyg/mace/wrapper/tests/test_mace_wrapper.py index 21c669a7..39f3f917 100644 --- a/matsciml/models/pyg/mace/wrapper/tests/test_mace_wrapper.py +++ b/matsciml/models/pyg/mace/wrapper/tests/test_mace_wrapper.py @@ -97,10 +97,18 @@ def test_model_forward_nograd(dset_class_name: str, mace_architecture: MACEWrapp batch = next(iter(loader)) # run the model without gradient tracking with torch.no_grad(): - embeddings = mace_architecture(batch) + model_output = mace_architecture(batch) + embeddings = model_output.embeddings # returns embeddings, and runs numerical checks for z in [embeddings.system_embedding, embeddings.point_embedding]: assert torch.isreal(z).all() assert ~torch.isnan(z).all() # check there are no NaNs assert torch.isfinite(z).all() assert torch.all(torch.abs(z) <= 1000) # ensure reasonable values + # check energies as finite + graph_energies = model_output.total_energy + assert torch.isreal(graph_energies).all() + assert torch.isfinite(graph_energies).all() + node_energies = model_output.node_energies + assert torch.isreal(node_energies).all() + assert torch.isfinite(node_energies).all() diff --git a/matsciml/models/tests/test_tasks.py b/matsciml/models/tests/test_tasks.py index 3152fe1c..a58f2fd4 100644 --- a/matsciml/models/tests/test_tasks.py +++ b/matsciml/models/tests/test_tasks.py @@ -2,6 +2,9 @@ import pytest import lightning.pytorch as pl +import e3nn +import torch +import mace from matsciml.datasets import MaterialsProjectDataset from matsciml.datasets.transforms import ( @@ -11,7 +14,7 @@ FrameAveraging, ) from matsciml.lightning.data_utils import MatSciMLDataModule -from matsciml.models import PLEGNNBackbone, FAENet +from matsciml.models import PLEGNNBackbone, FAENet, MACEWrapper from matsciml.models.base import ( ForceRegressionTask, GradFreeForceRegressionTask, @@ -69,6 +72,33 @@ def faenet_config(): return {"encoder_class": FAENet, "encoder_kwargs": model_args} +@pytest.fixture +def mace_config(): + model_args = { + "r_max": 6.0, + "radial_type": "bessel", + "distance_transform": None, + "num_polynomial_cutoff": 5.0, + "num_interactions": 2, + "num_bessel": 8, + "num_atom_embedding": 100, + "max_ell": 3, + "gate": torch.nn.SiLU(), + "interaction_cls": mace.modules.blocks.RealAgnosticResidualInteractionBlock, + "interaction_cls_first": mace.modules.blocks.RealAgnosticResidualInteractionBlock, + "correlation": 3, + "avg_num_neighbors": 31.0, + "atomic_inter_scale": 0.21, + "atomic_inter_shift": 0.0, + "atom_embedding_dim": 128, + "MLP_irreps": e3nn.o3.Irreps("16x0e"), + "hidden_irreps": e3nn.o3.Irreps("128x0e + 128x1o"), + "mace_module": mace.modules.ScaleShiftMACE, + "disable_forces": False, + } + return {"encoder_class": MACEWrapper, "encoder_kwargs": model_args} + + def test_force_regression(egnn_config): devset = MatSciMLDataModule.from_devset( "S2EFDataset", @@ -91,6 +121,27 @@ def test_force_regression(egnn_config): assert f"train_{key}" in trainer.logged_metrics +def test_force_regression_byo_output(mace_config): + devset = MatSciMLDataModule.from_devset( + "S2EFDataset", + dset_kwargs={ + "transforms": [ + PeriodicPropertiesTransform(cutoff_radius=6.0, adaptive_cutoff=True), + PointCloudToGraphTransform( + "pyg", + node_keys=["pos", "atomic_numbers"], + ), + ], + }, + ) + task = ForceRegressionTask(**mace_config) + trainer = pl.Trainer(max_steps=5, logger=False, enable_checkpointing=False) + trainer.fit(task, datamodule=devset) + # make sure losses are tracked + for key in ["energy", "force"]: + assert f"train_{key}" in trainer.logged_metrics + + def test_fa_force_regression(faenet_config): devset = MatSciMLDataModule.from_devset( "S2EFDataset",