diff --git a/modelforge/dataset/__init__.py b/modelforge/dataset/__init__.py index b339f412..570dd7cb 100644 --- a/modelforge/dataset/__init__.py +++ b/modelforge/dataset/__init__.py @@ -7,7 +7,7 @@ from .spice2 import SPICE2Dataset from .spice1openff import SPICE1OpenFFDataset from .phalkethoh import PhAlkEthOHDataset -from .dataset import DatasetFactory, DataModule, NNPInput +from .dataset import DatasetFactory, DataModule from enum import Enum diff --git a/modelforge/jax.py b/modelforge/jax.py index ce47ce44..2140af6f 100644 --- a/modelforge/jax.py +++ b/modelforge/jax.py @@ -1,4 +1,4 @@ -from modelforge.dataset import NNPInput +from modelforge.utils.prop import NNPInput def nnpinput_flatten(nnpinput: NNPInput): @@ -42,12 +42,9 @@ def convert_NNPInput_to_jax(nnp_input: NNPInput): nnp_input.box_vectors = convert_to_jax(nnp_input.box_vectors) nnp_input.is_periodic = convert_to_jax(nnp_input.is_periodic) - if nnp_input.pair_list is not None: - nnp_input.pair_list = convert_to_jax(nnp_input.pair_list) - - if nnp_input.per_atom_partial_charge is not None: - nnp_input.per_atom_partial_charge = convert_to_jax( - nnp_input.per_atom_partial_charge - ) + nnp_input.pair_list = convert_to_jax(nnp_input.pair_list) + nnp_input.per_atom_partial_charge = convert_to_jax( + nnp_input.per_atom_partial_charge + ) return nnp_input diff --git a/modelforge/potential/ani.py b/modelforge/potential/ani.py index 141da7ef..803dbabe 100644 --- a/modelforge/potential/ani.py +++ b/modelforge/potential/ani.py @@ -12,9 +12,8 @@ from loguru import logger as log from torch import nn -from modelforge.utils.prop import SpeciesAEV +from modelforge.utils.prop import SpeciesAEV, NNPInput -from modelforge.dataset.dataset import NNPInput from modelforge.potential.neighbors import PairlistData diff --git a/modelforge/potential/featurization.py b/modelforge/potential/featurization.py index eb25a84f..9d9b8cb6 100644 --- a/modelforge/potential/featurization.py +++ b/modelforge/potential/featurization.py @@ -3,7 +3,7 @@ from typing import Dict, List, Union from modelforge.potential.utils import DenseWithCustomDist -from modelforge.dataset import NNPInput +from modelforge.utils.prop import NNPInput class AddPerMoleculeValue(nn.Module): diff --git a/modelforge/potential/neighbors.py b/modelforge/potential/neighbors.py index 556b2dd1..b9b6ae90 100644 --- a/modelforge/potential/neighbors.py +++ b/modelforge/potential/neighbors.py @@ -790,8 +790,8 @@ def _calculate_interacting_pairs( Atom positions. Shape: [nr_systems, nr_atoms, 3]. atomic_subsystem_indices : torch.Tensor Indices identifying atoms in subsystems. Shape: [nr_atoms]. - pair_indices : Optional[torch.Tensor] - Precomputed pair indices. If None, will compute pair indices. + pair_indices : torch.Tensor + Precomputed pair indices. Returns ------- @@ -799,11 +799,6 @@ def _calculate_interacting_pairs( A dataclass containing 'pair_indices', 'd_ij' (distances), and 'r_ij' (displacement vectors). """ - if pair_indices is None: - pair_indices = self.enumerate_all_pairs( - atomic_subsystem_indices, - ) - r_ij = self.calculate_r_ij(pair_indices, positions) d_ij = self.calculate_d_ij(r_ij) @@ -838,8 +833,9 @@ def forward(self, data: Union[NNPInput, NamedTuple]) -> PairlistData: # general input manipulation positions = data.positions atomic_subsystem_indices = data.atomic_subsystem_indices - # calculate pairlist if none is provided - if data.pair_list is None: + # calculate pairlist if it is not provided + + if data.pair_list is None or data.pair_list.shape[0] == 0: # note, we set the flag for unique pairs when instantiated in the constructor # and thus this call will return unique pairs if requested. pair_list = self.pairlist.enumerate_all_pairs(atomic_subsystem_indices) diff --git a/modelforge/potential/painn.py b/modelforge/potential/painn.py index 23022275..77824d40 100644 --- a/modelforge/potential/painn.py +++ b/modelforge/potential/painn.py @@ -9,7 +9,7 @@ from loguru import logger as log from openff.units import unit -from modelforge.dataset.dataset import NNPInput +from modelforge.utils.prop import NNPInput from modelforge.potential.neighbors import PairlistData from .utils import DenseWithCustomDist diff --git a/modelforge/potential/physnet.py b/modelforge/potential/physnet.py index 12ee27e0..a5224222 100644 --- a/modelforge/potential/physnet.py +++ b/modelforge/potential/physnet.py @@ -8,7 +8,7 @@ from loguru import logger as log from torch import nn -from modelforge.dataset.dataset import NNPInput +from modelforge.utils.prop import NNPInput from modelforge.potential.neighbors import PairlistData from .utils import Dense diff --git a/modelforge/potential/potential.py b/modelforge/potential/potential.py index 670aaa0a..6ee74c7d 100644 --- a/modelforge/potential/potential.py +++ b/modelforge/potential/potential.py @@ -620,7 +620,6 @@ def generate_potential( jax = import_("jax") from modelforge.jax import nnpinput_flatten, nnpinput_unflatten - from modelforge.dataset import NNPInput # registering NNPInput multiple times will result in a # ValueError diff --git a/modelforge/potential/sake.py b/modelforge/potential/sake.py index 1e9b1333..948b4736 100644 --- a/modelforge/potential/sake.py +++ b/modelforge/potential/sake.py @@ -8,7 +8,7 @@ import torch.nn as nn from loguru import logger as log -from modelforge.dataset.dataset import NNPInput +from modelforge.utils.prop import NNPInput from modelforge.potential.neighbors import PairlistData from .utils import DenseWithCustomDist, scatter_softmax diff --git a/modelforge/potential/schnet.py b/modelforge/potential/schnet.py index fd22d3ae..cf2ce1b9 100644 --- a/modelforge/potential/schnet.py +++ b/modelforge/potential/schnet.py @@ -8,7 +8,7 @@ import torch.nn as nn from loguru import logger as log -from modelforge.dataset.dataset import NNPInput +from modelforge.utils.prop import NNPInput from modelforge.potential.neighbors import PairlistData diff --git a/modelforge/potential/tensornet.py b/modelforge/potential/tensornet.py index 5c60334c..26099bcc 100644 --- a/modelforge/potential/tensornet.py +++ b/modelforge/potential/tensornet.py @@ -9,7 +9,7 @@ from modelforge.potential import CosineAttenuationFunction, TensorNetRadialBasisFunction -from modelforge.dataset.dataset import NNPInput +from modelforge.utils.prop import NNPInput from modelforge.potential.neighbors import PairlistData diff --git a/modelforge/potential/utils.py b/modelforge/potential/utils.py index 852cebb7..f9a37670 100644 --- a/modelforge/potential/utils.py +++ b/modelforge/potential/utils.py @@ -10,7 +10,7 @@ import torch.nn as nn from openff.units import unit -from modelforge.dataset.dataset import NNPInput +from modelforge.utils.prop import NNPInput @dataclass(frozen=False) diff --git a/modelforge/tests/conftest.py b/modelforge/tests/conftest.py index f2427304..c0572ad4 100644 --- a/modelforge/tests/conftest.py +++ b/modelforge/tests/conftest.py @@ -288,7 +288,7 @@ def generate_uniform_quaternion(u=None): def rotation_matrix_from_quaternion(quaternion): """Compute a 3x3 rotation matrix from a given quaternion (4-vector). - Adapted from the numpy implementation in openmm-tools + Adapted from the numpy implementation in modelforgeopenmm-tools https://github.com/choderalab/openmmtools/blob/main/openmmtools/mcmc.py diff --git a/modelforge/tests/precalculated_values.py b/modelforge/tests/precalculated_values.py index 90cf4ade..24aa1620 100644 --- a/modelforge/tests/precalculated_values.py +++ b/modelforge/tests/precalculated_values.py @@ -106,7 +106,7 @@ def setup_single_methane_input(): ) E = torch.tensor([0.0], requires_grad=True) atomic_subsystem_indices = torch.tensor([0, 0, 0, 0, 0], dtype=torch.int32) - from modelforge.dataset.dataset import NNPInput + from modelforge.utils.prop import NNPInput modelforge_methane = NNPInput( atomic_numbers=atomic_numbers, diff --git a/modelforge/tests/test_ani.py b/modelforge/tests/test_ani.py index 3aa006f6..0eaf4bfd 100644 --- a/modelforge/tests/test_ani.py +++ b/modelforge/tests/test_ani.py @@ -31,7 +31,7 @@ def setup_methane(): [0, 0, 0, 0, 0], dtype=torch.int32, device=device ) - from modelforge.dataset.dataset import NNPInput + from modelforge.utils.prop import NNPInput nnp_input = NNPInput( atomic_numbers=torch.tensor([6, 1, 1, 1, 1], device=device), @@ -76,7 +76,7 @@ def setup_two_methanes(): ) atomic_numbers = mf_species - from modelforge.dataset.dataset import NNPInput + from modelforge.utils.prop import NNPInput nnp_input = NNPInput( atomic_numbers=atomic_numbers, diff --git a/modelforge/tests/test_dataset.py b/modelforge/tests/test_dataset.py index 87bf8be5..9f64e4f2 100644 --- a/modelforge/tests/test_dataset.py +++ b/modelforge/tests/test_dataset.py @@ -13,11 +13,7 @@ @pytest.fixture(scope="session") def prep_temp_dir(tmp_path_factory): - import uuid - - filename = str(uuid.uuid4()) - - fn = tmp_path_factory.mktemp(f"test_dataset_temp") + fn = tmp_path_factory.mktemp("test_dataset_temp") return fn @@ -258,7 +254,7 @@ def test_file_existence_after_initialization( def test_caching(prep_temp_dir): import contextlib - local_cache_dir = str(prep_temp_dir) + "/data_test" + local_cache_dir = str(prep_temp_dir) from modelforge.dataset.qm9 import QM9Dataset data = QM9Dataset(version_select="nc_1000_v0", local_cache_dir=local_cache_dir) @@ -334,7 +330,7 @@ def test_metadata_validation(prep_temp_dir): which is used to validate if we can use .npz file, or we need to regenerate it.""" - local_cache_dir = str(prep_temp_dir) + "/data_test" + local_cache_dir = str(prep_temp_dir) from modelforge.dataset.qm9 import QM9Dataset @@ -471,7 +467,7 @@ def test_data_item_format_of_datamodule( """Test the format of individual data items in the dataset.""" from typing import Dict - local_cache_dir = str(prep_temp_dir) + "/data_test" + local_cache_dir = str(prep_temp_dir) dm = datamodule_factory( dataset_name=dataset_name, @@ -728,7 +724,7 @@ def test_dataset_downloader(dataset_name, dataset_factory, prep_temp_dir): """ Test the DatasetDownloader functionality. """ - local_cache_dir = str(prep_temp_dir) + "/data_test" + local_cache_dir = str(prep_temp_dir) dataset = dataset_factory( dataset_name=dataset_name, local_cache_dir=local_cache_dir @@ -740,7 +736,7 @@ def test_dataset_downloader(dataset_name, dataset_factory, prep_temp_dir): @pytest.mark.parametrize("dataset_name", _ImplementedDatasets.get_all_dataset_names()) -def test_numpy_dataset_assignment(dataset_name): +def test_numpy_dataset_assignment(dataset_name, prep_temp_dir): """ Test if the numpy_dataset attribute is correctly assigned after processing or loading. """ @@ -748,7 +744,7 @@ def test_numpy_dataset_assignment(dataset_name): factory = DatasetFactory() data = _ImplementedDatasets.get_dataset_class(dataset_name)( - version_select="nc_1000_v0" + version_select="nc_1000_v0", local_cache_dir=str(prep_temp_dir) ) factory._load_or_process_data(data) @@ -992,7 +988,7 @@ def test_function_of_self_energy(dataset_name, datamodule_factory, prep_temp_dir def test_shifting_center_of_mass_to_origin(prep_temp_dir): - local_cache_dir = str(prep_temp_dir) + "/data_test" + local_cache_dir = str(prep_temp_dir) from modelforge.dataset.dataset import initialize_datamodule from openff.units.elements import MASSES diff --git a/modelforge/tests/test_potentials.py b/modelforge/tests/test_potentials.py index 85e4558e..43792f86 100644 --- a/modelforge/tests/test_potentials.py +++ b/modelforge/tests/test_potentials.py @@ -1107,3 +1107,37 @@ def test_loading_from_checkpoint_file(): model = load_inference_model_from_checkpoint(chkp_file) assert model is not None + + +@pytest.mark.parametrize( + "potential_name", _Implemented_NNPs.get_all_neural_network_names() +) +def test_saving_torchscript(potential_name, single_batch_with_batchsize, prep_temp_dir): + batch = single_batch_with_batchsize( + batch_size=1, dataset_name="QM9", local_cache_dir=str(prep_temp_dir) + ) + from modelforge.potential.potential import NeuralNetworkPotentialFactory + from modelforge.utils.misc import load_configs_into_pydantic_models + + config = load_configs_into_pydantic_models(potential_name.lower(), "qm9") + + # read default parameters + potential = NeuralNetworkPotentialFactory.generate_potential( + use="inference", + potential_parameter=config["potential"], + potential_seed=42, + ) + potential_jit = torch.jit.script(potential) + filename = f"{str(prep_temp_dir)}/{potential_name.lower()}_qm9_jit.pt" + + potential_jit.save(filename) + + output = potential(batch.nnp_input) + output_jit1 = potential_jit(batch.nnp_input) + + # load the model + jit_model = torch.jit.load(filename) + output_jit2 = jit_model(batch.nnp_input) + + assert torch.allclose(output["per_system_energy"], output_jit1["per_system_energy"]) + assert torch.allclose(output["per_system_energy"], output_jit2["per_system_energy"]) diff --git a/modelforge/utils/prop.py b/modelforge/utils/prop.py index 2a0e64be..8deb1c64 100644 --- a/modelforge/utils/prop.py +++ b/modelforge/utils/prop.py @@ -48,8 +48,8 @@ def __init__( per_system_total_charge: torch.Tensor, box_vectors: torch.Tensor = torch.zeros(3, 3), is_periodic: torch.Tensor = torch.tensor([False]), - pair_list: torch.Tensor = None, - per_atom_partial_charge: torch.Tensor = None, + pair_list: torch.Tensor = torch.tensor([]), + per_atom_partial_charge: torch.Tensor = torch.tensor([]), ): self.atomic_numbers = atomic_numbers self.positions = positions @@ -60,7 +60,6 @@ def __init__( self.box_vectors = box_vectors self.is_periodic = is_periodic - # Validate inputs self._validate_inputs() @@ -95,7 +94,6 @@ def _validate_inputs(self): "The size of atomic_subsystem_indices and the first dimension of positions must match" ) - def to_device(self, device: torch.device): """Move all tensors in this instance to the specified device.""" @@ -105,12 +103,8 @@ def to_device(self, device: torch.device): self.per_system_total_charge = self.per_system_total_charge.to(device) self.box_vectors = self.box_vectors.to(device) self.is_periodic = self.is_periodic.to(device) - - if self.pair_list is not None: - self.pair_list = self.pair_list.to(device) - - if self.per_atom_partial_charge is not None: - self.per_atom_partial_charge = self.per_atom_partial_charge.to(device) + self.pair_list = self.pair_list.to(device) + self.per_atom_partial_charge = self.per_atom_partial_charge.to(device) return self @@ -191,11 +185,11 @@ def to_dtype(self, dtype: torch.dtype): class BatchData: nnp_input: NNPInput metadata: Metadata - + def to( self, device: torch.device, - ): # NOTE: this is required to move the data to device + ): # NOTE: this is required to move the data to device """Move all data in this batch to the specified device and dtype.""" self.nnp_input = self.nnp_input.to_device(device=device) self.metadata = self.metadata.to_device(device=device)