From d03272706b870d7be82546bd7f3bcb996e28c324 Mon Sep 17 00:00:00 2001 From: wiederm Date: Wed, 23 Oct 2024 11:38:45 +0200 Subject: [PATCH] notes on the envelope function --- modelforge/tests/test_dimenet.py | 36 +++++++++++++++++++++++++------- 1 file changed, 28 insertions(+), 8 deletions(-) diff --git a/modelforge/tests/test_dimenet.py b/modelforge/tests/test_dimenet.py index 43a3f957..f3dcc6c6 100644 --- a/modelforge/tests/test_dimenet.py +++ b/modelforge/tests/test_dimenet.py @@ -60,7 +60,9 @@ def test_envelope(): # Forward pass outputs = envelope(inputs) assert outputs.shape == inputs.shape - assert torch.allclose(outputs, torch.tensor([1.7109, 0.2539, 0.0000, 0.0000]), rtol=1e-3) + assert torch.allclose( + outputs, torch.tensor([1.7109, 0.2539, 0.0000, 0.0000]), rtol=1e-3 + ) # Script the model for optimization and deployment scripted_envelope = torch.jit.script(envelope) @@ -70,12 +72,26 @@ def test_envelope(): print(outputs_scripted) assert torch.allclose(outputs, outputs_scripted) + # ----------------------------------------- # # test for correct output computation - exponent=6 # Envelop function receives exponent = 5 but takes its increment and uses exponent = 6 -> why? - d_ij=0.5 - u_05 = 1 - (exponent+1)*(exponent+2)/2*d_ij**exponent + exponent*(exponent+2)*d_ij**(exponent+1)- exponent*(exponent+1)/2*d_ij**(exponent+2) - u_05/=d_ij # this should not be done, but the Envelope function does this + + exponent = ( + 5 + 1 + ) # NOTE: Envelop function receives exponent = 5 but takes its increment and uses exponent = 6 FIXME: that seems strange ? + # start with test for float: + d_ij = 0.5 + # generate envelope function of d_ij value + u_05 = ( + 1 + - (exponent + 1) * (exponent + 2) / 2 * d_ij**exponent + + exponent * (exponent + 2) * d_ij ** (exponent + 1) + - exponent * (exponent + 1) / 2 * d_ij ** (exponent + 2) + ) + u_05 /= d_ij # NOTE: this is not in the paper, but in the DimNet++ implementation + + # NOTE: this test passes, but only if you divide by d_ij at the end, which is not in the paper, but in the DimNet++ implementation u_05 = torch.tensor([u_05], dtype=torch.float32) + assert torch.allclose(u_05, outputs[0], rtol=1e-3) @@ -94,12 +110,16 @@ def test_bessel_basis(): # Sample input tensor of distances num_pairs = 100 - d_ij = torch.linspace(0, radial_cutoff, steps=num_pairs).unsqueeze(-1) # Shape: (100,1) + d_ij = torch.linspace(0, radial_cutoff, steps=num_pairs).unsqueeze( + -1 + ) # Shape: (100,1) # Forward pass outputs = bessel_layer(d_ij) # Shape: (100, num_radial) - shape_tensor = torch.randn(num_pairs,num_radial) #output from besser_layer should have this size - assert shape_tensor.shape == outputs.shape # Should print: torch.Size([100, 6]) + shape_tensor = torch.randn( + num_pairs, num_radial + ) # output from besser_layer should have this size + assert shape_tensor.shape == outputs.shape # Should print: torch.Size([100, 6]) def test_representation():