Skip to content

Commit

Permalink
simplify log_hyperparameters() and log best_checkpoint and checkpoint…
Browse files Browse the repository at this point in the history
…_dir (#153)
  • Loading branch information
ArneBinder authored Jan 18, 2024
1 parent ef6ebdf commit 2130659
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 43 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# --------- pytorch-ie --------- #
pytorch-ie>=0.28.0,<0.30.0
pie-datasets>=0.8.1,<0.9.0
pie-modules>=0.8.0,<0.9.0
pie-modules>=0.9.0,<0.10.0

# --------- hydra --------- #
hydra-core>=1.3.0
Expand Down
2 changes: 1 addition & 1 deletion src/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def evaluate(cfg: DictConfig) -> Tuple[dict, dict]:

if logger:
log.info("Logging hyperparameters!")
utils.log_hyperparameters(object_dict)
utils.log_hyperparameters(logger=logger, model=model, taskmodule=taskmodule, config=cfg)

log.info("Starting testing!")
trainer.test(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path)
Expand Down
16 changes: 13 additions & 3 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,14 @@
# https://github.com/ashleve/pyrootutils
# ------------------------------------------------------------------------------------ #

import os
from typing import Any, Dict, List, Optional, Tuple

import hydra
import pytorch_lightning as pl
from hydra.utils import get_class
from omegaconf import DictConfig
from pie_datasets import DatasetDict
from pie_modules.models.interface import RequiresTaskmoduleConfig
from pytorch_ie.core import PyTorchIEModel, TaskModule
from pytorch_ie.models import * # noqa: F403
from pytorch_ie.models.interface import RequiresModelNameOrPath, RequiresNumClasses
Expand Down Expand Up @@ -121,7 +122,7 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]:
log.info(f"Instantiating model <{cfg.model._target_}>")
# get additional model arguments
additional_model_kwargs: Dict[str, Any] = {}
model_cls = get_class(cfg.model["_target_"])
model_cls = hydra.utils.get_class(cfg.model["_target_"])
# NOTE: MODIFY THE additional_model_kwargs IF YOUR MODEL REQUIRES ANY MORE PARAMETERS FROM THE TASKMODULE!
# SEE EXAMPLES BELOW.
if issubclass(model_cls, RequiresNumClasses):
Expand All @@ -134,6 +135,9 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]:
if isinstance(taskmodule, ChangesTokenizerVocabSize):
additional_model_kwargs["tokenizer_vocab_size"] = len(taskmodule.tokenizer)

if issubclass(model_cls, RequiresTaskmoduleConfig):
additional_model_kwargs["taskmodule_config"] = taskmodule.config

# initialize the model
model: PyTorchIEModel = hydra.utils.instantiate(
cfg.model, _convert_="partial", **additional_model_kwargs
Expand All @@ -160,7 +164,7 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]:

if logger:
log.info("Logging hyperparameters!")
utils.log_hyperparameters(object_dict)
utils.log_hyperparameters(logger=logger, model=model, taskmodule=taskmodule, config=cfg)

if cfg.model_save_dir is not None:
log.info(f"Save taskmodule to {cfg.model_save_dir} [push_to_hub={cfg.push_to_hub}]")
Expand All @@ -177,6 +181,12 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]:
best_ckpt_path = trainer.checkpoint_callback.best_model_path
if best_ckpt_path != "":
log.info(f"Best ckpt path: {best_ckpt_path}")
best_checkpoint_file = os.path.basename(best_ckpt_path)
utils.log_hyperparameters(
logger=logger,
best_checkpoint=best_checkpoint_file,
checkpoint_dir=trainer.checkpoint_callback.dirpath,
)

if not cfg.trainer.get("fast_dev_run"):
if cfg.model_save_dir is not None:
Expand Down
82 changes: 44 additions & 38 deletions src/utils/logging_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import logging
from importlib.util import find_spec
from typing import List, Optional, Union

from omegaconf import DictConfig, OmegaConf
from pie_modules.models.interface import RequiresTaskmoduleConfig
from pytorch_ie import PyTorchIEModel, TaskModule
from pytorch_lightning.loggers import Logger
from pytorch_lightning.utilities import rank_zero_only


Expand All @@ -22,7 +27,14 @@ def get_pylogger(name=__name__) -> logging.Logger:


@rank_zero_only
def log_hyperparameters(object_dict: dict) -> None:
def log_hyperparameters(
logger: Optional[List[Logger]] = None,
config: Optional[Union[dict, DictConfig]] = None,
model: Optional[PyTorchIEModel] = None,
taskmodule: Optional[TaskModule] = None,
key_prefix: str = "_",
**kwargs,
) -> None:
"""Controls which config parts are saved by lightning loggers.
Additional saves:
Expand All @@ -31,48 +43,42 @@ def log_hyperparameters(object_dict: dict) -> None:

hparams = {}

cfg = object_dict["cfg"]
model = object_dict["model"]
taskmodule = object_dict["taskmodule"]
trainer = object_dict["trainer"]

if not trainer.logger:
if not logger:
log.warning("Logger not found! Skipping hyperparameter logging...")
return

# choose which parts of hydra config will be saved to loggers
# here we use the taskmodule/model config how it is after preparation/initialization
hparams["taskmodule"] = taskmodule._config()
hparams["model"] = model._config()

# save number of model parameters
hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
hparams["model/params/trainable"] = sum(
p.numel() for p in model.parameters() if p.requires_grad
)
hparams["model/params/non_trainable"] = sum(
p.numel() for p in model.parameters() if not p.requires_grad
)

hparams["dataset"] = cfg["dataset"]
hparams["trainer"] = cfg["trainer"]

hparams["callbacks"] = cfg.get("callbacks")
hparams["extras"] = cfg.get("extras")

hparams["pipeline_type"] = cfg.get("pipeline_type")
hparams["tags"] = cfg.get("tags")
hparams["ckpt_path"] = cfg.get("ckpt_path")
hparams["seed"] = cfg.get("seed")

hparams["monitor_metric"] = cfg.get("monitor_metric")
hparams["monitor_mode"] = cfg.get("monitor_mode")

hparams["model_save_dir"] = cfg.get("model_save_dir")
# this is just for backwards compatibility: usually, the taskmodule_config should be passed to
# the model and, thus, be logged there automatically
if model is not None and not isinstance(model, RequiresTaskmoduleConfig):
if taskmodule is None:
raise ValueError(
"If model is not an instance of RequiresTaskmoduleConfig, taskmodule must be passed!"
)
# here we use the taskmodule/model config how it is after preparation/initialization
hparams["taskmodule_config"] = taskmodule.config

if model is not None:
# save number of model parameters
hparams[f"{key_prefix}num_params/total"] = sum(p.numel() for p in model.parameters())
hparams[f"{key_prefix}num_params/trainable"] = sum(
p.numel() for p in model.parameters() if p.requires_grad
)
hparams[f"{key_prefix}num_params/non_trainable"] = sum(
p.numel() for p in model.parameters() if not p.requires_grad
)

if config is not None:
hparams[f"{key_prefix}config"] = (
OmegaConf.to_container(config, resolve=True) if OmegaConf.is_config(config) else config
)

# add additional hparams
for k, v in kwargs.items():
hparams[f"{key_prefix}{k}"] = v

# send hparams to all loggers
for logger in trainer.loggers:
logger.log_hyperparams(hparams)
for current_logger in logger:
current_logger.log_hyperparams(hparams)


def close_loggers() -> None:
Expand Down

0 comments on commit 2130659

Please sign in to comment.