From cde4b3af0d5ea60e5581fa6d696ee0ca819fb56a Mon Sep 17 00:00:00 2001 From: ninakokalj Date: Fri, 13 Dec 2024 08:45:24 +0100 Subject: [PATCH 1/8] Add function to detect repeated entities --- .../anonymize/extractors/multi_extractor.py | 69 +------ anonipy/anonymize/extractors/ner_extractor.py | 13 +- .../anonymize/extractors/pattern_extractor.py | 56 ++---- anonipy/anonymize/helpers.py | 170 +++++++++++++++++- 4 files changed, 200 insertions(+), 108 deletions(-) diff --git a/anonipy/anonymize/extractors/multi_extractor.py b/anonipy/anonymize/extractors/multi_extractor.py index b96cf0f..421853e 100644 --- a/anonipy/anonymize/extractors/multi_extractor.py +++ b/anonipy/anonymize/extractors/multi_extractor.py @@ -1,5 +1,4 @@ from typing import List, Set, Tuple, Iterable - import itertools from spacy import displacy @@ -7,6 +6,7 @@ from ...definitions import Entity from ...utils.colors import get_label_color +from ..helpers import merge_entities from .interface import ExtractorInterface @@ -63,7 +63,7 @@ def __init__(self, extractors: List[ExtractorInterface]): self.extractors = extractors def __call__( - self, text: str + self, text: str, detect_repeats: bool = False ) -> Tuple[List[Tuple[Doc, List[Entity]]], List[Entity]]: """Extract the entities fron the text using the provided extractors. @@ -73,14 +73,17 @@ def __call__( Args: text: The text to extract entities from. + detect_repeats: Whether to check text again for repeated entities. Returns: The list of extractor outputs containing the tuple (spacy document, extracted entities). The list of joint entities. """ - extractor_outputs = [e(text) for e in self.extractors] - joint_entities = self._merge_entities(extractor_outputs) + if (detect_repeats): + extractor_outputs = [e(text, detect_repeats) for e in self.extractors] # A JE TO PROU + joint_entities = merge_entities(extractor_outputs) + return extractor_outputs, joint_entities def display(self, doc: Doc, page: bool = False, jupyter: bool = None) -> str: @@ -109,60 +112,6 @@ def display(self, doc: Doc, page: bool = False, jupyter: bool = None) -> str: doc, style="ent", options=options, page=page, jupyter=jupyter ) - def _merge_entities( - self, extractor_outputs: List[Tuple[Doc, List[Entity]]] - ) -> List[Entity]: - """Merges the entities returned by the extractors. - - Args: - extractor_outputs: The list of extractor outputs. - - Returns: - The merged entities list. - - """ - - if len(extractor_outputs) == 0: - return [] - if len(extractor_outputs) == 1: - return extractor_outputs[1] - - joint_entities = self._filter_entities( - list( - itertools.chain.from_iterable( - [entity[1] for entity in extractor_outputs] - ) - ) - ) - return joint_entities - - def _filter_entities(self, entities: Iterable[Entity]) -> List[Entity]: - """Filters the entities based on their start and end indices. - - Args: - entities: The entities to filter. - - Returns: - The filtered entities. - - """ + + - def get_sort_key(entity): - return ( - entity.end_index - entity.start_index, - -entity.start_index, - ) - - sorted_entities = sorted(entities, key=get_sort_key, reverse=True) - result = [] - seen_tokens: Set[int] = set() - for entity in sorted_entities: - # Check for end - 1 here because boundaries are inclusive - if ( - entity.start_index not in seen_tokens - and entity.end_index - 1 not in seen_tokens - ): - result.append(entity) - seen_tokens.update(range(entity.start_index, entity.end_index)) - result = sorted(result, key=lambda entity: entity.start_index) - return result diff --git a/anonipy/anonymize/extractors/ner_extractor.py b/anonipy/anonymize/extractors/ner_extractor.py index 20ef52c..fab3a84 100644 --- a/anonipy/anonymize/extractors/ner_extractor.py +++ b/anonipy/anonymize/extractors/ner_extractor.py @@ -1,14 +1,15 @@ import re import warnings import importlib -from typing import List, Tuple +import itertools +from typing import List, Tuple, Iterable, Set import torch from spacy import displacy from spacy.tokens import Doc, Span from spacy.language import Language -from ..helpers import convert_spacy_to_entity +from ..helpers import convert_spacy_to_entity, detect_repeated_entities from ...utils.regex import regex_mapping from ...constants import LANGUAGES from ...definitions import Entity @@ -16,6 +17,7 @@ from .interface import ExtractorInterface + # =============================================== # Extractor class # =============================================== @@ -88,7 +90,7 @@ def __init__( self.labels = self._prepare_labels(labels) self.pipeline = self._prepare_pipeline() - def __call__(self, text: str, *args, **kwargs) -> Tuple[Doc, List[Entity]]: + def __call__(self, text: str, detect_repeats: bool = False, *args, **kwargs) -> Tuple[Doc, List[Entity]]: """Extract the entities from the text. Examples: @@ -97,6 +99,7 @@ def __call__(self, text: str, *args, **kwargs) -> Tuple[Doc, List[Entity]]: Args: text: The text to extract entities from. + detect_repeats: Whether to check text again for repeated entities. Returns: The spacy document. @@ -107,6 +110,10 @@ def __call__(self, text: str, *args, **kwargs) -> Tuple[Doc, List[Entity]]: doc = self.pipeline(text) anoni_entities, spacy_entities = self._prepare_entities(doc) self._set_spacy_fields(doc, spacy_entities) + + if (detect_repeats): + anoni_entities = detect_repeated_entities(anoni_entities, doc, self.spacy_style) + return doc, anoni_entities def display(self, doc: Doc, page: bool = False, jupyter: bool = None) -> str: diff --git a/anonipy/anonymize/extractors/pattern_extractor.py b/anonipy/anonymize/extractors/pattern_extractor.py index 32d1257..efe4d84 100644 --- a/anonipy/anonymize/extractors/pattern_extractor.py +++ b/anonipy/anonymize/extractors/pattern_extractor.py @@ -8,7 +8,7 @@ from spacy.language import Language from spacy.matcher import Matcher -from ..helpers import convert_spacy_to_entity +from ..helpers import convert_spacy_to_entity, detect_repeated_entities, get_doc_entity_spans, set_doc_entity_spans from ...constants import LANGUAGES from ...definitions import Entity from ...utils.colors import get_label_color @@ -79,7 +79,7 @@ def __init__( self.token_matchers = self._prepare_token_matchers() self.global_matchers = self._prepare_global_matchers() - def __call__(self, text: str, *args, **kwargs) -> Tuple[Doc, List[Entity]]: + def __call__(self, text: str, detect_repeats: bool = False, *args, **kwargs) -> Tuple[Doc, List[Entity]]: """Extract the entities from the text. Examples: @@ -88,6 +88,7 @@ def __call__(self, text: str, *args, **kwargs) -> Tuple[Doc, List[Entity]]: Args: text: The text to extract entities from. + detect_repeats: Whether to check text again for repeated entities. Returns: The spacy document. @@ -99,7 +100,11 @@ def __call__(self, text: str, *args, **kwargs) -> Tuple[Doc, List[Entity]]: self.token_matchers(doc) if self.token_matchers else None self.global_matchers(doc) if self.global_matchers else None anoni_entities, spacy_entities = self._prepare_entities(doc) - self._set_doc_entity_spans(doc, spacy_entities) + set_doc_entity_spans(self.spacy_style, doc, spacy_entities) + + if (detect_repeats): + anoni_entities = detect_repeated_entities(anoni_entities, doc, self.spacy_style) + return doc, anoni_entities def display(self, doc: Doc, page: bool = False, jupyter: bool = None) -> str: @@ -196,14 +201,14 @@ def global_matchers(doc: Doc) -> None: continue entity._.score = 1.0 # add the entity to the previous entity list - prev_entities = self._get_doc_entity_spans(doc) + prev_entities = get_doc_entity_spans(self.spacy_style, doc) if self.spacy_style == "ent": prev_entities = util.filter_spans(prev_entities + (entity,)) elif self.spacy_style == "span": prev_entities.append(entity) else: raise ValueError(f"Invalid spacy style: {self.spacy_style}") - self._set_doc_entity_spans(doc, prev_entities) + set_doc_entity_spans(self.spacy_style, doc, prev_entities) return global_matchers @@ -222,7 +227,7 @@ def _prepare_entities(self, doc: Doc) -> Tuple[List[Entity], List[Span]]: # TODO: make this part more generic anoni_entities = [] spacy_entities = [] - for e in self._get_doc_entity_spans(doc): + for e in get_doc_entity_spans(self.spacy_style, doc): label = list(filter(lambda x: x["label"] == e.label_, self.labels))[0] anoni_entities.append(convert_spacy_to_entity(e, **label)) spacy_entities.append(e) @@ -247,48 +252,13 @@ def add_event_ent(matcher, doc, i, matches): return entity._.score = 1.0 # add the entity to the previous entity list - prev_entities = self._get_doc_entity_spans(doc) + prev_entities = get_doc_entity_spans(self.spacy_style, doc) if self.spacy_style == "ent": prev_entities = util.filter_spans(prev_entities + (entity,)) elif self.spacy_style == "span": prev_entities.append(entity) else: raise ValueError(f"Invalid spacy style: {self.spacy_style}") - self._set_doc_entity_spans(doc, prev_entities) + set_doc_entity_spans(self.spacy_style, doc, prev_entities) return add_event_ent - - def _get_doc_entity_spans(self, doc: Doc) -> List[Span]: - """Get the spacy doc entity spans. - - Args: - doc: The spacy doc to get the entity spans from. - - Returns: - The list of entity spans. - - """ - - if self.spacy_style == "ent": - return doc.ents - if self.spacy_style == "span": - if "sc" not in doc.spans: - doc.spans["sc"] = [] - return doc.spans["sc"] - raise ValueError(f"Invalid spacy style: {self.spacy_style}") - - def _set_doc_entity_spans(self, doc: Doc, entities: List[Span]) -> None: - """Set the spacy doc entity spans. - - Args: - doc: The spacy doc to set the entity spans. - entities: The entity spans to assign the doc. - - """ - - if self.spacy_style == "ent": - doc.ents = entities - elif self.spacy_style == "span": - doc.spans["sc"] = entities - else: - raise ValueError(f"Invalid spacy style: {self.spacy_style}") diff --git a/anonipy/anonymize/helpers.py b/anonipy/anonymize/helpers.py index 38fd79d..24f19d5 100644 --- a/anonipy/anonymize/helpers.py +++ b/anonipy/anonymize/helpers.py @@ -1,7 +1,9 @@ import re -from typing import List, Union +from typing import List, Union, Tuple, Iterable, Set +import itertools -from spacy.tokens import Span +from spacy import util +from spacy.tokens import Span, Doc from ..definitions import Entity, Replacement from ..constants import ENTITY_TYPES @@ -72,3 +74,167 @@ def anonymize(text: str, replacements: List[Replacement]) -> str: + anonymized_text[replacement["end_index"] :] ) return anonymized_text, s_replacements[::-1] + + +# ===================================== +# Entity helpers +# ===================================== + + +def merge_entities(extractor_outputs: List[Tuple[Doc, List[Entity]]]) -> List[Entity]: + """Merges the entities returned by the extractors. + + Args: + extractor_outputs: The list of extractor outputs. + + Returns: + The merged entities list. + + """ + + if len(extractor_outputs) == 0: + return [] + if len(extractor_outputs) == 1: + return extractor_outputs[1] + + joint_entities = _filter_entities( + list( + itertools.chain.from_iterable( + [entity[1] for entity in extractor_outputs] + ) + ) + ) + return joint_entities + +def _filter_entities(entities: Iterable[Entity]) -> List[Entity]: + """Filters the entities based on their start and end indices. + + Args: + entities: The entities to filter. + + Returns: + The filtered entities. + + """ + + def get_sort_key(entity): + return ( + entity.end_index - entity.start_index, + -entity.start_index, + ) + + sorted_entities = sorted(entities, key=get_sort_key, reverse=True) + result = [] + seen_tokens: Set[int] = set() + for entity in sorted_entities: + # Check for end - 1 here because boundaries are inclusive + if ( + entity.start_index not in seen_tokens + and entity.end_index - 1 not in seen_tokens + ): + result.append(entity) + seen_tokens.update(range(entity.start_index, entity.end_index)) + result = sorted(result, key=lambda entity: entity.start_index) + return result + + +def detect_repeated_entities(entities: List[Entity], doc: Doc, spacy_style: str) -> List[Entity]: + """Detects repeated entities in the text. + + Args: + entities: The entities to detect. + doc: The spacy doc to detect entities in. + spacy_style: The spacy style to use. + + Returns: + The list of all entities. + + """ + + repeated_entities = [] + + for entity in entities: + matches = re.finditer(re.escape(entity.text), doc.text) + + for match in matches: + start_index, end_index = match.start(), match.end() + if (start_index != entity.start_index and end_index != entity.end_index): + repeated_entities.append( + Entity( + text = entity.text, + label = entity.label, + start_index = start_index, + end_index = end_index, + score = entity.score, + type = entity.type, + regex = entity.text + ) + ) + + filtered_entities = _filter_entities(entities + repeated_entities) + new_entities = [ent for ent in filtered_entities if ent not in entities] + updated_spans = get_doc_entity_spans(spacy_style, doc) + for entity in new_entities: + span = doc.char_span(entity.start_index, entity.end_index, label=entity.label) + if span: + span._.score = entity.score + if spacy_style == "ent": + updated_spans = util.filter_spans(updated_spans + (span,)) + elif spacy_style == "span": + updated_spans.append(span) + else: + raise ValueError(f"Invalid spacy style: {spacy_style}") + + set_doc_entity_spans(spacy_style, doc, updated_spans) + + final_entities = sorted(filtered_entities, key=lambda e: e.start_index) + + return final_entities + +# Filtriramo entitete tako, da če se slučajno kje prekrivajo entiteti, vzamemo tisto, ki ima večji span. Mislim da imamo že eno metodo merge_entities (poglej v MultiExtractor) +# Entity(text='John Doe', label='name', start_index=30, end_index=38, score=0.9963099360466003, type='string', regex='.*') +# covert spacy to entities +# pip install -e .[all] + + +# ==================================== +# Spacy helpers +# ==================================== + + +def get_doc_entity_spans(spacy_style: str, doc: Doc) -> List[Span]: + """Get the spacy doc entity spans. + + Args: + spacy_style: The spacy style to use. + doc: The spacy doc to get the entity spans from. + + Returns: + The list of entity spans. + + """ + + if spacy_style == "ent": + return doc.ents + if spacy_style == "span": + if "sc" not in doc.spans: + doc.spans["sc"] = [] + return doc.spans["sc"] + raise ValueError(f"Invalid spacy style: {spacy_style}") + +def set_doc_entity_spans(spacy_style: str, doc: Doc, entities: List[Span]) -> None: + """Set the spacy doc entity spans. + + Args: + spacy_style: The spacy style to use. + doc: The spacy doc to set the entity spans. + entities: The entity spans to assign the doc. + + """ + + if spacy_style == "ent": + doc.ents = entities + elif spacy_style == "span": + doc.spans["sc"] = entities + else: + raise ValueError(f"Invalid spacy style: {spacy_style}") \ No newline at end of file From ab2e084893cdfa5b7de1ba13446c07b20a0f39ab Mon Sep 17 00:00:00 2001 From: ninakokalj Date: Sun, 15 Dec 2024 12:35:49 +0100 Subject: [PATCH 2/8] Add a new function to detect repeated entities --- .../anonymize/extractors/multi_extractor.py | 10 ++-- anonipy/anonymize/extractors/ner_extractor.py | 50 +++--------------- .../anonymize/extractors/pattern_extractor.py | 7 ++- anonipy/anonymize/helpers.py | 52 +++++++++---------- 4 files changed, 38 insertions(+), 81 deletions(-) diff --git a/anonipy/anonymize/extractors/multi_extractor.py b/anonipy/anonymize/extractors/multi_extractor.py index 421853e..b51d9d9 100644 --- a/anonipy/anonymize/extractors/multi_extractor.py +++ b/anonipy/anonymize/extractors/multi_extractor.py @@ -1,4 +1,4 @@ -from typing import List, Set, Tuple, Iterable +from typing import List, Tuple import itertools from spacy import displacy @@ -27,7 +27,7 @@ class MultiExtractor: >>> PatternExtractor(pattern_labels, lang=LANGUAGES.ENGLISH), >>> ] >>> extractor = MultiExtractor(extractors) - >>> extractor("John Doe is a 19 year old software engineer.") + >>> extractor(text="John Doe is a 19 year old software engineer.", detect_repeats=False) [(Doc, [Entity]), (Doc, [Entity])], [Entity] Attributes: @@ -68,7 +68,7 @@ def __call__( """Extract the entities fron the text using the provided extractors. Examples: - >>> extractor("John Doe is a 19 year old software engineer.") + >>> extractor(text="John Doe is a 19 year old software engineer.", detect_repeats=False) [(Doc, [Entity]), (Doc, [Entity])], [Entity] Args: @@ -78,10 +78,10 @@ def __call__( Returns: The list of extractor outputs containing the tuple (spacy document, extracted entities). The list of joint entities. + """ - if (detect_repeats): - extractor_outputs = [e(text, detect_repeats) for e in self.extractors] # A JE TO PROU + extractor_outputs = [e(text, detect_repeats) for e in self.extractors] joint_entities = merge_entities(extractor_outputs) return extractor_outputs, joint_entities diff --git a/anonipy/anonymize/extractors/ner_extractor.py b/anonipy/anonymize/extractors/ner_extractor.py index fab3a84..7f21bea 100644 --- a/anonipy/anonymize/extractors/ner_extractor.py +++ b/anonipy/anonymize/extractors/ner_extractor.py @@ -1,15 +1,14 @@ import re import warnings import importlib -import itertools -from typing import List, Tuple, Iterable, Set +from typing import List, Tuple import torch from spacy import displacy from spacy.tokens import Doc, Span from spacy.language import Language -from ..helpers import convert_spacy_to_entity, detect_repeated_entities +from ..helpers import convert_spacy_to_entity, detect_repeated_entities, get_doc_entity_spans, set_doc_entity_spans from ...utils.regex import regex_mapping from ...constants import LANGUAGES from ...definitions import Entity @@ -31,7 +30,7 @@ class NERExtractor(ExtractorInterface): >>> from anonipy.anonymize.extractors import NERExtractor >>> labels = [{"label": "PERSON", "type": "string"}] >>> extractor = NERExtractor(labels, lang=LANGUAGES.ENGLISH) - >>> extractor("John Doe is a 19 year old software engineer.") + >>> extractor(text="John Doe is a 19 year old software engineer.", detect_repeats=False) Doc, [Entity] Attributes: @@ -94,7 +93,7 @@ def __call__(self, text: str, detect_repeats: bool = False, *args, **kwargs) -> """Extract the entities from the text. Examples: - >>> extractor("John Doe is a 19 year old software engineer.") + >>> extractor(text="John Doe is a 19 year old software engineer.", detect_repeats=False) Doc, [Entity] Args: @@ -109,7 +108,7 @@ def __call__(self, text: str, detect_repeats: bool = False, *args, **kwargs) -> doc = self.pipeline(text) anoni_entities, spacy_entities = self._prepare_entities(doc) - self._set_spacy_fields(doc, spacy_entities) + set_doc_entity_spans(self.spacy_style, doc, spacy_entities) if (detect_repeats): anoni_entities = detect_repeated_entities(anoni_entities, doc, self.spacy_style) @@ -229,46 +228,9 @@ def _prepare_entities(self, doc: Doc) -> Tuple[List[Entity], List[Span]]: # TODO: make this part more generic anoni_entities = [] spacy_entities = [] - for s in self._get_spacy_fields(doc): + for s in get_doc_entity_spans(self.spacy_style, doc): label = list(filter(lambda x: x["label"] == s.label_, self.labels))[0] if re.match(label["regex"], s.text): anoni_entities.append(convert_spacy_to_entity(s, **label)) spacy_entities.append(s) return anoni_entities, spacy_entities - - def _get_spacy_fields(self, doc: Doc) -> List[Span]: - """Get the spacy doc entity spans. - - args: - doc: The spacy doc to get the entity spans from. - - Returns: - The list of Spans from the spacy doc. - - """ - - if self.spacy_style == "ent": - return doc.ents - elif self.spacy_style == "span": - return doc.spans["sc"] - else: - raise ValueError(f"Invalid spacy style: {self.spacy_style}") - - def _set_spacy_fields(self, doc: Doc, entities: List[Span]) -> None: - """Set the spacy doc entity spans. - - Args: - doc: The spacy doc to set the entity spans. - entities: The entity spans to set. - - Returns: - None - - """ - - if self.spacy_style == "ent": - doc.ents = entities - elif self.spacy_style == "span": - doc.spans["sc"] = entities - else: - raise ValueError(f"Invalid spacy style: {self.spacy_style}") diff --git a/anonipy/anonymize/extractors/pattern_extractor.py b/anonipy/anonymize/extractors/pattern_extractor.py index efe4d84..a7be84d 100644 --- a/anonipy/anonymize/extractors/pattern_extractor.py +++ b/anonipy/anonymize/extractors/pattern_extractor.py @@ -1,5 +1,4 @@ import re - import importlib from typing import List, Tuple, Optional, Callable @@ -29,7 +28,7 @@ class PatternExtractor(ExtractorInterface): >>> from anonipy.anonymize.extractors import PatternExtractor >>> labels = [{"label": "PERSON", "type": "string", "regex": "([A-Z][a-z]+ [A-Z][a-z]+)"}] >>> extractor = PatternExtractor(labels, lang=LANGUAGES.ENGLISH) - >>> extractor("John Doe is a 19 year old software engineer.") + >>> extractor(text="John Doe is a 19 year old software engineer.", detect_repeats=False) Doc, [Entity] Attributes: @@ -83,7 +82,7 @@ def __call__(self, text: str, detect_repeats: bool = False, *args, **kwargs) -> """Extract the entities from the text. Examples: - >>> extractor("John Doe is a 19 year old software engineer.") + >>> extractor(text="John Doe is a 19 year old software engineer.", detect_repeats=False) Doc, [Entity] Args: @@ -261,4 +260,4 @@ def add_event_ent(matcher, doc, i, matches): raise ValueError(f"Invalid spacy style: {self.spacy_style}") set_doc_entity_spans(self.spacy_style, doc, prev_entities) - return add_event_ent + return add_event_ent \ No newline at end of file diff --git a/anonipy/anonymize/helpers.py b/anonipy/anonymize/helpers.py index 24f19d5..8aa8fca 100644 --- a/anonipy/anonymize/helpers.py +++ b/anonipy/anonymize/helpers.py @@ -8,6 +8,7 @@ from ..definitions import Entity, Replacement from ..constants import ENTITY_TYPES + # ===================================== # Entity converters # ===================================== @@ -137,14 +138,13 @@ def get_sort_key(entity): result = sorted(result, key=lambda entity: entity.start_index) return result - def detect_repeated_entities(entities: List[Entity], doc: Doc, spacy_style: str) -> List[Entity]: """Detects repeated entities in the text. Args: entities: The entities to detect. doc: The spacy doc to detect entities in. - spacy_style: The spacy style to use. + spacy_style: The style the entities should be stored in the spacy doc. Returns: The list of all entities. @@ -155,35 +155,36 @@ def detect_repeated_entities(entities: List[Entity], doc: Doc, spacy_style: str) for entity in entities: matches = re.finditer(re.escape(entity.text), doc.text) - for match in matches: start_index, end_index = match.start(), match.end() - if (start_index != entity.start_index and end_index != entity.end_index): - repeated_entities.append( - Entity( - text = entity.text, - label = entity.label, - start_index = start_index, - end_index = end_index, - score = entity.score, - type = entity.type, - regex = entity.text - ) + if (start_index == entity.start_index and end_index == entity.end_index): + continue + repeated_entities.append( + Entity( + text = entity.text, + label = entity.label, + start_index = start_index, + end_index = end_index, + score = entity.score, + type = entity.type, + regex = entity.text ) + ) filtered_entities = _filter_entities(entities + repeated_entities) new_entities = [ent for ent in filtered_entities if ent not in entities] updated_spans = get_doc_entity_spans(spacy_style, doc) for entity in new_entities: - span = doc.char_span(entity.start_index, entity.end_index, label=entity.label) - if span: - span._.score = entity.score - if spacy_style == "ent": - updated_spans = util.filter_spans(updated_spans + (span,)) - elif spacy_style == "span": - updated_spans.append(span) - else: - raise ValueError(f"Invalid spacy style: {spacy_style}") + span = doc.char_span(entity.start_index, entity.end_index, label=entity.label) + if not span: + continue + span._.score = entity.score + if spacy_style == "ent": + updated_spans = util.filter_spans(updated_spans + (span,)) + elif spacy_style == "span": + updated_spans.append(span) + else: + raise ValueError(f"Invalid spacy style: {spacy_style}") set_doc_entity_spans(spacy_style, doc, updated_spans) @@ -191,11 +192,6 @@ def detect_repeated_entities(entities: List[Entity], doc: Doc, spacy_style: str) return final_entities -# Filtriramo entitete tako, da če se slučajno kje prekrivajo entiteti, vzamemo tisto, ki ima večji span. Mislim da imamo že eno metodo merge_entities (poglej v MultiExtractor) -# Entity(text='John Doe', label='name', start_index=30, end_index=38, score=0.9963099360466003, type='string', regex='.*') -# covert spacy to entities -# pip install -e .[all] - # ==================================== # Spacy helpers From 2f5bbb187b5477132883b21a6fe9f8dd6381b0f1 Mon Sep 17 00:00:00 2001 From: ninakokalj Date: Sun, 15 Dec 2024 13:36:09 +0100 Subject: [PATCH 3/8] Fix an extractor test due to function location change --- anonipy/anonymize/helpers.py | 2 +- test/test_extractors.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/anonipy/anonymize/helpers.py b/anonipy/anonymize/helpers.py index 8aa8fca..44f49f9 100644 --- a/anonipy/anonymize/helpers.py +++ b/anonipy/anonymize/helpers.py @@ -96,7 +96,7 @@ def merge_entities(extractor_outputs: List[Tuple[Doc, List[Entity]]]) -> List[En if len(extractor_outputs) == 0: return [] if len(extractor_outputs) == 1: - return extractor_outputs[1] + return extractor_outputs[0][1] joint_entities = _filter_entities( list( diff --git a/test/test_extractors.py b/test/test_extractors.py index 59df2e0..e2e9c2f 100644 --- a/test/test_extractors.py +++ b/test/test_extractors.py @@ -7,6 +7,7 @@ from anonipy.definitions import Entity from anonipy.anonymize.extractors import NERExtractor, PatternExtractor, MultiExtractor from anonipy.constants import LANGUAGES +from anonipy.anonymize.helpers import _filter_entities # disable transformers logging logging.set_verbosity_error() @@ -343,7 +344,7 @@ def test_multi_extractor_extract_default(multi_extractor): # check the performance of the joint entities generation for p_entity, t_entity in zip( joint_entities, - multi_extractor._filter_entities(TEST_NER_ENTITIES + TEST_PATTERN_ENTITIES), + _filter_entities(TEST_NER_ENTITIES + TEST_PATTERN_ENTITIES), ): assert p_entity.text == t_entity.text assert p_entity.label == t_entity.label From 2d970b2be66a86818d4302901f03d97c8ae81fb9 Mon Sep 17 00:00:00 2001 From: ninakokalj Date: Sun, 15 Dec 2024 23:04:08 +0100 Subject: [PATCH 4/8] Create a new function for creating spacy entities --- .../anonymize/extractors/multi_extractor.py | 4 +- anonipy/anonymize/extractors/ner_extractor.py | 13 ++-- .../anonymize/extractors/pattern_extractor.py | 32 ++++------ anonipy/anonymize/helpers.py | 59 +++++++++++-------- test/test_extractors.py | 4 +- 5 files changed, 59 insertions(+), 53 deletions(-) diff --git a/anonipy/anonymize/extractors/multi_extractor.py b/anonipy/anonymize/extractors/multi_extractor.py index dc3f588..7ae7c15 100644 --- a/anonipy/anonymize/extractors/multi_extractor.py +++ b/anonipy/anonymize/extractors/multi_extractor.py @@ -27,7 +27,7 @@ class MultiExtractor: >>> PatternExtractor(pattern_labels, lang=LANGUAGES.ENGLISH), >>> ] >>> extractor = MultiExtractor(extractors) - >>> extractor(text="John Doe is a 19 year old software engineer.", detect_repeats=False) + >>> extractor("John Doe is a 19 year old software engineer.", detect_repeats=False) [(Doc, [Entity]), (Doc, [Entity])], [Entity] Attributes: @@ -72,7 +72,7 @@ def __call__( """Extract the entities fron the text using the provided extractors. Examples: - >>> extractor(text="John Doe is a 19 year old software engineer.", detect_repeats=False) + >>> extractor("John Doe is a 19 year old software engineer.", detect_repeats=False) [(Doc, [Entity]), (Doc, [Entity])], [Entity] Args: diff --git a/anonipy/anonymize/extractors/ner_extractor.py b/anonipy/anonymize/extractors/ner_extractor.py index 936745b..a0c598b 100644 --- a/anonipy/anonymize/extractors/ner_extractor.py +++ b/anonipy/anonymize/extractors/ner_extractor.py @@ -8,7 +8,7 @@ from spacy.tokens import Doc, Span from spacy.language import Language -from ..helpers import convert_spacy_to_entity, detect_repeated_entities, get_doc_entity_spans, set_doc_entity_spans +from ..helpers import convert_spacy_to_entity, detect_repeated_entities, get_doc_entity_spans, create_spacy_entities from ...utils.regex import regex_mapping from ...constants import LANGUAGES from ...definitions import Entity @@ -30,7 +30,7 @@ class NERExtractor(ExtractorInterface): >>> from anonipy.anonymize.extractors import NERExtractor >>> labels = [{"label": "PERSON", "type": "string"}] >>> extractor = NERExtractor(labels, lang=LANGUAGES.ENGLISH) - >>> extractor(text="John Doe is a 19 year old software engineer.", detect_repeats=False) + >>> extractor("John Doe is a 19 year old software engineer.", detect_repeats=False) Doc, [Entity] Attributes: @@ -97,7 +97,7 @@ def __call__(self, text: str, detect_repeats: bool = False, *args, **kwargs) -> """Extract the entities from the text. Examples: - >>> extractor(text="John Doe is a 19 year old software engineer.", detect_repeats=False) + >>> extractor("John Doe is a 19 year old software engineer.", detect_repeats=False) Doc, [Entity] Args: @@ -112,10 +112,11 @@ def __call__(self, text: str, detect_repeats: bool = False, *args, **kwargs) -> doc = self.pipeline(text) anoni_entities, spacy_entities = self._prepare_entities(doc) - set_doc_entity_spans(self.spacy_style, doc, spacy_entities) - if (detect_repeats): - anoni_entities = detect_repeated_entities(anoni_entities, doc, self.spacy_style) + if detect_repeats: + anoni_entities = detect_repeated_entities(doc, anoni_entities, self.spacy_style) + + create_spacy_entities(doc, anoni_entities, self.spacy_style) return doc, anoni_entities diff --git a/anonipy/anonymize/extractors/pattern_extractor.py b/anonipy/anonymize/extractors/pattern_extractor.py index a7be84d..71e0a09 100644 --- a/anonipy/anonymize/extractors/pattern_extractor.py +++ b/anonipy/anonymize/extractors/pattern_extractor.py @@ -7,7 +7,7 @@ from spacy.language import Language from spacy.matcher import Matcher -from ..helpers import convert_spacy_to_entity, detect_repeated_entities, get_doc_entity_spans, set_doc_entity_spans +from ..helpers import convert_spacy_to_entity, detect_repeated_entities, get_doc_entity_spans, set_doc_entity_spans, create_spacy_entities from ...constants import LANGUAGES from ...definitions import Entity from ...utils.colors import get_label_color @@ -28,7 +28,7 @@ class PatternExtractor(ExtractorInterface): >>> from anonipy.anonymize.extractors import PatternExtractor >>> labels = [{"label": "PERSON", "type": "string", "regex": "([A-Z][a-z]+ [A-Z][a-z]+)"}] >>> extractor = PatternExtractor(labels, lang=LANGUAGES.ENGLISH) - >>> extractor(text="John Doe is a 19 year old software engineer.", detect_repeats=False) + >>> extractor("John Doe is a 19 year old software engineer.", detect_repeats=False) Doc, [Entity] Attributes: @@ -82,7 +82,7 @@ def __call__(self, text: str, detect_repeats: bool = False, *args, **kwargs) -> """Extract the entities from the text. Examples: - >>> extractor(text="John Doe is a 19 year old software engineer.", detect_repeats=False) + >>> extractor("John Doe is a 19 year old software engineer.", detect_repeats=False) Doc, [Entity] Args: @@ -99,10 +99,11 @@ def __call__(self, text: str, detect_repeats: bool = False, *args, **kwargs) -> self.token_matchers(doc) if self.token_matchers else None self.global_matchers(doc) if self.global_matchers else None anoni_entities, spacy_entities = self._prepare_entities(doc) - set_doc_entity_spans(self.spacy_style, doc, spacy_entities) - if (detect_repeats): - anoni_entities = detect_repeated_entities(anoni_entities, doc, self.spacy_style) + if detect_repeats: + anoni_entities = detect_repeated_entities(doc, anoni_entities, self.spacy_style) + + create_spacy_entities(doc, anoni_entities, self.spacy_style) return doc, anoni_entities @@ -200,14 +201,14 @@ def global_matchers(doc: Doc) -> None: continue entity._.score = 1.0 # add the entity to the previous entity list - prev_entities = get_doc_entity_spans(self.spacy_style, doc) + prev_entities = get_doc_entity_spans(doc, self.spacy_style) if self.spacy_style == "ent": prev_entities = util.filter_spans(prev_entities + (entity,)) elif self.spacy_style == "span": prev_entities.append(entity) else: raise ValueError(f"Invalid spacy style: {self.spacy_style}") - set_doc_entity_spans(self.spacy_style, doc, prev_entities) + set_doc_entity_spans(doc, prev_entities, self.spacy_style) return global_matchers @@ -226,7 +227,7 @@ def _prepare_entities(self, doc: Doc) -> Tuple[List[Entity], List[Span]]: # TODO: make this part more generic anoni_entities = [] spacy_entities = [] - for e in get_doc_entity_spans(self.spacy_style, doc): + for e in get_doc_entity_spans(doc, self.spacy_style): label = list(filter(lambda x: x["label"] == e.label_, self.labels))[0] anoni_entities.append(convert_spacy_to_entity(e, **label)) spacy_entities.append(e) @@ -249,15 +250,8 @@ def add_event_ent(matcher, doc, i, matches): entity = Span(doc, start, end, label=label) if not entity: return - entity._.score = 1.0 - # add the entity to the previous entity list - prev_entities = get_doc_entity_spans(self.spacy_style, doc) - if self.spacy_style == "ent": - prev_entities = util.filter_spans(prev_entities + (entity,)) - elif self.spacy_style == "span": - prev_entities.append(entity) - else: - raise ValueError(f"Invalid spacy style: {self.spacy_style}") - set_doc_entity_spans(self.spacy_style, doc, prev_entities) + entities = [convert_spacy_to_entity(entity)] + + create_spacy_entities(doc, entities, self.spacy_style) return add_event_ent \ No newline at end of file diff --git a/anonipy/anonymize/helpers.py b/anonipy/anonymize/helpers.py index 44f49f9..f65db3b 100644 --- a/anonipy/anonymize/helpers.py +++ b/anonipy/anonymize/helpers.py @@ -98,7 +98,7 @@ def merge_entities(extractor_outputs: List[Tuple[Doc, List[Entity]]]) -> List[En if len(extractor_outputs) == 1: return extractor_outputs[0][1] - joint_entities = _filter_entities( + joint_entities = filter_entities( list( itertools.chain.from_iterable( [entity[1] for entity in extractor_outputs] @@ -107,7 +107,7 @@ def merge_entities(extractor_outputs: List[Tuple[Doc, List[Entity]]]) -> List[En ) return joint_entities -def _filter_entities(entities: Iterable[Entity]) -> List[Entity]: +def filter_entities(entities: Iterable[Entity]) -> List[Entity]: """Filters the entities based on their start and end indices. Args: @@ -138,12 +138,12 @@ def get_sort_key(entity): result = sorted(result, key=lambda entity: entity.start_index) return result -def detect_repeated_entities(entities: List[Entity], doc: Doc, spacy_style: str) -> List[Entity]: +def detect_repeated_entities(doc: Doc, entities: List[Entity], spacy_style: str) -> List[Entity]: """Detects repeated entities in the text. Args: - entities: The entities to detect. doc: The spacy doc to detect entities in. + entities: The entities to detect. spacy_style: The style the entities should be stored in the spacy doc. Returns: @@ -157,7 +157,7 @@ def detect_repeated_entities(entities: List[Entity], doc: Doc, spacy_style: str) matches = re.finditer(re.escape(entity.text), doc.text) for match in matches: start_index, end_index = match.start(), match.end() - if (start_index == entity.start_index and end_index == entity.end_index): + if start_index == entity.start_index and end_index == entity.end_index: continue repeated_entities.append( Entity( @@ -171,10 +171,30 @@ def detect_repeated_entities(entities: List[Entity], doc: Doc, spacy_style: str) ) ) - filtered_entities = _filter_entities(entities + repeated_entities) - new_entities = [ent for ent in filtered_entities if ent not in entities] - updated_spans = get_doc_entity_spans(spacy_style, doc) - for entity in new_entities: + filtered_entities = filter_entities(entities + repeated_entities) + final_entities = sorted(filtered_entities, key=lambda e: e.start_index) + + return final_entities + + +# ==================================== +# Spacy helpers +# ==================================== + + +def create_spacy_entities(doc: Doc, entities: List[Entity], spacy_style: str) -> None: + """Create spacy entities in the spacy doc. + + Args: + doc: The spacy doc to create entities in. + entities: The entities to create. + spacy_style: The style the entities should be stored in the spacy doc. + + """ + + updated_spans = get_doc_entity_spans(doc, spacy_style) + + for entity in entities: span = doc.char_span(entity.start_index, entity.end_index, label=entity.label) if not span: continue @@ -186,24 +206,15 @@ def detect_repeated_entities(entities: List[Entity], doc: Doc, spacy_style: str) else: raise ValueError(f"Invalid spacy style: {spacy_style}") - set_doc_entity_spans(spacy_style, doc, updated_spans) - - final_entities = sorted(filtered_entities, key=lambda e: e.start_index) + set_doc_entity_spans(doc, updated_spans, spacy_style) - return final_entities - - -# ==================================== -# Spacy helpers -# ==================================== - - -def get_doc_entity_spans(spacy_style: str, doc: Doc) -> List[Span]: + +def get_doc_entity_spans(doc: Doc, spacy_style: str) -> List[Span]: """Get the spacy doc entity spans. Args: - spacy_style: The spacy style to use. doc: The spacy doc to get the entity spans from. + spacy_style: The spacy style to use. Returns: The list of entity spans. @@ -218,13 +229,13 @@ def get_doc_entity_spans(spacy_style: str, doc: Doc) -> List[Span]: return doc.spans["sc"] raise ValueError(f"Invalid spacy style: {spacy_style}") -def set_doc_entity_spans(spacy_style: str, doc: Doc, entities: List[Span]) -> None: +def set_doc_entity_spans(doc: Doc, entities: List[Span], spacy_style: str) -> None: """Set the spacy doc entity spans. Args: - spacy_style: The spacy style to use. doc: The spacy doc to set the entity spans. entities: The entity spans to assign the doc. + spacy_style: The spacy style to use. """ diff --git a/test/test_extractors.py b/test/test_extractors.py index e2e9c2f..e63a5fc 100644 --- a/test/test_extractors.py +++ b/test/test_extractors.py @@ -7,7 +7,7 @@ from anonipy.definitions import Entity from anonipy.anonymize.extractors import NERExtractor, PatternExtractor, MultiExtractor from anonipy.constants import LANGUAGES -from anonipy.anonymize.helpers import _filter_entities +from anonipy.anonymize.helpers import filter_entities # disable transformers logging logging.set_verbosity_error() @@ -344,7 +344,7 @@ def test_multi_extractor_extract_default(multi_extractor): # check the performance of the joint entities generation for p_entity, t_entity in zip( joint_entities, - _filter_entities(TEST_NER_ENTITIES + TEST_PATTERN_ENTITIES), + filter_entities(TEST_NER_ENTITIES + TEST_PATTERN_ENTITIES), ): assert p_entity.text == t_entity.text assert p_entity.label == t_entity.label From f16b835e0022b91002fb91d37a99ab6a666da2f5 Mon Sep 17 00:00:00 2001 From: ninakokalj Date: Mon, 16 Dec 2024 13:27:31 +0100 Subject: [PATCH 5/8] Fix errors --- anonipy/anonymize/extractors/ner_extractor.py | 7 +++---- anonipy/anonymize/extractors/pattern_extractor.py | 2 +- anonipy/anonymize/helpers.py | 4 ++-- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/anonipy/anonymize/extractors/ner_extractor.py b/anonipy/anonymize/extractors/ner_extractor.py index a0c598b..8995410 100644 --- a/anonipy/anonymize/extractors/ner_extractor.py +++ b/anonipy/anonymize/extractors/ner_extractor.py @@ -8,7 +8,7 @@ from spacy.tokens import Doc, Span from spacy.language import Language -from ..helpers import convert_spacy_to_entity, detect_repeated_entities, get_doc_entity_spans, create_spacy_entities +from ..helpers import convert_spacy_to_entity, detect_repeated_entities, get_doc_entity_spans, create_spacy_entities, set_doc_entity_spans from ...utils.regex import regex_mapping from ...constants import LANGUAGES from ...definitions import Entity @@ -112,12 +112,11 @@ def __call__(self, text: str, detect_repeats: bool = False, *args, **kwargs) -> doc = self.pipeline(text) anoni_entities, spacy_entities = self._prepare_entities(doc) - + if detect_repeats: anoni_entities = detect_repeated_entities(doc, anoni_entities, self.spacy_style) create_spacy_entities(doc, anoni_entities, self.spacy_style) - return doc, anoni_entities def display(self, doc: Doc, page: bool = False, jupyter: bool = None) -> str: @@ -233,7 +232,7 @@ def _prepare_entities(self, doc: Doc) -> Tuple[List[Entity], List[Span]]: # TODO: make this part more generic anoni_entities = [] spacy_entities = [] - for s in get_doc_entity_spans(self.spacy_style, doc): + for s in get_doc_entity_spans(doc, self.spacy_style): label = list(filter(lambda x: x["label"] == s.label_, self.labels))[0] if re.match(label["regex"], s.text): anoni_entities.append(convert_spacy_to_entity(s, **label)) diff --git a/anonipy/anonymize/extractors/pattern_extractor.py b/anonipy/anonymize/extractors/pattern_extractor.py index 71e0a09..266b554 100644 --- a/anonipy/anonymize/extractors/pattern_extractor.py +++ b/anonipy/anonymize/extractors/pattern_extractor.py @@ -102,7 +102,7 @@ def __call__(self, text: str, detect_repeats: bool = False, *args, **kwargs) -> if detect_repeats: anoni_entities = detect_repeated_entities(doc, anoni_entities, self.spacy_style) - + create_spacy_entities(doc, anoni_entities, self.spacy_style) return doc, anoni_entities diff --git a/anonipy/anonymize/helpers.py b/anonipy/anonymize/helpers.py index f65db3b..35d3d90 100644 --- a/anonipy/anonymize/helpers.py +++ b/anonipy/anonymize/helpers.py @@ -167,7 +167,7 @@ def detect_repeated_entities(doc: Doc, entities: List[Entity], spacy_style: str) end_index = end_index, score = entity.score, type = entity.type, - regex = entity.text + regex = entity.regex ) ) @@ -196,7 +196,7 @@ def create_spacy_entities(doc: Doc, entities: List[Entity], spacy_style: str) -> for entity in entities: span = doc.char_span(entity.start_index, entity.end_index, label=entity.label) - if not span: + if not span or span in updated_spans: continue span._.score = entity.score if spacy_style == "ent": From 3567e83d5634bad5d7b6d913f8686f35c67adcba Mon Sep 17 00:00:00 2001 From: ninakokalj Date: Tue, 17 Dec 2024 13:12:51 +0100 Subject: [PATCH 6/8] Add new tests for detecting repeated entities --- test/test_extractors.py | 205 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 202 insertions(+), 3 deletions(-) diff --git a/test/test_extractors.py b/test/test_extractors.py index e63a5fc..3ca5472 100644 --- a/test/test_extractors.py +++ b/test/test_extractors.py @@ -33,6 +33,8 @@ Lisinopril 10 mg: Take one tablet daily to manage high blood pressure. Next Examination Date: 15-11-2024 + +The examination took place on 20-05-2024. John Doe was prescribed Ibuprofen 200 mg and Lisinopril 10 mg. """ TEST_NER_ENTITIES = [ @@ -79,8 +81,23 @@ end_index=727, type="date", ), + Entity( + text="20-05-2024", + label="date", + start_index=759, + end_index=769, + type="date", + ), +] +TEST_REPEATS_ENTITIES = [ + Entity( + text="John Doe", + label="name", + start_index=771, + end_index=779, + type="string", + ), ] - TEST_PATTERN_ENTITIES = [ Entity( text="15-01-1985", @@ -125,8 +142,87 @@ end_index=727, type="date", ), + Entity( + text="20-05-2024", + label="date", + start_index=759, + end_index=769, + type="date", + ), + Entity( + text="Ibuprofen 200 mg", + label="medicine", + start_index=795, + end_index=811, + type="string", + ), + Entity( + text="Lisinopril 10 mg", + label="medicine", + start_index=816, + end_index=832, + type="string", + ), +] +TEST_PATTERN_DETECT_REPEATS = [ + Entity( + text="20-05-2024", + label="date", + start_index=86, + end_index=96, + type="date", + regex=r"Date of Examination: (.*)" + ), + # Repeated entity + Entity( + text="20-05-2024", + label="date", + start_index=759, + end_index=769, + type="date", + regex=r"Date of Examination: (.*)" + ), +] +TEST_MULTI_REPEATS = [ + Entity( + text="John Doe", + label="name", + start_index=30, + end_index=38, + type="string", + ), + Entity( + text="20-05-2024", + label="date", + start_index=86, + end_index=96, + type="date", + regex=r"Date of Examination: (.*)", + ), + Entity( + text="John Doe", + label="name", + start_index=157, + end_index=165, + type="string", + ), + # Repeated entities + Entity( + text="20-05-2024", + label="date", + start_index=759, + end_index=769, + type="date", + regex=r"Date of Examination: (.*)", + ), + Entity( + text="John Doe", + label="name", + start_index=771, + end_index=779, + type="string", + ), ] - @pytest.fixture(autouse=True) def suppress_warnings(): @@ -174,7 +270,7 @@ def pattern_extractor(): {"SHAPE": "dddd"}, ] ], - }, + } ] return PatternExtractor(labels=labels, lang=LANGUAGES.ENGLISH) @@ -278,6 +374,19 @@ def test_ner_extractor_extract_custom_params_input(): assert p_entity.score >= 0.5 +def test_ner_extractor_detect_repeats_true(ner_extractor): + _, entities = ner_extractor(TEST_ORIGINAL_TEXT, detect_repeats=True) + expected_entities = TEST_NER_ENTITIES + TEST_REPEATS_ENTITIES + for p_entity, t_entity in zip(entities, expected_entities): + assert p_entity.text == t_entity.text + assert p_entity.label == t_entity.label + assert p_entity.start_index == t_entity.start_index + assert p_entity.end_index == t_entity.end_index + assert p_entity.type == t_entity.type + assert p_entity.regex == t_entity.regex + assert p_entity.score >= 0.5 + + def test_pattern_extractor_init(): with pytest.raises(TypeError): PatternExtractor() @@ -303,6 +412,48 @@ def test_pattern_extractor_extract_default(pattern_extractor): assert p_entity.regex == t_entity.regex assert p_entity.score == 1.0 +def test_pattern_extractor_detect_repeats_false(): + extractor = PatternExtractor( + labels=[ + { + "label": "date", + "type": "date", + "regex": r"Date of Examination: (.*)", + } + ], + lang=LANGUAGES.ENGLISH, + ) + _, entities = extractor(TEST_ORIGINAL_TEXT, detect_repeats=False) + excepted_entity = TEST_PATTERN_DETECT_REPEATS[0] + assert len(entities) == 1 + assert excepted_entity.text == entities[0].text + assert excepted_entity.label == entities[0].label + assert excepted_entity.start_index == entities[0].start_index + assert excepted_entity.end_index == entities[0].end_index + assert excepted_entity.type == entities[0].type + assert excepted_entity.regex == entities[0].regex + assert excepted_entity.score >= 0.5 + +def test_pattern_extractor_detect_repeats_true(): + extractor = PatternExtractor( + labels=[ + { + "label": "date", + "type": "date", + "regex": r"Date of Examination: (.*)", + } + ], + lang=LANGUAGES.ENGLISH, + ) + _, entities = extractor(TEST_ORIGINAL_TEXT, detect_repeats=True) + for p_entity, t_entity in zip(entities, TEST_PATTERN_DETECT_REPEATS): + assert p_entity.text == t_entity.text + assert p_entity.label == t_entity.label + assert p_entity.start_index == t_entity.start_index + assert p_entity.end_index == t_entity.end_index + assert p_entity.type == t_entity.type + assert p_entity.regex == t_entity.regex + assert p_entity.score >= 0.5 def test_multi_extractor_init(): with pytest.raises(TypeError): @@ -413,3 +564,51 @@ def test_multi_extractor_extract_single_extractor_pattern(multi_extractor): assert p_entity.type == t_entity.type assert p_entity.regex == t_entity.regex assert p_entity.score >= 0.5 + + +def test_multi_extractor_detect_repeats_false(): + extractors = [ + NERExtractor(labels=[ + {"label": "name", "type": "string"}, + ]), + PatternExtractor(labels=[ + { + "label": "date", + "type": "date", + "regex": r"Date of Examination: (.*)", + }, + ])] + extractor = MultiExtractor(extractors) + _, joint_entities = extractor(TEST_ORIGINAL_TEXT, detect_repeats=False) + for p_entity, t_entity in zip(joint_entities, TEST_MULTI_REPEATS[:3]): + assert p_entity.text == t_entity.text + assert p_entity.label == t_entity.label + assert p_entity.start_index == t_entity.start_index + assert p_entity.end_index == t_entity.end_index + assert p_entity.type == t_entity.type + assert p_entity.regex == t_entity.regex + assert p_entity.score >= 0.5 + + +def test_multi_extractor_detect_repeats_true(): + extractors = [ + NERExtractor(labels=[ + {"label": "name", "type": "string"}, + ]), + PatternExtractor(labels=[ + { + "label": "date", + "type": "date", + "regex": r"Date of Examination: (.*)", + }, + ])] + extractor = MultiExtractor(extractors) + _, joint_entities = extractor(TEST_ORIGINAL_TEXT, detect_repeats=True) + for p_entity, t_entity in zip(joint_entities, TEST_MULTI_REPEATS): + assert p_entity.text == t_entity.text + assert p_entity.label == t_entity.label + assert p_entity.start_index == t_entity.start_index + assert p_entity.end_index == t_entity.end_index + assert p_entity.type == t_entity.type + assert p_entity.regex == t_entity.regex + assert p_entity.score >= 0.5 \ No newline at end of file From 58f8dcc1b8db4fd38f0967d57ae7d574b2c1ebce Mon Sep 17 00:00:00 2001 From: ninakokalj Date: Tue, 17 Dec 2024 13:14:19 +0100 Subject: [PATCH 7/8] Fix detect_repeats problems --- anonipy/anonymize/extractors/ner_extractor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/anonipy/anonymize/extractors/ner_extractor.py b/anonipy/anonymize/extractors/ner_extractor.py index 8995410..b3b87f0 100644 --- a/anonipy/anonymize/extractors/ner_extractor.py +++ b/anonipy/anonymize/extractors/ner_extractor.py @@ -8,7 +8,7 @@ from spacy.tokens import Doc, Span from spacy.language import Language -from ..helpers import convert_spacy_to_entity, detect_repeated_entities, get_doc_entity_spans, create_spacy_entities, set_doc_entity_spans +from ..helpers import convert_spacy_to_entity, detect_repeated_entities, get_doc_entity_spans, create_spacy_entities from ...utils.regex import regex_mapping from ...constants import LANGUAGES from ...definitions import Entity @@ -117,6 +117,7 @@ def __call__(self, text: str, detect_repeats: bool = False, *args, **kwargs) -> anoni_entities = detect_repeated_entities(doc, anoni_entities, self.spacy_style) create_spacy_entities(doc, anoni_entities, self.spacy_style) + return doc, anoni_entities def display(self, doc: Doc, page: bool = False, jupyter: bool = None) -> str: From 706e9e785a98d3cd5c9d1ed9d245881ae32ebd74 Mon Sep 17 00:00:00 2001 From: ninakokalj Date: Tue, 17 Dec 2024 13:35:03 +0100 Subject: [PATCH 8/8] Use create_spacy_entities --- anonipy/anonymize/extractors/pattern_extractor.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/anonipy/anonymize/extractors/pattern_extractor.py b/anonipy/anonymize/extractors/pattern_extractor.py index 266b554..385fb7d 100644 --- a/anonipy/anonymize/extractors/pattern_extractor.py +++ b/anonipy/anonymize/extractors/pattern_extractor.py @@ -200,15 +200,9 @@ def global_matchers(doc: Doc) -> None: if not entity: continue entity._.score = 1.0 + entities = [convert_spacy_to_entity(entity)] # add the entity to the previous entity list - prev_entities = get_doc_entity_spans(doc, self.spacy_style) - if self.spacy_style == "ent": - prev_entities = util.filter_spans(prev_entities + (entity,)) - elif self.spacy_style == "span": - prev_entities.append(entity) - else: - raise ValueError(f"Invalid spacy style: {self.spacy_style}") - set_doc_entity_spans(doc, prev_entities, self.spacy_style) + create_spacy_entities(doc, entities, self.spacy_style) return global_matchers @@ -250,8 +244,8 @@ def add_event_ent(matcher, doc, i, matches): entity = Span(doc, start, end, label=label) if not entity: return + entity._.score = 1.0 entities = [convert_spacy_to_entity(entity)] - create_spacy_entities(doc, entities, self.spacy_style) return add_event_ent \ No newline at end of file