diff --git a/casanovo/config.py b/casanovo/config.py index 0274a1b1..c07073d6 100644 --- a/casanovo/config.py +++ b/casanovo/config.py @@ -83,7 +83,17 @@ def __init__(self, config_file: Optional[str] = None): else: with Path(config_file).open() as f_in: self._user_config = yaml.safe_load(f_in) - + # check for missing entries in config file + if len(self._user_config.keys()) < len(self._params.keys()): + keys_set = set(self._params.keys()) + users_set = set(self._user_config.keys()) + missing = list(keys_set - users_set) + raise KeyError(f"Missing expected entry {missing}") + # detect unrecognized config file entries + keys = list(self._params.keys()) + for key, val in self._user_config.items(): + if key not in keys: + raise KeyError(f"Unrecognized config file entry {key}") # Validate: for key, val in self._config_types.items(): self.validate_param(key, val) diff --git a/tests/conftest.py b/tests/conftest.py index 6dcda9c9..267dfa0f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -188,7 +188,7 @@ def tiny_config(tmp_path): """A config file for a tiny model.""" cfg = { "n_head": 2, - "dim_feedfoward": 10, + "dim_feedforward": 10, "n_layers": 1, "warmup_iters": 1, "max_iters": 1, @@ -196,6 +196,62 @@ def tiny_config(tmp_path): "val_check_interval": 1, "model_save_folder_path": str(tmp_path), "accelerator": "cpu", + "precursor_mass_tol": 5, + "isotope_error_range": [0, 1], + "min_peptide_len": 6, + "predict_batch_size": 1024, + "n_beams": 1, + "top_match": 1, + "devices": None, + "random_seed": 454, + "n_log": 1, + "tb_summarywriter": None, + "save_top_k": 5, + "n_peaks": 150, + "min_mz": 50.0, + "max_mz": 2500.0, + "min_intensity": 0.01, + "remove_precursor_tol": 2.0, + "max_charge": 10, + "dim_model": 512, + "dropout": 0.0, + "dim_intensity": None, + "max_length": 100, + "learning_rate": 5e-4, + "weight_decay": 1e-5, + "train_batch_size": 32, + "num_sanity_val_steps": 0, + "train_from_scratch": True, + "calculate_precision": False, + "residues": { + "G": 57.021464, + "A": 71.037114, + "S": 87.032028, + "P": 97.052764, + "V": 99.068414, + "T": 101.047670, + "C+57.021": 160.030649, + "L": 113.084064, + "I": 113.084064, + "N": 114.042927, + "D": 115.026943, + "Q": 128.058578, + "K": 128.094963, + "E": 129.042593, + "M": 131.040485, + "H": 137.058912, + "F": 147.068414, + "R": 156.101111, + "Y": 163.063329, + "W": 186.079313, + "M+15.995": 147.035400, + "N+0.984": 115.026943, + "Q+0.984": 129.042594, + "+42.011": 42.010565, + "+43.006": 43.005814, + "-17.027": -17.026549, + "+43.006-17.027": 25.980265, + }, } cfg_file = tmp_path / "config.yml" diff --git a/tests/unit_tests/test_config.py b/tests/unit_tests/test_config.py index 1e2ef338..fd8ed22e 100644 --- a/tests/unit_tests/test_config.py +++ b/tests/unit_tests/test_config.py @@ -1,5 +1,7 @@ """Test configuration loading""" from casanovo.config import Config +import pytest +import yaml def test_default(): @@ -11,7 +13,7 @@ def test_default(): assert config.file == "default" -def test_override(tmp_path): +def test_override(tmp_path, tiny_config): """Test overriding the default""" yml = tmp_path / "test.yml" with yml.open("w+") as f_out: @@ -26,12 +28,13 @@ def test_override(tmp_path): """ ) - config = Config(yml) - assert config.random_seed == 42 - assert config["random_seed"] == 42 - assert config.accelerator == "auto" - assert config.top_match == 3 - assert len(config.residues) == 4 - for i, residue in enumerate("WOUT", 1): - assert config["residues"][residue] == i - assert config.file == str(yml) + with open(tiny_config, "r") as read_file: + contents = yaml.safe_load(read_file) + contents["random_seed_"] = 354 + + with open("output.yml", "w") as write_file: + yaml.safe_dump(contents, write_file) + with pytest.raises(KeyError): + config = Config("output.yml") + with pytest.raises(KeyError): + config = Config(yml)