Skip to content

Commit

Permalink
add dataset auto-conversion to all main entry scripts (#132)
Browse files Browse the repository at this point in the history
* add dataset auto-conversion to all main entry scripts

* add document_type to statistical metric configs

* enable show_as_markdown in statistical metric configs

* re-add conversion.py

* remove conversion.py
  • Loading branch information
ArneBinder authored Sep 19, 2023
1 parent df0f950 commit 30dd394
Show file tree
Hide file tree
Showing 7 changed files with 17 additions and 20 deletions.
3 changes: 2 additions & 1 deletion configs/metric/count_entity_labels.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ title: entity label distribution
field: entities
labels: ???
show_histogram: true
# show_as_markdown: true
show_as_markdown: true
document_type: pytorch_ie.documents.TextDocumentWithLabeledEntities
3 changes: 2 additions & 1 deletion configs/metric/count_text_characters.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ _target_: pytorch_ie.metrics.statistics.FieldLengthCollector
title: text length (characters)
field: text
show_histogram: true
# show_as_markdown: true
show_as_markdown: true
document_type: pytorch_ie.documents.TextBasedDocument
4 changes: 3 additions & 1 deletion configs/metric/count_text_tokens.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,7 @@ text_field: text
tokenizer: bert-base-uncased
tokenizer_kwargs:
add_special_tokens: false
# strict_span_conversion: false
show_histogram: true
# show_as_markdown: true
show_as_markdown: true
document_type: pytorch_ie.documents.TextBasedDocument
3 changes: 3 additions & 0 deletions src/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ def evaluate(cfg: DictConfig) -> Tuple[dict, dict]:
log.info(f"Instantiating taskmodule <{cfg.taskmodule._target_}>")
taskmodule: TaskModule = hydra.utils.instantiate(cfg.taskmodule, _convert_="partial")

# auto-convert the dataset if the taskmodule specifies a document type
dataset = taskmodule.convert_dataset(dataset)

# Init pytorch-ie datamodule
log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>")
datamodule: PieDataModule = hydra.utils.instantiate(
Expand Down
3 changes: 3 additions & 0 deletions src/evaluate_documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ def evaluate_documents(cfg: DictConfig) -> Tuple[dict, dict]:
log.info(f"Instantiating metric <{cfg.metric._target_}>")
metric: DocumentMetric = hydra.utils.instantiate(cfg.metric, _convert_="partial")

# auto-convert the dataset if the metric specifies a document type
dataset = metric.convert_dataset(dataset)

# Init lightning loggers
loggers = utils.instantiate_dict_entries(cfg, "logger")

Expand Down
3 changes: 3 additions & 0 deletions src/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ def predict(cfg: DictConfig) -> Tuple[dict, dict]:
pipeline.device
)

# auto-convert the dataset if the taskmodule specifies a document type
dataset = pipeline.taskmodule.convert_dataset(dataset)

# Init the serializer
serializer: Optional[DocumentSerializer] = None
if cfg.get("serializer") and cfg.serializer.get("_target_"):
Expand Down
18 changes: 1 addition & 17 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,23 +107,7 @@ def train(cfg: DictConfig) -> Tuple[dict, dict]:
)

# auto-convert the dataset if the taskmodule specifies a document type
if taskmodule.document_type is not None:
if issubclass(dataset.document_type, taskmodule.document_type):
log.info(
f"the dataset is already of the document type that is specified by the taskmodule: "
f"{taskmodule.document_type}"
)
else:
log.info(
f"convert the dataset to the document type that is specified by the taskmodule: "
f"{taskmodule.document_type}"
)
dataset = dataset.to_document_type(taskmodule.document_type)
else:
log.warning(
"The taskmodule does not specify a document type. The dataset can not be automatically converted "
"to a document type."
)
dataset = taskmodule.convert_dataset(dataset)

# Init pytorch-ie datamodule
log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>")
Expand Down

0 comments on commit 30dd394

Please sign in to comment.