Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade to pytorch-lightning 2 #230

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions casanovo/casanovo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,8 @@
import click
import github
import requests
import torch
import tqdm
import yaml
from pytorch_lightning.lite import LightningLite
from lightning import Fabric

from . import __version__
from . import utils
Expand Down Expand Up @@ -127,7 +125,7 @@ def main(
# Read parameters from the config file.
config = Config(config)

LightningLite.seed_everything(seed=config["random_seed"], workers=True)
Fabric.seed_everything(seed=config["random_seed"], workers=True)

# Download model weights if these were not specified (except when training).
if model is None and mode != "train":
Expand Down
1 change: 1 addition & 0 deletions casanovo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class Config:
dropout=float,
dim_intensity=int,
max_length=int,
residues=dict, # note, this key is special-cased and type is ignored
bittremieux marked this conversation as resolved.
Show resolved Hide resolved
n_log=int,
tb_summarywriter=str,
warmup_iters=int,
Expand Down
30 changes: 13 additions & 17 deletions casanovo/denovo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,8 +724,8 @@ def training_step(
pred = pred[:, :-1, :].reshape(-1, self.decoder.vocab_size + 1)
loss = self.celoss(pred, truth.flatten())
self.log(
"CELoss",
{mode: loss.detach()},
f"CELoss/{mode}",
loss.detach(),
on_step=False,
on_epoch=True,
sync_dist=True,
Expand Down Expand Up @@ -765,13 +765,11 @@ def validation_step(
)
log_args = dict(on_step=False, on_epoch=True, sync_dist=True)
self.log(
"Peptide precision at coverage=1",
{"valid": pep_precision},
"Peptide precision at coverage=1/valid",
pep_precision,
**log_args,
)
self.log(
"AA precision at coverage=1", {"valid": aa_precision}, **log_args
)
self.log("AA precision at coverage=1/valid", aa_precision, **log_args)

return loss

Expand Down Expand Up @@ -824,7 +822,7 @@ def on_train_epoch_end(self) -> None:
"""
Log the training loss at the end of each epoch.
"""
train_loss = self.trainer.callback_metrics["CELoss"]["train"].detach()
train_loss = self.trainer.callback_metrics["CELoss/train"].detach()
metrics = {
"step": self.trainer.global_step,
"train": train_loss,
Expand All @@ -839,20 +837,18 @@ def on_validation_epoch_end(self) -> None:
callback_metrics = self.trainer.callback_metrics
metrics = {
"step": self.trainer.global_step,
"valid": callback_metrics["CELoss"]["valid"].detach(),
"valid": callback_metrics["CELoss/valid"].detach(),
"valid_aa_precision": callback_metrics[
"AA precision at coverage=1"
]["valid"].detach(),
"AA precision at coverage=1/valid"
].detach(),
"valid_pep_precision": callback_metrics[
"Peptide precision at coverage=1"
]["valid"].detach(),
"Peptide precision at coverage=1/valid"
].detach(),
}
self._history.append(metrics)
self._log_history()

def on_predict_epoch_end(
self, results: List[List[Tuple[np.ndarray, List[str], torch.Tensor]]]
) -> None:
def on_predict_epoch_end(self) -> None:
"""
Write the predicted peptide sequences and amino acid scores to the
output file.
Expand All @@ -868,7 +864,7 @@ def on_predict_epoch_end(
peptide_score,
aa_scores,
) in itertools.chain.from_iterable(
itertools.chain.from_iterable(results)
self.trainer.predict_loop.predictions
):
if len(peptide) == 0:
continue
Expand Down
28 changes: 15 additions & 13 deletions casanovo/denovo/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from depthcharge.data import AnnotatedSpectrumIndex, SpectrumIndex
from pytorch_lightning.strategies import DDPStrategy

from lightning.pytorch.accelerators import find_usable_cuda_devices

from .. import utils
from ..data import ms_io
from ..denovo.dataloaders import DeNovoDataModule
Expand Down Expand Up @@ -96,8 +98,12 @@ def _execute_existing(
model_filename,
)
raise FileNotFoundError("Could not find the trained model weights")
map_location = None
if torch.cuda.device_count() == 0:
map_location = "cpu"
model = Spec2Pep().load_from_checkpoint(
model_filename,
map_location=map_location,
dim_model=config["dim_model"],
n_head=config["n_head"],
dim_feedforward=config["dim_feedforward"],
Expand Down Expand Up @@ -158,7 +164,6 @@ def _execute_existing(
# Create the Trainer object.
trainer = pl.Trainer(
accelerator="auto",
auto_select_gpus=True,
devices=_get_devices(config["no_gpu"]),
logger=config["logger"],
max_epochs=config["max_epochs"],
Expand Down Expand Up @@ -304,7 +309,6 @@ def train(

trainer = pl.Trainer(
accelerator="auto",
auto_select_gpus=True,
callbacks=callbacks,
devices=_get_devices(config["no_gpu"]),
enable_checkpointing=config["save_model"],
Expand Down Expand Up @@ -352,7 +356,7 @@ def _get_peak_filenames(
]


def _get_strategy() -> Optional[DDPStrategy]:
def _get_strategy() -> Union[DDPStrategy, str]:
"""
Get the strategy for the Trainer.

Expand All @@ -362,16 +366,16 @@ def _get_strategy() -> Optional[DDPStrategy]:

Returns
-------
Optional[DDPStrategy]
Union[DDPStrategy,str]
The strategy parameter for the Trainer.
"""
if torch.cuda.device_count() > 1:
return DDPStrategy(find_unused_parameters=False, static_graph=True)

return None
return "auto"


def _get_devices(no_gpu: bool) -> Union[int, str]:
def _get_devices(no_gpu: bool) -> Union[List[int], str]:
"""
Get the number of GPUs/CPUs for the Trainer to use.

Expand All @@ -382,16 +386,14 @@ def _get_devices(no_gpu: bool) -> Union[int, str]:

Returns
-------
Union[int, str]
The number of GPUs/CPUs to use, or "auto" to let PyTorch Lightning
determine the appropriate number of devices.
Union[List[int], str]
A list of CUDA GPU devices to use, or "auto" to let PyTorch Lightning determine
the appropriate number of devices.
"""
if not no_gpu and any(
operator.attrgetter(device + ".is_available")(torch)()
for device in ("cuda",)
):
return -1
elif not (n_workers := utils.n_workers()):
return "auto"
return find_usable_cuda_devices()
else:
return n_workers
return "auto"
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ dependencies = [
"pandas",
"psutil",
"PyGithub",
"pytorch-lightning>=1.7,<2.0",
"lightning>=2.0.0",
"PyYAML",
"requests",
"scikit-learn",
Expand Down