Skip to content

Commit

Permalink
Merge pull request #236 from choderalab/bugfix-physnet
Browse files Browse the repository at this point in the history
Bugfix PhysNet
  • Loading branch information
wiederm authored Aug 19, 2024
2 parents 741f390 + 18b42cd commit 8c5fa6f
Showing 1 changed file with 80 additions and 129 deletions.
209 changes: 80 additions & 129 deletions modelforge/potential/physnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,40 @@
"""

from dataclasses import dataclass, field
from typing import Dict, Optional, Union, List, Dict, Type
from typing import Dict, List, Optional, Type, Union

import torch
from loguru import logger as log
from openff.units import unit
from torch import nn
from .models import PairListOutputs, NNPInput, BaseNetwork, CoreNetwork

from modelforge.potential.utils import NeuralNetworkData
from modelforge.potential.utils import NeuralNetworkData, shared_config_prior
from modelforge.utils.io import import_
from modelforge.utils.units import _convert_str_to_unit

from .models import BaseNetwork, CoreNetwork, NNPInput, PairListOutputs
from .utils import Dense


@dataclass
class PhysNetNeuralNetworkData(NeuralNetworkData):
"""
A dataclass to structure the inputs for PhysNet-based neural network potentials,
facilitating the efficient and structured representation of atomic systems for
energy computation and property prediction within the PhysNet framework.
A dataclass to structure the inputs for PhysNet-based neural network
potentials, facilitating the efficient and structured representation of
atomic systems for energy computation and property prediction within the
PhysNet framework.
Attributes
----------
atomic_embedding : torch.Tensor
A 2D tensor containing embeddings or features for each atom, derived from atomic numbers or other properties.
Shape: [num_atoms, embedding_dim].
A 2D tensor containing embeddings or features for each atom, derived
from atomic numbers or other properties. Shape: [num_atoms,
embedding_dim].
f_ij : Optional[torch.Tensor]
A tensor representing the radial basis function (RBF) expansion applied to distances between atom pairs,
capturing the local chemical environment. Will be added after initialization. Shape: [num_pairs, num_rbf].
A tensor representing the radial basis function (RBF) expansion applied
to distances between atom pairs, capturing the local chemical
environment. Will be added after initialization. Shape: [num_pairs,
num_rbf].
"""

atomic_embedding: Optional[torch.Tensor] = field(default=None)
Expand Down Expand Up @@ -64,9 +72,10 @@ def __init__(
self.cutoff_module = CosineAttenuationFunction(maximum_interaction_radius)

# Initialize radial symmetry function module
from .utils import PhysNetRadialBasisFunction
from modelforge.potential.utils import FeaturizeInput

from .utils import PhysNetRadialBasisFunction

self.featurize_input = FeaturizeInput(featurization_config)

self.radial_symmetry_function_module = PhysNetRadialBasisFunction(
Expand Down Expand Up @@ -100,52 +109,23 @@ def forward(self, data: Type[PhysNetNeuralNetworkData]) -> Dict[str, torch.Tenso
}


class GatingModule(nn.Module):
def __init__(self, number_of_atom_basis: int):
"""
Initializes a gating module that optionally applies a sigmoid gating mechanism to input features.
Parameters
----------
number_of_atom_basis : int
The dimensionality of the input (and output) features.
"""
super().__init__()
self.gate = nn.Parameter(torch.ones(number_of_atom_basis))

def forward(self, x: torch.Tensor, activation_fn: bool = False) -> torch.Tensor:
"""
Apply gating to the input tensor.
Parameters:
-----------
x : torch.Tensor
The input tensor to gate.
Returns:
--------
torch.Tensor
The gated input tensor.
"""
gating_signal = torch.sigmoid(self.gate)
return gating_signal * x


from .utils import DenseWithCustomDist


class PhysNetResidual(nn.Module):
"""
Implements a preactivation residual block as described in Equation 4 of the PhysNet paper.
Implements a preactivation residual block as described in Equation 4 of the
PhysNet paper.
The block refines atomic feature vectors by adding a residual component computed through two linear transformations and a non-linear activation function (Softplus). This setup enhances gradient flow and supports effective deep network training by employing a preactivation scheme.
The block refines atomic feature vectors by adding a residual component
computed through two linear transformations and a non-linear activation
function (Softplus). This setup enhances gradient flow and supports
effective deep network training by employing a preactivation scheme.
Parameters
----------
input_dim : int
Dimensionality of the input feature vector.
output_dim : int
Dimensionality of the output feature vector, which typically matches the input dimension.
Dimensionality of the output feature vector, which typically matches the
input dimension.
activation_function : Type[torch.nn.Module]
The activation function to be used in the residual block.
"""
Expand All @@ -158,27 +138,28 @@ def __init__(
):
super().__init__()
# Initialize dense layers and residual connection
self.dense = DenseWithCustomDist(
input_dim, output_dim, activation_function=activation_function

self.dense = nn.Sequential(
activation_function,
Dense(input_dim, output_dim, activation_function),
Dense(output_dim, output_dim),
)
self.residual = DenseWithCustomDist(output_dim, output_dim)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the ResidualBlock.
Parameters:
-----------
x: torch.Tensor
Parameters
----------
x : torch.Tensor
Input tensor containing feature vectors of atoms.
Returns:
--------
Returns
-------
torch.Tensor
Output tensor after applying the residual block operations.
"""
# update x with residual
return x + self.residual(self.dense(x))
return x + self.dense(x)


class PhysNetInteractionModule(nn.Module):
Expand Down Expand Up @@ -219,21 +200,19 @@ def __init__(
)

# Initialize networks for processing atomic embeddings of i and j atoms
self.interaction_i = DenseWithCustomDist(
self.interaction_i = Dense(
number_of_per_atom_features,
number_of_per_atom_features,
activation_function=activation_function,
)
self.interaction_j = DenseWithCustomDist(
self.interaction_j = Dense(
number_of_per_atom_features,
number_of_per_atom_features,
activation_function=activation_function,
)

# Initialize processing network
self.process_v = DenseWithCustomDist(
number_of_per_atom_features, number_of_per_atom_features
)
self.process_v = Dense(number_of_per_atom_features, number_of_per_atom_features)

# Initialize residual blocks
self.residuals = nn.ModuleList(
Expand All @@ -253,69 +232,68 @@ def __init__(

def forward(self, data: PhysNetNeuralNetworkData) -> torch.Tensor:
"""
Processes input tensors through the interaction module, applying Gaussian Logarithm Attention to modulate
the influence of pairwise distances on the interaction features, followed by aggregation to update atomic embeddings.
Processes input tensors through the interaction module, applying
Gaussian Logarithm Attention to modulate the influence of pairwise
distances on the interaction features, followed by aggregation to update
atomic embeddings.
Parameters
----------
data : PhysNetNeuralNetworkData
Input data containing pair indices, distances, and atomic embeddings.
Input data containing pair indices, distances, and atomic
embeddings.
Returns
-------
torch.Tensor
Updated atomic feature representations incorporating interaction information.
Updated atomic feature representations incorporating interaction
information.
"""
# Equation 6: Formation of the Proto-Message ṽ_i for an Atom i ṽ_i =
# σ(Wl_I * x_i^l + bl_I) + Σ_j (G_g * Wl * (σ(σl_J * x_j^l + bl_J)) *
# g(r_ij))
#
# Equation 6 implementation overview: ṽ_i = x_i_prime +
# sum_over_j(x_j_prime * f_ij_prime) where:
# - x_i_prime and x_j_prime are the features of atoms i and j,
# respectively, processed through separate networks.
# - f_ij_prime represents the modulated radial basis functions (f_ij) by
# the Gaussian Logarithm Attention weights.

# extract relevant variables
idx_i, idx_j = data.pair_indices
f_ij = data.f_ij
x = data.atomic_embedding
idx_i, idx_j = data.pair_indices # (nr_of_pairs, 2)
f_ij = data.f_ij # (nr_of_pairs, number_of_radial_basis_functions)

# # Apply activation to atomic embeddings
xa = self.dropout(self.activation_function(x))
per_atom_embedding = self.activation_function(
data.atomic_embedding
) # (nr_of_atoms_in_batch, number_of_per_atom_features)

# calculate attention weights and transform to
# input shape: (number_of_pairs, number_of_radial_basis_functions)
# output shape: (number_of_pairs, number_of_per_atom_features)
g = self.attention_mask(f_ij)

# Calculate contribution of central atom
x_i = self.interaction_i(xa)
# Calculate contribution of central atom i
per_atom_updated_embedding = self.interaction_i(per_atom_embedding)

# Calculate contribution of neighbor atom
x_j = self.interaction_j(xa)
# Gather the results according to idx_j
x_j = x_j[idx_j]
# Multiply the gathered features by g
x_j_modulated = x_j * g
# Aggregate modulated contributions for each atom i
x_j_prime = torch.zeros_like(x_i)
x_j_prime.scatter_add_(
0, idx_i.unsqueeze(-1).expand(-1, x_j_modulated.size(-1)), x_j_modulated
per_interaction_embededding_for_atom_j = (
self.interaction_j(per_atom_embedding[idx_j]) * g
)

per_atom_updated_embedding.scatter_add_(
0,
idx_i.unsqueeze(-1).expand(
-1, per_interaction_embededding_for_atom_j.shape[-1]
),
per_interaction_embededding_for_atom_j,
)

# Draft proto message v_tilde
m = x_i + x_j_prime
# shape of m (nr_of_atoms_in_batch, 1)
# Equation 4: Preactivation Residual Block Implementation
# xl+2_i = xl_i + Wl+1 * sigma(Wl * xl_i + bl) + bl+1
# apply residual blocks
for residual in self.residuals:
m = residual(
m
per_atom_updated_embedding = residual(
per_atom_updated_embedding
) # shape (nr_of_atoms_in_batch, number_of_radial_basis_functions)
m = self.activation_function(m)
x = self.gate * x + self.process_v(m)
return x

per_atom_updated_embedding = self.activation_function(
per_atom_updated_embedding
)

per_atom_embedding = self.gate * per_atom_embedding + self.process_v(
per_atom_updated_embedding
)
return per_atom_embedding


class PhysNetOutput(nn.Module):
Expand Down Expand Up @@ -359,7 +337,6 @@ def __init__(
number_of_per_atom_features,
number_of_atomic_properties,
weight_init=torch.nn.init.zeros_,
bias=False,
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -433,25 +410,6 @@ def forward(self, data: PhysNetNeuralNetworkData) -> Dict[str, torch.Tensor]:
Dict[str, torch.Tensor]
Dictionary containing predictions and updated embeddings.
"""
# The PhysNet module is a sequence of interaction modules and residual modules.
# x_1, ..., x_N
# |
# v
# ┌─────────────┐
# │ interaction │ <-- g(d_ij)
# └─────────────┘
# │
# v
# ┌───────────┐
# │ residual │
# └───────────┘
# ┌───────────┐
# │ residual │
# └───────────┘
# ┌───────────┐ │
# │ output │<-----│
# └───────────┘ │
# v

# calculate the interaction
v = self.interaction(data)
Expand Down Expand Up @@ -645,13 +603,6 @@ def compute_properties(
}


from .models import NNPInput, BaseNetwork
from typing import List
from modelforge.utils.units import _convert_str_to_unit
from modelforge.utils.io import import_
from modelforge.potential.utils import shared_config_prior


class PhysNet(BaseNetwork):
"""
Implementation of the PhysNet neural network potential.
Expand Down

0 comments on commit 8c5fa6f

Please sign in to comment.