Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
wiederm committed Aug 26, 2024
1 parent 701122d commit 42dc20a
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 5 deletions.
7 changes: 5 additions & 2 deletions modelforge/tests/test_parameter_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,9 @@ def test_training_parameter_model():
with pytest.raises(ValidationError):
training_parameters.splitting_strategy.dataset_split = [0.7, 0.1, 0.1, 0.1]

# this will throw an error because the datafile has 2 entries for the loss_property dictionary
# this will throw an error because the datafile has 1 entries for the loss_property dictionary
with pytest.raises(ValidationError):
training_parameters.loss_parameter.loss_property = ["per_molecule_energy"]
training_parameters.loss_parameter.loss_property = [
"per_molecule_energy",
"per_atom_force",
]
2 changes: 1 addition & 1 deletion modelforge/tests/test_sake.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,7 @@ def test_model_invariance(single_batch_with_batchsize):
],
)
# get methane input
batch = single_batch_with_batchsize(batch_size=1)
batch = single_batch_with_batchsize(batch_size=1, dataset_name="QM9")
methane = batch.nnp_input

rotation_matrix = torch.tensor([[0.0, 1.0, 0.0], [-1.0, 0.0, 0.0], [0.0, 0.0, 1.0]])
Expand Down
4 changes: 2 additions & 2 deletions modelforge/tests/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def test_error_calculation(single_batch_with_batchsize):
1
) # FIXME : fi
reference_E_error = torch.mean(scale_squared_error)
assert torch.allclose(E_error, reference_E_error)
assert torch.allclose(torch.mean(E_error), reference_E_error)

# test error for property with shape (nr_of_atoms, 3)
error = FromPerAtomToPerMoleculeSquaredError()
Expand All @@ -170,7 +170,7 @@ def test_error_calculation(single_batch_with_batchsize):
reference_F_error = torch.mean(
per_mol_error / (3 * data.metadata.atomic_subsystem_counts.unsqueeze(1))
)
assert torch.allclose(F_error, reference_F_error)
assert torch.allclose(torch.mean(F_error), reference_F_error)


def test_loss(single_batch_with_batchsize):
Expand Down

0 comments on commit 42dc20a

Please sign in to comment.