Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing NNPInput to allow us to write torch script models #295

Merged
merged 5 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion modelforge/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .spice2 import SPICE2Dataset
from .spice1openff import SPICE1OpenFFDataset
from .phalkethoh import PhAlkEthOHDataset
from .dataset import DatasetFactory, DataModule, NNPInput
from .dataset import DatasetFactory, DataModule
from enum import Enum


Expand Down
19 changes: 12 additions & 7 deletions modelforge/jax.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from modelforge.dataset import NNPInput
from modelforge.utils.prop import NNPInput


def nnpinput_flatten(nnpinput: NNPInput):
Expand Down Expand Up @@ -42,12 +42,17 @@ def convert_NNPInput_to_jax(nnp_input: NNPInput):
nnp_input.box_vectors = convert_to_jax(nnp_input.box_vectors)
nnp_input.is_periodic = convert_to_jax(nnp_input.is_periodic)

if nnp_input.pair_list is not None:
nnp_input.pair_list = convert_to_jax(nnp_input.pair_list)
nnp_input.pair_list = convert_to_jax(nnp_input.pair_list)
nnp_input.per_atom_partial_charge = convert_to_jax(
nnp_input.per_atom_partial_charge
)

if nnp_input.per_atom_partial_charge is not None:
nnp_input.per_atom_partial_charge = convert_to_jax(
nnp_input.per_atom_partial_charge
)
# if nnp_input.pair_list is not None:
# nnp_input.pair_list = convert_to_jax(nnp_input.pair_list)
#
# if nnp_input.per_atom_partial_charge is not None:
# nnp_input.per_atom_partial_charge = convert_to_jax(
# nnp_input.per_atom_partial_charge
# )

chrisiacovella marked this conversation as resolved.
Show resolved Hide resolved
return nnp_input
3 changes: 1 addition & 2 deletions modelforge/potential/ani.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@
from loguru import logger as log
from torch import nn

from modelforge.utils.prop import SpeciesAEV
from modelforge.utils.prop import SpeciesAEV, NNPInput

from modelforge.dataset.dataset import NNPInput
from modelforge.potential.neighbors import PairlistData


Expand Down
2 changes: 1 addition & 1 deletion modelforge/potential/featurization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Dict, List, Union

from modelforge.potential.utils import DenseWithCustomDist
from modelforge.dataset import NNPInput
from modelforge.utils.prop import NNPInput


class AddPerMoleculeValue(nn.Module):
Expand Down
14 changes: 5 additions & 9 deletions modelforge/potential/neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,20 +790,15 @@ def _calculate_interacting_pairs(
Atom positions. Shape: [nr_systems, nr_atoms, 3].
atomic_subsystem_indices : torch.Tensor
Indices identifying atoms in subsystems. Shape: [nr_atoms].
pair_indices : Optional[torch.Tensor]
Precomputed pair indices. If None, will compute pair indices.
pair_indices : torch.Tensor
Precomputed pair indices.

Returns
-------
PairListOutputs
A dataclass containing 'pair_indices', 'd_ij' (distances), and 'r_ij' (displacement vectors).
"""

if pair_indices is None:
pair_indices = self.enumerate_all_pairs(
atomic_subsystem_indices,
)

r_ij = self.calculate_r_ij(pair_indices, positions)
d_ij = self.calculate_d_ij(r_ij)

Expand Down Expand Up @@ -838,8 +833,9 @@ def forward(self, data: Union[NNPInput, NamedTuple]) -> PairlistData:
# general input manipulation
positions = data.positions
atomic_subsystem_indices = data.atomic_subsystem_indices
# calculate pairlist if none is provided
if data.pair_list is None:
# calculate pairlist if it is not provided

if data.pair_list is None or data.pair_list.shape[0] == 0:
# note, we set the flag for unique pairs when instantiated in the constructor
# and thus this call will return unique pairs if requested.
pair_list = self.pairlist.enumerate_all_pairs(atomic_subsystem_indices)
Expand Down
2 changes: 1 addition & 1 deletion modelforge/potential/painn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from loguru import logger as log
from openff.units import unit

from modelforge.dataset.dataset import NNPInput
from modelforge.utils.prop import NNPInput
from modelforge.potential.neighbors import PairlistData
from .utils import DenseWithCustomDist

Expand Down
2 changes: 1 addition & 1 deletion modelforge/potential/physnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from loguru import logger as log
from torch import nn

from modelforge.dataset.dataset import NNPInput
from modelforge.utils.prop import NNPInput
from modelforge.potential.neighbors import PairlistData
from .utils import Dense

Expand Down
1 change: 0 additions & 1 deletion modelforge/potential/potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,6 @@ def generate_potential(

jax = import_("jax")
from modelforge.jax import nnpinput_flatten, nnpinput_unflatten
from modelforge.dataset import NNPInput

# registering NNPInput multiple times will result in a
# ValueError
Expand Down
2 changes: 1 addition & 1 deletion modelforge/potential/sake.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch.nn as nn
from loguru import logger as log

from modelforge.dataset.dataset import NNPInput
from modelforge.utils.prop import NNPInput
from modelforge.potential.neighbors import PairlistData

from .utils import DenseWithCustomDist, scatter_softmax
Expand Down
2 changes: 1 addition & 1 deletion modelforge/potential/schnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch.nn as nn
from loguru import logger as log

from modelforge.dataset.dataset import NNPInput
from modelforge.utils.prop import NNPInput
from modelforge.potential.neighbors import PairlistData


Expand Down
2 changes: 1 addition & 1 deletion modelforge/potential/tensornet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from modelforge.potential import CosineAttenuationFunction, TensorNetRadialBasisFunction

from modelforge.dataset.dataset import NNPInput
from modelforge.utils.prop import NNPInput
from modelforge.potential.neighbors import PairlistData


Expand Down
2 changes: 1 addition & 1 deletion modelforge/potential/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch.nn as nn
from openff.units import unit

from modelforge.dataset.dataset import NNPInput
from modelforge.utils.prop import NNPInput


@dataclass(frozen=False)
Expand Down
4 changes: 2 additions & 2 deletions modelforge/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def generate_uniform_quaternion(u=None):
"""
Generates a uniform normalized quaternion.

Adapted from numpy implementation in openmm-tools
Adapted from numpy implementation in modelforgeopenmm-tools
chrisiacovella marked this conversation as resolved.
Show resolved Hide resolved
https://github.com/choderalab/openmmtools/blob/main/openmmtools/mcmc.py

Parameters
Expand Down Expand Up @@ -288,7 +288,7 @@ def generate_uniform_quaternion(u=None):
def rotation_matrix_from_quaternion(quaternion):
"""Compute a 3x3 rotation matrix from a given quaternion (4-vector).

Adapted from the numpy implementation in openmm-tools
Adapted from the numpy implementation in modelforgeopenmm-tools

https://github.com/choderalab/openmmtools/blob/main/openmmtools/mcmc.py

Expand Down
2 changes: 1 addition & 1 deletion modelforge/tests/precalculated_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def setup_single_methane_input():
)
E = torch.tensor([0.0], requires_grad=True)
atomic_subsystem_indices = torch.tensor([0, 0, 0, 0, 0], dtype=torch.int32)
from modelforge.dataset.dataset import NNPInput
from modelforge.utils.prop import NNPInput

modelforge_methane = NNPInput(
atomic_numbers=atomic_numbers,
Expand Down
4 changes: 2 additions & 2 deletions modelforge/tests/test_ani.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def setup_methane():
[0, 0, 0, 0, 0], dtype=torch.int32, device=device
)

from modelforge.dataset.dataset import NNPInput
from modelforge.utils.prop import NNPInput

nnp_input = NNPInput(
atomic_numbers=torch.tensor([6, 1, 1, 1, 1], device=device),
Expand Down Expand Up @@ -76,7 +76,7 @@ def setup_two_methanes():
)

atomic_numbers = mf_species
from modelforge.dataset.dataset import NNPInput
from modelforge.utils.prop import NNPInput

nnp_input = NNPInput(
atomic_numbers=atomic_numbers,
Expand Down
34 changes: 34 additions & 0 deletions modelforge/tests/test_potentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,3 +1107,37 @@ def test_loading_from_checkpoint_file():

model = load_inference_model_from_checkpoint(chkp_file)
assert model is not None


@pytest.mark.parametrize(
"potential_name", _Implemented_NNPs.get_all_neural_network_names()
)
def test_saving_torchscript(potential_name, single_batch_with_batchsize, prep_temp_dir):
batch = single_batch_with_batchsize(
batch_size=1, dataset_name="QM9", local_cache_dir=str(prep_temp_dir)
)
from modelforge.potential.potential import NeuralNetworkPotentialFactory
from modelforge.utils.misc import load_configs_into_pydantic_models

config = load_configs_into_pydantic_models("ani2x", "qm9")
chrisiacovella marked this conversation as resolved.
Show resolved Hide resolved

# read default parameters
potential = NeuralNetworkPotentialFactory.generate_potential(
use="inference",
potential_parameter=config["potential"],
potential_seed=42,
)
potential_jit = torch.jit.script(potential)
filename = f"{str(prep_temp_dir)}/{potential_name.lower()}_qm9_jit.pt"

potential_jit.save(filename)

output = potential(batch.nnp_input)
output_jit1 = potential_jit(batch.nnp_input)

# load the model
jit_model = torch.jit.load(filename)
output_jit2 = jit_model(batch.nnp_input)

assert torch.allclose(output["per_system_energy"], output_jit1["per_system_energy"])
assert torch.allclose(output["per_system_energy"], output_jit2["per_system_energy"])
22 changes: 11 additions & 11 deletions modelforge/utils/prop.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def __init__(
per_system_total_charge: torch.Tensor,
box_vectors: torch.Tensor = torch.zeros(3, 3),
is_periodic: torch.Tensor = torch.tensor([False]),
pair_list: torch.Tensor = None,
per_atom_partial_charge: torch.Tensor = None,
pair_list: torch.Tensor = torch.tensor([]),
per_atom_partial_charge: torch.Tensor = torch.tensor([]),
):
self.atomic_numbers = atomic_numbers
self.positions = positions
Expand All @@ -60,7 +60,6 @@ def __init__(
self.box_vectors = box_vectors
self.is_periodic = is_periodic


# Validate inputs
self._validate_inputs()

Expand Down Expand Up @@ -95,7 +94,6 @@ def _validate_inputs(self):
"The size of atomic_subsystem_indices and the first dimension of positions must match"
)


def to_device(self, device: torch.device):
"""Move all tensors in this instance to the specified device."""

Expand All @@ -105,12 +103,14 @@ def to_device(self, device: torch.device):
self.per_system_total_charge = self.per_system_total_charge.to(device)
self.box_vectors = self.box_vectors.to(device)
self.is_periodic = self.is_periodic.to(device)
self.pair_list = self.pair_list.to(device)
self.per_atom_partial_charge = self.per_atom_partial_charge.to(device)

if self.pair_list is not None:
self.pair_list = self.pair_list.to(device)

if self.per_atom_partial_charge is not None:
self.per_atom_partial_charge = self.per_atom_partial_charge.to(device)
# if not self.pair_list is None:
# self.pair_list = self.pair_list.to(device)
#
# if not self.per_atom_partial_charge is None:
# self.per_atom_partial_charge = self.per_atom_partial_charge.to(device)
chrisiacovella marked this conversation as resolved.
Show resolved Hide resolved

return self

Expand Down Expand Up @@ -191,11 +191,11 @@ def to_dtype(self, dtype: torch.dtype):
class BatchData:
nnp_input: NNPInput
metadata: Metadata

def to(
self,
device: torch.device,
): # NOTE: this is required to move the data to device
): # NOTE: this is required to move the data to device
"""Move all data in this batch to the specified device and dtype."""
self.nnp_input = self.nnp_input.to_device(device=device)
self.metadata = self.metadata.to_device(device=device)
Expand Down
Loading