diff --git a/matsciml/interfaces/ase/base.py b/matsciml/interfaces/ase/base.py index af435b76..62e05a5a 100644 --- a/matsciml/interfaces/ase/base.py +++ b/matsciml/interfaces/ase/base.py @@ -17,6 +17,7 @@ ) from matsciml.datasets.transforms.base import AbstractDataTransform from matsciml.interfaces.ase import multitask as mt +from matsciml.datasets.utils import concatenate_keys __all__ = ["MatSciMLCalculator"] @@ -83,10 +84,12 @@ class MatSciMLCalculator(Calculator): def __init__( self, - task_module: ScalarRegressionTask - | GradFreeForceRegressionTask - | ForceRegressionTask - | MultiTaskLitModule, + task_module: ( + ScalarRegressionTask + | GradFreeForceRegressionTask + | ForceRegressionTask + | MultiTaskLitModule + ), transforms: list[AbstractDataTransform | Callable] | None = None, restart=None, label=None, @@ -94,6 +97,8 @@ def __init__( directory=".", conversion_factor: float | dict[str, float] = 1.0, multitask_strategy: str | Callable | mt.AbstractStrategy = "AverageTasks", + output_map: dict[str, str] | None = None, + matsciml_model: bool = True, **kwargs, ): """ @@ -144,33 +149,39 @@ def __init__( to ``ase``. If a single ``float`` is passed, we assume that the conversion is applied to the energy output. Each factor is multiplied with the result. + output_map : dict[str, str] | None, default None + specifies how model outputs should be mapped to Calculator expected + results. for example {"ase_expected": "model_output"} -> {"forces": "force"} + matsciml_model : bool, default True + flag indicating whether model was trained with matsciml or not. """ super().__init__( restart, label=label, atoms=atoms, directory=directory, **kwargs ) - assert isinstance( - task_module, - ( - ForceRegressionTask, - ScalarRegressionTask, - GradFreeForceRegressionTask, - MultiTaskLitModule, - ), - ), f"Expected task to be one that is capable of energy/force prediction. Got {task_module.__type__}." - if isinstance(task_module, MultiTaskLitModule): - assert any( - [ - isinstance( - subtask, - ( - ForceRegressionTask, - ScalarRegressionTask, - GradFreeForceRegressionTask, - ), - ) - for subtask in task_module.task_list - ] - ), "Expected at least one subtask to be energy/force predictor." + if matsciml_model: + assert isinstance( + task_module, + ( + ForceRegressionTask, + ScalarRegressionTask, + GradFreeForceRegressionTask, + MultiTaskLitModule, + ), + ), f"Expected task to be one that is capable of energy/force prediction. Got {task_module.__type__}." + if isinstance(task_module, MultiTaskLitModule): + assert any( + [ + isinstance( + subtask, + ( + ForceRegressionTask, + ScalarRegressionTask, + GradFreeForceRegressionTask, + ), + ) + for subtask in task_module.task_list + ] + ), "Expected at least one subtask to be energy/force predictor." self.task_module = task_module self.transforms = transforms self.conversion_factor = conversion_factor @@ -182,6 +193,18 @@ def __init__( ) multitask_strategy = cls_name() self.multitask_strategy = multitask_strategy + self.matsciml_model = matsciml_model + self.output_map = dict( + zip(self.implemented_properties, self.implemented_properties) + ) + if output_map is not None: + for k, v in output_map.items(): + if k not in self.output_map: + raise KeyError( + f"Specified key {k} is not one of the implemented_properties of this calculator: {self.implemented_properties}" + ) + else: + self.output_map[k] = v @property def conversion_factor(self) -> dict[str, float]: @@ -212,9 +235,8 @@ def _format_atoms(self, atoms: Atoms) -> DataDict: data_dict["pos"] = pos data_dict["atomic_numbers"] = atomic_numbers data_dict["cell"] = cell - # ptr and batch are usually expected by MACE even if it's a single graph - data_dict["ptr"] = torch.tensor([0]) - data_dict["batch"] = torch.zeros((pos.size(0))) + data_dict["frac_coords"] = torch.from_numpy(atoms.get_scaled_positions()) + data_dict["natoms"] = pos.size(0) return data_dict def _format_pipeline(self, atoms: Atoms) -> DataDict: @@ -230,10 +252,6 @@ def _format_pipeline(self, atoms: Atoms) -> DataDict: """ # initial formatting to get something akin to dataset outputs data_dict = self._format_atoms(atoms) - # type cast into the type expected by the model - data_dict = recursive_type_cast( - data_dict, self.dtype, ignore_keys=["atomic_numbers"], convert_numpy=True - ) # now run through the same transform pipeline as for datasets if self.transforms: for transform in self.transforms: @@ -248,24 +266,39 @@ def calculate( ) -> None: # retrieve atoms even if not passed Calculator.calculate(self, atoms) - # get into format ready for matsciml model - data_dict = self._format_pipeline(atoms) - # run the data structure through the model - output = self.task_module.predict(data_dict) + if self.matsciml_model: + # get into format ready for matsciml model + data_dict = self._format_pipeline(atoms) + # concatenate_keys batches data and adds some attributes that may be expected, like ptr. + data_dict = concatenate_keys([data_dict]) + # type cast into the type expected by the model + data_dict = recursive_type_cast( + data_dict, + self.dtype, + ignore_keys=["atomic_numbers"], + convert_numpy=True, + ) + # run the data structure through the model + output = self.task_module.predict(data_dict) + else: + output = self.task_module.forward(atoms) if isinstance(self.task_module, MultiTaskLitModule): # use a more complicated parser for multitasks results = self.multitask_strategy(output, self.task_module) self.results = results else: - # add outputs to self.results as expected by ase - if "energy" in output: - self.results["energy"] = output["energy"].detach().item() - if "force" in output: - self.results["forces"] = output["force"].detach().numpy() - if "stress" in output: - self.results["stress"] = output["stress"].detach().numpy() - if "dipole" in output: - self.results["dipole"] = output["dipole"].detach().numpy() + # add outputs to self.results as expected by ase, as specified by ``properties`` + # "ase_properties" are those in ``properties``. + for ase_property in properties: + model_property = self.output_map[ase_property] + model_output = output.get(model_property, None) + if model_output is not None: + self.results[ase_property] = model_output.detach().numpy() + else: + raise KeyError( + f"Expected model to return {model_property} as an output." + ) + if len(self.results) == 0: raise RuntimeError( f"No expected properties were written. Output dict: {output}" diff --git a/matsciml/interfaces/ase/tests/test_ase_calc.py b/matsciml/interfaces/ase/tests/test_ase_calc.py index fabf3303..aefb09ef 100644 --- a/matsciml/interfaces/ase/tests/test_ase_calc.py +++ b/matsciml/interfaces/ase/tests/test_ase_calc.py @@ -14,6 +14,12 @@ ForceRegressionTask, ) from matsciml.models.pyg import EGNN +from types import MethodType + +import matgl +import torch +from matgl.ext.ase import Atoms2Graph + np.random.seed(21516136) @@ -48,7 +54,9 @@ def test_egnn_energy_forces(egnn_config: dict, test_pbc: Atoms, pbc_transform: l task = ForceRegressionTask( encoder_class=EGNN, encoder_kwargs=egnn_config, output_kwargs={"hidden_dim": 32} ) - calc = MatSciMLCalculator(task, transforms=pbc_transform) + calc = MatSciMLCalculator( + task, transforms=pbc_transform, output_map={"forces": "force"} + ) atoms = test_pbc.copy() atoms.calc = calc energy = atoms.get_potential_energy() @@ -62,8 +70,49 @@ def test_egnn_dynamics(egnn_config: dict, test_pbc: Atoms, pbc_transform: list): task = ForceRegressionTask( encoder_class=EGNN, encoder_kwargs=egnn_config, output_kwargs={"hidden_dim": 32} ) - calc = MatSciMLCalculator(task, transforms=pbc_transform) + calc = MatSciMLCalculator( + task, transforms=pbc_transform, output_map={"forces": "force"} + ) atoms = test_pbc.copy() atoms.calc = calc dyn = VelocityVerlet(atoms, timestep=5 * units.fs, logfile="md.log") dyn.run(3) + + +def test_matgl(): + matgl_model = matgl.load_model("CHGNet-MPtrj-2024.2.13-PES-11M") + + def forward(self, atoms): + graph_converter = Atoms2Graph( + element_types=matgl_model.model.element_types, + cutoff=matgl_model.model.cutoff, + ) + graph, lattice, state_feats_default = graph_converter.get_graph(atoms) + graph.edata["pbc_offshift"] = torch.matmul( + graph.edata["pbc_offset"], lattice[0] + ) + graph.ndata["pos"] = graph.ndata["frac_coords"] @ lattice[0] + state_feats = torch.tensor(state_feats_default) + total_energies, forces, stresses, *others = self.matgl_forward( + graph, lattice, state_feats + ) + output = {} + output["energy"] = total_energies + output["forces"] = forces + output["stress"] = stresses + return output + + matgl_model.matgl_forward = matgl_model.forward + matgl_model.forward = MethodType(forward, matgl_model) + + calc = MatSciMLCalculator(matgl_model, matsciml_model=False) + pos = np.random.normal(0.0, 1.0, size=(10, 3)) + # Using a different atoms object due to pretrained model atom embedding expecting + # a different range of atomic numbers. + atomic_numbers = np.random.randint(1, 94, size=(10,)) + atoms = Atoms(numbers=atomic_numbers, positions=pos) + atoms.calc = calc + energy = atoms.get_potential_energy() + assert np.isfinite(energy) + forces = atoms.get_forces() + assert np.isfinite(forces).all()