Skip to content

Commit

Permalink
Merge pull request #24 from eriknovak/feature/llms
Browse files Browse the repository at this point in the history
Enable CPU utilization for LLMLabelGenerator
  • Loading branch information
eriknovak authored Nov 9, 2024
2 parents 15f7019 + e057b18 commit e883f45
Show file tree
Hide file tree
Showing 13 changed files with 228 additions and 301 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/documentation.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ jobs:
restore-keys: |
mkdocs-material-
- run: pip install mkdocs-material mkdocstrings[python]
- run: mkdocs gh-deploy --force
- run: mkdocs gh-deploy --force
4 changes: 2 additions & 2 deletions .github/workflows/unittests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: ["3.8", "3.9", "3.10"]
python-version: ["3.9", "3.10", "3.11", "3.12"]
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
Expand All @@ -25,4 +25,4 @@ jobs:
pip install -e .[test]
- name: Test with unittest
run: |
python -m unittest discover test
python -m unittest discover test
20 changes: 11 additions & 9 deletions anonipy/anonymize/extractors/multi_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,20 +147,22 @@ def _filter_entities(self, entities: Iterable[Entity]) -> List[Entity]:
"""

get_sort_key = lambda entity: (
entity.end_index - entity.start_index,
-entity.start_index,
)
def get_sort_key(entity):
return (
entity.end_index - entity.start_index,
-entity.start_index,
)

sorted_entities = sorted(entities, key=get_sort_key, reverse=True)
result = []
seen_tokens: Set[int] = set()
for entities in sorted_entities:
for entity in sorted_entities:
# Check for end - 1 here because boundaries are inclusive
if (
entities.start_index not in seen_tokens
and entities.end_index - 1 not in seen_tokens
entity.start_index not in seen_tokens
and entity.end_index - 1 not in seen_tokens
):
result.append(entities)
seen_tokens.update(range(entities.start_index, entities.end_index))
result.append(entity)
seen_tokens.update(range(entity.start_index, entity.end_index))
result = sorted(result, key=lambda entity: entity.start_index)
return result
4 changes: 2 additions & 2 deletions anonipy/anonymize/extractors/ner_extractor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re
import warnings
import importlib
from typing import List, Tuple
import warnings

import torch
from spacy import displacy
Expand Down Expand Up @@ -52,12 +52,12 @@ class NERExtractor(ExtractorInterface):
def __init__(
self,
labels: List[dict],
*args,
lang: LANGUAGES = LANGUAGES.ENGLISH,
score_th: float = 0.5,
use_gpu: bool = False,
gliner_model: str = "urchade/gliner_multi_pii-v1",
spacy_style: str = "ent",
*args,
**kwargs,
):
"""Initialize the named entity recognition (NER) extractor.
Expand Down
7 changes: 3 additions & 4 deletions anonipy/anonymize/extractors/pattern_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ class PatternExtractor(ExtractorInterface):
def __init__(
self,
labels: List[dict],
*args,
lang: LANGUAGES = LANGUAGES.ENGLISH,
spacy_style: str = "ent",
*args,
**kwargs,
):
"""Initialize the pattern extractor.
Expand Down Expand Up @@ -271,12 +271,11 @@ def _get_doc_entity_spans(self, doc: Doc) -> List[Span]:

if self.spacy_style == "ent":
return doc.ents
elif self.spacy_style == "span":
if self.spacy_style == "span":
if "sc" not in doc.spans:
doc.spans["sc"] = []
return doc.spans["sc"]
else:
raise ValueError(f"Invalid spacy style: {self.spacy_style}")
raise ValueError(f"Invalid spacy style: {self.spacy_style}")

def _set_doc_entity_spans(self, doc: Doc, entities: List[Span]) -> None:
"""Set the spacy doc entity spans.
Expand Down
4 changes: 2 additions & 2 deletions anonipy/anonymize/generators/date_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ class DateGenerator(GeneratorInterface):
"""

def __init__(self, date_format: str = "auto", day_sigma: int = 30, *args, **kwargs):
def __init__(self, *args, date_format: str = "auto", day_sigma: int = 30, **kwargs):
"""Initializes he date generator.
Examples:
Expand All @@ -137,8 +137,8 @@ def __init__(self, date_format: str = "auto", day_sigma: int = 30, *args, **kwar
def generate(
self,
entity: Entity,
sub_variant: DATE_TRANSFORM_VARIANTS = DATE_TRANSFORM_VARIANTS.RANDOM,
*args,
sub_variant: DATE_TRANSFORM_VARIANTS = DATE_TRANSFORM_VARIANTS.RANDOM,
**kwargs,
) -> str:
"""Generate the entity substitute based on the input parameters.
Expand Down
Loading

0 comments on commit e883f45

Please sign in to comment.