Skip to content

Commit

Permalink
Update pytorch-ie to 0.23 (#125)
Browse files Browse the repository at this point in the history
* use DatasetDict.to_document_type in convert_documents

* fix _TPU_AVAILABLE

* upgrade pytorch-ie to >=0.23.0

* adjust train.py: use interface classes and auto-convert the dataset
  • Loading branch information
ArneBinder authored Sep 13, 2023
1 parent 8e26c0d commit 8c8ebef
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 31 deletions.
3 changes: 2 additions & 1 deletion configs/dataset/_convert_documents.yaml
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
convert_documents:
_processor_: pytorch_ie.DatasetDict.map
_processor_: pytorch_ie.DatasetDict.to_document_type
document_type: ???
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# --------- pytorch-ie --------- #
pytorch-ie>=0.19.0,<1.0.0
pytorch-ie>=0.23.0,<1.0.0
# pie-utils provides some useful helper methods for pytorch-ie,
# e.g. document processors or span utils (convert span annotations
# to sequence encodings such as BIO, IO or BIOUL, and back).
Expand Down
69 changes: 42 additions & 27 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@
from omegaconf import DictConfig
from pytorch_ie import DatasetDict
from pytorch_ie.core import PyTorchIEModel, TaskModule
from pytorch_ie.models import TransformerTokenClassificationModel
from pytorch_ie.models import * # noqa: F403
from pytorch_ie.models.interface import RequiresModelNameOrPath, RequiresNumClasses
from pytorch_ie.taskmodules import * # noqa: F403
from pytorch_ie.taskmodules.interface import ChangesTokenizerVocabSize
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.loggers import Logger

Expand Down Expand Up @@ -92,49 +95,61 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]:
if cfg.get("seed"):
pl.seed_everything(cfg.seed, workers=True)

# Init pytorch-ie dataset
log.info(f"Instantiating dataset <{cfg.dataset._target_}>")
dataset: DatasetDict = hydra.utils.instantiate(cfg.dataset, _convert_="partial")

# Init pytorch-ie taskmodule
log.info(f"Instantiating taskmodule <{cfg.taskmodule._target_}>")
taskmodule: TaskModule = hydra.utils.instantiate(cfg.taskmodule, _convert_="partial")

# Init pytorch-ie dataset
log.info(f"Instantiating dataset <{cfg.dataset._target_}>")
dataset: DatasetDict = hydra.utils.instantiate(
cfg.dataset,
_convert_="partial",
)

# auto-convert the dataset if the taskmodule specifies a document type
if taskmodule.document_type is not None:
if issubclass(dataset.document_type, taskmodule.document_type):
log.info(
f"the dataset is already of the document type that is specified by the taskmodule: "
f"{taskmodule.document_type}"
)
else:
log.info(
f"convert the dataset to the document type that is specified by the taskmodule: "
f"{taskmodule.document_type}"
)
dataset = dataset.to_document_type(taskmodule.document_type)
else:
log.warning(
"The taskmodule does not specify a document type. The dataset can not be automatically converted "
"to a document type."
)

# Init pytorch-ie datamodule
log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>")
datamodule: PieDataModule = hydra.utils.instantiate(
cfg.datamodule, dataset=dataset, taskmodule=taskmodule, _convert_="partial"
)
# Use the train dataset split to prepare the taskmodule
taskmodule.prepare(dataset["train"])
taskmodule.prepare(dataset[datamodule.train_split])

# Init the pytorch-ie model
log.info(f"Instantiating model <{cfg.model._target_}>")
# get additional model arguments
additional_model_kwargs: Dict[str, Any] = {}
model_cls = get_class(cfg.model["_target_"])
# NOTE: DEFINE THE additional_model_kwargs IF YOU WANT TO USE ANOTHER MODEL! SEE EXAMPLES BELOW.
if model_cls == TransformerTokenClassificationModel:
# NOTE: MODIFY THE additional_model_kwargs IF YOUR MODEL REQUIRES ANY MORE PARAMETERS FROM THE TASKMODULE!
# SEE EXAMPLES BELOW.
if issubclass(model_cls, RequiresNumClasses):
additional_model_kwargs["num_classes"] = len(taskmodule.label_to_id)
# elif model_cls == pytorch_ie.models.TransformerSpanClassificationModel:
# additional_model_kwargs["num_classes"] = len(taskmodule.label_to_id)
# max_train_steps = cfg["trainer"]["max_epochs"] * datamodule.num_train
# additional_model_kwargs["t_total"] = int(
# max_train_steps / float(cfg["datamodule"]["batch_size"])
# )
# elif model_cls == pytorch_ie.models.TransformerTextClassificationModel:
# additional_model_kwargs["num_classes"] = len(taskmodule.label_to_id)
# max_train_steps = cfg["trainer"]["max_epochs"] * datamodule.num_train
# additional_model_kwargs["t_total"] = int(
# max_train_steps / float(cfg["datamodule"]["batch_size"])
# )
# elif model_cls == pytorch_ie.models.TransformerSeq2SeqModel:
# pass
else:
raise Exception(
f"unknown model class: {model_cls.__name__}. Please adjust the train.py script for that class, i.e. "
f"define how to set additional_model_kwargs for your model."
)
if issubclass(model_cls, RequiresModelNameOrPath):
if "model_name_or_path" not in cfg.model:
raise Exception(
f"Please specify model_name_or_path in the model config for {model_cls.__name__}."
)
if isinstance(taskmodule, ChangesTokenizerVocabSize):
additional_model_kwargs["tokenizer_vocab_size"] = len(taskmodule.tokenizer)

# initialize the model
model: PyTorchIEModel = hydra.utils.instantiate(
cfg.model, _convert_="partial", **additional_model_kwargs
Expand Down
4 changes: 2 additions & 2 deletions tests/helpers/package_available.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import platform

import pkg_resources
from pytorch_lightning.utilities.xla_device import XLADeviceUtils
from lightning_fabric.accelerators import TPUAccelerator


def _package_available(package_name: str) -> bool:
Expand All @@ -12,7 +12,7 @@ def _package_available(package_name: str) -> bool:
return False


_TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists()
_TPU_AVAILABLE = TPUAccelerator.is_available()

_IS_WINDOWS = platform.system() == "Windows"

Expand Down

0 comments on commit 8c8ebef

Please sign in to comment.