diff --git a/modelforge/potential/aimnet2.py b/modelforge/potential/aimnet2.py index 1fee22ac..744e4aea 100644 --- a/modelforge/potential/aimnet2.py +++ b/modelforge/potential/aimnet2.py @@ -314,6 +314,8 @@ def calculate_contributions( # Accumulate the vector contributions using index_add_ vector_contributions.index_add_(0, idx_j, vector_prot_step2) + if torch.isnan(vector_contributions).any(): + raise ValueError("NaN values detected in vector_contributions.") # Step 3: Compute the Euclidean Norm for each atom vector_norms = torch.norm(