Skip to content

Commit

Permalink
Merge pull request #319 from laserkelvin/mace-interface-changes
Browse files Browse the repository at this point in the history
`MACEWrapper` interface changes
  • Loading branch information
laserkelvin authored Nov 18, 2024
2 parents 028f44e + 6465b17 commit e446eea
Showing 1 changed file with 67 additions and 3 deletions.
70 changes: 67 additions & 3 deletions matsciml/models/pyg/mace/wrapper/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,54 @@ def __init__(
embedding_kwargs: Any = None,
encoder_only: bool = True,
readout_method: str | Callable = "add",
atomic_energies: dict[int, float] | list | torch.Tensor | None = None,
disable_forces: bool = True,
**mace_kwargs,
) -> None:
"""
Initializes a wrapper for MACE architectures.
This wrapper integrates MACE models into the ``matsciml`` pipeline
by ensuring that the inputs are what MACE expects, and that the
outputs are what tasks expect to see.
Additional ``mace_kwargs`` are passed into the the model constructor
after validating them with what is expected by that particular
variant.
Parameters
----------
atom_embedding_dim : int
Embedding dimensionality for atoms.
mace_module : Type[MACE], default MACE
Reference to the MACE architecture class. Defaults to
the ``MACE`` class, but can be swapped out for references
to e.g. ``ScaleShiftMACE``.
num_atom_embedding : int, default 100
Number of atoms expected to be trained on for this dataset.
This is common to other ``matsciml`` models/wrappers, and
refers to the maximum atomic number to include in the modeling;
this differs slightly from how MACE treats the periodic table.
embedding_kwargs : Any, default None
Unused by MACE models, as we do not use an embedding table.
encoder_only : bool, default True
Unintended for usage but kept for continuity of ``matsciml``
models/wrappers.
readout_method : str | Callable, default 'add'
Method or string for the node reduction to obtain graph-level
energies/properties. If a string is passed, we use this to
map to the ``global_<readout_method>_pool`` function in PyG.
atomic_energies : dict[int, float] | list | torch.Tensor | None, default None
If None, uses the ``free_ion_energy_table`` function to obtain ionization
energies to use as a basis for the atom bias. If a dictionary is passed,
the keys should correspond to the atomic number, and value the associated
atomic energy. This then gets mapped to a tensor where unmapped values are
ones.
disable_forces : bool, default True
If set to ``True``, force computation by MACE is disabled.
The default value is set to ensure backwards and general
task compatibility.
"""
if embedding_kwargs is not None:
logger.warning("`embedding_kwargs` is not used for MACE models.")
super().__init__(atom_embedding_dim, num_atom_embedding, {}, encoder_only)
Expand Down Expand Up @@ -95,12 +140,31 @@ def __init__(
# pack stuff into the mace kwargs
mace_kwargs["num_elements"] = num_atom_embedding
mace_kwargs["hidden_irreps"] = hidden_irreps
mace_kwargs["atomic_numbers"] = list(range(1, num_atom_embedding + 1))
if "atomic_energies" not in mace_kwargs:
if atomic_energies is None:
logger.warning(
"No ``atomic_energies`` provided, defaulting to total ionization energy."
)
mace_kwargs["atomic_energies"] = free_ion_energy_table(num_atom_embedding)
atomic_energies = free_ion_energy_table(num_atom_embedding)
if isinstance(atomic_energies, dict):
max_atom_num = max(list(atomic_energies.keys()))
if max_atom_num > num_atom_embedding:
logger.warning(
"atomic_energies contains higher atom number than num_atom_embedding;"
" setting a larger value for the latter."
)
num_atom_embedding = max_atom_num
temp_tensor = torch.ones(num_atom_embedding).double()
# iterate through the atomic numbers and map to values
for index, value in atomic_energies.items():
temp_tensor[index] = value
atomic_energies = temp_tensor
if isinstance(atomic_energies, list):
assert (
len(atomic_energies) == num_atom_embedding
), "Mismatch in number of atomic energies and expected atom table."
atomic_energies = torch.Tensor(atomic_energies).double()
mace_kwargs["atomic_numbers"] = list(range(1, num_atom_embedding + 1))
mace_kwargs["atomic_energies"] = atomic_energies
# check to make sure all that's required is
for key in __mace_required_args + __mace_submodule_required_args:
if key not in mace_kwargs:
Expand Down

0 comments on commit e446eea

Please sign in to comment.