Skip to content

Commit

Permalink
adopt schnet for multiple output heads
Browse files Browse the repository at this point in the history
  • Loading branch information
wiederm committed Aug 16, 2024
1 parent f2e71d6 commit 672d2cc
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 5 deletions.
11 changes: 11 additions & 0 deletions modelforge/potential/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,16 @@ class ActivationFunctionName(CaseInsensitiveEnum):
ELU = "ELU"


class OutputTypeEnum(CaseInsensitiveEnum):
scalar = "scalar"
vector = "vector"


class PredictedPropertiesParameter(ParametersBase):
name: str
type: OutputTypeEnum


# this enum will tell us if we need to pass additional parameters to the activation function
class ActivationFunctionParamsEnum(CaseInsensitiveEnum):
ReLU = "None"
Expand Down Expand Up @@ -170,6 +180,7 @@ class Featurization(ParametersBase):
shared_interactions: bool
activation_function_parameter: ActivationFunctionConfig
featurization: Featurization
predicted_properties: List[PredictedPropertiesParameter] # Add the outputs here

converted_units = field_validator("maximum_interaction_radius")(
_convert_str_to_unit
Expand Down
34 changes: 30 additions & 4 deletions modelforge/potential/schnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __init__(
shared_interactions: bool,
activation_function: Type[torch.nn.Module],
maximum_interaction_radius: unit.Quantity,
predicted_properties: List[Dict[str, str]],
) -> None:

log.debug("Initializing the SchNet architecture.")
Expand Down Expand Up @@ -132,6 +133,26 @@ def __init__(
1,
),
)
# Initialize output layers based on configuration
self.output_layers = nn.ModuleDict()
for property in predicted_properties:
output_name = property["name"]
output_type = property["type"]
output_dimension = (
1 if output_type == "scalar" else 3
) # vector means 3D output

self.output_layers[output_name] = nn.Sequential(
DenseWithCustomDist(
number_of_per_atom_features,
number_of_per_atom_features,
activation_function=self.activation_function,
),
DenseWithCustomDist(
number_of_per_atom_features,
output_dimension,
),
)

def _model_specific_input_preparation(
self, data: "NNPInput", pairlist_output: PairListOutputs
Expand Down Expand Up @@ -197,14 +218,17 @@ def compute_properties(
atomic_embedding + v
) # Update per atom features given the environment

E_i = self.energy_layer(atomic_embedding).squeeze(1)

return {
"per_atom_energy": E_i,
results = {
"per_atom_scalar_representation": atomic_embedding,
"atomic_subsystem_indices": data.atomic_subsystem_indices,
}

# Compute all specified outputs
for output_name, output_layer in self.output_layers.items():
results[output_name] = output_layer(atomic_embedding).squeeze(-1)

return results


class SchNETInteractionModule(nn.Module):
"""
Expand Down Expand Up @@ -442,6 +466,7 @@ def __init__(
activation_function_parameter: Dict,
shared_interactions: bool,
postprocessing_parameter: Dict[str, Dict[str, bool]],
predicted_properties: List[Dict[str, str]],
dataset_statistic: Optional[Dict[str, float]] = None,
potential_seed: Optional[int] = None,
) -> None:
Expand All @@ -465,6 +490,7 @@ def __init__(
shared_interactions=shared_interactions,
activation_function=activation_function,
maximum_interaction_radius=_convert_str_to_unit(maximum_interaction_radius),
predicted_properties=predicted_properties,
)

def _config_prior(self):
Expand Down
13 changes: 13 additions & 0 deletions modelforge/tests/data/potential_defaults/schnet.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ maximum_interaction_radius = "5.0 angstrom"
number_of_interaction_modules = 3
number_of_filters = 32
shared_interactions = false
predicted_properties = [
{ name = "per_atom_energy", type = "scalar" },
{ name = "per_atom_charge", type = "scalar" },
# we could also define per_atom_force to be consistent?
]

[potential.core_parameter.activation_function_parameter]
activation_function_name = "ShiftedSoftplus"
Expand All @@ -16,10 +21,18 @@ properties_to_featurize = ['atomic_number']
maximum_atomic_number = 101
number_of_per_atom_features = 32


[potential.postprocessing_parameter]
[potential.postprocessing_parameter.per_atom_energy]
normalize = true
from_atom_to_molecule_reduction = true
keep_per_atom_property = true

[potential.postprocessing_parameter.per_atom_charge]
normalize = true
from_atom_to_molecule_reduction = false
keep_per_atom_property = true


[potential.postprocessing_parameter.general_postprocessing_operation]
calculate_molecular_self_energy = true
6 changes: 5 additions & 1 deletion modelforge/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def load_configs_into_pydantic_models(potential_name: str, dataset_name: str):
dataset_config_dict = toml.load(dataset_path)
potential_config_dict = toml.load(potential_path)
runtime_config_dict = toml.load(runtime_path)

print(potential_config_dict)
potential_name = potential_config_dict["potential"]["potential_name"]

from modelforge.potential import _Implemented_NNP_Parameters
Expand Down Expand Up @@ -446,6 +446,10 @@ def test_forward_pass(
# test that we get an energie per molecule
assert len(output["per_molecule_energy"]) == nr_of_mols

# check that per-atom charge/energies are present
assert "per_atom_energy" in output
assert "per_atom_charge" in output

# the batch consists of methane (CH4) and amamonium (NH3)
# which have chemically equivalent hydrogens at the minimum geometry.
# This has to be reflected in the atomic energies E_i, which
Expand Down

0 comments on commit 672d2cc

Please sign in to comment.