From 03383ba7e153955703309439bf6250a09a56c8e0 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Wed, 13 Mar 2024 14:31:29 +0100 Subject: [PATCH] further simplify --- src/pipeline/ner_re_pipeline.py | 28 ++++++++++++--------- tests/unit/pipeline/test_ner_re_pipeline.py | 20 ++++----------- 2 files changed, 21 insertions(+), 27 deletions(-) diff --git a/src/pipeline/ner_re_pipeline.py b/src/pipeline/ner_re_pipeline.py index ea32584..b781ef7 100644 --- a/src/pipeline/ner_re_pipeline.py +++ b/src/pipeline/ner_re_pipeline.py @@ -70,15 +70,12 @@ def add_annotations_from_other_documents( def process_pipeline_steps( documents: Sequence[Document], - processors: Dict[str, Callable[[Document], Optional[Document]]], - inplace: bool = False, -): - if not inplace: - documents = [type(doc).fromdict(doc.asdict()) for doc in documents] + processors: Dict[str, Callable[[Sequence[Document]], Optional[Sequence[Document]]]], +) -> Sequence[Document]: - # do the actual inference + # call the processors in the order they are provided for step_name, processor in processors.items(): - print(f"process {step_name} ...") + logger.info(f"process {step_name} ...") processed_documents = processor(documents) if processed_documents is not None: documents = processed_documents @@ -133,12 +130,20 @@ def __init__( ): self.processor_kwargs[inference_pipeline]["show_progress_bar"] = show_progress_bar - def __call__(self, documents: Sequence[Document], inplace: bool = False): + def __call__(self, documents: Sequence[Document], inplace: bool = False) -> Sequence[Document]: - docs = [type(doc).fromdict(doc.asdict()) for doc in documents] + input_docs: Sequence[Document] + # we need to keep the original documents to add the gold data back + original_docs: Sequence[Document] + if inplace: + input_docs = documents + original_docs = [doc.copy() for doc in documents] + else: + input_docs = [doc.copy() for doc in documents] + original_docs = documents docs_with_predictions = process_pipeline_steps( - documents=docs, + documents=input_docs, processors={ "clear_annotations": partial( process_documents, @@ -179,11 +184,10 @@ def __call__(self, documents: Sequence[Document], inplace: bool = False): ), "re_add_gold_data": partial( add_annotations_from_other_documents, - other_docs=documents, + other_docs=original_docs, layer_names=[self.entity_layer, self.relation_layer], **self.processor_kwargs.get("re_add_gold_data", {}), ), }, - inplace=inplace, ) return docs_with_predictions diff --git a/tests/unit/pipeline/test_ner_re_pipeline.py b/tests/unit/pipeline/test_ner_re_pipeline.py index 559d58c..2d5465c 100644 --- a/tests/unit/pipeline/test_ner_re_pipeline.py +++ b/tests/unit/pipeline/test_ner_re_pipeline.py @@ -186,27 +186,17 @@ def documents_processor(documents) -> TextBasedDocument: return documents -@pytest.mark.parametrize("inplace", [True, False]) -def test_process_pipeline_steps(document, inplace): +def test_process_pipeline_steps(document): original_spans = document["labeled_spans"] assert len(original_spans) == 3 - docs = process_pipeline_steps( + process_pipeline_steps( documents=[document], processors={"add_span": documents_processor}, - inplace=inplace, ) - doc = docs[0] - - if inplace: - original_spans = document["labeled_spans"] - assert len(original_spans) == 4 - else: - original_spans = document["labeled_spans"] - assert len(original_spans) == 3 - spans = doc["labeled_spans"] - assert len(spans) == 4 + original_spans = document["labeled_spans"] + assert len(original_spans) == 4 def test_add_annotations_from_other_documents(document, document_with_relations): @@ -309,7 +299,7 @@ def test_ner_re_pipeline(): docs = pipeline(documents=[document]) assert len(docs) == 1 - doc = docs[0] + doc: TextDocumentWithEntitiesAndRelations = docs[0] # gold entities and relations gold_entities = doc.entities