Skip to content

Commit

Permalink
notes on the envelope function
Browse files Browse the repository at this point in the history
  • Loading branch information
wiederm committed Oct 23, 2024
1 parent f762f22 commit d032727
Showing 1 changed file with 28 additions and 8 deletions.
36 changes: 28 additions & 8 deletions modelforge/tests/test_dimenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ def test_envelope():
# Forward pass
outputs = envelope(inputs)
assert outputs.shape == inputs.shape
assert torch.allclose(outputs, torch.tensor([1.7109, 0.2539, 0.0000, 0.0000]), rtol=1e-3)
assert torch.allclose(
outputs, torch.tensor([1.7109, 0.2539, 0.0000, 0.0000]), rtol=1e-3
)

# Script the model for optimization and deployment
scripted_envelope = torch.jit.script(envelope)
Expand All @@ -70,12 +72,26 @@ def test_envelope():
print(outputs_scripted)
assert torch.allclose(outputs, outputs_scripted)

# ----------------------------------------- #
# test for correct output computation
exponent=6 # Envelop function receives exponent = 5 but takes its increment and uses exponent = 6 -> why?
d_ij=0.5
u_05 = 1 - (exponent+1)*(exponent+2)/2*d_ij**exponent + exponent*(exponent+2)*d_ij**(exponent+1)- exponent*(exponent+1)/2*d_ij**(exponent+2)
u_05/=d_ij # this should not be done, but the Envelope function does this

exponent = (
5 + 1
) # NOTE: Envelop function receives exponent = 5 but takes its increment and uses exponent = 6 FIXME: that seems strange ?
# start with test for float:
d_ij = 0.5
# generate envelope function of d_ij value
u_05 = (
1
- (exponent + 1) * (exponent + 2) / 2 * d_ij**exponent
+ exponent * (exponent + 2) * d_ij ** (exponent + 1)
- exponent * (exponent + 1) / 2 * d_ij ** (exponent + 2)
)
u_05 /= d_ij # NOTE: this is not in the paper, but in the DimNet++ implementation

# NOTE: this test passes, but only if you divide by d_ij at the end, which is not in the paper, but in the DimNet++ implementation
u_05 = torch.tensor([u_05], dtype=torch.float32)

assert torch.allclose(u_05, outputs[0], rtol=1e-3)


Expand All @@ -94,12 +110,16 @@ def test_bessel_basis():

# Sample input tensor of distances
num_pairs = 100
d_ij = torch.linspace(0, radial_cutoff, steps=num_pairs).unsqueeze(-1) # Shape: (100,1)
d_ij = torch.linspace(0, radial_cutoff, steps=num_pairs).unsqueeze(
-1
) # Shape: (100,1)

# Forward pass
outputs = bessel_layer(d_ij) # Shape: (100, num_radial)
shape_tensor = torch.randn(num_pairs,num_radial) #output from besser_layer should have this size
assert shape_tensor.shape == outputs.shape # Should print: torch.Size([100, 6])
shape_tensor = torch.randn(
num_pairs, num_radial
) # output from besser_layer should have this size
assert shape_tensor.shape == outputs.shape # Should print: torch.Size([100, 6])


def test_representation():
Expand Down

0 comments on commit d032727

Please sign in to comment.