diff --git a/src/evaluate.py b/src/evaluate.py index dec9ed0..6163178 100644 --- a/src/evaluate.py +++ b/src/evaluate.py @@ -76,6 +76,9 @@ def evaluate(cfg: DictConfig) -> Tuple[dict, dict]: log.info(f"Instantiating taskmodule <{cfg.taskmodule._target_}>") taskmodule: TaskModule = hydra.utils.instantiate(cfg.taskmodule, _convert_="partial") + # auto-convert the dataset if the taskmodule specifies a document type + dataset = taskmodule.convert_dataset(dataset) + # Init pytorch-ie datamodule log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>") datamodule: PieDataModule = hydra.utils.instantiate( diff --git a/src/evaluate_documents.py b/src/evaluate_documents.py index 0df8586..808acf2 100644 --- a/src/evaluate_documents.py +++ b/src/evaluate_documents.py @@ -70,6 +70,9 @@ def evaluate_documents(cfg: DictConfig) -> Tuple[dict, dict]: log.info(f"Instantiating metric <{cfg.metric._target_}>") metric: DocumentMetric = hydra.utils.instantiate(cfg.metric, _convert_="partial") + # auto-convert the dataset if the metric specifies a document type + dataset = metric.convert_dataset(dataset) + # Init lightning loggers loggers = utils.instantiate_dict_entries(cfg, "logger") diff --git a/src/predict.py b/src/predict.py index d05c8cd..d65ee6b 100644 --- a/src/predict.py +++ b/src/predict.py @@ -84,6 +84,9 @@ def predict(cfg: DictConfig) -> Tuple[dict, dict]: pipeline.device ) + # auto-convert the dataset if the taskmodule specifies a document type + dataset = pipeline.taskmodule.convert_dataset(dataset) + # Init the serializer serializer: Optional[DocumentSerializer] = None if cfg.get("serializer") and cfg.serializer.get("_target_"): diff --git a/src/train.py b/src/train.py index 4411264..8e8c836 100644 --- a/src/train.py +++ b/src/train.py @@ -107,23 +107,7 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]: ) # auto-convert the dataset if the taskmodule specifies a document type - if taskmodule.document_type is not None: - if issubclass(dataset.document_type, taskmodule.document_type): - log.info( - f"the dataset is already of the document type that is specified by the taskmodule: " - f"{taskmodule.document_type}" - ) - else: - log.info( - f"convert the dataset to the document type that is specified by the taskmodule: " - f"{taskmodule.document_type}" - ) - dataset = dataset.to_document_type(taskmodule.document_type) - else: - log.warning( - "The taskmodule does not specify a document type. The dataset can not be automatically converted " - "to a document type." - ) + dataset = taskmodule.convert_dataset(dataset) # Init pytorch-ie datamodule log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>")