Skip to content

Commit

Permalink
Merge pull request #315 from laserkelvin/structured-model-output-type
Browse files Browse the repository at this point in the history
Structured model output type
  • Loading branch information
laserkelvin authored Nov 12, 2024
2 parents 2caaee3 + f2dc62f commit 406bcbf
Show file tree
Hide file tree
Showing 2 changed files with 173 additions and 2 deletions.
48 changes: 48 additions & 0 deletions matsciml/common/tests/test_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from __future__ import annotations

import torch
import pytest
from pydantic import ValidationError

from matsciml.common import types


def test_output_validation():
"""Simple unit test to make sure a minimal configuration works."""
types.ModelOutput(batch_size=16, embeddings=None, node_energies=None)


def test_invalid_embeddings():
"""Type invalid embeddings"""
with pytest.raises(ValidationError):
types.ModelOutput(batch_size=8, embeddings="aggaga")


def test_valid_embeddings():
embeddings = types.Embeddings(
system_embedding=torch.rand(64, 32), point_embedding=torch.rand(162, 32)
)
types.ModelOutput(batch_size=64, embeddings=embeddings)


def test_incorrect_force_shape():
"""This passes a force tensor with too many dimensions"""
with pytest.raises(ValidationError):
types.ModelOutput(batch_size=8, forces=torch.rand(32, 4, 3))


def test_consistency_check_pass():
types.ModelOutput(
batch_size=8, forces=torch.rand(32, 3), node_energies=torch.rand(32, 1)
)


def test_consistency_check_fail():
with pytest.raises(RuntimeError):
# check mismatching node energies and forces
types.ModelOutput(
batch_size=8, forces=torch.rand(32, 3), node_energies=torch.rand(64, 1)
)
with pytest.raises(RuntimeError):
# check mismatch in number of energies and batch size
types.ModelOutput(batch_size=4, total_energy=torch.rand(16, 1))
127 changes: 125 additions & 2 deletions matsciml/common/types.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations

from dataclasses import dataclass, field
from dataclasses import dataclass
from typing import Callable, Union

import torch
from pydantic import ConfigDict, Field, field_validator, BaseModel, model_validator

from matsciml.common import package_registry

Expand Down Expand Up @@ -44,7 +45,7 @@ class Embeddings:
system_embedding: torch.Tensor | None = None
point_embedding: torch.Tensor | None = None
reduction: str | Callable | None = None
reduction_kwargs: dict[str, str | float] = field(default_factory=dict)
reduction_kwargs: dict[str, str | float] = Field(default_factory=dict)

@property
def num_points(self) -> int:
Expand Down Expand Up @@ -90,3 +91,125 @@ def reduce_point_embeddings(
system_embeddings = reduction(self.point_embedding, **self.reduction_kwargs)
self.system_embedding = system_embeddings
return system_embeddings


class ModelOutput(BaseModel):
"""
Standardized output data structure out of models.
The advantage of doing is to standardize keys, as well
as to standardize shapes the are produced by models;
i.e. remove unused dimensions using ``pydantic``
validation mechanisms.
"""

batch_size: int
embeddings: Embeddings | None = None
node_energies: torch.Tensor | None = None
total_energy: torch.Tensor | None = None
forces: torch.Tensor | None = None
stresses: torch.Tensor | None = None

model_config = ConfigDict(arbitrary_types_allowed=True)

@field_validator("total_energy", mode="before")
@classmethod
def standardize_total_energy(
cls, values: torch.Tensor | None
) -> torch.Tensor | None:
"""
Check to ensure the total energy tensor being passed
is ultimately scalar.
Parameters
----------
values : torch.Tensor
Tensor holding energy values for each graph/system
within a batch.
Returns
-------
torch.Tensor
1-D tensor containing energies for each graph/system
within a batch.
Raises
------
ValueError:
If after running ``squeeze`` on the input tensor, the
dimensions are still greater than one we raise a
``ValueError``.
"""
if isinstance(values, torch.Tensor):
# drop all redundant dimensions
values = values.squeeze()
# last step is an assertion check for QA
if values.ndim != 1:
raise ValueError(
f"Expected graph/system energies to be scalar; got shape {values.shape}"
)
return values

@field_validator("forces", mode="after")
@classmethod
def check_force_shape(cls, forces: torch.Tensor | None) -> torch.Tensor | None:
"""
Check to ensure that the force tensor has the expected
shape. Runs after the type checking by ``pydantic``.
Parameters
----------
forces : torch.Tensor
Force tensor to check.
Returns
-------
torch.Tensor
Validated force tensor without modifications.
Raises
------
ValueError:
If the dimensionality of the tensor is not 2D, and/or
if the last dimensionality of the tensor is not 3-long.
"""
if isinstance(forces, torch.Tensor):
if forces.ndim != 2:
raise ValueError(f"Expected force tensor to be 2D; got {forces.shape}.")
if forces.size(-1) != 3:
raise ValueError(
f"Expected last dimension of forces to be length 3; got {forces.shape}."
)
return forces

@model_validator(mode="after")
def consistency_checks(self) -> ModelOutput:
"""
Performs general consistency checks based on what data is provided.
Raises
------
RuntimeError:
If the number of node energies and forces are mismatched;
if the number of predicted system/graph energies do not
match the batch size.
"""
if isinstance(self.node_energies, torch.Tensor) and isinstance(
self.forces, torch.Tensor
):
if not self.node_energies.size(0) == self.forces.size(0):
raise RuntimeError(
f"Expected node energies and forces to be same shape; got {self.node_energies.shape} node energies and {self.forces.shape} forces."
)
if isinstance(self.total_energy, torch.Tensor):
if not len(self.total_energy) == self.batch_size:
raise RuntimeError(
f"Batch size ({self.batch_size}) and energy mismatch ({len(self.total_energy)})."
)
if isinstance(self.embeddings, Embeddings):
if not self.embeddings.system_embedding.size(0) == self.batch_size:
raise RuntimeError(
f"Expected {self.batch_size} system embeddings; got {self.embeddings.system_embedding.size(0)}."
)

return self

0 comments on commit 406bcbf

Please sign in to comment.