Skip to content

Commit

Permalink
Issue error for unrecognized/missing config file entry (#257)
Browse files Browse the repository at this point in the history
* added unit tests to raise exceptions when unrecognized/missing file entry

* fixed lint issue

* fix lint issue

* fixed failing unit test

* lint issue

* lint

* lint issue

---------

Co-authored-by: Isha Gokhale <[email protected]>
  • Loading branch information
ishagokhale and Isha Gokhale authored Nov 9, 2023
1 parent 7721962 commit 235420f
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 12 deletions.
12 changes: 11 additions & 1 deletion casanovo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,17 @@ def __init__(self, config_file: Optional[str] = None):
else:
with Path(config_file).open() as f_in:
self._user_config = yaml.safe_load(f_in)

# check for missing entries in config file
if len(self._user_config.keys()) < len(self._params.keys()):
keys_set = set(self._params.keys())
users_set = set(self._user_config.keys())
missing = list(keys_set - users_set)
raise KeyError(f"Missing expected entry {missing}")
# detect unrecognized config file entries
keys = list(self._params.keys())
for key, val in self._user_config.items():
if key not in keys:
raise KeyError(f"Unrecognized config file entry {key}")
# Validate:
for key, val in self._config_types.items():
self.validate_param(key, val)
Expand Down
58 changes: 57 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,14 +188,70 @@ def tiny_config(tmp_path):
"""A config file for a tiny model."""
cfg = {
"n_head": 2,
"dim_feedfoward": 10,
"dim_feedforward": 10,
"n_layers": 1,
"warmup_iters": 1,
"max_iters": 1,
"max_epochs": 20,
"val_check_interval": 1,
"model_save_folder_path": str(tmp_path),
"accelerator": "cpu",
"precursor_mass_tol": 5,
"isotope_error_range": [0, 1],
"min_peptide_len": 6,
"predict_batch_size": 1024,
"n_beams": 1,
"top_match": 1,
"devices": None,
"random_seed": 454,
"n_log": 1,
"tb_summarywriter": None,
"save_top_k": 5,
"n_peaks": 150,
"min_mz": 50.0,
"max_mz": 2500.0,
"min_intensity": 0.01,
"remove_precursor_tol": 2.0,
"max_charge": 10,
"dim_model": 512,
"dropout": 0.0,
"dim_intensity": None,
"max_length": 100,
"learning_rate": 5e-4,
"weight_decay": 1e-5,
"train_batch_size": 32,
"num_sanity_val_steps": 0,
"train_from_scratch": True,
"calculate_precision": False,
"residues": {
"G": 57.021464,
"A": 71.037114,
"S": 87.032028,
"P": 97.052764,
"V": 99.068414,
"T": 101.047670,
"C+57.021": 160.030649,
"L": 113.084064,
"I": 113.084064,
"N": 114.042927,
"D": 115.026943,
"Q": 128.058578,
"K": 128.094963,
"E": 129.042593,
"M": 131.040485,
"H": 137.058912,
"F": 147.068414,
"R": 156.101111,
"Y": 163.063329,
"W": 186.079313,
"M+15.995": 147.035400,
"N+0.984": 115.026943,
"Q+0.984": 129.042594,
"+42.011": 42.010565,
"+43.006": 43.005814,
"-17.027": -17.026549,
"+43.006-17.027": 25.980265,
},
}

cfg_file = tmp_path / "config.yml"
Expand Down
23 changes: 13 additions & 10 deletions tests/unit_tests/test_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Test configuration loading"""
from casanovo.config import Config
import pytest
import yaml


def test_default():
Expand All @@ -11,7 +13,7 @@ def test_default():
assert config.file == "default"


def test_override(tmp_path):
def test_override(tmp_path, tiny_config):
"""Test overriding the default"""
yml = tmp_path / "test.yml"
with yml.open("w+") as f_out:
Expand All @@ -26,12 +28,13 @@ def test_override(tmp_path):
"""
)

config = Config(yml)
assert config.random_seed == 42
assert config["random_seed"] == 42
assert config.accelerator == "auto"
assert config.top_match == 3
assert len(config.residues) == 4
for i, residue in enumerate("WOUT", 1):
assert config["residues"][residue] == i
assert config.file == str(yml)
with open(tiny_config, "r") as read_file:
contents = yaml.safe_load(read_file)
contents["random_seed_"] = 354

with open("output.yml", "w") as write_file:
yaml.safe_dump(contents, write_file)
with pytest.raises(KeyError):
config = Config("output.yml")
with pytest.raises(KeyError):
config = Config(yml)

0 comments on commit 235420f

Please sign in to comment.