Skip to content

Commit

Permalink
Merge pull request #9 from ninakokalj/main
Browse files Browse the repository at this point in the history
Detect repeated entities
  • Loading branch information
eriknovak authored Dec 18, 2024
2 parents 32c146b + 706e9e7 commit c652040
Show file tree
Hide file tree
Showing 5 changed files with 419 additions and 174 deletions.
75 changes: 10 additions & 65 deletions anonipy/anonymize/extractors/multi_extractor.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from typing import List, Set, Tuple, Iterable

from typing import List, Tuple
import itertools

from spacy import displacy
from spacy.tokens import Doc

from ...definitions import Entity
from ...utils.colors import get_label_color
from ..helpers import merge_entities

from .interface import ExtractorInterface

Expand All @@ -27,7 +27,7 @@ class MultiExtractor:
>>> PatternExtractor(pattern_labels, lang=LANGUAGES.ENGLISH),
>>> ]
>>> extractor = MultiExtractor(extractors)
>>> extractor("John Doe is a 19 year old software engineer.")
>>> extractor("John Doe is a 19 year old software engineer.", detect_repeats=False)
[(Doc, [Entity]), (Doc, [Entity])], [Entity]
Attributes:
Expand Down Expand Up @@ -67,24 +67,27 @@ def __init__(self, extractors: List[ExtractorInterface]):
self.extractors = extractors

def __call__(
self, text: str
self, text: str, detect_repeats: bool = False
) -> Tuple[List[Tuple[Doc, List[Entity]]], List[Entity]]:
"""Extract the entities fron the text using the provided extractors.
Examples:
>>> extractor("John Doe is a 19 year old software engineer.")
>>> extractor("John Doe is a 19 year old software engineer.", detect_repeats=False)
[(Doc, [Entity]), (Doc, [Entity])], [Entity]
Args:
text: The text to extract entities from.
detect_repeats: Whether to check text again for repeated entities.
Returns:
The list of extractor outputs containing the tuple (spacy document, extracted entities).
The list of joint entities.
"""

extractor_outputs = [e(text) for e in self.extractors]
joint_entities = self._merge_entities(extractor_outputs)
extractor_outputs = [e(text, detect_repeats) for e in self.extractors]
joint_entities = merge_entities(extractor_outputs)

return extractor_outputs, joint_entities

def display(self, doc: Doc, page: bool = False, jupyter: bool = None) -> str:
Expand Down Expand Up @@ -112,61 +115,3 @@ def display(self, doc: Doc, page: bool = False, jupyter: bool = None) -> str:
return displacy.render(
doc, style="ent", options=options, page=page, jupyter=jupyter
)

def _merge_entities(
self, extractor_outputs: List[Tuple[Doc, List[Entity]]]
) -> List[Entity]:
"""Merges the entities returned by the extractors.
Args:
extractor_outputs: The list of extractor outputs.
Returns:
The merged entities list.
"""

if len(extractor_outputs) == 0:
return []
if len(extractor_outputs) == 1:
return extractor_outputs[0][1]

joint_entities = self._filter_entities(
list(
itertools.chain.from_iterable(
[entity[1] for entity in extractor_outputs]
)
)
)
return joint_entities

def _filter_entities(self, entities: Iterable[Entity]) -> List[Entity]:
"""Filters the entities based on their start and end indices.
Args:
entities: The entities to filter.
Returns:
The filtered entities.
"""

def get_sort_key(entity):
return (
entity.end_index - entity.start_index,
-entity.start_index,
)

sorted_entities = sorted(entities, key=get_sort_key, reverse=True)
result = []
seen_tokens: Set[int] = set()
for entity in sorted_entities:
# Check for end - 1 here because boundaries are inclusive
if (
entity.start_index not in seen_tokens
and entity.end_index - 1 not in seen_tokens
):
result.append(entity)
seen_tokens.update(range(entity.start_index, entity.end_index))
result = sorted(result, key=lambda entity: entity.start_index)
return result
56 changes: 13 additions & 43 deletions anonipy/anonymize/extractors/ner_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@
from spacy.tokens import Doc, Span
from spacy.language import Language

from ..helpers import convert_spacy_to_entity
from ..helpers import convert_spacy_to_entity, detect_repeated_entities, get_doc_entity_spans, create_spacy_entities
from ...utils.regex import regex_mapping
from ...constants import LANGUAGES
from ...definitions import Entity
from ...utils.colors import get_label_color

from .interface import ExtractorInterface


# ===============================================
# Extractor class
# ===============================================
Expand All @@ -29,7 +30,7 @@ class NERExtractor(ExtractorInterface):
>>> from anonipy.anonymize.extractors import NERExtractor
>>> labels = [{"label": "PERSON", "type": "string"}]
>>> extractor = NERExtractor(labels, lang=LANGUAGES.ENGLISH)
>>> extractor("John Doe is a 19 year old software engineer.")
>>> extractor("John Doe is a 19 year old software engineer.", detect_repeats=False)
Doc, [Entity]
Attributes:
Expand Down Expand Up @@ -92,15 +93,16 @@ def __init__(
warnings.filterwarnings("ignore", category=ResourceWarning)
self.pipeline = self._prepare_pipeline()

def __call__(self, text: str, *args, **kwargs) -> Tuple[Doc, List[Entity]]:
def __call__(self, text: str, detect_repeats: bool = False, *args, **kwargs) -> Tuple[Doc, List[Entity]]:
"""Extract the entities from the text.
Examples:
>>> extractor("John Doe is a 19 year old software engineer.")
>>> extractor("John Doe is a 19 year old software engineer.", detect_repeats=False)
Doc, [Entity]
Args:
text: The text to extract entities from.
detect_repeats: Whether to check text again for repeated entities.
Returns:
The spacy document.
Expand All @@ -110,7 +112,12 @@ def __call__(self, text: str, *args, **kwargs) -> Tuple[Doc, List[Entity]]:

doc = self.pipeline(text)
anoni_entities, spacy_entities = self._prepare_entities(doc)
self._set_spacy_fields(doc, spacy_entities)

if detect_repeats:
anoni_entities = detect_repeated_entities(doc, anoni_entities, self.spacy_style)

create_spacy_entities(doc, anoni_entities, self.spacy_style)

return doc, anoni_entities

def display(self, doc: Doc, page: bool = False, jupyter: bool = None) -> str:
Expand Down Expand Up @@ -226,46 +233,9 @@ def _prepare_entities(self, doc: Doc) -> Tuple[List[Entity], List[Span]]:
# TODO: make this part more generic
anoni_entities = []
spacy_entities = []
for s in self._get_spacy_fields(doc):
for s in get_doc_entity_spans(doc, self.spacy_style):
label = list(filter(lambda x: x["label"] == s.label_, self.labels))[0]
if re.match(label["regex"], s.text):
anoni_entities.append(convert_spacy_to_entity(s, **label))
spacy_entities.append(s)
return anoni_entities, spacy_entities

def _get_spacy_fields(self, doc: Doc) -> List[Span]:
"""Get the spacy doc entity spans.
args:
doc: The spacy doc to get the entity spans from.
Returns:
The list of Spans from the spacy doc.
"""

if self.spacy_style == "ent":
return doc.ents
elif self.spacy_style == "span":
return doc.spans["sc"]
else:
raise ValueError(f"Invalid spacy style: {self.spacy_style}")

def _set_spacy_fields(self, doc: Doc, entities: List[Span]) -> None:
"""Set the spacy doc entity spans.
Args:
doc: The spacy doc to set the entity spans.
entities: The entity spans to set.
Returns:
None
"""

if self.spacy_style == "ent":
doc.ents = entities
elif self.spacy_style == "span":
doc.spans["sc"] = entities
else:
raise ValueError(f"Invalid spacy style: {self.spacy_style}")
77 changes: 17 additions & 60 deletions anonipy/anonymize/extractors/pattern_extractor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import re

import importlib
from typing import List, Tuple, Optional, Callable

Expand All @@ -8,7 +7,7 @@
from spacy.language import Language
from spacy.matcher import Matcher

from ..helpers import convert_spacy_to_entity
from ..helpers import convert_spacy_to_entity, detect_repeated_entities, get_doc_entity_spans, set_doc_entity_spans, create_spacy_entities
from ...constants import LANGUAGES
from ...definitions import Entity
from ...utils.colors import get_label_color
Expand All @@ -29,7 +28,7 @@ class PatternExtractor(ExtractorInterface):
>>> from anonipy.anonymize.extractors import PatternExtractor
>>> labels = [{"label": "PERSON", "type": "string", "regex": "([A-Z][a-z]+ [A-Z][a-z]+)"}]
>>> extractor = PatternExtractor(labels, lang=LANGUAGES.ENGLISH)
>>> extractor("John Doe is a 19 year old software engineer.")
>>> extractor("John Doe is a 19 year old software engineer.", detect_repeats=False)
Doc, [Entity]
Attributes:
Expand Down Expand Up @@ -79,15 +78,16 @@ def __init__(
self.token_matchers = self._prepare_token_matchers()
self.global_matchers = self._prepare_global_matchers()

def __call__(self, text: str, *args, **kwargs) -> Tuple[Doc, List[Entity]]:
def __call__(self, text: str, detect_repeats: bool = False, *args, **kwargs) -> Tuple[Doc, List[Entity]]:
"""Extract the entities from the text.
Examples:
>>> extractor("John Doe is a 19 year old software engineer.")
>>> extractor("John Doe is a 19 year old software engineer.", detect_repeats=False)
Doc, [Entity]
Args:
text: The text to extract entities from.
detect_repeats: Whether to check text again for repeated entities.
Returns:
The spacy document.
Expand All @@ -99,7 +99,12 @@ def __call__(self, text: str, *args, **kwargs) -> Tuple[Doc, List[Entity]]:
self.token_matchers(doc) if self.token_matchers else None
self.global_matchers(doc) if self.global_matchers else None
anoni_entities, spacy_entities = self._prepare_entities(doc)
self._set_doc_entity_spans(doc, spacy_entities)

if detect_repeats:
anoni_entities = detect_repeated_entities(doc, anoni_entities, self.spacy_style)

create_spacy_entities(doc, anoni_entities, self.spacy_style)

return doc, anoni_entities

def display(self, doc: Doc, page: bool = False, jupyter: bool = None) -> str:
Expand Down Expand Up @@ -195,15 +200,9 @@ def global_matchers(doc: Doc) -> None:
if not entity:
continue
entity._.score = 1.0
entities = [convert_spacy_to_entity(entity)]
# add the entity to the previous entity list
prev_entities = self._get_doc_entity_spans(doc)
if self.spacy_style == "ent":
prev_entities = util.filter_spans(prev_entities + (entity,))
elif self.spacy_style == "span":
prev_entities.append(entity)
else:
raise ValueError(f"Invalid spacy style: {self.spacy_style}")
self._set_doc_entity_spans(doc, prev_entities)
create_spacy_entities(doc, entities, self.spacy_style)

return global_matchers

Expand All @@ -222,7 +221,7 @@ def _prepare_entities(self, doc: Doc) -> Tuple[List[Entity], List[Span]]:
# TODO: make this part more generic
anoni_entities = []
spacy_entities = []
for e in self._get_doc_entity_spans(doc):
for e in get_doc_entity_spans(doc, self.spacy_style):
label = list(filter(lambda x: x["label"] == e.label_, self.labels))[0]
anoni_entities.append(convert_spacy_to_entity(e, **label))
spacy_entities.append(e)
Expand All @@ -246,49 +245,7 @@ def add_event_ent(matcher, doc, i, matches):
if not entity:
return
entity._.score = 1.0
# add the entity to the previous entity list
prev_entities = self._get_doc_entity_spans(doc)
if self.spacy_style == "ent":
prev_entities = util.filter_spans(prev_entities + (entity,))
elif self.spacy_style == "span":
prev_entities.append(entity)
else:
raise ValueError(f"Invalid spacy style: {self.spacy_style}")
self._set_doc_entity_spans(doc, prev_entities)

return add_event_ent

def _get_doc_entity_spans(self, doc: Doc) -> List[Span]:
"""Get the spacy doc entity spans.
Args:
doc: The spacy doc to get the entity spans from.
Returns:
The list of entity spans.
"""

if self.spacy_style == "ent":
return doc.ents
if self.spacy_style == "span":
if "sc" not in doc.spans:
doc.spans["sc"] = []
return doc.spans["sc"]
raise ValueError(f"Invalid spacy style: {self.spacy_style}")

def _set_doc_entity_spans(self, doc: Doc, entities: List[Span]) -> None:
"""Set the spacy doc entity spans.
Args:
doc: The spacy doc to set the entity spans.
entities: The entity spans to assign the doc.
"""
entities = [convert_spacy_to_entity(entity)]
create_spacy_entities(doc, entities, self.spacy_style)

if self.spacy_style == "ent":
doc.ents = entities
elif self.spacy_style == "span":
doc.spans["sc"] = entities
else:
raise ValueError(f"Invalid spacy style: {self.spacy_style}")
return add_event_ent
Loading

0 comments on commit c652040

Please sign in to comment.