diff --git a/configs/pipeline/ner_re_pipeline.yaml b/configs/pipeline/ner_re_pipeline.yaml index 5923d35..b3b2292 100644 --- a/configs/pipeline/ner_re_pipeline.yaml +++ b/configs/pipeline/ner_re_pipeline.yaml @@ -1,6 +1,8 @@ _target_: src.pipeline.NerRePipeline ner_model_path: ??? re_model_path: ??? +entity_layer: labeled_spans +relation_layer: binary_relations # some settings for the ner / re inference show_progress_bar: true diff --git a/src/pipeline/ner_re_pipeline.py b/src/pipeline/ner_re_pipeline.py index e128368..b781ef7 100644 --- a/src/pipeline/ner_re_pipeline.py +++ b/src/pipeline/ner_re_pipeline.py @@ -13,22 +13,15 @@ D = TypeVar("D", bound=Document) -def clear_annotation_layers( - doc: D, layer_names: List[str], predictions: bool = False, inplace: bool = False -) -> D: - if not inplace: - doc = type(doc).fromdict(doc.asdict()) +def clear_annotation_layers(doc: D, layer_names: List[str], predictions: bool = False) -> None: for layer_name in layer_names: if predictions: doc[layer_name].predictions.clear() else: doc[layer_name].clear() - return doc -def move_annotations_from_predictions(doc: D, layer_names: List[str], inplace: bool = False) -> D: - if not inplace: - doc = type(doc).fromdict(doc.asdict()) +def move_annotations_from_predictions(doc: D, layer_names: List[str]) -> None: for layer_name in layer_names: annotations = list(doc[layer_name].predictions) # remove any previous annotations @@ -36,12 +29,9 @@ def move_annotations_from_predictions(doc: D, layer_names: List[str], inplace: b # each annotation can be attached to just one annotation container, so we need to clear the predictions doc[layer_name].predictions.clear() doc[layer_name].extend(annotations) - return doc -def move_annotations_to_predictions(doc: D, layer_names: List[str], inplace: bool = False) -> D: - if not inplace: - doc = type(doc).fromdict(doc.asdict()) +def move_annotations_to_predictions(doc: D, layer_names: List[str]) -> None: for layer_name in layer_names: annotations = list(doc[layer_name]) # each annotation can be attached to just one annotation container, so we need to clear the layer @@ -49,7 +39,6 @@ def move_annotations_to_predictions(doc: D, layer_names: List[str], inplace: boo # remove any previous annotations doc[layer_name].predictions.clear() doc[layer_name].predictions.extend(annotations) - return doc def add_annotations_from_other_documents( @@ -59,12 +48,8 @@ def add_annotations_from_other_documents( from_predictions: bool = False, to_predictions: bool = False, clear_before: bool = True, - inplace: bool = False, -) -> List[D]: - prepared_documents = [] +) -> None: for i, doc in enumerate(docs): - if not inplace: - doc = type(doc).fromdict(doc.asdict()) other_doc = other_docs[i] # copy to not modify the input other_doc = type(other_doc).fromdict(other_doc.asdict()) @@ -81,21 +66,16 @@ def add_annotations_from_other_documents( doc[layer_name].predictions.extend(other_annotations) else: doc[layer_name].extend(other_annotations) - prepared_documents.append(doc) - return prepared_documents def process_pipeline_steps( documents: Sequence[Document], - processors: Dict[str, Callable[[Document], Optional[Document]]], - inplace: bool = False, -): - if not inplace: - documents = [type(doc).fromdict(doc.asdict()) for doc in documents] + processors: Dict[str, Callable[[Sequence[Document]], Optional[Sequence[Document]]]], +) -> Sequence[Document]: - # do the actual inference + # call the processors in the order they are provided for step_name, processor in processors.items(): - print(f"process {step_name} ...") + logger.info(f"process {step_name} ...") processed_documents = processor(documents) if processed_documents is not None: documents = processed_documents @@ -121,6 +101,8 @@ def __init__( self, ner_model_path: str, re_model_path: str, + entity_layer: str, + relation_layer: str, device: Optional[int] = None, batch_size: Optional[int] = None, show_progress_bar: Optional[bool] = None, @@ -129,6 +111,8 @@ def __init__( self.ner_model_path = ner_model_path self.re_model_path = re_model_path self.processor_kwargs = processor_kwargs or {} + self.entity_layer = entity_layer + self.relation_layer = relation_layer # set some values for the inference processors, if provided for inference_pipeline in ["ner_pipeline", "re_pipeline"]: if inference_pipeline not in self.processor_kwargs: @@ -146,18 +130,25 @@ def __init__( ): self.processor_kwargs[inference_pipeline]["show_progress_bar"] = show_progress_bar - def __call__(self, documents: Sequence[Document], inplace: bool = False): + def __call__(self, documents: Sequence[Document], inplace: bool = False) -> Sequence[Document]: - if not inplace: - documents = [type(doc).fromdict(doc.asdict()) for doc in documents] + input_docs: Sequence[Document] + # we need to keep the original documents to add the gold data back + original_docs: Sequence[Document] + if inplace: + input_docs = documents + original_docs = [doc.copy() for doc in documents] + else: + input_docs = [doc.copy() for doc in documents] + original_docs = documents docs_with_predictions = process_pipeline_steps( - documents=documents, + documents=input_docs, processors={ "clear_annotations": partial( process_documents, processor=clear_annotation_layers, - layer_names=["entities", "relations"], + layer_names=[self.entity_layer, self.relation_layer], **self.processor_kwargs.get("clear_annotations", {}), ), "ner_pipeline": AutoPipeline.from_pretrained( @@ -166,7 +157,7 @@ def __call__(self, documents: Sequence[Document], inplace: bool = False): "use_predicted_entities": partial( process_documents, processor=move_annotations_from_predictions, - layer_names=["entities"], + layer_names=[self.entity_layer], **self.processor_kwargs.get("use_predicted_entities", {}), ), # "create_candidate_relations": partial( @@ -182,19 +173,19 @@ def __call__(self, documents: Sequence[Document], inplace: bool = False): "clear_candidate_relations": partial( process_documents, processor=clear_annotation_layers, - layer_names=["relations"], + layer_names=[self.relation_layer], **self.processor_kwargs.get("clear_candidate_relations", {}), ), "move_entities_to_predictions": partial( process_documents, processor=move_annotations_to_predictions, - layer_names=["entities"], + layer_names=[self.entity_layer], **self.processor_kwargs.get("move_entities_to_predictions", {}), ), "re_add_gold_data": partial( add_annotations_from_other_documents, - other_docs=documents, - layer_names=["entities", "relations"], + other_docs=original_docs, + layer_names=[self.entity_layer, self.relation_layer], **self.processor_kwargs.get("re_add_gold_data", {}), ), }, diff --git a/tests/unit/pipeline/__init__.py b/tests/unit/pipeline/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/pipeline/test_ner_re_pipeline.py b/tests/unit/pipeline/test_ner_re_pipeline.py new file mode 100644 index 0000000..2d5465c --- /dev/null +++ b/tests/unit/pipeline/test_ner_re_pipeline.py @@ -0,0 +1,351 @@ +import dataclasses + +import pytest +from pie_modules.annotations import BinaryRelation, LabeledSpan +from pie_modules.documents import TextBasedDocument +from pytorch_ie import AnnotationLayer +from pytorch_ie.core import annotation_field + +from src.pipeline.ner_re_pipeline import ( + NerRePipeline, + add_annotations_from_other_documents, + clear_annotation_layers, + move_annotations_from_predictions, + move_annotations_to_predictions, + process_documents, + process_pipeline_steps, +) + + +@dataclasses.dataclass +class TextDocumentWithLabeledSpansAndBinaryRelations(TextBasedDocument): + labeled_spans: AnnotationLayer[LabeledSpan] = annotation_field(target="text") + binary_relations: AnnotationLayer[BinaryRelation] = annotation_field(target="labeled_spans") + + +@pytest.fixture +def document(): + document = TextDocumentWithLabeledSpansAndBinaryRelations( + text="Harry lives in Berlin, Germany. He works at DFKI.", id="tmp" + ) + document.labeled_spans.predictions.extend( + [ + LabeledSpan(start=0, end=5, label="PERSON"), + LabeledSpan(start=44, end=48, label="ORGANIZATION"), + ] + ) + + assert str(document.labeled_spans.predictions[0]) == "Harry" + assert str(document.labeled_spans.predictions[1]) == "DFKI" + + document.labeled_spans.extend( + [ + LabeledSpan(start=0, end=5, label="PERSON"), + LabeledSpan(start=15, end=30, label="LOCATION"), + LabeledSpan(start=44, end=48, label="ORGANIZATION"), + ] + ) + assert str(document.labeled_spans[0]) == "Harry" + assert str(document.labeled_spans[1]) == "Berlin, Germany" + assert str(document.labeled_spans[2]) == "DFKI" + + return document + + +@pytest.fixture +def document_with_relations(document): + + document = document.copy() + + document.binary_relations.predictions.extend( + [ + BinaryRelation( + head=document.labeled_spans[0], + tail=document.labeled_spans[2], + label="per:employee_of", + ), + ] + ) + + document.binary_relations.extend( + [ + BinaryRelation( + head=document.labeled_spans[0], + tail=document.labeled_spans[2], + label="per:employee_of", + ), + ] + ) + + return document + + +def test_clear_annotation_layers(document): + original_entities = document["labeled_spans"] + assert len(original_entities) == 3 + + original_predictions = document["labeled_spans"].predictions + assert len(original_predictions) == 2 + + clear_annotation_layers( + document, + layer_names=["labeled_spans"], + ) + + new_entities = document["labeled_spans"] + assert len(new_entities) == 0 + + predictions = document["labeled_spans"].predictions + assert len(predictions) == 2 + + # clear predictions + clear_annotation_layers( + document, + layer_names=["labeled_spans"], + predictions=True, + ) + + new_entities = document["labeled_spans"] + assert len(new_entities) == 0 + + new_predictions = document["labeled_spans"].predictions + assert len(new_predictions) == 0 + + +def test_move_annotations_from_predictions(document): + original_entities = document["labeled_spans"] + assert len(original_entities) == 3 + + original_predictions = document["labeled_spans"].predictions + assert len(original_predictions) == 2 + + move_annotations_from_predictions( + document, + layer_names=["labeled_spans"], + ) + + new_entities = document["labeled_spans"] + assert len(new_entities) == 2 + + new_predictions = document["labeled_spans"].predictions + assert len(new_predictions) == 0 + + +def test_move_annotations_to_predictions(document): + original_entities = document["labeled_spans"] + assert len(original_entities) == 3 + + original_predictions = document["labeled_spans"].predictions + assert len(original_predictions) == 2 + + move_annotations_to_predictions( + document, + layer_names=["labeled_spans"], + ) + + new_entities = document["labeled_spans"] + assert len(new_entities) == 0 + + new_predictions = document["labeled_spans"].predictions + assert len(new_predictions) == 3 + + +def document_processor(document) -> TextBasedDocument: + doc = document.copy() + doc["labeled_spans"].append(LabeledSpan(start=0, end=0, label="empty")) + return doc + + +def none_processor(document) -> None: + return None + + +def test_process_documents(document): + result = process_documents( + documents=[document], + processor=document_processor, + ) + doc = result[0] + + spans = doc["labeled_spans"] + assert len(spans) == 4 + + result = process_documents( + documents=[document], + processor=none_processor, + ) + doc = result[0] + + spans = doc["labeled_spans"] + assert len(spans) == 3 + + +def documents_processor(documents) -> TextBasedDocument: + for doc in documents: + doc["labeled_spans"].append(LabeledSpan(start=0, end=0, label="empty")) + return documents + + +def test_process_pipeline_steps(document): + original_spans = document["labeled_spans"] + assert len(original_spans) == 3 + + process_pipeline_steps( + documents=[document], + processors={"add_span": documents_processor}, + ) + + original_spans = document["labeled_spans"] + assert len(original_spans) == 4 + + +def test_add_annotations_from_other_documents(document, document_with_relations): + + original_relations = document_with_relations["binary_relations"] + assert len(original_relations) == 1 + original_relations_predictions = document_with_relations["binary_relations"].predictions + assert len(original_relations_predictions) == 1 + + add_annotations_from_other_documents( + docs=[document], other_docs=[document_with_relations], layer_names=["binary_relations"] + ) + + relations = document["binary_relations"] + assert len(relations) == 1 + + # from predictions + + add_annotations_from_other_documents( + docs=[document], + other_docs=[document_with_relations], + layer_names=["binary_relations"], + from_predictions=True, + ) + + relations = document["binary_relations"] + assert len(relations) == 1 + + assert relations[0] == document_with_relations["binary_relations"].predictions[0] + + # to predictions + + add_annotations_from_other_documents( + docs=[document], + other_docs=[document_with_relations], + layer_names=["binary_relations"], + to_predictions=True, + ) + + relations = document["binary_relations"].predictions + assert len(relations) == 1 + + assert relations[0] == document_with_relations["binary_relations"].predictions[0] + + +@dataclasses.dataclass +class TextDocumentWithEntitiesAndRelations(TextBasedDocument): + entities: AnnotationLayer[LabeledSpan] = annotation_field(target="text") + relations: AnnotationLayer[BinaryRelation] = annotation_field(target="entities") + + +@pytest.mark.slow +def test_ner_re_pipeline(): + # These imports register the respective taskmodules and models for NER and RE + from pytorch_ie.models import ( + TransformerSpanClassificationModel, + TransformerTextClassificationModel, + ) + from pytorch_ie.taskmodules import ( + TransformerRETextClassificationTaskModule, + TransformerSpanClassificationTaskModule, + ) + + document = TextDocumentWithEntitiesAndRelations( + text="Harry lives in Berlin, Germany. He works at DFKI.", id="tmp" + ) + + document.entities.extend( + [ + LabeledSpan(start=0, end=5, label="PER"), + LabeledSpan(start=15, end=30, label="LOC"), + LabeledSpan(start=44, end=48, label="ORG"), + ] + ) + assert str(document.entities[0]) == "Harry" + assert str(document.entities[1]) == "Berlin, Germany" + assert str(document.entities[2]) == "DFKI" + + document.relations.extend( + [ + BinaryRelation( + head=document.entities[0], + tail=document.entities[2], + label="per:employee_of", + ), + ] + ) + re_pipeline_args = {"taskmodule_kwargs": {"create_relation_candidates": True}} + + pipeline = NerRePipeline( + ner_model_path="pie/example-ner-spanclf-conll03", + re_model_path="pie/example-re-textclf-tacred", + entity_layer="entities", + relation_layer="relations", + device=-1, + batch_size=1, + re_pipeline=re_pipeline_args, + show_progress_bar=False, + ) + docs = pipeline(documents=[document]) + assert len(docs) == 1 + + doc: TextDocumentWithEntitiesAndRelations = docs[0] + + # gold entities and relations + gold_entities = doc.entities + assert len(gold_entities) == 3 + gold_relations = doc.relations + assert len(gold_relations) == 1 + + # predicted entities and relations + predicted_entities = doc.entities.predictions + assert len(predicted_entities) == 4 + + assert str(predicted_entities[0]) == "Harry" + assert predicted_entities[0].label == "PER" + + assert str(predicted_entities[1]) == "Berlin" + assert predicted_entities[1].label == "LOC" + + assert str(predicted_entities[2]) == "Germany" + assert predicted_entities[2].label == "LOC" + + assert str(predicted_entities[3]) == "DFKI" + assert predicted_entities[3].label == "ORG" + + predicted_relations = doc.relations.predictions + assert len(predicted_relations) == 6 + + assert str(predicted_relations[0].head) == "Harry" + assert str(predicted_relations[0].tail) == "Berlin" + assert predicted_relations[0].label == "per:cities_of_residence" + + assert str(predicted_relations[1].head) == "Harry" + assert str(predicted_relations[1].tail) == "Germany" + assert predicted_relations[1].label == "per:countries_of_residence" + + assert str(predicted_relations[2].head) == "Harry" + assert str(predicted_relations[2].tail) == "DFKI" + assert predicted_relations[2].label == "per:employee_of" + + assert str(predicted_relations[3].head) == "Berlin" + assert str(predicted_relations[3].tail) == "Harry" + assert predicted_relations[3].label == "per:cities_of_residence" + + assert str(predicted_relations[4].head) == "Germany" + assert str(predicted_relations[4].tail) == "Harry" + assert predicted_relations[4].label == "per:countries_of_residence" + + assert str(predicted_relations[5].head) == "DFKI" + assert str(predicted_relations[5].tail) == "Harry" + assert predicted_relations[5].label == "per:employee_of"