Skip to content

Commit

Permalink
all tests passing, fixed featurization bug
Browse files Browse the repository at this point in the history
  • Loading branch information
wiederm committed Sep 13, 2024
1 parent a47ec23 commit cdc6ca3
Show file tree
Hide file tree
Showing 11 changed files with 63 additions and 69 deletions.
2 changes: 2 additions & 0 deletions modelforge/potential/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,6 +962,8 @@ def setup_potential(
**potential_parameter.core_parameter.model_dump()
)

# pop property properties_to_featurize from potential_parameter.postprocessing_parameter dictionary

postprocessing = PostProcessing(
postprocessing_parameter=potential_parameter.postprocessing_parameter.model_dump(),
dataset_statistic=remove_units_from_dataset_statistics(dataset_statistic),
Expand Down
45 changes: 15 additions & 30 deletions modelforge/potential/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@
computed_field,
Field,
)
from openff.units import unit
from typing import Union, List, Optional, Type
from modelforge.utils.units import _convert_str_to_unit, _convert_str_to_unit_length
from typing import Union, Optional, Type, List
from modelforge.utils.units import _convert_str_to_unit_length
from enum import Enum

import torch
Expand Down Expand Up @@ -44,7 +43,8 @@ class ParametersBase(BaseModel):
)


# for the activation functions we have defined alpha and negative slope are the two parameters that are possible
# for the activation functions we have defined alpha and negative slope are the
# two parameters that are possible
class ActivationFunctionParamsAlpha(BaseModel):
alpha: Optional[float] = None

Expand All @@ -53,6 +53,16 @@ class ActivationFunctionParamsNegativeSlope(BaseModel):
negative_slope: Optional[float] = None


class AtomicNumber(BaseModel):
maximum_atomic_number: int = 101
number_of_per_atom_features: int = 32


class Featurization(BaseModel):
properties_to_featurize: List[str]
atomic_number: AtomicNumber = Field(default_factory=AtomicNumber)


class ActivationFunctionName(CaseInsensitiveEnum):
ReLU = "ReLU"
CeLU = "CeLU"
Expand Down Expand Up @@ -160,20 +170,13 @@ class PostProcessingParameter(ParametersBase):

class SchNetParameters(ParametersBase):
class CoreParameter(ParametersBase):
class Featurization(BaseModel):
class AtomicNumber(BaseModel):
maximum_atomic_number: int = 101
number_of_per_atom_features: int = 32

atomic_number: AtomicNumber = Field(default_factory=AtomicNumber)

number_of_radial_basis_functions: int
maximum_interaction_radius: float
number_of_interaction_modules: int
number_of_filters: int
shared_interactions: bool
activation_function_parameter: ActivationFunctionConfig
featurization: Featurization = Field(default_factory=Featurization)
featurization: Featurization

converted_units = field_validator("maximum_interaction_radius", mode="before")(
_convert_str_to_unit_length
Expand Down Expand Up @@ -221,12 +224,6 @@ class PostProcessingParameter(ParametersBase):

class PaiNNParameters(ParametersBase):
class CoreParameter(ParametersBase):
class Featurization(BaseModel):
class AtomicNumber(BaseModel):
maximum_atomic_number: int = 101
number_of_per_atom_features: int = 32

atomic_number: AtomicNumber = Field(default_factory=AtomicNumber)

number_of_radial_basis_functions: int
maximum_interaction_radius: float
Expand Down Expand Up @@ -254,12 +251,6 @@ class PostProcessingParameter(ParametersBase):

class PhysNetParameters(ParametersBase):
class CoreParameter(ParametersBase):
class Featurization(BaseModel):
class AtomicNumber(BaseModel):
maximum_atomic_number: int = 101
number_of_per_atom_features: int = 32

atomic_number: AtomicNumber = Field(default_factory=AtomicNumber)

number_of_radial_basis_functions: int
maximum_interaction_radius: float
Expand All @@ -286,12 +277,6 @@ class PostProcessingParameter(ParametersBase):

class SAKEParameters(ParametersBase):
class CoreParameter(ParametersBase):
class Featurization(BaseModel):
class AtomicNumber(BaseModel):
maximum_atomic_number: int = 101
number_of_per_atom_features: int = 32

atomic_number: AtomicNumber = Field(default_factory=AtomicNumber)

number_of_radial_basis_functions: int
maximum_interaction_radius: float
Expand Down
3 changes: 2 additions & 1 deletion modelforge/potential/schnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
SchNet neural network potential for modeling quantum interactions.
"""

from typing import Dict, Type
from typing import Dict, Type, List

import torch
import torch.nn as nn
Expand All @@ -14,6 +14,7 @@
class SchNetCore(torch.nn.Module):
def __init__(
self,
properties_to_featurize: List[str],
featurization: Dict[str, Dict[str, int]],
number_of_radial_basis_functions: int,
number_of_interaction_modules: int,
Expand Down
55 changes: 31 additions & 24 deletions modelforge/potential/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ class FeaturizeInput(nn.Module):
----------
_SUPPORTED_FEATURIZATION_TYPES : List[str]
The list of supported featurization types.
nuclear_charge_embedding : Embedding
atomic_number_embedding : Embedding
The embedding layer for nuclear charges.
append_to_embedding_tensor : nn.ModuleList
The list of modules to append to the embedding tensor.
Expand All @@ -242,7 +242,9 @@ class FeaturizeInput(nn.Module):
"spin_state",
]

def __init__(self, featurization_config: Dict[str, Dict[str, int]]) -> None:
def __init__(
self, featurization_config: Dict[str, Union[List[str], Dict[str, int]]]
) -> None:
"""
Initialize the FeaturizeInput class.
Expand Down Expand Up @@ -279,30 +281,33 @@ def __init__(self, featurization_config: Dict[str, Dict[str, int]]) -> None:
self.registered_embedding_operations: List[str] = []

self.increase_dim_of_embedded_tensor: int = 0

base_embedding_dim = int(featurization_config["atomic_number"][
"number_of_per_atom_features"
])
properties_to_featurize = featurization_config["properties_to_featurize"]
# iterate through the supported featurization types and check if one of
# these is requested
for featurization in self._SUPPORTED_FEATURIZATION_TYPES:
for featurization in properties_to_featurize:

# embed nuclear charges
# embed atomic number
if (
featurization == "atomic_number"
and featurization in featurization_config
and featurization in self._SUPPORTED_FEATURIZATION_TYPES
):
self.nuclear_charge_embedding = torch.nn.Embedding(
self.atomic_number_embedding = torch.nn.Embedding(
int(featurization_config[featurization]["maximum_atomic_number"]),
int(
featurization_config[featurization][
"number_of_per_atom_features"
]
),
)
self.registered_embedding_operations.append("nuclear_charge_embedding")
self.registered_embedding_operations.append("atomic_number")

# add total charge to embedding vector
if (
elif (
featurization == "per_molecule_total_charge"
and featurization in featurization_config
and featurization in self._SUPPORTED_FEATURIZATION_TYPES
):
# transform output o f embedding with shape (nr_atoms,
# nr_features) to (nr_atoms, nr_features + 1). The added
Expand All @@ -315,26 +320,32 @@ def __init__(self, featurization_config: Dict[str, Dict[str, int]]) -> None:
self.registered_appended_properties.append("total_charge")

# add partial charge to embedding vector
if (
elif (
featurization == "per_atom_partial_charge"
and featurization in featurization_config
): # transform output o f embedding with shape (nr_atoms, nr_features) to (nr_atoms, nr_features + 1).
# #The added features is the total charge (which will be transformed to a per-atom property)
and featurization in self._SUPPORTED_FEATURIZATION_TYPES
): # transform output of embedding with shape (nr_atoms, nr_features) to (nr_atoms, nr_features + 1).
# #The added features is the total charge (which will be
# transformed to a per-atom property)
self.append_to_embedding_tensor.append(
AddPerAtomValue("partial_charge")
)
self.increase_dim_of_embedded_tensor += 1
self.append_to_embedding_tensor("partial_charge")

else:
raise RuntimeError(
f"Unsupported featurization type {featurization}. Supported types are {self._SUPPORTED_FEATURIZATION_TYPES}"
)

# if only nuclear charges are embedded no mixing is performed
self.mixing: Union[nn.Identity, DenseWithCustomDist]
if self.increase_dim_of_embedded_tensor == 0:
self.mixing = nn.Identity()
else:
self.mixing = DenseWithCustomDist(
int(featurization_config["number_of_per_atom_features"])
base_embedding_dim
+ self.increase_dim_of_embedded_tensor,
int(featurization_config["number_of_per_atom_features"]),
base_embedding_dim,
)

def forward(self, data: NNPInputTuple) -> torch.Tensor:
Expand All @@ -353,19 +364,15 @@ def forward(self, data: NNPInputTuple) -> torch.Tensor:
"""

atomic_numbers = data.atomic_numbers
embedded_nuclear_charges = self.nuclear_charge_embedding(atomic_numbers)
categorial_embedding = self.atomic_number_embedding(atomic_numbers)

for additional_embedding in self.embeddings:
embedded_nuclear_charges = additional_embedding(
embedded_nuclear_charges, data
)
categorial_embedding = additional_embedding(categorial_embedding, data)

for append_embedding_vector in self.append_to_embedding_tensor:
embedded_nuclear_charges = append_embedding_vector(
embedded_nuclear_charges, data
)
categorial_embedding = append_embedding_vector(categorial_embedding, data)

return self.mixing(embedded_nuclear_charges)
return self.mixing(categorial_embedding)


import torch.nn.functional as F
Expand Down
2 changes: 1 addition & 1 deletion modelforge/tests/data/potential_defaults/ani2x.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ minimum_interaction_radius_for_angular_features = "0.8 angstrom"
angular_dist_divisions = 8

[potential.core_parameter.activation_function_parameter]
activation_function_name = "CeLU" # for the original ANI behavior please stick with CeLu since the alpha parameter is currently hard coded and might lead to different behavior when another activation function is used.
activation_function_name = "CeLU"

[potential.core_parameter.activation_function_parameter.activation_function_arguments]
alpha = 0.1
Expand Down
1 change: 1 addition & 0 deletions modelforge/tests/data/potential_defaults/painn.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ shared_filters = false
activation_function_name = "SiLU"

[potential.core_parameter.featurization]
properties_to_featurize = ['atomic_number']
[potential.core_parameter.featurization.atomic_number]
maximum_atomic_number = 101
number_of_per_atom_features = 32
Expand Down
1 change: 1 addition & 0 deletions modelforge/tests/data/potential_defaults/physnet.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ number_of_modules = 5
activation_function_name = "ShiftedSoftplus"

[potential.core_parameter.featurization]
properties_to_featurize = ['atomic_number']
[potential.core_parameter.featurization.atomic_number]
maximum_atomic_number = 101
number_of_per_atom_features = 32
Expand Down
1 change: 1 addition & 0 deletions modelforge/tests/data/potential_defaults/sake.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ number_of_spatial_attention_heads = 4
activation_function_name = "SiLU"

[potential.core_parameter.featurization]
properties_to_featurize = ['atomic_number']
[potential.core_parameter.featurization.atomic_number]
maximum_atomic_number = 101
number_of_per_atom_features = 11
Expand Down
1 change: 1 addition & 0 deletions modelforge/tests/data/potential_defaults/schnet.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ shared_interactions = false
activation_function_name = "ShiftedSoftplus"

[potential.core_parameter.featurization]
properties_to_featurize = ['atomic_number']
[potential.core_parameter.featurization.atomic_number]
maximum_atomic_number = 101
number_of_per_atom_features = 32
Expand Down
1 change: 1 addition & 0 deletions modelforge/tests/data/potential_defaults/tensornet.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ equivariance_invariance_group = "O(3)"
activation_function_name = "SiLU"

[potential.postprocessing_parameter]
properties_to_featurize = ['atomic_number']
[potential.postprocessing_parameter.per_atom_energy]
normalize = true
from_atom_to_molecule_reduction = true
Expand Down
20 changes: 7 additions & 13 deletions modelforge/tests/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,19 @@ def test_embedding(single_batch_with_batchsize):
config = load_configs_into_pydantic_models(f"{model_name.lower()}", "qm9")
featurization_config = config["potential"].core_parameter.featurization.model_dump()

# featurize the atomic input (default is only nuclear charge embedding)
# featurize the atomic input (default is only atomic number embedding)
from modelforge.potential.utils import FeaturizeInput

featurize_input_module = FeaturizeInput(featurization_config)

# mixing module should be the identidy operation since only nuclear charge
# is used
# mixing module should be the identity operation since only atomic number
# embedding is used
mixing_module = featurize_input_module.mixing
assert mixing_module.__module__ == "torch.nn.modules.linear"
mixing_module_name = str(mixing_module)

# only nucreal charges embedded
assert (
"nuclear_charge_embedding"
in featurize_input_module.registered_embedding_operations
)
# only atomic number embedded
assert "atomic_number" in featurize_input_module.registered_embedding_operations
assert len(featurize_input_module.registered_embedding_operations) == 1
# no mixing
assert "Identity()" in mixing_module_name
Expand All @@ -40,11 +37,8 @@ def test_embedding(single_batch_with_batchsize):
featurization_config["properties_to_featurize"].append("per_molecule_total_charge")
featurize_input_module = FeaturizeInput(featurization_config)

# only nuclear charges embedded
assert (
"nuclear_charge_embedding"
in featurize_input_module.registered_embedding_operations
)
# only atomic number embedded
assert "atomic_number" in featurize_input_module.registered_embedding_operations
assert len(featurize_input_module.registered_embedding_operations) == 1
# total charge is added to feature vector
assert "total_charge" in featurize_input_module.registered_appended_properties
Expand Down

0 comments on commit cdc6ca3

Please sign in to comment.