Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add function to detect repeated entities #33

Merged
merged 12 commits into from
Dec 18, 2024
75 changes: 10 additions & 65 deletions anonipy/anonymize/extractors/multi_extractor.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from typing import List, Set, Tuple, Iterable

from typing import List, Tuple
import itertools

from spacy import displacy
from spacy.tokens import Doc

from ...definitions import Entity
from ...utils.colors import get_label_color
from ..helpers import merge_entities

from .interface import ExtractorInterface

Expand All @@ -27,7 +27,7 @@ class MultiExtractor:
>>> PatternExtractor(pattern_labels, lang=LANGUAGES.ENGLISH),
>>> ]
>>> extractor = MultiExtractor(extractors)
>>> extractor("John Doe is a 19 year old software engineer.")
>>> extractor("John Doe is a 19 year old software engineer.", detect_repeats=False)
[(Doc, [Entity]), (Doc, [Entity])], [Entity]

Attributes:
Expand Down Expand Up @@ -67,24 +67,27 @@ def __init__(self, extractors: List[ExtractorInterface]):
self.extractors = extractors

def __call__(
self, text: str
self, text: str, detect_repeats: bool = False
) -> Tuple[List[Tuple[Doc, List[Entity]]], List[Entity]]:
"""Extract the entities fron the text using the provided extractors.

Examples:
>>> extractor("John Doe is a 19 year old software engineer.")
>>> extractor("John Doe is a 19 year old software engineer.", detect_repeats=False)
[(Doc, [Entity]), (Doc, [Entity])], [Entity]

Args:
text: The text to extract entities from.
detect_repeats: Whether to check text again for repeated entities.

Returns:
The list of extractor outputs containing the tuple (spacy document, extracted entities).
The list of joint entities.

"""

extractor_outputs = [e(text) for e in self.extractors]
joint_entities = self._merge_entities(extractor_outputs)
extractor_outputs = [e(text, detect_repeats) for e in self.extractors]
joint_entities = merge_entities(extractor_outputs)

return extractor_outputs, joint_entities

def display(self, doc: Doc, page: bool = False, jupyter: bool = None) -> str:
Expand Down Expand Up @@ -112,61 +115,3 @@ def display(self, doc: Doc, page: bool = False, jupyter: bool = None) -> str:
return displacy.render(
doc, style="ent", options=options, page=page, jupyter=jupyter
)

def _merge_entities(
self, extractor_outputs: List[Tuple[Doc, List[Entity]]]
) -> List[Entity]:
"""Merges the entities returned by the extractors.

Args:
extractor_outputs: The list of extractor outputs.

Returns:
The merged entities list.

"""

if len(extractor_outputs) == 0:
return []
if len(extractor_outputs) == 1:
return extractor_outputs[0][1]

joint_entities = self._filter_entities(
list(
itertools.chain.from_iterable(
[entity[1] for entity in extractor_outputs]
)
)
)
return joint_entities

def _filter_entities(self, entities: Iterable[Entity]) -> List[Entity]:
"""Filters the entities based on their start and end indices.

Args:
entities: The entities to filter.

Returns:
The filtered entities.

"""

def get_sort_key(entity):
return (
entity.end_index - entity.start_index,
-entity.start_index,
)

sorted_entities = sorted(entities, key=get_sort_key, reverse=True)
result = []
seen_tokens: Set[int] = set()
for entity in sorted_entities:
# Check for end - 1 here because boundaries are inclusive
if (
entity.start_index not in seen_tokens
and entity.end_index - 1 not in seen_tokens
):
result.append(entity)
seen_tokens.update(range(entity.start_index, entity.end_index))
result = sorted(result, key=lambda entity: entity.start_index)
return result
56 changes: 13 additions & 43 deletions anonipy/anonymize/extractors/ner_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@
from spacy.tokens import Doc, Span
from spacy.language import Language

from ..helpers import convert_spacy_to_entity
from ..helpers import convert_spacy_to_entity, detect_repeated_entities, get_doc_entity_spans, create_spacy_entities
from ...utils.regex import regex_mapping
from ...constants import LANGUAGES
from ...definitions import Entity
from ...utils.colors import get_label_color

from .interface import ExtractorInterface


# ===============================================
# Extractor class
# ===============================================
Expand All @@ -29,7 +30,7 @@ class NERExtractor(ExtractorInterface):
>>> from anonipy.anonymize.extractors import NERExtractor
>>> labels = [{"label": "PERSON", "type": "string"}]
>>> extractor = NERExtractor(labels, lang=LANGUAGES.ENGLISH)
>>> extractor("John Doe is a 19 year old software engineer.")
>>> extractor("John Doe is a 19 year old software engineer.", detect_repeats=False)
Doc, [Entity]

Attributes:
Expand Down Expand Up @@ -92,15 +93,16 @@ def __init__(
warnings.filterwarnings("ignore", category=ResourceWarning)
self.pipeline = self._prepare_pipeline()

def __call__(self, text: str, *args, **kwargs) -> Tuple[Doc, List[Entity]]:
def __call__(self, text: str, detect_repeats: bool = False, *args, **kwargs) -> Tuple[Doc, List[Entity]]:
"""Extract the entities from the text.

Examples:
>>> extractor("John Doe is a 19 year old software engineer.")
>>> extractor("John Doe is a 19 year old software engineer.", detect_repeats=False)
Doc, [Entity]

Args:
text: The text to extract entities from.
detect_repeats: Whether to check text again for repeated entities.

Returns:
The spacy document.
Expand All @@ -110,7 +112,12 @@ def __call__(self, text: str, *args, **kwargs) -> Tuple[Doc, List[Entity]]:

doc = self.pipeline(text)
anoni_entities, spacy_entities = self._prepare_entities(doc)
self._set_spacy_fields(doc, spacy_entities)

if detect_repeats:
anoni_entities = detect_repeated_entities(doc, anoni_entities, self.spacy_style)

create_spacy_entities(doc, anoni_entities, self.spacy_style)

return doc, anoni_entities

def display(self, doc: Doc, page: bool = False, jupyter: bool = None) -> str:
Expand Down Expand Up @@ -226,46 +233,9 @@ def _prepare_entities(self, doc: Doc) -> Tuple[List[Entity], List[Span]]:
# TODO: make this part more generic
anoni_entities = []
spacy_entities = []
for s in self._get_spacy_fields(doc):
for s in get_doc_entity_spans(doc, self.spacy_style):
label = list(filter(lambda x: x["label"] == s.label_, self.labels))[0]
if re.match(label["regex"], s.text):
anoni_entities.append(convert_spacy_to_entity(s, **label))
spacy_entities.append(s)
return anoni_entities, spacy_entities

def _get_spacy_fields(self, doc: Doc) -> List[Span]:
"""Get the spacy doc entity spans.

args:
doc: The spacy doc to get the entity spans from.

Returns:
The list of Spans from the spacy doc.

"""

if self.spacy_style == "ent":
return doc.ents
elif self.spacy_style == "span":
return doc.spans["sc"]
else:
raise ValueError(f"Invalid spacy style: {self.spacy_style}")

def _set_spacy_fields(self, doc: Doc, entities: List[Span]) -> None:
"""Set the spacy doc entity spans.

Args:
doc: The spacy doc to set the entity spans.
entities: The entity spans to set.

Returns:
None

"""

if self.spacy_style == "ent":
doc.ents = entities
elif self.spacy_style == "span":
doc.spans["sc"] = entities
else:
raise ValueError(f"Invalid spacy style: {self.spacy_style}")
77 changes: 17 additions & 60 deletions anonipy/anonymize/extractors/pattern_extractor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import re

import importlib
from typing import List, Tuple, Optional, Callable

Expand All @@ -8,7 +7,7 @@
from spacy.language import Language
from spacy.matcher import Matcher

from ..helpers import convert_spacy_to_entity
from ..helpers import convert_spacy_to_entity, detect_repeated_entities, get_doc_entity_spans, set_doc_entity_spans, create_spacy_entities
from ...constants import LANGUAGES
from ...definitions import Entity
from ...utils.colors import get_label_color
Expand All @@ -29,7 +28,7 @@ class PatternExtractor(ExtractorInterface):
>>> from anonipy.anonymize.extractors import PatternExtractor
>>> labels = [{"label": "PERSON", "type": "string", "regex": "([A-Z][a-z]+ [A-Z][a-z]+)"}]
>>> extractor = PatternExtractor(labels, lang=LANGUAGES.ENGLISH)
>>> extractor("John Doe is a 19 year old software engineer.")
>>> extractor("John Doe is a 19 year old software engineer.", detect_repeats=False)
Doc, [Entity]

Attributes:
Expand Down Expand Up @@ -79,15 +78,16 @@ def __init__(
self.token_matchers = self._prepare_token_matchers()
self.global_matchers = self._prepare_global_matchers()

def __call__(self, text: str, *args, **kwargs) -> Tuple[Doc, List[Entity]]:
def __call__(self, text: str, detect_repeats: bool = False, *args, **kwargs) -> Tuple[Doc, List[Entity]]:
"""Extract the entities from the text.

Examples:
>>> extractor("John Doe is a 19 year old software engineer.")
>>> extractor("John Doe is a 19 year old software engineer.", detect_repeats=False)
Doc, [Entity]

Args:
text: The text to extract entities from.
detect_repeats: Whether to check text again for repeated entities.

Returns:
The spacy document.
Expand All @@ -99,7 +99,12 @@ def __call__(self, text: str, *args, **kwargs) -> Tuple[Doc, List[Entity]]:
self.token_matchers(doc) if self.token_matchers else None
self.global_matchers(doc) if self.global_matchers else None
anoni_entities, spacy_entities = self._prepare_entities(doc)
self._set_doc_entity_spans(doc, spacy_entities)

if detect_repeats:
anoni_entities = detect_repeated_entities(doc, anoni_entities, self.spacy_style)

create_spacy_entities(doc, anoni_entities, self.spacy_style)

return doc, anoni_entities

def display(self, doc: Doc, page: bool = False, jupyter: bool = None) -> str:
Expand Down Expand Up @@ -195,15 +200,9 @@ def global_matchers(doc: Doc) -> None:
if not entity:
continue
entity._.score = 1.0
entities = [convert_spacy_to_entity(entity)]
# add the entity to the previous entity list
prev_entities = self._get_doc_entity_spans(doc)
if self.spacy_style == "ent":
prev_entities = util.filter_spans(prev_entities + (entity,))
elif self.spacy_style == "span":
prev_entities.append(entity)
else:
raise ValueError(f"Invalid spacy style: {self.spacy_style}")
self._set_doc_entity_spans(doc, prev_entities)
create_spacy_entities(doc, entities, self.spacy_style)

return global_matchers

Expand All @@ -222,7 +221,7 @@ def _prepare_entities(self, doc: Doc) -> Tuple[List[Entity], List[Span]]:
# TODO: make this part more generic
anoni_entities = []
spacy_entities = []
for e in self._get_doc_entity_spans(doc):
for e in get_doc_entity_spans(doc, self.spacy_style):
label = list(filter(lambda x: x["label"] == e.label_, self.labels))[0]
anoni_entities.append(convert_spacy_to_entity(e, **label))
spacy_entities.append(e)
Expand All @@ -246,49 +245,7 @@ def add_event_ent(matcher, doc, i, matches):
if not entity:
return
entity._.score = 1.0
# add the entity to the previous entity list
prev_entities = self._get_doc_entity_spans(doc)
if self.spacy_style == "ent":
prev_entities = util.filter_spans(prev_entities + (entity,))
elif self.spacy_style == "span":
prev_entities.append(entity)
else:
raise ValueError(f"Invalid spacy style: {self.spacy_style}")
self._set_doc_entity_spans(doc, prev_entities)

return add_event_ent

def _get_doc_entity_spans(self, doc: Doc) -> List[Span]:
"""Get the spacy doc entity spans.

Args:
doc: The spacy doc to get the entity spans from.

Returns:
The list of entity spans.

"""

if self.spacy_style == "ent":
return doc.ents
if self.spacy_style == "span":
if "sc" not in doc.spans:
doc.spans["sc"] = []
return doc.spans["sc"]
raise ValueError(f"Invalid spacy style: {self.spacy_style}")

def _set_doc_entity_spans(self, doc: Doc, entities: List[Span]) -> None:
"""Set the spacy doc entity spans.

Args:
doc: The spacy doc to set the entity spans.
entities: The entity spans to assign the doc.

"""
entities = [convert_spacy_to_entity(entity)]
create_spacy_entities(doc, entities, self.spacy_style)

if self.spacy_style == "ent":
doc.ents = entities
elif self.spacy_style == "span":
doc.spans["sc"] = entities
else:
raise ValueError(f"Invalid spacy style: {self.spacy_style}")
return add_event_ent
Loading
Loading