Skip to content

Commit

Permalink
further simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
ArneBinder committed Mar 13, 2024
1 parent de93d87 commit 03383ba
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 27 deletions.
28 changes: 16 additions & 12 deletions src/pipeline/ner_re_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,12 @@ def add_annotations_from_other_documents(

def process_pipeline_steps(
documents: Sequence[Document],
processors: Dict[str, Callable[[Document], Optional[Document]]],
inplace: bool = False,
):
if not inplace:
documents = [type(doc).fromdict(doc.asdict()) for doc in documents]
processors: Dict[str, Callable[[Sequence[Document]], Optional[Sequence[Document]]]],
) -> Sequence[Document]:

# do the actual inference
# call the processors in the order they are provided
for step_name, processor in processors.items():
print(f"process {step_name} ...")
logger.info(f"process {step_name} ...")
processed_documents = processor(documents)
if processed_documents is not None:
documents = processed_documents
Expand Down Expand Up @@ -133,12 +130,20 @@ def __init__(
):
self.processor_kwargs[inference_pipeline]["show_progress_bar"] = show_progress_bar

def __call__(self, documents: Sequence[Document], inplace: bool = False):
def __call__(self, documents: Sequence[Document], inplace: bool = False) -> Sequence[Document]:

docs = [type(doc).fromdict(doc.asdict()) for doc in documents]
input_docs: Sequence[Document]
# we need to keep the original documents to add the gold data back
original_docs: Sequence[Document]
if inplace:
input_docs = documents
original_docs = [doc.copy() for doc in documents]
else:
input_docs = [doc.copy() for doc in documents]
original_docs = documents

docs_with_predictions = process_pipeline_steps(
documents=docs,
documents=input_docs,
processors={
"clear_annotations": partial(
process_documents,
Expand Down Expand Up @@ -179,11 +184,10 @@ def __call__(self, documents: Sequence[Document], inplace: bool = False):
),
"re_add_gold_data": partial(
add_annotations_from_other_documents,
other_docs=documents,
other_docs=original_docs,
layer_names=[self.entity_layer, self.relation_layer],
**self.processor_kwargs.get("re_add_gold_data", {}),
),
},
inplace=inplace,
)
return docs_with_predictions
20 changes: 5 additions & 15 deletions tests/unit/pipeline/test_ner_re_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,27 +186,17 @@ def documents_processor(documents) -> TextBasedDocument:
return documents


@pytest.mark.parametrize("inplace", [True, False])
def test_process_pipeline_steps(document, inplace):
def test_process_pipeline_steps(document):
original_spans = document["labeled_spans"]
assert len(original_spans) == 3

docs = process_pipeline_steps(
process_pipeline_steps(
documents=[document],
processors={"add_span": documents_processor},
inplace=inplace,
)

doc = docs[0]

if inplace:
original_spans = document["labeled_spans"]
assert len(original_spans) == 4
else:
original_spans = document["labeled_spans"]
assert len(original_spans) == 3
spans = doc["labeled_spans"]
assert len(spans) == 4
original_spans = document["labeled_spans"]
assert len(original_spans) == 4


def test_add_annotations_from_other_documents(document, document_with_relations):
Expand Down Expand Up @@ -309,7 +299,7 @@ def test_ner_re_pipeline():
docs = pipeline(documents=[document])
assert len(docs) == 1

doc = docs[0]
doc: TextDocumentWithEntitiesAndRelations = docs[0]

# gold entities and relations
gold_entities = doc.entities
Expand Down

0 comments on commit 03383ba

Please sign in to comment.