From 2130659034b1d76668acce0ff94f81cadb1e025f Mon Sep 17 00:00:00 2001 From: ArneBinder Date: Thu, 18 Jan 2024 22:06:52 +0100 Subject: [PATCH] simplify log_hyperparameters() and log best_checkpoint and checkpoint_dir (#153) --- requirements.txt | 2 +- src/evaluate.py | 2 +- src/train.py | 16 ++++++-- src/utils/logging_utils.py | 82 ++++++++++++++++++++------------------ 4 files changed, 59 insertions(+), 43 deletions(-) diff --git a/requirements.txt b/requirements.txt index 5dbb44d..6994afa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ # --------- pytorch-ie --------- # pytorch-ie>=0.28.0,<0.30.0 pie-datasets>=0.8.1,<0.9.0 -pie-modules>=0.8.0,<0.9.0 +pie-modules>=0.9.0,<0.10.0 # --------- hydra --------- # hydra-core>=1.3.0 diff --git a/src/evaluate.py b/src/evaluate.py index 7df1db3..ab12529 100644 --- a/src/evaluate.py +++ b/src/evaluate.py @@ -107,7 +107,7 @@ def evaluate(cfg: DictConfig) -> Tuple[dict, dict]: if logger: log.info("Logging hyperparameters!") - utils.log_hyperparameters(object_dict) + utils.log_hyperparameters(logger=logger, model=model, taskmodule=taskmodule, config=cfg) log.info("Starting testing!") trainer.test(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path) diff --git a/src/train.py b/src/train.py index b3de31c..88ef652 100644 --- a/src/train.py +++ b/src/train.py @@ -33,13 +33,14 @@ # https://github.com/ashleve/pyrootutils # ------------------------------------------------------------------------------------ # +import os from typing import Any, Dict, List, Optional, Tuple import hydra import pytorch_lightning as pl -from hydra.utils import get_class from omegaconf import DictConfig from pie_datasets import DatasetDict +from pie_modules.models.interface import RequiresTaskmoduleConfig from pytorch_ie.core import PyTorchIEModel, TaskModule from pytorch_ie.models import * # noqa: F403 from pytorch_ie.models.interface import RequiresModelNameOrPath, RequiresNumClasses @@ -121,7 +122,7 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]: log.info(f"Instantiating model <{cfg.model._target_}>") # get additional model arguments additional_model_kwargs: Dict[str, Any] = {} - model_cls = get_class(cfg.model["_target_"]) + model_cls = hydra.utils.get_class(cfg.model["_target_"]) # NOTE: MODIFY THE additional_model_kwargs IF YOUR MODEL REQUIRES ANY MORE PARAMETERS FROM THE TASKMODULE! # SEE EXAMPLES BELOW. if issubclass(model_cls, RequiresNumClasses): @@ -134,6 +135,9 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]: if isinstance(taskmodule, ChangesTokenizerVocabSize): additional_model_kwargs["tokenizer_vocab_size"] = len(taskmodule.tokenizer) + if issubclass(model_cls, RequiresTaskmoduleConfig): + additional_model_kwargs["taskmodule_config"] = taskmodule.config + # initialize the model model: PyTorchIEModel = hydra.utils.instantiate( cfg.model, _convert_="partial", **additional_model_kwargs @@ -160,7 +164,7 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]: if logger: log.info("Logging hyperparameters!") - utils.log_hyperparameters(object_dict) + utils.log_hyperparameters(logger=logger, model=model, taskmodule=taskmodule, config=cfg) if cfg.model_save_dir is not None: log.info(f"Save taskmodule to {cfg.model_save_dir} [push_to_hub={cfg.push_to_hub}]") @@ -177,6 +181,12 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]: best_ckpt_path = trainer.checkpoint_callback.best_model_path if best_ckpt_path != "": log.info(f"Best ckpt path: {best_ckpt_path}") + best_checkpoint_file = os.path.basename(best_ckpt_path) + utils.log_hyperparameters( + logger=logger, + best_checkpoint=best_checkpoint_file, + checkpoint_dir=trainer.checkpoint_callback.dirpath, + ) if not cfg.trainer.get("fast_dev_run"): if cfg.model_save_dir is not None: diff --git a/src/utils/logging_utils.py b/src/utils/logging_utils.py index b889221..2dfa9d2 100644 --- a/src/utils/logging_utils.py +++ b/src/utils/logging_utils.py @@ -1,6 +1,11 @@ import logging from importlib.util import find_spec +from typing import List, Optional, Union +from omegaconf import DictConfig, OmegaConf +from pie_modules.models.interface import RequiresTaskmoduleConfig +from pytorch_ie import PyTorchIEModel, TaskModule +from pytorch_lightning.loggers import Logger from pytorch_lightning.utilities import rank_zero_only @@ -22,7 +27,14 @@ def get_pylogger(name=__name__) -> logging.Logger: @rank_zero_only -def log_hyperparameters(object_dict: dict) -> None: +def log_hyperparameters( + logger: Optional[List[Logger]] = None, + config: Optional[Union[dict, DictConfig]] = None, + model: Optional[PyTorchIEModel] = None, + taskmodule: Optional[TaskModule] = None, + key_prefix: str = "_", + **kwargs, +) -> None: """Controls which config parts are saved by lightning loggers. Additional saves: @@ -31,48 +43,42 @@ def log_hyperparameters(object_dict: dict) -> None: hparams = {} - cfg = object_dict["cfg"] - model = object_dict["model"] - taskmodule = object_dict["taskmodule"] - trainer = object_dict["trainer"] - - if not trainer.logger: + if not logger: log.warning("Logger not found! Skipping hyperparameter logging...") return - # choose which parts of hydra config will be saved to loggers - # here we use the taskmodule/model config how it is after preparation/initialization - hparams["taskmodule"] = taskmodule._config() - hparams["model"] = model._config() - - # save number of model parameters - hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) - hparams["model/params/trainable"] = sum( - p.numel() for p in model.parameters() if p.requires_grad - ) - hparams["model/params/non_trainable"] = sum( - p.numel() for p in model.parameters() if not p.requires_grad - ) - - hparams["dataset"] = cfg["dataset"] - hparams["trainer"] = cfg["trainer"] - - hparams["callbacks"] = cfg.get("callbacks") - hparams["extras"] = cfg.get("extras") - - hparams["pipeline_type"] = cfg.get("pipeline_type") - hparams["tags"] = cfg.get("tags") - hparams["ckpt_path"] = cfg.get("ckpt_path") - hparams["seed"] = cfg.get("seed") - - hparams["monitor_metric"] = cfg.get("monitor_metric") - hparams["monitor_mode"] = cfg.get("monitor_mode") - - hparams["model_save_dir"] = cfg.get("model_save_dir") + # this is just for backwards compatibility: usually, the taskmodule_config should be passed to + # the model and, thus, be logged there automatically + if model is not None and not isinstance(model, RequiresTaskmoduleConfig): + if taskmodule is None: + raise ValueError( + "If model is not an instance of RequiresTaskmoduleConfig, taskmodule must be passed!" + ) + # here we use the taskmodule/model config how it is after preparation/initialization + hparams["taskmodule_config"] = taskmodule.config + + if model is not None: + # save number of model parameters + hparams[f"{key_prefix}num_params/total"] = sum(p.numel() for p in model.parameters()) + hparams[f"{key_prefix}num_params/trainable"] = sum( + p.numel() for p in model.parameters() if p.requires_grad + ) + hparams[f"{key_prefix}num_params/non_trainable"] = sum( + p.numel() for p in model.parameters() if not p.requires_grad + ) + + if config is not None: + hparams[f"{key_prefix}config"] = ( + OmegaConf.to_container(config, resolve=True) if OmegaConf.is_config(config) else config + ) + + # add additional hparams + for k, v in kwargs.items(): + hparams[f"{key_prefix}{k}"] = v # send hparams to all loggers - for logger in trainer.loggers: - logger.log_hyperparams(hparams) + for current_logger in logger: + current_logger.log_hyperparams(hparams) def close_loggers() -> None: