diff --git a/anonipy/anonymize/extractors/multi_extractor.py b/anonipy/anonymize/extractors/multi_extractor.py index 64345dd..dc3f588 100644 --- a/anonipy/anonymize/extractors/multi_extractor.py +++ b/anonipy/anonymize/extractors/multi_extractor.py @@ -1,5 +1,4 @@ -from typing import List, Set, Tuple, Iterable - +from typing import List, Tuple import itertools from spacy import displacy @@ -7,6 +6,7 @@ from ...definitions import Entity from ...utils.colors import get_label_color +from ..helpers import merge_entities from .interface import ExtractorInterface @@ -27,7 +27,7 @@ class MultiExtractor: >>> PatternExtractor(pattern_labels, lang=LANGUAGES.ENGLISH), >>> ] >>> extractor = MultiExtractor(extractors) - >>> extractor("John Doe is a 19 year old software engineer.") + >>> extractor(text="John Doe is a 19 year old software engineer.", detect_repeats=False) [(Doc, [Entity]), (Doc, [Entity])], [Entity] Attributes: @@ -67,24 +67,27 @@ def __init__(self, extractors: List[ExtractorInterface]): self.extractors = extractors def __call__( - self, text: str + self, text: str, detect_repeats: bool = False ) -> Tuple[List[Tuple[Doc, List[Entity]]], List[Entity]]: """Extract the entities fron the text using the provided extractors. Examples: - >>> extractor("John Doe is a 19 year old software engineer.") + >>> extractor(text="John Doe is a 19 year old software engineer.", detect_repeats=False) [(Doc, [Entity]), (Doc, [Entity])], [Entity] Args: text: The text to extract entities from. + detect_repeats: Whether to check text again for repeated entities. Returns: The list of extractor outputs containing the tuple (spacy document, extracted entities). The list of joint entities. + """ - extractor_outputs = [e(text) for e in self.extractors] - joint_entities = self._merge_entities(extractor_outputs) + extractor_outputs = [e(text, detect_repeats) for e in self.extractors] + joint_entities = merge_entities(extractor_outputs) + return extractor_outputs, joint_entities def display(self, doc: Doc, page: bool = False, jupyter: bool = None) -> str: @@ -112,61 +115,3 @@ def display(self, doc: Doc, page: bool = False, jupyter: bool = None) -> str: return displacy.render( doc, style="ent", options=options, page=page, jupyter=jupyter ) - - def _merge_entities( - self, extractor_outputs: List[Tuple[Doc, List[Entity]]] - ) -> List[Entity]: - """Merges the entities returned by the extractors. - - Args: - extractor_outputs: The list of extractor outputs. - - Returns: - The merged entities list. - - """ - - if len(extractor_outputs) == 0: - return [] - if len(extractor_outputs) == 1: - return extractor_outputs[0][1] - - joint_entities = self._filter_entities( - list( - itertools.chain.from_iterable( - [entity[1] for entity in extractor_outputs] - ) - ) - ) - return joint_entities - - def _filter_entities(self, entities: Iterable[Entity]) -> List[Entity]: - """Filters the entities based on their start and end indices. - - Args: - entities: The entities to filter. - - Returns: - The filtered entities. - - """ - - def get_sort_key(entity): - return ( - entity.end_index - entity.start_index, - -entity.start_index, - ) - - sorted_entities = sorted(entities, key=get_sort_key, reverse=True) - result = [] - seen_tokens: Set[int] = set() - for entity in sorted_entities: - # Check for end - 1 here because boundaries are inclusive - if ( - entity.start_index not in seen_tokens - and entity.end_index - 1 not in seen_tokens - ): - result.append(entity) - seen_tokens.update(range(entity.start_index, entity.end_index)) - result = sorted(result, key=lambda entity: entity.start_index) - return result diff --git a/anonipy/anonymize/extractors/ner_extractor.py b/anonipy/anonymize/extractors/ner_extractor.py index cabd55e..936745b 100644 --- a/anonipy/anonymize/extractors/ner_extractor.py +++ b/anonipy/anonymize/extractors/ner_extractor.py @@ -8,7 +8,7 @@ from spacy.tokens import Doc, Span from spacy.language import Language -from ..helpers import convert_spacy_to_entity +from ..helpers import convert_spacy_to_entity, detect_repeated_entities, get_doc_entity_spans, set_doc_entity_spans from ...utils.regex import regex_mapping from ...constants import LANGUAGES from ...definitions import Entity @@ -16,6 +16,7 @@ from .interface import ExtractorInterface + # =============================================== # Extractor class # =============================================== @@ -29,7 +30,7 @@ class NERExtractor(ExtractorInterface): >>> from anonipy.anonymize.extractors import NERExtractor >>> labels = [{"label": "PERSON", "type": "string"}] >>> extractor = NERExtractor(labels, lang=LANGUAGES.ENGLISH) - >>> extractor("John Doe is a 19 year old software engineer.") + >>> extractor(text="John Doe is a 19 year old software engineer.", detect_repeats=False) Doc, [Entity] Attributes: @@ -92,15 +93,16 @@ def __init__( warnings.filterwarnings("ignore", category=ResourceWarning) self.pipeline = self._prepare_pipeline() - def __call__(self, text: str, *args, **kwargs) -> Tuple[Doc, List[Entity]]: + def __call__(self, text: str, detect_repeats: bool = False, *args, **kwargs) -> Tuple[Doc, List[Entity]]: """Extract the entities from the text. Examples: - >>> extractor("John Doe is a 19 year old software engineer.") + >>> extractor(text="John Doe is a 19 year old software engineer.", detect_repeats=False) Doc, [Entity] Args: text: The text to extract entities from. + detect_repeats: Whether to check text again for repeated entities. Returns: The spacy document. @@ -110,7 +112,11 @@ def __call__(self, text: str, *args, **kwargs) -> Tuple[Doc, List[Entity]]: doc = self.pipeline(text) anoni_entities, spacy_entities = self._prepare_entities(doc) - self._set_spacy_fields(doc, spacy_entities) + set_doc_entity_spans(self.spacy_style, doc, spacy_entities) + + if (detect_repeats): + anoni_entities = detect_repeated_entities(anoni_entities, doc, self.spacy_style) + return doc, anoni_entities def display(self, doc: Doc, page: bool = False, jupyter: bool = None) -> str: @@ -226,46 +232,9 @@ def _prepare_entities(self, doc: Doc) -> Tuple[List[Entity], List[Span]]: # TODO: make this part more generic anoni_entities = [] spacy_entities = [] - for s in self._get_spacy_fields(doc): + for s in get_doc_entity_spans(self.spacy_style, doc): label = list(filter(lambda x: x["label"] == s.label_, self.labels))[0] if re.match(label["regex"], s.text): anoni_entities.append(convert_spacy_to_entity(s, **label)) spacy_entities.append(s) return anoni_entities, spacy_entities - - def _get_spacy_fields(self, doc: Doc) -> List[Span]: - """Get the spacy doc entity spans. - - args: - doc: The spacy doc to get the entity spans from. - - Returns: - The list of Spans from the spacy doc. - - """ - - if self.spacy_style == "ent": - return doc.ents - elif self.spacy_style == "span": - return doc.spans["sc"] - else: - raise ValueError(f"Invalid spacy style: {self.spacy_style}") - - def _set_spacy_fields(self, doc: Doc, entities: List[Span]) -> None: - """Set the spacy doc entity spans. - - Args: - doc: The spacy doc to set the entity spans. - entities: The entity spans to set. - - Returns: - None - - """ - - if self.spacy_style == "ent": - doc.ents = entities - elif self.spacy_style == "span": - doc.spans["sc"] = entities - else: - raise ValueError(f"Invalid spacy style: {self.spacy_style}") diff --git a/anonipy/anonymize/extractors/pattern_extractor.py b/anonipy/anonymize/extractors/pattern_extractor.py index 32d1257..a7be84d 100644 --- a/anonipy/anonymize/extractors/pattern_extractor.py +++ b/anonipy/anonymize/extractors/pattern_extractor.py @@ -1,5 +1,4 @@ import re - import importlib from typing import List, Tuple, Optional, Callable @@ -8,7 +7,7 @@ from spacy.language import Language from spacy.matcher import Matcher -from ..helpers import convert_spacy_to_entity +from ..helpers import convert_spacy_to_entity, detect_repeated_entities, get_doc_entity_spans, set_doc_entity_spans from ...constants import LANGUAGES from ...definitions import Entity from ...utils.colors import get_label_color @@ -29,7 +28,7 @@ class PatternExtractor(ExtractorInterface): >>> from anonipy.anonymize.extractors import PatternExtractor >>> labels = [{"label": "PERSON", "type": "string", "regex": "([A-Z][a-z]+ [A-Z][a-z]+)"}] >>> extractor = PatternExtractor(labels, lang=LANGUAGES.ENGLISH) - >>> extractor("John Doe is a 19 year old software engineer.") + >>> extractor(text="John Doe is a 19 year old software engineer.", detect_repeats=False) Doc, [Entity] Attributes: @@ -79,15 +78,16 @@ def __init__( self.token_matchers = self._prepare_token_matchers() self.global_matchers = self._prepare_global_matchers() - def __call__(self, text: str, *args, **kwargs) -> Tuple[Doc, List[Entity]]: + def __call__(self, text: str, detect_repeats: bool = False, *args, **kwargs) -> Tuple[Doc, List[Entity]]: """Extract the entities from the text. Examples: - >>> extractor("John Doe is a 19 year old software engineer.") + >>> extractor(text="John Doe is a 19 year old software engineer.", detect_repeats=False) Doc, [Entity] Args: text: The text to extract entities from. + detect_repeats: Whether to check text again for repeated entities. Returns: The spacy document. @@ -99,7 +99,11 @@ def __call__(self, text: str, *args, **kwargs) -> Tuple[Doc, List[Entity]]: self.token_matchers(doc) if self.token_matchers else None self.global_matchers(doc) if self.global_matchers else None anoni_entities, spacy_entities = self._prepare_entities(doc) - self._set_doc_entity_spans(doc, spacy_entities) + set_doc_entity_spans(self.spacy_style, doc, spacy_entities) + + if (detect_repeats): + anoni_entities = detect_repeated_entities(anoni_entities, doc, self.spacy_style) + return doc, anoni_entities def display(self, doc: Doc, page: bool = False, jupyter: bool = None) -> str: @@ -196,14 +200,14 @@ def global_matchers(doc: Doc) -> None: continue entity._.score = 1.0 # add the entity to the previous entity list - prev_entities = self._get_doc_entity_spans(doc) + prev_entities = get_doc_entity_spans(self.spacy_style, doc) if self.spacy_style == "ent": prev_entities = util.filter_spans(prev_entities + (entity,)) elif self.spacy_style == "span": prev_entities.append(entity) else: raise ValueError(f"Invalid spacy style: {self.spacy_style}") - self._set_doc_entity_spans(doc, prev_entities) + set_doc_entity_spans(self.spacy_style, doc, prev_entities) return global_matchers @@ -222,7 +226,7 @@ def _prepare_entities(self, doc: Doc) -> Tuple[List[Entity], List[Span]]: # TODO: make this part more generic anoni_entities = [] spacy_entities = [] - for e in self._get_doc_entity_spans(doc): + for e in get_doc_entity_spans(self.spacy_style, doc): label = list(filter(lambda x: x["label"] == e.label_, self.labels))[0] anoni_entities.append(convert_spacy_to_entity(e, **label)) spacy_entities.append(e) @@ -247,48 +251,13 @@ def add_event_ent(matcher, doc, i, matches): return entity._.score = 1.0 # add the entity to the previous entity list - prev_entities = self._get_doc_entity_spans(doc) + prev_entities = get_doc_entity_spans(self.spacy_style, doc) if self.spacy_style == "ent": prev_entities = util.filter_spans(prev_entities + (entity,)) elif self.spacy_style == "span": prev_entities.append(entity) else: raise ValueError(f"Invalid spacy style: {self.spacy_style}") - self._set_doc_entity_spans(doc, prev_entities) - - return add_event_ent - - def _get_doc_entity_spans(self, doc: Doc) -> List[Span]: - """Get the spacy doc entity spans. - - Args: - doc: The spacy doc to get the entity spans from. - - Returns: - The list of entity spans. - - """ - - if self.spacy_style == "ent": - return doc.ents - if self.spacy_style == "span": - if "sc" not in doc.spans: - doc.spans["sc"] = [] - return doc.spans["sc"] - raise ValueError(f"Invalid spacy style: {self.spacy_style}") - - def _set_doc_entity_spans(self, doc: Doc, entities: List[Span]) -> None: - """Set the spacy doc entity spans. - - Args: - doc: The spacy doc to set the entity spans. - entities: The entity spans to assign the doc. - - """ + set_doc_entity_spans(self.spacy_style, doc, prev_entities) - if self.spacy_style == "ent": - doc.ents = entities - elif self.spacy_style == "span": - doc.spans["sc"] = entities - else: - raise ValueError(f"Invalid spacy style: {self.spacy_style}") + return add_event_ent \ No newline at end of file diff --git a/anonipy/anonymize/helpers.py b/anonipy/anonymize/helpers.py index 38fd79d..44f49f9 100644 --- a/anonipy/anonymize/helpers.py +++ b/anonipy/anonymize/helpers.py @@ -1,11 +1,14 @@ import re -from typing import List, Union +from typing import List, Union, Tuple, Iterable, Set +import itertools -from spacy.tokens import Span +from spacy import util +from spacy.tokens import Span, Doc from ..definitions import Entity, Replacement from ..constants import ENTITY_TYPES + # ===================================== # Entity converters # ===================================== @@ -72,3 +75,162 @@ def anonymize(text: str, replacements: List[Replacement]) -> str: + anonymized_text[replacement["end_index"] :] ) return anonymized_text, s_replacements[::-1] + + +# ===================================== +# Entity helpers +# ===================================== + + +def merge_entities(extractor_outputs: List[Tuple[Doc, List[Entity]]]) -> List[Entity]: + """Merges the entities returned by the extractors. + + Args: + extractor_outputs: The list of extractor outputs. + + Returns: + The merged entities list. + + """ + + if len(extractor_outputs) == 0: + return [] + if len(extractor_outputs) == 1: + return extractor_outputs[0][1] + + joint_entities = _filter_entities( + list( + itertools.chain.from_iterable( + [entity[1] for entity in extractor_outputs] + ) + ) + ) + return joint_entities + +def _filter_entities(entities: Iterable[Entity]) -> List[Entity]: + """Filters the entities based on their start and end indices. + + Args: + entities: The entities to filter. + + Returns: + The filtered entities. + + """ + + def get_sort_key(entity): + return ( + entity.end_index - entity.start_index, + -entity.start_index, + ) + + sorted_entities = sorted(entities, key=get_sort_key, reverse=True) + result = [] + seen_tokens: Set[int] = set() + for entity in sorted_entities: + # Check for end - 1 here because boundaries are inclusive + if ( + entity.start_index not in seen_tokens + and entity.end_index - 1 not in seen_tokens + ): + result.append(entity) + seen_tokens.update(range(entity.start_index, entity.end_index)) + result = sorted(result, key=lambda entity: entity.start_index) + return result + +def detect_repeated_entities(entities: List[Entity], doc: Doc, spacy_style: str) -> List[Entity]: + """Detects repeated entities in the text. + + Args: + entities: The entities to detect. + doc: The spacy doc to detect entities in. + spacy_style: The style the entities should be stored in the spacy doc. + + Returns: + The list of all entities. + + """ + + repeated_entities = [] + + for entity in entities: + matches = re.finditer(re.escape(entity.text), doc.text) + for match in matches: + start_index, end_index = match.start(), match.end() + if (start_index == entity.start_index and end_index == entity.end_index): + continue + repeated_entities.append( + Entity( + text = entity.text, + label = entity.label, + start_index = start_index, + end_index = end_index, + score = entity.score, + type = entity.type, + regex = entity.text + ) + ) + + filtered_entities = _filter_entities(entities + repeated_entities) + new_entities = [ent for ent in filtered_entities if ent not in entities] + updated_spans = get_doc_entity_spans(spacy_style, doc) + for entity in new_entities: + span = doc.char_span(entity.start_index, entity.end_index, label=entity.label) + if not span: + continue + span._.score = entity.score + if spacy_style == "ent": + updated_spans = util.filter_spans(updated_spans + (span,)) + elif spacy_style == "span": + updated_spans.append(span) + else: + raise ValueError(f"Invalid spacy style: {spacy_style}") + + set_doc_entity_spans(spacy_style, doc, updated_spans) + + final_entities = sorted(filtered_entities, key=lambda e: e.start_index) + + return final_entities + + +# ==================================== +# Spacy helpers +# ==================================== + + +def get_doc_entity_spans(spacy_style: str, doc: Doc) -> List[Span]: + """Get the spacy doc entity spans. + + Args: + spacy_style: The spacy style to use. + doc: The spacy doc to get the entity spans from. + + Returns: + The list of entity spans. + + """ + + if spacy_style == "ent": + return doc.ents + if spacy_style == "span": + if "sc" not in doc.spans: + doc.spans["sc"] = [] + return doc.spans["sc"] + raise ValueError(f"Invalid spacy style: {spacy_style}") + +def set_doc_entity_spans(spacy_style: str, doc: Doc, entities: List[Span]) -> None: + """Set the spacy doc entity spans. + + Args: + spacy_style: The spacy style to use. + doc: The spacy doc to set the entity spans. + entities: The entity spans to assign the doc. + + """ + + if spacy_style == "ent": + doc.ents = entities + elif spacy_style == "span": + doc.spans["sc"] = entities + else: + raise ValueError(f"Invalid spacy style: {spacy_style}") \ No newline at end of file diff --git a/test/test_extractors.py b/test/test_extractors.py index 59df2e0..e2e9c2f 100644 --- a/test/test_extractors.py +++ b/test/test_extractors.py @@ -7,6 +7,7 @@ from anonipy.definitions import Entity from anonipy.anonymize.extractors import NERExtractor, PatternExtractor, MultiExtractor from anonipy.constants import LANGUAGES +from anonipy.anonymize.helpers import _filter_entities # disable transformers logging logging.set_verbosity_error() @@ -343,7 +344,7 @@ def test_multi_extractor_extract_default(multi_extractor): # check the performance of the joint entities generation for p_entity, t_entity in zip( joint_entities, - multi_extractor._filter_entities(TEST_NER_ENTITIES + TEST_PATTERN_ENTITIES), + _filter_entities(TEST_NER_ENTITIES + TEST_PATTERN_ENTITIES), ): assert p_entity.text == t_entity.text assert p_entity.label == t_entity.label