Skip to content

Commit

Permalink
Merge branch 'main' into ref-add-ani2x-test
Browse files Browse the repository at this point in the history
  • Loading branch information
wiederm authored Oct 23, 2024
2 parents 2b03d32 + d6be993 commit 685de67
Show file tree
Hide file tree
Showing 18 changed files with 69 additions and 52 deletions.
2 changes: 1 addition & 1 deletion modelforge/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .spice2 import SPICE2Dataset
from .spice1openff import SPICE1OpenFFDataset
from .phalkethoh import PhAlkEthOHDataset
from .dataset import DatasetFactory, DataModule, NNPInput
from .dataset import DatasetFactory, DataModule
from enum import Enum


Expand Down
13 changes: 5 additions & 8 deletions modelforge/jax.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from modelforge.dataset import NNPInput
from modelforge.utils.prop import NNPInput


def nnpinput_flatten(nnpinput: NNPInput):
Expand Down Expand Up @@ -42,12 +42,9 @@ def convert_NNPInput_to_jax(nnp_input: NNPInput):
nnp_input.box_vectors = convert_to_jax(nnp_input.box_vectors)
nnp_input.is_periodic = convert_to_jax(nnp_input.is_periodic)

if nnp_input.pair_list is not None:
nnp_input.pair_list = convert_to_jax(nnp_input.pair_list)

if nnp_input.per_atom_partial_charge is not None:
nnp_input.per_atom_partial_charge = convert_to_jax(
nnp_input.per_atom_partial_charge
)
nnp_input.pair_list = convert_to_jax(nnp_input.pair_list)
nnp_input.per_atom_partial_charge = convert_to_jax(
nnp_input.per_atom_partial_charge
)

return nnp_input
3 changes: 1 addition & 2 deletions modelforge/potential/ani.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@
from loguru import logger as log
from torch import nn

from modelforge.utils.prop import SpeciesAEV
from modelforge.utils.prop import SpeciesAEV, NNPInput

from modelforge.dataset.dataset import NNPInput
from modelforge.potential.neighbors import PairlistData


Expand Down
2 changes: 1 addition & 1 deletion modelforge/potential/featurization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Dict, List, Union

from modelforge.potential.utils import DenseWithCustomDist
from modelforge.dataset import NNPInput
from modelforge.utils.prop import NNPInput


class AddPerMoleculeValue(nn.Module):
Expand Down
14 changes: 5 additions & 9 deletions modelforge/potential/neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,20 +790,15 @@ def _calculate_interacting_pairs(
Atom positions. Shape: [nr_systems, nr_atoms, 3].
atomic_subsystem_indices : torch.Tensor
Indices identifying atoms in subsystems. Shape: [nr_atoms].
pair_indices : Optional[torch.Tensor]
Precomputed pair indices. If None, will compute pair indices.
pair_indices : torch.Tensor
Precomputed pair indices.
Returns
-------
PairListOutputs
A dataclass containing 'pair_indices', 'd_ij' (distances), and 'r_ij' (displacement vectors).
"""

if pair_indices is None:
pair_indices = self.enumerate_all_pairs(
atomic_subsystem_indices,
)

r_ij = self.calculate_r_ij(pair_indices, positions)
d_ij = self.calculate_d_ij(r_ij)

Expand Down Expand Up @@ -838,8 +833,9 @@ def forward(self, data: Union[NNPInput, NamedTuple]) -> PairlistData:
# general input manipulation
positions = data.positions
atomic_subsystem_indices = data.atomic_subsystem_indices
# calculate pairlist if none is provided
if data.pair_list is None:
# calculate pairlist if it is not provided

if data.pair_list is None or data.pair_list.shape[0] == 0:
# note, we set the flag for unique pairs when instantiated in the constructor
# and thus this call will return unique pairs if requested.
pair_list = self.pairlist.enumerate_all_pairs(atomic_subsystem_indices)
Expand Down
2 changes: 1 addition & 1 deletion modelforge/potential/painn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from loguru import logger as log
from openff.units import unit

from modelforge.dataset.dataset import NNPInput
from modelforge.utils.prop import NNPInput
from modelforge.potential.neighbors import PairlistData
from .utils import DenseWithCustomDist

Expand Down
2 changes: 1 addition & 1 deletion modelforge/potential/physnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from loguru import logger as log
from torch import nn

from modelforge.dataset.dataset import NNPInput
from modelforge.utils.prop import NNPInput
from modelforge.potential.neighbors import PairlistData
from .utils import Dense

Expand Down
1 change: 0 additions & 1 deletion modelforge/potential/potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,6 @@ def generate_potential(

jax = import_("jax")
from modelforge.jax import nnpinput_flatten, nnpinput_unflatten
from modelforge.dataset import NNPInput

# registering NNPInput multiple times will result in a
# ValueError
Expand Down
2 changes: 1 addition & 1 deletion modelforge/potential/sake.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch.nn as nn
from loguru import logger as log

from modelforge.dataset.dataset import NNPInput
from modelforge.utils.prop import NNPInput
from modelforge.potential.neighbors import PairlistData

from .utils import DenseWithCustomDist, scatter_softmax
Expand Down
2 changes: 1 addition & 1 deletion modelforge/potential/schnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch.nn as nn
from loguru import logger as log

from modelforge.dataset.dataset import NNPInput
from modelforge.utils.prop import NNPInput
from modelforge.potential.neighbors import PairlistData


Expand Down
2 changes: 1 addition & 1 deletion modelforge/potential/tensornet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from modelforge.potential import CosineAttenuationFunction, TensorNetRadialBasisFunction

from modelforge.dataset.dataset import NNPInput
from modelforge.utils.prop import NNPInput
from modelforge.potential.neighbors import PairlistData


Expand Down
2 changes: 1 addition & 1 deletion modelforge/potential/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch.nn as nn
from openff.units import unit

from modelforge.dataset.dataset import NNPInput
from modelforge.utils.prop import NNPInput


@dataclass(frozen=False)
Expand Down
2 changes: 1 addition & 1 deletion modelforge/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def generate_uniform_quaternion(u=None):
def rotation_matrix_from_quaternion(quaternion):
"""Compute a 3x3 rotation matrix from a given quaternion (4-vector).
Adapted from the numpy implementation in openmm-tools
Adapted from the numpy implementation in modelforgeopenmm-tools
https://github.com/choderalab/openmmtools/blob/main/openmmtools/mcmc.py
Expand Down
2 changes: 1 addition & 1 deletion modelforge/tests/precalculated_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def setup_single_methane_input():
)
E = torch.tensor([0.0], requires_grad=True)
atomic_subsystem_indices = torch.tensor([0, 0, 0, 0, 0], dtype=torch.int32)
from modelforge.dataset.dataset import NNPInput
from modelforge.utils.prop import NNPInput

modelforge_methane = NNPInput(
atomic_numbers=atomic_numbers,
Expand Down
4 changes: 2 additions & 2 deletions modelforge/tests/test_ani.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def setup_methane():
[0, 0, 0, 0, 0], dtype=torch.int32, device=device
)

from modelforge.dataset.dataset import NNPInput
from modelforge.utils.prop import NNPInput

nnp_input = NNPInput(
atomic_numbers=torch.tensor([6, 1, 1, 1, 1], device=device),
Expand Down Expand Up @@ -80,7 +80,7 @@ def setup_two_methanes():
)

atomic_numbers = mf_species
from modelforge.dataset.dataset import NNPInput
from modelforge.utils.prop import NNPInput

nnp_input = NNPInput(
atomic_numbers=atomic_numbers,
Expand Down
20 changes: 8 additions & 12 deletions modelforge/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,7 @@

@pytest.fixture(scope="session")
def prep_temp_dir(tmp_path_factory):
import uuid

filename = str(uuid.uuid4())

fn = tmp_path_factory.mktemp(f"test_dataset_temp")
fn = tmp_path_factory.mktemp("test_dataset_temp")
return fn


Expand Down Expand Up @@ -258,7 +254,7 @@ def test_file_existence_after_initialization(
def test_caching(prep_temp_dir):
import contextlib

local_cache_dir = str(prep_temp_dir) + "/data_test"
local_cache_dir = str(prep_temp_dir)
from modelforge.dataset.qm9 import QM9Dataset

data = QM9Dataset(version_select="nc_1000_v0", local_cache_dir=local_cache_dir)
Expand Down Expand Up @@ -334,7 +330,7 @@ def test_metadata_validation(prep_temp_dir):
which is used to validate if we can use .npz file, or we need to
regenerate it."""

local_cache_dir = str(prep_temp_dir) + "/data_test"
local_cache_dir = str(prep_temp_dir)

from modelforge.dataset.qm9 import QM9Dataset

Expand Down Expand Up @@ -471,7 +467,7 @@ def test_data_item_format_of_datamodule(
"""Test the format of individual data items in the dataset."""
from typing import Dict

local_cache_dir = str(prep_temp_dir) + "/data_test"
local_cache_dir = str(prep_temp_dir)

dm = datamodule_factory(
dataset_name=dataset_name,
Expand Down Expand Up @@ -761,7 +757,7 @@ def test_dataset_downloader(dataset_name, dataset_factory, prep_temp_dir):
"""
Test the DatasetDownloader functionality.
"""
local_cache_dir = str(prep_temp_dir) + "/data_test"
local_cache_dir = str(prep_temp_dir)

dataset = dataset_factory(
dataset_name=dataset_name, local_cache_dir=local_cache_dir
Expand All @@ -773,15 +769,15 @@ def test_dataset_downloader(dataset_name, dataset_factory, prep_temp_dir):


@pytest.mark.parametrize("dataset_name", _ImplementedDatasets.get_all_dataset_names())
def test_numpy_dataset_assignment(dataset_name):
def test_numpy_dataset_assignment(dataset_name, prep_temp_dir):
"""
Test if the numpy_dataset attribute is correctly assigned after processing or loading.
"""
from modelforge.dataset import _ImplementedDatasets

factory = DatasetFactory()
data = _ImplementedDatasets.get_dataset_class(dataset_name)(
version_select="nc_1000_v0"
version_select="nc_1000_v0", local_cache_dir=str(prep_temp_dir)
)
factory._load_or_process_data(data)

Expand Down Expand Up @@ -1025,7 +1021,7 @@ def test_function_of_self_energy(dataset_name, datamodule_factory, prep_temp_dir


def test_shifting_center_of_mass_to_origin(prep_temp_dir):
local_cache_dir = str(prep_temp_dir) + "/data_test"
local_cache_dir = str(prep_temp_dir)

from modelforge.dataset.dataset import initialize_datamodule
from openff.units.elements import MASSES
Expand Down
34 changes: 34 additions & 0 deletions modelforge/tests/test_potentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,3 +1107,37 @@ def test_loading_from_checkpoint_file():

model = load_inference_model_from_checkpoint(chkp_file)
assert model is not None


@pytest.mark.parametrize(
"potential_name", _Implemented_NNPs.get_all_neural_network_names()
)
def test_saving_torchscript(potential_name, single_batch_with_batchsize, prep_temp_dir):
batch = single_batch_with_batchsize(
batch_size=1, dataset_name="QM9", local_cache_dir=str(prep_temp_dir)
)
from modelforge.potential.potential import NeuralNetworkPotentialFactory
from modelforge.utils.misc import load_configs_into_pydantic_models

config = load_configs_into_pydantic_models(potential_name.lower(), "qm9")

# read default parameters
potential = NeuralNetworkPotentialFactory.generate_potential(
use="inference",
potential_parameter=config["potential"],
potential_seed=42,
)
potential_jit = torch.jit.script(potential)
filename = f"{str(prep_temp_dir)}/{potential_name.lower()}_qm9_jit.pt"

potential_jit.save(filename)

output = potential(batch.nnp_input)
output_jit1 = potential_jit(batch.nnp_input)

# load the model
jit_model = torch.jit.load(filename)
output_jit2 = jit_model(batch.nnp_input)

assert torch.allclose(output["per_system_energy"], output_jit1["per_system_energy"])
assert torch.allclose(output["per_system_energy"], output_jit2["per_system_energy"])
12 changes: 4 additions & 8 deletions modelforge/utils/prop.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def __init__(
per_system_total_charge: torch.Tensor,
box_vectors: torch.Tensor = torch.zeros(3, 3),
is_periodic: torch.Tensor = torch.tensor([False]),
pair_list: torch.Tensor = None,
per_atom_partial_charge: torch.Tensor = None,
pair_list: torch.Tensor = torch.tensor([]),
per_atom_partial_charge: torch.Tensor = torch.tensor([]),
):
self.atomic_numbers = atomic_numbers
self.positions = positions
Expand Down Expand Up @@ -103,12 +103,8 @@ def to_device(self, device: torch.device):
self.per_system_total_charge = self.per_system_total_charge.to(device)
self.box_vectors = self.box_vectors.to(device)
self.is_periodic = self.is_periodic.to(device)

if self.pair_list is not None:
self.pair_list = self.pair_list.to(device)

if self.per_atom_partial_charge is not None:
self.per_atom_partial_charge = self.per_atom_partial_charge.to(device)
self.pair_list = self.pair_list.to(device)
self.per_atom_partial_charge = self.per_atom_partial_charge.to(device)

return self

Expand Down

0 comments on commit 685de67

Please sign in to comment.