Skip to content

Commit

Permalink
Added test for cli.
Browse files Browse the repository at this point in the history
  • Loading branch information
NikoOinonen committed Nov 24, 2023
1 parent a1f5f3b commit cce2956
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 4 deletions.
8 changes: 6 additions & 2 deletions mlspm/cli.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
from typing import Optional


def _bool_type(value):
Expand All @@ -9,10 +10,13 @@ def _bool_type(value):
raise KeyError(f"`{value}` can't be interpreted as a boolean.")


def parse_args() -> dict:
def parse_args(argv: Optional[list[str]] = None) -> dict:
"""
Parse some useful CLI arguments for use in training scripts.
Arguments:
argv: List of argument values. Defaults to ``sys.argv``.
Returns:
A dictionary of the argument values.
"""
Expand Down Expand Up @@ -68,5 +72,5 @@ def parse_args() -> dict:
parser.add_argument(
"--avg_best_epochs", type=int, default=3, help="Number of epochs to average the best validation loss over. Default = 3."
)
args = parser.parse_args()
args = parser.parse_args(argv)
return vars(args)
16 changes: 16 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@

import pytest
import mlspm.cli

def test_parse_args():
from mlspm.cli import parse_args

args = parse_args(["--train", "false", "--predict", "False", '--test', "true", "--classes", "1,2,3", "4,5,6"])

assert args["train"] == False
assert args["predict"] == False
assert args["test"] == True
assert args["classes"] == [[1, 2, 3], [4, 5, 6]]

with pytest.raises(KeyError):
parse_args(["--train", "fals"])
4 changes: 2 additions & 2 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,8 @@ def test_GraphImgNet():

# Test that the edges are not directional
node_features_reverse, edge_features_reverse = model.mpnn(pos, x_afm, edges_combined[[1, 0]])
assert torch.allclose(node_features, node_features_reverse)
assert torch.allclose(edge_features, edge_features_reverse)
assert torch.allclose(node_features, node_features_reverse, rtol=1e-4, atol=1e-6)
assert torch.allclose(edge_features, edge_features_reverse, rtol=1e-4, atol=1e-6)

# Test whole model
model.afm_cutoff = 0.8
Expand Down

0 comments on commit cce2956

Please sign in to comment.