Skip to content

Commit

Permalink
Remove train_from_scratch config option (#275)
Browse files Browse the repository at this point in the history
Instead of having to specify `train_from_scratch` in the config file, training will proceed from an existing model weights file if this is given as an argument to `casanovo train`.

Fixes #263.
  • Loading branch information
bittremieux authored Jan 9, 2024
1 parent 3c2d3f5 commit f01c607
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 30 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [Unreleased]

### Changed

- Instead of having to specify `train_from_scratch` in the config file, training will proceed from an existing model weights file if this is given as an argument to `casanovo train`.

## [4.0.0] - 2023-12-22

### Added
Expand Down
1 change: 0 additions & 1 deletion casanovo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ class Config:
top_match=int,
max_epochs=int,
num_sanity_val_steps=int,
train_from_scratch=bool,
save_top_k=int,
model_save_folder_path=str,
val_check_interval=int,
Expand Down
2 changes: 0 additions & 2 deletions casanovo/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,6 @@ train_batch_size: 32
max_epochs: 30
# Number of validation steps to run before training begins
num_sanity_val_steps: 0
# Set to "False" to further train a pre-trained Casanovo model
train_from_scratch: True
# Calculate peptide and amino acid precision during training. this
# is expensive, so we recommend against it.
calculate_precision: False
Expand Down
20 changes: 10 additions & 10 deletions casanovo/denovo/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,16 +251,16 @@ def initialize_model(self, train: bool) -> None:
calculate_precision=self.config.calculate_precision,
)

from_scratch = (
self.config.train_from_scratch,
self.model_filename is None,
)
if train and any(from_scratch):
self.model = Spec2Pep(**model_params)
return
elif self.model_filename is None:
logger.error("A model file must be provided")
raise ValueError("A model file must be provided")
if self.model_filename is None:
# Train a model from scratch if no model file is provided.
if train:
self.model = Spec2Pep(**model_params)
return
# Else we're not training, so a model file must be provided.
else:
logger.error("A model file must be provided")
raise ValueError("A model file must be provided")
# Else a model file is provided (to continue training or for inference).

if not Path(self.model_filename).exists():
logger.error(
Expand Down
1 change: 0 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,6 @@ def tiny_config(tmp_path):
"weight_decay": 1e-5,
"train_batch_size": 32,
"num_sanity_val_steps": 0,
"train_from_scratch": True,
"calculate_precision": False,
"residues": {
"G": 57.021464,
Expand Down
36 changes: 20 additions & 16 deletions tests/unit_tests/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,35 +6,39 @@
from casanovo.denovo.model_runner import ModelRunner


def test_initialize_model(tmp_path):
"""Test that"""
def test_initialize_model(tmp_path, mgf_small):
"""Test initializing a new or existing model."""
config = Config()
config.train_from_scratch = False
# No model filename given, so train from scratch.
ModelRunner(config=config).initialize_model(train=True)

# No model filename given during inference = error.
with pytest.raises(ValueError):
ModelRunner(config=config).initialize_model(train=False)

with pytest.raises(FileNotFoundError):
runner = ModelRunner(config=config, model_filename="blah")
runner.initialize_model(train=True)

# Non-existing model filename given during inference = error.
with pytest.raises(FileNotFoundError):
runner = ModelRunner(config=config, model_filename="blah")
runner.initialize_model(train=False)

# This should work now:
config.train_from_scratch = True
runner = ModelRunner(config=config, model_filename="blah")
# Train a quick model.
config.max_epochs = 1
config.n_layers = 1
ckpt = tmp_path / "existing.ckpt"
with ModelRunner(config=config) as runner:
runner.train([mgf_small], [mgf_small])
runner.trainer.save_checkpoint(ckpt)

# Resume training from previous model.
runner = ModelRunner(config=config, model_filename=str(ckpt))
runner.initialize_model(train=True)

# But this should still fail:
with pytest.raises(FileNotFoundError):
runner = ModelRunner(config=config, model_filename="blah")
runner.initialize_model(train=False)
# Inference with previous model.
runner = ModelRunner(config=config, model_filename=str(ckpt))
runner.initialize_model(train=False)

# If the model initialization throws and EOFError, then the Spec2Pep model
# has tried to load the weights:
# has tried to load the weights.
weights = tmp_path / "blah"
weights.touch()
with pytest.raises(EOFError):
Expand All @@ -43,7 +47,7 @@ def test_initialize_model(tmp_path):


def test_save_and_load_weights(tmp_path, mgf_small, tiny_config):
"""Test saving aloading weights"""
"""Test saving and loading weights"""
config = Config(tiny_config)
config.max_epochs = 1
config.n_layers = 1
Expand Down

0 comments on commit f01c607

Please sign in to comment.