diff --git a/matsciml/interfaces/ase/base.py b/matsciml/interfaces/ase/base.py index ec8f3414..af435b76 100644 --- a/matsciml/interfaces/ase/base.py +++ b/matsciml/interfaces/ase/base.py @@ -212,6 +212,9 @@ def _format_atoms(self, atoms: Atoms) -> DataDict: data_dict["pos"] = pos data_dict["atomic_numbers"] = atomic_numbers data_dict["cell"] = cell + # ptr and batch are usually expected by MACE even if it's a single graph + data_dict["ptr"] = torch.tensor([0]) + data_dict["batch"] = torch.zeros((pos.size(0))) return data_dict def _format_pipeline(self, atoms: Atoms) -> DataDict: