diff --git a/configs/train.yaml b/configs/train.yaml index 699f9df..abf5af6 100644 --- a/configs/train.yaml +++ b/configs/train.yaml @@ -61,6 +61,9 @@ monitor_mode: "max" # seed for random number generators in pytorch, numpy and python.random seed: null +# path to pretrained pytorch-ie model that updates the weights of base model with pretrained pie model +pretrained_pie_model_path: null + # simply provide checkpoint path to resume training ckpt_path: null diff --git a/src/train.py b/src/train.py index 88ef652..ebe26e9 100644 --- a/src/train.py +++ b/src/train.py @@ -40,7 +40,10 @@ import pytorch_lightning as pl from omegaconf import DictConfig from pie_datasets import DatasetDict +from pie_modules.models import * # noqa: F403 from pie_modules.models.interface import RequiresTaskmoduleConfig +from pie_modules.taskmodules import * # noqa: F403 +from pytorch_ie import AutoModel from pytorch_ie.core import PyTorchIEModel, TaskModule from pytorch_ie.models import * # noqa: F403 from pytorch_ie.models.interface import RequiresModelNameOrPath, RequiresNumClasses @@ -140,9 +143,28 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]: # initialize the model model: PyTorchIEModel = hydra.utils.instantiate( - cfg.model, _convert_="partial", **additional_model_kwargs + cfg.model, + _convert_="partial", + # In the case of loading weights from a pretrained PIE model, we do not need to download the base (transformer) model in the model constructors. We disable that by passing is_from_pretrained=True in these cases. + is_from_pretrained=cfg.get("pretrained_pie_model_path", None) is not None, + **additional_model_kwargs, ) + if cfg.get("pretrained_pie_model_path", None) is not None: + pie_model = AutoModel.from_pretrained(cfg["pretrained_pie_model_path"]) + loaded_state_dict = pie_model.state_dict() + has_prefix_mapping = cfg.get("pretrained_pie_model_prefix_mapping", None) is not None + if has_prefix_mapping: + state_dict_to_load = {} + for prefix_from, prefix_to in cfg["pretrained_pie_model_prefix_mapping"].items(): + for name, value in loaded_state_dict.items(): + if name.startswith(prefix_from): + new_name = prefix_to + name[len(prefix_from) :] + state_dict_to_load[new_name] = value + else: + state_dict_to_load = loaded_state_dict + model.load_state_dict(state_dict_to_load, strict=not has_prefix_mapping) + log.info("Instantiating callbacks...") callbacks: List[Callback] = utils.instantiate_dict_entries(cfg, key="callbacks")