From 30b06d677b834a4999a155de2a64a957889aa331 Mon Sep 17 00:00:00 2001 From: eriknovak Date: Wed, 11 Dec 2024 21:42:56 +0100 Subject: [PATCH] Change unit tests to use pytest --- .github/workflows/unittests.yaml | 5 +- anonipy/anonymize/strategies/masking.py | 2 +- anonipy/anonymize/strategies/redaction.py | 2 +- pyproject.toml | 9 +- test/test_anonymize.py | 16 +- test/test_entity.py | 74 ++- test/test_extractors.py | 583 ++++++++++------------ test/test_file_system.py | 36 +- test/test_generators.py | 302 +++++------ test/test_language_detector.py | 197 ++++---- test/test_pipeline.py | 168 ++++--- test/test_regex.py | 26 +- test/test_strategies.py | 239 +++++---- 13 files changed, 806 insertions(+), 853 deletions(-) diff --git a/.github/workflows/unittests.yaml b/.github/workflows/unittests.yaml index e2bc177..62c0922 100644 --- a/.github/workflows/unittests.yaml +++ b/.github/workflows/unittests.yaml @@ -25,6 +25,5 @@ jobs: run: | python -m pip install --upgrade pip pip install -e .[test] - - name: Test with unittest - run: | - python -m unittest discover test + - name: Test with pytest + run: pytest diff --git a/anonipy/anonymize/strategies/masking.py b/anonipy/anonymize/strategies/masking.py index 1c652b6..05a7a27 100644 --- a/anonipy/anonymize/strategies/masking.py +++ b/anonipy/anonymize/strategies/masking.py @@ -40,7 +40,7 @@ def __init__(self, substitute_label: str = "*", *args, **kwargs): """ super().__init__(*args, **kwargs) - self.substitute_label = substitute_label + self.substitute_label = substitute_label or "*" def anonymize( self, text: str, entities: List[Entity], *args, **kwargs diff --git a/anonipy/anonymize/strategies/redaction.py b/anonipy/anonymize/strategies/redaction.py index 5e075a9..688bee9 100644 --- a/anonipy/anonymize/strategies/redaction.py +++ b/anonipy/anonymize/strategies/redaction.py @@ -39,7 +39,7 @@ def __init__(self, substitute_label: str = "[REDACTED]", *args, **kwargs) -> Non """ super().__init__(*args, **kwargs) - self.substitute_label = substitute_label + self.substitute_label = substitute_label or "[REDACTED]" def anonymize( self, text: str, entities: List[Entity], *args, **kwargs diff --git a/pyproject.toml b/pyproject.toml index f0e7af4..480355e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,10 @@ build-backend = 'setuptools.build_meta' [project] name = "anonipy" description = "The data anonymization package" -authors=[{ name = "Erik Novak" }] +authors=[ + { name = "Erik Novak" }, + { name = "Nina Kokalj" } +] maintainers = [{ name = "Erik Novak" }] readme = "README.md" license = { file = "LICENSE" } @@ -36,8 +39,8 @@ dev = [ "mkdocstrings[python]", ] test = [ - "coverage", - "nbmake", + "pytest", + "pytest-cov", ] all = ["anonipy[dev,test]"] diff --git a/test/test_anonymize.py b/test/test_anonymize.py index 7448d34..f9bcad9 100644 --- a/test/test_anonymize.py +++ b/test/test_anonymize.py @@ -1,5 +1,3 @@ -import unittest - from anonipy.anonymize import anonymize # ===================================== @@ -30,18 +28,12 @@ }, ] - # ===================================== # Test Anonymize # ===================================== -class TestAnonymize(unittest.TestCase): - def test_anonymize(self): - anonymized_text, replacements = anonymize(test_text, test_replacements) - self.assertEqual(anonymized_text, test_text_anonymized) - self.assertEqual(replacements, test_replacements) - - -if __name__ == "__main__": - unittest.main() +def test_anonymize(): + anonymized_text, replacements = anonymize(test_text, test_replacements) + assert anonymized_text == test_text_anonymized + assert replacements == test_replacements diff --git a/test/test_entity.py b/test/test_entity.py index e8b51e6..e25756f 100644 --- a/test/test_entity.py +++ b/test/test_entity.py @@ -1,48 +1,40 @@ -import unittest - from anonipy.definitions import Entity - # ===================================== # Test Entity # ===================================== -class TestEntity(unittest.TestCase): - - def test_init_default(self): - entity = Entity( - text="test", - label="test", - start_index=0, - end_index=4, - ) - self.assertEqual(entity.text, "test") - self.assertEqual(entity.label, "test") - self.assertEqual(entity.start_index, 0) - self.assertEqual(entity.end_index, 4) - self.assertEqual(entity.score, 1.0) - self.assertEqual(entity.type, None) - self.assertEqual(entity.regex, ".*") - - def test_init_custom(self): - entity = Entity( - text="test", - label="test", - start_index=0, - end_index=4, - score=0.89, - type="test", - regex="test", - ) - self.assertEqual(entity.text, "test") - self.assertEqual(entity.label, "test") - self.assertEqual(entity.start_index, 0) - self.assertEqual(entity.end_index, 4) - self.assertEqual(entity.score, 0.89) - self.assertEqual(entity.type, "test") - self.assertEqual(entity.regex, "test") - - -if __name__ == "__main__": - unittest.main() +def test_init_default(): + entity = Entity( + text="test", + label="test", + start_index=0, + end_index=4, + ) + assert entity.text == "test" + assert entity.label == "test" + assert entity.start_index == 0 + assert entity.end_index == 4 + assert entity.score == 1.0 + assert entity.type is None + assert entity.regex == ".*" + + +def test_init_custom(): + entity = Entity( + text="test", + label="test", + start_index=0, + end_index=4, + score=0.89, + type="test", + regex="test", + ) + assert entity.text == "test" + assert entity.label == "test" + assert entity.start_index == 0 + assert entity.end_index == 4 + assert entity.score == 0.89 + assert entity.type == "test" + assert entity.regex == "test" diff --git a/test/test_extractors.py b/test/test_extractors.py index 1f187e8..59df2e0 100644 --- a/test/test_extractors.py +++ b/test/test_extractors.py @@ -1,6 +1,6 @@ -import unittest import warnings +import pytest import torch from transformers import logging @@ -15,7 +15,7 @@ # Helper functions # ===================================== -original_text = """\ +TEST_ORIGINAL_TEXT = """\ Medical Record Patient Name: John Doe @@ -34,7 +34,7 @@ 15-11-2024 """ -ner_entities = [ +TEST_NER_ENTITIES = [ Entity( text="John Doe", label="name", @@ -80,7 +80,7 @@ ), ] -pattern_entities = [ +TEST_PATTERN_ENTITIES = [ Entity( text="15-01-1985", label="date", @@ -127,172 +127,105 @@ ] -# ===================================== -# Test NER Extractor -# ===================================== - - -class TestNERExtractor(unittest.TestCase): - - def setUp(self): - warnings.filterwarnings("ignore", category=ImportWarning) - warnings.filterwarnings("ignore", category=UserWarning) - warnings.filterwarnings("ignore", category=FutureWarning) - # define the labels to be extracted and anonymized - self.labels = [ - {"label": "name", "type": "string"}, - { - "label": "social security number", - "type": "custom", - "regex": "[0-9]{3}-[0-9]{2}-[0-9]{4}", - }, - {"label": "date of birth", "type": "date"}, - {"label": "date", "type": "date"}, - ] - - def test_init(self): - with self.assertRaises(TypeError): - NERExtractor() - - def test_init_inputs(self): - extractor = NERExtractor( - labels=self.labels, lang=LANGUAGES.ENGLISH, score_th=0.5 - ) - self.assertEqual(extractor.__class__, NERExtractor) - - def test_init_gpu(self): - if torch.cuda.is_available(): - extractor = NERExtractor( - labels=self.labels, lang=LANGUAGES.ENGLISH, score_th=0.5, use_gpu=True - ) - self.assertEqual(extractor.__class__, NERExtractor) - - def test_methods(self): - extractor = NERExtractor( - labels=self.labels, lang=LANGUAGES.ENGLISH, score_th=0.5 - ) - self.assertEqual(hasattr(extractor, "__call__"), True) - self.assertEqual(hasattr(extractor, "display"), True) - - def test_extract_default_params(self): - extractor = NERExtractor(labels=self.labels) - _, entities = extractor(original_text) - for p_entity, t_entity in zip(entities, ner_entities): - self.assertEqual(p_entity.text, t_entity.text) - self.assertEqual(p_entity.label, t_entity.label) - self.assertEqual(p_entity.start_index, t_entity.start_index) - self.assertEqual(p_entity.end_index, t_entity.end_index) - self.assertEqual(p_entity.type, t_entity.type) - self.assertEqual(p_entity.regex, t_entity.regex) - self.assertEqual(p_entity.score >= 0.5, True) - - def test_extract_default_params_input(self): - extractor = NERExtractor( - labels=self.labels, - lang=LANGUAGES.ENGLISH, - gliner_model="urchade/gliner_multi_pii-v1", - score_th=0.5, - ) - _, entities = extractor(original_text) - for p_entity, t_entity in zip(entities, ner_entities): - self.assertEqual(p_entity.text, t_entity.text) - self.assertEqual(p_entity.label, t_entity.label) - self.assertEqual(p_entity.start_index, t_entity.start_index) - self.assertEqual(p_entity.end_index, t_entity.end_index) - self.assertEqual(p_entity.type, t_entity.type) - self.assertEqual(p_entity.regex, t_entity.regex) - self.assertEqual(p_entity.score >= 0.5, True) - - def test_extract_custom_params_input(self): +@pytest.fixture(autouse=True) +def suppress_warnings(): + warnings.filterwarnings("ignore", category=UserWarning) + warnings.filterwarnings("ignore", category=ResourceWarning) + + +@pytest.fixture(scope="module") +def ner_extractor(): + labels = [ + {"label": "name", "type": "string"}, + { + "label": "social security number", + "type": "custom", + "regex": "[0-9]{3}-[0-9]{2}-[0-9]{4}", + }, + {"label": "date of birth", "type": "date"}, + {"label": "date", "type": "date"}, + ] + return NERExtractor(labels=labels, lang=LANGUAGES.ENGLISH) + + +@pytest.fixture(scope="module") +def pattern_extractor(): + labels = [ + { + "label": "symptoms", + "type": "string", + "regex": r"\((.*)\)", # symptoms are enclosed in parentheses + }, + { + "label": "medicine", + "type": "string", + "pattern": [[{"IS_ALPHA": True}, {"LIKE_NUM": True}, {"LOWER": "mg"}]], + }, + { + "label": "date", + "type": "date", + "pattern": [ # represent the date as a sequence of digits using spacy + [ + {"SHAPE": "dd"}, + {"TEXT": "-"}, + {"SHAPE": "dd"}, + {"TEXT": "-"}, + {"SHAPE": "dddd"}, + ] + ], + }, + ] + return PatternExtractor(labels=labels, lang=LANGUAGES.ENGLISH) + + +@pytest.fixture(scope="module") +def multi_extractor(ner_extractor, pattern_extractor): + return MultiExtractor([ner_extractor, pattern_extractor]) + + +def test_ner_extractor_init(): + with pytest.raises(TypeError): + NERExtractor() + + +def test_ner_extractor_init_inputs(ner_extractor): + extractor = NERExtractor( + labels=ner_extractor.labels, lang=LANGUAGES.ENGLISH, score_th=0.5 + ) + assert isinstance(extractor, NERExtractor) + + +def test_ner_extractor_init_gpu(ner_extractor): + if torch.cuda.is_available(): extractor = NERExtractor( - labels=self.labels, + labels=ner_extractor.labels, lang=LANGUAGES.ENGLISH, - gliner_model="E3-JSI/gliner-multi-pii-domains-v1", score_th=0.5, + use_gpu=True, ) - _, entities = extractor(original_text) - for p_entity, t_entity in zip(entities, ner_entities): - self.assertEqual(p_entity.text, t_entity.text) - self.assertEqual(p_entity.label, t_entity.label) - self.assertEqual(p_entity.start_index, t_entity.start_index) - self.assertEqual(p_entity.end_index, t_entity.end_index) - self.assertEqual(p_entity.type, t_entity.type) - self.assertEqual(p_entity.regex, t_entity.regex) - self.assertEqual(p_entity.score >= 0.5, True) + assert isinstance(extractor, NERExtractor) -# ===================================== -# Test Pattern Extractor -# ===================================== +def test_ner_extractor_methods(ner_extractor): + assert hasattr(ner_extractor, "__call__") + assert hasattr(ner_extractor, "display") -class TestPatternExtractor(unittest.TestCase): +def test_ner_extractor_extract_default_params(ner_extractor): + _, entities = ner_extractor(TEST_ORIGINAL_TEXT) + for p_entity, t_entity in zip(entities, TEST_NER_ENTITIES): + assert p_entity.text == t_entity.text + assert p_entity.label == t_entity.label + assert p_entity.start_index == t_entity.start_index + assert p_entity.end_index == t_entity.end_index + assert p_entity.type == t_entity.type + assert p_entity.regex == t_entity.regex + assert p_entity.score >= 0.5 - def setUp(self): - warnings.filterwarnings("ignore", category=ImportWarning) - warnings.filterwarnings("ignore", category=UserWarning) - warnings.filterwarnings("ignore", category=FutureWarning) - # define the labels to be extracted and anonymized - self.labels = [ - { - "label": "symptoms", - "type": "string", - "regex": r"\((.*)\)", # symptoms are enclosed in parentheses - }, - { - "label": "medicine", - "type": "string", - "pattern": [[{"IS_ALPHA": True}, {"LIKE_NUM": True}, {"LOWER": "mg"}]], - }, - { - "label": "date", - "type": "date", - "pattern": [ # represent the date as a sequence of digits using spacy - [ - {"SHAPE": "dd"}, - {"TEXT": "-"}, - {"SHAPE": "dd"}, - {"TEXT": "-"}, - {"SHAPE": "dddd"}, - ] - ], - }, - ] - - def test_init(self): - with self.assertRaises(TypeError): - PatternExtractor() - - def test_init_inputs(self): - extractor = PatternExtractor(labels=self.labels, lang=LANGUAGES.ENGLISH) - self.assertEqual(extractor.__class__, PatternExtractor) - - def test_methods(self): - extractor = PatternExtractor(labels=self.labels, lang=LANGUAGES.ENGLISH) - self.assertEqual(hasattr(extractor, "__call__"), True) - self.assertEqual(hasattr(extractor, "display"), True) - - def test_extract_default(self): - extractor = PatternExtractor(labels=self.labels, lang=LANGUAGES.ENGLISH) - doc, entities = extractor(original_text) - for p_entity, t_entity in zip(entities, pattern_entities): - self.assertEqual(p_entity.text, t_entity.text) - self.assertEqual(p_entity.label, t_entity.label) - self.assertEqual(p_entity.start_index, t_entity.start_index) - self.assertEqual(p_entity.end_index, t_entity.end_index) - self.assertEqual(p_entity.type, t_entity.type) - self.assertEqual(p_entity.regex, t_entity.regex) - self.assertEqual(p_entity.score == 1.0, True) - - -class TestMultiExtractor(unittest.TestCase): - - def setUp(self): - warnings.filterwarnings("ignore", category=ImportWarning) - warnings.filterwarnings("ignore", category=UserWarning) - warnings.filterwarnings("ignore", category=FutureWarning) - # define the labels to be extracted and anonymized - self.ner_labels = [ + +def test_ner_extractor_extract_default_params_input(): + extractor = NERExtractor( + labels=[ {"label": "name", "type": "string"}, { "label": "social security number", @@ -301,157 +234,181 @@ def setUp(self): }, {"label": "date of birth", "type": "date"}, {"label": "date", "type": "date"}, - ] - self.pattern_labels = [ - { - "label": "symptoms", - "type": "string", - "regex": r"\((.*)\)", # symptoms are enclosed in parentheses - }, - { - "label": "medicine", - "type": "string", - "pattern": [[{"IS_ALPHA": True}, {"LIKE_NUM": True}, {"LOWER": "mg"}]], - }, + ], + lang=LANGUAGES.ENGLISH, + gliner_model="urchade/gliner_multi_pii-v1", + score_th=0.5, + ) + _, entities = extractor(TEST_ORIGINAL_TEXT) + for p_entity, t_entity in zip(entities, TEST_NER_ENTITIES): + assert p_entity.text == t_entity.text + assert p_entity.label == t_entity.label + assert p_entity.start_index == t_entity.start_index + assert p_entity.end_index == t_entity.end_index + assert p_entity.type == t_entity.type + assert p_entity.regex == t_entity.regex + assert p_entity.score >= 0.5 + + +def test_ner_extractor_extract_custom_params_input(): + extractor = NERExtractor( + labels=[ + {"label": "name", "type": "string"}, { - "label": "date", - "type": "date", - "pattern": [ # represent the date as a sequence of digits using spacy - [ - {"SHAPE": "dd"}, - {"TEXT": "-"}, - {"SHAPE": "dd"}, - {"TEXT": "-"}, - {"SHAPE": "dddd"}, - ] - ], + "label": "social security number", + "type": "custom", + "regex": "[0-9]{3}-[0-9]{2}-[0-9]{4}", }, - ] - - def test_init(self): - with self.assertRaises(TypeError): - MultiExtractor() - - def test_init_inputs(self): - extractors = [ - NERExtractor(labels=self.ner_labels, lang=LANGUAGES.ENGLISH), - PatternExtractor(labels=self.pattern_labels, lang=LANGUAGES.ENGLISH), - ] - extractor = MultiExtractor(extractors) - self.assertEqual(extractor.__class__, MultiExtractor) - - def test_methods(self): - extractors = [ - NERExtractor(labels=self.ner_labels, lang=LANGUAGES.ENGLISH), - PatternExtractor(labels=self.pattern_labels, lang=LANGUAGES.ENGLISH), - ] - extractor = MultiExtractor(extractors) - self.assertEqual(hasattr(extractor, "__call__"), True) - self.assertEqual(hasattr(extractor, "display"), True) - - def test_extract_default(self): - extractors = [ - NERExtractor(labels=self.ner_labels, lang=LANGUAGES.ENGLISH), - PatternExtractor(labels=self.pattern_labels, lang=LANGUAGES.ENGLISH), - ] - extractor = MultiExtractor(extractors) - extractor_outputs, joint_entities = extractor(original_text) - - # check the performance of the first extractor - for p_entity, t_entity in zip(extractor_outputs[0][1], ner_entities): - self.assertEqual(p_entity.text, t_entity.text) - self.assertEqual(p_entity.label, t_entity.label) - self.assertEqual(p_entity.start_index, t_entity.start_index) - self.assertEqual(p_entity.end_index, t_entity.end_index) - self.assertEqual(p_entity.type, t_entity.type) - self.assertEqual(p_entity.regex, t_entity.regex) - self.assertEqual(p_entity.score >= 0.5, True) - - # check the performance of the second extractor - for p_entity, t_entity in zip(extractor_outputs[1][1], pattern_entities): - self.assertEqual(p_entity.text, t_entity.text) - self.assertEqual(p_entity.label, t_entity.label) - self.assertEqual(p_entity.start_index, t_entity.start_index) - self.assertEqual(p_entity.end_index, t_entity.end_index) - self.assertEqual(p_entity.type, t_entity.type) - self.assertEqual(p_entity.regex, t_entity.regex) - self.assertEqual(p_entity.score == 1.0, True) - - # check the performance of the joint entities generation - for p_entity, t_entity in zip( - joint_entities, extractor._filter_entities(ner_entities + pattern_entities) - ): - self.assertEqual(p_entity.text, t_entity.text) - self.assertEqual(p_entity.label, t_entity.label) - self.assertEqual(p_entity.start_index, t_entity.start_index) - self.assertEqual(p_entity.end_index, t_entity.end_index) - self.assertEqual(p_entity.type, t_entity.type) - self.assertEqual(p_entity.regex, t_entity.regex) - self.assertEqual(p_entity.score >= 0.5, True) - - def test_extract_empty_extractor_list(self): - extractors = [] - with self.assertRaises(ValueError): - MultiExtractor(extractors) - - def test_extract_invalid_extractor_list(self): - extractors = [ - NERExtractor(labels=self.ner_labels, lang=LANGUAGES.ENGLISH), - "invalid", - ] - with self.assertRaises(ValueError): - MultiExtractor(extractors) - - def test_extract_single_extractor_ner(self): - extractors = [NERExtractor(labels=self.ner_labels, lang=LANGUAGES.ENGLISH)] - extractor = MultiExtractor(extractors) - extractor_outputs, joint_entities = extractor(original_text) - - # check the performance of the extractor - for p_entity, t_entity in zip(extractor_outputs[0][1], ner_entities): - self.assertEqual(p_entity.text, t_entity.text) - self.assertEqual(p_entity.label, t_entity.label) - self.assertEqual(p_entity.start_index, t_entity.start_index) - self.assertEqual(p_entity.end_index, t_entity.end_index) - self.assertEqual(p_entity.type, t_entity.type) - self.assertEqual(p_entity.regex, t_entity.regex) - self.assertEqual(p_entity.score >= 0.5, True) - - # check the performance of the joint entities generation - for p_entity, t_entity in zip(joint_entities, ner_entities): - self.assertEqual(p_entity.text, t_entity.text) - self.assertEqual(p_entity.label, t_entity.label) - self.assertEqual(p_entity.start_index, t_entity.start_index) - self.assertEqual(p_entity.end_index, t_entity.end_index) - self.assertEqual(p_entity.type, t_entity.type) - self.assertEqual(p_entity.regex, t_entity.regex) - self.assertEqual(p_entity.score >= 0.5, True) - - def test_extract_single_extractor_pattern(self): - extractors = [NERExtractor(labels=self.ner_labels, lang=LANGUAGES.ENGLISH)] - extractor = MultiExtractor(extractors) - extractor_outputs, joint_entities = extractor(original_text) - - # check the performance of the extractor - for p_entity, t_entity in zip(extractor_outputs[0][1], ner_entities): - self.assertEqual(p_entity.text, t_entity.text) - self.assertEqual(p_entity.label, t_entity.label) - self.assertEqual(p_entity.start_index, t_entity.start_index) - self.assertEqual(p_entity.end_index, t_entity.end_index) - self.assertEqual(p_entity.type, t_entity.type) - self.assertEqual(p_entity.regex, t_entity.regex) - self.assertEqual(p_entity.score >= 0.5, True) - - # check the performance of the joint entities generation - for p_entity, t_entity in zip(joint_entities, ner_entities): - self.assertEqual(p_entity.text, t_entity.text) - self.assertEqual(p_entity.label, t_entity.label) - self.assertEqual(p_entity.start_index, t_entity.start_index) - self.assertEqual(p_entity.end_index, t_entity.end_index) - self.assertEqual(p_entity.type, t_entity.type) - self.assertEqual(p_entity.regex, t_entity.regex) - self.assertEqual(p_entity.score >= 0.5, True) - - -if __name__ == "__main__": - unittest.main() + {"label": "date of birth", "type": "date"}, + {"label": "date", "type": "date"}, + ], + lang=LANGUAGES.ENGLISH, + gliner_model="E3-JSI/gliner-multi-pii-domains-v1", + score_th=0.5, + ) + _, entities = extractor(TEST_ORIGINAL_TEXT) + for p_entity, t_entity in zip(entities, TEST_NER_ENTITIES): + assert p_entity.text == t_entity.text + assert p_entity.label == t_entity.label + assert p_entity.start_index == t_entity.start_index + assert p_entity.end_index == t_entity.end_index + assert p_entity.type == t_entity.type + assert p_entity.regex == t_entity.regex + assert p_entity.score >= 0.5 + + +def test_pattern_extractor_init(): + with pytest.raises(TypeError): + PatternExtractor() + + +def test_pattern_extractor_init_inputs(pattern_extractor): + assert isinstance(pattern_extractor, PatternExtractor) + + +def test_pattern_extractor_methods(pattern_extractor): + assert hasattr(pattern_extractor, "__call__") + assert hasattr(pattern_extractor, "display") + + +def test_pattern_extractor_extract_default(pattern_extractor): + doc, entities = pattern_extractor(TEST_ORIGINAL_TEXT) + for p_entity, t_entity in zip(entities, TEST_PATTERN_ENTITIES): + assert p_entity.text == t_entity.text + assert p_entity.label == t_entity.label + assert p_entity.start_index == t_entity.start_index + assert p_entity.end_index == t_entity.end_index + assert p_entity.type == t_entity.type + assert p_entity.regex == t_entity.regex + assert p_entity.score == 1.0 + + +def test_multi_extractor_init(): + with pytest.raises(TypeError): + MultiExtractor() + + +def test_multi_extractor_init_inputs(multi_extractor): + assert isinstance(multi_extractor, MultiExtractor) + + +def test_multi_extractor_methods(multi_extractor): + assert hasattr(multi_extractor, "__call__") + assert hasattr(multi_extractor, "display") + + +def test_multi_extractor_extract_default(multi_extractor): + extractor_outputs, joint_entities = multi_extractor(TEST_ORIGINAL_TEXT) + + # check the performance of the first extractor + for p_entity, t_entity in zip(extractor_outputs[0][1], TEST_NER_ENTITIES): + assert p_entity.text == t_entity.text + assert p_entity.label == t_entity.label + assert p_entity.start_index == t_entity.start_index + assert p_entity.end_index == t_entity.end_index + assert p_entity.type == t_entity.type + assert p_entity.regex == t_entity.regex + assert p_entity.score >= 0.5 + + # check the performance of the second extractor + for p_entity, t_entity in zip(extractor_outputs[1][1], TEST_PATTERN_ENTITIES): + assert p_entity.text == t_entity.text + assert p_entity.label == t_entity.label + assert p_entity.start_index == t_entity.start_index + assert p_entity.end_index == t_entity.end_index + assert p_entity.type == t_entity.type + assert p_entity.regex == t_entity.regex + assert p_entity.score == 1.0 + + # check the performance of the joint entities generation + for p_entity, t_entity in zip( + joint_entities, + multi_extractor._filter_entities(TEST_NER_ENTITIES + TEST_PATTERN_ENTITIES), + ): + assert p_entity.text == t_entity.text + assert p_entity.label == t_entity.label + assert p_entity.start_index == t_entity.start_index + assert p_entity.end_index == t_entity.end_index + assert p_entity.type == t_entity.type + assert p_entity.regex == t_entity.regex + assert p_entity.score >= 0.5 + + +def test_multi_extractor_extract_empty_extractor_list(): + with pytest.raises(ValueError): + MultiExtractor([]) + + +def test_multi_extractor_extract_invalid_extractor_list(): + with pytest.raises(ValueError): + MultiExtractor([NERExtractor(labels=[], lang=LANGUAGES.ENGLISH), "invalid"]) + + +def test_multi_extractor_extract_single_extractor_ner(multi_extractor): + extractor = MultiExtractor([multi_extractor.extractors[0]]) + extractor_outputs, joint_entities = extractor(TEST_ORIGINAL_TEXT) + + # check the performance of the extractor + for p_entity, t_entity in zip(extractor_outputs[0][1], TEST_NER_ENTITIES): + assert p_entity.text == t_entity.text + assert p_entity.label == t_entity.label + assert p_entity.start_index == t_entity.start_index + assert p_entity.end_index == t_entity.end_index + assert p_entity.type == t_entity.type + assert p_entity.regex == t_entity.regex + assert p_entity.score >= 0.5 + + # check the performance of the joint entities generation + for p_entity, t_entity in zip(joint_entities, TEST_NER_ENTITIES): + assert p_entity.text == t_entity.text + assert p_entity.label == t_entity.label + assert p_entity.start_index == t_entity.start_index + assert p_entity.end_index == t_entity.end_index + assert p_entity.type == t_entity.type + assert p_entity.regex == t_entity.regex + assert p_entity.score >= 0.5 + + +def test_multi_extractor_extract_single_extractor_pattern(multi_extractor): + extractor = MultiExtractor([multi_extractor.extractors[1]]) + extractor_outputs, joint_entities = extractor(TEST_ORIGINAL_TEXT) + + # check the performance of the extractor + for p_entity, t_entity in zip(extractor_outputs[0][1], TEST_PATTERN_ENTITIES): + assert p_entity.text == t_entity.text + assert p_entity.label == t_entity.label + assert p_entity.start_index == t_entity.start_index + assert p_entity.end_index == t_entity.end_index + assert p_entity.type == t_entity.type + assert p_entity.regex == t_entity.regex + assert p_entity.score >= 0.5 + + # check the performance of the joint entities generation + for p_entity, t_entity in zip(joint_entities, TEST_PATTERN_ENTITIES): + assert p_entity.text == t_entity.text + assert p_entity.label == t_entity.label + assert p_entity.start_index == t_entity.start_index + assert p_entity.end_index == t_entity.end_index + assert p_entity.type == t_entity.type + assert p_entity.regex == t_entity.regex + assert p_entity.score >= 0.5 diff --git a/test/test_file_system.py b/test/test_file_system.py index fcf9f82..5d3913c 100644 --- a/test/test_file_system.py +++ b/test/test_file_system.py @@ -1,34 +1,22 @@ -import unittest +import pytest from anonipy.utils.file_system import open_file - -# ===================================== -# Helper functions -# ===================================== - from test.resources.example_outputs import WORD_TEXT, PDF_TEXT, TXT_TEXT -resources = { +RESOURCES = { "word": "./test/resources/example.docx", "pdf": "./test/resources/example.pdf", "txt": "./test/resources/example.txt", } -# ===================================== -# Test Entity Extractor -# ===================================== - - -class TestFileSystem(unittest.TestCase): - def test_open_file_word(self): - self.assertEqual(open_file(resources["word"]), WORD_TEXT) - - def test_open_file_pdf(self): - self.assertEqual(open_file(resources["pdf"]), PDF_TEXT) - - def test_open_file_txt(self): - self.assertEqual(open_file(resources["txt"]), TXT_TEXT) - -if __name__ == "__main__": - unittest.main() +@pytest.mark.parametrize( + "file_type, expected_output", + [ + ("word", WORD_TEXT), + ("pdf", PDF_TEXT), + ("txt", TXT_TEXT), + ], +) +def test_open_file(file_type, expected_output): + assert open_file(RESOURCES[file_type]) == expected_output diff --git a/test/test_generators.py b/test/test_generators.py index f07b8f7..382cc07 100644 --- a/test/test_generators.py +++ b/test/test_generators.py @@ -1,7 +1,7 @@ import re -import unittest import warnings +import pytest from transformers import logging from anonipy.definitions import Entity @@ -93,7 +93,7 @@ # ===================================== -original_text = """\ +TEST_ORIGINAL_TEXT = """\ Medical Record Patient Name: John Doe @@ -112,7 +112,7 @@ 15-11-2024 """ -test_entities = { +TEST_ENTITIES = { "name": Entity( text="John Doe", label="name", @@ -177,40 +177,42 @@ # ===================================== -class TestLLMLabelGenerator(unittest.TestCase): +@pytest.fixture(scope="module") +def llm_label_generator(): + return LLMLabelGenerator() - @classmethod - def setUpClass(self): - self.generator = LLMLabelGenerator() - def test_has_methods(self): - self.assertEqual(hasattr(self.generator, "generate"), True) +def test_llm_label_generator_has_methods(llm_label_generator): + assert hasattr(llm_label_generator, "generate") - def test_generate_default(self): - entity = test_entities["name"] - generated_text = self.generator.generate(entity) - regex = entity.get_regex_group() or entity.regex - match = re.match(regex, generated_text) - self.assertNotEqual(match, None) - self.assertEqual(match.group(0), generated_text) - def test_generate_custom(self): - entity = test_entities["name"] - generated_text = self.generator.generate( - entity, add_entity_attrs="Spanish", temperature=0.5 - ) - regex = entity.get_regex_group() or entity.regex - match = re.match(regex, generated_text) - self.assertNotEqual(match, None) - self.assertEqual(match.group(0), generated_text) +def test_llm_label_generator_generate_default(llm_label_generator): + entity = TEST_ENTITIES["name"] + generated_text = llm_label_generator.generate(entity) + regex = entity.get_regex_group() or entity.regex + match = re.match(regex, generated_text) + assert match is not None + assert match.group(0) == generated_text + + +def test_llm_label_generator_generate_custom(llm_label_generator): + entity = TEST_ENTITIES["name"] + generated_text = llm_label_generator.generate( + entity, add_entity_attrs="Spanish", temperature=0.5 + ) + regex = entity.get_regex_group() or entity.regex + match = re.match(regex, generated_text) + assert match is not None + assert match.group(0) == generated_text + - def test_generate_pattern(self): - entity = test_entities["name:pattern"] - generated_text = self.generator.generate(entity) - regex = entity.get_regex_group() or entity.regex - match = re.match(regex, generated_text) - self.assertNotEqual(match, None) - self.assertEqual(match.group(0), generated_text) +def test_llm_label_generator_generate_pattern(llm_label_generator): + entity = TEST_ENTITIES["name:pattern"] + generated_text = llm_label_generator.generate(entity) + regex = entity.get_regex_group() or entity.regex + match = re.match(regex, generated_text) + assert match is not None + assert match.group(0) == generated_text # ===================================== @@ -218,26 +220,28 @@ def test_generate_pattern(self): # ===================================== -class TestMaskLabelGenerator(unittest.TestCase): +@pytest.fixture(scope="module") +def mask_label_generator(): + return MaskLabelGenerator() - @classmethod - def setUpClass(self): - self.generator = MaskLabelGenerator() - def setUp(self): - warnings.filterwarnings("ignore", category=ImportWarning) - warnings.filterwarnings("ignore", category=UserWarning) - warnings.filterwarnings("ignore", category=FutureWarning) +@pytest.fixture(autouse=True) +def suppress_warnings(): + warnings.filterwarnings("ignore", category=ImportWarning) + warnings.filterwarnings("ignore", category=UserWarning) + warnings.filterwarnings("ignore", category=FutureWarning) - def test_has_methods(self): - self.assertEqual(hasattr(self.generator, "generate"), True) - def test_generate_default(self): - entity = test_entities["name"] - generated_text = self.generator.generate(entity, text=original_text) - match = re.match(entity.regex, generated_text) - self.assertNotEqual(match, None) - self.assertEqual(match.group(0), generated_text) +def test_mask_label_generator_has_methods(mask_label_generator): + assert hasattr(mask_label_generator, "generate") + + +def test_mask_label_generator_generate_default(mask_label_generator): + entity = TEST_ENTITIES["name"] + generated_text = mask_label_generator.generate(entity, text=TEST_ORIGINAL_TEXT) + match = re.match(entity.regex, generated_text) + assert match is not None + assert match.group(0) == generated_text # ===================================== @@ -245,88 +249,89 @@ def test_generate_default(self): # ===================================== -class TestDateGenerator(unittest.TestCase): +@pytest.fixture(scope="module") +def date_generator(): + return DateGenerator(lang="en") - @classmethod - def setUpClass(self): - self.generator = DateGenerator(lang="en") - def test_has_methods(self): - self.assertEqual(hasattr(self.generator, "generate"), True) +def test_date_generator_has_methods(date_generator): + assert hasattr(date_generator, "generate") - def test_generate_default(self): - entity = test_entities["date"][0] - generated_text = self.generator.generate(entity) - match = re.match(entity.regex, generated_text) - self.assertNotEqual(match, None) - self.assertEqual(match.group(0), generated_text) - def test_generate_custom_date_format(self): - entity = test_entities["date"][0] - generator = DateGenerator(date_format="dd-MM-yyyy") - generated_text = generator.generate(entity) - match = re.match(entity.regex, generated_text) - self.assertNotEqual(match, None) - self.assertEqual(match.group(0), generated_text) +def test_date_generator_generate_default(date_generator): + entity = TEST_ENTITIES["date"][0] + generated_text = date_generator.generate(entity) + match = re.match(entity.regex, generated_text) + assert match is not None + assert match.group(0) == generated_text - def text_generate_uncorrect_date_format(self): - entity = test_entities["date"][0] - generator = DateGenerator(date_format="yyyy-MM-dd") - try: - generator.generate(entity) - except Exception as e: - self.assertRaises(TypeError, e) - - def test_generate_first_day_of_the_month(self): - entity = test_entities["date"][0] - generated_text = self.generator.generate( - entity, sub_variant="FIRST_DAY_OF_THE_MONTH" - ) - self.assertEqual(generated_text, "01-05-2024") - def test_generate_last_day_of_the_month(self): - entity = test_entities["date"][0] - generated_text = self.generator.generate( - entity, sub_variant="LAST_DAY_OF_THE_MONTH" - ) - self.assertEqual(generated_text, "31-05-2024") +def test_date_generator_generate_custom_date_format(): + entity = TEST_ENTITIES["date"][0] + generator = DateGenerator(date_format="dd-MM-yyyy") + generated_text = generator.generate(entity) + match = re.match(entity.regex, generated_text) + assert match is not None + assert match.group(0) == generated_text - def test_generate_middle_of_the_month(self): - entity = test_entities["date"][0] - generated_text = self.generator.generate( - entity, sub_variant="MIDDLE_OF_THE_MONTH" - ) - self.assertEqual(generated_text, "15-05-2024") - def test_generate_middle_of_the_year(self): - entity = test_entities["date"][0] - generated_text = self.generator.generate( - entity, sub_variant="MIDDLE_OF_THE_YEAR" - ) - self.assertEqual(generated_text, "01-07-2024") +def test_date_generator_generate_non_matching_date_format(): + entity = TEST_ENTITIES["date"][0] + generator = DateGenerator(date_format="yyyy-MM-dd") + custom_date = generator.generate(entity, sub_variant="FIRST_DAY_OF_THE_MONTH") + assert custom_date == "2024-05-01" - def test_generate_random(self): - entity = test_entities["date"][0] - generated_text = self.generator.generate(entity, sub_variant="RANDOM") - match = re.match(entity.regex, generated_text) - self.assertNotEqual(match, None) - self.assertEqual(match.group(0), generated_text) - def test_generate_uncorrect_type(self): - entity = test_entities["name"] - try: - self.generator.generate(entity) - except Exception as e: - self.assertEqual(type(e), ValueError) +def test_date_generator_generate_first_day_of_the_month(date_generator): + entity = TEST_ENTITIES["date"][0] + generated_text = date_generator.generate( + entity, sub_variant="FIRST_DAY_OF_THE_MONTH" + ) + assert generated_text == "01-05-2024" + - def test_process_different_formats(self): - for entity in test_entities["date"]: - try: - self.generator.generate(entity, sub_variant="RANDOM") - except ValueError: - self.fail( - f"self.generator.generate() raised ValueError unexpectedly for date: {entity.text}" - ) +def test_date_generator_generate_last_day_of_the_month(date_generator): + entity = TEST_ENTITIES["date"][0] + generated_text = date_generator.generate( + entity, sub_variant="LAST_DAY_OF_THE_MONTH" + ) + assert generated_text == "31-05-2024" + + +def test_date_generator_generate_middle_of_the_month(date_generator): + entity = TEST_ENTITIES["date"][0] + generated_text = date_generator.generate(entity, sub_variant="MIDDLE_OF_THE_MONTH") + assert generated_text == "15-05-2024" + + +def test_date_generator_generate_middle_of_the_year(date_generator): + entity = TEST_ENTITIES["date"][0] + generated_text = date_generator.generate(entity, sub_variant="MIDDLE_OF_THE_YEAR") + assert generated_text == "01-07-2024" + + +def test_date_generator_generate_random(date_generator): + entity = TEST_ENTITIES["date"][0] + generated_text = date_generator.generate(entity, sub_variant="RANDOM") + match = re.match(entity.regex, generated_text) + assert match is not None + assert match.group(0) == generated_text + + +def test_date_generator_generate_uncorrect_type(date_generator): + entity = TEST_ENTITIES["name"] + with pytest.raises(ValueError): + date_generator.generate(entity) + + +def test_date_generator_process_different_formats(date_generator): + for entity in TEST_ENTITIES["date"]: + try: + date_generator.generate(entity, sub_variant="RANDOM") + except ValueError: + pytest.fail( + f"date_generator.generate() raised ValueError unexpectedly for date: {entity.text}" + ) # ===================================== @@ -334,41 +339,40 @@ def test_process_different_formats(self): # ===================================== -class TestNumberGenerator(unittest.TestCase): +@pytest.fixture(scope="module") +def number_generator(): + return NumberGenerator() + + +def test_number_generator_has_methods(number_generator): + assert hasattr(number_generator, "generate") - @classmethod - def setUpClass(self): - self.generator = NumberGenerator() - def test_has_methods(self): - self.assertEqual(hasattr(self.generator, "generate"), True) +def test_number_generator_generate_integer(number_generator): + entity = TEST_ENTITIES["integer"] + generated_text = number_generator.generate(entity) + match = re.match(entity.regex, generated_text) + assert match is not None + assert match.group(0) == generated_text - def test_generate_integer(self): - entity = test_entities["integer"] - generated_text = self.generator.generate(entity) - match = re.match(entity.regex, generated_text) - self.assertNotEqual(match, None) - self.assertEqual(match.group(0), generated_text) - def test_generate_float(self): - entity = test_entities["float"] - generated_text = self.generator.generate(entity) - match = re.match(entity.regex, generated_text) - self.assertNotEqual(match, None) - self.assertEqual(match.group(0), generated_text) +def test_number_generator_generate_float(number_generator): + entity = TEST_ENTITIES["float"] + generated_text = number_generator.generate(entity) + match = re.match(entity.regex, generated_text) + assert match is not None + assert match.group(0) == generated_text - def test_generate_custom(self): - entity = test_entities["custom"] - generated_text = self.generator.generate(entity) - match = re.match(entity.regex, generated_text) - self.assertNotEqual(match, None) - self.assertEqual(match.group(0), generated_text) - def test_generate_uncorrect_type(self): - entity = test_entities["name"] - with self.assertRaises(ValueError): - self.generator.generate(entity) +def test_number_generator_generate_custom(number_generator): + entity = TEST_ENTITIES["custom"] + generated_text = number_generator.generate(entity) + match = re.match(entity.regex, generated_text) + assert match is not None + assert match.group(0) == generated_text -if __name__ == "__main__": - unittest.main() +def test_number_generator_generate_uncorrect_type(number_generator): + entity = TEST_ENTITIES["name"] + with pytest.raises(ValueError): + number_generator.generate(entity) diff --git a/test/test_language_detector.py b/test/test_language_detector.py index 60790eb..535355d 100644 --- a/test/test_language_detector.py +++ b/test/test_language_detector.py @@ -1,6 +1,7 @@ -import unittest import warnings +import pytest + from anonipy.utils.language_detector import LanguageDetector from anonipy.constants import LANGUAGES @@ -9,102 +10,98 @@ # ===================================== -class TestLanguageDetector(unittest.TestCase): - - def setUp(self): - warnings.filterwarnings("ignore", category=ImportWarning) - warnings.filterwarnings("ignore", category=UserWarning) - warnings.filterwarnings("ignore", category=FutureWarning) - - def test_init(self): - language_detector = LanguageDetector() - self.assertEqual(language_detector.__class__, LanguageDetector) - - def test_has_methods(self): - language_detector = LanguageDetector() - self.assertEqual(hasattr(language_detector, "detect"), True) - - def test_detect_english(self): - language_detector = LanguageDetector() - language = language_detector.detect( - "This test verifies that the method is working correctly" - ) - self.assertEqual(language[0], "en") - self.assertEqual(language[1], "English") - self.assertEqual(language, LANGUAGES.ENGLISH) - - def test_detect_slovene(self): - language_detector = LanguageDetector() - language = language_detector.detect( - "Ta test preverja, ali metoda dela pravilno" - ) - self.assertEqual(language[0], "sl") - self.assertEqual(language[1], "Slovene") - self.assertEqual(language, LANGUAGES.SLOVENE) - - def test_detect_german(self): - language_detector = LanguageDetector() - language = language_detector.detect( - "Dieser Test überprüft, ob die Methode ordnungsgemäß funktioniert" - ) - self.assertEqual(language[0], "de") - self.assertEqual(language[1], "German") - self.assertEqual(language, LANGUAGES.GERMAN) - - def test_detect_dutch(self): - language_detector = LanguageDetector() - language = language_detector.detect( - "Deze test verifieert dat de methode correct werkt" - ) - self.assertEqual(language[0], "nl") - self.assertEqual(language[1], "Dutch") - self.assertEqual(language, LANGUAGES.DUTCH) - - def test_detect_spanish(self): - language_detector = LanguageDetector() - language = language_detector.detect( - "Esta prueba verifica que el método está funcionando correctamente" - ) - self.assertEqual(language[0], "es") - self.assertEqual(language[1], "Spanish") - self.assertEqual(language, LANGUAGES.SPANISH) - - def test_detect_greek(self): - language_detector = LanguageDetector() - language = language_detector.detect( - "Αυτή η δοκιμή επαληθεύει ότι η μέθοδος λειτουργεί σωστά" - ) - self.assertEqual(language[0], "el") - self.assertEqual(language[1], "Greek") - self.assertEqual(language, LANGUAGES.GREEK) - - def test_detect_italian(self): - language_detector = LanguageDetector() - language = language_detector.detect( - "Questo test verifica che il metodo funzioni correttamente" - ) - self.assertEqual(language[0], "it") - self.assertEqual(language[1], "Italian") - self.assertEqual(language, LANGUAGES.ITALIAN) - - def test_detect_french(self): - language_detector = LanguageDetector() - language = language_detector.detect( - "Ce test vérifie que la méthode fonctionne correctement" - ) - self.assertEqual(language[0], "fr") - self.assertEqual(language[1], "French") - self.assertEqual(language, LANGUAGES.FRENCH) - - def test_detect_ukrainian(self): - language_detector = LanguageDetector() - language = language_detector.detect( - "Цей тест перевіряє, чи метод працює правильно" - ) - self.assertEqual(language[0], "uk") - self.assertEqual(language[1], "Ukrainian") - self.assertEqual(language, LANGUAGES.UKRAINIAN) - - -if __name__ == "__main__": - unittest.main() +@pytest.fixture(scope="module", autouse=True) +def suppress_warnings(): + warnings.filterwarnings("ignore", category=ImportWarning) + warnings.filterwarnings("ignore", category=UserWarning) + warnings.filterwarnings("ignore", category=FutureWarning) + + +@pytest.fixture +def language_detector(): + return LanguageDetector() + + +def test_init(language_detector): + assert isinstance(language_detector, LanguageDetector) + + +def test_has_methods(language_detector): + assert hasattr(language_detector, "detect") + + +def test_detect_english(language_detector): + language = language_detector.detect( + "This test verifies that the method is working correctly" + ) + assert language[0] == "en" + assert language[1] == "English" + assert language == LANGUAGES.ENGLISH + + +def test_detect_slovene(language_detector): + language = language_detector.detect("Ta test preverja, ali metoda dela pravilno") + assert language[0] == "sl" + assert language[1] == "Slovene" + assert language == LANGUAGES.SLOVENE + + +def test_detect_german(language_detector): + language = language_detector.detect( + "Dieser Test überprüft, ob die Methode ordnungsgemäß funktioniert" + ) + assert language[0] == "de" + assert language[1] == "German" + assert language == LANGUAGES.GERMAN + + +def test_detect_dutch(language_detector): + language = language_detector.detect( + "Deze test verifieert dat de methode correct werkt" + ) + assert language[0] == "nl" + assert language[1] == "Dutch" + assert language == LANGUAGES.DUTCH + + +def test_detect_spanish(language_detector): + language = language_detector.detect( + "Esta prueba verifica que el método está funcionando correctamente" + ) + assert language[0] == "es" + assert language[1] == "Spanish" + assert language == LANGUAGES.SPANISH + + +def test_detect_greek(language_detector): + language = language_detector.detect( + "Αυτή η δοκιμή επαληθεύει ότι η μέθοδος λειτουργεί σωστά" + ) + assert language[0] == "el" + assert language[1] == "Greek" + assert language == LANGUAGES.GREEK + + +def test_detect_italian(language_detector): + language = language_detector.detect( + "Questo test verifica che il metodo funzioni correttamente" + ) + assert language[0] == "it" + assert language[1] == "Italian" + assert language == LANGUAGES.ITALIAN + + +def test_detect_french(language_detector): + language = language_detector.detect( + "Ce test vérifie que la méthode fonctionne correctement" + ) + assert language[0] == "fr" + assert language[1] == "French" + assert language == LANGUAGES.FRENCH + + +def test_detect_ukrainian(language_detector): + language = language_detector.detect("Цей тест перевіряє, чи метод працює правильно") + assert language[0] == "uk" + assert language[1] == "Ukrainian" + assert language == LANGUAGES.UKRAINIAN diff --git a/test/test_pipeline.py b/test/test_pipeline.py index ce06285..ca70061 100644 --- a/test/test_pipeline.py +++ b/test/test_pipeline.py @@ -1,7 +1,8 @@ import os -import unittest import shutil +import warnings +import pytest from transformers import logging from anonipy.anonymize.pipeline import Pipeline @@ -13,87 +14,94 @@ logging.set_verbosity_error() # ===================================== -# Helper functions +# Test Pipeline # ===================================== -# ===================================== -# Test Pipeline -# ===================================== +@pytest.fixture(autouse=True) +def suppress_warnings(): + warnings.filterwarnings("ignore", category=UserWarning) + + +@pytest.fixture(scope="module") +def setup(): + ner_labels = [{"label": "PERSON", "type": "string"}] + pattern_labels = [{"label": "DATE", "type": "regex", "regex": r"\d{4}-\d{2}-\d{2}"}] + extractors = [ + NERExtractor(ner_labels, lang=LANGUAGES.ENGLISH), + PatternExtractor(pattern_labels, lang=LANGUAGES.ENGLISH), + ] + multi_extractor = MultiExtractor(extractors) + strategy = RedactionStrategy() + input_dir = "test/resources" + output_dir = "test/output" + yield extractors, multi_extractor, strategy, input_dir, output_dir + if os.path.exists(output_dir): + shutil.rmtree(output_dir) + + +def test_init(): + with pytest.raises(TypeError): + Pipeline() + + +def test_init_extractor_single(setup): + extractors, _, strategy, _, _ = setup + pipeline = Pipeline(extractors[0], strategy) + assert isinstance(pipeline, Pipeline) + + +def test_init_extractor_list(setup): + extractors, _, strategy, _, _ = setup + pipeline = Pipeline(extractors, strategy) + assert isinstance(pipeline, Pipeline) + + +def test_init_extractor_multi(setup): + _, multi_extractor, strategy, _, _ = setup + pipeline = Pipeline(multi_extractor, strategy) + assert isinstance(pipeline, Pipeline) + + +def test_methods(setup): + _, multi_extractor, strategy, _, _ = setup + pipeline = Pipeline(multi_extractor, strategy) + assert hasattr(pipeline, "anonymize") + + +def test_anonymize(setup): + _, multi_extractor, strategy, input_dir, output_dir = setup + pipeline = Pipeline(multi_extractor, strategy) + pipeline.anonymize(input_dir, output_dir) + + assert os.path.exists(output_dir) + for root, _, files in os.walk(output_dir): + for file in files: + with open(os.path.join(root, file), "r") as f: + assert f.read() + + +def test_anonymize_flatten(setup): + _, multi_extractor, strategy, input_dir, output_dir = setup + pipeline = Pipeline(multi_extractor, strategy) + pipeline.anonymize(input_dir, output_dir, flatten=True) + + assert os.path.exists(output_dir) + for root, _, files in os.walk(output_dir): + for file in files: + with open(os.path.join(root, file), "r") as f: + assert f.read() + + +def test_anonymize_invalid_input_dir(setup): + _, multi_extractor, strategy, _, output_dir = setup + pipeline = Pipeline(multi_extractor, strategy) + with pytest.raises(ValueError): + pipeline.anonymize("invalid", output_dir) -class TestPipeline(unittest.TestCase): - - def setUp(self): - self.ner_labels = [{"label": "PERSON", "type": "string"}] - self.pattern_labels = [ - {"label": "DATE", "type": "regex", "regex": r"\d{4}-\d{2}-\d{2}"} - ] - self.extractors = [ - NERExtractor(self.ner_labels, lang=LANGUAGES.ENGLISH), - PatternExtractor(self.pattern_labels, lang=LANGUAGES.ENGLISH), - ] - self.multi_extractor = MultiExtractor(self.extractors) - self.strategy = RedactionStrategy() - self.input_dir = "test/resources" - self.output_dir = "test/output" - - def tearDown(self): - if not os.path.exists(self.output_dir): - return - # remove the output directory - shutil.rmtree(self.output_dir) - - def test_init(self): - with self.assertRaises(TypeError): - Pipeline() - - def test_init_extractor_single(self): - pipeline = Pipeline(self.extractors[0], self.strategy) - self.assertEqual(pipeline.__class__, Pipeline) - - def test_init_extractor_list(self): - pipeline = Pipeline(self.extractors, self.strategy) - self.assertEqual(pipeline.__class__, Pipeline) - - def test_init_extractor_multi(self): - pipeline = Pipeline(self.multi_extractor, self.strategy) - self.assertEqual(pipeline.__class__, Pipeline) - - def test_methods(self): - pipeline = Pipeline(self.multi_extractor, self.strategy) - self.assertTrue(hasattr(pipeline, "anonymize")) - - def test_anonymize(self): - pipeline = Pipeline(self.multi_extractor, self.strategy) - pipeline.anonymize(self.input_dir, self.output_dir) - - self.assertTrue(os.path.exists(self.output_dir)) - for root, _, files in os.walk(self.output_dir): - for file in files: - with open(os.path.join(root, file), "r") as f: - self.assertTrue(f.read()) - - def test_anonymize_flatten(self): - pipeline = Pipeline(self.multi_extractor, self.strategy) - pipeline.anonymize(self.input_dir, self.output_dir, flatten=True) - - self.assertTrue(os.path.exists(self.output_dir)) - for root, _, files in os.walk(self.output_dir): - for file in files: - with open(os.path.join(root, file), "r") as f: - self.assertTrue(f.read()) - - def test_anonymize_invalid_input_dir(self): - pipeline = Pipeline(self.multi_extractor, self.strategy) - with self.assertRaises(ValueError): - pipeline.anonymize("invalid", self.output_dir) - - def test_anonymize_invalid_output_dir(self): - pipeline = Pipeline(self.multi_extractor, self.strategy) - with self.assertRaises(ValueError): - pipeline.anonymize(self.input_dir, self.input_dir) - - -if __name__ == "__main__": - unittest.main() +def test_anonymize_invalid_output_dir(setup): + _, multi_extractor, strategy, input_dir, _ = setup + pipeline = Pipeline(multi_extractor, strategy) + with pytest.raises(ValueError): + pipeline.anonymize(input_dir, input_dir) diff --git a/test/test_regex.py b/test/test_regex.py index 8a8d382..81d971f 100644 --- a/test/test_regex.py +++ b/test/test_regex.py @@ -1,4 +1,4 @@ -import unittest +import pytest from anonipy.utils.regex import ( regex_mapping, @@ -13,7 +13,6 @@ ) from anonipy.constants import ENTITY_TYPES - # ===================================== # Test Cases # ===================================== @@ -46,26 +45,17 @@ {"value": "test", "entity": "test", "regex": ".*"}, ] - # ===================================== # Test Entity # ===================================== -class TestRegex(unittest.TestCase): - - def test_init(self): - self.assertEqual(regex_mapping.__class__, RegexMapping) - self.assertEqual(hasattr(regex_mapping, "regex_mapping"), True) - - def test_regex_mapping(self): - - for test_case in TEST_CASES: - self.assertEqual(regex_mapping[test_case["entity"]], test_case["regex"]) - self.assertEqual( - regex_mapping[test_case["value"]], regex_mapping[test_case["entity"]] - ) +def test_init(): + assert isinstance(regex_mapping, RegexMapping) + assert hasattr(regex_mapping, "regex_mapping") -if __name__ == "__main__": - unittest.main() +@pytest.mark.parametrize("test_case", TEST_CASES) +def test_regex_mapping(test_case): + assert regex_mapping[test_case["entity"]] == test_case["regex"] + assert regex_mapping[test_case["value"]] == regex_mapping[test_case["entity"]] diff --git a/test/test_strategies.py b/test/test_strategies.py index c2d1878..5b5b097 100644 --- a/test/test_strategies.py +++ b/test/test_strategies.py @@ -1,4 +1,4 @@ -import unittest +import pytest from anonipy.definitions import Entity from anonipy.anonymize.strategies import ( @@ -11,8 +11,8 @@ # Helper functions # ===================================== -test_text = "Test this string, and this test too!" -test_entities = [ +TEST_TEXT = "Test this string, and this test too!" +TEST_ENTITIES = [ Entity(text="Test", label="test", start_index=0, end_index=4), Entity(text="string", label="type", start_index=10, end_index=16), Entity(text="test", label="test", start_index=27, end_index=31), @@ -27,36 +27,52 @@ def anonymization_mapping(text, entity): return "[REDACTED]" +@pytest.fixture +def redaction_strategy(): + return RedactionStrategy() + + +@pytest.fixture +def masking_strategy(): + return MaskingStrategy() + + +@pytest.fixture +def pseudonymization_strategy(): + return PseudonymizationStrategy(mapping=anonymization_mapping) + + # ===================================== # Test Redaction Strategy # ===================================== -class TestRedactionStrategy(unittest.TestCase): - def test_init(self): - strategy = RedactionStrategy() - self.assertEqual(strategy.__class__, RedactionStrategy) +def test_redaction_strategy_init(redaction_strategy): + assert redaction_strategy.__class__ == RedactionStrategy - def test_has_methods(self): - strategy = RedactionStrategy() - self.assertEqual(hasattr(strategy, "anonymize"), True) - def test_default_inputs(self): - strategy = RedactionStrategy() - self.assertEqual(strategy.substitute_label, "[REDACTED]") +def test_redaction_strategy_has_methods(redaction_strategy): + assert hasattr(redaction_strategy, "anonymize") - def test_custom_inputs(self): - strategy = RedactionStrategy(substitute_label="[TEST]") - self.assertEqual(strategy.substitute_label, "[TEST]") - def test_anonymize_default_inputs(self): - strategy = RedactionStrategy() - anonymized_text, replacements = strategy.anonymize(test_text, test_entities) - self.assertEqual( - anonymized_text, "[REDACTED] this [REDACTED], and this [REDACTED] too!" - ) - self.assertEqual( - replacements, +@pytest.mark.parametrize( + "substitute_label, expected_label", + [ + (None, "[REDACTED]"), + ("[TEST]", "[TEST]"), + ], +) +def test_redaction_strategy_inputs(substitute_label, expected_label): + strategy = RedactionStrategy(substitute_label=substitute_label) + assert strategy.substitute_label == expected_label + + +@pytest.mark.parametrize( + "substitute_label, expected_text, expected_replacements", + [ + ( + None, + "[REDACTED] this [REDACTED], and this [REDACTED] too!", [ { "original_text": "Test", @@ -80,14 +96,10 @@ def test_anonymize_default_inputs(self): "anonymized_text": "[REDACTED]", }, ], - ) - - def test_anonymize_custom_inputs(self): - strategy = RedactionStrategy(substitute_label="[TEST]") - anonymized_text, replacements = strategy.anonymize(test_text, test_entities) - self.assertEqual(anonymized_text, "[TEST] this [TEST], and this [TEST] too!") - self.assertEqual( - replacements, + ), + ( + "[TEST]", + "[TEST] this [TEST], and this [TEST] too!", [ { "original_text": "Test", @@ -111,7 +123,16 @@ def test_anonymize_custom_inputs(self): "anonymized_text": "[TEST]", }, ], - ) + ), + ], +) +def test_redaction_strategy_anonymize( + substitute_label, expected_text, expected_replacements +): + strategy = RedactionStrategy(substitute_label=substitute_label) + anonymized_text, replacements = strategy.anonymize(TEST_TEXT, TEST_ENTITIES) + assert anonymized_text == expected_text + assert replacements == expected_replacements # ===================================== @@ -119,29 +140,32 @@ def test_anonymize_custom_inputs(self): # ===================================== -class TestMaskingStrategy(unittest.TestCase): - def test_init(self): - strategy = MaskingStrategy() - self.assertEqual(strategy.__class__, MaskingStrategy) +def test_masking_strategy_init(masking_strategy): + assert masking_strategy.__class__ == MaskingStrategy - def test_methods(self): - strategy = MaskingStrategy() - self.assertEqual(hasattr(strategy, "anonymize"), True) - def test_default_inputs(self): - strategy = MaskingStrategy() - self.assertEqual(strategy.substitute_label, "*") +def test_masking_strategy_methods(masking_strategy): + assert hasattr(masking_strategy, "anonymize") - def test_custom_inputs(self): - strategy = MaskingStrategy(substitute_label="A") - self.assertEqual(strategy.substitute_label, "A") - def test_anonymize_default_inputs(self): - strategy = MaskingStrategy() - anonymized_text, replacements = strategy.anonymize(test_text, test_entities) - self.assertEqual(anonymized_text, "**** this ******, and this **** too!") - self.assertEqual( - replacements, +@pytest.mark.parametrize( + "substitute_label, expected_label", + [ + (None, "*"), + ("A", "A"), + ], +) +def test_masking_strategy_inputs(substitute_label, expected_label): + strategy = MaskingStrategy(substitute_label=substitute_label) + assert strategy.substitute_label == expected_label + + +@pytest.mark.parametrize( + "substitute_label, expected_text, expected_replacements", + [ + ( + None, + "**** this ******, and this **** too!", [ { "original_text": "Test", @@ -165,14 +189,10 @@ def test_anonymize_default_inputs(self): "anonymized_text": "****", }, ], - ) - - def test_anonymize_custom_inputs(self): - strategy = MaskingStrategy(substitute_label="A") - anonymized_text, replacements = strategy.anonymize(test_text, test_entities) - self.assertEqual(anonymized_text, "AAAA this AAAAAA, and this AAAA too!") - self.assertEqual( - replacements, + ), + ( + "A", + "AAAA this AAAAAA, and this AAAA too!", [ { "original_text": "Test", @@ -196,7 +216,16 @@ def test_anonymize_custom_inputs(self): "anonymized_text": "AAAA", }, ], - ) + ), + ], +) +def test_masking_strategy_anonymize( + substitute_label, expected_text, expected_replacements +): + strategy = MaskingStrategy(substitute_label=substitute_label) + anonymized_text, replacements = strategy.anonymize(TEST_TEXT, TEST_ENTITIES) + assert anonymized_text == expected_text + assert replacements == expected_replacements # ===================================== @@ -204,50 +233,44 @@ def test_anonymize_custom_inputs(self): # ===================================== -class TestPseudonymizationStrategy(unittest.TestCase): - def test_init(self): - with self.assertRaises(TypeError): - PseudonymizationStrategy() - - def test_init_inputs(self): - strategy = PseudonymizationStrategy(mapping=anonymization_mapping) - self.assertEqual(strategy.__class__, PseudonymizationStrategy) - - def test_methods(self): - strategy = PseudonymizationStrategy(mapping=anonymization_mapping) - self.assertEqual(hasattr(strategy, "anonymize"), True) - - def test_anonymize_inputs(self): - strategy = PseudonymizationStrategy(mapping=anonymization_mapping) - anonymized_text, replacements = strategy.anonymize(test_text, test_entities) - self.assertEqual(anonymized_text, "[TEST] this [TYPE], and this [TEST] too!") - self.assertEqual( - replacements, - [ - { - "original_text": "Test", - "label": "test", - "start_index": 0, - "end_index": 4, - "anonymized_text": "[TEST]", - }, - { - "original_text": "string", - "label": "type", - "start_index": 10, - "end_index": 16, - "anonymized_text": "[TYPE]", - }, - { - "original_text": "test", - "label": "test", - "start_index": 27, - "end_index": 31, - "anonymized_text": "[TEST]", - }, - ], - ) - - -if __name__ == "__main__": - unittest.main() +def test_pseudonymization_strategy_init(): + with pytest.raises(TypeError): + PseudonymizationStrategy() + + +def test_pseudonymization_strategy_init_inputs(pseudonymization_strategy): + assert pseudonymization_strategy.__class__ == PseudonymizationStrategy + + +def test_pseudonymization_strategy_methods(pseudonymization_strategy): + assert hasattr(pseudonymization_strategy, "anonymize") + + +def test_pseudonymization_strategy_anonymize_inputs(pseudonymization_strategy): + anonymized_text, replacements = pseudonymization_strategy.anonymize( + TEST_TEXT, TEST_ENTITIES + ) + assert anonymized_text == "[TEST] this [TYPE], and this [TEST] too!" + assert replacements == [ + { + "original_text": "Test", + "label": "test", + "start_index": 0, + "end_index": 4, + "anonymized_text": "[TEST]", + }, + { + "original_text": "string", + "label": "type", + "start_index": 10, + "end_index": 16, + "anonymized_text": "[TYPE]", + }, + { + "original_text": "test", + "label": "test", + "start_index": 27, + "end_index": 31, + "anonymized_text": "[TEST]", + }, + ]