diff --git a/matsciml/common/tests/test_types.py b/matsciml/common/tests/test_types.py new file mode 100644 index 00000000..f7836948 --- /dev/null +++ b/matsciml/common/tests/test_types.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +import torch +import pytest +from pydantic import ValidationError + +from matsciml.common import types + + +def test_output_validation(): + """Simple unit test to make sure a minimal configuration works.""" + types.ModelOutput(batch_size=16, embeddings=None, node_energies=None) + + +def test_invalid_embeddings(): + """Type invalid embeddings""" + with pytest.raises(ValidationError): + types.ModelOutput(batch_size=8, embeddings="aggaga") + + +def test_valid_embeddings(): + embeddings = types.Embeddings( + system_embedding=torch.rand(64, 32), point_embedding=torch.rand(162, 32) + ) + types.ModelOutput(batch_size=64, embeddings=embeddings) + + +def test_incorrect_force_shape(): + """This passes a force tensor with too many dimensions""" + with pytest.raises(ValidationError): + types.ModelOutput(batch_size=8, forces=torch.rand(32, 4, 3)) + + +def test_consistency_check_pass(): + types.ModelOutput( + batch_size=8, forces=torch.rand(32, 3), node_energies=torch.rand(32, 1) + ) + + +def test_consistency_check_fail(): + with pytest.raises(RuntimeError): + # check mismatching node energies and forces + types.ModelOutput( + batch_size=8, forces=torch.rand(32, 3), node_energies=torch.rand(64, 1) + ) + with pytest.raises(RuntimeError): + # check mismatch in number of energies and batch size + types.ModelOutput(batch_size=4, total_energy=torch.rand(16, 1)) diff --git a/matsciml/common/types.py b/matsciml/common/types.py index ad32431b..be9642fe 100644 --- a/matsciml/common/types.py +++ b/matsciml/common/types.py @@ -1,9 +1,10 @@ from __future__ import annotations -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import Callable, Union import torch +from pydantic import ConfigDict, Field, field_validator, BaseModel, model_validator from matsciml.common import package_registry @@ -44,7 +45,7 @@ class Embeddings: system_embedding: torch.Tensor | None = None point_embedding: torch.Tensor | None = None reduction: str | Callable | None = None - reduction_kwargs: dict[str, str | float] = field(default_factory=dict) + reduction_kwargs: dict[str, str | float] = Field(default_factory=dict) @property def num_points(self) -> int: @@ -90,3 +91,125 @@ def reduce_point_embeddings( system_embeddings = reduction(self.point_embedding, **self.reduction_kwargs) self.system_embedding = system_embeddings return system_embeddings + + +class ModelOutput(BaseModel): + """ + Standardized output data structure out of models. + + The advantage of doing is to standardize keys, as well + as to standardize shapes the are produced by models; + i.e. remove unused dimensions using ``pydantic`` + validation mechanisms. + """ + + batch_size: int + embeddings: Embeddings | None = None + node_energies: torch.Tensor | None = None + total_energy: torch.Tensor | None = None + forces: torch.Tensor | None = None + stresses: torch.Tensor | None = None + + model_config = ConfigDict(arbitrary_types_allowed=True) + + @field_validator("total_energy", mode="before") + @classmethod + def standardize_total_energy( + cls, values: torch.Tensor | None + ) -> torch.Tensor | None: + """ + Check to ensure the total energy tensor being passed + is ultimately scalar. + + Parameters + ---------- + values : torch.Tensor + Tensor holding energy values for each graph/system + within a batch. + + Returns + ------- + torch.Tensor + 1-D tensor containing energies for each graph/system + within a batch. + + Raises + ------ + ValueError: + If after running ``squeeze`` on the input tensor, the + dimensions are still greater than one we raise a + ``ValueError``. + """ + if isinstance(values, torch.Tensor): + # drop all redundant dimensions + values = values.squeeze() + # last step is an assertion check for QA + if values.ndim != 1: + raise ValueError( + f"Expected graph/system energies to be scalar; got shape {values.shape}" + ) + return values + + @field_validator("forces", mode="after") + @classmethod + def check_force_shape(cls, forces: torch.Tensor | None) -> torch.Tensor | None: + """ + Check to ensure that the force tensor has the expected + shape. Runs after the type checking by ``pydantic``. + + Parameters + ---------- + forces : torch.Tensor + Force tensor to check. + + Returns + ------- + torch.Tensor + Validated force tensor without modifications. + + Raises + ------ + ValueError: + If the dimensionality of the tensor is not 2D, and/or + if the last dimensionality of the tensor is not 3-long. + """ + if isinstance(forces, torch.Tensor): + if forces.ndim != 2: + raise ValueError(f"Expected force tensor to be 2D; got {forces.shape}.") + if forces.size(-1) != 3: + raise ValueError( + f"Expected last dimension of forces to be length 3; got {forces.shape}." + ) + return forces + + @model_validator(mode="after") + def consistency_checks(self) -> ModelOutput: + """ + Performs general consistency checks based on what data is provided. + + Raises + ------ + RuntimeError: + If the number of node energies and forces are mismatched; + if the number of predicted system/graph energies do not + match the batch size. + """ + if isinstance(self.node_energies, torch.Tensor) and isinstance( + self.forces, torch.Tensor + ): + if not self.node_energies.size(0) == self.forces.size(0): + raise RuntimeError( + f"Expected node energies and forces to be same shape; got {self.node_energies.shape} node energies and {self.forces.shape} forces." + ) + if isinstance(self.total_energy, torch.Tensor): + if not len(self.total_energy) == self.batch_size: + raise RuntimeError( + f"Batch size ({self.batch_size}) and energy mismatch ({len(self.total_energy)})." + ) + if isinstance(self.embeddings, Embeddings): + if not self.embeddings.system_embedding.size(0) == self.batch_size: + raise RuntimeError( + f"Expected {self.batch_size} system embeddings; got {self.embeddings.system_embedding.size(0)}." + ) + + return self