Skip to content

Commit

Permalink
updated docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
BowenD-UCB committed Mar 1, 2023
1 parent 9486f99 commit 68373c7
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 12 deletions.
41 changes: 30 additions & 11 deletions chgnet/model/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,15 @@ class Fourier(nn.Module):
Fourier Expansion for angle feautures
"""
def __init__(self, order: int = 5, learnable: bool = False):
"""
Initialize the Fourier expansion
Args:
order (int): the maximum order, refer to the N in eq 1 in CHGNet paper
Default = 5
learnable (bool): whether to set the frequencies learnanle
Default = False
"""
super().__init__()
self.order = order
# Initialize frequencies at canonical
Expand All @@ -27,31 +36,36 @@ def forward(self, x):
result[:, 0] = 1 / torch.sqrt(torch.tensor([2]))
tmp = torch.outer(x, self.frequencies)
result[:, 1 : self.order + 1] = torch.sin(tmp)
result[:, self.order + 1 :] = torch.cos(tmp)
result[:, self.order + 1:] = torch.cos(tmp)
return result / np.sqrt(np.pi)


class RadialBessel(torch.nn.Module):
"""
1D Bessel Basis
from: https://github.com/TUM-DAML/gemnet_pytorch/
Parameters
----------
num_radial: int
Controls maximum frequency.
cutoff: float
Cutoff distance in Angstrom.
envelope_exponent: int = 5
Exponent of the envelope function.
"""

def __init__(
self,
num_radial: int,
num_radial: int = 9,
cutoff: float = 5,
learnable: bool = False,
smooth_cutoff: int = None,
smooth_cutoff: int = 5,
):
"""
Initialize the SmoothRBF function
Args:
num_radial (int): Controls maximum frequency
Default = 9
cutoff (float): Cutoff distance in Angstrom.
Default = 5
learnable (bool): whether to set the frequencies learnanle
Default = False
smooth_cutoff (int): smooth cutoff strength
Default = 5
"""
super().__init__()
self.num_radial = num_radial
self.inv_cutoff = 1 / cutoff
Expand Down Expand Up @@ -104,6 +118,7 @@ def __init__(
"""
Gaussian Expansion
expand a scalar feature to a soft-one-hot feature vector
Args:
min (float): minimum Gaussian center value
max (float): maximum Gaussian center value
Expand All @@ -122,8 +137,10 @@ def __init__(
def expand(self, features: Tensor) -> Tensor:
"""
Apply Gaussian filter to a feature Tensor
Args:
features (torch.Tensor): tensor of features [n]
Returns:
expanded features (torch.Tensor): tensor of Gaussian distances [n, dim]
where the expanded dimension will be (dmax-dmin)/step + 1
Expand Down Expand Up @@ -159,8 +176,10 @@ def __init__(self, cutoff: float = 5, cutoff_coeff: float = 5):
def forward(self, r: Tensor) -> Tensor:
"""
Polynomial cutoff function
Args:
r (Tensor): radius distance tensor
Returns:
polynomial cutoff functions: decaying from 1 at r=0 to 0 at r=cutoff
"""
Expand Down
2 changes: 1 addition & 1 deletion chgnet/model/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ class AngleEncoder(nn.Module):
Encode an angle given the two bond vectors using Fourier Expansion.
"""

def __init__(self, num_angular: int = 21, learnable: bool = True):
def __init__(self, num_angular: int = 9, learnable: bool = True):
"""
Initialize the angle encoder
Expand Down

0 comments on commit 68373c7

Please sign in to comment.