diff --git a/mlspm/cli.py b/mlspm/cli.py index 5770f66..409287f 100644 --- a/mlspm/cli.py +++ b/mlspm/cli.py @@ -1,4 +1,5 @@ import argparse +from typing import Optional def _bool_type(value): @@ -9,10 +10,13 @@ def _bool_type(value): raise KeyError(f"`{value}` can't be interpreted as a boolean.") -def parse_args() -> dict: +def parse_args(argv: Optional[list[str]] = None) -> dict: """ Parse some useful CLI arguments for use in training scripts. + Arguments: + argv: List of argument values. Defaults to ``sys.argv``. + Returns: A dictionary of the argument values. """ @@ -68,5 +72,5 @@ def parse_args() -> dict: parser.add_argument( "--avg_best_epochs", type=int, default=3, help="Number of epochs to average the best validation loss over. Default = 3." ) - args = parser.parse_args() + args = parser.parse_args(argv) return vars(args) diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000..c374fa3 --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,16 @@ + +import pytest +import mlspm.cli + +def test_parse_args(): + from mlspm.cli import parse_args + + args = parse_args(["--train", "false", "--predict", "False", '--test', "true", "--classes", "1,2,3", "4,5,6"]) + + assert args["train"] == False + assert args["predict"] == False + assert args["test"] == True + assert args["classes"] == [[1, 2, 3], [4, 5, 6]] + + with pytest.raises(KeyError): + parse_args(["--train", "fals"]) diff --git a/tests/test_models.py b/tests/test_models.py index 03e785f..69ad5f2 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -170,8 +170,8 @@ def test_GraphImgNet(): # Test that the edges are not directional node_features_reverse, edge_features_reverse = model.mpnn(pos, x_afm, edges_combined[[1, 0]]) - assert torch.allclose(node_features, node_features_reverse) - assert torch.allclose(edge_features, edge_features_reverse) + assert torch.allclose(node_features, node_features_reverse, rtol=1e-4, atol=1e-6) + assert torch.allclose(edge_features, edge_features_reverse, rtol=1e-4, atol=1e-6) # Test whole model model.afm_cutoff = 0.8