-
Notifications
You must be signed in to change notification settings - Fork 25
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #315 from laserkelvin/structured-model-output-type
Structured model output type
- Loading branch information
Showing
2 changed files
with
173 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
from __future__ import annotations | ||
|
||
import torch | ||
import pytest | ||
from pydantic import ValidationError | ||
|
||
from matsciml.common import types | ||
|
||
|
||
def test_output_validation(): | ||
"""Simple unit test to make sure a minimal configuration works.""" | ||
types.ModelOutput(batch_size=16, embeddings=None, node_energies=None) | ||
|
||
|
||
def test_invalid_embeddings(): | ||
"""Type invalid embeddings""" | ||
with pytest.raises(ValidationError): | ||
types.ModelOutput(batch_size=8, embeddings="aggaga") | ||
|
||
|
||
def test_valid_embeddings(): | ||
embeddings = types.Embeddings( | ||
system_embedding=torch.rand(64, 32), point_embedding=torch.rand(162, 32) | ||
) | ||
types.ModelOutput(batch_size=64, embeddings=embeddings) | ||
|
||
|
||
def test_incorrect_force_shape(): | ||
"""This passes a force tensor with too many dimensions""" | ||
with pytest.raises(ValidationError): | ||
types.ModelOutput(batch_size=8, forces=torch.rand(32, 4, 3)) | ||
|
||
|
||
def test_consistency_check_pass(): | ||
types.ModelOutput( | ||
batch_size=8, forces=torch.rand(32, 3), node_energies=torch.rand(32, 1) | ||
) | ||
|
||
|
||
def test_consistency_check_fail(): | ||
with pytest.raises(RuntimeError): | ||
# check mismatching node energies and forces | ||
types.ModelOutput( | ||
batch_size=8, forces=torch.rand(32, 3), node_energies=torch.rand(64, 1) | ||
) | ||
with pytest.raises(RuntimeError): | ||
# check mismatch in number of energies and batch size | ||
types.ModelOutput(batch_size=4, total_energy=torch.rand(16, 1)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters