Skip to content

Commit

Permalink
add dataset auto-conversion to all main entry scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
ArneBinder committed Sep 19, 2023
1 parent 689fa3c commit c40f012
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 17 deletions.
3 changes: 3 additions & 0 deletions src/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ def evaluate(cfg: DictConfig) -> Tuple[dict, dict]:
log.info(f"Instantiating taskmodule <{cfg.taskmodule._target_}>")
taskmodule: TaskModule = hydra.utils.instantiate(cfg.taskmodule, _convert_="partial")

# auto-convert the dataset if the taskmodule specifies a document type
dataset = taskmodule.convert_dataset(dataset)

# Init pytorch-ie datamodule
log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>")
datamodule: PieDataModule = hydra.utils.instantiate(
Expand Down
3 changes: 3 additions & 0 deletions src/evaluate_documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ def evaluate_documents(cfg: DictConfig) -> Tuple[dict, dict]:
log.info(f"Instantiating metric <{cfg.metric._target_}>")
metric: DocumentMetric = hydra.utils.instantiate(cfg.metric, _convert_="partial")

# auto-convert the dataset if the metric specifies a document type
dataset = metric.convert_dataset(dataset)

# Init lightning loggers
loggers = utils.instantiate_dict_entries(cfg, "logger")

Expand Down
3 changes: 3 additions & 0 deletions src/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ def predict(cfg: DictConfig) -> Tuple[dict, dict]:
pipeline.device
)

# auto-convert the dataset if the taskmodule specifies a document type
dataset = pipeline.taskmodule.convert_dataset(dataset)

# Init the serializer
serializer: Optional[DocumentSerializer] = None
if cfg.get("serializer") and cfg.serializer.get("_target_"):
Expand Down
18 changes: 1 addition & 17 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,23 +107,7 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]:
)

# auto-convert the dataset if the taskmodule specifies a document type
if taskmodule.document_type is not None:
if issubclass(dataset.document_type, taskmodule.document_type):
log.info(
f"the dataset is already of the document type that is specified by the taskmodule: "
f"{taskmodule.document_type}"
)
else:
log.info(
f"convert the dataset to the document type that is specified by the taskmodule: "
f"{taskmodule.document_type}"
)
dataset = dataset.to_document_type(taskmodule.document_type)
else:
log.warning(
"The taskmodule does not specify a document type. The dataset can not be automatically converted "
"to a document type."
)
dataset = taskmodule.convert_dataset(dataset)

# Init pytorch-ie datamodule
log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>")
Expand Down

0 comments on commit c40f012

Please sign in to comment.